@@ -67,18 +67,34 @@ func TestUnion(t *testing.T) {
6767 members = append (members , NewUnionMember (f ))
6868 }
6969
70- // Create mock extractors that return predefined values instead of
71- // actually extracting from the object.
72- extractors := make ([]ExtractorFn [* testMember , bool ], len (tc .fieldValues ))
73- for i , val := range tc .fieldValues {
74- extractors [i ] = func (_ * testMember ) bool { return val }
75- }
70+ t .Run ("pointer" , func (t * testing.T ) {
71+ // Create mock extractors that return predefined values instead of
72+ // actually extracting from the object.
73+ extractors := make ([]ExtractorFn [* testMember , bool ], len (tc .fieldValues ))
74+ for i , val := range tc .fieldValues {
75+ extractors [i ] = func (_ * testMember ) bool { return val }
76+ }
7677
77- got := Union (context .Background (), operation.Operation {}, nil , & testMember {}, nil ,
78- NewUnionMembership (members ... ), extractors ... )
79- if ! reflect .DeepEqual (got , tc .expected ) {
80- t .Errorf ("got %v want %v" , got , tc .expected )
81- }
78+ got := Union (context .Background (), operation.Operation {}, nil , & testMember {}, nil ,
79+ NewUnionMembership (members ... ), extractors ... )
80+ if ! reflect .DeepEqual (got , tc .expected ) {
81+ t .Errorf ("got %v want %v" , got , tc .expected )
82+ }
83+ })
84+ t .Run ("value" , func (t * testing.T ) {
85+ // Create mock extractors that return predefined values instead of
86+ // actually extracting from the object.
87+ extractors := make ([]ExtractorFn [testMember , bool ], len (tc .fieldValues ))
88+ for i , val := range tc .fieldValues {
89+ extractors [i ] = func (_ testMember ) bool { return val }
90+ }
91+
92+ got := Union (context .Background (), operation.Operation {}, nil , testMember {}, testMember {},
93+ NewUnionMembership (members ... ), extractors ... )
94+ if ! reflect .DeepEqual (got , tc .expected ) {
95+ t .Errorf ("got %v want %v" , got , tc .expected )
96+ }
97+ })
8298 })
8399 }
84100}
@@ -131,33 +147,53 @@ func TestDiscriminatedUnion(t *testing.T) {
131147 }
132148
133149 for _ , tc := range testCases {
150+ members := []UnionMember {}
151+ for _ , f := range tc .fields {
152+ members = append (members , NewDiscriminatedUnionMember (f [0 ], f [1 ]))
153+ }
154+
134155 t .Run (tc .name , func (t * testing.T ) {
135- members := []UnionMember {}
136- for _ , f := range tc .fields {
137- members = append (members , NewDiscriminatedUnionMember (f [0 ], f [1 ]))
138- }
156+ t .Run ("pointer" , func (t * testing.T ) {
157+ discriminatorExtractor := func (_ * testMember ) string { return tc .discriminatorValue }
139158
140- discriminatorExtractor := func (_ * testMember ) string { return tc .discriminatorValue }
159+ // Create mock extractors that return predefined values instead of
160+ // actually extracting from the object.
161+ extractors := make ([]ExtractorFn [* testMember , bool ], len (tc .fieldValues ))
162+ for i , val := range tc .fieldValues {
163+ extractors [i ] = func (_ * testMember ) bool { return val }
164+ }
141165
142- // Create mock extractors that return predefined values instead of
143- // actually extracting from the object.
144- extractors := make ([]ExtractorFn [* testMember , bool ], len (tc .fieldValues ))
145- for i , val := range tc .fieldValues {
146- extractors [i ] = func (_ * testMember ) bool { return val }
147- }
166+ got := DiscriminatedUnion (context .Background (), operation.Operation {}, nil , & testMember {}, nil ,
167+ NewDiscriminatedUnionMembership (tc .discriminatorField , members ... ), discriminatorExtractor , extractors ... )
168+ if ! reflect .DeepEqual (got , tc .expected ) {
169+ t .Errorf ("got %v want %v" , got .ToAggregate (), tc .expected .ToAggregate ())
170+ }
171+ })
172+ t .Run ("value" , func (t * testing.T ) {
173+ discriminatorExtractor := func (_ testMember ) string { return tc .discriminatorValue }
148174
149- got := DiscriminatedUnion (context .Background (), operation.Operation {}, nil , & testMember {}, nil ,
150- NewDiscriminatedUnionMembership (tc .discriminatorField , members ... ), discriminatorExtractor , extractors ... )
151- if ! reflect .DeepEqual (got , tc .expected ) {
152- t .Errorf ("got %v want %v" , got .ToAggregate (), tc .expected .ToAggregate ())
153- }
175+ // Create mock extractors that return predefined values instead of
176+ // actually extracting from the object.
177+ extractors := make ([]ExtractorFn [testMember , bool ], len (tc .fieldValues ))
178+ for i , val := range tc .fieldValues {
179+ extractors [i ] = func (_ testMember ) bool { return val }
180+ }
181+
182+ got := DiscriminatedUnion (context .Background (), operation.Operation {}, nil , testMember {}, testMember {},
183+ NewDiscriminatedUnionMembership (tc .discriminatorField , members ... ), discriminatorExtractor , extractors ... )
184+ if ! reflect .DeepEqual (got , tc .expected ) {
185+ t .Errorf ("got %v want %v" , got .ToAggregate (), tc .expected .ToAggregate ())
186+ }
187+ })
154188 })
155189 }
156190}
157191
158192type testStruct struct {
159- M1 * m1 `json:"m1"`
160- M2 * m2 `json:"m2"`
193+ M1 * m1 `json:"m1"`
194+ M2 * m2 `json:"m2"`
195+ M3 []string `json:"m3"`
196+ M4 map [string ]string `json:"m4"`
161197}
162198
163199type m1 struct {}
@@ -176,6 +212,18 @@ var extractors = []ExtractorFn[*testStruct, bool]{
176212 }
177213 return s .M2 != nil
178214 },
215+ func (s * testStruct ) bool {
216+ if s == nil {
217+ return false
218+ }
219+ return len (s .M3 ) != 0
220+ },
221+ func (s * testStruct ) bool {
222+ if s == nil {
223+ return false
224+ }
225+ return len (s .M4 ) != 0
226+ },
179227}
180228
181229func TestUnionRatcheting (t * testing.T ) {
@@ -186,9 +234,12 @@ func TestUnionRatcheting(t *testing.T) {
186234 expected field.ErrorList
187235 }{
188236 {
189- name : "both nil" ,
237+ name : "old nil - no ratcheting " ,
190238 oldStruct : nil ,
191- newStruct : nil ,
239+ newStruct : & testStruct {},
240+ expected : field.ErrorList {
241+ field .Invalid (nil , "" , "must specify one of: `m1`, `m2`, `m3`, `m4`" ),
242+ }.WithOrigin ("union" ),
192243 },
193244 {
194245 name : "both empty struct" ,
@@ -216,14 +267,40 @@ func TestUnionRatcheting(t *testing.T) {
216267 M2 : & m2 {},
217268 },
218269 expected : field.ErrorList {
219- field .Invalid (nil , "{m1, m2}" , "must specify exactly one of: `m1`, `m2`" ),
270+ field .Invalid (nil , "{m1, m2}" , "must specify exactly one of: `m1`, `m2`, `m3`, `m4`" ),
271+ }.WithOrigin ("union" ),
272+ },
273+ {
274+ name : "slice member ratcheting: unchanged membership" ,
275+ oldStruct : & testStruct {M3 : []string {"a" }},
276+ newStruct : & testStruct {M3 : []string {"b" }},
277+ },
278+ {
279+ name : "map member ratcheting: unchanged membership" ,
280+ oldStruct : & testStruct {M4 : map [string ]string {"k" : "v1" }},
281+ newStruct : & testStruct {M4 : map [string ]string {"k" : "v2" }},
282+ },
283+ {
284+ name : "empty slice is not set" ,
285+ oldStruct : nil ,
286+ newStruct : & testStruct {M3 : []string {}},
287+ expected : field.ErrorList {
288+ field .Invalid (nil , "" , "must specify one of: `m1`, `m2`, `m3`, `m4`" ),
289+ }.WithOrigin ("union" ),
290+ },
291+ {
292+ name : "empty map is not set" ,
293+ oldStruct : nil ,
294+ newStruct : & testStruct {M4 : map [string ]string {}},
295+ expected : field.ErrorList {
296+ field .Invalid (nil , "" , "must specify one of: `m1`, `m2`, `m3`, `m4`" ),
220297 }.WithOrigin ("union" ),
221298 },
222299 }
223300
224301 for _ , tc := range testCases {
225302 t .Run (tc .name , func (t * testing.T ) {
226- members := []UnionMember {NewUnionMember ("m1" ), NewUnionMember ("m2" )}
303+ members := []UnionMember {NewUnionMember ("m1" ), NewUnionMember ("m2" ), NewUnionMember ( "m3" ), NewUnionMember ( "m4" ) }
227304 got := Union (context .Background (), operation.Operation {Type : operation .Update }, nil , tc .newStruct , tc .oldStruct ,
228305 NewUnionMembership (members ... ), extractors ... )
229306 if ! reflect .DeepEqual (got , tc .expected ) {
@@ -234,9 +311,11 @@ func TestUnionRatcheting(t *testing.T) {
234311}
235312
236313type testDiscriminatedStruct struct {
237- D string `json:"d"`
238- M1 * m1 `json:"m1"`
239- M2 * m2 `json:"m2"`
314+ D string `json:"d"`
315+ M1 * m1 `json:"m1"`
316+ M2 * m2 `json:"m2"`
317+ M3 []string `json:"m3"`
318+ M4 map [string ]string `json:"m4"`
240319}
241320
242321var testDiscriminatorExtractor = func (s * testDiscriminatedStruct ) string {
@@ -258,6 +337,18 @@ var testDiscriminatedExtractors = []ExtractorFn[*testDiscriminatedStruct, bool]{
258337 }
259338 return s .M2 != nil
260339 },
340+ func (s * testDiscriminatedStruct ) bool {
341+ if s == nil {
342+ return false
343+ }
344+ return len (s .M3 ) != 0
345+ },
346+ func (s * testDiscriminatedStruct ) bool {
347+ if s == nil {
348+ return false
349+ }
350+ return len (s .M4 ) != 0
351+ },
261352}
262353
263354func TestDiscriminatedUnionRatcheting (t * testing.T ) {
@@ -329,11 +420,61 @@ func TestDiscriminatedUnionRatcheting(t *testing.T) {
329420 field .Invalid (field .NewPath ("m2" ), "" , "must be specified when `d` is \" m2\" " ),
330421 }.WithOrigin ("union" ),
331422 },
423+ {
424+ name : "slice member ratcheting: unchanged membership" ,
425+ oldStruct : & testDiscriminatedStruct {
426+ D : "m3" ,
427+ M3 : []string {"a" },
428+ },
429+ newStruct : & testDiscriminatedStruct {
430+ D : "m3" ,
431+ M3 : []string {"b" },
432+ },
433+ },
434+ {
435+ name : "map member ratcheting: unchanged membership" ,
436+ oldStruct : & testDiscriminatedStruct {
437+ D : "m4" ,
438+ M4 : map [string ]string {"k" : "v1" },
439+ },
440+ newStruct : & testDiscriminatedStruct {
441+ D : "m4" ,
442+ M4 : map [string ]string {"k" : "v2" },
443+ },
444+ },
445+ {
446+ name : "empty slice is not set" ,
447+ oldStruct : & testDiscriminatedStruct {
448+ D : "m3" ,
449+ M3 : []string {"a" },
450+ },
451+ newStruct : & testDiscriminatedStruct {
452+ D : "m3" ,
453+ M3 : []string {},
454+ },
455+ expected : field.ErrorList {
456+ field .Invalid (field .NewPath ("m3" ), "" , "must be specified when `d` is \" m3\" " ),
457+ }.WithOrigin ("union" ),
458+ },
459+ {
460+ name : "empty map is not set" ,
461+ oldStruct : & testDiscriminatedStruct {
462+ D : "m4" ,
463+ M4 : map [string ]string {"k" : "v" },
464+ },
465+ newStruct : & testDiscriminatedStruct {
466+ D : "m4" ,
467+ M4 : map [string ]string {},
468+ },
469+ expected : field.ErrorList {
470+ field .Invalid (field .NewPath ("m4" ), "" , "must be specified when `d` is \" m4\" " ),
471+ }.WithOrigin ("union" ),
472+ },
332473 }
333474
334475 for _ , tc := range testCases {
335476 t .Run (tc .name , func (t * testing.T ) {
336- members := []UnionMember {NewDiscriminatedUnionMember ("m1" , "m1" ), NewDiscriminatedUnionMember ("m2" , "m2" )}
477+ members := []UnionMember {NewDiscriminatedUnionMember ("m1" , "m1" ), NewDiscriminatedUnionMember ("m2" , "m2" ), NewDiscriminatedUnionMember ( "m3" , "m3" ), NewDiscriminatedUnionMember ( "m4" , "m4" ) }
337478 got := DiscriminatedUnion (context .Background (), operation.Operation {Type : operation .Update }, nil , tc .newStruct , tc .oldStruct ,
338479 NewDiscriminatedUnionMembership ("d" , members ... ), testDiscriminatorExtractor , testDiscriminatedExtractors ... )
339480 if ! reflect .DeepEqual (got , tc .expected ) {
0 commit comments