Skip to content

Commit 25b7d39

Browse files
tanasecucliciuChimou0
authored andcommitted
feat: Added support for async watcher callbacks apache#340 (apache#341)
1 parent 4cd113c commit 25b7d39

File tree

2 files changed

+215
-12
lines changed

2 files changed

+215
-12
lines changed

casbin/async_internal_enforcer.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import copy
15+
import inspect
1516

1617
from casbin.core_enforcer import CoreEnforcer
1718
from casbin.model import Model, FunctionMap
@@ -105,8 +106,12 @@ async def save_policy(self):
105106
await self.adapter.save_policy(self.model)
106107

107108
if self.watcher:
108-
if callable(getattr(self.watcher, "update_for_save_policy", None)):
109-
self.watcher.update_for_save_policy(self.model)
109+
update_for_save_policy = getattr(self.watcher, "update_for_save_policy", None)
110+
if callable(update_for_save_policy):
111+
if inspect.iscoroutinefunction(update_for_save_policy):
112+
await update_for_save_policy(self.model)
113+
else:
114+
update_for_save_policy(self.model)
110115
else:
111116
self.watcher.update()
112117

@@ -122,8 +127,12 @@ async def _add_policy(self, sec, ptype, rule):
122127
return False
123128

124129
if self.watcher and self.auto_notify_watcher:
125-
if callable(getattr(self.watcher, "update_for_add_policy", None)):
126-
self.watcher.update_for_add_policy(sec, ptype, rule)
130+
update_for_add_policy = getattr(self.watcher, "update_for_add_policy", None)
131+
if callable(update_for_add_policy):
132+
if inspect.iscoroutinefunction(update_for_add_policy):
133+
await update_for_add_policy(sec, ptype, rule)
134+
else:
135+
update_for_add_policy(sec, ptype, rule)
127136
else:
128137
self.watcher.update()
129138

@@ -144,8 +153,12 @@ async def _add_policies(self, sec, ptype, rules):
144153
return False
145154

146155
if self.watcher and self.auto_notify_watcher:
147-
if callable(getattr(self.watcher, "update_for_add_policies", None)):
148-
self.watcher.update_for_add_policies(sec, ptype, rules)
156+
update_for_add_policies = getattr(self.watcher, "update_for_add_policies", None)
157+
if callable(update_for_add_policies):
158+
if inspect.iscoroutinefunction(update_for_add_policies):
159+
await update_for_add_policies(sec, ptype, rules)
160+
else:
161+
update_for_add_policies(sec, ptype, rules)
149162
else:
150163
self.watcher.update()
151164

@@ -224,8 +237,12 @@ async def _remove_policy(self, sec, ptype, rule):
224237
return False
225238

226239
if self.watcher and self.auto_notify_watcher:
227-
if callable(getattr(self.watcher, "update_for_remove_policy", None)):
228-
self.watcher.update_for_remove_policy(sec, ptype, rule)
240+
update_for_remove_policy = getattr(self.watcher, "update_for_remove_policy", None)
241+
if callable(update_for_remove_policy):
242+
if inspect.iscoroutinefunction(update_for_remove_policy):
243+
await update_for_remove_policy(sec, ptype, rule)
244+
else:
245+
update_for_remove_policy(sec, ptype, rule)
229246
else:
230247
self.watcher.update()
231248

@@ -246,8 +263,12 @@ async def _remove_policies(self, sec, ptype, rules):
246263
return False
247264

248265
if self.watcher and self.auto_notify_watcher:
249-
if callable(getattr(self.watcher, "update_for_remove_policies", None)):
250-
self.watcher.update_for_remove_policies(sec, ptype, rules)
266+
update_for_remove_policies = getattr(self.watcher, "update_for_remove_policies", None)
267+
if callable(update_for_remove_policies):
268+
if inspect.iscoroutinefunction(update_for_remove_policies):
269+
await update_for_remove_policies(sec, ptype, rules)
270+
else:
271+
update_for_remove_policies(sec, ptype, rules)
251272
else:
252273
self.watcher.update()
253274

@@ -265,8 +286,12 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
265286
return False
266287

267288
if self.watcher and self.auto_notify_watcher:
268-
if callable(getattr(self.watcher, "update_for_remove_filtered_policy", None)):
269-
self.watcher.update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
289+
update_for_remove_filtered_policy = getattr(self.watcher, "update_for_remove_filtered_policy", None)
290+
if callable(update_for_remove_filtered_policy):
291+
if inspect.iscoroutinefunction(update_for_remove_filtered_policy):
292+
await update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
293+
else:
294+
update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
270295
else:
271296
self.watcher.update()
272297

tests/test_watcher_ex.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import casbin
1616
from tests.test_enforcer import get_examples, TestCaseBase
17+
from unittest import IsolatedAsyncioTestCase
1718

1819

1920
class SampleWatcher:
@@ -113,6 +114,103 @@ def start_watch(self):
113114
pass
114115

115116

