Skip to content

Commit 6350ed6

Browse files
authored
added part count validation for download request (#6353)
* added part count and content range validation for download * changelog added * remove contentRange validation since part size could different * feedback addressed * delete unused utils * change validation in futurn complete * added unit test * change test name * minor change * minor change * Integration test added * Unit test fixed * swap method body of onError and handleError * minor change
1 parent fb8f93a commit 6350ed6

File tree

5 files changed

+198
-1
lines changed

5 files changed

+198
-1
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "bugfix",
3+
"category": "Amazon S3",
4+
"contributor": "",
5+
"description": "Added additional validations for multipart download operations in the Java multipart S3 client"
6+
}

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.reactivestreams.Subscription;
2424
import software.amazon.awssdk.annotations.SdkInternalApi;
2525
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
26+
import software.amazon.awssdk.core.exception.SdkClientException;
2627
import software.amazon.awssdk.services.s3.S3AsyncClient;
2728
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
2829
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
@@ -60,6 +61,11 @@ public class MultipartDownloaderSubscriber implements Subscriber<AsyncResponseTr
6061
*/
6162
private final AtomicInteger completedParts;
6263

64+
/**
65+
* The total number of getObject calls made. This tracks how many times we've actually called getObject.
66+
*/
67+
private final AtomicInteger getObjectCallCount;
68+
6369
/**
6470
* The subscription received from the publisher this subscriber subscribes to.
6571
*/
@@ -94,6 +100,7 @@ public MultipartDownloaderSubscriber(S3AsyncClient s3, GetObjectRequest getObjec
94100
this.s3 = s3;
95101
this.getObjectRequest = getObjectRequest;
96102
this.completedParts = new AtomicInteger(completedParts);
103+
this.getObjectCallCount = new AtomicInteger(completedParts);
97104
}
98105

