Skip to content

Commit 6648c1d

Browse files
authored
fix(arrow/cdata, arrow/flight): fix handling of colons in values and fix potential panics (#761)
### fix(arrow/cdata): importSchema: handle colons in values Use strings.Cut, both as an optimization, and to prevent values containing a colon (e.g. "tsu:+01:00") from being mis-interpreted. This patch also removes some intermediate variables, and redundant handling of "defaulttz", which assigned an empty string if the value was empty. ### fix(arrow/cdata): importSchema: fix potential panic and optimize Rewrite the code with strings.Cut and strings.SplitSeq to reduce allocations, and to fix a potential panic. Before this patch, the code would panic if a colon was missing; CGO_ENABLED=1 go test -v -tags test -run TestUnionSchemaErrors ./arrow/cdata/ --- FAIL: TestUnionSchemaErrors (0.00s) --- FAIL: TestUnionSchemaErrors/+us (0.00s) panic: runtime error: index out of range [1] with length 1 [recovered, repanicked] goroutine 9 [running]: testing.tRunner.func1.2({0x7fc7c0, 0x4000026ab0}) /usr/local/go/src/testing/testing.go:1872 +0x190 testing.tRunner.func1() /usr/local/go/src/testing/testing.go:1875 +0x31c panic({0x7fc7c0?, 0x4000026ab0?}) /usr/local/go/src/runtime/panic.go:783 +0x120 github.com/apache/arrow-go/v18/arrow/cdata.importSchema(0x40001c36d0) /foo/arrow/cdata/cdata.go:306 +0x1520 github.com/apache/arrow-go/v18/arrow/cdata.ImportCArrowField(...) /foo/arrow/cdata/interface.go:43 github.com/apache/arrow-go/v18/arrow/cdata.TestUnionSchemaErrors.func1(0x40000e0a80) /foo/arrow/cdata/cdata_test.go:188 +0xb0 testing.tRunner(0x40000e0a80, 0x400020c060) /usr/local/go/src/testing/testing.go:1934 +0xc8 created by testing.(*T).Run in goroutine 8 /usr/local/go/src/testing/testing.go:1997 +0x364 FAIL github.com/apache/arrow-go/v18/arrow/cdata 0.007s FAIL With this patch applied, the code handles the invalid value gracefully; CGO_ENABLED=1 go test -v -tags test -run TestUnionSchemaErrors ./arrow/cdata/ === RUN TestUnionSchemaErrors === RUN TestUnionSchemaErrors/+us === RUN TestUnionSchemaErrors/+ud --- PASS: TestUnionSchemaErrors (0.00s) --- PASS: TestUnionSchemaErrors/+us (0.00s) --- PASS: TestUnionSchemaErrors/+ud (0.00s) PASS ok github.com/apache/arrow-go/v18/arrow/cdata 0.003s ### fix(arrow/flight): avoid panic on malformed authorization header Rewrite the code with strings.Cut for readability and ensue missing credentials in Basic/Bearer authorization headers return Unauthenticated instead of panicking. Before this patch, the code could panic; go test -run TestBasicAuthMissingCredential ./arrow/flight/ panic: runtime error: index out of range [1] with length 1 goroutine 7 [running]: github.com/apache/arrow-go/v18/arrow/flight_test.TestBasicAuthMissingCredential.CreateServerBasicAuthMiddleware.createServerBearerTokenStreamInterceptor.func3({0x8d8240, 0x40002134a0}, {0xa73e68, 0x40000e2000}, 0x40000100c0, 0x96b628) /foo/arrow/flight/server_auth.go:188 +0x49c .... With this patch applied, the code handles the invalid header gracefully; go test -run TestBasicAuthMissingCredential ./arrow/flight/ ok github.com/apache/arrow-go/v18/arrow/flight 0.010s ### Rationale for this change ### What changes are included in this PR? ### Are these changes tested? ### Are there any user-facing changes? --------- Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
1 parent 2e44b72 commit 6648c1d

4 files changed

Lines changed: 92 additions & 39 deletions

File tree

arrow/cdata/cdata.go

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -202,41 +202,23 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) {
202202
}
203203

