|
43 | 43 | from airflow.listeners.listener import get_listener_manager |
44 | 44 | from airflow.providers.standard.operators.python import PythonOperator |
45 | 45 | from airflow.sdk import DAG, BaseOperator, Connection, dag as dag_decorator, get_current_context |
46 | | -from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState |
| 46 | +from airflow.sdk.api.datamodels._generated import ( |
| 47 | + AssetEventResponse, |
| 48 | + AssetProfile, |
| 49 | + AssetResponse, |
| 50 | + TaskInstance, |
| 51 | + TerminalTIState, |
| 52 | +) |
47 | 53 | from airflow.sdk.definitions.asset import Asset, AssetAlias |
48 | 54 | from airflow.sdk.definitions.param import DagParam |
49 | 55 | from airflow.sdk.definitions.variable import Variable |
50 | 56 | from airflow.sdk.execution_time.comms import ( |
| 57 | + AssetEventsResult, |
51 | 58 | BundleInfo, |
52 | 59 | ConnectionResult, |
53 | 60 | DeferTask, |
@@ -731,6 +738,50 @@ def test_run_with_asset_outlets( |
731 | 738 | mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY) |
732 | 739 |
|
733 | 740 |
|
| 741 | +def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): |
| 742 | + """Test running a basic task that contains asset inlets.""" |
| 743 | + asset_event_resp = AssetEventResponse( |
| 744 | + id=1, |
| 745 | + created_dagruns=[], |
| 746 | + timestamp=datetime.now(), |
| 747 | + asset=AssetResponse(name="test", uri="test", group="asset"), |
| 748 | + ) |
| 749 | + events_result = AssetEventsResult(asset_events=[asset_event_resp]) |
| 750 | + mock_supervisor_comms.get_message.return_value = events_result |
| 751 | + |
| 752 | + from airflow.providers.standard.operators.bash import BashOperator |
| 753 | + |
| 754 | + task = BashOperator( |
| 755 | + inlets=[Asset(name="test", uri="test://uri"), AssetAlias(name="alias-name")], |
| 756 | + task_id="asset-outlet-task", |
| 757 | + bash_command="echo 0", |
| 758 | + ) |
| 759 | + |
| 760 | + ti = create_runtime_ti(task=task, dag_id="dag_with_asset_outlet_task") |
| 761 | + run(ti, log=mock.MagicMock()) |
| 762 | + inlet_events = ti.get_template_context()["inlet_events"] |
| 763 | + |
| 764 | + # access the asset events of Asset(name="test", uri="test://uri") |
| 765 | + assert inlet_events[0] == [asset_event_resp] |
| 766 | + assert inlet_events[-2] == [asset_event_resp] |
| 767 | + assert inlet_events[Asset(name="test", uri="test://uri")] == [asset_event_resp] |
| 768 | + |
| 769 | + # access the asset events of AssetAlias(name="alias-name") |
| 770 | + assert inlet_events[1] == [asset_event_resp] |
| 771 | + assert inlet_events[-1] == [asset_event_resp] |
| 772 | + assert inlet_events[AssetAlias(name="alias-name")] == [asset_event_resp] |
| 773 | + |
| 774 | + # access with invalid index |
| 775 | + with pytest.raises(IndexError): |
| 776 | + inlet_events[2] |
| 777 | + |
| 778 | + with pytest.raises(IndexError): |
| 779 | + inlet_events[-3] |
| 780 | + |
| 781 | + with pytest.raises(KeyError): |
| 782 | + inlet_events[Asset(name="no such asset in inlets")] |
| 783 | + |
| 784 | + |
734 | 785 | @pytest.mark.parametrize( |
735 | 786 | ["ok", "last_expected_msg"], |
736 | 787 | [ |
|
0 commit comments