Skip to content

Commit 9c0e407

Browse files
authored
Merge pull request #1122 from ethho/dev-tests-plat-146-aggr
PLAT-146: Migrate test_aggr_regressions.py
2 parents 90209a6 + 110d642 commit 9c0e407

5 files changed

Lines changed: 247 additions & 1 deletion

File tree

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
schema_advanced,
2323
schema_adapted,
2424
schema_external,
25+
schema_uuid as schema_uuid_module,
2526
)
2627

2728

@@ -307,6 +308,20 @@ def schema_ext(connection_test, stores_config, enable_filepath_feature):
307308
schema.drop()
308309

309310

311+
@pytest.fixture
312+
def schema_uuid(connection_test):
313+
schema = dj.Schema(
314+
PREFIX + "_test1",
315+
context=schema_uuid_module.LOCALS_UUID,
316+
connection=connection_test,
317+
)
318+
schema(schema_uuid_module.Basic)
319+
schema(schema_uuid_module.Topic)
320+
schema(schema_uuid_module.Item)
321+
yield schema
322+
schema.drop()
323+
324+
310325
@pytest.fixture(scope="session")
311326
def http_client():
312327
# Initialize httpClient with relevant timeout.

tests/schema_aggr_regress.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import datajoint as dj
2+
import itertools
3+
import inspect
4+
5+
6+
class R(dj.Lookup):
7+
definition = """
8+
r : char(1)
9+
"""
10+
contents = zip("ABCDFGHIJKLMNOPQRST")
11+
12+
13+
class Q(dj.Lookup):
14+
definition = """
15+
-> R
16+
"""
17+
contents = zip("ABCDFGH")
18+
19+
20+
class S(dj.Lookup):
21+
definition = """
22+
-> R
23+
s : int
24+
"""
25+
contents = itertools.product("ABCDF", range(10))
26+
27+
28+
class A(dj.Lookup):
29+
definition = """
30+
id: int
31+
"""
32+
contents = zip(range(10))
33+
34+
35+
class B(dj.Lookup):
36+
definition = """
37+
-> A
38+
id2: int
39+
"""
40+
contents = zip(range(5), range(5, 10))
41+
42+
43+
class X(dj.Lookup):
44+
definition = """
45+
id: int
46+
"""
47+
contents = zip(range(10))
48+
49+
50+
LOCALS_AGGR_REGRESS = {k: v for k, v in locals().items() if inspect.isclass(v)}
51+
__all__ = list(LOCALS_AGGR_REGRESS)

tests/schema_uuid.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import uuid
2+
import inspect
3+
import datajoint as dj
4+
from . import PREFIX, CONN_INFO
5+
6+
top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000")
7+
8+
9+
class Basic(dj.Manual):
10+
definition = """
11+
item : uuid
12+
---
13+
number : int
14+
"""
15+
16+
17+
class Topic(dj.Manual):
18+
definition = """
19+
# A topic for items
20+
topic_id : uuid # internal identification of a topic, reflects topic name
21+
---
22+
topic : varchar(8000) # full topic name used to generate the topic id
23+
"""
24+
25+
def add(self, topic):
26+
"""add a new topic with a its UUID"""
27+
self.insert1(
28+
dict(topic_id=uuid.uuid5(top_level_namespace_id, topic), topic=topic)
29+
)
30+
31+
32+
class Item(dj.Computed):
33+
definition = """
34+
item_id : uuid # internal identification of
35+
---
36+
-> Topic
37+
word : varchar(8000)
38+
"""
39+
40+
key_source = Topic # test key source that is not instantiated
41+
42+
def make(self, key):
43+
for word in ("Habenula", "Hippocampus", "Hypothalamus", "Hypophysis"):
44+
self.insert1(
45+
dict(key, word=word, item_id=uuid.uuid5(key["topic_id"], word))
46+
)
47+
48+
49+
LOCALS_UUID = {k: v for k, v in locals().items() if inspect.isclass(v)}
50+
__all__ = list(LOCALS_UUID)

