diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java b/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java index 59ac3fdacf7..21db3793b01 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/tracker/RequestIdGenerator.java @@ -19,20 +19,21 @@ import com.datastax.oss.driver.api.core.cql.Statement; import com.datastax.oss.driver.api.core.session.Request; -import com.datastax.oss.protocol.internal.util.collection.NullAllowingImmutableMap; import edu.umd.cs.findbugs.annotations.NonNull; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; /** * Interface responsible for generating request IDs. * - *

Note that all request IDs have a parent/child relationship. A "parent ID" can loosely be - * thought of as encompassing a sequence of a request + any attendant retries, speculative + *

Note that all request IDs have a parent/child relationship. A "session request ID" can loosely + * be thought of as encompassing a sequence of a request + any attendant retries, speculative * executions etc. It's scope is identical to that of a {@link - * com.datastax.oss.driver.internal.core.cql.CqlRequestHandler}. A "request ID" represents a single - * request within this larger scope. Note that a request corresponding to a request ID may be + * com.datastax.oss.driver.internal.core.cql.CqlRequestHandler}. A "node request ID" represents a + * single request within this larger scope. Note that a request corresponding to a request ID may be * retried; in that case the retry count will be appended to the corresponding identifier in the * logs. */ @@ -67,11 +68,17 @@ default String getCustomPayloadKey() { default Statement getDecoratedStatement( @NonNull Statement statement, @NonNull String requestId) { - Map customPayload = - NullAllowingImmutableMap.builder() - .putAll(statement.getCustomPayload()) - .put(getCustomPayloadKey(), ByteBuffer.wrap(requestId.getBytes(StandardCharsets.UTF_8))) - .build(); - return statement.setCustomPayload(customPayload); + + Map existing = new HashMap<>(statement.getCustomPayload()); + String key = getCustomPayloadKey(); + + // Add or overwrite + existing.put(key, ByteBuffer.wrap(requestId.getBytes(StandardCharsets.UTF_8))); + + // Allowing null key/values + // Wrap a map inside to be immutable without instanciating a new map + Map unmodifiableMap = Collections.unmodifiableMap(existing); + + return statement.setCustomPayload(unmodifiableMap); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java index bea52891c18..ccac873c616 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandlerRetryTest.java @@ -48,6 +48,8 @@ import com.datastax.oss.driver.api.core.servererrors.ServerError; import com.datastax.oss.driver.api.core.servererrors.UnavailableException; import com.datastax.oss.driver.api.core.servererrors.WriteTimeoutException; +import com.datastax.oss.driver.api.core.session.Request; +import com.datastax.oss.driver.api.core.tracker.RequestIdGenerator; import com.datastax.oss.protocol.internal.ProtocolConstants; import com.datastax.oss.protocol.internal.response.Error; import com.datastax.oss.protocol.internal.response.error.ReadTimeout; @@ -55,9 +57,13 @@ import com.datastax.oss.protocol.internal.response.error.WriteTimeout; import com.tngtech.java.junit.dataprovider.DataProvider; import com.tngtech.java.junit.dataprovider.UseDataProvider; +import edu.umd.cs.findbugs.annotations.NonNull; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Iterator; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; public class CqlRequestHandlerRetryTest extends CqlRequestHandlerTestBase { @@ -384,6 +390,63 @@ public void should_rethrow_error_if_not_idempotent_and_error_unsafe_or_policy_re } } + @Test + @UseDataProvider("failureAndIdempotent") + public void should_not_fail_with_duplicate_key_when_retrying_with_request_id_generator( + FailureScenario failureScenario, boolean defaultIdempotence, Statement statement) { + + // Create a RequestIdGenerator that uses the same key as the statement's custom payload + RequestIdGenerator requestIdGenerator = + new RequestIdGenerator() { + private AtomicInteger counter = new AtomicInteger(0); + + @Override + public String getSessionRequestId() { + return "session-123"; + } + + @Override + public String getNodeRequestId(@NonNull Request request, @NonNull String parentId) { + return parentId + "-" + counter.getAndIncrement(); + } + }; + + RequestHandlerTestHarness.Builder harnessBuilder = + RequestHandlerTestHarness.builder() + .withDefaultIdempotence(defaultIdempotence) + .withRequestIdGenerator(requestIdGenerator); + failureScenario.mockRequestError(harnessBuilder, node1); + harnessBuilder.withResponse(node2, defaultFrameOf(singleRow())); + + try (RequestHandlerTestHarness harness = harnessBuilder.build()) { + failureScenario.mockRetryPolicyVerdict( + harness.getContext().getRetryPolicy(anyString()), RetryVerdict.RETRY_NEXT); + + CompletionStage resultSetFuture = + new CqlRequestHandler(statement, harness.getSession(), harness.getContext(), "test") + .handle(); + + // The test should succeed without throwing a duplicate key exception + assertThatStage(resultSetFuture) + .isSuccess( + resultSet -> { + Iterator rows = resultSet.currentPage().iterator(); + assertThat(rows.hasNext()).isTrue(); + assertThat(rows.next().getString("message")).isEqualTo("hello, world"); + + ExecutionInfo executionInfo = resultSet.getExecutionInfo(); + assertThat(executionInfo.getCoordinator()).isEqualTo(node2); + assertThat(executionInfo.getErrors()).hasSize(1); + assertThat(executionInfo.getErrors().get(0).getKey()).isEqualTo(node1); + + // Verify that the custom payload still contains the request ID key + // (either the original value or the generated one, depending on implementation) + assertThat(executionInfo.getRequest().getCustomPayload().get("request-id")) + .isEqualTo(ByteBuffer.wrap("session-123-1".getBytes(StandardCharsets.UTF_8))); + }); + } + } + /** * Sets up the mocks to simulate an error from a node, and make the retry policy return a given * decision for that error. diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java index 6ecd6111992..6a7657d5809 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java @@ -37,6 +37,7 @@ import com.datastax.oss.driver.api.core.session.Session; import com.datastax.oss.driver.api.core.specex.SpeculativeExecutionPolicy; import com.datastax.oss.driver.api.core.time.TimestampGenerator; +import com.datastax.oss.driver.api.core.tracker.RequestIdGenerator; import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry; import com.datastax.oss.driver.internal.core.DefaultConsistencyLevelRegistry; import com.datastax.oss.driver.internal.core.ProtocolFeature; @@ -170,7 +171,8 @@ protected RequestHandlerTestHarness(Builder builder) { when(context.getRequestTracker()).thenReturn(new NoopRequestTracker(context)); - when(context.getRequestIdGenerator()).thenReturn(Optional.empty()); + when(context.getRequestIdGenerator()) + .thenReturn(Optional.ofNullable(builder.requestIdGenerator)); } public DefaultSession getSession() { @@ -203,6 +205,7 @@ public static class Builder { private final List poolBehaviors = new ArrayList<>(); private boolean defaultIdempotence; private ProtocolVersion protocolVersion; + private RequestIdGenerator requestIdGenerator; /** * Sets the given node as the next one in the query plan; an empty pool will be simulated when @@ -258,6 +261,11 @@ public Builder withProtocolVersion(ProtocolVersion protocolVersion) { return this; } + public Builder withRequestIdGenerator(RequestIdGenerator requestIdGenerator) { + this.requestIdGenerator = requestIdGenerator; + return this; + } + /** * Sets the given node as the next one in the query plan; the test code is responsible of * calling the methods on the returned object to complete the write and the query. diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java index 2848a8fb629..516a62bb1f7 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/tracker/RequestIdGeneratorIT.java @@ -17,12 +17,14 @@ */ package com.datastax.oss.driver.core.tracker; +import static com.datastax.oss.driver.Assertions.assertThatStage; import static org.assertj.core.api.Assertions.assertThat; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfigLoader; import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.cql.Statement; import com.datastax.oss.driver.api.core.session.Request; import com.datastax.oss.driver.api.core.tracker.RequestIdGenerator; @@ -119,7 +121,24 @@ public void should_not_write_id_to_custom_payload_when_key_is_not_set() { try (CqlSession session = SessionUtils.newSession(ccmRule, loader)) { String query = "SELECT * FROM system.local"; ResultSet rs = session.execute(query); - assertThat(rs.getExecutionInfo().getRequest().getCustomPayload().get("trace_key")).isNull(); + assertThat(rs.getExecutionInfo().getRequest().getCustomPayload().get("request-id")).isNull(); + } + } + + @Test + public void should_succeed_with_null_value_in_custom_payload() { + DriverConfigLoader loader = + SessionUtils.configLoaderBuilder() + .withString( + DefaultDriverOption.REQUEST_ID_GENERATOR_CLASS, "W3CContextRequestIdGenerator") + .build(); + try (CqlSession session = SessionUtils.newSession(ccmRule, loader)) { + String query = "SELECT * FROM system.local"; + Map customPayload = + new NullAllowingImmutableMap.Builder(1).put("my_key", null).build(); + SimpleStatement statement = + SimpleStatement.newInstance(query).setCustomPayload(customPayload); + assertThatStage(session.executeAsync(statement)).isSuccess(); } } }