Skip to content

Commit ac99cd8

Browse files
authored
feat(ai-rate-limiting): add expression-based limit strategy (#13191)
1 parent f8e88f3 commit ac99cd8

File tree

2 files changed

+724
-2
lines changed

2 files changed

+724
-2
lines changed

apisix/plugins/ai-rate-limiting.lua

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ local require = require
1818
local setmetatable = setmetatable
1919
local ipairs = ipairs
2020
local type = type
21+
local pairs = pairs
22+
local pcall = pcall
23+
local load = load
24+
local math_floor = math.floor
25+
local math_huge = math.huge
2126
local core = require("apisix.core")
2227
local limit_count = require("apisix.plugins.limit-count.init")
2328

@@ -61,10 +66,19 @@ local schema = {
6166
show_limit_quota_header = {type = "boolean", default = true},
6267
limit_strategy = {
6368
type = "string",
64-
enum = {"total_tokens", "prompt_tokens", "completion_tokens"},
69+
enum = {"total_tokens", "prompt_tokens", "completion_tokens", "expression"},
6570
default = "total_tokens",
6671
description = "The strategy to limit the tokens"
6772
},
73+
cost_expr = {
74+
type = "string",
75+
minLength = 1,
76+
description = "Lua arithmetic expression for dynamic token cost calculation. "
77+
.. "Variables are injected from the LLM API raw usage response fields. "
78+
.. "Missing variables default to 0. "
79+
.. "Only valid when limit_strategy is 'expression'. "
80+
.. "Example: input_tokens + cache_creation_input_tokens + output_tokens",
81+
},
6882
instances = {
6983
type = "array",
7084
items = instance_limit_schema,
@@ -136,8 +150,42 @@ local limit_conf_cache = core.lrucache.new({
136150
})
137151

138152

153+
-- safe math functions allowed in cost expressions
154+
local expr_safe_env = {
155+
math = math,
156+
abs = math.abs,
157+
ceil = math.ceil,
158+
floor = math.floor,
159+
max = math.max,
160+
min = math.min,
161+
}
162+
163+
local function compile_cost_expr(expr_str)
164+
local fn_code = "return " .. expr_str
165+
-- validate syntax by loading first
166+
local fn, err = load(fn_code, "cost_expr", "t", expr_safe_env)
167+
if not fn then
168+
return nil, err
169+
end
170+
return fn_code
171+
end
172+
173+
139174
function _M.check_schema(conf)
140-
return core.schema.check(schema, conf)
175+
local ok, err = core.schema.check(schema, conf)
176+
if not ok then
177+
return false, err
178+
end
179+
if conf.limit_strategy == "expression" then
180+
if not conf.cost_expr or conf.cost_expr == "" then
181+
return false, "cost_expr is required when limit_strategy is 'expression'"
182+
end
183+
local _, compile_err = compile_cost_expr(conf.cost_expr)
184+
if compile_err then
185+
return false, "invalid cost_expr: " .. compile_err
186+
end
187+
end
188+
return true
141189
end
142190

143191

@@ -264,7 +312,57 @@ function _M.check_instance_status(conf, ctx, instance_name)
264312
end
265313

266314

315+
local function eval_cost_expr(conf_cost_expr, raw)
316+
local fn_code = "return " .. conf_cost_expr
317+
-- build environment: safe math + usage variables (missing vars default to 0)
318+
local env = setmetatable({}, {
319+
__index = function(_, k)
320+
local v = expr_safe_env[k]
321+
if v ~= nil then
322+
return v
323+
end
324+
return 0
325+
end
326+
})
327+
for k, v in pairs(raw) do
328+
if type(v) == "number" and not expr_safe_env[k] then
329+
env[k] = v
330+
end
331+
end
332+
local fn, err = load(fn_code, "cost_expr", "t", env)
333+
if not fn then
334+
return nil, "failed to compile cost_expr: " .. err
335+
end
336+
local ok, result = pcall(fn)
337+
if not ok then
338+
return nil, "failed to evaluate cost_expr: " .. result
339+
end
340+
if type(result) ~= "number" then
341+
return nil, "cost_expr must return a number, got: " .. type(result)
342+
end
343+
if result ~= result or result == math_huge or result == -math_huge then
344+
return nil, "cost_expr returned non-finite value"
345+
end
346+
if result < 0 then
347+
result = 0
348+
end
349+
return math_floor(result + 0.5)
350+
end
351+
267352
local function get_token_usage(conf, ctx)
353+
if conf.limit_strategy == "expression" then
354+
local raw = ctx.llm_raw_usage
355+
if not raw then
356+
return
357+
end
358+
local result, err = eval_cost_expr(conf.cost_expr, raw)
359+
if not result then
360+
core.log.error(err)
361+
return
362+
end
363+
return result
364+
end
365+
268366
local usage = ctx.ai_token_usage
269367
if not usage then
270368
return
@@ -288,6 +386,10 @@ function _M.log(conf, ctx)
288386
core.log.error("failed to get token usage for llm service")
289387
return
290388
end
389+
if used_tokens == 0 then
390+
core.log.info("token usage is 0, skip rate limiting")
391+
return
392+
end
291393

292394
core.log.info("instance name: ", instance_name, " used tokens: ", used_tokens)
293395

0 commit comments

Comments
 (0)