Skip to content

Commit 1cf88c4

Browse files
committed
test(task_sdk): add test case test_run_with_asset_inlets
1 parent 259ae05 commit 1cf88c4

1 file changed

Lines changed: 52 additions & 1 deletion

File tree

task_sdk/tests/execution_time/test_task_runner.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,18 @@
4343
from airflow.listeners.listener import get_listener_manager
4444
from airflow.providers.standard.operators.python import PythonOperator
4545
from airflow.sdk import DAG, BaseOperator, Connection, dag as dag_decorator, get_current_context
46-
from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState
46+
from airflow.sdk.api.datamodels._generated import (
47+
AssetEventResponse,
48+
AssetProfile,
49+
AssetResponse,
50+
TaskInstance,
51+
TerminalTIState,
52+
)
4753
from airflow.sdk.definitions.asset import Asset, AssetAlias
4854
from airflow.sdk.definitions.param import DagParam
4955
from airflow.sdk.definitions.variable import Variable
5056
from airflow.sdk.execution_time.comms import (
57+
AssetEventsResult,
5158
BundleInfo,
5259
ConnectionResult,
5360
DeferTask,
@@ -731,6 +738,50 @@ def test_run_with_asset_outlets(
731738
mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY)
732739

733740

741+
def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms):
742+
"""Test running a basic task that contains asset inlets."""
743+
asset_event_resp = AssetEventResponse(
744+
id=1,
745+
created_dagruns=[],
746+
timestamp=datetime.now(),
747+
asset=AssetResponse(name="test", uri="test", group="asset"),
748+
)
749+
events_result = AssetEventsResult(asset_events=[asset_event_resp])
750+
mock_supervisor_comms.get_message.return_value = events_result
751+
752+
from airflow.providers.standard.operators.bash import BashOperator
753+
754+
task = BashOperator(
755+
inlets=[Asset(name="test", uri="test://uri"), AssetAlias(name="alias-name")],
756+
task_id="asset-outlet-task",
757+
bash_command="echo 0",
758+
)
759+
760+
ti = create_runtime_ti(task=task, dag_id="dag_with_asset_outlet_task")
761+
run(ti, log=mock.MagicMock())
762+
inlet_events = ti.get_template_context()["inlet_events"]
763+
764+
# access the asset events of Asset(name="test", uri="test://uri")
765+
assert inlet_events[0] == [asset_event_resp]
766+
assert inlet_events[-2] == [asset_event_resp]
767+
assert inlet_events[Asset(name="test", uri="test://uri")] == [asset_event_resp]
768+
769+
# access the asset events of AssetAlias(name="alias-name")
770+
assert inlet_events[1] == [asset_event_resp]
771+
assert inlet_events[-1] == [asset_event_resp]
772+
assert inlet_events[AssetAlias(name="alias-name")] == [asset_event_resp]
773+
774+
# access with invalid index
775+
with pytest.raises(IndexError):
776+
inlet_events[2]
777+
778+
with pytest.raises(IndexError):
779+
inlet_events[-3]
780+
781+
with pytest.raises(KeyError):
782+
inlet_events[Asset(name="no such asset in inlets")]
783+
784+
734785
@pytest.mark.parametrize(
735786
["ok", "last_expected_msg"],
736787
[

0 commit comments

Comments
 (0)