Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,16 @@
# Regression test: Don't forget the parentheses in the annotation when breaking
class DefaultRunner:
task_runner_cls: TaskRunnerProtocol | typing.Callable[[], typing.Any] = DefaultTaskRunner


# Preserve parentheses around invalid type expressions.
def preserve_invalid_type_expressions_in_annotations():
named: (value := int) = 1
yielded_with_value: (yield 1) = 1
yielded: (yield 1)
yielded_from: (yield from iter) = 1


async def preserve_invalid_type_expressions_in_async_annotations():
awaited: (await g()) = 1
bare_awaited: (await g())
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,16 @@ def args_with_type_annotations_no_after_colon_comment(
int # trailing type
# after type
): pass


# Preserve parentheses around invalid type expressions in parameter annotations.
def preserve_named_param_annotation(x: (value := int)):
pass


def preserve_yield_param_annotation(x: (yield 1)):
pass


async def preserve_await_param_annotation(x: (await g())):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,8 @@ def process_board_action(
payload: WildValue, action_type: Optional[str]
) -> Optional[Tuple[str, str]]:
pass


# Preserve parentheses around invalid type expressions.
async def preserve_await_return_annotation() -> (await g()):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,9 @@
type bound_and_default[T:int=int] = int
type long_bound_short_default[T: (aaaaaaaaaaaaaaaaaaaaaaaaaaaaa, bbbbbbbbbbbbbbb, ccccccccccc, ddddddddddddd, eeeeeee)=a]=int
type short_bound_long_default[T:a= (aaaaaaaaaaaaaaaaaaaaaaaaaaaaa, bbbbbbbbbbbbbbb, ccccccccccc, ddddddddddddd, eeeeeee)]=int

# Preserve parentheses around invalid type expressions.
type NamedExprValue = (value := int)

async def preserve_await_type_alias_value():
type AwaitValue = (await g())
4 changes: 3 additions & 1 deletion crates/ruff_python_formatter/src/expression/expr_await.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use ruff_formatter::write;
use ruff_python_ast::AnyNodeRef;
use ruff_python_ast::ExprAwait;
use ruff_text_size::Ranged;

use crate::expression::maybe_parenthesize_expression;
use crate::expression::parentheses::{
NeedsParentheses, OptionalParentheses, Parenthesize, is_expression_parenthesized,
is_type_annotation_of,
};
use crate::prelude::*;

