diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/AcceptableExtension.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/AcceptableExtension.java new file mode 100644 index 000000000000..1fcb53e95a2f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/AcceptableExtension.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.bind.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation for validating file extensions of multipart file uploads in Spring MVC + * controller methods. When applied to a {@link org.springframework.web.multipart.MultipartFile} + * parameter, it restricts the acceptable file extensions that can be uploaded. + * + *
This annotation works in conjunction with a custom argument resolver or validator + * to enforce file extension constraints at the controller level, providing early + * validation before file processing. + * + *
Example usage: + *
+ * @PostMapping("/upload")
+ * public String handleFileUpload(
+ * @AcceptableExtension(extensions = {"jpg", "png", "pdf"})
+ * @RequestParam("file") MultipartFile file) {
+ * // Process file
+ * return "success";
+ * }
+ *
+ *
+ * @author Aleksei Iakhnenko
+ * @since 7.0
+ * @see org.springframework.web.multipart.MultipartFile
+ * @see org.springframework.web.bind.annotation.RequestParam
+ */
+@Target(ElementType.PARAMETER)
+@Retention(RetentionPolicy.RUNTIME)
+public @interface AcceptableExtension {
+ String[] extensions() default {};
+ String message() default "Invalid file extension";
+}
diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/AcceptableExtensionMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/AcceptableExtensionMethodArgumentResolver.java
new file mode 100644
index 000000000000..21c225ef18df
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/web/method/annotation/AcceptableExtensionMethodArgumentResolver.java
@@ -0,0 +1,132 @@
+/*
+ * Copyright 2002-present the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.method.annotation;
+
+import java.util.Arrays;
+
+import jakarta.servlet.http.HttpServletRequest;
+import org.jspecify.annotations.Nullable;
+
+import org.springframework.core.MethodParameter;
+import org.springframework.util.StringUtils;
+import org.springframework.web.bind.annotation.AcceptableExtension;
+import org.springframework.web.bind.support.WebDataBinderFactory;
+import org.springframework.web.context.request.NativeWebRequest;
+import org.springframework.web.method.support.HandlerMethodArgumentResolver;
+import org.springframework.web.method.support.ModelAndViewContainer;
+import org.springframework.web.multipart.MultipartException;
+import org.springframework.web.multipart.MultipartFile;
+import org.springframework.web.multipart.support.MultipartResolutionDelegate;
+
+/**
+ * Resolves method arguments annotated with @AcceptableExtension and validates
+ * file extensions for MultipartFile parameters.
+ *
+ * @author Aleksei Iakhnenko
+ * @since 7.0
+ * @see AcceptableExtension
+ */
+public class AcceptableExtensionMethodArgumentResolver implements HandlerMethodArgumentResolver {
+
+ @Override
+ public boolean supportsParameter(MethodParameter parameter) {
+ return parameter.hasParameterAnnotation(AcceptableExtension.class);
+ }
+
+ @Override
+ @Nullable
+ public Object resolveArgument(
+ MethodParameter parameter,
+ @Nullable ModelAndViewContainer mavContainer,
+ NativeWebRequest webRequest,
+ @Nullable WebDataBinderFactory binderFactory) throws Exception {
+
+ AcceptableExtension annotation = parameter.getParameterAnnotation(AcceptableExtension.class);
+ if (annotation == null) {
+ return null;
+ }
+
+ HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
+ if (servletRequest == null) {
+ return null;
+ }
+
+ String paramName = getParameterName(parameter);
+ if (paramName == null) {
+ return null;
+ }
+
+ Object resolvedArgument = MultipartResolutionDelegate.resolveMultipartArgument(
+ paramName, parameter, servletRequest);
+
+ MultipartFile file = (resolvedArgument instanceof MultipartFile) ?
+ (MultipartFile) resolvedArgument :
+ null;
+
+ if (file != null && !file.isEmpty()) {
+ String filename = file.getOriginalFilename();
+ if (StringUtils.hasText(filename)) {
+ String extension = StringUtils.getFilenameExtension(filename);
+ if (extension != null && !isAcceptableExtension(extension, annotation.extensions())) {
+ throw new MultipartException(annotation.message() +
+ ". Allowed: " + Arrays.toString(annotation.extensions()) +
+ ", received: " + extension);
+ }
+ }
+ }
+
+ return file;
+ }
+
+ /**
+ * Determine the name for the given method parameter.
+ * @param parameter the method parameter
+ * @return the parameter name, or {@code null} if not resolvable
+ */
+ @Nullable
+ private String getParameterName(MethodParameter parameter) {
+ org.springframework.web.bind.annotation.RequestParam requestParam =
+ parameter.getParameterAnnotation(org.springframework.web.bind.annotation.RequestParam.class);
+
+ if (requestParam != null) {
+ String paramName = requestParam.value();
+ if (StringUtils.hasText(paramName)) {
+ return paramName;
+ }
+ paramName = requestParam.name();
+ if (StringUtils.hasText(paramName)) {
+ return paramName;
+ }
+ }
+
+ // Fallback to actual parameter name if available
+ return parameter.getParameterName();
+ }
+
+ private boolean isAcceptableExtension(String extension, String[] acceptableExtensions) {
+ if (acceptableExtensions.length == 0) {
+ return true;
+ }
+ for (String acceptable : acceptableExtensions) {
+ if (acceptable.equalsIgnoreCase(extension)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+}
diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/AcceptableExtensionMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/AcceptableExtensionMethodArgumentResolverTests.java
new file mode 100644
index 000000000000..d754792e3404
--- /dev/null
+++ b/spring-web/src/test/java/org/springframework/web/method/annotation/AcceptableExtensionMethodArgumentResolverTests.java
@@ -0,0 +1,449 @@
+/*
+ * Copyright 2002-present the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.method.annotation;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.core.MethodParameter;
+import org.springframework.web.bind.annotation.AcceptableExtension;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.context.request.NativeWebRequest;
+import org.springframework.web.context.request.ServletWebRequest;
+import org.springframework.web.multipart.MultipartException;
+import org.springframework.web.multipart.MultipartFile;
+import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
+import org.springframework.web.testfixture.servlet.MockMultipartFile;
+import org.springframework.web.testfixture.servlet.MockMultipartHttpServletRequest;
+
+import java.lang.reflect.Method;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Unit tests for {@link AcceptableExtensionMethodArgumentResolver}.
+ *
+ * @author Aleksei Iakhnenko
+ */
+class AcceptableExtensionMethodArgumentResolverTests {
+
+ private AcceptableExtensionMethodArgumentResolver resolver;
+
+ @BeforeEach
+ void setUp() {
+ this.resolver = new AcceptableExtensionMethodArgumentResolver();
+ }
+
+ @Test
+ void supportsParameterWithAcceptableExtensionAnnotation() throws Exception {
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+ assertThat(this.resolver.supportsParameter(parameter)).isTrue();
+ }
+
+ @Test
+ void doesNotSupportParameterWithoutAcceptableExtensionAnnotation() throws Exception {
+ MethodParameter parameter = getMethodParameter("handleFileUploadWithoutAnnotation", 0);
+ assertThat(this.resolver.supportsParameter(parameter)).isFalse();
+ }
+
+ @Test
+ void resolveArgumentWithValidExtension() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "file", "test.jpg", "image/jpeg", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ MultipartFile resolvedFile = (MultipartFile) result;
+ assertThat(resolvedFile.getOriginalFilename()).isEqualTo("test.jpg");
+ }
+
+ @Test
+ void resolveArgumentWithInvalidExtensionThrowsException() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "file", "test.exe", "application/octet-stream", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ assertThatThrownBy(() -> this.resolver.resolveArgument(parameter, null, webRequest, null))
+ .isInstanceOf(MultipartException.class)
+ .hasMessageContaining("Invalid file extension")
+ .hasMessageContaining("Allowed: [jpg, png, pdf]")
+ .hasMessageContaining("received: exe");
+ }
+
+ @Test
+ void resolveArgumentWithEmptyExtensionsArrayAcceptsAnyExtension() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "document", "test.xyz", "application/octet-stream", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUploadWithEmptyExtensions", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ MultipartFile resolvedFile = (MultipartFile) result;
+ assertThat(resolvedFile.getOriginalFilename()).isEqualTo("test.xyz");
+ }
+
+ @Test
+ void resolveArgumentWithEmptyFile() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "file", "test.jpg", "image/jpeg", new byte[0]);
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ MultipartFile resolvedFile = (MultipartFile) result;
+ assertThat(resolvedFile.isEmpty()).isTrue();
+ }
+
+ @Test
+ void resolveArgumentWithNullFilename() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "file", null, "image/jpeg", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ }
+
+ @Test
+ void resolveArgumentWithEmptyFilename() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "file", "", "image/jpeg", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ }
+
+ @Test
+ void resolveArgumentWithFilenameWithoutExtension() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "file", "testfile", "application/octet-stream", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ }
+
+ @Test
+ void resolveArgumentWithCaseInsensitiveExtension() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "file", "test.JPG", "image/jpeg", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ MultipartFile resolvedFile = (MultipartFile) result;
+ assertThat(resolvedFile.getOriginalFilename()).isEqualTo("test.JPG");
+ }
+
+ @Test
+ void resolveArgumentWithCustomMessage() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "avatar", "profile.exe", "application/octet-stream", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleAvatarUpload", 0);
+
+ assertThatThrownBy(() -> this.resolver.resolveArgument(parameter, null, webRequest, null))
+ .isInstanceOf(MultipartException.class)
+ .hasMessageContaining("Please upload only image files");
+ }
+
+ @Test
+ void resolveArgumentWithRequestParamName() throws Exception {
+ MockMultipartFile file = new MockMultipartFile(
+ "uploadedFile", "test.png", "image/png", "content".getBytes());
+
+ MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
+ request.addFile(file);
+
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUploadWithCustomName", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isInstanceOf(MultipartFile.class);
+ MultipartFile resolvedFile = (MultipartFile) result;
+ assertThat(resolvedFile.getOriginalFilename()).isEqualTo("test.png");
+ }
+
+ @Test
+ void resolveArgumentWithNonMultipartRequest() throws Exception {
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ NativeWebRequest webRequest = new ServletWebRequest(request);
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ Object result = this.resolver.resolveArgument(parameter, null, webRequest, null);
+
+ assertThat(result).isNull();
+ }
+
+ @Test
+ void resolveArgumentWithoutHttpServletRequest() throws Exception {
+ MethodParameter parameter = getMethodParameter("handleFileUpload", 0);
+
+ // Simulate a scenario where getNativeRequest returns null
+ NativeWebRequest emptyWebRequest = new NativeWebRequest() {
+ @Override
+ public Object getNativeRequest() {
+ return new Object(); // Not an HttpServletRequest
+ }
+
+ @Override
+ public Object getNativeResponse() {
+ return null;
+ }
+
+ @Override
+ public