1616package org .openrewrite .java ;
1717
1818import org .openrewrite .internal .ListUtils ;
19+ import org .openrewrite .internal .lang .Nullable ;
1920import org .openrewrite .java .tree .*;
2021import org .openrewrite .marker .Markers ;
2122
23+ import java .util .List ;
24+ import java .util .Objects ;
25+ import java .util .stream .Collectors ;
26+
2227import static org .openrewrite .Tree .randomId ;
2328import static org .openrewrite .java .tree .Space .format ;
2429
2530public class ImplementInterface <P > extends JavaIsoVisitor <P > {
2631 private final J .ClassDeclaration scope ;
2732 private final JavaType .FullyQualified interfaceType ;
33+ private final @ Nullable List <Expression > typeParameters ;
2834
29- public ImplementInterface (J .ClassDeclaration scope , JavaType .FullyQualified interfaceType ) {
35+ public ImplementInterface (J .ClassDeclaration scope , JavaType .FullyQualified interfaceType , @ Nullable List < Expression > typeParameters ) {
3036 this .scope = scope ;
3137 this .interfaceType = interfaceType ;
38+ this .typeParameters = typeParameters ;
39+ }
40+
41+ public ImplementInterface (J .ClassDeclaration scope , String interfaze , @ Nullable List <Expression > typeParameters ) {
42+ this (scope , JavaType .ShallowClass .build (interfaze ), typeParameters );
43+ }
44+
45+ public ImplementInterface (J .ClassDeclaration scope , JavaType .FullyQualified interfaceType ) {
46+ this (scope , interfaceType , null );
3247 }
3348
3449 public ImplementInterface (J .ClassDeclaration scope , String interfaze ) {
35- this (scope , JavaType . ShallowClass . build ( interfaze ) );
50+ this (scope , interfaze , null );
3651 }
3752
3853 @ Override
@@ -41,7 +56,7 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P
4156 if (c .isScope (scope ) && (c .getImplements () == null || c .getImplements ().stream ()
4257 .noneMatch (f -> TypeUtils .isAssignableTo (f .getType (), interfaceType )))) {
4358
44- if (!classDecl .getSimpleName ().equals (interfaceType .getClassName ())) {
59+ if (!classDecl .getSimpleName ().equals (interfaceType .getClassName ())) {
4560 maybeAddImport (interfaceType );
4661 }
4762
@@ -50,7 +65,29 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, P
5065 .withType (interfaceType )
5166 .withPrefix (format (" " ));
5267
53- c = c .withImplements (ListUtils .concat (c .getImplements (), impl ));
68+ if (typeParameters != null && !typeParameters .isEmpty ()) {
69+ typeParameters .stream ()
70+ .map (Expression ::getType )
71+ .map (t -> (t instanceof JavaType .FullyQualified ) ? (JavaType .FullyQualified ) t : null )
72+ .filter (Objects ::nonNull )
73+ .forEach (t -> maybeAddImport (t .getFullyQualifiedName ()));
74+
75+ List <JRightPadded <Expression >> elements = typeParameters .stream ()
76+ .map (t -> new JRightPadded <>(t , Space .EMPTY , Markers .EMPTY ))
77+ .collect (Collectors .toList ());
78+
79+ J .ParameterizedType typedImpl = new J .ParameterizedType (
80+ randomId (),
81+ Space .EMPTY ,
82+ Markers .EMPTY ,
83+ impl ,
84+ JContainer .build (Space .EMPTY , elements , Markers .EMPTY )
85+ );
86+
87+ c = c .withImplements (ListUtils .concat (c .getImplements (), typedImpl ));
88+ } else {
89+ c = c .withImplements (ListUtils .concat (c .getImplements (), impl ));
90+ }
5491
5592 JContainer <TypeTree > anImplements = c .getPadding ().getImplements ();
5693 assert anImplements != null ;
0 commit comments