diff --git a/application/src/test/groovy/javasabr/mqtt/broker/application/PublishRetryTest.groovy b/application/src/test/groovy/javasabr/mqtt/broker/application/PublishRetryTest.groovy index 5ae84a2..e6b647c 100644 --- a/application/src/test/groovy/javasabr/mqtt/broker/application/PublishRetryTest.groovy +++ b/application/src/test/groovy/javasabr/mqtt/broker/application/PublishRetryTest.groovy @@ -12,7 +12,14 @@ import javasabr.mqtt.network.message.in.ConnectAckMqttInMessage import javasabr.mqtt.network.message.in.PublishMqttInMessage import javasabr.mqtt.network.message.in.PublishReleaseMqttInMessage import javasabr.mqtt.network.message.in.SubscribeAckMqttInMessage -import javasabr.mqtt.network.message.out.* +import javasabr.mqtt.network.message.out.ConnectMqtt311OutMessage +import javasabr.mqtt.network.message.out.ConnectMqtt5OutMessage +import javasabr.mqtt.network.message.out.PublishCompleteMqtt311OutMessage +import javasabr.mqtt.network.message.out.PublishCompleteMqtt5OutMessage +import javasabr.mqtt.network.message.out.PublishReceivedMqtt311OutMessage +import javasabr.mqtt.network.message.out.PublishReceivedMqtt5OutMessage +import javasabr.mqtt.network.message.out.SubscribeMqtt311OutMessage +import javasabr.mqtt.network.message.out.SubscribeMqtt5OutMessage import javasabr.mqtt.service.session.MqttSessionService import javasabr.rlib.collections.array.Array import org.springframework.beans.factory.annotation.Autowired diff --git a/model/src/main/java/javasabr/mqtt/model/MqttClientConnectionConfig.java b/model/src/main/java/javasabr/mqtt/model/MqttClientConnectionConfig.java index 48c9dd0..6296b09 100644 --- a/model/src/main/java/javasabr/mqtt/model/MqttClientConnectionConfig.java +++ b/model/src/main/java/javasabr/mqtt/model/MqttClientConnectionConfig.java @@ -35,4 +35,8 @@ public boolean sessionsEnabled() { public int maxTopicLevels() { return server.maxTopicLevels(); } + + public int maxStringLength() { + return server.maxStringLength(); + } } diff --git a/model/src/main/java/javasabr/mqtt/model/MqttProtocolErrors.java b/model/src/main/java/javasabr/mqtt/model/MqttProtocolErrors.java new file mode 100644 index 0000000..d4ba215 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/MqttProtocolErrors.java @@ -0,0 +1,6 @@ +package javasabr.mqtt.model; + +public interface MqttProtocolErrors { + String NO_ANY_TOPIC_FILTER = "No any topic filters"; + String UNSUPPORTED_QOS_OR_RETAIN_HANDLING = "Unsupported qos or retain handling"; +} diff --git a/network/src/main/java/javasabr/mqtt/network/message/in/SubscribeMqttInMessage.java b/network/src/main/java/javasabr/mqtt/network/message/in/SubscribeMqttInMessage.java index feb7300..de63837 100644 --- a/network/src/main/java/javasabr/mqtt/network/message/in/SubscribeMqttInMessage.java +++ b/network/src/main/java/javasabr/mqtt/network/message/in/SubscribeMqttInMessage.java @@ -4,9 +4,10 @@ import java.util.EnumSet; import java.util.Set; import javasabr.mqtt.base.util.DebugUtils; +import javasabr.mqtt.model.MqttClientConnectionConfig; import javasabr.mqtt.model.MqttMessageProperty; import javasabr.mqtt.model.MqttProperties; -import javasabr.mqtt.model.MqttServerConnectionConfig; +import javasabr.mqtt.model.MqttProtocolErrors; import javasabr.mqtt.model.MqttVersion; import javasabr.mqtt.model.QoS; import javasabr.mqtt.model.SubscribeRetainHandling; @@ -21,6 +22,7 @@ import lombok.Getter; import lombok.experimental.Accessors; import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; /** * Subscribe request. @@ -30,7 +32,10 @@ @FieldDefaults(level = AccessLevel.PROTECTED) public class SubscribeMqttInMessage extends TrackableMqttInMessage { + private static final Array EMPTY_SUBSCRIPTIONS = Array.empty(RequestedSubscription.class); + private static final byte MESSAGE_TYPE = (byte) MqttMessageType.SUBSCRIBE.ordinal(); + public static final byte MESSAGE_FLAGS = 0b0000_0010; static { DebugUtils.registerIncludedFields("subscriptions"); @@ -53,14 +58,14 @@ public class SubscribeMqttInMessage extends TrackableMqttInMessage { */ MqttMessageProperty.USER_PROPERTY); - final MutableArray subscriptions; + @Nullable + MutableArray subscriptions; // properties int subscriptionId; public SubscribeMqttInMessage(byte info) { super(info); - this.subscriptions = ArrayFactory.mutableArray(RequestedSubscription.class); this.subscriptionId = MqttProperties.SUBSCRIPTION_ID_IS_NOT_SET; } @@ -71,22 +76,25 @@ public byte messageType() { @Override protected boolean validMessageFlags(byte messageFlags) { - return messageFlags == 0b0000_0010; + return messageFlags == MESSAGE_FLAGS; } @Override protected void readPayload(MqttConnection connection, ByteBuffer buffer) { if (buffer.remaining() < 1) { - throw new MalformedProtocolMqttException("No any topic filters"); + throw new MalformedProtocolMqttException(MqttProtocolErrors.NO_ANY_TOPIC_FILTER); } - MqttServerConnectionConfig severConnConfig = connection.serverConnectionConfig(); + MqttClientConnectionConfig connectionConfig = connection.clientConnectionConfig(); + int maxStringLength = connectionConfig.maxStringLength(); boolean isMqtt5 = connection.isSupported(MqttVersion.MQTT_5); + subscriptions = ArrayFactory.mutableArray(RequestedSubscription.class); + // http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718066 // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901168 while (buffer.hasRemaining()) { - String topicFilter = readString(buffer, severConnConfig.maxStringLength()); + String topicFilter = readString(buffer, maxStringLength); int options = readByteUnsigned(buffer); int qosLevel = options & 0b0000_0011; @@ -105,7 +113,7 @@ protected void readPayload(MqttConnection connection, ByteBuffer buffer) { QoS qos = QoS.ofCode(qosLevel); if (qos == QoS.INVALID || retainHandling == SubscribeRetainHandling.INVALID) { - throw new MalformedProtocolMqttException("Unsupported qos or retain handling"); + throw new MalformedProtocolMqttException(MqttProtocolErrors.UNSUPPORTED_QOS_OR_RETAIN_HANDLING); } subscriptions.add(new RequestedSubscription( @@ -131,11 +139,11 @@ protected void applyProperty(MqttMessageProperty property, long value) { } public Array subscriptions() { - return subscriptions; + return subscriptions == null ? EMPTY_SUBSCRIPTIONS : subscriptions; } public int subscriptionsCount() { - return subscriptions.size(); + return subscriptions().size(); } private static void validateMqtt311Options(int options) { diff --git a/network/src/main/java/javasabr/mqtt/network/message/in/UnsubscribeMqttInMessage.java b/network/src/main/java/javasabr/mqtt/network/message/in/UnsubscribeMqttInMessage.java index 3e99436..ce31a7b 100644 --- a/network/src/main/java/javasabr/mqtt/network/message/in/UnsubscribeMqttInMessage.java +++ b/network/src/main/java/javasabr/mqtt/network/message/in/UnsubscribeMqttInMessage.java @@ -4,7 +4,9 @@ import java.util.EnumSet; import java.util.Set; import javasabr.mqtt.base.util.DebugUtils; +import javasabr.mqtt.model.MqttClientConnectionConfig; import javasabr.mqtt.model.MqttMessageProperty; +import javasabr.mqtt.model.MqttProtocolErrors; import javasabr.mqtt.model.exception.MalformedProtocolMqttException; import javasabr.mqtt.network.MqttConnection; import javasabr.mqtt.network.message.MqttMessageType; @@ -22,7 +24,7 @@ */ @Getter @Accessors(fluent = true) -@FieldDefaults(level = AccessLevel.PRIVATE) +@FieldDefaults(level = AccessLevel.PROTECTED) public class UnsubscribeMqttInMessage extends TrackableMqttInMessage { static { @@ -30,6 +32,7 @@ public class UnsubscribeMqttInMessage extends TrackableMqttInMessage { } private static final byte MESSAGE_TYPE = (byte) MqttMessageType.UNSUBSCRIBE.ordinal(); + public static final byte MESSAGE_FLAGS = 0b0000_0010; private static final Set AVAILABLE_PROPERTIES = EnumSet.of( /* @@ -57,12 +60,16 @@ protected boolean validMessageFlags(byte messageFlags) { @Override protected void readPayload(MqttConnection connection, ByteBuffer buffer) { - if (buffer.remaining() < 1) { - throw new MalformedProtocolMqttException("No any topic filters."); + if (!buffer.hasRemaining()) { + throw new MalformedProtocolMqttException(MqttProtocolErrors.NO_ANY_TOPIC_FILTER); } + + MqttClientConnectionConfig connectionConfig = connection.clientConnectionConfig(); + int maxStringLength = connectionConfig.maxStringLength(); + rawTopicFilters = ArrayFactory.mutableArray(String.class); while (buffer.hasRemaining()) { - rawTopicFilters.add(readString(buffer, Integer.MAX_VALUE)); + rawTopicFilters.add(readString(buffer, maxStringLength)); } } @@ -70,6 +77,10 @@ public Array rawTopicFilters() { return rawTopicFilters == null ? EMPTY_STRINGS : rawTopicFilters; } + public int topicFiltersCount() { + return rawTopicFilters == null ? 0 : rawTopicFilters.size(); + } + @Override protected Set availableProperties() { return AVAILABLE_PROPERTIES; diff --git a/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt311OutMessage.java b/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt311OutMessage.java index 01e9547..d98bcd5 100644 --- a/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt311OutMessage.java +++ b/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt311OutMessage.java @@ -4,12 +4,16 @@ import javasabr.mqtt.network.MqttConnection; import javasabr.mqtt.network.message.MqttMessageType; import lombok.AccessLevel; +import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.experimental.Accessors; import lombok.experimental.FieldDefaults; /** * Unsubscribe acknowledgement. */ +@Getter +@Accessors(fluent = true) @RequiredArgsConstructor @FieldDefaults(level = AccessLevel.PROTECTED, makeFinal = true) public class UnsubscribeAckMqtt311OutMessage extends MqttOutMessage { diff --git a/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt5OutMessage.java b/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt5OutMessage.java index a0da8b6..00687c6 100644 --- a/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt5OutMessage.java +++ b/network/src/main/java/javasabr/mqtt/network/message/out/UnsubscribeAckMqtt5OutMessage.java @@ -9,11 +9,15 @@ import javasabr.mqtt.network.MqttConnection; import javasabr.rlib.collections.array.Array; import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.Accessors; import lombok.experimental.FieldDefaults; /** * Unsubscribe acknowledgement. */ +@Getter +@Accessors(fluent = true) @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class UnsubscribeAckMqtt5OutMessage extends UnsubscribeAckMqtt311OutMessage { diff --git a/network/src/test/groovy/javasabr/mqtt/network/message/in/SubscribeMqttInMessageTest.groovy b/network/src/test/groovy/javasabr/mqtt/network/message/in/SubscribeMqttInMessageTest.groovy index f68a3ee..c00dc2d 100644 --- a/network/src/test/groovy/javasabr/mqtt/network/message/in/SubscribeMqttInMessageTest.groovy +++ b/network/src/test/groovy/javasabr/mqtt/network/message/in/SubscribeMqttInMessageTest.groovy @@ -1,7 +1,9 @@ package javasabr.mqtt.network.message.in + import javasabr.mqtt.model.MqttMessageProperty import javasabr.mqtt.model.MqttProperties +import javasabr.mqtt.model.MqttProtocolErrors import javasabr.mqtt.model.QoS import javasabr.mqtt.model.SubscribeRetainHandling import javasabr.mqtt.model.exception.MalformedProtocolMqttException @@ -139,7 +141,7 @@ class SubscribeMqttInMessageTest extends BaseMqttInMessageTest { then: !successful2 inMessage2.exception() instanceof MalformedProtocolMqttException - inMessage2.exception().message == 'Unsupported qos or retain handling' + inMessage2.exception().message == MqttProtocolErrors.UNSUPPORTED_QOS_OR_RETAIN_HANDLING when: def dataBuffer3 = BufferUtils.prepareBuffer(512) { it.putShort(messageId) @@ -150,6 +152,6 @@ class SubscribeMqttInMessageTest extends BaseMqttInMessageTest { then: !successful3 inMessage3.exception() instanceof MalformedProtocolMqttException - inMessage3.exception().message == 'No any topic filters' + inMessage3.exception().message == MqttProtocolErrors.NO_ANY_TOPIC_FILTER } } diff --git a/network/src/testFixtures/groovy/javasabr/mqtt/network/MqttMockClient.groovy b/network/src/testFixtures/groovy/javasabr/mqtt/network/MqttMockClient.groovy index eb5b5a9..c0cf43b 100644 --- a/network/src/testFixtures/groovy/javasabr/mqtt/network/MqttMockClient.groovy +++ b/network/src/testFixtures/groovy/javasabr/mqtt/network/MqttMockClient.groovy @@ -1,7 +1,11 @@ package javasabr.mqtt.network import javasabr.mqtt.network.message.MqttMessageType -import javasabr.mqtt.network.message.in.* +import javasabr.mqtt.network.message.in.ConnectAckMqttInMessage +import javasabr.mqtt.network.message.in.MqttInMessage +import javasabr.mqtt.network.message.in.PublishMqttInMessage +import javasabr.mqtt.network.message.in.PublishReleaseMqttInMessage +import javasabr.mqtt.network.message.in.SubscribeAckMqttInMessage import javasabr.mqtt.network.message.out.DisconnectMqtt311OutMessage import javasabr.mqtt.network.message.out.MqttOutMessage import javasabr.mqtt.network.util.MqttDataUtils diff --git a/network/src/testFixtures/groovy/javasabr/mqtt/network/NetworkUnitSpecification.groovy b/network/src/testFixtures/groovy/javasabr/mqtt/network/NetworkUnitSpecification.groovy index ce5c690..e377213 100644 --- a/network/src/testFixtures/groovy/javasabr/mqtt/network/NetworkUnitSpecification.groovy +++ b/network/src/testFixtures/groovy/javasabr/mqtt/network/NetworkUnitSpecification.groovy @@ -1,6 +1,11 @@ package javasabr.mqtt.network -import javasabr.mqtt.model.* + +import javasabr.mqtt.model.MqttClientConnectionConfig +import javasabr.mqtt.model.MqttServerConnectionConfig +import javasabr.mqtt.model.MqttVersion +import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.SubscribeRetainHandling import javasabr.mqtt.model.data.type.StringPair import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode diff --git a/service/src/main/java/javasabr/mqtt/service/impl/DefaultConnectionService.java b/service/src/main/java/javasabr/mqtt/service/impl/DefaultConnectionService.java index 79d36b5..166c793 100644 --- a/service/src/main/java/javasabr/mqtt/service/impl/DefaultConnectionService.java +++ b/service/src/main/java/javasabr/mqtt/service/impl/DefaultConnectionService.java @@ -66,7 +66,7 @@ protected void processReceivedValidMessage( try { MqttInMessageHandler messageHandler = inMessageHandlers[mqttInMessage.messageType()]; //noinspection DataFlowIssue - messageHandler.processReceivedValidMessage(connection, mqttInMessage); + messageHandler.processValidMessage(connection, mqttInMessage); } catch (IndexOutOfBoundsException | NullPointerException ex) { log.warning(mqttInMessage, "Received not supported MQTT message:[%s]"::formatted); } @@ -90,7 +90,7 @@ protected void processReceivedInvalidMessage( try { MqttInMessageHandler messageHandler = inMessageHandlers[mqttInMessage.messageType()]; //noinspection DataFlowIssue - messageHandler.processReceivedInvalidMessage(connection, mqttInMessage); + messageHandler.processInvalidMessage(connection, mqttInMessage); } catch (IndexOutOfBoundsException | NullPointerException ex) { log.warning(mqttInMessage, "Received not supported MQTT message:[%s]"::formatted); } diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/MqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/MqttInMessageHandler.java index fcd26b1..ece8220 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/MqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/MqttInMessageHandler.java @@ -8,7 +8,7 @@ public interface MqttInMessageHandler { MqttMessageType messageType(); - void processReceivedValidMessage(MqttConnection connection, MqttInMessage mqttInMessage); + void processValidMessage(MqttConnection connection, MqttInMessage mqttInMessage); - void processReceivedInvalidMessage(MqttConnection connection, MqttInMessage mqttInMessage); + void processInvalidMessage(MqttConnection connection, MqttInMessage mqttInMessage); } diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/AbstractMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/AbstractMqttInMessageHandler.java index 2571ce2..d6ac6cc 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/AbstractMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/AbstractMqttInMessageHandler.java @@ -6,6 +6,8 @@ import javasabr.mqtt.network.MqttConnection; import javasabr.mqtt.network.message.in.MqttInMessage; import javasabr.mqtt.network.message.out.MqttOutMessage; +import javasabr.mqtt.network.session.MqttSession; +import javasabr.mqtt.network.util.ExtraErrorReasons; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.message.handler.MqttInMessageHandler; import lombok.AccessLevel; @@ -23,8 +25,12 @@ public abstract class AbstractMqttInMessageHandler expectedNetworkPacket; MessageOutFactoryService messageOutFactoryService; + protected boolean requireSession() { + return true; + } + @Override - public void processReceivedValidMessage(MqttConnection connection, MqttInMessage mqttInMessage) { + public void processValidMessage(MqttConnection connection, MqttInMessage mqttInMessage) { MqttClient client = connection.client(); if (!expectedClient.isInstance(client)) { log.warning(client, "Received not expected client:[%s]"::formatted); @@ -35,11 +41,21 @@ public void processReceivedValidMessage(MqttConnection connection, MqttInMessage } C castedClient = expectedClient.cast(client); M castedMessage = expectedNetworkPacket.cast(mqttInMessage); - processReceivedValidMessage(connection, castedClient, castedMessage); + if (requireSession()) { + MqttSession session = client.session(); + if (session == null) { + log.warning(client.clientId(), "[%s] Session is already closed"::formatted); + handleSessionIsAlreadyClosed(client); + return; + } + processValidMessage(connection, castedClient, session, castedMessage); + } else { + processValidMessage(connection, castedClient, castedMessage); + } } @Override - public void processReceivedInvalidMessage(MqttConnection connection, MqttInMessage mqttInMessage) { + public void processInvalidMessage(MqttConnection connection, MqttInMessage mqttInMessage) { MqttClient client = connection.client(); if (!expectedClient.isInstance(client)) { log.warning(client, "Received not expected client:[%s]"::formatted); @@ -50,23 +66,58 @@ public void processReceivedInvalidMessage(MqttConnection connection, MqttInMessa } C castedClient = expectedClient.cast(client); M castedMessage = expectedNetworkPacket.cast(mqttInMessage); - processReceivedInvalidMessage(connection, castedClient, castedMessage); + if (requireSession()) { + MqttSession session = client.session(); + if (session == null) { + log.warning(client.clientId(), "[%s] Session is already closed"::formatted); + handleSessionIsAlreadyClosed(client); + return; + } + processInvalidMessage(connection, castedClient, session, castedMessage); + } else { + processInvalidMessage(connection, castedClient, castedMessage); + } } - protected abstract void processReceivedValidMessage(MqttConnection connection, C client, M message); + protected void processValidMessage(MqttConnection connection, C client, M message) {} + + protected void processValidMessage(MqttConnection connection, C client, MqttSession session, M message) {} + + protected boolean processInvalidMessage(MqttConnection connection, C client, M message) { + Exception exception = message.exception(); + if (exception instanceof MalformedProtocolMqttException) { + malformedProtocolError(connection, client, exception); + return true; + } + return false; + } - protected boolean processReceivedInvalidMessage(MqttConnection connection, C client, M message) { + protected boolean processInvalidMessage(MqttConnection connection, C client, MqttSession session, M message) { Exception exception = message.exception(); if (exception instanceof MalformedProtocolMqttException) { - // send feedback and close connection - MqttOutMessage feedback = messageOutFactoryService - .resolveFactory(client) - .newDisconnect(client, DisconnectReasonCode.MALFORMED_PACKET); - client - .sendWithFeedback(feedback) - .thenAccept(_ -> connection.close()); + malformedProtocolError(connection, client, exception); return true; } return false; } + + protected void malformedProtocolError(MqttConnection connection, C client, Exception exception) { + // send feedback and close connection + MqttOutMessage feedback = messageOutFactoryService + .resolveFactory(client) + .newDisconnect(client, DisconnectReasonCode.MALFORMED_PACKET, exception.getMessage()); + client + .sendWithFeedback(feedback) + .thenAccept(_ -> connection.close()); + } + + protected void handleSessionIsAlreadyClosed(MqttClient client) { + MqttOutMessage response = messageOutFactoryService + .resolveFactory(client) + .newDisconnect( + client, + DisconnectReasonCode.UNSPECIFIED_ERROR, + ExtraErrorReasons.SESSION_IS_ALREADY_CLOSED); + client.closeWithReason(response); + } } diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/ConnectInMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/ConnectInMqttInMessageHandler.java index eec789b..4180ef9 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/ConnectInMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/ConnectInMqttInMessageHandler.java @@ -63,7 +63,12 @@ public MqttMessageType messageType() { } @Override - protected void processReceivedValidMessage( + protected boolean requireSession() { + return false; + } + + @Override + protected void processValidMessage( MqttConnection connection, ExternalMqttClient client, ConnectMqttInMessage message) { @@ -218,9 +223,10 @@ private boolean onSentConnAck(MqttClient.UnsafeMqttClient client, MqttSession se } @Override - protected boolean processReceivedInvalidMessage( + protected boolean processInvalidMessage( MqttConnection connection, ExternalMqttClient client, + MqttSession session, ConnectMqttInMessage message) { Exception exception = message.exception(); if (exception instanceof ConnectionRejectException cre) { @@ -230,6 +236,6 @@ protected boolean processReceivedInvalidMessage( client.closeWithReason(feedback); return true; } - return super.processReceivedInvalidMessage(connection, client, message); + return super.processInvalidMessage(connection, client, session, message); } } diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/DisconnectMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/DisconnectMqttInMessageHandler.java index de45bea..4412848 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/DisconnectMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/DisconnectMqttInMessageHandler.java @@ -5,6 +5,7 @@ import javasabr.mqtt.network.impl.ExternalMqttClient; import javasabr.mqtt.network.message.MqttMessageType; import javasabr.mqtt.network.message.in.DisconnectMqttInMessage; +import javasabr.mqtt.network.session.MqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import lombok.CustomLog; @@ -21,9 +22,10 @@ public MqttMessageType messageType() { } @Override - protected void processReceivedValidMessage( + protected void processValidMessage( MqttConnection connection, ExternalMqttClient client, + MqttSession session, DisconnectMqttInMessage message) { DisconnectReasonCode reasonCode = message.reasonCode(); if (reasonCode == DisconnectReasonCode.NORMAL_DISCONNECTION) { diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PendingOutResponseMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PendingOutResponseMqttInMessageHandler.java index 771ab12..43cc810 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PendingOutResponseMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PendingOutResponseMqttInMessageHandler.java @@ -17,10 +17,11 @@ protected PendingOutResponseMqttInMessageHandler( } @Override - protected void processReceivedValidMessage(MqttConnection connection, ExternalMqttClient client, P message) { - MqttSession session = client.session(); - if (session != null) { - session.updateOutPendingPacket(client, message); - } + protected void processValidMessage( + MqttConnection connection, + ExternalMqttClient client, + MqttSession session, + P message) { + session.updateOutPendingPacket(client, message); } } diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishMqttInMessageHandler.java index 212d7b9..ca4ac68 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishMqttInMessageHandler.java @@ -36,19 +36,13 @@ public PublishMqttInMessageHandler( } @Override - protected void processReceivedValidMessage( + protected void processValidMessage( MqttConnection connection, ExternalMqttClient client, + MqttSession session, PublishMqttInMessage message) { - MqttSession session = client.session(); - if (session == null) { - log.warning(client.clientId(), "[%s] Client has no any session..."::formatted); - return; - } - int messageId = message.messageId(); - if (messageId > 0 && session.hasInPending(messageId)) { client.send(messageOutFactoryService .resolveFactory(client) diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishReleaseMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishReleaseMqttInMessageHandler.java index 56f6b7a..5cd42aa 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishReleaseMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/PublishReleaseMqttInMessageHandler.java @@ -20,13 +20,11 @@ public MqttMessageType messageType() { } @Override - protected void processReceivedValidMessage( + protected void processValidMessage( MqttConnection connection, ExternalMqttClient client, + MqttSession session, PublishReleaseMqttInMessage message) { - MqttSession session = client.session(); - if (session != null) { - session.updateInPendingPacket(client, message); - } + session.updateInPendingPacket(client, message); } } diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java index 2855cbb..9e5528d 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java @@ -19,7 +19,6 @@ import javasabr.mqtt.network.message.out.MqttOutMessage; import javasabr.mqtt.network.session.MessageTacker; import javasabr.mqtt.network.session.MqttSession; -import javasabr.mqtt.network.util.ExtraErrorReasons; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.TopicService; @@ -57,23 +56,17 @@ public MqttMessageType messageType() { } @Override - protected void processReceivedValidMessage( + protected void processValidMessage( MqttConnection connection, ExternalMqttClient client, + MqttSession session, SubscribeMqttInMessage subscribeMessage) { MqttClientConnectionConfig connectionConfig = client.connectionConfig(); - MqttSession session = client.session(); - if (session == null) { - log.warning(client.clientId(), "[%s] Session is already closed"::formatted); - handleSessionIsAlreadyClosed(client); - return; - } - int messageId = subscribeMessage.messageId(); MessageTacker messageTacker = session.inMessageTracker(); if (messageTacker.isInUse(messageId)) { - log.warning(client.clientId(), messageId, "[%s] Message id:[%d] is already in use"::formatted); + log.warning(client.clientId(), messageId, "[%s] MessageId:[%d] is already in use"::formatted); handleMessageIdIsInUse(client, subscribeMessage); return; } @@ -84,7 +77,7 @@ protected void processReceivedValidMessage( if (subscriptionId != MqttProperties.SUBSCRIPTION_ID_IS_NOT_SET) { if (!connectionConfig.subscriptionIdAvailable()) { log.warning(client.clientId(), subscriptionId, - "[%s] Provided subscription id:[%d] but server doesn't allow it"::formatted); + "[%s] Provided subscriptionId:[%d] but server doesn't allow it"::formatted); handleSubscriptionIdNotSupported(client, session, subscribeMessage); return; } @@ -95,12 +88,12 @@ protected void processReceivedValidMessage( client, subscribeMessage.subscriptions(), subscriptionId); - Array subscriptionResults = subscriptionService + Array subscribeResults = subscriptionService .subscribe(client, session, subscriptions); - sendSubscriptionResults(client, session, subscribeMessage, subscriptionResults); + sendSubscribeResults(client, session, subscribeMessage, subscribeResults); - SubscribeAckReasonCode anyReasonToDisconnect = subscriptionResults + SubscribeAckReasonCode anyReasonToDisconnect = subscribeResults .iterations() .reversedArgs() .findAny(DISCONNECT_CASES, Set::contains); @@ -139,53 +132,43 @@ private Array transformSubscriptions( return subscriptions; } - private void handleSessionIsAlreadyClosed(ExternalMqttClient client) { - MqttOutMessage response = messageOutFactoryService - .resolveFactory(client) - .newDisconnect( - client, - DisconnectReasonCode.UNSPECIFIED_ERROR, - ExtraErrorReasons.SESSION_IS_ALREADY_CLOSED); - client.closeWithReason(response); - } - private void handleMessageIdIsInUse( ExternalMqttClient client, SubscribeMqttInMessage subscribeMessage) { - Array subscriptionResults = Array.repeated( + Array subscribeResults = Array.repeated( SubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE, subscribeMessage.subscriptionsCount()); client.send(messageOutFactoryService .resolveFactory(client) - .newSubscribeAck(subscribeMessage.messageId(), subscriptionResults)); + .newSubscribeAck(subscribeMessage.messageId(), subscribeResults)); } private void handleSubscriptionIdNotSupported( ExternalMqttClient client, MqttSession session, SubscribeMqttInMessage subscribeMessage) { - Array subscriptionResults = Array.repeated( + Array subscribeResults = Array.repeated( SubscribeAckReasonCode.SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED, subscribeMessage.subscriptionsCount()); int messageId = subscribeMessage.messageId(); MqttOutMessage response = messageOutFactoryService .resolveFactory(client) - .newSubscribeAck(messageId, subscriptionResults); + .newSubscribeAck(messageId, subscribeResults); client.sendWithFeedback(response) .thenAccept(_ -> session .inMessageTracker() .remove(messageId)); } - private void sendSubscriptionResults( + private void sendSubscribeResults( ExternalMqttClient client, MqttSession session, SubscribeMqttInMessage subscribeMessage, - Array subscriptionResults) { + Array subscribeResults) { int messageId = subscribeMessage.messageId(); MqttOutMessage response = messageOutFactoryService .resolveFactory(client) - .newSubscribeAck(messageId, subscriptionResults); + .newSubscribeAck(messageId, subscribeResults); client.sendWithFeedback(response) .thenAccept(_ -> session .inMessageTracker() diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandler.java index a78dfde..83d47cb 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandler.java @@ -6,14 +6,19 @@ import javasabr.mqtt.network.impl.ExternalMqttClient; import javasabr.mqtt.network.message.MqttMessageType; import javasabr.mqtt.network.message.in.UnsubscribeMqttInMessage; +import javasabr.mqtt.network.message.out.MqttOutMessage; +import javasabr.mqtt.network.session.MessageTacker; +import javasabr.mqtt.network.session.MqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.TopicService; import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayCollectors; import lombok.AccessLevel; +import lombok.CustomLog; import lombok.experimental.FieldDefaults; +@CustomLog @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class UnsubscribeMqttInMessageHandler extends AbstractMqttInMessageHandler { @@ -36,22 +41,49 @@ public MqttMessageType messageType() { } @Override - protected void processReceivedValidMessage( + protected void processValidMessage( MqttConnection connection, ExternalMqttClient client, - UnsubscribeMqttInMessage message) { + MqttSession session, + UnsubscribeMqttInMessage unsubscribeMessage) { - Array topicFilters = message + int messageId = unsubscribeMessage.messageId(); + MessageTacker messageTacker = session.inMessageTracker(); + if (messageTacker.isInUse(messageId)) { + log.warning(client.clientId(), messageId, "[%s] MessageId:[%d] is already in use"::formatted); + handleMessageIdIsInUse(client, unsubscribeMessage); + return; + } + + messageTacker.add(messageId); + + Array topicFilters = unsubscribeMessage .rawTopicFilters() .stream() .map(rawTopicFilter -> topicService.createTopicFilter(client, rawTopicFilter)) .collect(ArrayCollectors.toArray(TopicFilter.class)); Array unsubscribeResults = subscriptionService - .unsubscribe(client, client.session(), topicFilters); + .unsubscribe(client, session, topicFilters); + + MqttOutMessage response = messageOutFactoryService + .resolveFactory(client) + .newUnsubscribeAck(unsubscribeMessage.messageId(), unsubscribeResults); + client.sendWithFeedback(response) + .thenAccept(_ -> session + .inMessageTracker() + .remove(messageId)); + } + + private void handleMessageIdIsInUse( + ExternalMqttClient client, + UnsubscribeMqttInMessage unsubscribeMessage) { + Array unsubscribeResults = Array.repeated( + UnsubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE, + unsubscribeMessage.topicFiltersCount()); client.send(messageOutFactoryService .resolveFactory(client) - .newUnsubscribeAck(message.messageId(), unsubscribeResults, message.userProperties())); + .newUnsubscribeAck(unsubscribeMessage.messageId(), unsubscribeResults)); } } diff --git a/service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy b/service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy index 5ef51c6..8a7d393 100644 --- a/service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy +++ b/service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy @@ -1,6 +1,11 @@ package javasabr.mqtt.service -import javasabr.mqtt.model.* + +import javasabr.mqtt.model.MqttClientConnectionConfig +import javasabr.mqtt.model.MqttProperties +import javasabr.mqtt.model.MqttServerConnectionConfig +import javasabr.mqtt.model.MqttVersion +import javasabr.mqtt.model.QoS import javasabr.mqtt.network.MqttConnection import javasabr.mqtt.network.handler.MqttClientReleaseHandler import javasabr.mqtt.service.impl.DefaultMessageOutFactoryService diff --git a/service/src/test/groovy/javasabr/mqtt/service/TestExternalMqttClient.groovy b/service/src/test/groovy/javasabr/mqtt/service/TestExternalMqttClient.groovy index c0598e5..3e3d75c 100644 --- a/service/src/test/groovy/javasabr/mqtt/service/TestExternalMqttClient.groovy +++ b/service/src/test/groovy/javasabr/mqtt/service/TestExternalMqttClient.groovy @@ -7,16 +7,25 @@ import javasabr.mqtt.network.message.out.MqttOutMessage import javasabr.rlib.collections.array.MutableArray import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executor +import java.util.concurrent.TimeUnit class TestExternalMqttClient extends ExternalMqttClient { + private static final Executor DELAYED_EXECUTOR = CompletableFuture.delayedExecutor(5000, TimeUnit.MILLISECONDS) + private final MutableArray sentMessages + private boolean returnCompletedFeatures = true; TestExternalMqttClient(MqttConnection connection, MqttClientReleaseHandler releaseHandler) { super(connection, releaseHandler) this.sentMessages = MutableArray.ofType(MqttOutMessage) } + void returnCompletedFeatures(boolean returnCompletedFeatures) { + this.returnCompletedFeatures = returnCompletedFeatures; + } + @Override void send(MqttOutMessage message) { sentMessages.add(message) @@ -25,12 +34,18 @@ class TestExternalMqttClient extends ExternalMqttClient { @Override CompletableFuture sendWithFeedback(MqttOutMessage message) { sentMessages.add(message) + if (!returnCompletedFeatures) { + return CompletableFuture.supplyAsync({ true }, DELAYED_EXECUTOR); + } return CompletableFuture.completedFuture(true) } @Override CompletableFuture closeWithReason(MqttOutMessage message) { sentMessages.add(message) + if (!returnCompletedFeatures) { + return CompletableFuture.supplyAsync({ true }, DELAYED_EXECUTOR); + } return CompletableFuture.completedFuture(true) } diff --git a/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy b/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy index 1df9d3e..c740fdc 100644 --- a/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy +++ b/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy @@ -12,6 +12,8 @@ import javasabr.mqtt.network.util.ExtraErrorReasons import javasabr.mqtt.service.IntegrationServiceSpecification import javasabr.mqtt.service.TestExternalMqttClient import javasabr.rlib.collections.array.Array +import javasabr.rlib.collections.array.MutableArray +import javasabr.rlib.common.util.ThreadUtils import javasabr.rlib.logger.api.LoggerLevel import javasabr.rlib.logger.api.LoggerManager @@ -32,8 +34,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def mqttClient = mqttConnection.client() as TestExternalMqttClient mqttClient.session(null) when: - def subscribeMessage = new SubscribeMqttInMessage(0 as byte) - messageHandler.processReceivedValidMessage(mqttConnection, mqttClient, subscribeMessage) + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) + messageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def disconnectReason = mqttClient.nextSentMessage(DisconnectMqtt5OutMessage) disconnectReason.reasonCode() == DisconnectReasonCode.UNSPECIFIED_ERROR @@ -53,13 +55,14 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def session = mqttClient.session() session.inMessageTracker().add(expectedMessageId) when: - def subscribeMessage = new SubscribeMqttInMessage(0 as byte) {{ + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of( RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processReceivedValidMessage(mqttConnection, mqttClient, subscribeMessage) + messageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -81,14 +84,15 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: - def subscribeMessage = new SubscribeMqttInMessage(0 as byte) {{ + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ this.messageId = expectedMessageId this.subscriptionId = 25 + this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of( RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processReceivedValidMessage(mqttConnection, mqttClient, subscribeMessage) + messageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -110,14 +114,15 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: - def subscribeMessage = new SubscribeMqttInMessage(0 as byte) {{ + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ this.messageId = expectedMessageId this.subscriptionId = 25 + this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of( RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) }} - messageHandler.processReceivedValidMessage(mqttConnection, mqttClient, subscribeMessage) + messageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -139,13 +144,14 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: - def subscribeMessage = new SubscribeMqttInMessage(0 as byte) {{ + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of( RequestedSubscription.minimal("topic1/#", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("topic2/+", QoS.EXACTLY_ONCE))) }} - messageHandler.processReceivedValidMessage(mqttConnection, mqttClient, subscribeMessage) + messageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -171,13 +177,14 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: - def subscribeMessage = new SubscribeMqttInMessage(0 as byte) {{ + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) this.subscriptions.addAll(Array.of( RequestedSubscription.minimal("\$share/group1/topic1/#", QoS.EXACTLY_ONCE), RequestedSubscription.minimal("\$share/group1/topic2/+", QoS.EXACTLY_ONCE))) }} - messageHandler.processReceivedValidMessage(mqttConnection, mqttClient, subscribeMessage) + messageHandler.processValidMessage(mqttConnection, subscribeMessage) then: def subscribeAck = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) def reasonCodes = subscribeAck.reasonCodes() @@ -190,4 +197,99 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification && disconnectReason.reason() == "" && disconnectReason.serverReference() == "" } + + def "should close connection by reason MQTT protocol error"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new SubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def mqttClient = mqttConnection.client() as TestExternalMqttClient + when: + def subscribeMessage = new SubscribeMqttInMessage(0 as byte) + messageHandler.processInvalidMessage(mqttConnection, subscribeMessage) + then: + def disconnectReason = mqttClient.nextSentMessage(DisconnectMqtt5OutMessage) + disconnectReason.reasonCode() == DisconnectReasonCode.MALFORMED_PACKET + && disconnectReason.reason() == "Unexpected flags bits:0b0000_0000" + && disconnectReason.serverReference() == "" + } + + def "should reuse the same message if from previous request"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new SubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def expectedMessageId = 15 + def mqttClient = mqttConnection.client() as TestExternalMqttClient + when: + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic1", QoS.EXACTLY_ONCE))) + }} + messageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + def subscribeAck = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) + def reasonCodes = subscribeAck.reasonCodes() + reasonCodes.size() == 1 + && reasonCodes.get(0) == SubscribeAckReasonCode.GRANTED_QOS_2 + && subscribeAck.messageId() == expectedMessageId + when: + ThreadUtils.sleep(300) + def subscribeMessage2 = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) + }} + messageHandler.processValidMessage(mqttConnection, subscribeMessage2) + then: + def subscribeAck2 = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) + def reasonCodes2 = subscribeAck2.reasonCodes() + reasonCodes2.size() == 1 + && reasonCodes2.get(0) == SubscribeAckReasonCode.GRANTED_QOS_2 + && subscribeAck2.messageId() == expectedMessageId + } + + def "should response that message id is in use because previous is still in progress"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new SubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def expectedMessageId = 15 + def mqttClient = mqttConnection.client() as TestExternalMqttClient + mqttClient.returnCompletedFeatures(false) + when: + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) + }} + messageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + def subscribeAck = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) + def reasonCodes = subscribeAck.reasonCodes() + reasonCodes.size() == 1 + && reasonCodes.get(0) == SubscribeAckReasonCode.GRANTED_QOS_2 + && subscribeAck.messageId() == expectedMessageId + when: + ThreadUtils.sleep(300) + def subscribeMessage2 = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) + }} + messageHandler.processValidMessage(mqttConnection, subscribeMessage2) + then: + def subscribeAck2 = mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) + def reasonCodes2 = subscribeAck2.reasonCodes() + reasonCodes2.size() == 1 + && reasonCodes2.get(0) == SubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE + && subscribeAck2.messageId() == expectedMessageId + } } diff --git a/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy b/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy new file mode 100644 index 0000000..bdd30db --- /dev/null +++ b/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy @@ -0,0 +1,211 @@ +package javasabr.mqtt.service.message.handler.impl + +import javasabr.mqtt.model.MqttVersion +import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.reason.code.DisconnectReasonCode +import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode +import javasabr.mqtt.model.subscribtion.Subscription +import javasabr.mqtt.network.message.in.UnsubscribeMqttInMessage +import javasabr.mqtt.network.message.out.DisconnectMqtt5OutMessage +import javasabr.mqtt.network.message.out.UnsubscribeAckMqtt5OutMessage +import javasabr.mqtt.network.util.ExtraErrorReasons +import javasabr.mqtt.service.IntegrationServiceSpecification +import javasabr.mqtt.service.TestExternalMqttClient +import javasabr.mqtt.service.impl.InMemorySubscriptionService +import javasabr.rlib.collections.array.Array +import javasabr.rlib.collections.array.MutableArray +import javasabr.rlib.common.util.ThreadUtils + +class UnsubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification { + + def "should close connection by reason that session is already closed"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new UnsubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def mqttClient = mqttConnection.client() as TestExternalMqttClient + mqttClient.session(null) + when: + def unsubscribeMessage = new UnsubscribeMqttInMessage(UnsubscribeMqttInMessage.MESSAGE_FLAGS) + messageHandler.processValidMessage(mqttConnection, unsubscribeMessage) + then: + def disconnectReason = mqttClient.nextSentMessage(DisconnectMqtt5OutMessage) + disconnectReason.reasonCode() == DisconnectReasonCode.UNSPECIFIED_ERROR + && disconnectReason.reason() == ExtraErrorReasons.SESSION_IS_ALREADY_CLOSED + && disconnectReason.serverReference() == "" + } + + def "should response that message id is in use"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new UnsubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def expectedMessageId = 15 + def mqttClient = mqttConnection.client() as TestExternalMqttClient + def session = mqttClient.session() + session.inMessageTracker().add(expectedMessageId) + when: + def unsubscribeMessage = new UnsubscribeMqttInMessage(UnsubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.rawTopicFilters = MutableArray.ofType(String) + this.rawTopicFilters.addAll(Array.of("topic1", "topic2")) + }} + messageHandler.processValidMessage(mqttConnection, unsubscribeMessage) + then: + def unsubscribeAck = mqttClient.nextSentMessage(UnsubscribeAckMqtt5OutMessage) + def reasonCodes = unsubscribeAck.reasonCodes() + reasonCodes.size() == 2 + && reasonCodes.get(0) == UnsubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE + && reasonCodes.get(1) == UnsubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE + && unsubscribeAck.messageId() == expectedMessageId + } + + def "should response with expected results"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def subscriptionService = new InMemorySubscriptionService() + def messageHandler = new UnsubscribeMqttInMessageHandler( + subscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def expectedMessageId = 15 + def mqttClient = mqttConnection.client() as TestExternalMqttClient + def session = mqttClient.session() + def topicFilter1 = defaultTopicService.createTopicFilter(mqttClient, "topic/exist") + def topicFilter2 = defaultTopicService.createTopicFilter(mqttClient, "topic/exist2") + def topicFilter3 = defaultTopicService.createTopicFilter(mqttClient, "topic/exist3") + subscriptionService.subscribe( + mqttClient, + session, + Array.of( + Subscription.minimal(topicFilter1, QoS.AT_MOST_ONCE), + Subscription.minimal(topicFilter2, QoS.AT_MOST_ONCE), + Subscription.minimal(topicFilter3, QoS.AT_MOST_ONCE))) + when: + def unsubscribeMessage = new UnsubscribeMqttInMessage(UnsubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.rawTopicFilters = MutableArray.ofType(String) + this.rawTopicFilters.addAll(Array.of(topicFilter1.rawTopic(), topicFilter2.rawTopic(), "topic/notexist", "topic/invalid##")) + }} + messageHandler.processValidMessage(mqttConnection, unsubscribeMessage) + then: + def unsubscribeAck = mqttClient.nextSentMessage(UnsubscribeAckMqtt5OutMessage) + def reasonCodes = unsubscribeAck.reasonCodes() + reasonCodes.size() == 4 + && reasonCodes.get(0) == UnsubscribeAckReasonCode.SUCCESS + && reasonCodes.get(1) == UnsubscribeAckReasonCode.SUCCESS + && reasonCodes.get(2) == UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED + && reasonCodes.get(3) == UnsubscribeAckReasonCode.TOPIC_FILTER_INVALID + && unsubscribeAck.messageId() == expectedMessageId + when: + def topicName1 = defaultTopicService.createTopicName(mqttClient, "topic/exist") + def topicName2 = defaultTopicService.createTopicName(mqttClient, "topic/exist2") + def topicName3 = defaultTopicService.createTopicName(mqttClient, "topic/exist3") + def subscribers1 = subscriptionService.findSubscribers(topicName1) + def subscribers2 = subscriptionService.findSubscribers(topicName2) + def subscribers3 = subscriptionService.findSubscribers(topicName3) + then: + subscribers1.isEmpty() + subscribers2.isEmpty() + subscribers3.size() == 1 && subscribers3.first().owner() == mqttClient + } + + def "should close connection by reason MQTT protocol error"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new UnsubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def mqttClient = mqttConnection.client() as TestExternalMqttClient + when: + def unsubscribeMessage = new UnsubscribeMqttInMessage(0 as byte) + messageHandler.processInvalidMessage(mqttConnection, unsubscribeMessage) + then: + def disconnectReason = mqttClient.nextSentMessage(DisconnectMqtt5OutMessage) + disconnectReason.reasonCode() == DisconnectReasonCode.MALFORMED_PACKET + && disconnectReason.reason() == "Unexpected flags bits:0b0000_0000" + && disconnectReason.serverReference() == "" + } + + def "should reuse the same message if from previous request"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new UnsubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def expectedMessageId = 15 + def mqttClient = mqttConnection.client() as TestExternalMqttClient + when: + def unsubscribeMessage = new UnsubscribeMqttInMessage(UnsubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.rawTopicFilters = MutableArray.ofType(String) + this.rawTopicFilters.addAll(Array.of("topic1")) + }} + messageHandler.processValidMessage(mqttConnection, unsubscribeMessage) + then: + def unsubscribeAck = mqttClient.nextSentMessage(UnsubscribeAckMqtt5OutMessage) + def reasonCodes = unsubscribeAck.reasonCodes() + reasonCodes.size() == 1 + && reasonCodes.get(0) == UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED + && unsubscribeAck.messageId() == expectedMessageId + when: + ThreadUtils.sleep(300) + def unsubscribeMessage2 = new UnsubscribeMqttInMessage(UnsubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.rawTopicFilters = MutableArray.ofType(String) + this.rawTopicFilters.addAll(Array.of("topic2")) + }} + messageHandler.processValidMessage(mqttConnection, unsubscribeMessage2) + then: + def unsubscribeAck2 = mqttClient.nextSentMessage(UnsubscribeAckMqtt5OutMessage) + def reasonCodes2 = unsubscribeAck2.reasonCodes() + reasonCodes2.size() == 1 + && reasonCodes2.get(0) == UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED + && unsubscribeAck2.messageId() == expectedMessageId + } + + def "should response that message id is in use because previous is still in progress"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new UnsubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService) + def expectedMessageId = 15 + def mqttClient = mqttConnection.client() as TestExternalMqttClient + mqttClient.returnCompletedFeatures(false) + when: + def unsubscribeMessage = new UnsubscribeMqttInMessage(UnsubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.rawTopicFilters = MutableArray.ofType(String) + this.rawTopicFilters.addAll(Array.of("topic1")) + }} + messageHandler.processValidMessage(mqttConnection, unsubscribeMessage) + then: + def unsubscribeAck = mqttClient.nextSentMessage(UnsubscribeAckMqtt5OutMessage) + def reasonCodes = unsubscribeAck.reasonCodes() + reasonCodes.size() == 1 + && reasonCodes.get(0) == UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED + && unsubscribeAck.messageId() == expectedMessageId + when: + ThreadUtils.sleep(300) + def unsubscribeMessage2 = new UnsubscribeMqttInMessage(UnsubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = expectedMessageId + this.rawTopicFilters = MutableArray.ofType(String) + this.rawTopicFilters.addAll(Array.of("topic2")) + }} + messageHandler.processValidMessage(mqttConnection, unsubscribeMessage2) + then: + def unsubscribeAck2 = mqttClient.nextSentMessage(UnsubscribeAckMqtt5OutMessage) + def reasonCodes2 = unsubscribeAck2.reasonCodes() + reasonCodes2.size() == 1 + && reasonCodes2.get(0) == UnsubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE + && unsubscribeAck2.messageId() == expectedMessageId + } +}