diff --git a/build.gradle b/build.gradle index 37bc99c..74ea462 100644 --- a/build.gradle +++ b/build.gradle @@ -113,8 +113,11 @@ dependencies { implementation 'commons-codec:commons-codec:1.15' implementation 'org.apache.httpcomponents:httpcore-nio:4.4.15' + implementation 'javax.validation:validation-api:2.0.1.Final' + testImplementation 'junit:junit:4.13.2' testImplementation 'org.easymock:easymock:4.3' + testImplementation("org.springframework.boot:spring-boot-starter-web:${springBotVersion}") { exclude module: 'logback-classic' } diff --git a/src/main/java/com/googlecode/jsonrpc4j/JsonRpcBasicServer.java b/src/main/java/com/googlecode/jsonrpc4j/JsonRpcBasicServer.java index 9e195ff..2531f62 100644 --- a/src/main/java/com/googlecode/jsonrpc4j/JsonRpcBasicServer.java +++ b/src/main/java/com/googlecode/jsonrpc4j/JsonRpcBasicServer.java @@ -8,25 +8,58 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.validation.ValidationException; +import javax.validation.constraints.NotNull; + +import java.beans.IntrospectionException; +import java.beans.Introspector; +import java.beans.PropertyDescriptor; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.lang.annotation.Annotation; -import java.lang.reflect.*; +import java.lang.reflect.Array; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.lang.reflect.Type; +import java.lang.reflect.UndeclaredThrowableException; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; +import java.util.stream.Collectors; -import static com.googlecode.jsonrpc4j.ErrorResolver.JsonError.ERROR_NOT_HANDLED; -import static com.googlecode.jsonrpc4j.ErrorResolver.JsonError.INTERNAL_ERROR; -import static com.googlecode.jsonrpc4j.ReflectionUtil.findCandidateMethods; -import static com.googlecode.jsonrpc4j.ReflectionUtil.getParameterTypes; -import static com.googlecode.jsonrpc4j.Util.hasNonNullData; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.NullNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.googlecode.jsonrpc4j.ErrorResolver.JsonError; +import net.iharder.Base64; /** * A JSON-RPC request server reads JSON-RPC requests from an input stream and writes responses to an output stream. @@ -255,6 +288,11 @@ public int handleRequest(final InputStream input, final OutputStream output) thr JsonResponse responseError = createResponseError(VERSION, NULL, JsonError.PARSE_ERROR); writeAndFlushValue(output, responseError.getResponse()); return responseError.getCode(); + } catch (ValidationException e) { + JsonResponse responseError = + createResponseError(VERSION, NULL, JsonError.METHOD_PARAMS_INVALID); + writeAndFlushValue(output, responseError.getResponse()); + return responseError.getCode(); } } @@ -461,7 +499,7 @@ private JsonResponse handleObject(final ObjectNode node) return createResponseSuccess(jsonRpc, id, handler.result); } return new JsonResponse(null, JsonError.OK.code); - } catch (JsonParseException | JsonMappingException e) { + } catch (JsonParseException | JsonMappingException | ValidationException e) { throw e; // rethrow this, it will be handled as PARSE_ERROR later } catch (ParameterConvertException pce) { handler.error = pce.getCause(); @@ -595,6 +633,9 @@ private JsonNode invoke(Object target, Method method, List params) thr if (convertedParameterTransformer != null) { convertedParams = convertedParameterTransformer.transformConvertedParameters(target, convertedParams); } + if (!allowLessParams) { + collectApiModelsAndValidate(convertedParams); + } result = method.invoke(target, convertedParams); } @@ -603,6 +644,59 @@ private JsonNode invoke(Object target, Method method, List params) thr return hasReturnValue(method) ? mapper.valueToTree(result) : null; } + private void collectApiModelsAndValidate(Object[] params) { + + List requestModels = Arrays.stream(params) + .filter(param -> !param.getClass().isPrimitive()) + .collect(Collectors.toList()); + requestModels.forEach(this::validateFields); + } + + private void validateFields(Object requestModel) { + + Arrays.stream(requestModel.getClass().getDeclaredFields()).forEach(field -> { + validateField(requestModel, field); + }); + + } + + private void validateField(Object requestModel, Field field) { + + if (fieldIsRequired(field)) { + try { + for (PropertyDescriptor pd : Introspector.getBeanInfo(requestModel.getClass()) + .getPropertyDescriptors()) { + if (pd.getReadMethod() != null && !"class".equals(pd.getName()) + && Objects.equals(pd.getName(), field.getName())) { + invokeGetterAndValidate(requestModel, pd.getReadMethod(), field.getName()); + } + } + } catch (IntrospectionException e) { + logger.warn("Unable to find getter for field {} in class {}", field.getName(), + requestModel.getClass().getName()); + } + + } + } + + private void invokeGetterAndValidate(Object o, Method gett, String fieldName) { + + try { + if (gett.invoke(o) == null) { + + throw new ValidationException( + String.format("Field %s cannot be empty", fieldName)); + } + } catch (IllegalAccessException | InvocationTargetException e) { + e.printStackTrace(); + } + } + + private boolean fieldIsRequired(Field field) { + + return field.getAnnotationsByType(NotNull.class).length > 0; + } + private Object invokePrimitiveVarargs(Object target, Method method, List params, Class componentType) throws IllegalAccessException, InvocationTargetException { // need to cast to object here in order to support primitives. Object convertedParams = Array.newInstance(componentType, params.size()); diff --git a/src/test/java/com/googlecode/jsonrpc4j/server/JsonRpcServerAnnotatedParamTest.java b/src/test/java/com/googlecode/jsonrpc4j/server/JsonRpcServerAnnotatedParamTest.java index 66026ac..2c9d8ab 100644 --- a/src/test/java/com/googlecode/jsonrpc4j/server/JsonRpcServerAnnotatedParamTest.java +++ b/src/test/java/com/googlecode/jsonrpc4j/server/JsonRpcServerAnnotatedParamTest.java @@ -15,7 +15,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.UUID; - import static com.googlecode.jsonrpc4j.ErrorResolver.JsonError.METHOD_PARAMS_INVALID; import static com.googlecode.jsonrpc4j.ErrorResolver.JsonError.PARSE_ERROR; import static com.googlecode.jsonrpc4j.JsonRpcBasicServer.ID; @@ -23,6 +22,24 @@ import static com.googlecode.jsonrpc4j.util.Util.*; import static org.junit.Assert.*; +import javax.validation.constraints.NotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import com.fasterxml.jackson.databind.JsonNode; +import com.googlecode.jsonrpc4j.JsonRpcBasicServer; +import com.googlecode.jsonrpc4j.JsonRpcParam; +import com.googlecode.jsonrpc4j.util.Util; + @RunWith(EasyMockRunner.class) public class JsonRpcServerAnnotatedParamTest { @@ -175,21 +192,44 @@ public void callParseErrorJson() throws Exception { assertEquals(PARSE_ERROR.code, errorCode(error(byteArrayOutputStream)).asInt()); } - @Test - public void callMethodWithIncompatibleParamTypeAndExpectInvalidParamsError() throws Exception { - final Object invalidDouble = "callMeDouble"; - jsonRpcServerAnnotatedParam.handleRequest( - createStream( - messageWithListParams( - 3, - METHOD_WITH_DIFFERENT_TYPES, - false, invalidDouble, UUID.randomUUID() - ) - ), - byteArrayOutputStream - ); - assertEquals(METHOD_PARAMS_INVALID.code, errorCode(error(byteArrayOutputStream)).asInt()); - } + @Test + public void callMethodWithAllRequiredParametersInObjectAsParam() throws Exception { + EasyMock.expect(mockService.testMethodWithObjParam(EasyMock.anyObject(String.class),EasyMock.anyObject(TestRequestObj.class))).andReturn("success"); + EasyMock.replay(mockService); + jsonRpcServerAnnotatedParam.handleRequest(messageWithMapParamsStream("testMethodWithObjParam", param1, param2,"obj",new TestRequestObj("1","2","3")), byteArrayOutputStream); + assertEquals("success", result().textValue()); + } + + @Test + public void callMethodWithNullInRequiredParametersInObjectAsParam() throws Exception { + EasyMock.expect(mockService.testMethodWithObjParam(EasyMock.anyObject(String.class),EasyMock.anyObject(TestRequestObj.class))).andReturn("success"); + EasyMock.replay(mockService); + jsonRpcServerAnnotatedParam.handleRequest(messageWithMapParamsStream("testMethodWithObjParam", param1, param2,"obj",new TestRequestObj(null,"2","3")), byteArrayOutputStream); + assertEquals(METHOD_PARAMS_INVALID.code, errorCode(error(byteArrayOutputStream)).intValue()); + } + + @Test + public void callMethodWithNullInNonRequiredParametersInObjectAsParam() throws Exception { + EasyMock.expect(mockService.testMethodWithObjParam(EasyMock.anyObject(String.class),EasyMock.anyObject(TestRequestObj.class))).andReturn("success"); + EasyMock.replay(mockService); + jsonRpcServerAnnotatedParam.handleRequest(messageWithMapParamsStream("testMethodWithObjParam", param1, param2,"obj",new TestRequestObj("1","2",null)), byteArrayOutputStream); + assertEquals("success", result().textValue()); + } + + public void callMethodWithIncompatibleParamTypeAndExpectInvalidParamsError() throws Exception { + final Object invalidDouble = "callMeDouble"; + jsonRpcServerAnnotatedParam.handleRequest( + createStream( + messageWithListParams( + 3, + METHOD_WITH_DIFFERENT_TYPES, + false, invalidDouble, UUID.randomUUID() + ) + ), + byteArrayOutputStream + ); + assertEquals(METHOD_PARAMS_INVALID.code, errorCode(error(byteArrayOutputStream)).asInt()); + } @Test public void callMethodWithIncompatibleParamTypeAndExpectProperJsonRpcIdResponse() throws Exception { @@ -213,6 +253,8 @@ public void callMethodWithIncompatibleParamTypeAndExpectProperJsonRpcIdResponse( public interface ServiceInterfaceWithParamNameAnnotation { String testMethod(@JsonRpcParam("param1") String param1); + + String testMethodWithObjParam(@JsonRpcParam("param1") String param1,@JsonRpcParam("obj") TestRequestObj obj); String overloadedMethod(); @@ -232,4 +274,58 @@ String methodWithDifferentTypes( @JsonRpcParam("param3") UUID doubleParam3 ); } + + public static class TestRequestObj { + + public TestRequestObj(String requiredValue, String anotherRequiredValue, + String nonRequiredValue) { + + this.requiredValue = requiredValue; + this.anotherRequiredValue = anotherRequiredValue; + this.nonRequiredValue = nonRequiredValue; + } + + // for serialization + public TestRequestObj() { + + } + + @NotNull + public String requiredValue; + + @NotNull + public String anotherRequiredValue; + + public String nonRequiredValue; + + public String getRequiredValue() { + + return requiredValue; + } + + public void setRequiredValue(String requiredValue) { + + this.requiredValue = requiredValue; + } + + public String getAnotherRequiredValue() { + + return anotherRequiredValue; + } + + public void setAnotherRequiredValue(String anotherRequiredValue) { + + this.anotherRequiredValue = anotherRequiredValue; + } + + public String getNonRequiredValue() { + + return nonRequiredValue; + } + + public void setNonRequiredValue(String nonRequiredValue) { + + this.nonRequiredValue = nonRequiredValue; + } + } }