99import java .util .List ;
1010
1111import org .opensearch .ml .common .settings .MLFeatureEnabledSetting ;
12+ import org .opensearch .ml .grpc .interfaces .MLClient ;
13+ import org .opensearch .ml .grpc .interfaces .MLModelAccessControlHelper ;
14+ import org .opensearch .ml .grpc .interfaces .MLModelManager ;
15+ import org .opensearch .ml .grpc .interfaces .MLSdkClient ;
16+ import org .opensearch .ml .grpc .interfaces .MLTaskRunner ;
17+ import org .opensearch .ml .grpc .interfaces .MLUserContextProvider ;
1218import org .opensearch .transport .grpc .spi .GrpcServiceFactory ;
1319
1420import io .grpc .BindableService ;
2228@ Log4j2
2329public class MLGrpcServiceFactory implements GrpcServiceFactory {
2430
25- private static volatile Object modelManager ;
26- private static volatile Object predictTaskRunner ;
27- private static volatile Object executeTaskRunner ;
31+ private static volatile MLModelManager modelManager ;
32+ private static volatile MLTaskRunner predictTaskRunner ;
33+ private static volatile MLTaskRunner executeTaskRunner ;
2834 private static volatile MLFeatureEnabledSetting mlFeatureEnabledSetting ;
29- private static volatile Object modelAccessControlHelper ;
30- private static volatile Object client ;
31- private static volatile Object sdkClient ;
35+ private static volatile MLModelAccessControlHelper modelAccessControlHelper ;
36+ private static volatile MLClient client ;
37+ private static volatile MLSdkClient sdkClient ;
38+ private static volatile MLUserContextProvider userContextProvider ;
3239
3340 /**
3441 * No-arg constructor required by SPI.
@@ -48,23 +55,30 @@ public MLGrpcServiceFactory() {
4855 * @param modelAccessControlHelper helper for validating model access control
4956 * @param client OpenSearch client for validation
5057 * @param sdkClient SDK client for multi-tenant operations
58+ * @param userContextProvider provider for extracting user from security context
5159 */
5260 public static void initialize (
53- Object modelManager ,
54- Object predictTaskRunner ,
55- Object executeTaskRunner ,
61+ MLModelManager modelManager ,
62+ MLTaskRunner predictTaskRunner ,
63+ MLTaskRunner executeTaskRunner ,
5664 MLFeatureEnabledSetting mlFeatureEnabledSetting ,
57- Object modelAccessControlHelper ,
58- Object client ,
59- Object sdkClient
65+ MLModelAccessControlHelper modelAccessControlHelper ,
66+ MLClient client ,
67+ MLSdkClient sdkClient ,
68+ MLUserContextProvider userContextProvider
6069 ) {
70+ if (modelManager == null || predictTaskRunner == null || executeTaskRunner == null ) {
71+ throw new IllegalArgumentException ("Required dependencies cannot be null" );
72+ }
73+
6174 MLGrpcServiceFactory .modelManager = modelManager ;
6275 MLGrpcServiceFactory .predictTaskRunner = predictTaskRunner ;
6376 MLGrpcServiceFactory .executeTaskRunner = executeTaskRunner ;
6477 MLGrpcServiceFactory .mlFeatureEnabledSetting = mlFeatureEnabledSetting ;
6578 MLGrpcServiceFactory .modelAccessControlHelper = modelAccessControlHelper ;
6679 MLGrpcServiceFactory .client = client ;
6780 MLGrpcServiceFactory .sdkClient = sdkClient ;
81+ MLGrpcServiceFactory .userContextProvider = userContextProvider ;
6882 }
6983
7084 @ Override
@@ -77,6 +91,11 @@ public List<BindableService> build() {
7791 List <BindableService > services = new ArrayList <>();
7892
7993 try {
94+ // Validate initialization
95+ if (modelManager == null || predictTaskRunner == null || executeTaskRunner == null ) {
96+ throw new IllegalStateException ("MLGrpcServiceFactory not initialized. Call initialize() first." );
97+ }
98+
8099 // Create ML streaming service
81100 MLStreamingService streamingService = new MLStreamingService (
82101 modelManager ,
@@ -85,7 +104,8 @@ public List<BindableService> build() {
85104 mlFeatureEnabledSetting ,
86105 modelAccessControlHelper ,
87106 client ,
88- sdkClient
107+ sdkClient ,
108+ userContextProvider
89109 );
90110
91111 // Wrap service with tenant ID interceptor to extract metadata
0 commit comments