Skip to content

Commit 94e6110

Browse files
authored
[airflow] Implement task-branch-as-short-circuit (AIR004) (astral-sh#23579)
## Summary Adds a new rule `AIR004` that detects `@task.branch` decorated functions that could be replaced with `@task.short_circuit`. In Airflow, [`@task.branch`](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/dags.html#branching) selects which downstream tasks to run by returning a list of task IDs (or an empty list to skip all). When the function has at least two `return` statements and exactly one of them returns a non-empty list, it is effectively acting as a boolean short-circuit (i.e. either run one specific set of downstream tasks or skip them all). In that case, [`@task.short_circuit`](https://www.astronomer.io/docs/learn/airflow-branch-operator#taskshort_circuit-shortcircuitoperator) is a simpler and more readable alternative that returns `True`/`False` instead. ```python # Before (AIR004) @task.branch def my_task(): if condition: return ["my_downstream_task"] return [] # After @task.short_circuit def my_task(): return condition ``` ### Implementation details - Resolves the `@task.branch` decorator via the semantic model (`airflow.decorators.task` + `.branch` attribute), handling both `@task.branch` and `@task.branch()` call forms via `map_callable`. - Uses `ReturnStatementVisitor` to collect all `return` statements recursively (including those inside nested `if`/`else`/`for`/`while` blocks). - Flags the function when: `len(returns) >= 2` and exactly one return has a non-empty list value. ### What it does NOT flag - Functions with multiple non-empty list returns (genuine branching logic). - Functions with all-empty returns (no downstream tasks selected at all). - Functions with only a single return statement. - Functions not decorated with `@task.branch`. - Functions returning non-list values (strings, `None`, etc.). ## Test Plan <!-- How was it tested? --> Added snapshot tests in `AIR004.py` covering both violation and non-violation cases: - two returns with one non-empty list - three returns with one non-empty list - nested returns - multiple non-empty returns - all-empty returns - single return - undecorated functions - `@task.short_circuit` decorated functions
1 parent 3efc690 commit 94e6110

12 files changed

Lines changed: 506 additions & 0 deletions

File tree

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from airflow.decorators import task
2+
from airflow.operators.python import BranchPythonOperator, ShortCircuitOperator
3+
from airflow.providers.standard.operators.python import (
4+
BranchPythonOperator as ProviderBranchPythonOperator,
5+
)
6+
7+
condition1 = True
8+
condition2 = True
9+
condition3 = True
10+
11+
12+
# Violations (should trigger AIR004):
13+
14+
@task.branch
15+
def two_returns_one_non_empty(): # AIR004
16+
if condition1:
17+
return ["my_downstream_task"]
18+
return []
19+
20+
21+
@task.branch
22+
def three_returns_one_non_empty(): # AIR004
23+
if condition1:
24+
return []
25+
if condition2:
26+
return ["another_downstream_task"]
27+
return []
28+
29+
30+
@task.branch
31+
def nested_returns_one_non_empty(): # AIR004
32+
if condition1:
33+
if condition2:
34+
return []
35+
return ["downstream_task"]
36+
return []
37+
38+
39+
@task.branch()
40+
def with_parens(): # AIR004
41+
if condition1:
42+
return ["downstream_task"]
43+
return []
44+
45+
46+
@task.branch
47+
def bare_return_and_list(): # AIR004
48+
if condition1:
49+
return ["downstream_task"]
50+
return
51+
52+
53+
@task.branch
54+
def none_return_and_list(): # AIR004
55+
if condition1:
56+
return ["downstream_task"]
57+
return None
58+
59+
60+
# No violations:
61+
62+
@task.branch
63+
def two_non_empty_returns():
64+
if condition1:
65+
return ["task_a"]
66+
if condition2:
67+
return ["task_b"]
68+
return []
69+
70+
71+
@task.branch
72+
def all_empty_returns():
73+
if condition1:
74+
return []
75+
if condition2:
76+
return []
77+
return []
78+
79+
80+
@task.branch
81+
def single_return():
82+
return ["downstream_task"]
83+
84+
85+
def not_decorated():
86+
if condition1:
87+
return ["downstream_task"]
88+
return []
89+
90+
91+
@task.short_circuit
92+
def already_short_circuit():
93+
if condition1:
94+
return True
95+
return False
96+
97+
98+
# BranchPythonOperator violations (should trigger AIR004):
99+
100+
101+
def operator_short_circuit_candidate():
102+
if condition1:
103+
return ["downstream_task"]
104+
return []
105+
106+
107+
BranchPythonOperator(task_id="task", python_callable=operator_short_circuit_candidate) # AIR004
108+
109+
110+
def provider_short_circuit_candidate():
111+
if condition1:
112+
return ["downstream_task"]
113+
return []
114+
115+
116+
ProviderBranchPythonOperator( # AIR004
117+
task_id="task", python_callable=provider_short_circuit_candidate
118+
)
119+
120+
121+
# BranchPythonOperator non-violations:
122+
123+
124+
def operator_two_non_empty():
125+
if condition1:
126+
return ["task_a"]
127+
if condition2:
128+
return ["task_b"]
129+
return []
130+
131+
132+
BranchPythonOperator(task_id="task", python_callable=operator_two_non_empty)
133+
134+
135+
def operator_single_return():
136+
return ["downstream_task"]
137+
138+
139+
BranchPythonOperator(task_id="task", python_callable=operator_single_return)
140+
141+
ShortCircuitOperator(task_id="task", python_callable=operator_short_circuit_candidate)
142+
143+
BranchPythonOperator(task_id="task", python_callable=lambda: ["downstream_task"])
144+
145+
BranchPythonOperator(task_id="task")
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from airflow.sdk import task
2+
3+
condition1 = True
4+
condition2 = True
5+
6+
7+
# Violations (should trigger AIR004):
8+
9+
@task.branch
10+
def sdk_two_returns_one_non_empty(): # AIR004
11+
if condition1:
12+
return ["my_downstream_task"]
13+
return []
14+
15+
16+
# No violations:
17+
18+
@task.branch
19+
def sdk_two_non_empty_returns():
20+
if condition1:
21+
return ["task_a"]
22+
if condition2:
23+
return ["task_b"]
24+
return []

crates/ruff_linter/src/checkers/ast/analyze/expression.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,9 @@ pub(crate) fn expression(expr: &Expr, checker: &Checker) {
13221322
if checker.is_rule_enabled(Rule::Airflow3DagDynamicValue) {
13231323
airflow::rules::airflow_3_dag_dynamic_value(checker, call);
13241324
}
1325+
if checker.is_rule_enabled(Rule::AirflowTaskBranchAsShortCircuit) {
1326+
airflow::rules::branch_python_operator_as_short_circuit(checker, call);
1327+
}
13251328
if checker.is_rule_enabled(Rule::UnnecessaryCastToInt) {
13261329
ruff::rules::unnecessary_cast_to_int(checker, call);
13271330
}

crates/ruff_linter/src/checkers/ast/analyze/statement.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
339339
if checker.is_rule_enabled(Rule::PytestParameterWithDefaultArgument) {
340340
flake8_pytest_style::rules::parameter_with_default_argument(checker, function_def);
341341
}
342+
if checker.is_rule_enabled(Rule::AirflowTaskBranchAsShortCircuit) {
343+
airflow::rules::task_branch_as_short_circuit(checker, function_def);
344+
}
342345
if checker.is_rule_enabled(Rule::Airflow3Removal) {
343346
airflow::rules::airflow_3_removal_function_def(checker, function_def);
344347
}

crates/ruff_linter/src/codes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> {
11341134
(Airflow, "001") => rules::airflow::rules::AirflowVariableNameTaskIdMismatch,
11351135
(Airflow, "002") => rules::airflow::rules::AirflowDagNoScheduleArgument,
11361136
(Airflow, "003") => rules::airflow::rules::AirflowVariableGetOutsideTask,
1137+
(Airflow, "004") => rules::airflow::rules::AirflowTaskBranchAsShortCircuit,
11371138
(Airflow, "201") => rules::airflow::rules::AirflowXcomPullInTemplateString,
11381139
(Airflow, "301") => rules::airflow::rules::Airflow3Removal,
11391140
(Airflow, "302") => rules::airflow::rules::Airflow3MovedToProvider,

crates/ruff_linter/src/rules/airflow/helpers.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,23 @@ pub(crate) fn is_airflow_task(function_def: &StmtFunctionDef, semantic: &Semanti
326326
false
327327
})
328328
}
329+
330+
/// Returns `true` if the given function is decorated with a specific `@task.<variant>`
331+
/// form (e.g., `@task.branch` or `@task.short_circuit`).
332+
pub(crate) fn is_airflow_task_variant(
333+
function_def: &StmtFunctionDef,
334+
semantic: &SemanticModel,
335+
variant: &str,
336+
) -> bool {
337+
function_def.decorator_list.iter().any(|decorator| {
338+
let expr = map_callable(&decorator.expression);
339+
if let Expr::Attribute(ExprAttribute { value, attr, .. }) = expr {
340+
attr.as_str() == variant
341+
&& semantic.resolve_qualified_name(value).is_some_and(|qn| {
342+
matches!(qn.segments(), ["airflow", "decorators" | "sdk", "task"])
343+
})
344+
} else {
345+
false
346+
}
347+
})
348+
}

crates/ruff_linter/src/rules/airflow/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ mod tests {
2121
Rule::AirflowVariableGetOutsideTask,
2222
Path::new("AIR003_dag_decorator.py")
2323
)]
24+
#[test_case(Rule::AirflowTaskBranchAsShortCircuit, Path::new("AIR004.py"))]
25+
#[test_case(Rule::AirflowTaskBranchAsShortCircuit, Path::new("AIR004_sdk.py"))]
2426
#[test_case(Rule::AirflowXcomPullInTemplateString, Path::new("AIR201.py"))]
2527
#[test_case(Rule::Airflow3Removal, Path::new("AIR301_args.py"))]
2628
#[test_case(Rule::Airflow3Removal, Path::new("AIR301_names.py"))]

crates/ruff_linter/src/rules/airflow/rules/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub(crate) use removal_in_3::*;
66
pub(crate) use runtime_value_in_dag_or_task::*;
77
pub(crate) use suggested_to_move_to_provider_in_3::*;
88
pub(crate) use suggested_to_update_3_0::*;
9+
pub(crate) use task_branch_as_short_circuit::*;
910
pub(crate) use task_variable_name::*;
1011
pub(crate) use variable_get_outside_task::*;
1112
pub(crate) use xcom_pull_in_template_string::*;
@@ -18,6 +19,7 @@ mod removal_in_3;
1819
mod runtime_value_in_dag_or_task;
1920
mod suggested_to_move_to_provider_in_3;
2021
mod suggested_to_update_3_0;
22+
mod task_branch_as_short_circuit;
2123
mod task_variable_name;
2224
mod variable_get_outside_task;
2325
mod xcom_pull_in_template_string;

0 commit comments

Comments
 (0)