Skip to content

Commit d41e0b5

Browse files
authored
fix(parquet): strip repetition_type from root SchemaElement during serialization (#723)
### Rationale The Parquet spec says the root of the schema doesn't have a `repetition_type`. arrow-go writes `REPEATED` for the root `SchemaElement` in the Thrift footer, which breaks interop with consumers like Snowflake. #722 has the full writeup and cross-implementation comparison. ### What changes are included in this PR? `ToThrift()` now nils out `RepetitionType` on the root element before returning, stripping it from the serialized output. This matches how parquet-java and arrow-rs handle the root. The in-memory representation and `WithRootRepetition` API are unaffected. `FromParquet` [already tolerates a nil root repetition type](https://github.com/apache/arrow-go/blob/main/parquet/schema/schema.go#L78-L79), so this is backwards-compatible for both readers and writers. ### Are these changes tested? Updated the existing `TestNestedExample` and added `TestToThriftRootRepetitionStripped` which checks that the root's `repetition_type` is stripped for all three repetition variants and that non-root elements keep theirs. Closes #722
1 parent b0287d7 commit d41e0b5

3 files changed

Lines changed: 23 additions & 3 deletions

File tree

parquet/pqarrow/file_writer_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func TestFileWriterTotalBytes(t *testing.T) {
172172

173173
// Verify total bytes & compressed bytes are correct
174174
assert.Equal(t, int64(408), writer.TotalCompressedBytes())
175-
assert.Equal(t, int64(912), writer.TotalBytesWritten())
175+
assert.Equal(t, int64(910), writer.TotalBytesWritten())
176176
}
177177

178178
func TestFileWriterTotalBytesBuffered(t *testing.T) {
@@ -206,5 +206,5 @@ func TestFileWriterTotalBytesBuffered(t *testing.T) {
206206

207207
// Verify total bytes & compressed bytes are correct
208208
assert.Equal(t, int64(596), writer.TotalCompressedBytes())
209-
assert.Equal(t, int64(1308), writer.TotalBytesWritten())
209+
assert.Equal(t, int64(1306), writer.TotalBytesWritten())
210210
}

parquet/schema/schema.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ func (t *toThriftVisitor) VisitPost(Node) {}
272272
func ToThrift(schema *GroupNode) []*format.SchemaElement {
273273
t := &toThriftVisitor{make([]*format.SchemaElement, 0)}
274274
schema.Visit(t)
275+
t.elements[0].RepetitionType = nil
275276
return t.elements
276277
}
277278

parquet/schema/schema_flatten_test.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ func (s *SchemaFlattenSuite) TestDecimalMetadata() {
9292

9393
func (s *SchemaFlattenSuite) TestNestedExample() {
9494
elements := make([]*format.SchemaElement, 0)
95+
root := NewGroup(s.name, format.FieldRepetitionType_REPEATED, 2 /* numChildren */, 0 /* fieldID */)
96+
root.RepetitionType = nil
9597
elements = append(elements,
96-
NewGroup(s.name, format.FieldRepetitionType_REPEATED, 2 /* numChildren */, 0 /* fieldID */),
98+
root,
9799
NewPrimitive("a" /* name */, format.FieldRepetitionType_REQUIRED, format.Type_INT32, 1 /* fieldID */),
98100
NewGroup("bag" /* name */, format.FieldRepetitionType_OPTIONAL, 1 /* numChildren */, 2 /* fieldID */))
99101

@@ -120,6 +122,23 @@ func TestSchemaFlatten(t *testing.T) {
120122
suite.Run(t, new(SchemaFlattenSuite))
121123
}
122124

125+
func TestToThriftRootRepetitionStripped(t *testing.T) {
126+
for _, rep := range []parquet.Repetition{
127+
parquet.Repetitions.Repeated,
128+
parquet.Repetitions.Required,
129+
parquet.Repetitions.Optional,
130+
} {
131+
group := MustGroup(NewGroupNode("schema", rep, FieldList{
132+
NewInt32Node("a", parquet.Repetitions.Required, -1),
133+
}, -1))
134+
elements := ToThrift(group)
135+
assert.False(t, elements[0].IsSetRepetitionType(),
136+
"root element should not have repetition_type set (was %v)", rep)
137+
assert.True(t, elements[1].IsSetRepetitionType(),
138+
"non-root element must have repetition_type set")
139+
}
140+
}
141+
123142
func TestInvalidConvertedTypeInDeserialize(t *testing.T) {
124143
n := MustPrimitive(NewPrimitiveNodeLogical("string" /* name */, parquet.Repetitions.Required, StringLogicalType{},
125144
parquet.Types.ByteArray, -1 /* type len */, -1 /* fieldID */))

0 commit comments

Comments
 (0)