Skip to content

Commit cf9c38f

Browse files
Lee-Wnailo2c
authored andcommitted
feat(task_sdk): add support for inlet_events in Task Context (apache#45960)
* feat(task_sdk): add support for inlet_events in Task Context * feat(task_sdk): add AssetEventCollectionResponse * refactor(task_sdk): combine asset event uris * refactor(api_fastapi): extract asset_event datamodels from asset * fix(task_sdk): revert unrelated datamodels change * fix(task_sdk/context): add _get_asset_events_from_db for fixing tests * test(task_sdk): add test cases for execution_time context inlet access * test(task_sdk): extend test_handle_requests to include asset event calls * test(execution_api): add tests to asset event apis * fix(execution_api): remove unnecessary redact * feat(task_sdk): extract asset response from asset event response * feat(task_sdk): add missing http exception * feat(task_sdk): extract asset response from asset event response * feat(task_sdk): remove duplicate inlet logic * feat(task_sdk): remove AssetEvent form definitions * test(task_sdk): add test case test_run_with_asset_inlets * docs(newsfragments): add description of how inlet_events access has been changed
1 parent c2f12b1 commit cf9c38f

18 files changed

Lines changed: 880 additions & 141 deletions

File tree

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from datetime import datetime
21+
22+
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
23+
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
24+
25+
26+
class DagRunAssetReference(StrictBaseModel):
27+
"""DagRun serializer for asset responses."""
28+
29+
run_id: str
30+
dag_id: str
31+
logical_date: datetime | None
32+
start_date: datetime
33+
end_date: datetime | None
34+
state: str
35+
data_interval_start: datetime | None
36+
data_interval_end: datetime | None
37+
38+
39+
class AssetEventResponse(BaseModel):
40+
"""Asset event schema with fields that are needed for Runtime."""
41+
42+
id: int
43+
timestamp: datetime
44+
extra: dict | None = None
45+
46+
asset: AssetResponse
47+
created_dagruns: list[DagRunAssetReference]
48+
49+
source_task_id: str | None = None
50+
source_dag_id: str | None = None
51+
source_run_id: str | None = None
52+
source_map_index: int = -1
53+
54+
55+
class AssetEventsResponse(BaseModel):
56+
"""Collection of AssetEventResponse."""
57+
58+
asset_events: list[AssetEventResponse]

airflow/api_fastapi/execution_api/routes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from airflow.api_fastapi.common.router import AirflowRouter
2020
from airflow.api_fastapi.execution_api.routes import (
21+
asset_events,
2122
assets,
2223
connections,
2324
health,
@@ -28,6 +29,7 @@
2829

2930
execution_api_router = AirflowRouter()
3031
execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
32+
execution_api_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
3133
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
3234
execution_api_router.include_router(health.router, tags=["Health"])
3335
execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from typing import Annotated
21+
22+
from fastapi import HTTPException, Query, status
23+
from sqlalchemy import and_, select
24+
25+
from airflow.api_fastapi.common.db.common import SessionDep
26+
from airflow.api_fastapi.common.router import AirflowRouter
27+
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
28+
from airflow.api_fastapi.execution_api.datamodels.asset_event import (
29+
AssetEventResponse,
30+
AssetEventsResponse,
31+
)
32+
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel
33+
34+
# TODO: Add dependency on JWT token
35+
router = AirflowRouter(
36+
responses={
37+
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
38+
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
39+
},
40+
)
41+
42+
43+
def _get_asset_events_through_sql_clauses(
44+
*, join_clause, where_clause, session: SessionDep
45+
) -> AssetEventsResponse:
46+
asset_events = session.scalars(
47+
select(AssetEvent).join(join_clause).where(where_clause).order_by(AssetEvent.timestamp)
48+
)
49+
return AssetEventsResponse.model_validate(
50+
{
51+
"asset_events": [
52+
AssetEventResponse(
53+
id=event.id,
54+
timestamp=event.timestamp,
55+
extra=event.extra,
56+
asset=AssetResponse(
57+
name=event.asset.name,
58+
uri=event.asset.uri,
59+
group=event.asset.group,
60+
extra=event.asset.extra,
61+
),
62+
created_dagruns=event.created_dagruns,
63+
source_task_id=event.source_task_id,
64+
source_dag_id=event.source_dag_id,
65+
source_run_id=event.source_run_id,
66+
source_map_index=event.source_map_index,
67+
)
68+
for event in asset_events
69+
]
70+
}
71+
)
72+
73+
74+
@router.get("/by-asset")
75+
def get_asset_event_by_asset_name_uri(
76+
name: Annotated[str | None, Query(description="The name of the Asset")],
77+
uri: Annotated[str | None, Query(description="The URI of the Asset")],
78+
session: SessionDep,
79+
) -> AssetEventsResponse:
80+
if name and uri:
81+
where_clause = and_(AssetModel.name == name, AssetModel.uri == uri)
82+
elif uri:
83+
where_clause = and_(AssetModel.uri == uri, AssetModel.active.has())
84+
elif name:
85+
where_clause = and_(AssetModel.name == name, AssetModel.active.has())
86+
else:
87+
raise HTTPException(
88+
status_code=status.HTTP_400_BAD_REQUEST,
89+
detail={
90+
"reason": "Missing parameter",
91+
"message": "name and uri cannot both be None",
92+
},
93+
)
94+
95+
return _get_asset_events_through_sql_clauses(
96+
join_clause=AssetEvent.asset,
97+
where_clause=where_clause,
98+
session=session,
99+
)
100+
101+
102+
@router.get("/by-asset-alias")
103+
def get_asset_event_by_asset_alias(
104+
name: Annotated[str, Query(description="The name of the Asset Alias")],
105+
session: SessionDep,
106+
) -> AssetEventsResponse:
107+
return _get_asset_events_through_sql_clauses(
108+
join_clause=AssetEvent.source_aliases,
109+
where_clause=(AssetAliasModel.name == name),
110+
session=session,
111+
)

airflow/models/taskinstance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
110110
from airflow.sdk.definitions.param import process_params
111111
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
112+
from airflow.sdk.execution_time.context import InletEventsAccessors
112113
from airflow.sentry import Sentry
113114
from airflow.settings import task_instance_mutation_hook
114115
from airflow.stats import Stats
@@ -119,7 +120,6 @@
119120
from airflow.utils.context import (
120121
ConnectionAccessor,
121122
Context,
122-
InletEventsAccessors,
123123
OutletEventAccessors,
124124
VariableAccessor,
125125
context_get_outlet_events,
@@ -973,7 +973,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]:
973973
context.update(
974974
{
975975
"outlet_events": OutletEventAccessors(),
976-
"inlet_events": InletEventsAccessors(task.inlets, session=session),
976+
"inlet_events": InletEventsAccessors(task.inlets),
977977
"macros": macros,
978978
"params": validated_params,
979979
"prev_data_interval_start_success": get_prev_data_interval_start_success(),

airflow/utils/context.py

Lines changed: 2 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -21,50 +21,29 @@
2121

2222
from collections.abc import (
2323
Container,
24-
Iterator,
25-
Mapping,
2624
)
2725
from typing import (
2826
TYPE_CHECKING,
2927
Any,
30-
Union,
3128
cast,
3229
)
3330

34-
import attrs
35-
from sqlalchemy import and_, select
31+
from sqlalchemy import select
3632

3733
from airflow.models.asset import (
38-
AssetAliasModel,
39-
AssetEvent,
4034
AssetModel,
41-
fetch_active_assets_by_name,
42-
fetch_active_assets_by_uri,
43-
)
44-
from airflow.sdk.definitions.asset import (
45-
Asset,
46-
AssetAlias,
47-
AssetAliasUniqueKey,
48-
AssetNameRef,
49-
AssetRef,
50-
AssetUniqueKey,
51-
AssetUriRef,
5235
)
5336
from airflow.sdk.definitions.context import Context
5437
from airflow.sdk.execution_time.context import (
5538
ConnectionAccessor as ConnectionAccessorSDK,
5639
OutletEventAccessors as OutletEventAccessorsSDK,
5740
VariableAccessor as VariableAccessorSDK,
5841
)
59-
from airflow.utils.db import LazySelectSequence
6042
from airflow.utils.session import create_session
6143
from airflow.utils.types import NOTSET
6244

6345
if TYPE_CHECKING:
64-
from sqlalchemy.engine import Row
65-
from sqlalchemy.orm import Session
66-
from sqlalchemy.sql.expression import Select, TextClause
67-
46+
from airflow.sdk.definitions.asset import Asset
6847
from airflow.sdk.types import OutletEventAccessorsProtocol
6948

7049
# NOTE: Please keep this in sync with the following:
@@ -170,106 +149,6 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
170149
return asset.to_public()
171150

172151

173-
class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
174-
"""
175-
List-like interface to lazily access AssetEvent rows.
176-
177-
:meta private:
178-
"""
179-
180-
@staticmethod
181-
def _rebuild_select(stmt: TextClause) -> Select:
182-
return select(AssetEvent).from_statement(stmt)
183-
184-
@staticmethod
185-
def _process_row(row: Row) -> AssetEvent:
186-
return row[0]
187-
188-
189-
@attrs.define(init=False)
190-
class InletEventsAccessors(Mapping[Union[int, Asset, AssetAlias, AssetRef], LazyAssetEventSelectSequence]):
191-
"""
192-
Lazy mapping for inlet asset events accessors.
193-
194-
:meta private:
195-
"""
196-
197-
_inlets: list[Any]
198-
_assets: dict[AssetUniqueKey, Asset]
199-
_asset_aliases: dict[AssetAliasUniqueKey, AssetAlias]
200-
_session: Session
201-
202-
def __init__(self, inlets: list, *, session: Session) -> None:
203-
self._inlets = inlets
204-
self._session = session
205-
self._assets = {}
206-
self._asset_aliases = {}
207-
208-
_asset_ref_names: list[str] = []
209-
_asset_ref_uris: list[str] = []
210-
for inlet in inlets:
211-
if isinstance(inlet, Asset):
212-
self._assets[AssetUniqueKey.from_asset(inlet)] = inlet
213-
elif isinstance(inlet, AssetAlias):
214-
self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet
215-
elif isinstance(inlet, AssetNameRef):
216-
_asset_ref_names.append(inlet.name)
217-
elif isinstance(inlet, AssetUriRef):
218-
_asset_ref_uris.append(inlet.uri)
219-
220-
if _asset_ref_names:
221-
for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items():
222-
self._assets[AssetUniqueKey.from_asset(asset)] = asset
223-
if _asset_ref_uris:
224-
for _, asset in fetch_active_assets_by_uri(_asset_ref_uris, self._session).items():
225-
self._assets[AssetUniqueKey.from_asset(asset)] = asset
226-
227-
def __iter__(self) -> Iterator[Asset | AssetAlias]:
228-
return iter(self._inlets)
229-
230-
def __len__(self) -> int:
231-
return len(self._inlets)
232-
233-
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> LazyAssetEventSelectSequence:
234-
if isinstance(key, int): # Support index access; it's easier for trivial cases.
235-
obj = self._inlets[key]
236-
if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
237-
raise IndexError(key)
238-
else:
239-
obj = key
240-
241-
if isinstance(obj, Asset):
242-
asset = self._assets[AssetUniqueKey.from_asset(obj)]
243-
join_clause = AssetEvent.asset
244-
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
245-
elif isinstance(obj, AssetAlias):
246-
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
247-
join_clause = AssetEvent.source_aliases
248-
where_clause = AssetAliasModel.name == asset_alias.name
249-
elif isinstance(obj, AssetNameRef):
250-
try:
251-
asset = next(a for k, a in self._assets.items() if k.name == obj.name)
252-
except StopIteration:
253-
raise KeyError(obj) from None
254-
join_clause = AssetEvent.asset
255-
where_clause = and_(AssetModel.name == asset.name, AssetModel.active.has())
256-
elif isinstance(obj, AssetUriRef):
257-
try:
258-
asset = next(a for k, a in self._assets.items() if k.uri == obj.uri)
259-
except StopIteration:
260-
raise KeyError(obj) from None
261-
join_clause = AssetEvent.asset
262-
where_clause = and_(AssetModel.uri == asset.uri, AssetModel.active.has())
263-
else:
264-
raise ValueError(key)
265-
266-
return LazyAssetEventSelectSequence.from_select(
267-
select(AssetEvent).join(join_clause).where(where_clause),
268-
order_by=[AssetEvent.timestamp],
269-
session=self._session,
270-
)
271-
272-
273152
def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
274153
"""
275154
Merge parameters into an existing context.

0 commit comments

Comments
 (0)