|
3 | 3 | from pprint import pprint |
4 | 4 | from re import sub as re_sub |
5 | 5 | from time import time |
6 | | -from typing import Union |
| 6 | +from typing import Union, List |
7 | 7 |
|
8 | 8 | import datajoint as dj |
9 | 9 | from datajoint.condition import make_condition |
@@ -532,17 +532,22 @@ def fetch_nwb( |
532 | 532 | if isinstance(self, dict): |
533 | 533 | raise ValueError("Try replacing Merge.method with Merge().method") |
534 | 534 | restriction = restriction or self.restriction or True |
535 | | - sources = set((self & restriction).fetch(self._reserved_sk)) |
| 535 | + merge_restriction = self.extract_merge_id(restriction) |
| 536 | + sources = set((self & merge_restriction).fetch(self._reserved_sk)) |
536 | 537 | nwb_list = [] |
537 | 538 | merge_ids = [] |
538 | 539 | for source in sources: |
539 | 540 | source_restr = ( |
540 | | - self & {self._reserved_sk: source} & restriction |
| 541 | + self & {self._reserved_sk: source} & merge_restriction |
541 | 542 | ).fetch("KEY") |
542 | 543 | nwb_list.extend( |
543 | | - self.merge_restrict_class( |
544 | | - source_restr, permit_multiple_rows=True |
545 | | - ).fetch_nwb() |
| 544 | + (self & source_restr) |
| 545 | + .merge_restrict_class( |
| 546 | + restriction, |
| 547 | + permit_multiple_rows=True, |
| 548 | + add_invalid_restrict=False, |
| 549 | + ) |
| 550 | + .fetch_nwb() |
546 | 551 | ) |
547 | 552 | if return_merge_ids: |
548 | 553 | merge_ids.extend([k[self._reserved_pk] for k in source_restr]) |
@@ -738,10 +743,15 @@ def merge_get_parent_class(self, source: str) -> dj.Table: |
738 | 743 | return ret |
739 | 744 |
|
740 | 745 | def merge_restrict_class( |
741 | | - self, key: dict, permit_multiple_rows: bool = False |
| 746 | + self, |
| 747 | + key: dict, |
| 748 | + permit_multiple_rows: bool = False, |
| 749 | + add_invalid_restrict=True, |
742 | 750 | ) -> dj.Table: |
743 | 751 | """Returns native parent class, restricted with key.""" |
744 | | - parent = self.merge_get_parent(key) |
| 752 | + parent = self.merge_get_parent( |
| 753 | + key, add_invalid_restrict=add_invalid_restrict |
| 754 | + ) |
745 | 755 | parent_key = parent.fetch("KEY", as_dict=True) |
746 | 756 |
|
747 | 757 | if not permit_multiple_rows and len(parent_key) > 1: |
@@ -834,6 +844,48 @@ def super_delete(self, warn=True, *args, **kwargs): |
834 | 844 | self._log_delete(start=time(), super_delete=True) |
835 | 845 | super().delete(*args, **kwargs) |
836 | 846 |
|
| 847 | + @classmethod |
| 848 | + def extract_merge_id(cls, restriction) -> Union[dict, list]: |
| 849 | + """Utility function to extract merge_id from a restriction |
| 850 | +
|
| 851 | + Removes all other restricted attributes, and defaults to a |
| 852 | + universal set (either empty dict or True) when there is no |
| 853 | + merge_id present in the input, relying on parent func to |
| 854 | + restrict on secondary or part-parent key(s). |
| 855 | +
|
| 856 | + Parameters |
| 857 | + ---------- |
| 858 | + restriction : str, dict, or dj.condition.AndList |
| 859 | + A datajoint restriction |
| 860 | +
|
| 861 | + Returns |
| 862 | + ------- |
| 863 | + restriction |
| 864 | + A restriction containing only the merge_id key |
| 865 | + """ |
| 866 | + if restriction is None: |
| 867 | + return None |
| 868 | + if isinstance(restriction, dict): |
| 869 | + if merge_id := restriction.get("merge_id"): |
| 870 | + return {"merge_id": merge_id} |
| 871 | + else: |
| 872 | + return {} |
| 873 | + merge_restr = [] |
| 874 | + if isinstance(restriction, dj.condition.AndList) or isinstance( |
| 875 | + restriction, List |
| 876 | + ): |
| 877 | + merge_id_list = [cls.extract_merge_id(r) for r in restriction] |
| 878 | + merge_restr = [x for x in merge_id_list if x is not None] |
| 879 | + elif isinstance(restriction, str): |
| 880 | + parsed = [x.split(")")[0] for x in restriction.split("(") if x] |
| 881 | + merge_restr = dj.condition.AndList( |
| 882 | + [x for x in parsed if "merge_id" in x] |
| 883 | + ) |
| 884 | + |
| 885 | + if len(merge_restr) == 0: |
| 886 | + return True |
| 887 | + return merge_restr |
| 888 | + |
837 | 889 |
|
838 | 890 | _Merge = Merge |
839 | 891 |
|
|
0 commit comments