Expand Down Expand Up @@ -36,7 +38,7 @@ impl NeedsParentheses for ExprAwait {
parent: AnyNodeRef,
context: &PyFormatContext,
) -> OptionalParentheses {
if parent.is_expr_await() {
if parent.is_expr_await() || is_type_annotation_of(self.range(), parent) {
OptionalParentheses::Always
} else if is_expression_parenthesized(
self.value.as_ref().into(),
Expand Down
5 changes: 5 additions & 0 deletions crates/ruff_python_formatter/src/expression/expr_yield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use ruff_text_size::{Ranged, TextRange};
use crate::expression::maybe_parenthesize_expression;
use crate::expression::parentheses::{
NeedsParentheses, OptionalParentheses, Parenthesize, is_expression_parenthesized,
is_type_annotation_of,
};
use crate::prelude::*;

Expand Down Expand Up @@ -42,6 +43,10 @@ impl NeedsParentheses for AnyExpressionYield<'_> {
parent: AnyNodeRef,
context: &PyFormatContext,
) -> OptionalParentheses {
if is_type_annotation_of(self.range(), parent) {
return OptionalParentheses::Always;
}

// According to https://docs.python.org/3/reference/grammar.html There are two situations
// where we do not want to always parenthesize a yield expression:
// 1. Right hand side of an assignment, e.g. `x = yield y`
Expand Down
9 changes: 9 additions & 0 deletions crates/ruff_python_formatter/src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,15 @@ pub(crate) fn is_splittable_expression(expr: &Expr, context: &PyFormatContext) -
}
}

/// Returns `true` if `expr` is invalid in a type-expression position unless the source keeps it
/// parenthesized.
pub(crate) const fn is_invalid_type_expression(expr: &Expr) -> bool {
matches!(
expr,
Expr::Named(_) | Expr::Await(_) | Expr::Yield(_) | Expr::YieldFrom(_)
)
}

/// Returns the sub-expression to which the left-most character in expression belongs.
///
/// For example, in the expression `a + b * c`, the left-most subexpression is `a`. But for
Expand Down
15 changes: 14 additions & 1 deletion crates/ruff_python_formatter/src/expression/parentheses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ruff_python_trivia::CommentRanges;
use ruff_python_trivia::{
BackwardsTokenizer, SimpleToken, SimpleTokenKind, first_non_trivia_token,
};
use ruff_text_size::Ranged;
use ruff_text_size::{Ranged, TextRange};

use crate::comments::{
SourceComment, dangling_comments, dangling_open_parenthesis_comments, trailing_comments,
Expand Down Expand Up @@ -42,6 +42,19 @@ pub(crate) trait NeedsParentheses {
) -> OptionalParentheses;
}

/// Returns `true` if `expr_range` identifies a type annotation child of `parent`,
/// i.e. the annotation of a `StmtAnnAssign` or the return annotation of a `StmtFunctionDef`.
pub(crate) fn is_type_annotation_of(expr_range: TextRange, parent: AnyNodeRef) -> bool {
match parent {
AnyNodeRef::StmtAnnAssign(stmt) => stmt.annotation.range() == expr_range,
AnyNodeRef::StmtFunctionDef(stmt) => stmt
.returns
.as_deref()
.is_some_and(|r| r.range() == expr_range),
_ => false,
}
}

/// From the perspective of the parent statement or expression, when should the child expression
/// get parentheses?
#[derive(Copy, Clone, Debug, PartialEq)]
Expand Down
41 changes: 21 additions & 20 deletions crates/ruff_python_formatter/src/statement/stmt_ann_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ impl FormatNodeRule<StmtAnnAssign> for FormatStmtAnnAssign {
value,
simple: _,
} = item;
let comments = f.context().comments().clone();
let annotation_parentheses = annotation
.as_ref()
.needs_parentheses(item.into(), f.context());

write!(f, [target.format(), token(":"), space()])?;

if let Some(value) = value {
if is_splittable_expression(annotation, f.context()) {
if annotation_parentheses != OptionalParentheses::Always
&& is_splittable_expression(annotation, f.context())
{
FormatStatementsLastExpression::RightToLeft {
before_operator: AnyBeforeOperator::Expression(annotation),
operator: AnyAssignmentOperator::Assign,
Expand All @@ -37,26 +43,16 @@ impl FormatNodeRule<StmtAnnAssign> for FormatStmtAnnAssign {
} else {
// Remove unnecessary parentheses around the annotation if the parenthesize long type hints preview style is enabled.
// Ensure we keep the parentheses if the annotation has any comments.
let preserve_parentheses = f.context().comments().has_leading(annotation.as_ref())
|| f.context().comments().has_trailing(annotation.as_ref())
|| matches!(
annotation
.as_ref()
.needs_parentheses(item.into(), f.context()),
OptionalParentheses::Always
);

if preserve_parentheses {
annotation
.format()
.with_options(Parentheses::Always)
.fmt(f)?;
let parentheses = if comments.has_leading(annotation.as_ref())
|| comments.has_trailing(annotation.as_ref())
|| annotation_parentheses == OptionalParentheses::Always
{
Parentheses::Always
} else {
annotation
.format()
.with_options(Parentheses::Never)
.fmt(f)?;
}
Parentheses::Never
};

annotation.format().with_options(parentheses).fmt(f)?;

write!(
f,
Expand All @@ -68,6 +64,11 @@ impl FormatNodeRule<StmtAnnAssign> for FormatStmtAnnAssign {
]
)?;
}
} else if annotation_parentheses == OptionalParentheses::Always {
annotation
.format()
.with_options(Parentheses::Always)
.fmt(f)?;
} else {
// Parenthesize the value and inline the comment if it is a "simple" type annotation, similar
// to what we do with the value.
Expand Down
18 changes: 18 additions & 0 deletions crates/ruff_python_formatter/src/statement/stmt_type_alias.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use ruff_formatter::write;
use ruff_python_ast::StmtTypeAlias;

use crate::expression::is_invalid_type_expression;
use crate::expression::parentheses::Parentheses;
use crate::prelude::*;
use crate::statement::stmt_assign::{
AnyAssignmentOperator, AnyBeforeOperator, FormatStatementsLastExpression,
Expand All @@ -21,6 +23,22 @@ impl FormatNodeRule<StmtTypeAlias> for FormatStmtTypeAlias {

write!(f, [token("type"), space(), name.as_ref().format()])?;

if is_invalid_type_expression(value) {
if let Some(type_params) = type_params {
type_params.format().fmt(f)?;
}

return write!(
f,
[
space(),
token("="),
space(),
value.format().with_options(Parentheses::Preserve)
]
);
}

if let Some(type_params) = type_params {
return FormatStatementsLastExpression::RightToLeft {
before_operator: AnyBeforeOperator::TypeParams(type_params),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/ann_assign.py
---
## Input
```python
Expand Down Expand Up @@ -28,6 +27,19 @@ D: (E := 4) = (F := 5)
# Regression test: Don't forget the parentheses in the annotation when breaking
class DefaultRunner:
task_runner_cls: TaskRunnerProtocol | typing.Callable[[], typing.Any] = DefaultTaskRunner


# Preserve parentheses around invalid type expressions.
def preserve_invalid_type_expressions_in_annotations():
named: (value := int) = 1
yielded_with_value: (yield 1) = 1
yielded: (yield 1)
yielded_from: (yield from iter) = 1


async def preserve_invalid_type_expressions_in_async_annotations():
awaited: (await g()) = 1
bare_awaited: (await g())
```

## Output
Expand Down Expand Up @@ -68,4 +80,17 @@ class DefaultRunner:
task_runner_cls: TaskRunnerProtocol | typing.Callable[[], typing.Any] = (
DefaultTaskRunner
)


# Preserve parentheses around invalid type expressions.
def preserve_invalid_type_expressions_in_annotations():
named: (value := int) = 1
yielded_with_value: (yield 1) = 1
yielded: (yield 1)
yielded_from: (yield from iter) = 1


async def preserve_invalid_type_expressions_in_async_annotations():
awaited: (await g()) = 1
bare_awaited: (await g())
```
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/function.py
---
## Input
```python
Expand Down Expand Up @@ -569,6 +568,19 @@ def args_with_type_annotations_no_after_colon_comment(
int # trailing type
# after type
): pass


# Preserve parentheses around invalid type expressions in parameter annotations.
def preserve_named_param_annotation(x: (value := int)):
pass


def preserve_yield_param_annotation(x: (yield 1)):
pass


async def preserve_await_param_annotation(x: (await g())):
pass
```

## Output
Expand Down Expand Up @@ -1318,6 +1330,19 @@ def args_with_type_annotations_no_after_colon_comment(
# after type
):
pass


# Preserve parentheses around invalid type expressions in parameter annotations.
def preserve_named_param_annotation(x: (value := int)):
pass


def preserve_yield_param_annotation(x: (yield 1)):
pass


async def preserve_await_param_annotation(x: (await g())):
pass
```


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/return_annotation.py
---
## Input
```python
Expand Down Expand Up @@ -198,6 +197,11 @@ def process_board_action(
payload: WildValue, action_type: Optional[str]
) -> Optional[Tuple[str, str]]:
pass


# Preserve parentheses around invalid type expressions.
async def preserve_await_return_annotation() -> (await g()):
pass
```

## Output
Expand Down Expand Up @@ -507,4 +511,9 @@ def process_board_action(
payload: WildValue, action_type: Optional[str]
) -> Optional[Tuple[str, str]]:
pass


# Preserve parentheses around invalid type expressions.
async def preserve_await_return_annotation() -> (await g()):
pass
```
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
source: crates/ruff_python_formatter/tests/fixtures.rs
input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/type_alias.py
---
## Input
```python
Expand Down Expand Up @@ -156,6 +155,12 @@ type after_star[*Ts = *
type bound_and_default[T:int=int] = int
type long_bound_short_default[T: (aaaaaaaaaaaaaaaaaaaaaaaaaaaaa, bbbbbbbbbbbbbbb, ccccccccccc, ddddddddddddd, eeeeeee)=a]=int
type short_bound_long_default[T:a= (aaaaaaaaaaaaaaaaaaaaaaaaaaaaa, bbbbbbbbbbbbbbb, ccccccccccc, ddddddddddddd, eeeeeee)]=int

# Preserve parentheses around invalid type expressions.
type NamedExprValue = (value := int)

async def preserve_await_type_alias_value():
type AwaitValue = (await g())
```

## Output
Expand Down Expand Up @@ -402,4 +407,11 @@ type short_bound_long_default[
eeeeeee,
)
] = int

# Preserve parentheses around invalid type expressions.
type NamedExprValue = (value := int)


async def preserve_await_type_alias_value():
type AwaitValue = (await g())
```
Loading