Skip to content

Commit 2cf2b29

Browse files
authored
feat(parquet/pqarrow): Correctly handle Variant types in schema (#433)
### Rationale for this change Updating the `pqarrow` package to handle the variant extension type when converting between arrow and parquet schemas. ### What changes are included in this PR? Replacing the TODOs with implementations to handle shredded variant structures in schema conversion. ### Are these changes tested? A unit test is added for shredded variant handling. ### Are there any user-facing changes? Only that this is now supported instead of erroring.
1 parent 8598fb3 commit 2cf2b29

3 files changed

Lines changed: 127 additions & 7 deletions

File tree

arrow/extensions/variant.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ func (v *VariantType) Value() arrow.Field {
168168
return v.StorageType().(*arrow.StructType).Field(v.valueFieldIdx)
169169
}
170170

171+
func (v *VariantType) TypedValue() arrow.Field {
172+
if v.typedValueFieldIdx == -1 {
173+
return arrow.Field{}
174+
}
175+
176+
return v.StorageType().(*arrow.StructType).Field(v.typedValueFieldIdx)
177+
}
178+
171179
func (*VariantType) ExtensionName() string { return "parquet.variant" }
172180

173181
func (v *VariantType) String() string {

parquet/pqarrow/schema.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,19 @@ func variantToNode(t *extensions.VariantType, field arrow.Field, props *parquet.
253253
return nil, err
254254
}
255255

256-
//TODO: implement shredding
256+
fields := schema.FieldList{metadataNode, valueNode}
257+
258+
typedField := t.TypedValue()
259+
if typedField.Type != nil {
260+
typedNode, err := fieldToNode("typed_value", typedField, props, arrProps)
261+
if err != nil {
262+
return nil, err
263+
}
264+
fields = append(fields, typedNode)
265+
}
257266

258267
return schema.NewGroupNodeLogical(field.Name, repFromNullable(field.Nullable),
259-
schema.FieldList{metadataNode, valueNode}, schema.VariantLogicalType{},
260-
fieldIDFromMeta(field.Metadata))
268+
fields, schema.VariantLogicalType{}, fieldIDFromMeta(field.Metadata))
261269
}
262270

263271
func structToNode(field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) {
@@ -857,10 +865,10 @@ func mapToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *sc
857865
}
858866

