Skip to content

Commit fbe3ca4

Browse files
committed
fix ut
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 82a4bd4 commit fbe3ca4

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() {
245245
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
246246
Settings settings = Settings.builder().build();
247247
threadContext = new ThreadContext(settings);
248-
when(executor.getClient()).thenReturn(client);
248+
executor.setClient(client);
249249
when(client.threadPool()).thenReturn(threadPool);
250250
when(threadPool.getThreadContext()).thenReturn(threadContext);
251251

@@ -724,7 +724,7 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() {
724724
Settings settings = Settings.builder().build();
725725
threadContext = new ThreadContext(settings);
726726
ExecutorService executorService = mock(ExecutorService.class);
727-
when(executor.getClient()).thenReturn(client);
727+
executor.setClient(client);
728728
when(client.threadPool()).thenReturn(threadPool);
729729
when(threadPool.getThreadContext()).thenReturn(threadContext);
730730
when(threadPool.executor(any())).thenReturn(executorService);

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import org.mockito.Mockito;
3030
import org.mockito.MockitoAnnotations;
3131
import org.opensearch.common.collect.Tuple;
32+
import org.opensearch.common.settings.Settings;
33+
import org.opensearch.common.util.concurrent.ThreadContext;
3234
import org.opensearch.core.action.ActionListener;
3335
import org.opensearch.ml.common.FunctionName;
3436
import org.opensearch.ml.common.connector.Connector;
@@ -41,6 +43,8 @@
4143
import org.opensearch.ml.common.output.model.ModelTensors;
4244
import org.opensearch.ml.common.transport.MLTaskResponse;
4345
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;
46+
import org.opensearch.threadpool.ThreadPool;
47+
import org.opensearch.transport.client.Client;
4448

4549
import com.google.common.collect.ImmutableMap;
4650

@@ -51,9 +55,19 @@ public class HttpJsonConnectorExecutorTest {
5155
@Mock
5256
private ActionListener<Tuple<Integer, ModelTensors>> actionListener;
5357

58+
@Mock
59+
private ThreadPool threadPool;
60+
61+
@Mock
62+
private Client client;
63+
64+
private ThreadContext threadContext;
65+
5466
@Before
5567
public void setUp() {
5668
MockitoAnnotations.openMocks(this);
69+
Settings settings = Settings.builder().build();
70+
threadContext = new ThreadContext(settings);
5771
}
5872

5973
@Test
@@ -95,8 +109,11 @@ public void invokeRemoteService_invalidIpAddress() {
95109
.protocol("http")
96110
.actions(Arrays.asList(predictAction))
97111
.build();
98-
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
112+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
99113
executor.setConnectorPrivateIpEnabled(false);
114+
executor.setClient(client);
115+
when(client.threadPool()).thenReturn(threadPool);
116+
when(threadPool.getThreadContext()).thenReturn(threadContext);
100117
executor
101118
.invokeRemoteService(
102119
PREDICT.name(),
@@ -128,8 +145,11 @@ public void invokeRemoteService_EnabledPrivateIpAddress() {
128145
.protocol("http")
129146
.actions(Arrays.asList(predictAction))
130147
.build();
131-
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
148+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
132149
executor.setConnectorPrivateIpEnabled(true);
150+
executor.setClient(client);
151+
when(client.threadPool()).thenReturn(threadPool);
152+
when(threadPool.getThreadContext()).thenReturn(threadContext);
133153
executor
134154
.invokeRemoteService(
135155
PREDICT.name(),
@@ -158,8 +178,11 @@ public void invokeRemoteService_DisabledPrivateIpAddress() {
158178
.protocol("http")
159179
.actions(Arrays.asList(predictAction))
160180
.build();
161-
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
181+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
162182
executor.setConnectorPrivateIpEnabled(false);
183+
executor.setClient(client);
184+
when(client.threadPool()).thenReturn(threadPool);
185+
when(threadPool.getThreadContext()).thenReturn(threadContext);
163186
executor
164187
.invokeRemoteService(
165188
PREDICT.name(),
@@ -215,7 +238,10 @@ public void invokeRemoteService_get_request() {
215238
.protocol("http")
216239
.actions(Arrays.asList(predictAction))
217240
.build();
218-
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
241+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
242+
executor.setClient(client);
243+
when(client.threadPool()).thenReturn(threadPool);
244+
when(threadPool.getThreadContext()).thenReturn(threadContext);
219245
executor.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener);
220246
}
221247

@@ -235,7 +261,10 @@ public void invokeRemoteService_post_request() {
235261
.protocol("http")
236262
.actions(Arrays.asList(predictAction))
237263
.build();
238-
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
264+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
265+
executor.setClient(client);
266+
when(client.threadPool()).thenReturn(threadPool);
267+
when(threadPool.getThreadContext()).thenReturn(threadContext);
239268
executor
240269
.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener);
241270
}
@@ -257,6 +286,9 @@ public void invokeRemoteService_nullHttpClient_throwMLException() {
257286
.actions(Arrays.asList(predictAction))
258287
.build();
259288
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
289+
executor.setClient(client);
290+
when(client.threadPool()).thenReturn(threadPool);
291+
when(threadPool.getThreadContext()).thenReturn(threadContext);
260292
when(executor.getHttpClient()).thenReturn(null);
261293
executor
262294
.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener);

0 commit comments

Comments
 (0)