Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
import org.opensearch.tools.jackson.core.JsonParseException;
import org.reflections.Reflections;

import com.fasterxml.jackson.core.JsonParseException;

import lombok.extern.log4j.Log4j2;

@Log4j2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput;
import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput;
import org.opensearch.search.SearchModule;

import com.fasterxml.jackson.core.JsonParseException;
import org.opensearch.tools.jackson.core.JsonParseException;

public class MLCommonsClassLoaderTests {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchModule;

import com.fasterxml.jackson.core.JsonParseException;
import org.opensearch.tools.jackson.core.JsonParseException;

public class MLDeployingSettingTests {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
import org.opensearch.remote.metadata.client.GetDataObjectResponse;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -1815,6 +1816,49 @@ private void mockSdkClientWithAgent(String getResponseJson) {
});
}

/**
* Mock sdkClient to return agent JSON on first call, then fail on subsequent calls (e.g., model metadata fetch).
* This ensures deterministic test behavior for agents that require model lookups.
* Handles both single-arg and two-arg overloads of getDataObjectAsync.
*/
private void mockSdkClientWithAgentThenFail(String getResponseJson) {
when(clusterService.localNode()).thenReturn(localNode);
when(localNode.getId()).thenReturn("test-node-id");
AtomicBoolean firstCall = new AtomicBoolean(true);

// Two-arg version used by MLAgentExecutor.execute() for agent fetch
when(sdkClient.getDataObjectAsync(any(), any())).thenAnswer(inv -> {
CompletionStage<GetDataObjectResponse> stage = mock(CompletionStage.class);
when(stage.whenComplete(any())).thenAnswer(cbInv -> {
BiConsumer<GetDataObjectResponse, Throwable> cb = cbInv.getArgument(0);
if (firstCall.getAndSet(false)) {
// First call: return agent data
GetDataObjectResponse resp = mock(GetDataObjectResponse.class);
when(resp.parser())
.thenReturn(XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, getResponseJson));
cb.accept(resp, null);
} else {
// Subsequent calls: fail with not found
cb.accept(null, new RuntimeException("Model not found"));
}
return stage;
});
return stage;
});

// Single-arg version used by AgentUtils for model metadata lookups
when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class))).thenAnswer(inv -> {
CompletionStage<GetDataObjectResponse> stage = mock(CompletionStage.class);
when(stage.whenComplete(any())).thenAnswer(cbInv -> {
BiConsumer<GetDataObjectResponse, Throwable> cb = cbInv.getArgument(0);
// Model metadata calls should fail deterministically
cb.accept(null, new RuntimeException("Model not found"));
return stage;
});
return stage;
});
}

private AgentMLInput buildV2AgentMLInput() {
ContentBlock textBlock = new ContentBlock();
textBlock.setType(ContentType.TEXT);
Expand Down Expand Up @@ -1850,7 +1894,7 @@ public void test_ExecuteAgent_NonV2Agent_DoesNotRouteToExecuteV2Agent() throws I
.lastUpdateTime(Instant.now())
.build();

mockSdkClientWithAgent(serializeAgentToGetResponseJson(convAgent));
mockSdkClientWithAgentThenFail(serializeAgentToGetResponseJson(convAgent));

Map<String, String> params = new HashMap<>();
params.put(QUESTION, "What is ML?");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,16 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.io.OutputStream;

import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentGenerator;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -66,11 +60,7 @@ public void testToXContent() throws IOException {
SearchResponse.Clusters.EMPTY,
"iid"
);
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
XContentBuilder builder = XContentFactory.jsonBuilder();
XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(actual);
}
Expand Down Expand Up @@ -99,11 +89,7 @@ public void testToXContentWithError() throws IOException {
SearchResponse.Clusters.EMPTY,
"iid"
);
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
XContentBuilder builder = XContentFactory.jsonBuilder();
XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(actual);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,18 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.opensearch.Version;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentGenerator;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -276,22 +271,13 @@ public void testToXConent() throws IOException {
null,
messageList
);
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
XContentBuilder builder = XContentFactory.jsonBuilder();
assertNotNull(parameters.toXContent(builder, null));
}

public void testToXContentEmptyParams() throws IOException {
GenerativeQAParameters parameters = new GenerativeQAParameters();
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
parameters.toXContent(builder, null);
XContentBuilder builder = XContentFactory.jsonBuilder();
assertNotNull(parameters.toXContent(builder, null));
}

Expand All @@ -316,11 +302,7 @@ public void testToXContentAllOptionalParameters() throws IOException {
timeout,
llmResponseField
);
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
XContentBuilder builder = XContentFactory.jsonBuilder();
assertNotNull(parameters.toXContent(builder, null));
}
}
Loading