Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions ext/formatting.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {

// stringFormatValidator implements the cel.ASTValidator interface allowing for static validation
// of string.format calls.
type stringFormatValidator struct{}
type stringFormatValidator struct {
maxPrecision int
}

// Name returns the name of the validator.
func (stringFormatValidator) Name() string {
Expand All @@ -427,7 +429,7 @@ func (stringFormatValidator) Configure(config cel.MutableValidatorConfig) error

// Validate parses all literal format strings and type checks the format clause against the argument
// at the corresponding ordinal within the list literal argument to the function, if one is specified.
func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
func (v stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
root := ast.NavigateAST(a)
formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a))
for _, e := range formatCallExprs {
Expand All @@ -439,7 +441,7 @@ func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *as
ast: a,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US")
_, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US", v.maxPrecision)
if err != nil {
iss.ReportErrorAtID(getErrorExprID(e.ID(), err), "%v", err)
continue
Expand Down Expand Up @@ -778,7 +780,7 @@ type formatListArgs interface {

// parseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func parseFormatString(formatStr string, callback formatStringInterpolator, list formatListArgs, locale string) (string, error) {
func parseFormatString(formatStr string, callback formatStringInterpolator, list formatListArgs, locale string, maxPrecision int) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
Expand All @@ -802,7 +804,7 @@ func parseFormatString(formatStr string, callback formatStringInterpolator, list
if int64(argIndex) >= list.Size() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale, maxPrecision)
if refErr != nil {
return "", refErr
}
Expand All @@ -826,9 +828,9 @@ func parseFormatString(formatStr string, callback formatStringInterpolator, list

// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClause(formatStr string, val ref.Val, callback formatStringInterpolator, list formatListArgs, locale string) (int, string, error) {
func parseAndFormatClause(formatStr string, val ref.Val, callback formatStringInterpolator, list formatListArgs, locale string, maxPrecision int) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
read, formatter, err := parseFormattingClause(formatStr[i:], callback, maxPrecision)
i += read
if err != nil {
return -1, "", newParseFormatError("could not parse formatting clause", err)
Expand All @@ -841,9 +843,9 @@ func parseAndFormatClause(formatStr string, val ref.Val, callback formatStringIn
return i, valStr, nil
}

func parseFormattingClause(formatStr string, callback formatStringInterpolator) (int, clauseImpl, error) {
func parseFormattingClause(formatStr string, callback formatStringInterpolator, maxPrecision int) (int, clauseImpl, error) {
i := 0
read, precision, err := parsePrecision(formatStr[i:])
read, precision, err := parsePrecision(formatStr[i:], maxPrecision)
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
Expand All @@ -870,7 +872,7 @@ func parseFormattingClause(formatStr string, callback formatStringInterpolator)
}
}

func parsePrecision(formatStr string) (int, *int, error) {
func parsePrecision(formatStr string, maxPrecision int) (int, *int, error) {
i := 0
if formatStr[i] != '.' {
return i, nil, nil
Expand All @@ -891,6 +893,9 @@ func parsePrecision(formatStr string) (int, *int, error) {
if err != nil {
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
}
if maxPrecision > 0 && precision > maxPrecision {
return -1, nil, fmt.Errorf("precision %d exceeds maximum allowed precision %d", precision, maxPrecision)
}
return i, &precision, nil
}

Expand Down
8 changes: 8 additions & 0 deletions ext/formatting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,14 @@ func TestStringFormat(t *testing.T) {
formatArgs: "3.14",
err: "error during formatting: octal clause can only be used on integers",
},
{
name: "high precision allowed on older version",
format: "%.200f",
formatArgs: "1.0",
expectedOutput: "1." + strings.Repeat("0", 200),
skipCompileCheck: true,
locale: "en_US",
},
}
evalExpr := func(env *cel.Env, expr string, evalArgs any, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) {
t.Logf("evaluating expr: %s", expr)
Expand Down
25 changes: 15 additions & 10 deletions ext/formatting_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ func (c *stringFormatterV2) Octal(arg ref.Val) (string, error) {

// stringFormatValidatorV2 implements the cel.ASTValidator interface allowing for static validation
// of string.format calls.
type stringFormatValidatorV2 struct{}
type stringFormatValidatorV2 struct {
maxPrecision int
}

// Name returns the name of the validator.
func (stringFormatValidatorV2) Name() string {
Expand All @@ -419,7 +421,7 @@ func (stringFormatValidatorV2) Configure(config cel.MutableValidatorConfig) erro

// Validate parses all literal format strings and type checks the format clause against the argument
// at the corresponding ordinal within the list literal argument to the function, if one is specified.
func (stringFormatValidatorV2) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
func (v stringFormatValidatorV2) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
root := ast.NavigateAST(a)
formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a))
for _, e := range formatCallExprs {
Expand All @@ -431,7 +433,7 @@ func (stringFormatValidatorV2) Validate(env *cel.Env, _ cel.ValidatorConfig, a *
ast: a,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := parseFormatStringV2(formatStr, formatCheck, formatCheck)
_, err := parseFormatStringV2(formatStr, formatCheck, formatCheck, v.maxPrecision)
if err != nil {
iss.ReportErrorAtID(getErrorExprID(e.ID(), err), "%v", err)
continue
Expand Down Expand Up @@ -668,7 +670,7 @@ type formatStringInterpolatorV2 interface {

// parseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2, list formatListArgs) (string, error) {
func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2, list formatListArgs, maxPrecision int) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
Expand All @@ -692,7 +694,7 @@ func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2,
if int64(argIndex) >= list.Size() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClauseV2(formatStr[i:], argAny, callback, list)
numRead, val, refErr := parseAndFormatClauseV2(formatStr[i:], argAny, callback, list, maxPrecision)
if refErr != nil {
return "", refErr
}
Expand All @@ -716,9 +718,9 @@ func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2,

// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClauseV2(formatStr string, val ref.Val, callback formatStringInterpolatorV2, list formatListArgs) (int, string, error) {
func parseAndFormatClauseV2(formatStr string, val ref.Val, callback formatStringInterpolatorV2, list formatListArgs, maxPrecision int) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClauseV2(formatStr[i:], callback)
read, formatter, err := parseFormattingClauseV2(formatStr[i:], callback, maxPrecision)
i += read
if err != nil {
return -1, "", newParseFormatError("could not parse formatting clause", err)
Expand All @@ -731,9 +733,9 @@ func parseAndFormatClauseV2(formatStr string, val ref.Val, callback formatString
return i, valStr, nil
}

func parseFormattingClauseV2(formatStr string, callback formatStringInterpolatorV2) (int, clauseImplV2, error) {
func parseFormattingClauseV2(formatStr string, callback formatStringInterpolatorV2, maxPrecision int) (int, clauseImplV2, error) {
i := 0
read, precision, err := parsePrecisionV2(formatStr[i:])
read, precision, err := parsePrecisionV2(formatStr[i:], maxPrecision)
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
Expand All @@ -760,7 +762,7 @@ func parseFormattingClauseV2(formatStr string, callback formatStringInterpolator
}
}

func parsePrecisionV2(formatStr string) (int, int, error) {
func parsePrecisionV2(formatStr string, maxPrecision int) (int, int, error) {
i := 0
if formatStr[i] != '.' {
return i, defaultPrecision, nil
Expand All @@ -784,5 +786,8 @@ func parsePrecisionV2(formatStr string) (int, int, error) {
if precision < 0 {
return -1, -1, fmt.Errorf("negative precision: %d", precision)
}
if maxPrecision > 0 && precision > maxPrecision {
return -1, -1, fmt.Errorf("precision %d exceeds maximum allowed precision %d", precision, maxPrecision)
}
return i, precision, nil
}
7 changes: 7 additions & 0 deletions ext/formatting_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,13 @@ func TestStringFormatV2(t *testing.T) {
formatArgs: "3.14",
err: "octal clause can only be used on ints and uints, was given double",
},
{
name: "precision exceeds maximum",
format: "%.9999999f",
formatArgs: "3.14",
skipCompileCheck: true,
err: "precision 9999999 exceeds maximum allowed precision 100",
},
}
evalExpr := func(env *cel.Env, expr string, evalArgs any, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) {
t.Logf("evaluating expr: %s", expr)
Expand Down
30 changes: 24 additions & 6 deletions ext/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ func Strings(options ...StringsOption) cel.EnvOption {
}

type stringLib struct {
locale string
version uint32
locale string
version uint32
maxPrecision int
}

// LibraryName implements the SingletonLibrary interface method.
Expand Down Expand Up @@ -353,6 +354,16 @@ func StringsValidateFormatCalls(value bool) StringsOption {
}
}

// StringsMaxPrecision configures the maximum precision for floating-point format clauses.
//
// If not set, the default is 100 for version >= 5, and no limit for earlier versions.
func StringsMaxPrecision(limit int) StringsOption {
return func(lib *stringLib) *stringLib {
lib.maxPrecision = limit
return lib
}
}

// CompileOptions implements the Library interface method.
func (lib *stringLib) CompileOptions() []cel.EnvOption {
formatLocale := "en_US"
Expand Down Expand Up @@ -470,22 +481,29 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
return stringOrError(upperASCII(string(s)))
}))),
}
// maxPrecision is unbounded (0) for versions < 5 to maintain backward
// compatibility. For version >= 5, the default is 100 if not explicitly
// configured via StringsMaxPrecision().
maxPrecision := lib.maxPrecision
if maxPrecision == 0 && lib.version >= 5 {
maxPrecision = 100
}
if lib.version >= 1 {
if lib.version >= 4 {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(parseFormatStringV2(s, &stringFormatterV2{}, &stringArgList{formatArgs}))
return stringOrError(parseFormatStringV2(s, &stringFormatterV2{}, &stringArgList{formatArgs}, maxPrecision))
}))))
} else {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale, maxPrecision))
}))))
}
opts = append(opts,
Expand Down Expand Up @@ -544,9 +562,9 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
}
if lib.version >= 1 {
if lib.version >= 4 {
opts = append(opts, cel.ASTValidators(stringFormatValidatorV2{}))
opts = append(opts, cel.ASTValidators(stringFormatValidatorV2{maxPrecision: maxPrecision}))
Comment thread
Flo354 marked this conversation as resolved.
} else {
opts = append(opts, cel.ASTValidators(stringFormatValidator{}))
opts = append(opts, cel.ASTValidators(stringFormatValidator{maxPrecision: maxPrecision}))
}
}
return opts
Expand Down