2929import org .mockito .Mockito ;
3030import org .mockito .MockitoAnnotations ;
3131import org .opensearch .common .collect .Tuple ;
32+ import org .opensearch .common .settings .Settings ;
33+ import org .opensearch .common .util .concurrent .ThreadContext ;
3234import org .opensearch .core .action .ActionListener ;
3335import org .opensearch .ml .common .FunctionName ;
3436import org .opensearch .ml .common .connector .Connector ;
4143import org .opensearch .ml .common .output .model .ModelTensors ;
4244import org .opensearch .ml .common .transport .MLTaskResponse ;
4345import org .opensearch .ml .engine .algorithms .remote .streaming .StreamPredictActionListener ;
46+ import org .opensearch .threadpool .ThreadPool ;
47+ import org .opensearch .transport .client .Client ;
4448
4549import com .google .common .collect .ImmutableMap ;
4650
@@ -51,9 +55,19 @@ public class HttpJsonConnectorExecutorTest {
5155 @ Mock
5256 private ActionListener <Tuple <Integer , ModelTensors >> actionListener ;
5357
58+ @ Mock
59+ private ThreadPool threadPool ;
60+
61+ @ Mock
62+ private Client client ;
63+
64+ private ThreadContext threadContext ;
65+
5466 @ Before
5567 public void setUp () {
5668 MockitoAnnotations .openMocks (this );
69+ Settings settings = Settings .builder ().build ();
70+ threadContext = new ThreadContext (settings );
5771 }
5872
5973 @ Test
@@ -95,8 +109,11 @@ public void invokeRemoteService_invalidIpAddress() {
95109 .protocol ("http" )
96110 .actions (Arrays .asList (predictAction ))
97111 .build ();
98- HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor (connector );
112+ HttpJsonConnectorExecutor executor = spy ( new HttpJsonConnectorExecutor (connector ) );
99113 executor .setConnectorPrivateIpEnabled (false );
114+ executor .setClient (client );
115+ when (client .threadPool ()).thenReturn (threadPool );
116+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
100117 executor
101118 .invokeRemoteService (
102119 PREDICT .name (),
@@ -128,8 +145,11 @@ public void invokeRemoteService_EnabledPrivateIpAddress() {
128145 .protocol ("http" )
129146 .actions (Arrays .asList (predictAction ))
130147 .build ();
131- HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor (connector );
148+ HttpJsonConnectorExecutor executor = spy ( new HttpJsonConnectorExecutor (connector ) );
132149 executor .setConnectorPrivateIpEnabled (true );
150+ executor .setClient (client );
151+ when (client .threadPool ()).thenReturn (threadPool );
152+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
133153 executor
134154 .invokeRemoteService (
135155 PREDICT .name (),
@@ -158,8 +178,11 @@ public void invokeRemoteService_DisabledPrivateIpAddress() {
158178 .protocol ("http" )
159179 .actions (Arrays .asList (predictAction ))
160180 .build ();
161- HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor (connector );
181+ HttpJsonConnectorExecutor executor = spy ( new HttpJsonConnectorExecutor (connector ) );
162182 executor .setConnectorPrivateIpEnabled (false );
183+ executor .setClient (client );
184+ when (client .threadPool ()).thenReturn (threadPool );
185+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
163186 executor
164187 .invokeRemoteService (
165188 PREDICT .name (),
@@ -215,7 +238,10 @@ public void invokeRemoteService_get_request() {
215238 .protocol ("http" )
216239 .actions (Arrays .asList (predictAction ))
217240 .build ();
218- HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor (connector );
241+ HttpJsonConnectorExecutor executor = spy (new HttpJsonConnectorExecutor (connector ));
242+ executor .setClient (client );
243+ when (client .threadPool ()).thenReturn (threadPool );
244+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
219245 executor .invokeRemoteService (PREDICT .name (), createMLInput (), new HashMap <>(), null , new ExecutionContext (0 ), actionListener );
220246 }
221247
@@ -235,7 +261,10 @@ public void invokeRemoteService_post_request() {
235261 .protocol ("http" )
236262 .actions (Arrays .asList (predictAction ))
237263 .build ();
238- HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor (connector );
264+ HttpJsonConnectorExecutor executor = spy (new HttpJsonConnectorExecutor (connector ));
265+ executor .setClient (client );
266+ when (client .threadPool ()).thenReturn (threadPool );
267+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
239268 executor
240269 .invokeRemoteService (PREDICT .name (), createMLInput (), new HashMap <>(), "hello world" , new ExecutionContext (0 ), actionListener );
241270 }
@@ -257,6 +286,9 @@ public void invokeRemoteService_nullHttpClient_throwMLException() {
257286 .actions (Arrays .asList (predictAction ))
258287 .build ();
259288 HttpJsonConnectorExecutor executor = spy (new HttpJsonConnectorExecutor (connector ));
289+ executor .setClient (client );
290+ when (client .threadPool ()).thenReturn (threadPool );
291+ when (threadPool .getThreadContext ()).thenReturn (threadContext );
260292 when (executor .getHttpClient ()).thenReturn (null );
261293 executor
262294 .invokeRemoteService (PREDICT .name (), createMLInput (), new HashMap <>(), "hello world" , new ExecutionContext (0 ), actionListener );
0 commit comments