Skip to content

Commit ae49cd0

Browse files
Json field names runtime support (#1286)
* Add checker, ast, and type-provider support for JSON names * Support backwards compatible JSON field / proto field resolution * Runtime support for JSON field names * Updated JSON name runtime support with fallback for proto field names * Adjusted registry json support to be an in-place update * Update to SourceInfo.HasExtension
1 parent 3624b64 commit ae49cd0

19 files changed

Lines changed: 481 additions & 141 deletions

cel/cel_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,6 +2234,62 @@ func TestContextProto(t *testing.T) {
22342234
}
22352235
}
22362236

2237+
func TestContextProtoJSONFieldNames(t *testing.T) {
2238+
descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor()
2239+
env := testEnv(t, JSONFieldNames(true), DeclareContextProto(descriptor))
2240+
expression := `
2241+
singleInt64 == 1
2242+
&& singleDouble == 1.0
2243+
&& singleBool == true
2244+
&& singleString == ''
2245+
&& singleNestedMessage == google.expr.proto3.test.TestAllTypes.NestedMessage{}
2246+
&& standaloneEnum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO
2247+
&& singleDuration == duration('5s')
2248+
&& singleTimestamp == timestamp(63154820)
2249+
&& singleAny == null
2250+
&& singleUint32Wrapper == null
2251+
&& singleUint64Wrapper == 0u
2252+
&& repeatedInt32 == [1,2]
2253+
&& mapStringString == {'': ''}
2254+
&& mapInt64NestedType == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
2255+
ast, iss := env.Compile(expression)
2256+
if iss.Err() != nil {
2257+
t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err())
2258+
}
2259+
prg, err := env.Program(ast)
2260+
if err != nil {
2261+
t.Fatalf("env.Program() failed: %v", err)
2262+
}
2263+
in := &proto3pb.TestAllTypes{
2264+
SingleInt64: 1,
2265+
SingleDouble: 1.0,
2266+
SingleBool: true,
2267+
NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{
2268+
SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{},
2269+
},
2270+
StandaloneEnum: proto3pb.TestAllTypes_FOO,
2271+
SingleDuration: &durationpb.Duration{Seconds: 5},
2272+
SingleTimestamp: &timestamppb.Timestamp{
2273+
Seconds: 63154820,
2274+
},
2275+
SingleUint64Wrapper: wrapperspb.UInt64(0),
2276+
RepeatedInt32: []int32{1, 2},
2277+
MapStringString: map[string]string{"": ""},
2278+
MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}},
2279+
}
2280+
vars, err := ContextProtoVars(in, types.JSONFieldNames(true))
2281+
if err != nil {
2282+
t.Fatalf("ContextProtoVars(%v) failed: %v", in, err)
2283+
}
2284+
out, _, err := prg.Eval(vars)
2285+
if err != nil {
2286+
t.Fatalf("prg.Eval() failed: %v", err)
2287+
}
2288+
if out.Equal(types.True) != types.True {
2289+
t.Errorf("prg.Eval() got %v, wanted true", out)
2290+
}
2291+
}
2292+
22372293
func TestRegexOptimizer(t *testing.T) {
22382294
var stringTests = []struct {
22392295
expr string
@@ -3607,6 +3663,125 @@ func TestAstProgramNilValue(t *testing.T) {
36073663
}
36083664
}
36093665

3666+
func TestJSONFieldNames(t *testing.T) {
3667+
tests := []struct {
3668+
name string
3669+
expr string
3670+
jsonFieldNames bool
3671+
}{
3672+
{
3673+
name: "proto simple field",
3674+
expr: `msg.single_int32 == 1`,
3675+
},
3676+
{
3677+
name: "proto map field",
3678+
expr: `msg.map_string_string['key'] == 'value'`,
3679+
},
3680+
{
3681+
name: "json simple field",
3682+
expr: `msg.singleInt32 == 1`,
3683+
jsonFieldNames: true,
3684+
},
3685+
{
3686+
name: "json repeated field",
3687+
expr: `msg.mapStringString['key'] == 'value'`,
3688+
jsonFieldNames: true,
3689+
},
3690+
{
3691+
name: "message with json field",
3692+
expr: `TestAllTypes{singleInt32: 1} != msg`,
3693+
jsonFieldNames: true,
3694+
},
3695+
{
3696+
name: "message with json field and proto fallback",
3697+
expr: `dyn(TestAllTypes{singleInt32: 2}).single_int32 == 2`,
3698+
jsonFieldNames: true,
3699+
},
3700+
{
3701+
name: "json with proto fallback",
3702+
expr: `dyn(msg).single_int32 == dyn(msg).singleInt32`,
3703+
jsonFieldNames: true,
3704+
},
3705+
}
3706+
msg := &proto3pb.TestAllTypes{
3707+
SingleInt32: 1,
3708+
MapStringString: map[string]string{
3709+
"key": "value",
3710+
},
3711+
}
3712+
for _, tst := range tests {
3713+
tc := tst
3714+
t.Run(tc.name, func(t *testing.T) {
3715+
env, err := NewEnv(
3716+
JSONFieldNames(tc.jsonFieldNames),
3717+
Types(msg),
3718+
Container(string(msg.ProtoReflect().Descriptor().ParentFile().Package())),
3719+
Variable("msg", ObjectType(string(msg.ProtoReflect().Descriptor().FullName()))),
3720+
)
3721+
if err != nil {
3722+
t.Fatalf("NewEnv() failed: %v", err)
3723+
}
3724+
ast, iss := env.Compile(tc.expr)
3725+
if iss.Err() != nil {
3726+
t.Fatalf("env.Compile() failed: %v", iss.Err())
3727+
}
3728+
prg, err := env.Program(ast)
3729+
if err != nil {
3730+
t.Fatalf("env.Program() failed: %v", err)
3731+
}
3732+
out, _, err := prg.Eval(map[string]any{"msg": msg})
3733+
if err != nil {
3734+
t.Fatalf("prg.Eval() failed: %v", err)
3735+
}
3736+
if out != types.True {
3737+
t.Errorf("prg.Eval() got %v, wanted 'true'", out)
3738+
}
3739+
3740+
if tc.jsonFieldNames {
3741+
noJSONEnv, err := env.Extend(JSONFieldNames(false))
3742+
if err != nil {
3743+
t.Fatalf("env.Extend() failed: %v", err)
3744+
}
3745+
_, err = noJSONEnv.Program(ast)
3746+
if err == nil {
3747+
t.Fatal("env with json disabled allowed program with json extension to be planned")
3748+
}
3749+
} else {
3750+
jsonEnv, err := env.Extend(JSONFieldNames(true))
3751+
if err != nil {
3752+
t.Fatalf("env.Extend() failed: %v", err)
3753+
}
3754+
prg, err = jsonEnv.Program(ast)
3755+
if err != nil {
3756+
t.Fatalf("env.Program() failed: %v", err)
3757+
}
3758+
out, _, err := prg.Eval(map[string]any{"msg": msg})
3759+
if err != nil {
3760+
t.Fatalf("prg.Eval() failed: %v", err)
3761+
}
3762+
if out != types.True {
3763+
t.Errorf("prg.Eval() got %v, wanted 'true'", out)
3764+
}
3765+
}
3766+
})
3767+
}
3768+
}
3769+
3770+
func TestJSONFieldNamesInvalidProvider(t *testing.T) {
3771+
type wrapperRegistry struct {
3772+
*types.Registry
3773+
}
3774+
reg, err := types.NewProtoRegistry(types.JSONFieldNames(true))
3775+
if err != nil {
3776+
t.Fatalf("types.NewProtoRegistry() failed: %v", err)
3777+
}
3778+
wrapped := wrapperRegistry{Registry: reg}
3779+
_, err = NewEnv(CustomTypeProvider(wrapped), CustomTypeAdapter(reg), JSONFieldNames(true))
3780+
if err == nil {
3781+
t.Error("NewEnv() created a CEL environment successfully despite incompatible configs")
3782+
}
3783+
}
3784+
36103785
// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
36113786
type testCostEstimator struct {
36123787
hints map[string]uint64

cel/env.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,16 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
184184
conf.AddImports(env.NewImport(typeName))
185185
}
186186

187+
// Serialize features
188+
for featID, enabled := range e.features {
189+
featName, found := featureNameByID(featID)
190+
if !found {
191+
// If the feature isn't named, it isn't intended to be publicly exposed
192+
continue
193+
}
194+
conf.AddFeatures(env.NewFeature(featName, enabled))
195+
}
196+
187197
libOverloads := map[string][]string{}
188198
for libName, lib := range e.libraries {
189199
// Track the options which have been configured by a library and
@@ -244,7 +254,7 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
244254
fields := e.contextProto.Fields()
245255
for i := 0; i < fields.Len(); i++ {
246256
field := fields.Get(i)
247-
variable, err := fieldToVariable(field)
257+
variable, err := fieldToVariable(field, e.HasFeature(featureJSONFieldNames))
248258
if err != nil {
249259
return nil, fmt.Errorf("could not serialize context field variable %q, reason: %w", field.FullName(), err)
250260
}
@@ -279,16 +289,6 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
279289
}
280290
}
281291

282-
// Serialize features
283-
for featID, enabled := range e.features {
284-
featName, found := featureNameByID(featID)
285-
if !found {
286-
// If the feature isn't named, it isn't intended to be publicly exposed
287-
continue
288-
}
289-
conf.AddFeatures(env.NewFeature(featName, enabled))
290-
}
291-
292292
for id, val := range e.limits {
293293
limitName, found := limitNameByID(id)
294294
if !found || val == 0 {
@@ -361,7 +361,7 @@ func NewEnv(opts ...EnvOption) (*Env, error) {
361361
// See the EnvOption helper functions for the options that can be used to configure the
362362
// environment.
363363
func NewCustomEnv(opts ...EnvOption) (*Env, error) {
364-
registry, err := types.NewRegistry()
364+
registry, err := types.NewProtoRegistry()
365365
if err != nil {
366366
return nil, err
367367
}
@@ -554,6 +554,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
554554
}
555555
validatorsCopy := make([]ASTValidator, len(e.validators))
556556
copy(validatorsCopy, e.validators)
557+
557558
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
558559
copy(costOptsCopy, e.costOptions)
559560

@@ -847,6 +848,18 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
847848
return nil, err
848849
}
849850

851+
// Enable JSON field names is using a proto-based *types.Registry
852+
if e.HasFeature(featureJSONFieldNames) {
853+
reg, isReg := e.provider.(*types.Registry)
854+
if !isReg {
855+
return nil, fmt.Errorf("JSONFieldNames() option is only compatible with *types.Registry providers")
856+
}
857+
err := reg.WithJSONFieldNames(true)
858+
if err != nil {
859+
return nil, err
860+
}
861+
}
862+
850863
// Ensure that the checker init happens eagerly rather than lazily.
851864
if e.HasFeature(featureEagerlyValidateDeclarations) {
852865
_, err := e.initChecker()
@@ -865,6 +878,8 @@ func (e *Env) initChecker() (*checker.Env, error) {
865878
chkOpts = append(chkOpts,
866879
checker.CrossTypeNumericComparisons(
867880
e.HasFeature(featureCrossTypeNumericComparisons)))
881+
chkOpts = append(chkOpts,
882+
checker.JSONFieldNames(e.HasFeature(featureJSONFieldNames)))
868883

869884
ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...)
870885
if err != nil {

cel/env_test.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ func TestEnvPartialVarsError(t *testing.T) {
198198
}
199199

200200
func TestTypeProviderInterop(t *testing.T) {
201-
reg, err := types.NewRegistry(&proto3pb.TestAllTypes{})
201+
reg, err := types.NewProtoRegistry(types.ProtoTypeDefs(&proto3pb.TestAllTypes{}))
202202
if err != nil {
203-
t.Fatalf("types.NewRegistry() failed: %v", err)
203+
t.Fatalf("types.NewProtoRegistry() failed: %v", err)
204204
}
205205
tests := []struct {
206206
name string
@@ -399,6 +399,14 @@ func TestEnvToConfig(t *testing.T) {
399399
env.NewMemberOverload("string_last", env.NewTypeDesc("string"), []*env.TypeDesc{}, env.NewTypeDesc("string")),
400400
)),
401401
},
402+
{
403+
name: "json field names",
404+
opts: []EnvOption{
405+
JSONFieldNames(true),
406+
},
407+
want: env.NewConfig("json field names").
408+
AddFeatures(env.NewFeature("cel.feature.json_field_names", true)),
409+
},
402410
{
403411
name: "context proto - with extra variable",
404412
opts: []EnvOption{
@@ -495,7 +503,7 @@ func TestEnvFromConfig(t *testing.T) {
495503
{
496504
name: "std env - imports",
497505
beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})},
498-
conf: env.NewConfig("std env - context proto").
506+
conf: env.NewConfig("std env - imports").
499507
AddImports(env.NewImport("google.expr.proto3.test.TestAllTypes")),
500508
exprs: []exprCase{
501509
{
@@ -520,6 +528,22 @@ func TestEnvFromConfig(t *testing.T) {
520528
},
521529
},
522530
},
531+
{
532+
name: "std env - context proto w/ json field names",
533+
beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})},
534+
conf: env.NewConfig("std env - context proto w/ json field names").
535+
SetContainer("google.expr.proto3.test").
536+
SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")).
537+
AddFeatures(env.NewFeature("cel.feature.json_field_names", true)),
538+
exprs: []exprCase{
539+
{
540+
name: "field select literal",
541+
in: mustContextProto(t, &proto3pb.TestAllTypes{SingleInt64: 10}, types.JSONFieldNames(true)),
542+
expr: "TestAllTypes{singleInt64: singleInt64}.singleInt64",
543+
out: types.Int(10),
544+
},
545+
},
546+
},
523547
{
524548
name: "custom env - variables",
525549
beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})},
@@ -1154,9 +1178,9 @@ func BenchmarkEnvExtendEagerDecls(b *testing.B) {
11541178
}
11551179
}
11561180

1157-
func mustContextProto(t *testing.T, pb proto.Message) Activation {
1181+
func mustContextProto(t *testing.T, pb proto.Message, opts ...types.RegistryOption) Activation {
11581182
t.Helper()
1159-
ctx, err := ContextProtoVars(pb)
1183+
ctx, err := ContextProtoVars(pb, opts...)
11601184
if err != nil {
11611185
t.Fatalf("ContextProtoVars() failed: %v", err)
11621186
}

0 commit comments

Comments
 (0)