99106
@Override
@@ -126,11 +133,12 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
126133
GetObjectRequest actualRequest = nextRequest(nextPartToGet);
127134
log.debug(() -> "Sending GetObjectRequest for next part with partNumber=" + nextPartToGet);
128135
CompletableFuture<GetObjectResponse> getObjectFuture = s3.getObject(actualRequest, asyncResponseTransformer);
136+
getObjectCallCount.incrementAndGet();
129137
getObjectFutures.add(getObjectFuture);
130138
getObjectFuture.whenComplete((response, error) -> {
131139
if (error != null) {
132140
log.debug(() -> "Error encountered during GetObjectRequest with partNumber=" + nextPartToGet);
133-
onError(error);
141+
handleError(error);
134142
return;
135143
}
136144
requestMoreIfNeeded(response);
@@ -166,6 +174,7 @@ private void requestMoreIfNeeded(GetObjectResponse response) {
166174
if (totalParts != null && totalParts > 1 && totalComplete < totalParts) {
167175
subscription.request(1);
168176
} else {
177+
validatePartsCount();
169178
log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts));
170179
subscription.cancel();
171180
}
@@ -174,6 +183,13 @@ private void requestMoreIfNeeded(GetObjectResponse response) {
174183

175184
@Override
176185
public void onError(Throwable t) {
186+
handleError(t);
187+
}
188+
189+
/**
190+
* The method used by the Subscriber itself when error occured.
191+
*/
192+
private void handleError(Throwable t) {
177193
CompletableFuture<GetObjectResponse> partFuture;
178194
while ((partFuture = getObjectFutures.poll()) != null) {
179195
partFuture.cancel(true);
@@ -198,4 +214,14 @@ private GetObjectRequest nextRequest(int nextPartToGet) {
198214
}
199215
});
200216
}
217+
218+
private void validatePartsCount() {
219+
int actualGetCount = getObjectCallCount.get();
220+
if (totalParts != null && actualGetCount != totalParts) {
221+
String errorMessage = String.format("PartsCount validation failed. Expected %d, downloaded %d parts.", totalParts,
222+
actualGetCount);
223+
SdkClientException exception = SdkClientException.create(errorMessage);
224+
handleError(exception);
225+
}
226+
}
201227
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.s3.internal.multipart;
17+
18+
19+
import static org.junit.Assert.assertThrows;
20+
import static org.junit.Assert.assertTrue;
21+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
22+
import static org.mockito.ArgumentMatchers.eq;
23+
import static org.mockito.ArgumentMatchers.any;
24+
import static org.mockito.Mockito.when;
25+
26+
import java.util.concurrent.CompletableFuture;
27+
import java.util.concurrent.ExecutionException;
28+
import java.util.concurrent.TimeUnit;
29+
import org.junit.jupiter.api.BeforeEach;
30+
import org.junit.jupiter.api.Test;
31+
import org.mockito.Mock;
32+
import org.mockito.MockitoAnnotations;
33+
import org.reactivestreams.Subscription;
34+
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
35+
import software.amazon.awssdk.core.exception.SdkClientException;
36+
import software.amazon.awssdk.services.s3.S3AsyncClient;
37+
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
38+
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
39+
40+
public class MultipartDownloaderSubscriberPartCountValidationTest {
41+
@Mock
42+
private S3AsyncClient s3Client;
43+
44+
@Mock
45+
private Subscription subscription;
46+
47+
@Mock
48+
private AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> responseTransformer;
49+
50+
private GetObjectRequest getObjectRequest;
51+
private MultipartDownloaderSubscriber subscriber;
52+
53+
@BeforeEach
54+
void setUp() {
55+
MockitoAnnotations.openMocks(this);
56+
getObjectRequest = GetObjectRequest.builder()
57+
.bucket("test-bucket")
58+
.key("test-key")
59+
.build();
60+
}
61+
62+
@Test
63+
void callCountMatchesTotalParts_shouldPass() throws InterruptedException {
64+
subscriber = new MultipartDownloaderSubscriber(s3Client, getObjectRequest);
65+
GetObjectResponse response1 = createMockResponse(3, "etag1");
66+
GetObjectResponse response2 = createMockResponse(3, "etag2");
67+
GetObjectResponse response3 = createMockResponse(3, "etag3");
68+
69+
CompletableFuture<GetObjectResponse> future1 = CompletableFuture.completedFuture(response1);
70+
CompletableFuture<GetObjectResponse> future2 = CompletableFuture.completedFuture(response2);
71+
CompletableFuture<GetObjectResponse> future3 = CompletableFuture.completedFuture(response3);
72+
73+
when(s3Client.getObject(any(GetObjectRequest.class), eq(responseTransformer)))
74+
.thenReturn(future1, future2, future3);
75+
76+
subscriber.onSubscribe(subscription);
77+
subscriber.onNext(responseTransformer);
78+
subscriber.onNext(responseTransformer);
79+
subscriber.onNext(responseTransformer);
80+
Thread.sleep(100);
81+
82+
subscriber.onComplete();
83+
84+
assertDoesNotThrow(() -> subscriber.future().get(1, TimeUnit.SECONDS));
85+
}
86+
87+
@Test
88+
void callCountMoreThanTotalParts_shouldThrowException() throws InterruptedException {
89+
subscriber = new MultipartDownloaderSubscriber(s3Client, getObjectRequest, 3);
90+
GetObjectResponse response1 = createMockResponse(2, "etag1");
91+
92+
CompletableFuture<GetObjectResponse> future1 = CompletableFuture.completedFuture(response1);
93+
94+
when(s3Client.getObject(any(GetObjectRequest.class), eq(responseTransformer)))
95+
.thenReturn(future1);
96+
97+
subscriber.onSubscribe(subscription);
98+
subscriber.onNext(responseTransformer);
99+
Thread.sleep(100);
100+
101+
subscriber.onComplete();
102+
103+
ExecutionException exception = assertThrows(ExecutionException.class,
104+
() -> subscriber.future().get(1, TimeUnit.SECONDS));
105+
assertTrue(exception.getCause() instanceof SdkClientException);
106+
assertTrue(exception.getCause().getMessage().contains("PartsCount validation failed"));
107+
assertTrue(exception.getCause().getMessage().contains("Expected 2, downloaded 4 parts"));
108+
109+
}
110+
111+
private GetObjectResponse createMockResponse(int partsCount, String etag) {
112+
GetObjectResponse.Builder builder = GetObjectResponse.builder()
113+
.eTag(etag)
114+
.contentLength(1024L);
115+
116+
builder.partsCount(partsCount);
117+
return builder.build();
118+
}
119+
120+
}

services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import static org.junit.jupiter.params.provider.Arguments.arguments;
3535
import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.internalErrorBody;
3636
import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.transformersSuppliers;
37+
import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.MULTIPART_DOWNLOAD_RESUME_CONTEXT;
3738

3839
import com.github.tomakehurst.wiremock.http.Fault;
3940
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
@@ -58,6 +59,7 @@
5859
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
5960
import software.amazon.awssdk.core.SplittingTransformerConfiguration;
6061
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
62+
import software.amazon.awssdk.core.exception.SdkClientException;
6163
import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer;
6264
import software.amazon.awssdk.core.internal.async.FileAsyncResponseTransformer;
6365
import software.amazon.awssdk.core.internal.async.InputStreamResponseTransformer;
@@ -67,6 +69,7 @@
6769
import software.amazon.awssdk.regions.Region;
6870
import software.amazon.awssdk.services.s3.S3AsyncClient;
6971
import software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils;
72+
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
7073
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
7174
import software.amazon.awssdk.services.s3.model.S3Exception;
7275
import software.amazon.awssdk.services.s3.utils.AsyncResponseTransformerTestSupplier;
@@ -144,6 +147,47 @@ public <T> void errorOnThirdPart_shouldCompleteExceptionallyOnlyPartsGreaterThan
144147
}
145148
}
146149

