Skip to content

Commit 9c24557

Browse files
Refactor code
Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>
1 parent a5ce2c8 commit 9c24557

20 files changed

+680
-311
lines changed

grpc/src/main/java/org/opensearch/ml/grpc/GrpcStatusMapper.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.grpc;
77

88
import java.io.IOException;
9+
import java.util.Map;
910

1011
import org.opensearch.OpenSearchException;
1112
import org.opensearch.OpenSearchSecurityException;
@@ -20,6 +21,21 @@
2021
* Maps OpenSearch and ML Commons exceptions to gRPC Status codes.
2122
*/
2223
public class GrpcStatusMapper {
24+
private static final Map<RestStatus, Status> STATUS_MAP = Map
25+
.of(
26+
RestStatus.FORBIDDEN,
27+
Status.PERMISSION_DENIED,
28+
RestStatus.NOT_FOUND,
29+
Status.NOT_FOUND,
30+
RestStatus.TOO_MANY_REQUESTS,
31+
Status.RESOURCE_EXHAUSTED,
32+
RestStatus.BAD_REQUEST,
33+
Status.INVALID_ARGUMENT,
34+
RestStatus.SERVICE_UNAVAILABLE,
35+
Status.UNAVAILABLE,
36+
RestStatus.UNAUTHORIZED,
37+
Status.UNAUTHENTICATED
38+
);
2339

2440
/**
2541
* Converts an exception to a gRPC Status.
@@ -48,17 +64,8 @@ public static Status toGrpcStatus(Exception exception) {
4864

4965
// Handle OpenSearch exceptions with status codes
5066
if (exception instanceof OpenSearchException osException) {
51-
RestStatus status = osException.status();
52-
53-
return switch (status) {
54-
case FORBIDDEN -> Status.PERMISSION_DENIED.withDescription(exception.getMessage()).withCause(exception);
55-
case NOT_FOUND -> Status.NOT_FOUND.withDescription(exception.getMessage()).withCause(exception);
56-
case TOO_MANY_REQUESTS -> Status.RESOURCE_EXHAUSTED.withDescription(exception.getMessage()).withCause(exception);
57-
case BAD_REQUEST -> Status.INVALID_ARGUMENT.withDescription(exception.getMessage()).withCause(exception);
58-
case SERVICE_UNAVAILABLE -> Status.UNAVAILABLE.withDescription(exception.getMessage()).withCause(exception);
59-
case UNAUTHORIZED -> Status.UNAUTHENTICATED.withDescription(exception.getMessage()).withCause(exception);
60-
default -> Status.INTERNAL.withDescription(exception.getMessage()).withCause(exception);
61-
};
67+
Status grpcStatus = STATUS_MAP.getOrDefault(osException.status(), Status.INTERNAL);
68+
return grpcStatus.withDescription(exception.getMessage()).withCause(exception);
6269
}
6370

6471
// Handle I/O exceptions

grpc/src/main/java/org/opensearch/ml/grpc/MLGrpcServiceFactory.java

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
import java.util.List;
1010

1111
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
12+
import org.opensearch.ml.grpc.interfaces.MLClient;
13+
import org.opensearch.ml.grpc.interfaces.MLModelAccessControlHelper;
14+
import org.opensearch.ml.grpc.interfaces.MLModelManager;
15+
import org.opensearch.ml.grpc.interfaces.MLSdkClient;
16+
import org.opensearch.ml.grpc.interfaces.MLTaskRunner;
17+
import org.opensearch.ml.grpc.interfaces.MLUserContextProvider;
1218
import org.opensearch.transport.grpc.spi.GrpcServiceFactory;
1319

1420
import io.grpc.BindableService;
@@ -22,13 +28,14 @@
2228
@Log4j2
2329
public class MLGrpcServiceFactory implements GrpcServiceFactory {
2430

25-
private static volatile Object modelManager;
26-
private static volatile Object predictTaskRunner;
27-
private static volatile Object executeTaskRunner;
31+
private static volatile MLModelManager modelManager;
32+
private static volatile MLTaskRunner predictTaskRunner;
33+
private static volatile MLTaskRunner executeTaskRunner;
2834
private static volatile MLFeatureEnabledSetting mlFeatureEnabledSetting;
29-
private static volatile Object modelAccessControlHelper;
30-
private static volatile Object client;
31-
private static volatile Object sdkClient;
35+
private static volatile MLModelAccessControlHelper modelAccessControlHelper;
36+
private static volatile MLClient client;
37+
private static volatile MLSdkClient sdkClient;
38+
private static volatile MLUserContextProvider userContextProvider;
3239

3340
/**
3441
* No-arg constructor required by SPI.
@@ -48,23 +55,30 @@ public MLGrpcServiceFactory() {
4855
* @param modelAccessControlHelper helper for validating model access control
4956
* @param client OpenSearch client for validation
5057
* @param sdkClient SDK client for multi-tenant operations
58+
* @param userContextProvider provider for extracting user from security context
5159
*/
5260
public static void initialize(
53-
Object modelManager,
54-
Object predictTaskRunner,
55-
Object executeTaskRunner,
61+
MLModelManager modelManager,
62+
MLTaskRunner predictTaskRunner,
63+
MLTaskRunner executeTaskRunner,
5664
MLFeatureEnabledSetting mlFeatureEnabledSetting,
57-
Object modelAccessControlHelper,
58-
Object client,
59-
Object sdkClient
65+
MLModelAccessControlHelper modelAccessControlHelper,
66+
MLClient client,
67+
MLSdkClient sdkClient,
68+
MLUserContextProvider userContextProvider
6069
) {
70+
if (modelManager == null || predictTaskRunner == null || executeTaskRunner == null) {
71+
throw new IllegalArgumentException("Required dependencies cannot be null");
72+
}
73+
6174
MLGrpcServiceFactory.modelManager = modelManager;
6275
MLGrpcServiceFactory.predictTaskRunner = predictTaskRunner;
6376
MLGrpcServiceFactory.executeTaskRunner = executeTaskRunner;
6477
MLGrpcServiceFactory.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
6578
MLGrpcServiceFactory.modelAccessControlHelper = modelAccessControlHelper;
6679
MLGrpcServiceFactory.client = client;
6780
MLGrpcServiceFactory.sdkClient = sdkClient;
81+
MLGrpcServiceFactory.userContextProvider = userContextProvider;
6882
}
6983

7084
@Override
@@ -77,6 +91,11 @@ public List<BindableService> build() {
7791
List<BindableService> services = new ArrayList<>();
7892

7993
try {
94+
// Validate initialization
95+
if (modelManager == null || predictTaskRunner == null || executeTaskRunner == null) {
96+
throw new IllegalStateException("MLGrpcServiceFactory not initialized. Call initialize() first.");
97+
}
98+
8099
// Create ML streaming service
81100
MLStreamingService streamingService = new MLStreamingService(
82101
modelManager,
@@ -85,7 +104,8 @@ public List<BindableService> build() {
85104
mlFeatureEnabledSetting,
86105
modelAccessControlHelper,
87106
client,
88-
sdkClient
107+
sdkClient,
108+
userContextProvider
89109
);
90110

91111
// Wrap service with tenant ID interceptor to extract metadata

0 commit comments

Comments
 (0)