tests/test_aggr_regressions.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Regression tests for issues 386, 449, 484, and 558 — all related to processing complex aggregations and projections.
3+
"""
4+
5+
import pytest
6+
import datajoint as dj
7+
from . import PREFIX
8+
import uuid
9+
from .schema_uuid import Topic, Item, top_level_namespace_id
10+
from .schema_aggr_regress import R, Q, S, A, B, X, LOCALS_AGGR_REGRESS
11+
12+
13+
@pytest.fixture(scope="function")
14+
def schema_aggr_reg(connection_test):
15+
context = LOCALS_AGGR_REGRESS
16+
schema = dj.Schema(
17+
PREFIX + "_aggr_regress",
18+
context=context,
19+
connection=connection_test,
20+
)
21+
schema(R)
22+
schema(Q)
23+
schema(S)
24+
yield schema
25+
schema.drop()
26+
27+
28+
@pytest.fixture(scope="function")
29+
def schema_aggr_reg_with_abx(connection_test):
30+
context = LOCALS_AGGR_REGRESS
31+
schema = dj.Schema(
32+
PREFIX + "_aggr_regress_with_abx",
33+
context=context,
34+
connection=connection_test,
35+
)
36+
schema(R)
37+
schema(Q)
38+
schema(S)
39+
schema(A)
40+
schema(B)
41+
schema(X)
42+
yield schema
43+
schema.drop()
44+
45+
46+
def test_issue386(schema_aggr_reg):
47+
"""
48+
--------------- ISSUE 386 -------------------
49+
Issue 386 resulted from the loss of aggregated attributes when the aggregation was used as the restrictor
50+
Q & (R.aggr(S, n='count(*)') & 'n=2')
51+
Error: Unknown column 'n' in HAVING
52+
"""
53+
result = R.aggr(S, n="count(*)") & "n=10"
54+
result = Q & result
55+
result.fetch()
56+
57+
58+
def test_issue449(schema_aggr_reg):
59+
"""
60+
---------------- ISSUE 449 ------------------
61+
Issue 449 arises from incorrect group by attributes after joining with a dj.U()
62+
"""
63+
result = dj.U("n") * R.aggr(S, n="max(s)")
64+
result.fetch()
65+
66+
67+
def test_issue484(schema_aggr_reg):
68+
"""
69+
---------------- ISSUE 484 -----------------
70+
Issue 484
71+
"""
72+
q = dj.U().aggr(S, n="max(s)")
73+
n = q.fetch("n")
74+
n = q.fetch1("n")
75+
q = dj.U().aggr(S, n="avg(s)")
76+
result = dj.U().aggr(q, m="max(n)")
77+
result.fetch()
78+
79+
80+
def test_union_join(schema_aggr_reg_with_abx):
81+
"""
82+
This test fails if it runs after TestIssue558.
83+
84+
https://github.com/datajoint/datajoint-python/issues/930
85+
"""
86+
A.insert(zip([100, 200, 300, 400, 500, 600]))
87+
B.insert([(100, 11), (200, 22), (300, 33), (400, 44)])
88+
q1 = B & "id < 300"
89+
q2 = B & "id > 300"
90+
91+
expected_data = [
92+
{"id": 0, "id2": 5},
93+
{"id": 1, "id2": 6},
94+
{"id": 2, "id2": 7},
95+
{"id": 3, "id2": 8},
96+
{"id": 4, "id2": 9},
97+
{"id": 100, "id2": 11},
98+
{"id": 200, "id2": 22},
99+
{"id": 400, "id2": 44},
100+
]
101+
102+
assert ((q1 + q2) * A).fetch(as_dict=True) == expected_data
103+
104+
105+
class TestIssue558:
106+
"""
107+
--------------- ISSUE 558 ------------------
108+
Issue 558 resulted from the fact that DataJoint saves subqueries and often combines a restriction followed
109+
by a projection into a single SELECT statement, which in several unusual cases produces unexpected results.
110+
"""
111+
112+
def test_issue558_part1(self, schema_aggr_reg_with_abx):
113+
q = (A - B).proj(id2="3")
114+
assert len(A - B) == len(q)
115+
116+
def test_issue558_part2(self, schema_aggr_reg_with_abx):
117+
d = dict(id=3, id2=5)
118+
assert len(X & d) == len((X & d).proj(id2="3"))
119+
120+
121+
def test_left_join_len(schema_uuid):
122+
Topic().add("jeff")
123+
Item.populate()
124+
Topic().add("jeff2")
125+
Topic().add("jeff3")
126+
q = Topic.join(
127+
Item - dict(topic_id=uuid.uuid5(top_level_namespace_id, "jeff")), left=True
128+
)
129+
qf = q.fetch()
130+
assert len(q) == len(qf)

tests/test_erd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_make_image(schema_simp):
5858

5959
def test_part_table_parsing(schema_simp):
6060
# https://github.com/datajoint/datajoint-python/issues/882
61-
erd = dj.Di(schema_simp)
61+
erd = dj.Di(schema_simp, context=LOCALS_SIMPLE)
6262
graph = erd._make_graph()
6363
assert "OutfitLaunch" in graph.nodes()
6464
assert "OutfitLaunch.OutfitPiece" in graph.nodes()

0 commit comments

Comments
 (0)