150+
@ParameterizedTest
151+
@MethodSource("partSizeAndTransformerParams")
152+
public <T> void partCountValidationFailure_shouldThrowException(
153+
AsyncResponseTransformerTestSupplier<T> supplier,
154+
int partSize) {
155+
156+
// To trigger the partCount failure, the resumeContext is used to initialize the actualGetCount larger than the
157+
// totalPart number set in the response. This won't happen in real scenario, just to test if the error can be surfaced
158+
// to the user if the validation fails.
159+
MultipartDownloadResumeContext resumeContext = new MultipartDownloadResumeContext();
160+
resumeContext.addCompletedPart(1);
161+
resumeContext.addCompletedPart(2);
162+
resumeContext.addCompletedPart(3);
163+
resumeContext.addToBytesToLastCompletedParts(3 * partSize);
164+
165+
GetObjectRequest request = GetObjectRequest.builder()
166+
.bucket(BUCKET)
167+
.key(KEY)
168+
.overrideConfiguration(config -> config
169+
.putExecutionAttribute(
170+
MULTIPART_DOWNLOAD_RESUME_CONTEXT,
171+
resumeContext))
172+
.build();
173+
174+
util.stubForPart(BUCKET, KEY, 4, 2, partSize);
175+
176+
// Skip the lazy transformer since the error won't surface unless the content is consumed
177+
AsyncResponseTransformer<GetObjectResponse, T> transformer = supplier.transformer();
178+
if (transformer instanceof InputStreamResponseTransformer || transformer instanceof PublisherAsyncResponseTransformer) {
179+
return;
180+
}
181+
182+
assertThatThrownBy(() -> {
183+
T res = multipartClient.getObject(request, transformer).join();
184+
supplier.body(res);
185+
}).isInstanceOf(CompletionException.class)
186+
.hasCauseInstanceOf(SdkClientException.class)
187+
.hasMessageContaining("PartsCount validation failed. Expected 2, downloaded 4 parts");
188+
189+
}
190+
147191
@ParameterizedTest
148192
@MethodSource("nonRetryableResponseTransformers")
149193
public <T> void errorOnFirstPart_shouldFail(AsyncResponseTransformerTestSupplier<T> supplier) {

services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ public byte[] stubForPart(String testBucket, String testKey,int part, int totalP
6666
aResponse()
6767
.withHeader("x-amz-mp-parts-count", totalPart + "")
6868
.withHeader("ETag", eTag)
69+
.withHeader("Content-Length", String.valueOf(body.length))
6970
.withBody(body)));
7071
return body;
7172
}

0 commit comments

Comments
 (0)