Skip to content

Commit 59b73bc

Browse files
feat: Implement pre-initialized Docker container pool to improve /eval request performance
- Added a pool of pre-initialized Docker containers, each assigned a unique session ID and ready for immediate use. - Pre-configured each container with the necessary startup scripts to ensure environments are ready for /eval requests. - On receiving a new /eval request, the system now allocates a container from the pool, reducing the need to create and initialize a container on demand. - Improved request latency by significantly reducing the time taken to start and initialize Docker containers during each session.
1 parent 4fb5436 commit 59b73bc

File tree

5 files changed

+226
-90
lines changed

5 files changed

+226
-90
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package org.togetherjava.jshellapi.dto;
2+
3+
import java.io.BufferedReader;
4+
import java.io.BufferedWriter;
5+
import java.io.InputStream;
6+
import java.io.OutputStream;
7+
8+
public record ContainerState(boolean isCached, String containerId, BufferedReader containerOutput, BufferedWriter containerInput) {
9+
}

JShellAPI/src/main/java/org/togetherjava/jshellapi/service/DockerService.java

Lines changed: 187 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.github.dockerjava.api.DockerClient;
44
import com.github.dockerjava.api.async.ResultCallback;
5+
import com.github.dockerjava.api.command.InspectContainerResponse;
56
import com.github.dockerjava.api.command.PullImageResultCallback;
67
import com.github.dockerjava.api.model.*;
78
import com.github.dockerjava.core.DefaultDockerClientConfig;
@@ -10,28 +11,35 @@
1011
import org.slf4j.Logger;
1112
import org.slf4j.LoggerFactory;
1213
import org.springframework.beans.factory.DisposableBean;
13-
import org.springframework.lang.Nullable;
1414
import org.springframework.stereotype.Service;
1515

1616
import org.togetherjava.jshellapi.Config;
17+
import org.togetherjava.jshellapi.dto.ContainerState;
1718

1819
import java.io.*;
1920
import java.nio.charset.StandardCharsets;
2021
import java.time.Duration;
2122
import java.util.*;
22-
import java.util.concurrent.TimeUnit;
23+
import java.util.concurrent.*;
2324

