Skip to content

Commit 9b7ca2d

Browse files
authored
Merge pull request #1173 from fluxcd/refactor-tar-idioms
tar: various internal improvements
2 parents ad86bcd + 154b6d4 commit 9b7ca2d

3 files changed

Lines changed: 79 additions & 118 deletions

File tree

tar/symlinks.go

Lines changed: 32 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ func ResolveSymlinks(srcDir, dstDir string) error {
5656
return fmt.Errorf("srcDir %s is not a directory", absSrc)
5757
}
5858

59-
if err := checkDstDir(dstDir); err != nil {
59+
if err = checkDstDir(dstDir); err != nil {
6060
return err
6161
}
6262

63-
return copyResolvedDir(realSrc, dstDir, make(map[string]bool))
63+
return copyDir("", realSrc, dstDir, make(map[string]bool))
6464
}
6565

6666
// ResolveSymlinksRoot is the confined variant of ResolveSymlinks: every
@@ -109,11 +109,11 @@ func ResolveSymlinksRoot(rootDir, srcDir, dstDir string) error {
109109
return fmt.Errorf("srcDir %s is not a directory", absSrc)
110110
}
111111

112-
if err := checkDstDir(dstDir); err != nil {
112+
if err = checkDstDir(dstDir); err != nil {
113113
return err
114114
}
115115

116-
return copyConfinedDir(realRoot, realSrc, dstDir, make(map[string]bool))
116+
return copyDir(realRoot, realSrc, dstDir, make(map[string]bool))
117117
}
118118

119119
// checkDstDir verifies that dstDir exists and is a directory.
@@ -128,13 +128,16 @@ func checkDstDir(dstDir string) error {
128128
return nil
129129
}
130130

