Skip to content

Commit 84ed6ee

Browse files
authored
feat: finish up update_policy and update_policies (#48)
Signed-off-by: ffyuanda <46557895+ffyuanda@users.noreply.github.com>
1 parent 09f1c9c commit 84ed6ee

File tree

2 files changed

+114
-5
lines changed

2 files changed

+114
-5
lines changed

casbin_sqlalchemy_adapter/adapter.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Filter:
4343
v5 = []
4444

4545

46-
class Adapter(persist.Adapter):
46+
class Adapter(persist.Adapter, persist.adapters.UpdateAdapter):
4747
"""the interface for Casbin adapters."""
4848

4949
def __init__(self, engine, db_class=None, filtered=False):
@@ -152,7 +152,7 @@ def remove_policy(self, sec, ptype, rule):
152152
return True if r > 0 else False
153153

154154
def remove_policies(self, sec, ptype, rules):
155-
"""removes a policy rules from the storage."""
155+
"""remove policy rules from the storage."""
156156
if not rules:
157157
return
158158
with self._session_scope() as session:
@@ -168,15 +168,65 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
168168
This is part of the Auto-Save feature.
169169
"""
170170
with self._session_scope() as session:
171-
query = session.query(self._db_class)
172-
query = query.filter(self._db_class.ptype == ptype)
171+
query = (session
172+
.query(self._db_class)
173+
.filter(self._db_class.ptype == ptype))
174+
173175
if not (0 <= field_index <= 5):
174176
return False
175177
if not (1 <= field_index + len(field_values) <= 6):
176178
return False
177179
for i, v in enumerate(field_values):
178180
if v != '':
179-
query = query.filter(getattr(self._db_class, "v{}".format(field_index + i)) == v)
181+
v_value = getattr(self._db_class, "v{}".format(field_index + i))
182+
query = query.filter(v_value == v)
180183
r = query.delete()
181184

182185
return True if r > 0 else False
186+
187+
def update_policy(self, sec: str, ptype: str, old_rule: [str], new_rule: [str]) -> None:
188+
"""
189+
Update the old_rule with the new_rule in the database (storage).
190+
191+
:param sec: section type
192+
:param ptype: policy type
193+
:param old_rule: the old rule that needs to be modified
194+
:param new_rule: the new rule to replace the old rule
195+
196+
:return: None
197+
"""
198+
199+
with self._session_scope() as session:
200+
query = (session
201+
.query(self._db_class)
202+
.filter(self._db_class.ptype == ptype))
203+
204+
# locate the old rule
205+
for index, value in enumerate(old_rule):
206+
v_value = getattr(self._db_class, "v{}".format(index))
207+
query = query.filter(v_value == value)
208+
209+
# need the length of the longest_rule to perform overwrite
210+
longest_rule = old_rule if len(old_rule) > len(new_rule) else new_rule
211+
old_rule_line = query.one()
212+
213+
# overwrite the old rule with the new rule
214+
for index in range(len(longest_rule)):
215+
if index < len(new_rule):
216+
exec(f"old_rule_line.v{index} = new_rule[{index}]")
217+
else:
218+
exec(f"old_rule_line.v{index} = None")
219+
220+
def update_policies(self, sec: str, ptype: str, old_rules: [[str], ], new_rules: [[str], ]) -> None:
221+
"""
222+
Update the old_rules with the new_rules in the database (storage).
223+
224+
:param sec: section type
225+
:param ptype: policy type
226+
:param old_rules: the old rules that need to be modified
227+
:param new_rules: the new rules to replace the old rules
228+
229+
:return: None
230+
"""
231+
for i in range(len(old_rules)):
232+
self.update_policy(sec, ptype, old_rules[i], new_rules[i])

tests/test_adapter.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,62 @@ def test_filtered_policy(self):
279279
self.assertTrue(e.enforce('bob', 'data2', 'write'))
280280
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
281281
self.assertTrue(e.enforce('data2_admin', 'data2', 'write'))
282+
283+
def test_update_policy(self):
284+
e = get_enforcer()
285+
example_p = ['mike', 'cookie', 'eat']
286+
287+
self.assertTrue(e.enforce('alice', 'data1', 'read'))
288+
e.update_policy(['alice', 'data1', 'read'], ['alice', 'data1', 'no_read'])
289+
self.assertFalse(e.enforce('alice', 'data1', 'read'))
290+
291+
self.assertFalse(e.enforce('bob', 'data1', 'read'))
292+
e.add_policy(example_p)
293+
e.update_policy(example_p, ['bob', 'data1', 'read'])
294+
self.assertTrue(e.enforce('bob', 'data1', 'read'))
295+
296+
self.assertFalse(e.enforce('bob', 'data1', 'write'))
297+
e.update_policy(['bob', 'data1', 'read'], ['bob', 'data1', 'write'])
298+
self.assertTrue(e.enforce('bob', 'data1', 'write'))
299+
300+
self.assertTrue(e.enforce('bob', 'data2', 'write'))
301+
e.update_policy(['bob', 'data2', 'write'], ['bob', 'data2', 'read'])
302+
self.assertFalse(e.enforce('bob', 'data2', 'write'))
303+
304+
self.assertTrue(e.enforce('bob', 'data2', 'read'))
305+
e.update_policy(['bob', 'data2', 'read'], ['carl', 'data2', 'write'])
306+
self.assertFalse(e.enforce('bob', 'data2', 'write'))
307+
308+
self.assertTrue(e.enforce('carl', 'data2', 'write'))
309+
e.update_policy(['carl', 'data2', 'write'], ['carl', 'data2', 'no_write'])
310+
self.assertFalse(e.enforce('bob', 'data2', 'write'))
311+
312+
def test_update_policies(self):
313+
e = get_enforcer()
314+
315+
old_rule_0 = ['alice', 'data1', 'read']
316+
old_rule_1 = ['bob', 'data2', 'write']
317+
old_rule_2 = ['data2_admin', 'data2', 'read']
318+
old_rule_3 = ['data2_admin', 'data2', 'write']
319+
320+
new_rule_0 = ['alice', 'data_test', 'read']
321+
new_rule_1 = ['bob', 'data_test', 'write']
322+
new_rule_2 = ['data2_admin', 'data_test', 'read']
323+
new_rule_3 = ['data2_admin', 'data_test', 'write']
324+
325+
old_rules = [old_rule_0, old_rule_1, old_rule_2, old_rule_3]
326+
new_rules = [new_rule_0, new_rule_1, new_rule_2, new_rule_3]
327+
328+
e.update_policies(old_rules, new_rules)
329+
330+
self.assertFalse(e.enforce('alice', 'data1', 'read'))
331+
self.assertTrue(e.enforce('alice', 'data_test', 'read'))
332+
333+
self.assertFalse(e.enforce('bob', 'data2', 'write'))
334+
self.assertTrue(e.enforce('bob', 'data_test', 'write'))
335+
336+
self.assertFalse(e.enforce('data2_admin', 'data2', 'read'))
337+
self.assertTrue(e.enforce('data2_admin', 'data_test', 'read'))
338+
339+
self.assertFalse(e.enforce('data2_admin', 'data2', 'write'))
340+
self.assertTrue(e.enforce('data2_admin', 'data_test', 'write'))

0 commit comments

Comments
 (0)