Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions storage/transfermanager/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"hash/crc32"
"io"
"io/fs"
"log"
"math"
"os"
"path/filepath"
Expand Down Expand Up @@ -200,7 +199,6 @@ func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirec
return fmt.Errorf("transfermanager: DownloadDirectory failed to verify path: %w", err)
}
if !isUnder {
log.Printf("transfermanager: skipping object with unsafe path after stripping prefix %q", objectWithoutPrefix)
// skipped files will later be added in the results
illegalPathObjects = append(illegalPathObjects, DownloadObjectInput{
Bucket: input.Bucket,
Expand Down
25 changes: 15 additions & 10 deletions storage/transfermanager/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,12 @@ func TestIsSubPath(t *testing.T) {
tempDir := t.TempDir()

testCases := []struct {
name string
localDirectory string
filePath string
wantIsSub bool
wantErr bool
name string
localDirectory string
filePath string
wantIsSub bool
wantErr bool
Comment thread
krishnamd-jkp marked this conversation as resolved.
Outdated
wantErrorMessage string
}{
{
name: "filePath is a child",
Expand Down Expand Up @@ -309,10 +310,11 @@ func TestIsSubPath(t *testing.T) {
wantIsSub: false,
},
{
name: "IsSubPath returns error when base dir is changed",
localDirectory: "foo",
filePath: "bar",
wantErr: true,
name: "IsSubPath returns error when base dir is changed",
localDirectory: "foo",
filePath: "bar",
wantErr: true,
wantErrorMessage: "no such file or directory",
},
}

Expand All @@ -331,7 +333,10 @@ func TestIsSubPath(t *testing.T) {
isSub, err := isSubPath(tc.localDirectory, tc.filePath)

if (err != nil) != tc.wantErr {
Comment thread
krishnamd-jkp marked this conversation as resolved.
Outdated
t.Errorf("isSubPath() error = %v, wantErr %v", err, tc.wantErr)
t.Fatalf("isSubPath() error = %v, wantErr %v", err, tc.wantErr)
}
if tc.wantErr && !strings.Contains(err.Error(), tc.wantErrorMessage) {
t.Errorf("isSubPath() error = %s, want err containing %s", err.Error(), tc.wantErrorMessage)
return
}
if isSub != tc.wantIsSub {
Expand Down
283 changes: 283 additions & 0 deletions storage/transfermanager/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"math/rand"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -456,6 +457,288 @@ func TestIntegration_DownloadDirectoryError(t *testing.T) {
})
}

// TestIntegration_DownloadDirectorySkippedFiles tests that an error is returned if any
// object download escapes target directory
func TestIntegration_DownloadDirectoryWithSkippedFiles(t *testing.T) {
multiTransportTest(context.Background(), t, func(t *testing.T, ctx context.Context, c *storage.Client, tb downloadTestBucket) {
localDir := t.TempDir()
callbackMu := sync.Mutex{}
numCallbacks := 0
numErrors := 0
illegalObjs := []string{
"../objA",
"../objB",
}

b := c.Bucket(tb.bucket)
for _, obj := range illegalObjs {
size := randomInt64(minObjectSize, maxObjectSize)
_, err := generateFileInGCS(ctx, b.Object(obj), size)
if err != nil {
t.Errorf("could not create file in GCS: %v for obj %s", err, obj)
}
}
t.Cleanup(func() {
for _, obj := range illegalObjs {
if err := b.Object(obj).Delete(ctx); err != nil {
t.Errorf("failed to delete object in cleanup for %s: %v", obj, err)
}
}
})
d, err := NewDownloader(c, WithWorkers(8), WithPartSize(maxObjectSize/2))
if err != nil {
t.Fatalf("NewDownloader: %v", err)
}

if err := d.DownloadDirectory(ctx, &DownloadDirectoryInput{
Bucket: tb.bucket,
LocalDirectory: localDir,
OnObjectDownload: func(got *DownloadOutput) {
callbackMu.Lock()
numCallbacks++
if got.Err != nil {
numErrors++
}
callbackMu.Unlock()

if slices.Contains(illegalObjs, got.Object) {
if got.Err == nil {
t.Errorf("Expected error but got nil for object %v: %v", got.Object, got.Err)
}
return
}
if got.Err != nil {
t.Errorf("result.Err: %v", got.Err)
}

if got, want := got.Attrs.Size, tb.objectSizes[got.Object]; want != got {
t.Errorf("expected object size %d, got %d", want, got)
}

path := filepath.Join(localDir, got.Object)

f, err := os.Open(path)
if err != nil {
t.Errorf("os.Open(%q): %v", path, err)
}
defer f.Close()

b := bytes.NewBuffer(make([]byte, 0, got.Attrs.Size))
if _, err := io.Copy(b, f); err != nil {
t.Errorf("io.Copy: %v", err)
}

if wantCRC, gotCRC := tb.contentHashes[got.Object], crc32c(b.Bytes()); gotCRC != wantCRC {
Comment thread
krishnamd-jkp marked this conversation as resolved.
Outdated
t.Errorf("object(%q) at filepath(%q): content crc32c does not match; got: %v, expected: %v", got.Object, path, gotCRC, wantCRC)
}
got.Object = "modifying this shouldn't be a problem"
},
}); err != nil {
t.Errorf("d.DownloadDirectory: %v", err)
}

results, err := d.WaitAndClose()
if err == nil {
t.Errorf("expected error but did got nil: %v", err)
}

if len(results) != len(tb.objects)+len(illegalObjs) {
t.Errorf("expected to receive %d results, got %d results", len(tb.objects), len(results))
}
if numCallbacks != len(tb.objects)+len(illegalObjs) {
t.Errorf("expected to receive %d callbacks, got %d", len(tb.objects), numCallbacks)
}
if numErrors != len(illegalObjs) {
t.Errorf("expected to receive %d errors, got %d", numErrors, len(illegalObjs))
}

for _, got := range results {
if slices.Contains(illegalObjs, got.Object) {
if got.Err == nil {
t.Errorf("Expected error but got nil for object %v: %v", got.Object, got.Err)
}
// verify that file path has not been created
if _, err := os.Stat(filepath.Join(localDir, got.Object)); err == nil {
t.Errorf("Expected error but got nil for object %v: %v", got.Object, err)
}
continue
}

if got, want := got.Attrs.Size, tb.objectSizes[got.Object]; want != got {
Comment thread
krishnamd-jkp marked this conversation as resolved.
Outdated
t.Errorf("expected object size %d, got %d", want, got)
}

path := filepath.Join(localDir, got.Object)
f, err := os.Open(path)
if err != nil {
t.Errorf("os.Open(%q): %v", path, err)
}
defer f.Close()

b := bytes.NewBuffer(make([]byte, 0, got.Attrs.Size))
if _, err := io.Copy(b, f); err != nil {
t.Errorf("io.Copy: %v", err)
}

if wantCRC, gotCRC := tb.contentHashes[got.Object], crc32c(b.Bytes()); gotCRC != wantCRC {
t.Errorf("object(%q) at filepath(%q): content crc32c does not match; got: %v, expected: %v", got.Object, path, gotCRC, wantCRC)
}
}
})
}

func TestIntegration_DownloadDirectoryWithSkippedFilesAsync(t *testing.T) {
multiTransportTest(context.Background(), t, func(t *testing.T, ctx context.Context, c *storage.Client, tb downloadTestBucket) {
localDir := t.TempDir()
illegalObjs := []string{
"../objA",
"../objB",
}
b := c.Bucket(tb.bucket)
for _, obj := range illegalObjs {
size := randomInt64(minObjectSize, maxObjectSize)
_, err := generateFileInGCS(ctx, b.Object(obj), size)
if err != nil {
t.Errorf("could not create file in GCS: %v for obj %s", err, obj)
}
}
t.Cleanup(func() {
for _, obj := range illegalObjs {
if err := b.Object(obj).Delete(ctx); err != nil {
t.Errorf("failed to delete object in cleanup for %s: %v", obj, err)
}
}
})
d, err := NewDownloader(c, WithWorkers(2), WithPartSize(maxObjectSize/2), WithCallbacks())
if err != nil {
t.Fatalf("NewDownloader: %v", err)
}
objectDownloaded := make(chan bool)
objectSkipped := make(chan bool)
done := make(chan bool)

trackObjectsDownloaded := func(objectsDownloadedCount, objectSkippedCount *int) {
for {
select {
case <-done:
return
case <-objectDownloaded:
*objectsDownloadedCount++
case <-objectSkipped:
*objectSkippedCount++
}
}
}

objectsDownloadedCount := 0
objectSkippedCount := 0
go trackObjectsDownloaded(&objectsDownloadedCount, &objectSkippedCount)

if err := d.DownloadDirectory(ctx, &DownloadDirectoryInput{
Bucket: tb.bucket,
LocalDirectory: localDir,
OnObjectDownload: func(got *DownloadOutput) {
if got.Err != nil {
objectSkipped <- true
} else {
objectDownloaded <- true
}

if slices.Contains(illegalObjs, got.Object) {
if got.Err == nil {
t.Errorf("Expected error but got nil for object %v: %v", got.Object, got.Err)
}
return
}

if got.Err != nil {
t.Errorf("result.Err: %v", got.Err)
}

if got, want := got.Attrs.Size, tb.objectSizes[got.Object]; want != got {
t.Errorf("expected object size %d, got %d", want, got)
}

path := filepath.Join(localDir, got.Object)

f, err := os.Open(path)
if err != nil {
t.Errorf("os.Open(%q): %v", path, err)
}
defer f.Close()

b := bytes.NewBuffer(make([]byte, 0, got.Attrs.Size))
if _, err := io.Copy(b, f); err != nil {
t.Errorf("io.Copy: %v", err)
}

if wantCRC, gotCRC := tb.contentHashes[got.Object], crc32c(b.Bytes()); gotCRC != wantCRC {
t.Errorf("object(%q) at filepath(%q): content crc32c does not match; got: %v, expected: %v", got.Object, path, gotCRC, wantCRC)
}
got.Object = "modifying this shouldn't be a problem"
},
Callback: func(outs []DownloadOutput) {
Comment thread
krishnamd-jkp marked this conversation as resolved.
if len(outs) != len(tb.objects)+len(illegalObjs) {
t.Errorf("expected to receive %d results, got %d results", len(tb.objects)+len(illegalObjs), len(outs))
}

for _, got := range outs {
if slices.Contains(illegalObjs, got.Object) {
if got.Err == nil {
t.Errorf("Expected error but got nil for object %v: %v", got.Object, got.Err)
}
// verify that file path has not been created
if _, err := os.Stat(filepath.Join(localDir, got.Object)); err == nil {
t.Errorf("Expected error but got nil for object %v: %v", got.Object, err)
}
continue
}

if got.Err != nil {
t.Errorf("result.Err: %v", got.Err)
continue
}

if got, want := got.Attrs.Size, tb.objectSizes[got.Object]; want != got {
t.Errorf("expected object size %d, got %d", want, got)
}

path := filepath.Join(localDir, got.Object)
f, err := os.Open(path)
if err != nil {
t.Errorf("os.Open(%q): %v", path, err)
}
defer f.Close()

b := bytes.NewBuffer(make([]byte, 0, got.Attrs.Size))
if _, err := io.Copy(b, f); err != nil {
t.Errorf("io.Copy: %v", err)
}

if wantCRC, gotCRC := tb.contentHashes[got.Object], crc32c(b.Bytes()); gotCRC != wantCRC {
t.Errorf("object(%q) at filepath(%q): content crc32c does not match; got: %v, expected: %v", got.Object, path, gotCRC, wantCRC)
}
}
},
}); err != nil {
t.Errorf("d.DownloadDirectory: %v", err)
}

_, err = d.WaitAndClose()
if err == nil {
t.Errorf("expected error but did got nil: %v", err)
}
done <- true

if want, got := len(tb.objects), objectsDownloadedCount; want != got {
t.Errorf("expected to receive %d callbacks, got %d", want, got)
}
if objectSkippedCount != len(illegalObjs) {
t.Errorf("expected to receive %d errors, got %d", objectSkippedCount, len(illegalObjs))
}
})
}

func TestIntegration_DownloaderSynchronous(t *testing.T) {
multiTransportTest(context.Background(), t, func(t *testing.T, ctx context.Context, c *storage.Client, tb downloadTestBucket) {
objects := tb.objects
Expand Down
Loading