131-
// copyResolvedDir recursively copies srcDir (already resolved via
132-
// EvalSymlinks) into dstDir. visited tracks resolved directory paths
133-
// currently on the call stack so that a re-entry via a symlink does
134-
// not loop. Entries are removed when the call returns, so the same
135-
// directory may be copied again through a different symlink — this is
136-
// intentional (both link sites need the content).
137-
func copyResolvedDir(srcDir, dstDir string, visited map[string]bool) error {
131+
// copyDir recursively copies srcDir into dstDir, resolving all
132+
// symlinks. srcDir must already be fully resolved (no symlink
133+
// components). If confineRoot is non-empty, every symlink target must
134+
// resolve within confineRoot; a violation fails the call. visited is a
135+
// stack-based cycle breaker: directories currently on the call stack
136+
// are skipped to prevent infinite loops from symlink cycles. Entries
137+
// are removed when the call returns, so the same directory may be
138+
// copied again through a different symlink — this is intentional
139+
// (both link sites need the content).
140+
func copyDir(confineRoot, srcDir, dstDir string, visited map[string]bool) error {
138141
if visited[srcDir] {
139142
return nil
140143
}
@@ -150,74 +153,36 @@ func copyResolvedDir(srcDir, dstDir string, visited map[string]bool) error {
150153
srcPath := filepath.Join(srcDir, entry.Name())
151154
dstPath := filepath.Join(dstDir, entry.Name())
152155

153-
realPath, err := filepath.EvalSymlinks(srcPath)
154-
if err != nil {
155-
return fmt.Errorf("resolving symlink %s: %w", srcPath, err)
156-
}
157-
realInfo, err := os.Stat(realPath)
158-
if err != nil {
159-
return fmt.Errorf("stat resolved path %s: %w", realPath, err)
160-
}
156+
isLink := entry.Type()&os.ModeSymlink != 0
161157

162-
if realInfo.IsDir() {
163-
if err := os.MkdirAll(dstPath, realInfo.Mode()); err != nil {
164-
return err
158+
realPath := srcPath
159+
if isLink {
160+
realPath, err = filepath.EvalSymlinks(srcPath)
161+
if err != nil {
162+
return fmt.Errorf("resolving symlink %s: %w", srcPath, err)
165163
}
166-
if err := copyResolvedDir(realPath, dstPath, visited); err != nil {
167-
return err
164+
// Report the logical path, not the resolved target,
165+
// to avoid leaking filesystem layout.
166+
if confineRoot != "" && !isWithin(confineRoot, realPath) {
167+
return fmt.Errorf("symlink %s resolves outside rootDir", srcPath)
168168
}
169-
continue
170-
}
171-
172-
if !realInfo.Mode().IsRegular() {
173-
continue
174169
}
175170

176-
if err := copyResolvedFile(realPath, dstPath, realInfo.Mode()); err != nil {
177-
return err
178-
}
179-
}
180-
return nil
181-
}
182-
183-
// copyConfinedDir is the root-confined equivalent of copyResolvedDir.
184-
// srcDir is assumed already resolved and already verified as within
185-
// realRoot. visited is a stack-based cycle breaker (see copyResolvedDir).
186-
func copyConfinedDir(realRoot, srcDir, dstDir string, visited map[string]bool) error {
187-
if visited[srcDir] {
188-
return nil
189-
}
190-
visited[srcDir] = true
191-
defer delete(visited, srcDir)
192-
193-
entries, err := os.ReadDir(srcDir)
194-
if err != nil {
195-
return err
196-
}
197-
198-
for _, entry := range entries {
199-
srcPath := filepath.Join(srcDir, entry.Name())
200-
dstPath := filepath.Join(dstDir, entry.Name())
201-
202-
realPath, err := filepath.EvalSymlinks(srcPath)
203-
if err != nil {
204-
return fmt.Errorf("resolving %s: %w", srcPath, err)
205-
}
206-
// Report the logical path of the offending symlink, not the
207-
// resolved target, to avoid leaking filesystem layout.
208-
if !isWithin(realRoot, realPath) {
209-
return fmt.Errorf("symlink %s resolves outside rootDir", srcPath)
171+
var realInfo os.FileInfo
172+
if isLink {
173+
realInfo, err = os.Stat(realPath)
174+
} else {
175+
realInfo, err = entry.Info()
210176
}
211-
realInfo, err := os.Stat(realPath)
212177
if err != nil {
213178
return fmt.Errorf("stat %s: %w", realPath, err)
214179
}
215180

216181
if realInfo.IsDir() {
217-
if err := os.MkdirAll(dstPath, realInfo.Mode()); err != nil {
182+
if err = os.MkdirAll(dstPath, realInfo.Mode()); err != nil {
218183
return err
219184
}
220-
if err := copyConfinedDir(realRoot, realPath, dstPath, visited); err != nil {
185+
if err = copyDir(confineRoot, realPath, dstPath, visited); err != nil {
221186
return err
222187
}
223188
continue
@@ -227,7 +192,7 @@ func copyConfinedDir(realRoot, srcDir, dstDir string, visited map[string]bool) e
227192
continue
228193
}
229194

230-
if err := copyResolvedFile(realPath, dstPath, realInfo.Mode()); err != nil {
195+
if err = copyResolvedFile(realPath, dstPath, realInfo.Mode()); err != nil {
231196
return err
232197
}
233198
}

tar/tar.go

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"compress/gzip"
2222
"fmt"
2323
"io"
24+
"io/fs"
2425
"os"
2526
"path/filepath"
2627
"time"
@@ -44,34 +45,43 @@ func Tar(dir string, w io.Writer, opts ...Option) (int64, error) {
4445
return 0, err
4546
}
4647

47-
if fi, err := os.Stat(absDir); err != nil {
48+
fi, err := os.Stat(absDir)
49+
if err != nil {
4850
return 0, fmt.Errorf("invalid dir path %s: %w", absDir, err)
49-
} else if !fi.IsDir() {
51+
}
52+
if !fi.IsDir() {
5053
return 0, fmt.Errorf("not a directory: %s", absDir)
5154
}
5255

5356
cw := &countWriter{w: w}
5457

55-
var gw *gzip.Writer
5658
var tw *tar.Writer
59+
var closers []io.Closer
5760
if o.skipGzip {
5861
tw = tar.NewWriter(cw)
62+
closers = []io.Closer{tw}
5963
} else {
60-
gw = gzip.NewWriter(cw)
64+
gw := gzip.NewWriter(cw)
6165
tw = tar.NewWriter(gw)
66+
closers = []io.Closer{tw, gw}
6267
}
6368

6469
buf := make([]byte, bufferSize)
65-
if err := filepath.Walk(absDir, func(p string, fi os.FileInfo, err error) error {
70+
walkErr := filepath.WalkDir(absDir, func(p string, d fs.DirEntry, err error) error {
6671
if err != nil {
6772
return err
6873
}
6974

7075
// Skip symlinks and other non-regular, non-directory entries.
71-
if m := fi.Mode(); !(m.IsRegular() || m.IsDir()) {
76+
if t := d.Type(); !t.IsRegular() && !t.IsDir() {
7277
return nil
7378
}
7479

80+
fi, err := d.Info()
81+
if err != nil {
82+
return err
83+
}
84+
7585
if o.filter != nil && o.filter(p, fi) {
7686
return nil
7787
}
@@ -96,7 +106,7 @@ func Tar(dir string, w io.Writer, opts ...Option) (int64, error) {
96106
header.AccessTime = time.Time{}
97107
header.ChangeTime = time.Time{}
98108

99-
if err := tw.WriteHeader(header); err != nil {
109+
if err = tw.WriteHeader(header); err != nil {
100110
return err
101111
}
102112

@@ -113,27 +123,14 @@ func Tar(dir string, w io.Writer, opts ...Option) (int64, error) {
113123
err = closeErr
114124
}
115125
return err
116-
}); err != nil {
117-
_ = tw.Close()
118-
if gw != nil {
119-
_ = gw.Close()
120-
}
121-
return cw.n, err
122-
}
126+
})
123127

124-
if err := tw.Close(); err != nil {
125-
if gw != nil {
126-
_ = gw.Close()
128+
for _, c := range closers {
129+
if closeErr := c.Close(); closeErr != nil && walkErr == nil {
130+
walkErr = closeErr
127131
}
128-
return cw.n, err
129132
}
130-
if gw != nil {
131-
if err := gw.Close(); err != nil {
132-
return cw.n, err
133-
}
134-
}
135-
136-
return cw.n, nil
133+
return cw.n, walkErr
137134
}
138135

139136
// countWriter wraps an io.Writer and counts the bytes written.

tar/untar.go

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ const (
4747
// If dir is a relative path, it cannot ascend from the current working
4848
// directory. If dir exists, it must be a directory; otherwise it is
4949
// created.
50-
func Untar(r io.Reader, dir string, inOpts ...Option) (err error) {
50+
func Untar(r io.Reader, dir string, inOpts ...Option) error {
5151
opts := tarOpts{
5252
maxUntarSize: DefaultMaxUntarSize,
5353
}
@@ -60,8 +60,7 @@ func Untar(r io.Reader, dir string, inOpts ...Option) (err error) {
6060
return err
6161
}
6262

63-
dir, err = securejoin.SecureJoin(cwd, dir)
64-
if err != nil {
63+
if dir, err = securejoin.SecureJoin(cwd, dir); err != nil {
6564
return err
6665
}
6766
}
@@ -77,35 +76,33 @@ func Untar(r io.Reader, dir string, inOpts ...Option) (err error) {
7776
}
7877

7978
madeDir := map[string]bool{}
80-
var tr *tar.Reader
81-
if opts.skipGzip {
82-
tr = tar.NewReader(r)
83-
} else {
84-
zr, err := gzip.NewReader(r)
79+
80+
var rc = io.NopCloser(r)
81+
if !opts.skipGzip {
82+
var err error
83+
rc, err = gzip.NewReader(r)
8584
if err != nil {
8685
return fmt.Errorf("requires gzip-compressed body: %w", err)
8786
}
88-
89-
tr = tar.NewReader(zr)
9087
}
88+
tr := tar.NewReader(rc)
9189

92-
processedBytes := 0
90+
var processedBytes int64
9391
t0 := time.Now()
9492

95-
// For improved concurrency, this could be optimised by sourcing
96-
// the buffer from a sync.Pool.
93+
// Reuse a single buffer for all file copies.
9794
buf := make([]byte, bufferSize)
9895
for {
9996
f, err := tr.Next()
100-
if err == io.EOF {
97+
if errors.Is(err, io.EOF) {
10198
break
10299
}
103100
if err != nil {
104101
return fmt.Errorf("tar error: %w", err)
105102
}
106-
processedBytes += int(f.Size)
103+
processedBytes += f.Size
107104
if opts.maxUntarSize > UnlimitedUntarSize &&
108-
processedBytes > opts.maxUntarSize {
105+
processedBytes > int64(opts.maxUntarSize) {
109106
return fmt.Errorf("tar %q is bigger than max archive size of %d bytes", f.Name, opts.maxUntarSize)
110107
}
111108
if !validRelPath(f.Name) {
@@ -127,12 +124,12 @@ func Untar(r io.Reader, dir string, inOpts ...Option) (err error) {
127124
// already be made by a directory entry in the tar
128125
// beforehand. Thus, don't check for errors; the next
129126
// write will fail with the same error.
130-
dir := filepath.Dir(abs)
131-
if !madeDir[dir] {
132-
if err := os.MkdirAll(filepath.Dir(abs), 0o750); err != nil {
127+
parentDir := filepath.Dir(abs)
128+
if !madeDir[parentDir] {
129+
if err := os.MkdirAll(parentDir, 0o750); err != nil {
133130
return err
134131
}
135-
madeDir[dir] = true
132+
madeDir[parentDir] = true
136133
}
137134
if runtime.GOOS == "darwin" && mode&0111 != 0 {
138135
// The darwin kernel caches binary signatures
@@ -146,13 +143,13 @@ func Untar(r io.Reader, dir string, inOpts ...Option) (err error) {
146143
return err
147144
}
148145
}
149-
wf, err := os.OpenFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm())
146+
wf, err := os.OpenFile(abs, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.Perm())
150147
if err != nil {
151148
return err
152149
}
153150

154151
n, err := copyBuffer(wf, tr, buf)
155-
if err != nil && err != io.EOF {
152+
if err != nil && !errors.Is(err, io.EOF) {
156153
return fmt.Errorf("error copying buffer: %w", err)
157154
}
158155

@@ -194,7 +191,7 @@ func Untar(r io.Reader, dir string, inOpts ...Option) (err error) {
194191
return fmt.Errorf("tar file entry %s contained unsupported file type %v", f.Name, mode)
195192
}
196193
}
197-
return nil
194+
return rc.Close()
198195
}
199196

200197
// Uses a variant of io.CopyBuffer which ensures that a buffer is being used.
@@ -211,11 +208,13 @@ func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err er
211208
for {
212209
nr, er := src.Read(buf)
213210
if nr > 0 {
214-
nw, ew := dst.Write(buf[0:nr])
211+
nw, ew := dst.Write(buf[:nr])
212+
// Guard against a broken Writer: negative byte count
213+
// or claiming more bytes written than provided.
215214
if nw < 0 || nr < nw {
216215
nw = 0
217216
if ew == nil {
218-
ew = fmt.Errorf("errInvalidWrite")
217+
ew = errors.New("invalid write result")
219218
}
220219
}
221220
written += int64(nw)
@@ -229,7 +228,7 @@ func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err er
229228
}
230229
}
231230
if er != nil {
232-
if er != io.EOF {
231+
if !errors.Is(er, io.EOF) {
233232
err = er
234233
}
235234
break

0 commit comments

Comments
 (0)