Skip to content

Commit f613905

Browse files
committed
feat(task_sdk): add AssetEventCollectionResponse
1 parent b44653d commit f613905

9 files changed

Lines changed: 186 additions & 75 deletions

File tree

airflow/api_fastapi/execution_api/datamodels/asset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def redact_extra(cls, v: dict):
7676
return redact(v)
7777

7878

79+
class AssetEventCollectionResponse(BaseModel):
80+
"""Collection of AssetEventResponse."""
81+
82+
asset_events: list[AssetEventResponse]
83+
84+
7985
class AssetProfile(BaseModel):
8086
"""
8187
Profile of an Asset.

airflow/api_fastapi/execution_api/routes/asset_events.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919

2020
from typing import Annotated
2121

22-
from fastapi import HTTPException, Query, status
22+
from fastapi import Query, status
2323
from sqlalchemy import and_, select
2424

2525
from airflow.api_fastapi.common.db.common import SessionDep
2626
from airflow.api_fastapi.common.router import AirflowRouter
27-
from airflow.api_fastapi.execution_api.datamodels.asset import AssetEventResponse
27+
from airflow.api_fastapi.execution_api.datamodels.asset import (
28+
AssetEventCollectionResponse,
29+
)
2830
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel
2931

3032
# TODO: Add dependency on JWT token
@@ -36,41 +38,22 @@
3638
)
3739

3840

39-
class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
40-
"""
41-
List-like interface to lazily access AssetEvent rows.
42-
43-
:meta private:
44-
"""
45-
46-
@staticmethod
47-
def _rebuild_select(stmt: TextClause) -> Select:
48-
return select(AssetEvent).from_statement(stmt)
49-
50-
@staticmethod
51-
def _process_row(row: Row) -> AssetEvent:
52-
return row[0]
53-
54-
55-
def _get_asset_event_through_sql_clause(
41+
def _get_asset_events_through_sql_clauses(
5642
*, join_clause, where_clause, session: SessionDep
57-
) -> AssetEventResponse:
58-
asset_event = LazyAssetEventSelectSequence.from_select(
59-
select(AssetEvent).join(join_clause).where(where_clause),
60-
order_by=[AssetEvent.timestamp],
61-
session=session,
43+
) -> AssetEventCollectionResponse:
44+
asset_events = session.scalars(
45+
select(AssetEvent).join(join_clause).where(where_clause).order_by(AssetEvent.timestamp)
6246
)
63-
_raise_if_not_found(asset_event=asset_event, msg="Not found")
64-
return AssetEventResponse.model_validate(asset_event)
47+
return AssetEventCollectionResponse.model_validate({"asset_events": asset_events or []})
6548

6649

6750
@router.get("/by-asset-name-uri")
6851
def get_asset_event_by_asset_name_uri(
6952
name: Annotated[str, Query(description="The name of the Asset")],
7053
uri: Annotated[str, Query(description="The URI of the Asset")],
7154
session: SessionDep,
72-
) -> AssetEventResponse:
73-
return _get_asset_event_through_sql_clause(
55+
) -> AssetEventCollectionResponse:
56+
return _get_asset_events_through_sql_clauses(
7457
join_clause=AssetEvent.asset,
7558
where_clause=and_(AssetModel.name == name, AssetModel.uri == uri),
7659
session=session,
@@ -81,8 +64,8 @@ def get_asset_event_by_asset_name_uri(
8164
def get_asset_event_by_uri(
8265
uri: Annotated[str, Query(description="The URI of the Asset")],
8366
session: SessionDep,
84-
) -> AssetEventResponse:
85-
return _get_asset_event_through_sql_clause(
67+
) -> AssetEventCollectionResponse:
68+
return _get_asset_events_through_sql_clauses(
8669
join_clause=AssetEvent.asset,
8770
where_clause=and_(AssetModel.uri == uri, AssetModel.active.has()),
8871
session=session,
@@ -93,8 +76,8 @@ def get_asset_event_by_uri(
9376
def get_asset_event_by_name(
9477
name: Annotated[str, Query(description="The name of the Asset")],
9578
session: SessionDep,
96-
) -> AssetEventResponse:
97-
return _get_asset_event_through_sql_clause(
79+
) -> AssetEventCollectionResponse:
80+
return _get_asset_events_through_sql_clauses(
9881
join_clause=AssetEvent.asset,
9982
where_clause=and_(AssetModel.uri == name, AssetModel.active.has()),
10083
session=session,
@@ -105,20 +88,9 @@ def get_asset_event_by_name(
10588
def get_asset_event_by_alias_name(
10689
name: Annotated[str, Query(description="The name of the Asset Alias")],
10790
session: SessionDep,
108-
) -> AssetEventResponse:
109-
return _get_asset_event_through_sql_clause(
91+
) -> AssetEventCollectionResponse:
92+
return _get_asset_events_through_sql_clauses(
11093
join_clause=AssetEvent.source_aliases,
11194
where_clause=(AssetAliasModel.name == name),
11295
session=session,
11396
)
114-
115-
116-
def _raise_if_not_found(asset_event, msg):
117-
if asset_event is None:
118-
raise HTTPException(
119-
status.HTTP_404_NOT_FOUND,
120-
detail={
121-
"reason": "not_found",
122-
"message": msg,
123-
},
124-
)

task_sdk/src/airflow/sdk/api/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRuntimeCheckPayload
3636
from airflow.sdk import __version__
3737
from airflow.sdk.api.datamodels._generated import (
38-
AssetEventResponse,
38+
AssetEventCollectionResponse,
3939
AssetResponse,
4040
ConnectionResponse,
4141
DagRunType,
@@ -340,7 +340,7 @@ def __init__(self, client: Client):
340340

341341
def get(
342342
self, name: str | None = None, uri: str | None = None, alias_name: str | None = None
343-
) -> AssetResponse:
343+
) -> AssetEventCollectionResponse:
344344
"""Get Asset value from the API server."""
345345
if name and uri:
346346
resp = self.client.get("asset-events/by-asset-name-uri", params={"name": name, "uri": uri})
@@ -353,7 +353,7 @@ def get(
353353
else:
354354
raise ValueError("Either `name`, `uri` or `alias_name` must be provided")
355355

356-
return AssetEventResponse.model_validate_json(resp.read())
356+
return AssetEventCollectionResponse.model_validate_json(resp.read())
357357

358358

359359
class BearerAuth(httpx.Auth):

task_sdk/src/airflow/sdk/api/datamodels/_generated.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#
21
# Licensed to the Apache Software Foundation (ASF) under one
32
# or more contributor license agreements. See the NOTICE file
43
# distributed with this work for additional information
@@ -370,3 +369,11 @@ class TITerminalStatePayload(BaseModel):
370369
)
371370
state: TerminalStateNonSuccess
372371
end_date: Annotated[datetime, Field(title="End Date")]
372+
373+
374+
class AssetEventCollectionResponse(BaseModel):
375+
"""
376+
Collection of AssetEventResponse.
377+
"""
378+
379+
asset_events: Annotated[list[AssetEventResponse], Field(title="Asset Events")]

task_sdk/src/airflow/sdk/definitions/asset/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import contextlib
21+
import datetime
2122
import logging
2223
import operator
2324
import os
@@ -694,6 +695,39 @@ def as_expression(self) -> Any:
694695
return {"all": [o.as_expression() for o in self.objects]}
695696

696697

698+
@attrs.define(kw_only=True)
699+
class DagRunAssetReference:
700+
run_id: str
701+
dag_id: str
702+
start_date: datetime.datetime
703+
state: str
704+
end_date: datetime.datetime | None = None
705+
706+
logical_date: datetime.datetime | None = None
707+
data_interval_start: datetime.datetime | None = None
708+
data_interval_end: datetime.datetime | None = None
709+
710+
711+
@attrs.define(kw_only=True)
712+
class AssetEvent:
713+
"""Representation of asset event to be triggered by an asset alias."""
714+
715+
id: int
716+
asset_id: int
717+
718+
created_dagruns: list[DagRunAssetReference]
719+
timestamp: datetime.datetime
720+
721+
uri: str | None = None
722+
name: str | None = None
723+
group: str | None = None
724+
extra: dict[str, Any] | None = None
725+
source_task_id: str | None = None
726+
source_dag_id: str | None = None
727+
source_run_id: str | None = None
728+
source_map_index: int | None = None
729+
730+
697731
@attrs.define
698732
class AssetAliasEvent:
699733
"""Representation of asset event to be triggered by an asset alias."""

task_sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pydantic import BaseModel, ConfigDict, Field, JsonValue
5252

5353
from airflow.sdk.api.datamodels._generated import (
54-
AssetEventResponse,
54+
AssetEventCollectionResponse,
5555
AssetResponse,
5656
BundleInfo,
5757
ConnectionResponse,
@@ -105,23 +105,28 @@ def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
105105
return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult")
106106

107107

108-
class AssetEventResult(AssetEventResponse):
109-
"""Response to ReadXCom request."""
108+
class AssetEventCollectionResult(AssetEventCollectionResponse):
109+
"""Response to GetAssetEvent request."""
110110

111-
type: Literal["AssetEventResult"] = "AssetEventResult"
111+
type: Literal["AssetEventCollectionResult"] = "AssetEventCollectionResult"
112112

113113
@classmethod
114-
def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult:
114+
def from_asset_event_collection_response(
115+
cls, asset_event_collection_response: AssetEventCollectionResponse
116+
) -> AssetEventCollectionResult:
115117
"""
116-
Get AssetEventResult from AssetEventResponse.
118+
Get AssetEventCollectionResult from AssetEventCollectionResponse.
117119
118120
AssetEventCollectionResponse is autogenerated from the API schema, so we need to convert it to AssetEventCollectionResult
119121
for communication between the Supervisor and the task process.
120122
"""
121123
# Exclude defaults to avoid sending unnecessary data
122124
# Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True
123125
# to avoid sending unset fields (which are defaults in our case).
124-
return cls(**asset_event_response.model_dump(exclude_defaults=True), type="AssetEventResult")
126+
return cls(
127+
**asset_event_collection_response.model_dump(exclude_defaults=True),
128+
type="AssetEventCollectionResult",
129+
)
125130

126131

127132
class XComResult(XComResponse):
@@ -204,7 +209,7 @@ class OKResponse(BaseModel):
204209
ToTask = Annotated[
205210
Union[
206211
AssetResult,
207-
AssetEventResult,
212+
AssetEventCollectionResult,
208213
ConnectionResult,
209214
ErrorResponse,
210215
PrevSuccessfulDagRunResult,

task_sdk/src/airflow/sdk/execution_time/context.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,21 @@
2424
import attrs
2525
import structlog
2626

27-
from airflow.models.asset import AssetEvent
2827
from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
2928
from airflow.sdk.definitions._internal.types import NOTSET
3029
from airflow.sdk.definitions.asset import (
3130
Asset,
3231
AssetAlias,
3332
AssetAliasEvent,
3433
AssetAliasUniqueKey,
34+
AssetEvent,
3535
AssetNameRef,
3636
AssetRef,
3737
AssetUniqueKey,
3838
AssetUriRef,
3939
BaseAssetUniqueKey,
4040
)
4141
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
42-
from airflow.sdk.execution_time.comms import (
43-
AssetEventResult,
44-
ErrorResponse,
45-
GetAssetEventByAliasName,
46-
GetAssetEventByName,
47-
)
4842

4943
if TYPE_CHECKING:
5044
from uuid import UUID
@@ -314,16 +308,16 @@ def __init__(self, inlets: list) -> None:
314308
def __iter__(self) -> Iterator[Asset | AssetAlias]:
315309
return iter(self._inlets)
316310

317-
AssetEventCollectionResult,
318-
ErrorResponse,
319-
GetAssetEventByAliasName,
320-
GetAssetEventByName,
321311
def __len__(self) -> int:
322312
return len(self._inlets)
323313

324314
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef):
325315
from airflow.sdk.definitions.asset import Asset
326316
from airflow.sdk.execution_time.comms import (
317+
AssetEventCollectionResult,
318+
ErrorResponse,
319+
GetAssetEventByAliasName,
320+
GetAssetEventByName,
327321
GetAssetEventByNameUri,
328322
GetAssetEventByUri,
329323
)
@@ -362,8 +356,8 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef):
362356
raise AirflowRuntimeError(msg)
363357

364358
if TYPE_CHECKING:
365-
assert isinstance(msg, AssetEventResult)
366-
return AssetEvent(**msg.model_dump(excldue={"type"}))
359+
assert isinstance(msg, AssetEventCollectionResult)
360+
return [AssetEvent(**event) for event in msg.model_dump()["asset_events"]]
367361

368362

369363
@cache # Prevent multiple API access.

task_sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
VariableResponse,
6262
)
6363
from airflow.sdk.execution_time.comms import (
64-
AssetEventResult,
64+
AssetEventCollectionResult,
6565
AssetResult,
6666
ConnectionResult,
6767
DeferTask,
@@ -831,19 +831,27 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
831831
resp = asset_result.model_dump_json(exclude_unset=True).encode()
832832
elif isinstance(msg, GetAssetEventByNameUri):
833833
asset_event_resp = self.client.asset_events.get(uri=msg.uri, name=msg.name)
834-
asset_event_result = AssetEventResult.from_asset_event_response(asset_event_resp)
834+
asset_event_result = AssetEventCollectionResult.from_asset_event_collection_response(
835+
asset_event_resp
836+
)
835837
resp = asset_event_result.model_dump_json(exclude_unset=True).encode()
836838
elif isinstance(msg, GetAssetEventByName):
837839
asset_event_resp = self.client.asset_events.get(name=msg.name)
838-
asset_event_result = AssetEventResult.from_asset_event_response(asset_event_resp)
840+
asset_event_result = AssetEventCollectionResult.from_asset_event_collection_response(
841+
asset_event_resp
842+
)
839843
resp = asset_event_result.model_dump_json(exclude_unset=True).encode()
840844
elif isinstance(msg, GetAssetEventByUri):
841845
asset_event_resp = self.client.asset_events.get(uri=msg.uri)
842-
asset_event_result = AssetEventResult.from_asset_event_response(asset_event_resp)
846+
asset_event_result = AssetEventCollectionResult.from_asset_event_collection_response(
847+
asset_event_resp
848+
)
843849
resp = asset_event_result.model_dump_json(exclude_unset=True).encode()
844850
elif isinstance(msg, GetAssetEventByAliasName):
845851
asset_event_resp = self.client.asset_events.get(name=msg.alias_name)
846-
asset_event_result = AssetEventResult.from_asset_event_response(asset_event_resp)
852+
asset_event_result = AssetEventCollectionResult.from_asset_event_collection_response(
853+
asset_event_resp
854+
)
847855
resp = asset_event_result.model_dump_json(exclude_unset=True).encode()
848856
elif isinstance(msg, GetPrevSuccessfulDagRun):
849857
dagrun_resp = self.client.task_instances.get_previous_successful_dagrun(self.id)

0 commit comments

Comments
 (0)