Skip to content

Commit 409bcbe

Browse files
Refactor match output compiling to accept user-defined logic. (#1246)
1 parent e9f15ea commit 409bcbe

1 file changed

Lines changed: 55 additions & 5 deletions

File tree

policy/compiler.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,49 @@ func MaxNestedExpressions(limit int) CompilerOption {
189189
}
190190
}
191191

192+
// MatchCompiler is an interface that provides the necessary functionality for compiling the output
193+
// of a match.
194+
type MatchCompiler interface {
195+
// Env returns the rule environment for the match compiler.
196+
Env() *cel.Env
197+
// RelSource returns a RelativeSource for a given expression id.
198+
RelSource(ValueString) *RelativeSource
199+
// ReportErrorAtID reports an error at the given expression id.
200+
ReportErrorAtID(int64, string, ...any)
201+
}
202+
203+
type matchCompilerImpl struct {
204+
env *cel.Env
205+
c *compiler
206+
iss *cel.Issues
207+
}
208+
209+
// Env returns the current environment available to the match expression being compiled.
210+
func (mc *matchCompilerImpl) Env() *cel.Env {
211+
return mc.env
212+
}
213+
214+
// RelSource exposes the compiler's relative source function.
215+
func (mc *matchCompilerImpl) RelSource(pstr ValueString) *RelativeSource {
216+
return mc.c.relSource(pstr)
217+
}
218+
219+
// ReportErrorAtID exposes the issue reporter for the match compiler.
220+
func (mc *matchCompilerImpl) ReportErrorAtID(id int64, message string, args ...any) {
221+
mc.iss.ReportErrorAtID(id, message, args...)
222+
}
223+
224+
// CompileMatchOutputFunc is a function that compiles the output of a match.
225+
type CompileMatchOutputFunc func(mc MatchCompiler, m *Match, p *Policy) (*cel.Ast, *cel.Issues)
226+
227+
// CompileMatchOutput sets a custom match output compiling function.
228+
func CompileMatchOutput(f CompileMatchOutputFunc) CompilerOption {
229+
return func(c *compiler) error {
230+
c.compileMatchOutput = f
231+
return nil
232+
}
233+
}
234+
192235
// Compile combines the policy compilation and composition steps into a single call.
193236
//
194237
// This generates a single CEL AST from a collection of policy expressions associated with a
@@ -211,6 +254,7 @@ func CompileRule(env *cel.Env, p *Policy, opts ...CompilerOption) (*CompiledRule
211254
env: env,
212255
info: p.SourceInfo(),
213256
src: p.Source(),
257+
compileMatchOutput: defaultCompileMatchOutput,
214258
maxNestedExpressions: defaultMaxNestedExpressions,
215259
}
216260
var err error
@@ -248,19 +292,20 @@ func CompileRule(env *cel.Env, p *Policy, opts ...CompilerOption) (*CompiledRule
248292
c.env = env
249293
}
250294
}
251-
return c.compileRule(p.Rule(), c.env, iss)
295+
return c.compileRule(p.Rule(), p, c.env, iss)
252296
}
253297

254298
type compiler struct {
255299
env *cel.Env
256300
info *ast.SourceInfo
257301
src *Source
258302

303+
compileMatchOutput CompileMatchOutputFunc
259304
maxNestedExpressions int
260305
nestedCount int
261306
}
262307

263-
func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) {
308+
func (c *compiler) compileRule(r *Rule, p *Policy, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) {
264309
compiledVars := make([]*CompiledVariable, len(r.Variables()))
265310
for i, v := range r.Variables() {
266311
exprSrc := c.relSource(v.Expression())
@@ -315,8 +360,8 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
315360
continue
316361
}
317362
if m.HasOutput() {
318-
outSrc := c.relSource(m.Output())
319-
outAST, outIss := ruleEnv.CompileSource(outSrc)
363+
mc := &matchCompilerImpl{env: ruleEnv, c: c, iss: iss}
364+
outAST, outIss := c.compileMatchOutput(mc, m, p)
320365
iss = iss.Append(outIss)
321366
compiledMatches = append(compiledMatches, &CompiledMatch{
322367
exprID: m.exprID,
@@ -329,7 +374,7 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
329374
continue
330375
}
331376
if m.HasRule() {
332-
nestedRule, ruleIss := c.compileRule(m.Rule(), ruleEnv, iss)
377+
nestedRule, ruleIss := c.compileRule(m.Rule(), p, ruleEnv, iss)
333378
iss = iss.Append(ruleIss)
334379
compiledMatches = append(compiledMatches, &CompiledMatch{
335380
exprID: m.exprID,
@@ -362,6 +407,11 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
362407
return rule, iss
363408
}
364409

410+
func defaultCompileMatchOutput(mc MatchCompiler, m *Match, p *Policy) (*cel.Ast, *cel.Issues) {
411+
outSrc := mc.RelSource(m.Output())
412+
return mc.Env().CompileSource(outSrc)
413+
}
414+
365415
func (c *compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issues) {
366416
var outputType *cel.Type
367417
for _, m := range rule.Matches() {

0 commit comments

Comments
 (0)