Skip to content

Commit 5007079

Browse files
feat(limit-count): support configuring multiple rules
Signed-off-by: Abhishek Choudhary <shreemaan.abhishek@gmail.com>
1 parent 3ba27f6 commit 5007079

File tree

3 files changed

+516
-43
lines changed

3 files changed

+516
-43
lines changed

apisix/plugins/limit-count/init.lua

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ local get_phase = ngx.get_phase
2525
local tonumber = tonumber
2626
local type = type
2727
local tostring = tostring
28+
local str_format = string.format
2829

2930
local limit_redis_cluster_new
3031
local limit_redis_new
@@ -82,6 +83,28 @@ local schema = {
8283
{type = "string"},
8384
},
8485
},
86+
rules = {
87+
type = "array",
88+
items = {
89+
type = "object",
90+
properties = {
91+
count = {
92+
oneOf = {
93+
{type = "integer", exclusiveMinimum = 0},
94+
{type = "string"},
95+
},
96+
},
97+
time_window = {
98+
oneOf = {
99+
{type = "integer", exclusiveMinimum = 0},
100+
{type = "string"},
101+
},
102+
},
103+
key = {type = "string"},
104+
},
105+
required = {"count", "time_window", "key"},
106+
},
107+
},
85108
group = {type = "string"},
86109
key = {type = "string", default = "remote_addr"},
87110
key_type = {type = "string",
@@ -102,7 +125,14 @@ local schema = {
102125
allow_degradation = {type = "boolean", default = false},
103126
show_limit_quota_header = {type = "boolean", default = true}
104127
},
105-
required = {"count", "time_window"},
128+
oneOf = {
129+
{
130+
required = {"count", "time_window"},
131+
},
132+
{
133+
required = {"rules"},
134+
}
135+
},
106136
["if"] = {
107137
properties = {
108138
policy = {
@@ -180,51 +210,34 @@ function _M.check_schema(conf, schema_type)
180210
end
181211
end
182212

183-
return true
184-
end
185-
186-
187-
local function create_limit_obj(conf, ctx, plugin_name)
188-
core.log.info("create new " .. plugin_name .. " plugin instance")
189-
190-
local count = conf.count
191-
if type(count) == "string" then
192-
local err, _
193-
count, err, _ = core.utils.resolve_var(count, ctx.var)
194-
if err then
195-
return nil, "could not resolve vars in count: " .. err
196-
end
197-
count = tonumber(count)
198-
if not count then
199-
return nil, "resolved count is not a number: " .. tostring(count)
213+
local keys = {}
214+
for _, rule in ipairs(conf.rules or {}) do
215+
if keys[rule.key] then
216+
return false, str_format("duplicate key '%s' in rules", rule.key)
200217
end
218+
keys[rule.key] = true
201219
end
202220

203-
local time_window = conf.time_window
204-
if type(time_window) == "string" then
205-
local err, _
206-
time_window, err, _ = core.utils.resolve_var(time_window, ctx.var)
207-
if err then
208-
return nil, "could not resolve vars in time_window: " .. err
209-
end
210-
time_window = tonumber(time_window)
211-
if not time_window then
212-
return nil, "resolved time_window is not a number: " .. tostring(time_window)
213-
end
214-
end
221+
return true
222+
end
223+
215224

216-
core.log.info("limit count: ", count, ", time_window: ", time_window)
225+
local function create_limit_obj(conf, rule, plugin_name)
226+
core.log.info("create new " .. plugin_name .. " plugin instance",
227+
", rule: ", core.json.delay_encode(rule, true))
217228

218229
if not conf.policy or conf.policy == "local" then
219-
return limit_local_new("plugin-" .. plugin_name, count, time_window)
230+
return limit_local_new("plugin-" .. plugin_name, rule.count,
231+
rule.time_window)
220232
end
221233

222234
if conf.policy == "redis" then
223-
return limit_redis_new("plugin-" .. plugin_name, count, time_window, conf)
235+
return limit_redis_new("plugin-" .. plugin_name, rule.count, rule.time_window, conf)
224236
end
225237

226238
if conf.policy == "redis-cluster" then
227-
return limit_redis_cluster_new("plugin-" .. plugin_name, count, time_window, conf)
239+
return limit_redis_cluster_new("plugin-" .. plugin_name, rule.count,
240+
rule.time_window, conf)
228241
end
229242

230243
return nil
@@ -258,11 +271,71 @@ local function gen_limit_key(conf, ctx, key)
258271
end
259272

260273

261-
function _M.rate_limit(conf, ctx, name, cost, dry_run)
262-
core.log.info("ver: ", ctx.conf_version)
263-
core.log.info("conf: ", core.json.delay_encode(conf, true))
274+
local function resolve_var(ctx, value)
275+
if type(value) == "string" then
276+
local err, _
277+
value, err, _ = core.utils.resolve_var(value, ctx.var)
278+
if err then
279+
return nil, "could not resolve var for value: " .. value .. ", err: " .. err
280+
end
281+
value = tonumber(value)
282+
if not value then
283+
return nil, "resolved value is not a number: " .. tostring(value)
284+
end
285+
end
286+
return value
287+
end
264288

265-
local lim, err = create_limit_obj(conf, ctx, name)
289+
290+
local function get_rules(ctx, conf)
291+
if not conf.rules then
292+
local count, err = resolve_var(ctx, conf.count)
293+
if err then
294+
return nil, err
295+
end
296+
local time_window, err2 = resolve_var(ctx, conf.time_window)
297+
if err2 then
298+
return nil, err2
299+
end
300+
return {
301+
{
302+
count = count,
303+
time_window = time_window,
304+
key = conf.key,
305+
key_type = conf.key_type,
306+
}
307+
}
308+
end
309+
310+
local rules = {}
311+
for _, rule in ipairs(conf.rules) do
312+
local count, err = resolve_var(ctx, rule.count)
313+
if err then
314+
goto CONTINUE
315+
end
316+
local time_window, err2 = resolve_var(ctx, rule.time_window)
317+
if err2 then
318+
goto CONTINUE
319+
end
320+
local key, _, n_resolved = core.utils.resolve_var(rule.key, ctx.var)
321+
if n_resolved == 0 then
322+
goto CONTINUE
323+
end
324+
core.table.insert(rules, {
325+
count = count,
326+
time_window = time_window,
327+
key_type = "constant",
328+
key = key,
329+
})
330+
331+
::CONTINUE::
332+
end
333+
return rules
334+
end
335+
336+
337+
local function run_rate_limit(conf, rule, ctx, name, cost, dry_run)
338+
local lim, err = create_limit_obj(conf, rule, name)
266339

267340
if not lim then
268341
core.log.error("failed to fetch limit.count object: ", err)
@@ -272,9 +345,9 @@ function _M.rate_limit(conf, ctx, name, cost, dry_run)
272345
return 500
273346
end
274347

275-
local conf_key = conf.key
348+
local conf_key = rule.key
276349
local key
277-
if conf.key_type == "var_combination" then
350+
if rule.key_type == "var_combination" then
278351
local err, n_resolved
279352
key, err, n_resolved = core.utils.resolve_var(conf_key, ctx.var)
280353
if err then
@@ -284,7 +357,7 @@ function _M.rate_limit(conf, ctx, name, cost, dry_run)
284357
if n_resolved == 0 then
285358
key = nil
286359
end
287-
elseif conf.key_type == "constant" then
360+
elseif rule.key_type == "constant" then
288361
key = conf_key
289362
else
290363
key = ctx.var[conf_key]
@@ -353,4 +426,25 @@ function _M.rate_limit(conf, ctx, name, cost, dry_run)
353426
end
354427

355428

429+
function _M.rate_limit(conf, ctx, name, cost, dry_run)
430+
core.log.info("ver: ", ctx.conf_version)
431+
432+
local rules, err = get_rules(ctx, conf)
433+
if not rules or #rules == 0 then
434+
core.log.error("failed to get rate limit rules: ", err)
435+
if conf.allow_degradation then
436+
return
437+
end
438+
return 500
439+
end
440+
441+
for _, rule in ipairs(rules) do
442+
local code, msg = run_rate_limit(conf, rule, ctx, name, cost, dry_run)
443+
if code then
444+
return code, msg
445+
end
446+
end
447+
end
448+
449+
356450
return _M

0 commit comments

Comments
 (0)