Skip to content

Commit d07c722

Browse files
committed
fix(task_sdk/context): add _get_asset_events_from_db for fixing tests
1 parent add9adf commit d07c722

6 files changed

Lines changed: 92 additions & 33 deletions

File tree

airflow/utils/context.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,37 @@
2929
)
3030

3131
import attrs
32-
from sqlalchemy import select
32+
from sqlalchemy import and_, select
3333

3434
from airflow.models.asset import (
35+
AssetAliasModel,
36+
AssetEvent,
3537
AssetModel,
3638
)
39+
from airflow.sdk.definitions.asset import (
40+
Asset,
41+
AssetAlias,
42+
AssetAliasUniqueKey,
43+
AssetNameRef,
44+
AssetRef,
45+
AssetUniqueKey,
46+
AssetUriRef,
47+
)
3748
from airflow.sdk.definitions.context import Context
3849
from airflow.sdk.execution_time.context import (
3950
ConnectionAccessor as ConnectionAccessorSDK,
4051
InletEventsAccessors as InletEventsAccessorsSDK,
4152
OutletEventAccessors as OutletEventAccessorsSDK,
4253
VariableAccessor as VariableAccessorSDK,
4354
)
55+
from airflow.utils.db import LazySelectSequence
4456
from airflow.utils.session import create_session
4557
from airflow.utils.types import NOTSET
4658

4759
if TYPE_CHECKING:
60+
from sqlalchemy.engine import Row
61+
from sqlalchemy.sql.expression import Select, TextClause
62+
4863
from airflow.sdk.types import OutletEventAccessorsProtocol
4964

5065
# NOTE: Please keep this in sync with the following:
@@ -150,6 +165,22 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
150165
return asset.to_public()
151166

152167

168+
class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
169+
"""
170+
List-like interface to lazily access AssetEvent rows.
171+
172+
:meta private:
173+
"""
174+
175+
@staticmethod
176+
def _rebuild_select(stmt: TextClause) -> Select:
177+
return select(AssetEvent).from_statement(stmt)
178+
179+
@staticmethod
180+
def _process_row(row: Row) -> AssetEvent:
181+
return row[0]
182+
183+
153184
@attrs.define(init=False)
154185
class InletEventsAccessors(InletEventsAccessorsSDK):
155186
"""
@@ -158,6 +189,37 @@ class InletEventsAccessors(InletEventsAccessorsSDK):
158189
:meta private:
159190
"""
160191

192+
def _get_asset_events_from_db(self, obj: Asset | AssetAlias | AssetRef):
193+
if isinstance(obj, Asset):
194+
asset = self._assets[AssetUniqueKey.from_asset(obj)]
195+
join_clause = AssetEvent.asset
196+
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
197+
elif isinstance(obj, AssetAlias):
198+
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
199+
join_clause = AssetEvent.source_aliases
200+
where_clause = AssetAliasModel.name == asset_alias.name
201+
elif isinstance(obj, AssetNameRef):
202+
try:
203+
asset = next(a for k, a in self._assets.items() if k.name == obj.name)
204+
except StopIteration:
205+
raise KeyError(obj) from None
206+
join_clause = AssetEvent.asset
207+
where_clause = and_(AssetModel.name == asset.name, AssetModel.active.has())
208+
elif isinstance(obj, AssetUriRef):
209+
try:
210+
asset = next(a for k, a in self._assets.items() if k.uri == obj.uri)
211+
except StopIteration:
212+
raise KeyError(obj) from None
213+
join_clause = AssetEvent.asset
214+
where_clause = and_(AssetModel.uri == asset.uri, AssetModel.active.has())
215+
216+
with create_session() as session:
217+
return LazyAssetEventSelectSequence.from_select(
218+
select(AssetEvent).join(join_clause).where(where_clause),
219+
order_by=[AssetEvent.timestamp],
220+
session=session,
221+
)
222+
161223

