@@ -215,8 +215,12 @@ def visit_right_padded(self, right: Optional[JRightPadded], p: P) -> Optional[JR
215215 after = self ._indent_to (right .after , indent )
216216
217217 else :
218+ method_select_indent = self .cursor .get_nearest_message ("method_select_indent" )
218219 elem = self .visit_and_cast (elem , J , p )
219- after = self .visit_space (right .after , p )
220+ if method_select_indent is not None :
221+ after = self ._indent_to (right .after , method_select_indent )
222+ else :
223+ after = self .visit_space (right .after , p )
220224
221225 else :
222226 if rp_context in ("method_invocation_argument" ,):
@@ -425,6 +429,62 @@ def visit_class_declaration(self, class_decl: ClassDeclaration, p: P) -> J:
425429 class_decl .body , Block , p ))
426430 return class_decl
427431
432+ def visit_method_invocation (self , method : MethodInvocation , p : P ) -> J :
433+ select = method .padding .select
434+ if select is not None and '\n ' in select .after .last_whitespace :
435+ col = self ._compute_select_column (method )
436+ if col >= 0 :
437+ self .cursor .put_message ("method_select_indent" , col )
438+ return super ().visit_method_invocation (method , p )
439+
440+ def _compute_select_column (self , method : MethodInvocation ) -> int :
441+ from rewrite .python .printer import PythonPrinter , PrintOutputCapture
442+
443+ line_start = None
444+ for c in self .cursor .get_path_as_cursors ():
445+ v = c .value
446+ if isinstance (v , J ):
447+ line_start = v
448+ if '\n ' in v .prefix .whitespace :
449+ break
450+ if line_start is None :
451+ return - 1
452+
453+ select = method .select
454+
455+ class _ColumnCounter (PrintOutputCapture ):
456+ def __init__ (self ):
457+ super ().__init__ ()
458+ self .col = 0
459+ self .found = False
460+
461+ def append (self , text ):
462+ if text and not self .found :
463+ for ch in text :
464+ self .col = 0 if ch == '\n ' else self .col + 1
465+ return self
466+
467+ class _Printer (PythonPrinter ):
468+ def __init__ (self , target ):
469+ super ().__init__ ()
470+ orig_visit = self ._delegate .visit
471+ def _check (tree , p ):
472+ if tree is target :
473+ p .found = True
474+ return tree
475+ return orig_visit (tree , p ) if not p .found else tree
476+ self ._delegate .visit = _check
477+
478+ def visit (self , tree , p , parent = None ):
479+ if p .found or tree is select :
480+ p .found = True
481+ return tree
482+ return super ().visit (tree , p )
483+
484+ counter = _ColumnCounter ()
485+ _Printer (select ).print (line_start , counter )
486+ return counter .col if counter .found else - 1
487+
428488 # -------------------------------------------------------------------------
429489 # Expression statement (docstring alignment)
430490 # -------------------------------------------------------------------------
0 commit comments