diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java b/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java
index 59ac3fdacf7..21db3793b01 100644
--- a/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java
+++ b/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java
@@ -19,20 +19,21 @@
import com.datastax.oss.driver.api.core.cql.Statement;
import com.datastax.oss.driver.api.core.session.Request;
-import com.datastax.oss.protocol.internal.util.collection.NullAllowingImmutableMap;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.Map;
/**
* Interface responsible for generating request IDs.
*
- *
Note that all request IDs have a parent/child relationship. A "parent ID" can loosely be
- * thought of as encompassing a sequence of a request + any attendant retries, speculative
+ *
Note that all request IDs have a parent/child relationship. A "session request ID" can loosely
+ * be thought of as encompassing a sequence of a request + any attendant retries, speculative
* executions etc. It's scope is identical to that of a {@link
- * com.datastax.oss.driver.internal.core.cql.CqlRequestHandler}. A "request ID" represents a single
- * request within this larger scope. Note that a request corresponding to a request ID may be
+ * com.datastax.oss.driver.internal.core.cql.CqlRequestHandler}. A "node request ID" represents a
+ * single request within this larger scope. Note that a request corresponding to a request ID may be
* retried; in that case the retry count will be appended to the corresponding identifier in the
* logs.
*/
@@ -67,11 +68,17 @@ default String getCustomPayloadKey() {
default Statement> getDecoratedStatement(
@NonNull Statement> statement, @NonNull String requestId) {
- Map customPayload =
- NullAllowingImmutableMap.builder()
- .putAll(statement.getCustomPayload())
- .put(getCustomPayloadKey(), ByteBuffer.wrap(requestId.getBytes(StandardCharsets.UTF_8)))
- .build();
- return statement.setCustomPayload(customPayload);
+
+ Map existing = new HashMap<>(statement.getCustomPayload());
+ String key = getCustomPayloadKey();
+
+ // Add or overwrite
+ existing.put(key, ByteBuffer.wrap(requestId.getBytes(StandardCharsets.UTF_8)));
+
+ // Allowing null key/values
+ // Wrap a map inside to be immutable without instanciating a new map
+ Map unmodifiableMap = Collections.unmodifiableMap(existing);
+
+ return statement.setCustomPayload(unmodifiableMap);
}
}
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java
index bea52891c18..ccac873c616 100644
--- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java
@@ -48,6 +48,8 @@
import com.datastax.oss.driver.api.core.servererrors.ServerError;
import com.datastax.oss.driver.api.core.servererrors.UnavailableException;
import com.datastax.oss.driver.api.core.servererrors.WriteTimeoutException;
+import com.datastax.oss.driver.api.core.session.Request;
+import com.datastax.oss.driver.api.core.tracker.RequestIdGenerator;
import com.datastax.oss.protocol.internal.ProtocolConstants;
import com.datastax.oss.protocol.internal.response.Error;
import com.datastax.oss.protocol.internal.response.error.ReadTimeout;
@@ -55,9 +57,13 @@
import com.datastax.oss.protocol.internal.response.error.WriteTimeout;
import com.tngtech.java.junit.dataprovider.DataProvider;
import com.tngtech.java.junit.dataprovider.UseDataProvider;
+import edu.umd.cs.findbugs.annotations.NonNull;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
public class CqlRequestHandlerRetryTest extends CqlRequestHandlerTestBase {
@@ -384,6 +390,63 @@ public void should_rethrow_error_if_not_idempotent_and_error_unsafe_or_policy_re
}
}
+ @Test
+ @UseDataProvider("failureAndIdempotent")
+ public void should_not_fail_with_duplicate_key_when_retrying_with_request_id_generator(
+ FailureScenario failureScenario, boolean defaultIdempotence, Statement> statement) {
+
+ // Create a RequestIdGenerator that uses the same key as the statement's custom payload
+ RequestIdGenerator requestIdGenerator =
+ new RequestIdGenerator() {
+ private AtomicInteger counter = new AtomicInteger(0);
+
+ @Override
+ public String getSessionRequestId() {
+ return "session-123";
+ }
+
+ @Override
+ public String getNodeRequestId(@NonNull Request request, @NonNull String parentId) {
+ return parentId + "-" + counter.getAndIncrement();
+ }
+ };
+
+ RequestHandlerTestHarness.Builder harnessBuilder =
+ RequestHandlerTestHarness.builder()
+ .withDefaultIdempotence(defaultIdempotence)
+ .withRequestIdGenerator(requestIdGenerator);
+ failureScenario.mockRequestError(harnessBuilder, node1);
+ harnessBuilder.withResponse(node2, defaultFrameOf(singleRow()));
+
+ try (RequestHandlerTestHarness harness = harnessBuilder.build()) {
+ failureScenario.mockRetryPolicyVerdict(
+ harness.getContext().getRetryPolicy(anyString()), RetryVerdict.RETRY_NEXT);
+
+ CompletionStage resultSetFuture =
+ new CqlRequestHandler(statement, harness.getSession(), harness.getContext(), "test")
+ .handle();
+
+ // The test should succeed without throwing a duplicate key exception
+ assertThatStage(resultSetFuture)
+ .isSuccess(
+ resultSet -> {
+ Iterator rows = resultSet.currentPage().iterator();
+ assertThat(rows.hasNext()).isTrue();
+ assertThat(rows.next().getString("message")).isEqualTo("hello, world");
+
+ ExecutionInfo executionInfo = resultSet.getExecutionInfo();
+ assertThat(executionInfo.getCoordinator()).isEqualTo(node2);
+ assertThat(executionInfo.getErrors()).hasSize(1);
+ assertThat(executionInfo.getErrors().get(0).getKey()).isEqualTo(node1);
+
+ // Verify that the custom payload still contains the request ID key
+ // (either the original value or the generated one, depending on implementation)
+ assertThat(executionInfo.getRequest().getCustomPayload().get("request-id"))
+ .isEqualTo(ByteBuffer.wrap("session-123-1".getBytes(StandardCharsets.UTF_8)));
+ });
+ }
+ }
+
/**
* Sets up the mocks to simulate an error from a node, and make the retry policy return a given
* decision for that error.
diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java
index 6ecd6111992..6a7657d5809 100644
--- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java
+++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java
@@ -37,6 +37,7 @@
import com.datastax.oss.driver.api.core.session.Session;
import com.datastax.oss.driver.api.core.specex.SpeculativeExecutionPolicy;
import com.datastax.oss.driver.api.core.time.TimestampGenerator;
+import com.datastax.oss.driver.api.core.tracker.RequestIdGenerator;
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
import com.datastax.oss.driver.internal.core.DefaultConsistencyLevelRegistry;
import com.datastax.oss.driver.internal.core.ProtocolFeature;
@@ -170,7 +171,8 @@ protected RequestHandlerTestHarness(Builder builder) {
when(context.getRequestTracker()).thenReturn(new NoopRequestTracker(context));
- when(context.getRequestIdGenerator()).thenReturn(Optional.empty());
+ when(context.getRequestIdGenerator())
+ .thenReturn(Optional.ofNullable(builder.requestIdGenerator));
}
public DefaultSession getSession() {
@@ -203,6 +205,7 @@ public static class Builder {
private final List poolBehaviors = new ArrayList<>();
private boolean defaultIdempotence;
private ProtocolVersion protocolVersion;
+ private RequestIdGenerator requestIdGenerator;
/**
* Sets the given node as the next one in the query plan; an empty pool will be simulated when
@@ -258,6 +261,11 @@ public Builder withProtocolVersion(ProtocolVersion protocolVersion) {
return this;
}
+ public Builder withRequestIdGenerator(RequestIdGenerator requestIdGenerator) {
+ this.requestIdGenerator = requestIdGenerator;
+ return this;
+ }
+
/**
* Sets the given node as the next one in the query plan; the test code is responsible of
* calling the methods on the returned object to complete the write and the query.
diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java
index 2848a8fb629..516a62bb1f7 100644
--- a/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java
+++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java
@@ -17,12 +17,14 @@
*/
package com.datastax.oss.driver.core.tracker;
+import static com.datastax.oss.driver.Assertions.assertThatStage;
import static org.assertj.core.api.Assertions.assertThat;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfigLoader;
import com.datastax.oss.driver.api.core.cql.ResultSet;
+import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.cql.Statement;
import com.datastax.oss.driver.api.core.session.Request;
import com.datastax.oss.driver.api.core.tracker.RequestIdGenerator;
@@ -119,7 +121,24 @@ public void should_not_write_id_to_custom_payload_when_key_is_not_set() {
try (CqlSession session = SessionUtils.newSession(ccmRule, loader)) {
String query = "SELECT * FROM system.local";
ResultSet rs = session.execute(query);
- assertThat(rs.getExecutionInfo().getRequest().getCustomPayload().get("trace_key")).isNull();
+ assertThat(rs.getExecutionInfo().getRequest().getCustomPayload().get("request-id")).isNull();
+ }
+ }
+
+ @Test
+ public void should_succeed_with_null_value_in_custom_payload() {
+ DriverConfigLoader loader =
+ SessionUtils.configLoaderBuilder()
+ .withString(
+ DefaultDriverOption.REQUEST_ID_GENERATOR_CLASS, "W3CContextRequestIdGenerator")
+ .build();
+ try (CqlSession session = SessionUtils.newSession(ccmRule, loader)) {
+ String query = "SELECT * FROM system.local";
+ Map customPayload =
+ new NullAllowingImmutableMap.Builder(1).put("my_key", null).build();
+ SimpleStatement statement =
+ SimpleStatement.newInstance(query).setCustomPayload(customPayload);
+ assertThatStage(session.executeAsync(statement)).isSuccess();
}
}
}