2425
@Service
2526
public class DockerService implements DisposableBean {
2627
private static final Logger LOGGER = LoggerFactory.getLogger(DockerService.class);
2728
private static final String WORKER_LABEL = "jshell-api-worker";
2829
private static final UUID WORKER_UNIQUE_ID = UUID.randomUUID();
30+
private static final String IMAGE_NAME = "togetherjava.org:5001/togetherjava/jshellwrapper";
31+
private static final String IMAGE_TAG = "master";
2932

3033
private final DockerClient client;
34+
private final Config config;
35+
private final ExecutorService executor = Executors.newSingleThreadExecutor();
36+
private final ConcurrentHashMap<StartupScriptId, String> cachedContainers = new ConcurrentHashMap<>();
37+
private final StartupScriptsService startupScriptsService;
3138

3239
private final String jshellWrapperBaseImageName;
3340

34-
public DockerService(Config config) {
41+
public DockerService(Config config, StartupScriptsService startupScriptsService) throws InterruptedException, IOException {
42+
this.startupScriptsService = startupScriptsService;
3543
DefaultDockerClientConfig clientConfig =
3644
DefaultDockerClientConfig.createDefaultConfigBuilder().build();
3745
ApacheDockerHttpClient httpClient =
@@ -41,11 +49,16 @@ public DockerService(Config config) {
4149
.connectionTimeout(Duration.ofSeconds(config.dockerConnectionTimeout()))
4250
.build();
4351
this.client = DockerClientImpl.getInstance(clientConfig, httpClient);
52+
this.config = config;
4453

4554
this.jshellWrapperBaseImageName =
4655
config.jshellWrapperImageName().split(Config.JSHELL_WRAPPER_IMAGE_NAME_TAG)[0];
4756

57+
if (!isImagePresentLocally()) {
58+
pullImage();
59+
}
4860
cleanupLeftovers(WORKER_UNIQUE_ID);
61+
executor.submit(() -> initializeCachedContainer(StartupScriptId.EMPTY));
4962
}
5063

5164
private void cleanupLeftovers(UUID currentId) {
@@ -62,80 +75,198 @@ private void cleanupLeftovers(UUID currentId) {
6275
}
6376
}
6477

65-
public String spawnContainer(long maxMemoryMegs, long cpus, @Nullable String cpuSetCpus,
66-
String name, Duration evalTimeout, long sysoutLimit) throws InterruptedException {
67-
68-
boolean presentLocally = client.listImagesCmd()
69-
.withFilter("reference", List.of(jshellWrapperBaseImageName))
70-
.exec()
71-
.stream()
72-
.flatMap(it -> Arrays.stream(it.getRepoTags()))
73-
.anyMatch(it -> it.endsWith(Config.JSHELL_WRAPPER_IMAGE_NAME_TAG));
78+
/**
79+
* Checks if the Docker image with the given name and tag is present locally.
80+
*
81+
* @return true if the image is present, false otherwise.
82+
*/
83+
private boolean isImagePresentLocally() {
84+
return client.listImagesCmd()
85+
.withFilter("reference", List.of(jshellWrapperBaseImageName))
86+
.exec()
87+
.stream()
88+
.flatMap(it -> Arrays.stream(it.getRepoTags()))
89+
.anyMatch(it -> it.endsWith(Config.JSHELL_WRAPPER_IMAGE_NAME_TAG));
90+
}
7491

75-
if (!presentLocally) {
92+
/**
93+
* Pulls the Docker image.
94+
*/
95+
private void pullImage() throws InterruptedException {
96+
if (!isImagePresentLocally()) {
7697
client.pullImageCmd(jshellWrapperBaseImageName)
77-
.withTag("master")
78-
.exec(new PullImageResultCallback())
79-
.awaitCompletion(5, TimeUnit.MINUTES);
98+
.withTag(IMAGE_TAG)
99+
.exec(new PullImageResultCallback())
100+
.awaitCompletion(5, TimeUnit.MINUTES);
80101
}
102+
}
81103

82-
return client
83-
.createContainerCmd(jshellWrapperBaseImageName + Config.JSHELL_WRAPPER_IMAGE_NAME_TAG)
84-
.withHostConfig(HostConfig.newHostConfig()
104+
/**
105+
* Creates a Docker container with the given name.
106+
*
107+
* @param name The name of the container to create.
108+
* @return The ID of the created container.
109+
*/
110+
public String createContainer(String name) {
111+
HostConfig hostConfig = HostConfig.newHostConfig()
85112
.withAutoRemove(true)
86113
.withInit(true)
87114
.withCapDrop(Capability.ALL)
88115
.withNetworkMode("none")
89116
.withPidsLimit(2000L)
90117
.withReadonlyRootfs(true)
91-
.withMemory(maxMemoryMegs * 1024 * 1024)
92-
.withCpuCount(cpus)
93-
.withCpusetCpus(cpuSetCpus))
94-
.withStdinOpen(true)
95-
.withAttachStdin(true)
96-
.withAttachStderr(true)
97-
.withAttachStdout(true)
98-
.withEnv("evalTimeoutSeconds=" + evalTimeout.toSeconds(),
99-
"sysOutCharLimit=" + sysoutLimit)
100-
.withLabels(Map.of(WORKER_LABEL, WORKER_UNIQUE_ID.toString()))
101-
.withName(name)
102-
.exec()
103-
.getId();
118+
.withMemory((long) config.dockerMaxRamMegaBytes() * 1024 * 1024)
119+
.withCpuCount((long) Math.ceil(config.dockerCPUsUsage()))
120+
.withCpusetCpus(config.dockerCPUSetCPUs());
121+
122+
return client.createContainerCmd(jshellWrapperBaseImageName + Config.JSHELL_WRAPPER_IMAGE_NAME_TAG)
123+
.withHostConfig(hostConfig)
124+
.withStdinOpen(true)
125+
.withAttachStdin(true)
126+
.withAttachStderr(true)
127+
.withAttachStdout(true)
128+
.withEnv("evalTimeoutSeconds=" + config.evalTimeoutSeconds(),
129+
"sysOutCharLimit=" + config.sysOutCharLimit())
130+
.withLabels(Map.of(WORKER_LABEL, WORKER_UNIQUE_ID.toString()))
131+
.withName(name)
132+
.exec()
133+
.getId();
134+
}
135+
136+
/**
137+
* Spawns a new Docker container with specified configurations.
138+
*
139+
* @param name Name of the container.
140+
* @param startupScriptId Script to initialize the container with.
141+
* @return The ContainerState of the newly created container.
142+
*/
143+
public ContainerState initializeContainer(String name, StartupScriptId startupScriptId) throws IOException {
144+
if (cachedContainers.isEmpty() || !cachedContainers.containsKey(startupScriptId)) {
145+
String containerId = createContainer(name);
146+
return setupContainerWithScript(containerId, true, startupScriptId);
147+
}
148+
String containerId = cachedContainers.get(startupScriptId);
149+
executor.submit(() -> initializeCachedContainer(startupScriptId));
150+
// Rename container with new name.
151+
client.renameContainerCmd(containerId).withName(name).exec();
152+
return setupContainerWithScript(containerId, false, startupScriptId);
153+
}
154+
155+
/**
156+
* Initializes a new cached docker container with specified configurations.
157+
*
158+
* @param startupScriptId Script to initialize the container with.
159+
*/
160+
private void initializeCachedContainer(StartupScriptId startupScriptId) {
161+
String containerName = cachedContainerName();
162+
String id = createContainer(containerName);
163+
startContainer(id);
164+
165+
try (PipedInputStream containerInput = new PipedInputStream();
166+
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new PipedOutputStream(containerInput)))) {
167+
attachToContainer(id, containerInput);
168+
169+
writer.write(Utils.sanitizeStartupScript(startupScriptsService.get(startupScriptId)));
170+
writer.newLine();
171+
writer.flush();
172+
173+
cachedContainers.put(startupScriptId, id);
174+
} catch (IOException e) {
175+
killContainerByName(containerName);
176+
throw new RuntimeException(e);
177+
}
104178
}
105179

106-
public InputStream startAndAttachToContainer(String containerId, InputStream stdin)
107-
throws IOException {
180+
/**
181+
*
182+
* @param containerId The id of the container
183+
* @param isCached Indicator if the container is cached or new
184+
* @param startupScriptId The startup script id of the session
185+
* @return ContainerState of the spawned container.
186+
* @throws IOException if an I/O error occurs
187+
*/
188+
private ContainerState setupContainerWithScript(String containerId, boolean isCached, StartupScriptId startupScriptId) throws IOException {
189+
if (!isCached) {
190+
startContainer(containerId);
191+
}
192+
PipedInputStream containerInput = new PipedInputStream();
193+
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new PipedOutputStream(containerInput)));
194+
195+
InputStream containerOutput = attachToContainer(containerId, containerInput);
196+
BufferedReader reader = new BufferedReader(new InputStreamReader(containerOutput));
197+
198+
if (!isCached) {
199+
writer.write(Utils.sanitizeStartupScript(startupScriptsService.get(startupScriptId)));
200+
writer.newLine();
201+
writer.flush();
202+
}
203+
204+
return new ContainerState(isCached, containerId, reader, writer);
205+
}
206+
207+
/**
208+
* Creates a new container
209+
* @param containerId the ID of the container to start
210+
*/
211+
public void startContainer(String containerId) {
212+
if (!isContainerRunning(containerId)) {
213+
client.startContainerCmd(containerId).exec();
214+
}
215+
}
216+
217+
/**
218+
* Attaches to a running Docker container's input (stdin) and output streams (stdout, stderr).
219+
* Logs any output from stderr and returns an InputStream to read stdout.
220+
*
221+
* @param containerId the ID of the running container to attach to
222+
* @param containerInput the input stream (containerInput) to send to the container
223+
* @return InputStream to read the container's stdout
224+
* @throws IOException if an I/O error occurs
225+
*/
226+
public InputStream attachToContainer(String containerId, InputStream containerInput) throws IOException {
108227
PipedInputStream pipeIn = new PipedInputStream();
109228
PipedOutputStream pipeOut = new PipedOutputStream(pipeIn);
110229

111230
client.attachContainerCmd(containerId)
112-
.withLogs(true)
113-
.withFollowStream(true)
114-
.withStdOut(true)
115-
.withStdErr(true)
116-
.withStdIn(stdin)
117-
.exec(new ResultCallback.Adapter<>() {
118-
@Override
119-
public void onNext(Frame object) {
120-
try {
121-
String payloadString =
122-
new String(object.getPayload(), StandardCharsets.UTF_8);
123-
if (object.getStreamType() == StreamType.STDOUT) {
124-
pipeOut.write(object.getPayload());
125-
} else {
126-
LOGGER.warn("Received STDERR from container {}: {}", containerId,
127-
payloadString);
231+
.withLogs(true)
232+
.withFollowStream(true)
233+
.withStdOut(true)
234+
.withStdErr(true)
235+
.withStdIn(containerInput)
236+
.exec(new ResultCallback.Adapter<>() {
237+
@Override
238+
public void onNext(Frame object) {
239+
try {
240+
String payloadString = new String(object.getPayload(), StandardCharsets.UTF_8);
241+
if (object.getStreamType() == StreamType.STDOUT) {
242+
pipeOut.write(object.getPayload()); // Write stdout data to pipeOut
243+
} else {
244+
LOGGER.warn("Received STDERR from container {}: {}", containerId, payloadString);
245+
}
246+
} catch (IOException e) {
247+
throw new UncheckedIOException(e);
128248
}
129-
} catch (IOException e) {
130-
throw new UncheckedIOException(e);
131249
}
132-
}
133-
});
250+
});
134251

135-
client.startContainerCmd(containerId).exec();
136252
return pipeIn;
137253
}
138254

255+
/**
256+
* Checks if the Docker container with the given ID is currently running.
257+
*
258+
* @param containerId the ID of the container to check
259+
* @return true if the container is running, false otherwise
260+
*/
261+
public boolean isContainerRunning(String containerId) {
262+
InspectContainerResponse containerResponse = client.inspectContainerCmd(containerId).exec();
263+
return Boolean.TRUE.equals(containerResponse.getState().getRunning());
264+
}
265+
266+
private String cachedContainerName() {
267+
return "cached_session_" + UUID.randomUUID();
268+
}
269+
139270
public void killContainerByName(String name) {
140271
LOGGER.debug("Fetching container to kill {}.", name);
141272
List<Container> containers = client.listContainersCmd().withNameFilter(Set.of(name)).exec();
@@ -156,6 +287,7 @@ public boolean isDead(String containerName) {
156287
@Override
157288
public void destroy() throws Exception {
158289
LOGGER.info("destroy() called. Destroying all containers...");
290+
executor.shutdown();
159291
cleanupLeftovers(UUID.randomUUID());
160292
client.close();
161293
}

0 commit comments

Comments
 (0)