162224
def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
163225
"""

task_sdk/src/airflow/sdk/api/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRuntimeCheckPayload
3636
from airflow.sdk import __version__
3737
from airflow.sdk.api.datamodels._generated import (
38-
AssetEventCollectionResponse,
38+
AssetEventsResponse,
3939
AssetResponse,
4040
ConnectionResponse,
4141
DagRunType,
@@ -340,7 +340,7 @@ def __init__(self, client: Client):
340340

341341
def get(
342342
self, name: str | None = None, uri: str | None = None, alias_name: str | None = None
343-
) -> AssetEventCollectionResponse:
343+
) -> AssetEventsResponse:
344344
"""Get Asset event from the API server."""
345345
if name or uri:
346346
resp = self.client.get("asset-events/by-asset", params={"name": name, "uri": uri})
@@ -349,7 +349,7 @@ def get(
349349
else:
350350
raise ValueError("Either `name`, `uri` or `alias_name` must be provided")
351351

352-
return AssetEventCollectionResponse.model_validate_json(resp.read())
352+
return AssetEventsResponse.model_validate_json(resp.read())
353353

354354

355355
class BearerAuth(httpx.Auth):

task_sdk/src/airflow/sdk/api/datamodels/_generated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class AssetEventResponse(BaseModel):
368368
timestamp: Annotated[datetime, Field(title="Timestamp")]
369369

370370

371-
class AssetEventCollectionResponse(BaseModel):
371+
class AssetEventsResponse(BaseModel):
372372
"""
373373
Collection of AssetEventResponse.
374374
"""

task_sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pydantic import BaseModel, ConfigDict, Field, JsonValue
5252

5353
from airflow.sdk.api.datamodels._generated import (
54-
AssetEventCollectionResponse,
54+
AssetEventsResponse,
5555
AssetResponse,
5656
BundleInfo,
5757
ConnectionResponse,
@@ -105,27 +105,25 @@ def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
105105
return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult")
106106

107107

108-
class AssetEventCollectionResult(AssetEventCollectionResponse):
108+
class AssetEventsResult(AssetEventsResponse):
109109
"""Response to GetAssetEvent request."""
110110

111-
type: Literal["AssetEventCollectionResult"] = "AssetEventCollectionResult"
111+
type: Literal["AssetEventsResult"] = "AssetEventsResult"
112112

113113
@classmethod
114-
def from_asset_event_collection_response(
115-
cls, asset_event_collection_response: AssetEventCollectionResponse
116-
) -> AssetEventCollectionResult:
114+
def from_asset_events_response(cls, asset_events_response: AssetEventsResponse) -> AssetEventsResult:
117115
"""
118-
Get AssetEventCollectionResult from AssetEventCollectionResponse.
116+
Get AssetEventsResult from AssetEventsResponse.
119117
120-
AssetEventCollectionResponse is autogenerated from the API schema, so we need to convert it to AssetEventCollectionResult
118+
AssetEventsResponse is autogenerated from the API schema, so we need to convert it to AssetEventsResponse
121119
for communication between the Supervisor and the task process.
122120
"""
123121
# Exclude defaults to avoid sending unnecessary data
124122
# Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True
125123
# to avoid sending unset fields (which are defaults in our case).
126124
return cls(
127-
**asset_event_collection_response.model_dump(exclude_defaults=True),
128-
type="AssetEventCollectionResult",
125+
**asset_events_response.model_dump(exclude_defaults=True),
126+
type="AssetEventsResult",
129127
)
130128

131129

@@ -209,7 +207,7 @@ class OKResponse(BaseModel):
209207
ToTask = Annotated[
210208
Union[
211209
AssetResult,
212-
AssetEventCollectionResult,
210+
AssetEventsResult,
213211
ConnectionResult,
214212
ErrorResponse,
215213
PrevSuccessfulDagRunResult,

task_sdk/src/airflow/sdk/execution_time/context.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from airflow.sdk.definitions.context import Context
4949
from airflow.sdk.definitions.variable import Variable
5050
from airflow.sdk.execution_time.comms import (
51+
AssetEventsResult,
5152
AssetResult,
5253
ConnectionResult,
5354
PrevSuccessfulDagRunResponse,
@@ -313,13 +314,6 @@ def __len__(self) -> int:
313314

314315
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef):
315316
from airflow.sdk.definitions.asset import Asset
316-
from airflow.sdk.execution_time.comms import (
317-
AssetEventCollectionResult,
318-
ErrorResponse,
319-
GetAssetEventByAsset,
320-
GetAssetEventByAssetAlias,
321-
)
322-
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
323317

324318
if isinstance(key, int): # Support index access; it's easier for trivial cases.
325319
obj = self._inlets[key]
@@ -328,6 +322,17 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef):
328322
else:
329323
obj = key
330324

325+
return self._get_asset_events_from_db(obj)
326+
327+
# TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py
328+
def _get_asset_events_from_db(self, obj: Asset | AssetAlias | AssetRef) -> list[AssetEvent]:
329+
from airflow.sdk.execution_time.comms import (
330+
ErrorResponse,
331+
GetAssetEventByAsset,
332+
GetAssetEventByAssetAlias,
333+
)
334+
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
335+
331336
if isinstance(obj, Asset):
332337
asset = self._assets[AssetUniqueKey.from_asset(obj)]
333338
SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=asset.uri))
@@ -346,15 +351,13 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef):
346351
elif isinstance(obj, AssetAlias):
347352
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
348353
SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAssetAlias(alias_name=asset_alias.name))
349-
else:
350-
raise ValueError(key)
351354

352355
msg = SUPERVISOR_COMMS.get_message()
353356
if isinstance(msg, ErrorResponse):
354357
raise AirflowRuntimeError(msg)
355358

356359
if TYPE_CHECKING:
357-
assert isinstance(msg, AssetEventCollectionResult)
360+
assert isinstance(msg, AssetEventsResult)
358361
return [AssetEvent(**event) for event in msg.model_dump()["asset_events"]]
359362

360363

task_sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
VariableResponse,
6262
)
6363
from airflow.sdk.execution_time.comms import (
64-
AssetEventCollectionResult,
64+
AssetEventsResult,
6565
AssetResult,
6666
ConnectionResult,
6767
DeferTask,
@@ -829,15 +829,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
829829
resp = asset_result.model_dump_json(exclude_unset=True).encode()
830830
elif isinstance(msg, GetAssetEventByAsset):
831831
asset_event_resp = self.client.asset_events.get(uri=msg.uri, name=msg.name)
832-
asset_event_result = AssetEventCollectionResult.from_asset_event_collection_response(
833-
asset_event_resp
834-
)
832+
asset_event_result = AssetEventsResult.from_asset_events_response(asset_event_resp)
835833
resp = asset_event_result.model_dump_json(exclude_unset=True).encode()
836834
elif isinstance(msg, GetAssetEventByAssetAlias):
837835
asset_event_resp = self.client.asset_events.get(name=msg.alias_name)
838-
asset_event_result = AssetEventCollectionResult.from_asset_event_collection_response(
839-
asset_event_resp
840-
)
836+
asset_event_result = AssetEventsResult.from_asset_events_response(asset_event_resp)
841837
resp = asset_event_result.model_dump_json(exclude_unset=True).encode()
842838
elif isinstance(msg, GetPrevSuccessfulDagRun):
843839
dagrun_resp = self.client.task_instances.get_previous_successful_dagrun(self.id)

0 commit comments

Comments
 (0)