Skip to content

Commit 2d6809b

Browse files
committed
Add history graph endpoint
1 parent 52d4f3f commit 2d6809b

5 files changed

Lines changed: 876 additions & 0 deletions

File tree

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
from datetime import datetime
2+
from typing import (
3+
Any,
4+
Optional,
5+
)
6+
7+
from sqlalchemy import (
8+
func,
9+
select,
10+
)
11+
from sqlalchemy.orm import Session
12+
13+
from galaxy.model import (
14+
Dataset,
15+
DatasetCollection,
16+
HistoryDatasetAssociation,
17+
HistoryDatasetCollectionAssociation,
18+
Job,
19+
JobToOutputDatasetAssociation,
20+
JobToOutputDatasetCollectionAssociation,
21+
ToolRequest,
22+
ToolRequestImplicitCollectionAssociation,
23+
)
24+
from galaxy.schema.history_graph import (
25+
ExternalRef,
26+
GraphEdge,
27+
GraphNode,
28+
HistoryGraphResponse,
29+
)
30+
from galaxy.security.idencoding import IdEncodingHelper
31+
32+
33+
def _iso(dt: Optional[datetime]) -> Optional[datetime]:
34+
return dt
35+
36+
37+
class HistoryGraphBuilder:
38+
def __init__(self, sa_session: Session, security: IdEncodingHelper, history_id: int, encoded_history_id: str):
39+
self.sa_session = sa_session
40+
self.security = security
41+
self.history_id = history_id
42+
self.encoded_history_id = encoded_history_id
43+
44+
self.nodes: list[GraphNode] = []
45+
self.edges: list[GraphEdge] = []
46+
self.external_refs: list[ExternalRef] = []
47+
48+
self.referenced_hda_ids: set[int] = set()
49+
self.referenced_hdca_ids: set[int] = set()
50+
51+
def build(self) -> HistoryGraphResponse:
52+
tool_requests = self._query_tool_requests()
53+
self._derive_input_edges(tool_requests)
54+
self._query_output_edges(tool_requests)
55+
self._fetch_hda_metadata()
56+
self._fetch_hdca_metadata()
57+
self._deduplicate_edges()
58+
self._sort()
59+
return HistoryGraphResponse(
60+
history_id=self.encoded_history_id,
61+
node_count=len(self.nodes),
62+
edge_count=len(self.edges),
63+
nodes=self.nodes,
64+
edges=self.edges,
65+
external_refs=self.external_refs,
66+
)
67+
68+
def _query_tool_requests(self) -> list[dict[str, Any]]:
69+
first_job = (
70+
select(
71+
Job.tool_request_id,
72+
func.min(Job.id).label("first_job_id"),
73+
)
74+
.where(Job.tool_request_id.isnot(None))
75+
.group_by(Job.tool_request_id)
76+
.subquery()
77+
)
78+
stmt = (
79+
select(
80+
ToolRequest.id,
81+
ToolRequest.state,
82+
ToolRequest.request,
83+
Job.tool_id,
84+
Job.tool_version,
85+
)
86+
.outerjoin(first_job, ToolRequest.id == first_job.c.tool_request_id)
87+
.outerjoin(Job, Job.id == first_job.c.first_job_id)
88+
.where(ToolRequest.history_id == self.history_id)
89+
)
90+
tool_requests = []
91+
for row in self.sa_session.execute(stmt):
92+
tool_id = row.tool_id
93+
if tool_id and tool_id.startswith("__"):
94+
continue
95+
tool_requests.append(
96+
{
97+
"id": row.id,
98+
"state": row.state,
99+
"request": row.request,
100+
"tool_id": tool_id,
101+
"tool_version": row.tool_version,
102+
}
103+
)
104+
self.nodes.append(
105+
GraphNode(
106+
id=f"r{self.security.encode_id(row.id)}",
107+
type="tool_request",
108+
state=row.state,
109+
tool_id=tool_id,
110+
tool_version=row.tool_version,
111+
)
112+
)
113+
return tool_requests
114+
115+
def _derive_input_edges(self, tool_requests: list[dict[str, Any]]):
116+
for tr in tool_requests:
117+
request_data = tr["request"]
118+
if not request_data:
119+
continue
120+
encoded_tr_id = f"r{self.security.encode_id(tr['id'])}"
121+
self._walk_request(request_data, encoded_tr_id, name=None)
122+
123+
def _walk_request(self, obj: Any, target: str, name: Optional[str]):
124+
if isinstance(obj, dict):
125+
src = obj.get("src")
126+
raw_id = obj.get("id")
127+
if src in ("hda", "hdca") and isinstance(raw_id, int):
128+
if src == "hda":
129+
self.referenced_hda_ids.add(raw_id)
130+
encoded_source = f"d{self.security.encode_id(raw_id)}"
131+
else:
132+
self.referenced_hdca_ids.add(raw_id)
133+
encoded_source = f"c{self.security.encode_id(raw_id)}"
134+
self.edges.append(
135+
GraphEdge(
136+
source=encoded_source,
137+
target=target,
138+
type="input",
139+
name=name,
140+
)
141+
)
142+
else:
143+
for key, value in obj.items():
144+
self._walk_request(value, target, name=key)
145+
elif isinstance(obj, list):
146+
for item in obj:
147+
self._walk_request(item, target, name=name)
148+
149+
def _query_output_edges(self, tool_requests: list[dict[str, Any]]):
150+
tr_ids = [tr["id"] for tr in tool_requests]
151+
if not tr_ids:
152+
return
153+
self._query_implicit_collection_outputs(tr_ids)
154+
self._query_hda_outputs(tr_ids)
155+
self._query_explicit_collection_outputs(tr_ids)
156+
157+
def _query_implicit_collection_outputs(self, tr_ids: list[int]):
158+
stmt = (
159+
select(
160+
ToolRequestImplicitCollectionAssociation.tool_request_id,
161+
ToolRequestImplicitCollectionAssociation.dataset_collection_id,
162+
ToolRequestImplicitCollectionAssociation.output_name,
163+
)
164+
.join(
165+
HistoryDatasetCollectionAssociation,
166+
ToolRequestImplicitCollectionAssociation.dataset_collection_id
167+
== HistoryDatasetCollectionAssociation.id,
168+
)
169+
.where(
170+
ToolRequestImplicitCollectionAssociation.tool_request_id.in_(tr_ids),
171+
HistoryDatasetCollectionAssociation.history_id == self.history_id,
172+
)
173+
)
174+
for row in self.sa_session.execute(stmt):
175+
self.referenced_hdca_ids.add(row.dataset_collection_id)
176+
self.edges.append(
177+
GraphEdge(
178+
source=f"r{self.security.encode_id(row.tool_request_id)}",
179+
target=f"c{self.security.encode_id(row.dataset_collection_id)}",
180+
type="output",
181+
name=row.output_name,
182+
)
183+
)
184+
185+
def _query_hda_outputs(self, tr_ids: list[int]):
186+
stmt = (
187+
select(
188+
Job.tool_request_id,
189+
JobToOutputDatasetAssociation.dataset_id,
190+
JobToOutputDatasetAssociation.name,
191+
)
192+
.distinct()
193+
.join(Job, JobToOutputDatasetAssociation.job_id == Job.id)
194+
.join(
195+
HistoryDatasetAssociation,
196+
JobToOutputDatasetAssociation.dataset_id == HistoryDatasetAssociation.id,
197+
)
198+
.where(
199+
Job.tool_request_id.in_(tr_ids),
200+
HistoryDatasetAssociation.history_id == self.history_id,
201+
)
202+
)
203+
for row in self.sa_session.execute(stmt):
204+
self.referenced_hda_ids.add(row.dataset_id)
205+
self.edges.append(
206+
GraphEdge(
207+
source=f"r{self.security.encode_id(row.tool_request_id)}",
208+
target=f"d{self.security.encode_id(row.dataset_id)}",
209+
type="output",
210+
name=row.name,
211+
)
212+
)
213+
214+
def _query_explicit_collection_outputs(self, tr_ids: list[int]):
215+
stmt = (
216+
select(
217+
Job.tool_request_id,
218+
JobToOutputDatasetCollectionAssociation.dataset_collection_id,
219+
JobToOutputDatasetCollectionAssociation.name,
220+
)
221+
.distinct()
222+
.join(Job, JobToOutputDatasetCollectionAssociation.job_id == Job.id)
223+
.join(
224+
HistoryDatasetCollectionAssociation,
225+
JobToOutputDatasetCollectionAssociation.dataset_collection_id
226+
== HistoryDatasetCollectionAssociation.id,
227+
)
228+
.where(
229+
Job.tool_request_id.in_(tr_ids),
230+
HistoryDatasetCollectionAssociation.history_id == self.history_id,
231+
)
232+
)
233+
for row in self.sa_session.execute(stmt):
234+
self.referenced_hdca_ids.add(row.dataset_collection_id)
235+
self.edges.append(
236+
GraphEdge(
237+
source=f"r{self.security.encode_id(row.tool_request_id)}",
238+
target=f"c{self.security.encode_id(row.dataset_collection_id)}",
239+
type="output",
240+
name=row.name,
241+
)
242+
)
243+
244+
def _fetch_hda_metadata(self):
245+
if not self.referenced_hda_ids:
246+
return
247+
stmt = (
248+
select(
249+
HistoryDatasetAssociation.id,
250+
HistoryDatasetAssociation.history_id,
251+
HistoryDatasetAssociation.hid,
252+
HistoryDatasetAssociation.name,
253+
HistoryDatasetAssociation._state,
254+
Dataset.state.label("dataset_state"),
255+
HistoryDatasetAssociation.extension,
256+
HistoryDatasetAssociation.deleted,
257+
HistoryDatasetAssociation.visible,
258+
HistoryDatasetAssociation.create_time,
259+
HistoryDatasetAssociation.update_time,
260+
)
261+
.join(Dataset, HistoryDatasetAssociation.dataset_id == Dataset.id)
262+
.where(HistoryDatasetAssociation.id.in_(self.referenced_hda_ids))
263+
)
264+
found: set[int] = set()
265+
for row in self.sa_session.execute(stmt):
266+
found.add(row.id)
267+
if row.history_id == self.history_id:
268+
state = row._state if row._state else row.dataset_state
269+
self.nodes.append(
270+
GraphNode(
271+
id=f"d{self.security.encode_id(row.id)}",
272+
type="dataset",
273+
hid=row.hid,
274+
name=row.name,
275+
state=state,
276+
extension=row.extension,
277+
deleted=row.deleted,
278+
visible=row.visible,
279+
create_time=_iso(row.create_time),
280+
update_time=_iso(row.update_time),
281+
)
282+
)
283+
else:
284+
self.external_refs.append(
285+
ExternalRef(
286+
id=f"d{self.security.encode_id(row.id)}",
287+
type="dataset",
288+
history_id=self.security.encode_id(row.history_id) if row.history_id else None,
289+
name=row.name,
290+
)
291+
)
292+
for missing_id in self.referenced_hda_ids - found:
293+
self.external_refs.append(
294+
ExternalRef(
295+
id=f"d{self.security.encode_id(missing_id)}",
296+
type="dataset",
297+
name=None,
298+
)
299+
)
300+
301+
def _fetch_hdca_metadata(self):
302+
if not self.referenced_hdca_ids:
303+
return
304+
stmt = (
305+
select(
306+
HistoryDatasetCollectionAssociation.id,
307+
HistoryDatasetCollectionAssociation.history_id,
308+
HistoryDatasetCollectionAssociation.hid,
309+
HistoryDatasetCollectionAssociation.name,
310+
HistoryDatasetCollectionAssociation.deleted,
311+
HistoryDatasetCollectionAssociation.visible,
312+
HistoryDatasetCollectionAssociation.create_time,
313+
HistoryDatasetCollectionAssociation.update_time,
314+
DatasetCollection.collection_type,
315+
DatasetCollection.element_count,
316+
)
317+
.join(
318+
DatasetCollection,
319+
HistoryDatasetCollectionAssociation.collection_id == DatasetCollection.id,
320+
)
321+
.where(HistoryDatasetCollectionAssociation.id.in_(self.referenced_hdca_ids))
322+
)
323+
found: set[int] = set()
324+
for row in self.sa_session.execute(stmt):
325+
found.add(row.id)
326+
if row.history_id == self.history_id:
327+
self.nodes.append(
328+
GraphNode(
329+
id=f"c{self.security.encode_id(row.id)}",
330+
type="collection",
331+
hid=row.hid,
332+
name=row.name,
333+
collection_type=row.collection_type,
334+
element_count=row.element_count,
335+
deleted=row.deleted,
336+
visible=row.visible,
337+
create_time=_iso(row.create_time),
338+
update_time=_iso(row.update_time),
339+
)
340+
)
341+
else:
342+
self.external_refs.append(
343+
ExternalRef(
344+
id=f"c{self.security.encode_id(row.id)}",
345+
type="collection",
346+
history_id=self.security.encode_id(row.history_id) if row.history_id else None,
347+
name=row.name,
348+
)
349+
)
350+
for missing_id in self.referenced_hdca_ids - found:
351+
self.external_refs.append(
352+
ExternalRef(
353+
id=f"c{self.security.encode_id(missing_id)}",
354+
type="collection",
355+
name=None,
356+
)
357+
)
358+
359+
def _deduplicate_edges(self):
360+
seen: set[tuple] = set()
361+
unique: list[GraphEdge] = []
362+
for edge in self.edges:
363+
key = (edge.source, edge.target, edge.type, edge.name)
364+
if key not in seen:
365+
seen.add(key)
366+
unique.append(edge)
367+
self.edges = unique
368+
369+
def _sort(self):
370+
type_order = {"dataset": 0, "collection": 1, "tool_request": 2}
371+
self.nodes.sort(key=lambda n: (type_order.get(n.type, 99), n.hid or 0, n.create_time or "", n.id))
372+
373+
edge_type_order = {"input": 0, "output": 1}
374+
self.edges.sort(key=lambda e: (edge_type_order.get(e.type, 99), e.source, e.target))
375+
376+
ref_type_order = {"dataset": 0, "collection": 1}
377+
self.external_refs.sort(key=lambda r: (ref_type_order.get(r.type, 99), r.id))

0 commit comments

Comments
 (0)