Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pprint import pprint
from re import sub as re_sub
from time import time
from typing import Union
from typing import Union, List

import datajoint as dj
from datajoint.condition import make_condition
Expand Down Expand Up @@ -532,16 +532,17 @@ def fetch_nwb(
if isinstance(self, dict):
raise ValueError("Try replacing Merge.method with Merge().method")
restriction = restriction or self.restriction or True
sources = set((self & restriction).fetch(self._reserved_sk))
merge_restriction = extract_merge_id(restriction)
sources = set((self & merge_restriction).fetch(self._reserved_sk))
nwb_list = []
merge_ids = []
for source in sources:
source_restr = (
self & {self._reserved_sk: source} & restriction
self & {self._reserved_sk: source} & merge_restriction
).fetch("KEY")
nwb_list.extend(
(self & source_restr)
.merge_restrict_class(restriction, permit_multiple_rows=True)
.merge_restrict_class(restriction, permit_multiple_rows=True, add_invalid_restrict=False)
.fetch_nwb()
)
if return_merge_ids:
Expand Down Expand Up @@ -737,10 +738,10 @@ def merge_get_parent_class(self, source: str) -> dj.Table:
return ret

def merge_restrict_class(
self, key: dict, permit_multiple_rows: bool = False
self, key: dict, permit_multiple_rows: bool = False, add_invalid_restrict=True
) -> dj.Table:
"""Returns native parent class, restricted with key."""
parent = self.merge_get_parent(key)
parent = self.merge_get_parent(key, add_invalid_restrict=add_invalid_restrict)
parent_key = parent.fetch("KEY", as_dict=True)

if not permit_multiple_rows and len(parent_key) > 1:
Expand Down Expand Up @@ -860,3 +861,35 @@ def delete_downstream_merge(
table = table if isinstance(table, dj.Table) else table()

return table.delete_downstream_parts(**kwargs)

def extract_merge_id(restriction):
Comment thread
samuelbray32 marked this conversation as resolved.
Outdated
"""Utility function to extract merge_id from a restriction

Parameters
----------
restriction : str, dict, or dj.condition.AndList
A datajoint restriction

Returns
-------
restriction
A restriction containing only the merge_id key
"""
if restriction is None:
return None
if isinstance(restriction, dict):
if merge_id := restriction.get("merge_id"):
return {"merge_id": merge_id}
else:
return {}
merge_restr = []
if isinstance(restriction, dj.condition.AndList) or isinstance(restriction, List):
merge_id_list = [extract_merge_id(r) for r in restriction]
merge_restr = [x for x in merge_id_list if x is not None]
elif isinstance(restriction, str):
parsed = [x.split(")")[0] for x in restriction.split("(") if x]
merge_restr = dj.condition.AndList([x for x in parsed if "merge_id" in x])

if len(merge_restr) == 0:
return True
return merge_restr