88from casbin .util import generate_g_function , SimpleEval , util
99
1010
11+ class EnforceContext :
12+ """
13+ EnforceContext is used as the first element of the parameter "rvals" in method "enforce"
14+ """
15+ def __init__ (self , rtype : str , ptype : str , etype : str , mtype : str ):
16+ self .rtype : str = rtype
17+ self .ptype : str = ptype
18+ self .etype : str = etype
19+ self .mtype : str = mtype
20+
21+
1122class CoreEnforcer :
1223 """CoreEnforcer defines the core functionality of an enforcer."""
1324
@@ -24,6 +35,11 @@ class CoreEnforcer:
2435 auto_save = False
2536 auto_build_role_links = False
2637
38+ rtype = 'r'
39+ ptype = 'p'
40+ etype = 'e'
41+ mtype = 'm'
42+
2743 def __init__ (self , model = None , adapter = None ):
2844 self .logger = logging .getLogger (__name__ )
2945 if isinstance (model , str ):
@@ -75,7 +91,7 @@ def init_with_model_and_adapter(self, m, adapter=None):
7591
7692 def _initialize (self ):
7793 self .rm_map = dict ()
78- self .eft = get_effector (self .model ["e" ]["e" ].value )
94+ self .eft = get_effector (self .model ["e" ][self . etype ].value )
7995 self .watcher = None
8096
8197 self .enabled = True
@@ -250,6 +266,15 @@ def add_named_domain_matching_func(self, ptype, fn):
250266
251267 return False
252268
269+ def new_enforce_context (self , suffix : str ) -> EnforceContext :
270+
271+ return EnforceContext (
272+ rtype = 'r' + suffix ,
273+ ptype = 'p' + suffix ,
274+ etype = 'e' + suffix ,
275+ mtype = 'm' + suffix ,
276+ )
277+
253278 def enforce (self , * rvals ):
254279 """decides whether a "subject" can access a "object" with the operation "action",
255280 input parameters are usually: (sub, obj, act).
@@ -273,19 +298,28 @@ def enforce_ex(self, *rvals):
273298 rm = ast .rm
274299 functions [key ] = generate_g_function (rm )
275300
301+ if len (rvals ) != 0 :
302+ if isinstance (rvals [0 ], EnforceContext ):
303+ enforce_context = rvals [0 ]
304+ self .rtype = enforce_context .rtype
305+ self .ptype = enforce_context .ptype
306+ self .etype = enforce_context .etype
307+ self .mtype = enforce_context .mtype
308+ rvals = rvals [1 :]
309+
276310 if "m" not in self .model .keys ():
277311 raise RuntimeError ("model is undefined" )
278312
279313 if "m" not in self .model ["m" ].keys ():
280314 raise RuntimeError ("model is undefined" )
281315
282- r_tokens = self .model ["r" ]["r" ].tokens
283- p_tokens = self .model ["p" ]["p" ].tokens
316+ r_tokens = self .model ["r" ][self . rtype ].tokens
317+ p_tokens = self .model ["p" ][self . ptype ].tokens
284318
285319 if len (r_tokens ) != len (rvals ):
286320 raise RuntimeError ("invalid request size" )
287321
288- exp_string = self .model ["m" ]["m" ].value
322+ exp_string = self .model ["m" ][self . mtype ].value
289323 has_eval = util .has_eval (exp_string )
290324 if not has_eval :
291325 expression = self ._get_expression (exp_string , functions )
@@ -294,11 +328,11 @@ def enforce_ex(self, *rvals):
294328
295329 r_parameters = dict (zip (r_tokens , rvals ))
296330
297- policy_len = len (self .model ["p" ]["p" ].policy )
331+ policy_len = len (self .model ["p" ][self . ptype ].policy )
298332
299333 explain_index = - 1
300334 if not 0 == policy_len :
301- for i , pvals in enumerate (self .model ["p" ]["p" ].policy ):
335+ for i , pvals in enumerate (self .model ["p" ][self . ptype ].policy ):
302336 if len (p_tokens ) != len (pvals ):
303337 raise RuntimeError ("invalid policy size" )
304338
@@ -327,8 +361,9 @@ def enforce_ex(self, *rvals):
327361 else :
328362 raise RuntimeError ("matcher result should be bool, int or float" )
329363
330- if "p_eft" in parameters .keys ():
331- eft = parameters ["p_eft" ]
364+ p_eft_key = self .ptype + "_eft"
365+ if p_eft_key in parameters .keys ():
366+ eft = parameters [p_eft_key ]
332367 if "allow" == eft :
333368 policy_effects .add (Effector .ALLOW )
334369 elif "deny" == eft :
@@ -353,7 +388,7 @@ def enforce_ex(self, *rvals):
353388
354389 parameters = r_parameters .copy ()
355390
356- for token in self .model ["p" ]["p" ].tokens :
391+ for token in self .model ["p" ][self . ptype ].tokens :
357392 parameters [token ] = ""
358393
359394 result = expression .eval (parameters )
@@ -380,7 +415,7 @@ def enforce_ex(self, *rvals):
380415
381416 explain_rule = []
382417 if explain_index != - 1 and explain_index < policy_len :
383- explain_rule = self .model ["p" ]["p" ].policy [explain_index ]
418+ explain_rule = self .model ["p" ][self . ptype ].policy [explain_index ]
384419
385420 return result , explain_rule
386421
0 commit comments