Skip to content

Commit 5c0d6df

Browse files
authored
feat(firestore): Adding vector search (#10548)
* feat(firestore): Adding vector search * feat(firestore): refactoring code * feat(firestore): Resolving vet failures * feat(firestore): Adding unit and integration tests * feat(firestore): Fixing tests and refactoring code * feat(firestore): Resolving vet failures * feat(firestore): Refactoring code * feat(firestore): Resolving review comments
1 parent 6b32871 commit 5c0d6df

8 files changed

Lines changed: 964 additions & 122 deletions

File tree

firestore/document.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ func (d *DocumentSnapshot) Data() map[string]interface{} {
9797
// Slices are resized to the incoming value's size, while arrays that are too
9898
// long have excess elements filled with zero values. If the array is too short,
9999
// excess incoming values will be dropped.
100+
// - Vectors convert to []float64
100101
// - Maps convert to map[string]interface{}. When setting a struct field,
101102
// maps of key type string and any value type are permitted, and are populated
102103
// recursively.

firestore/from_value.go

Lines changed: 85 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -32,221 +32,236 @@ func setFromProtoValue(x interface{}, vproto *pb.Value, c *Client) error {
3232
return setReflectFromProtoValue(v.Elem(), vproto, c)
3333
}
3434

35-
// setReflectFromProtoValue sets v from a Firestore Value.
36-
// v must be a settable value.
37-
func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) error {
35+
// setReflectFromProtoValue sets vDest from a Firestore Value.
36+
// vDest must be a settable value.
37+
func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Client) error {
3838
typeErr := func() error {
39-
return fmt.Errorf("firestore: cannot set type %s to %s", v.Type(), typeString(vproto))
39+
return fmt.Errorf("firestore: cannot set type %s to %s", vDest.Type(), typeString(vprotoSrc))
4040
}
4141

42-
val := vproto.ValueType
42+
valTypeSrc := vprotoSrc.ValueType
4343
// A Null value sets anything nullable to nil, and has no effect
4444
// on anything else.
45-
if _, ok := val.(*pb.Value_NullValue); ok {
46-
switch v.Kind() {
45+
if _, ok := valTypeSrc.(*pb.Value_NullValue); ok {
46+
switch vDest.Kind() {
4747
case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice:
48-
v.Set(reflect.Zero(v.Type()))
48+
vDest.Set(reflect.Zero(vDest.Type()))
4949
}
5050
return nil
5151
}
5252

5353
// Handle special types first.
54-
switch v.Type() {
54+
switch vDest.Type() {
5555
case typeOfByteSlice:
56-
x, ok := val.(*pb.Value_BytesValue)
56+
x, ok := valTypeSrc.(*pb.Value_BytesValue)
5757
if !ok {
5858
return typeErr()
5959
}
60-
v.SetBytes(x.BytesValue)
60+
vDest.SetBytes(x.BytesValue)
6161
return nil
6262

6363
case typeOfGoTime:
64-
x, ok := val.(*pb.Value_TimestampValue)
64+
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
6565
if !ok {
6666
return typeErr()
6767
}
6868
if err := x.TimestampValue.CheckValid(); err != nil {
6969
return err
7070
}
71-
v.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
71+
vDest.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
7272
return nil
7373

7474
case typeOfProtoTimestamp:
75-
x, ok := val.(*pb.Value_TimestampValue)
75+
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
7676
if !ok {
7777
return typeErr()
7878
}
79-
v.Set(reflect.ValueOf(x.TimestampValue))
79+
vDest.Set(reflect.ValueOf(x.TimestampValue))
8080
return nil
8181

8282
case typeOfLatLng:
83-
x, ok := val.(*pb.Value_GeoPointValue)
83+
x, ok := valTypeSrc.(*pb.Value_GeoPointValue)
8484
if !ok {
8585
return typeErr()
8686
}
87-
v.Set(reflect.ValueOf(x.GeoPointValue))
87+
vDest.Set(reflect.ValueOf(x.GeoPointValue))
8888
return nil
8989

9090
case typeOfDocumentRef:
91-
x, ok := val.(*pb.Value_ReferenceValue)
91+
x, ok := valTypeSrc.(*pb.Value_ReferenceValue)
9292
if !ok {
9393
return typeErr()
9494
}
9595
dr, err := pathToDoc(x.ReferenceValue, c)
9696
if err != nil {
9797
return err
9898
}
99-
v.Set(reflect.ValueOf(dr))
99+
vDest.Set(reflect.ValueOf(dr))
100+
return nil
101+
102+
case typeOfVector32:
103+
val, err := vector32FromProtoValue(vprotoSrc)
104+
if err != nil {
105+
return err
106+
}
107+
vDest.Set(reflect.ValueOf(val))
108+
return nil
109+
case typeOfVector64:
110+
val, err := vector64FromProtoValue(vprotoSrc)
111+
if err != nil {
112+
return err
113+
}
114+
vDest.Set(reflect.ValueOf(val))
100115
return nil
101116
}
102117

103-
switch v.Kind() {
118+
switch vDest.Kind() {
104119
case reflect.Bool:
105-
x, ok := val.(*pb.Value_BooleanValue)
120+
x, ok := valTypeSrc.(*pb.Value_BooleanValue)
106121
if !ok {
107122
return typeErr()
108123
}
109-
v.SetBool(x.BooleanValue)
124+
vDest.SetBool(x.BooleanValue)
110125

111126
case reflect.String:
112-
x, ok := val.(*pb.Value_StringValue)
127+
x, ok := valTypeSrc.(*pb.Value_StringValue)
113128
if !ok {
114129
return typeErr()
115130
}
116-
v.SetString(x.StringValue)
131+
vDest.SetString(x.StringValue)
117132

118133
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
119134
var i int64
120-
switch x := val.(type) {
135+
switch x := valTypeSrc.(type) {
121136
case *pb.Value_IntegerValue:
122137
i = x.IntegerValue
123138
case *pb.Value_DoubleValue:
124139
f := x.DoubleValue
125140
i = int64(f)
126141
if float64(i) != f {
127-
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
142+
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
128143
}
129144
default:
130145
return typeErr()
131146
}
132-
if v.OverflowInt(i) {
133-
return overflowErr(v, i)
147+
if vDest.OverflowInt(i) {
148+
return overflowErr(vDest, i)
134149
}
135-
v.SetInt(i)
150+
vDest.SetInt(i)
136151

137152
case reflect.Uint8, reflect.Uint16, reflect.Uint32:
138153
var u uint64
139-
switch x := val.(type) {
154+
switch x := valTypeSrc.(type) {
140155
case *pb.Value_IntegerValue:
141156
u = uint64(x.IntegerValue)
142157
case *pb.Value_DoubleValue:
143158
f := x.DoubleValue
144159
u = uint64(f)
145160
if float64(u) != f {
146-
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
161+
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
147162
}
148163
default:
149164
return typeErr()
150165
}
151-
if v.OverflowUint(u) {
152-
return overflowErr(v, u)
166+
if vDest.OverflowUint(u) {
167+
return overflowErr(vDest, u)
153168
}
154-
v.SetUint(u)
169+
vDest.SetUint(u)
155170

156171
case reflect.Float32, reflect.Float64:
157172
var f float64
158-
switch x := val.(type) {
173+
switch x := valTypeSrc.(type) {
159174
case *pb.Value_DoubleValue:
160175
f = x.DoubleValue
161176
case *pb.Value_IntegerValue:
162177
f = float64(x.IntegerValue)
163178
if int64(f) != x.IntegerValue {
164-
return overflowErr(v, x.IntegerValue)
179+
return overflowErr(vDest, x.IntegerValue)
165180
}
166181
default:
167182
return typeErr()
168183
}
169-
if v.OverflowFloat(f) {
170-
return overflowErr(v, f)
184+
if vDest.OverflowFloat(f) {
185+
return overflowErr(vDest, f)
171186
}
172-
v.SetFloat(f)
187+
vDest.SetFloat(f)
173188

174189
case reflect.Slice:
175-
x, ok := val.(*pb.Value_ArrayValue)
190+
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
176191
if !ok {
177192
return typeErr()
178193
}
179194
vals := x.ArrayValue.Values
180-
vlen := v.Len()
195+
vlen := vDest.Len()
181196
xlen := len(vals)
182197
// Make a slice of the right size, avoiding allocation if possible.
183198
switch {
184199
case vlen < xlen:
185-
v.Set(reflect.MakeSlice(v.Type(), xlen, xlen))
200+
vDest.Set(reflect.MakeSlice(vDest.Type(), xlen, xlen))
186201
case vlen > xlen:
187-
v.SetLen(xlen)
202+
vDest.SetLen(xlen)
188203
}
189-
return populateRepeated(v, vals, xlen, c)
204+
return populateRepeated(vDest, vals, xlen, c)
190205

191206
case reflect.Array:
192-
x, ok := val.(*pb.Value_ArrayValue)
207+
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
193208
if !ok {
194209
return typeErr()
195210
}
196211
vals := x.ArrayValue.Values
197212
xlen := len(vals)
198-
vlen := v.Len()
213+
vlen := vDest.Len()
199214
minlen := vlen
200215
// Set extra elements to their zero value.
201216
if vlen > xlen {
202-
z := reflect.Zero(v.Type().Elem())
217+
z := reflect.Zero(vDest.Type().Elem())
203218
for i := xlen; i < vlen; i++ {
204-
v.Index(i).Set(z)
219+
vDest.Index(i).Set(z)
205220
}
206221
minlen = xlen
207222
}
208-
return populateRepeated(v, vals, minlen, c)
223+
return populateRepeated(vDest, vals, minlen, c)
209224

210225
case reflect.Map:
211-
x, ok := val.(*pb.Value_MapValue)
226+
x, ok := valTypeSrc.(*pb.Value_MapValue)
212227
if !ok {
213228
return typeErr()
214229
}
215-
return populateMap(v, x.MapValue.Fields, c)
230+
return populateMap(vDest, x.MapValue.Fields, c)
216231

217232
case reflect.Ptr:
218233
// If the pointer is nil, set it to a zero value.
219-
if v.IsNil() {
220-
v.Set(reflect.New(v.Type().Elem()))
234+
if vDest.IsNil() {
235+
vDest.Set(reflect.New(vDest.Type().Elem()))
221236
}
222-
return setReflectFromProtoValue(v.Elem(), vproto, c)
237+
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)
223238

224239
case reflect.Struct:
225-
x, ok := val.(*pb.Value_MapValue)
240+
x, ok := valTypeSrc.(*pb.Value_MapValue)
226241
if !ok {
227242
return typeErr()
228243
}
229-
return populateStruct(v, x.MapValue.Fields, c)
244+
return populateStruct(vDest, x.MapValue.Fields, c)
230245

231246
case reflect.Interface:
232-
if v.NumMethod() == 0 { // empty interface
247+
if vDest.NumMethod() == 0 { // empty interface
233248
// If v holds a pointer, set the pointer.
234-
if !v.IsNil() && v.Elem().Kind() == reflect.Ptr {
235-
return setReflectFromProtoValue(v.Elem(), vproto, c)
249+
if !vDest.IsNil() && vDest.Elem().Kind() == reflect.Ptr {
250+
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)
236251
}
237252
// Otherwise, create a fresh value.
238-
x, err := createFromProtoValue(vproto, c)
253+
x, err := createFromProtoValue(vprotoSrc, c)
239254
if err != nil {
240255
return err
241256
}
242-
v.Set(reflect.ValueOf(x))
257+
vDest.Set(reflect.ValueOf(x))
243258
return nil
244259
}
245260
// Any other kind of interface is an error.
246261
fallthrough
247262

248263
default:
249-
return fmt.Errorf("firestore: cannot set type %s", v.Type())
264+
return fmt.Errorf("firestore: cannot set type %s", vDest.Type())
250265
}
251266
return nil
252267
}
@@ -389,8 +404,15 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) {
389404
}
390405
ret[k] = r
391406
}
392-
return ret, nil
393407

408+
typeVal, ok := ret[typeKey]
409+
if !ok || typeVal != typeValVector {
410+
// Map is not a vector. Return the map
411+
return ret, nil
412+
}
413+
414+
// Special handling for vector
415+
return vectorFromProtoValue(vproto)
394416
default:
395417
return nil, fmt.Errorf("firestore: unknown value type %T", v)
396418
}

0 commit comments

Comments
 (0)