Skip to content

Commit 1480c24

Browse files
samuelbray32CBroz1
andauthored
Apply restrictions on parent tables in fetch_nwb (#1086)
* include upstream restriction in fetch_nwb * update changelog * lint * resolve case of string restrictions * move extract_merge_id to class method * Apply suggestions from code review Co-authored-by: Chris Broz <Chris.Broz@ucsf.edu> * fix lint --------- Co-authored-by: Chris Broz <Chris.Broz@ucsf.edu>
1 parent a711874 commit 1480c24

2 files changed

Lines changed: 61 additions & 8 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
2020
#1108
2121
- Add docstrings to all public methods #1076
2222
- Update DataJoint to 0.14.2 #1081
23+
- Allow restriction based on parent keys in `Merge.fetch_nwb()` #1086
2324

2425
### Pipelines
2526

src/spyglass/utils/dj_merge_tables.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pprint import pprint
44
from re import sub as re_sub
55
from time import time
6-
from typing import Union
6+
from typing import Union, List
77

88
import datajoint as dj
99
from datajoint.condition import make_condition
@@ -532,17 +532,22 @@ def fetch_nwb(
532532
if isinstance(self, dict):
533533
raise ValueError("Try replacing Merge.method with Merge().method")
534534
restriction = restriction or self.restriction or True
535-
sources = set((self & restriction).fetch(self._reserved_sk))
535+
merge_restriction = self.extract_merge_id(restriction)
536+
sources = set((self & merge_restriction).fetch(self._reserved_sk))
536537
nwb_list = []
537538
merge_ids = []
538539
for source in sources:
539540
source_restr = (
540-
self & {self._reserved_sk: source} & restriction
541+
self & {self._reserved_sk: source} & merge_restriction
541542
).fetch("KEY")
542543
nwb_list.extend(
543-
self.merge_restrict_class(
544-
source_restr, permit_multiple_rows=True
545-
).fetch_nwb()
544+
(self & source_restr)
545+
.merge_restrict_class(
546+
restriction,
547+
permit_multiple_rows=True,
548+
add_invalid_restrict=False,
549+
)
550+
.fetch_nwb()
546551
)
547552
if return_merge_ids:
548553
merge_ids.extend([k[self._reserved_pk] for k in source_restr])
@@ -738,10 +743,15 @@ def merge_get_parent_class(self, source: str) -> dj.Table:
738743
return ret
739744

740745
def merge_restrict_class(
741-
self, key: dict, permit_multiple_rows: bool = False
746+
self,
747+
key: dict,
748+
permit_multiple_rows: bool = False,
749+
add_invalid_restrict=True,
742750
) -> dj.Table:
743751
"""Returns native parent class, restricted with key."""
744-
parent = self.merge_get_parent(key)
752+
parent = self.merge_get_parent(
753+
key, add_invalid_restrict=add_invalid_restrict
754+
)
745755
parent_key = parent.fetch("KEY", as_dict=True)
746756

747757
if not permit_multiple_rows and len(parent_key) > 1:
@@ -834,6 +844,48 @@ def super_delete(self, warn=True, *args, **kwargs):
834844
self._log_delete(start=time(), super_delete=True)
835845
super().delete(*args, **kwargs)
836846

847+
@classmethod
848+
def extract_merge_id(cls, restriction) -> Union[dict, list]:
849+
"""Utility function to extract merge_id from a restriction
850+
851+
Removes all other restricted attributes, and defaults to a
852+
universal set (either empty dict or True) when there is no
853+
merge_id present in the input, relying on parent func to
854+
restrict on secondary or part-parent key(s).
855+
856+
Parameters
857+
----------
858+
restriction : str, dict, or dj.condition.AndList
859+
A datajoint restriction
860+
861+
Returns
862+
-------
863+
restriction
864+
A restriction containing only the merge_id key
865+
"""
866+
if restriction is None:
867+
return None
868+
if isinstance(restriction, dict):
869+
if merge_id := restriction.get("merge_id"):
870+
return {"merge_id": merge_id}
871+
else:
872+
return {}
873+
merge_restr = []
874+
if isinstance(restriction, dj.condition.AndList) or isinstance(
875+
restriction, List
876+
):
877+
merge_id_list = [cls.extract_merge_id(r) for r in restriction]
878+
merge_restr = [x for x in merge_id_list if x is not None]
879+
elif isinstance(restriction, str):
880+
parsed = [x.split(")")[0] for x in restriction.split("(") if x]
881+
merge_restr = dj.condition.AndList(
882+
[x for x in parsed if "merge_id" in x]
883+
)
884+
885+
if len(merge_restr) == 0:
886+
return True
887+
return merge_restr
888+
837889

838890
_Merge = Merge
839891

0 commit comments

Comments
 (0)