204204
// handle types with params via colon
205-
typs := strings.Split(f, ":")
206-
defaulttz := ""
207-
switch typs[0] {
205+
switch key, val, _ := strings.Cut(f, ":"); key {
208206
case "tss":
209-
tz := typs[1]
210-
if len(typs[1]) == 0 {
211-
tz = defaulttz
212-
}
213-
dt = &arrow.TimestampType{Unit: arrow.Second, TimeZone: tz}
207+
dt = &arrow.TimestampType{Unit: arrow.Second, TimeZone: val}
214208
case "tsm":
215-
tz := typs[1]
216-
if len(typs[1]) == 0 {
217-
tz = defaulttz
218-
}
219-
dt = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: tz}
209+
dt = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: val}
220210
case "tsu":
221-
tz := typs[1]
222-
if len(typs[1]) == 0 {
223-
tz = defaulttz
224-
}
225-
dt = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: tz}
211+
dt = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: val}
226212
case "tsn":
227-
tz := typs[1]
228-
if len(typs[1]) == 0 {
229-
tz = defaulttz
230-
}
231-
dt = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: tz}
213+
dt = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: val}
232214
case "w": // fixed size binary is "w:##" where ## is the byteWidth
233-
byteWidth, err := strconv.Atoi(typs[1])
215+
byteWidth, err := strconv.Atoi(val)
234216
if err != nil {
235217
return ret, err
236218
}
237219
dt = &arrow.FixedSizeBinaryType{ByteWidth: byteWidth}
238220
case "d": // decimal types are d:<precision>,<scale>[,<bitsize>] size is assumed 128 if left out
239-
props := typs[1]
221+
props := val
240222
propList := strings.Split(props, ",")
241223
bitwidth := 128
242224
var precision, scale int
@@ -317,9 +299,12 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) {
317299
return
318300
}
319301