117+
class AsyncSampleWatcher:
118+
def __init__(self):
119+
self.callback = None
120+
self.notify_message = None
121+
122+
async def close(self):
123+
pass
124+
125+
async def set_update_callback(self, callback):
126+
"""
127+
sets the callback function to be called when the policy is updated
128+
:param callable callback: callback(event)
129+
- event: event received from the rabbitmq
130+
:return:
131+
"""
132+
self.callback = callback
133+
134+
async def update(self, msg):
135+
"""
136+
update the policy
137+
"""
138+
self.notify_message = msg
139+
return True
140+
141+
async def update_for_add_policy(self, section, ptype, *params):
142+
"""
143+
update for add policy
144+
:param section: section
145+
:param ptype: policy type
146+
:param params: other params
147+
:return: True if updated
148+
"""
149+
message = "called add policy"
150+
return await self.update(message)
151+
152+
async def update_for_remove_policy(self, section, ptype, *params):
153+
"""
154+
update for remove policy
155+
:param section: section
156+
:param ptype: policy type
157+
:param params: other params
158+
:return: True if updated
159+
"""
160+
message = "called remove policy"
161+
return await self.update(message)
162+
163+
async def update_for_remove_filtered_policy(self, section, ptype, field_index, *params):
164+
"""
165+
update for remove filtered policy
166+
:param section: section
167+
:param ptype: policy type
168+
:param field_index: field index
169+
:param params: other params
170+
:return:
171+
"""
172+
message = "called remove filtered policy"
173+
return await self.update(message)
174+
175+
async def update_for_save_policy(self, model: casbin.Model):
176+
"""
177+
update for save policy
178+
:param model: casbin model
179+
:return:
180+
"""
181+
message = "called save policy"
182+
return await self.update(message)
183+
184+
async def update_for_add_policies(self, section, ptype, *params):
185+
"""
186+
update for add policies
187+
:param section: section
188+
:param ptype: policy type
189+
:param params: other params
190+
:return:
191+
"""
192+
message = "called add policies"
193+
return await self.update(message)
194+
195+
async def update_for_remove_policies(self, section, ptype, *params):
196+
"""
197+
update for remove policies
198+
:param section: section
199+
:param ptype: policy type
200+
:param params: other params
201+
:return:
202+
"""
203+
message = "called remove policies"
204+
return await self.update(message)
205+
206+
async def start_watch(self):
207+
"""
208+
starts the watch thread
209+
:return:
210+
"""
211+
pass
212+
213+
116214
class TestWatcherEx(TestCaseBase):
117215
def get_enforcer(self, model=None, adapter=None):
118216
return casbin.Enforcer(
@@ -187,3 +285,83 @@ def test_auto_notify_disabled(self):
187285

188286
e.remove_policies(rules)
189287
self.assertEqual(w.notify_message, None)
288+
289+
290+
class TestAsyncWatcherEx(IsolatedAsyncioTestCase):
291+
def get_enforcer(self, model=None, adapter=None):
292+
return casbin.AsyncEnforcer(
293+
model,
294+
adapter,
295+
)
296+
297+
async def test_auto_notify_enabled(self):
298+
e = self.get_enforcer(
299+
get_examples("basic_model.conf"),
300+
get_examples("basic_policy.csv"),
301+
)
302+
await e.load_policy()
303+
304+
w = AsyncSampleWatcher()
305+
e.set_watcher(w)
306+
e.enable_auto_notify_watcher(True)
307+
308+
await e.save_policy()
309+
self.assertEqual(w.notify_message, "called save policy")
310+
311+
await e.add_policy("admin", "data1", "read")
312+
self.assertEqual(w.notify_message, "called add policy")
313+
314+
await e.remove_policy("admin", "data1", "read")
315+
self.assertEqual(w.notify_message, "called remove policy")
316+
317+
await e.remove_filtered_policy(1, "data1")
318+
self.assertEqual(w.notify_message, "called remove filtered policy")
319+
320+
rules = [
321+
["jack", "data4", "read"],
322+
["katy", "data4", "write"],
323+
["leyo", "data4", "read"],
324+
["ham", "data4", "write"],
325+
]
326+
await e.add_policies(rules)
327+
self.assertEqual(w.notify_message, "called add policies")
328+
329+
await e.remove_policies(rules)
330+
self.assertEqual(w.notify_message, "called remove policies")
331+
332+
async def test_auto_notify_disabled(self):
333+
e = self.get_enforcer(
334+
get_examples("basic_model.conf"),
335+
get_examples("basic_policy.csv"),
336+
)
337+
await e.load_policy()
338+
339+
w = SampleWatcher()
340+
e.set_watcher(w)
341+
e.enable_auto_notify_watcher(False)
342+
343+
await e.save_policy()
344+
self.assertEqual(w.notify_message, "called save policy")
345+
346+
w.notify_message = None
347+
348+
await e.add_policy("admin", "data1", "read")
349+
self.assertEqual(w.notify_message, None)
350+
351+
await e.remove_policy("admin", "data1", "read")
352+
self.assertEqual(w.notify_message, None)
353+
354+
await e.remove_filtered_policy(1, "data1")
355+
self.assertEqual(w.notify_message, None)
356+
357+
rules = [
358+
["jack", "data4", "read"],
359+
["katy", "data4", "write"],
360+
["leyo", "data4", "read"],
361+
["ham", "data4", "write"],
362+
]
363+
await e.add_policies(rules)
364+
self.assertEqual(w.notify_message, None)
365+
366+
await e.remove_policies(rules)
367+
self.assertEqual(w.notify_message, None)

0 commit comments

Comments
 (0)