859867
func variantToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, _, out *SchemaField) error {
860-
// this is for unshredded variants. shredded variants may have more fields
861-
// TODO: implement support for shredded variants
862-
if n.NumFields() != 2 {
863-
return errors.New("VARIANT group must have exactly 2 children")
868+
switch n.NumFields() {
869+
case 2, 3:
870+
default:
871+
return errors.New("VARIANT group must have exactly 2 or 3 children")
864872
}
865873

866874
var err error

parquet/pqarrow/schema_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,107 @@ func TestConvertSchemaParquetVariant(t *testing.T) {
534534
require.NoError(t, err)
535535
assert.True(t, pqschema.Equals(sc), pqschema.String(), sc.String())
536536
}
537+
538+
func TestShreddedVariantSchema(t *testing.T) {
539+
metaNoFieldID := arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"})
540+
541+
s := arrow.StructOf(
542+
arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Metadata: metaNoFieldID},
543+
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true, Metadata: metaNoFieldID},
544+
arrow.Field{Name: "typed_value", Type: arrow.StructOf(
545+
arrow.Field{Name: "tsmicro", Type: arrow.StructOf(
546+
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true, Metadata: metaNoFieldID},
547+
arrow.Field{Name: "typed_value", Type: arrow.FixedWidthTypes.Timestamp_us, Nullable: true, Metadata: metaNoFieldID},
548+
), Metadata: metaNoFieldID},
549+
arrow.Field{Name: "strval", Type: arrow.StructOf(
550+
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true, Metadata: metaNoFieldID},
551+
arrow.Field{Name: "typed_value", Type: arrow.BinaryTypes.String, Nullable: true, Metadata: metaNoFieldID},
552+
), Metadata: metaNoFieldID},
553+
arrow.Field{Name: "bool", Type: arrow.StructOf(
554+
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true, Metadata: metaNoFieldID},
555+
arrow.Field{Name: "typed_value", Type: arrow.FixedWidthTypes.Boolean, Nullable: true, Metadata: metaNoFieldID},
556+
), Metadata: metaNoFieldID},
557+
arrow.Field{Name: "uuid", Type: arrow.StructOf(
558+
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true, Metadata: metaNoFieldID},
559+
arrow.Field{Name: "typed_value", Type: extensions.NewUUIDType(), Nullable: true, Metadata: metaNoFieldID},
560+
), Metadata: metaNoFieldID},
561+
), Nullable: true, Metadata: metaNoFieldID})
562+
563+
vt, err := extensions.NewVariantType(s)
564+
require.NoError(t, err)
565+
566+
arrSchema := arrow.NewSchema([]arrow.Field{
567+
{Name: "variant_col", Type: vt, Nullable: true, Metadata: metaNoFieldID},
568+
{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: false, Metadata: metaNoFieldID},
569+
}, nil)
570+
571+
sc, err := pqarrow.ToParquet(arrSchema, nil, pqarrow.DefaultWriterProps())
572+
require.NoError(t, err)
573+
574+
// the equivalent shredded variant parquet schema looks like this:
575+
// repeated group field_id=-1 schema {
576+
// optional group field_id=-1 variant_col (Variant) {
577+
// required byte_array field_id=-1 metadata;
578+
// optional byte_array field_id=-1 value;
579+
// optional group field_id=-1 typed_value {
580+
// required group field_id=-1 tsmicro {
581+
// optional byte_array field_id=-1 value;
582+
// optional int64 field_id=-1 typed_value (Timestamp(isAdjustedToUTC=true, timeUnit=microseconds, is_from_converted_type=false, force_set_converted_type=true));
583+
// }
584+
// required group field_id=-1 strval {
585+
// optional byte_array field_id=-1 value;
586+
// optional byte_array field_id=-1 typed_value (String);
587+
// }
588+
// required group field_id=-1 bool {
589+
// optional byte_array field_id=-1 value;
590+
// optional boolean field_id=-1 typed_value;
591+
// }
592+
// required group field_id=-1 uuid {
593+
// optional byte_array field_id=-1 value;
594+
// optional fixed_len_byte_array field_id=-1 typed_value (UUID);
595+
// }
596+
// }
597+
// }
598+
// required int64 field_id=-1 id (Int(bitWidth=64, isSigned=true));
599+
// }
600+
601+
expected := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema",
602+
parquet.Repetitions.Repeated, schema.FieldList{
603+
schema.Must(schema.NewGroupNodeLogical("variant_col", parquet.Repetitions.Optional, schema.FieldList{
604+
schema.MustPrimitive(schema.NewPrimitiveNode("metadata", parquet.Repetitions.Required, parquet.Types.ByteArray, -1, -1)),
605+
schema.MustPrimitive(schema.NewPrimitiveNode("value", parquet.Repetitions.Optional, parquet.Types.ByteArray, -1, -1)),
606+
schema.MustGroup(schema.NewGroupNode("typed_value", parquet.Repetitions.Optional, schema.FieldList{
607+
schema.MustGroup(schema.NewGroupNode("tsmicro", parquet.Repetitions.Required, schema.FieldList{
608+
schema.MustPrimitive(schema.NewPrimitiveNode("value", parquet.Repetitions.Optional, parquet.Types.ByteArray, -1, -1)),
609+
schema.MustPrimitive(schema.NewPrimitiveNodeLogical("typed_value", parquet.Repetitions.Optional, schema.NewTimestampLogicalTypeWithOpts(
610+
schema.WithTSTimeUnitType(schema.TimeUnitMicros), schema.WithTSIsAdjustedToUTC(), schema.WithTSForceConverted(),
611+
), parquet.Types.Int64, -1, -1)),
612+
}, -1)),
613+
schema.MustGroup(schema.NewGroupNode("strval", parquet.Repetitions.Required, schema.FieldList{
614+
schema.MustPrimitive(schema.NewPrimitiveNode("value", parquet.Repetitions.Optional, parquet.Types.ByteArray, -1, -1)),
615+
schema.MustPrimitive(schema.NewPrimitiveNodeLogical("typed_value", parquet.Repetitions.Optional,
616+
schema.StringLogicalType{}, parquet.Types.ByteArray, -1, -1)),
617+
}, -1)),
618+
schema.MustGroup(schema.NewGroupNode("bool", parquet.Repetitions.Required, schema.FieldList{
619+
schema.MustPrimitive(schema.NewPrimitiveNode("value", parquet.Repetitions.Optional, parquet.Types.ByteArray, -1, -1)),
620+
schema.MustPrimitive(schema.NewPrimitiveNode("typed_value", parquet.Repetitions.Optional,
621+
parquet.Types.Boolean, -1, -1)),
622+
}, -1)),
623+
schema.MustGroup(schema.NewGroupNode("uuid", parquet.Repetitions.Required, schema.FieldList{
624+
schema.MustPrimitive(schema.NewPrimitiveNode("value", parquet.Repetitions.Optional, parquet.Types.ByteArray, -1, -1)),
625+
schema.MustPrimitive(schema.NewPrimitiveNodeLogical("typed_value", parquet.Repetitions.Optional,
626+
schema.UUIDLogicalType{}, parquet.Types.FixedLenByteArray, 16, -1)),
627+
}, -1)),
628+
}, -1)),
629+
}, schema.VariantLogicalType{}, -1)),
630+
schema.MustPrimitive(schema.NewPrimitiveNodeLogical("id", parquet.Repetitions.Required,
631+
schema.NewIntLogicalType(64, true), parquet.Types.Int64, -1, -1)),
632+
}, -1)))
633+
634+
assert.True(t, sc.Equals(expected), "expected: %s\ngot: %s", expected, sc)
635+
636+
arrsc, err := pqarrow.FromParquet(sc, nil, metadata.KeyValueMetadata{})
637+
require.NoError(t, err)
638+
639+
assert.True(t, arrSchema.Equal(arrsc), "expected: %s\ngot: %s", arrSchema, arrsc)
640+
}

0 commit comments

Comments
 (0)