320-
codes := strings.Split(strings.Split(f, ":")[1], ",")
321-
typeCodes := make([]arrow.UnionTypeCode, 0, len(codes))
322-
for _, i := range codes {
302+
_, val, ok := strings.Cut(f, ":")
303+
if !ok {
304+
return ret, fmt.Errorf("invalid union type code spec %q", f)
305+
}
306+
var typeCodes []arrow.UnionTypeCode
307+
for i := range strings.SplitSeq(val, ",") {
323308
v, e := strconv.ParseInt(i, 10, 8)
324309
if e != nil {
325310
err = fmt.Errorf("%w: invalid type code: %s", arrow.ErrInvalid, e)

arrow/cdata/cdata_test.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,24 @@ func TestDecimalSchemaErrors(t *testing.T) {
174174
}
175175
}
176176

177+
func TestUnionSchemaErrors(t *testing.T) {
178+
tests := []struct {
179+
fmt string
180+
}{
181+
{"+us"}, // missing ":<type_codes>"
182+
{"+ud"}, // missing ":<type_codes>"
183+
}
184+
185+
for _, tt := range tests {
186+
t.Run(tt.fmt, func(t *testing.T) {
187+
sc := testPrimitive(tt.fmt)
188+
189+
_, err := ImportCArrowField(&sc)
190+
assert.Error(t, err)
191+
})
192+
}
193+
}
194+
177195
func TestImportTemporalSchema(t *testing.T) {
178196
tests := []struct {
179197
typ arrow.DataType
@@ -195,9 +213,12 @@ func TestImportTemporalSchema(t *testing.T) {
195213
{arrow.FixedWidthTypes.Timestamp_s, "tss:UTC"},
196214
{&arrow.TimestampType{Unit: arrow.Second}, "tss:"},
197215
{&arrow.TimestampType{Unit: arrow.Second, TimeZone: "Europe/Paris"}, "tss:Europe/Paris"},
216+
{&arrow.TimestampType{Unit: arrow.Second, TimeZone: "Etc/GMT+1"}, "tss:Etc/GMT+1"},
217+
{&arrow.TimestampType{Unit: arrow.Second, TimeZone: "+01:00"}, "tss:+01:00"},
198218
{arrow.FixedWidthTypes.Timestamp_ms, "tsm:UTC"},
199219
{&arrow.TimestampType{Unit: arrow.Millisecond}, "tsm:"},
200220
{&arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: "Europe/Paris"}, "tsm:Europe/Paris"},
221+
{&arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: "-07:30"}, "tsm:-07:30"},
201222
{arrow.FixedWidthTypes.Timestamp_us, "tsu:UTC"},
202223
{&arrow.TimestampType{Unit: arrow.Microsecond}, "tsu:"},
203224
{&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: "Europe/Paris"}, "tsu:Europe/Paris"},
@@ -207,7 +228,7 @@ func TestImportTemporalSchema(t *testing.T) {
207228
}
208229

209230
for _, tt := range tests {
210-
t.Run(tt.typ.Name(), func(t *testing.T) {
231+
t.Run(tt.fmt, func(t *testing.T) {
211232
sc := testPrimitive(tt.fmt)
212233

213234
f, err := ImportCArrowField(&sc)

arrow/flight/basic_auth_flight_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,46 @@ func TestBasicAuthHelpers(t *testing.T) {
206206
t.Fatal("should have received carebears")
207207
}
208208
}
209+
210+
func TestBasicAuthMissingCredential(t *testing.T) {
211+
s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{flight.CreateServerBasicAuthMiddleware(&validator{})})
212+
s.Init("localhost:0")
213+
f := &HeaderAuthTestFlight{}
214+
s.RegisterFlightService(f)
215+
go s.Serve()
216+
defer s.Shutdown()
217+
218+
client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithTransportCredentials(insecure.NewCredentials()))
219+
if err != nil {
220+
t.Fatal(err)
221+
}
222+
223+
ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{
224+
"authorization": "Basic",
225+
}))
226+
227+
fc, err := client.Handshake(ctx)
228+
if err != nil {
229+
st, ok := status.FromError(err)
230+
if !ok {
231+
t.Fatalf("expected gRPC status error, got %T: %v", err, err)
232+
}
233+
if got, want := st.Code(), codes.Unauthenticated; got != want {
234+
t.Fatalf("unexpected code: got %v, want %v", got, want)
235+
}
236+
return
237+
}
238+
239+
_, err = fc.Recv()
240+
if err == nil {
241+
t.Fatal("expected error")
242+
}
243+
244+
st, ok := status.FromError(err)
245+
if !ok {
246+
t.Fatalf("expected gRPC status error, got %T: %v", err, err)
247+
}
248+
if got, want := st.Code(), codes.Unauthenticated; got != want {
249+
t.Fatalf("unexpected code: got %v, want %v", got, want)
250+
}
251+
}

arrow/flight/server_auth.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,24 +170,26 @@ func createServerBearerTokenUnaryInterceptor(validator BasicAuthValidator) grpc.
170170

171171
func createServerBearerTokenStreamInterceptor(validator BasicAuthValidator) grpc.StreamServerInterceptor {
172172
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
173-
var auth []string
173+
var scheme, credential string
174174
md, ok := metadata.FromIncomingContext(stream.Context())
175175
if ok {
176-
auth = md.Get(basicAuthHeader)
176+
auth := md.Get(basicAuthHeader)
177177
if len(auth) > 0 {
178-
auth = strings.Split(auth[0], " ")
178+
s := strings.TrimSpace(auth[0])
179+
scheme, credential, _ = strings.Cut(s, " ")
180+
credential = strings.TrimLeft(credential, " ") // only trim SP per HTTP auth format, keep trailing spaces.
179181
}
180182
}
181183

182-
if len(auth) == 0 {
184+
if scheme == "" || credential == "" {
183185
return status.Error(codes.Unauthenticated, "must authenticate first")
184186
}
185187

186188
if strings.HasSuffix(info.FullMethod, "/Handshake") {
187-
if auth[0] == basicAuthPrefix {
188-
val, err := base64.RawStdEncoding.DecodeString(auth[1])
189+
if scheme == basicAuthPrefix {
190+
val, err := base64.RawStdEncoding.DecodeString(credential)
189191
if err != nil {
190-
val, err = base64.StdEncoding.DecodeString(auth[1])
192+
val, err = base64.StdEncoding.DecodeString(credential)
191193
if err != nil {
192194
return status.Errorf(codes.Unauthenticated, "invalid basic auth encoding: %s", err)
193195
}
@@ -199,14 +201,16 @@ func createServerBearerTokenStreamInterceptor(validator BasicAuthValidator) grpc
199201
return err
200202
}
201203

202-
stream.SetTrailer(metadata.New(map[string]string{basicAuthHeader: strings.Join([]string{bearerTokenPrefix, token}, " ")}))
204+
stream.SetTrailer(metadata.New(map[string]string{
205+
basicAuthHeader: bearerTokenPrefix + " " + token,
206+
}))
203207
return handler(srv, stream)
204208
}
205209
return status.Errorf(codes.Unauthenticated, "only Basic Auth implemented")
206210
}
207211

208-
if auth[0] == bearerTokenPrefix {
209-
identity, err := validator.IsValid(auth[1])
212+
if scheme == bearerTokenPrefix {
213+
identity, err := validator.IsValid(credential)
210214
if err != nil {
211215
return err
212216
}

0 commit comments

Comments
 (0)