diff --git a/memstore.go b/memstore.go index 2a838c1..60532a5 100644 --- a/memstore.go +++ b/memstore.go @@ -81,7 +81,7 @@ func (store *MemoryStore) Put(key string, message packets.ControlPacket) { } // Get takes a key and looks in the store for a matching Message -// returning either the Message pointer or nil. +// returning either a copy of the Message as packets.ControlPacket or nil. func (store *MemoryStore) Get(key string) packets.ControlPacket { store.RLock() defer store.RUnlock() @@ -93,10 +93,12 @@ func (store *MemoryStore) Get(key string) packets.ControlPacket { m := store.messages[key] if m == nil { store.logger.Warn("memorystore get: message not found", slog.Uint64("messageID", uint64(mid)), slog.String("component", string(STR))) - } else { - store.logger.Debug("memorystore get: message found", slog.Uint64("messageID", uint64(mid)), slog.String("component", string(STR))) + return m } - return m + + store.logger.Debug("memorystore get: message found", slog.Uint64("messageID", uint64(mid)), slog.String("component", string(STR))) + + return m.Copy() } // All returns a slice of strings containing all the keys currently diff --git a/packets/connack.go b/packets/connack.go index 962197a..df4855d 100644 --- a/packets/connack.go +++ b/packets/connack.go @@ -69,3 +69,12 @@ func (ca *ConnackPacket) Unpack(b io.Reader) error { func (ca *ConnackPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the ConnackPacket +func (ca *ConnackPacket) Copy() ControlPacket { + cp := NewControlPacket(Connack).(*ConnackPacket) + + *cp = *ca + + return cp +} diff --git a/packets/connect.go b/packets/connect.go index 2313cb1..15c560b 100644 --- a/packets/connect.go +++ b/packets/connect.go @@ -172,3 +172,22 @@ func (c *ConnectPacket) Validate() byte { func (c *ConnectPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the ConnectPacket +func (c *ConnectPacket) Copy() ControlPacket { + cp := NewControlPacket(Connect).(*ConnectPacket) + + *cp = *c + + if len(c.Password) > 0 { + cp.Password = make([]byte, len(c.Password)) + copy(cp.Password, c.Password) + } + + if len(c.WillMessage) > 0 { + cp.WillMessage = make([]byte, len(c.WillMessage)) + copy(cp.WillMessage, c.WillMessage) + } + + return cp +} diff --git a/packets/disconnect.go b/packets/disconnect.go index c0ca3b9..38e3a1a 100644 --- a/packets/disconnect.go +++ b/packets/disconnect.go @@ -51,3 +51,12 @@ func (d *DisconnectPacket) Unpack(b io.Reader) error { func (d *DisconnectPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the DisconnectPacket +func (d *DisconnectPacket) Copy() ControlPacket { + cp := NewControlPacket(Disconnect).(*DisconnectPacket) + + *cp = *d + + return cp +} diff --git a/packets/packets.go b/packets/packets.go index 05cffcc..56c486a 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -32,6 +32,7 @@ type ControlPacket interface { Unpack(io.Reader) error String() string Details() Details + Copy() ControlPacket } // PacketNames maps the constants for each of the MQTT packet types diff --git a/packets/packets_test.go b/packets/packets_test.go index e940eeb..2011562 100644 --- a/packets/packets_test.go +++ b/packets/packets_test.go @@ -18,6 +18,8 @@ package packets import ( "bytes" + "fmt" + "reflect" "testing" ) @@ -270,3 +272,120 @@ func TestEncoding(t *testing.T) { } } + +// isCopy checks if the original and copy are the same, recursively. +// It will fail the test if the values are different or if the pointer +// of the original and copy are the same. +func isCopy(t *testing.T, original, copy any, fieldName ...string) { + t.Helper() + + log := func(field string, original, copy interface{}) { + t.Logf("Field: %s", field) + t.Logf("Original: %#v", original) + t.Logf("Copy: %#v", copy) + } + + originalValue := reflect.ValueOf(original) + copyValue := reflect.ValueOf(copy) + + fullFieldName := "" + if len(fieldName) > 0 { + fullFieldName = fieldName[0] + for _, name := range fieldName[1:] { + fullFieldName += "." + name + } + } + + if originalValue.Kind() != copyValue.Kind() { + log(fullFieldName, original, copy) + t.Errorf("Kind of original and copy are different: %s != %s", originalValue.Kind(), copyValue.Kind()) + } + + switch originalValue.Kind() { + case reflect.Ptr: + if originalValue.Pointer() == copyValue.Pointer() { + log(fullFieldName, original, copy) + t.Errorf("Pointer of original and copy are the same: %x == %x", originalValue.Pointer(), copyValue.Pointer()) + } + isCopy(t, originalValue.Elem().Interface(), copyValue.Elem().Interface(), append(fieldName, originalValue.Type().Elem().Name())...) + case reflect.Slice: + if originalValue.IsNil() && copyValue.IsNil() { + return + } + if originalValue.IsNil() != copyValue.IsNil() { + log(fullFieldName, original, copy) + t.Errorf("IsNil of original and copy are different: %t != %t", originalValue.IsNil(), copyValue.IsNil()) + } + if originalValue.Len() != copyValue.Len() { + log(fullFieldName, original, copy) + t.Errorf("Length of original and copy are different: %d != %d", originalValue.Len(), copyValue.Len()) + } + if originalValue.Len() > 0 && originalValue.Pointer() == copyValue.Pointer() { + log(fullFieldName, original, copy) + t.Errorf("Pointer of original and copy are the same: %x == %x", originalValue.Pointer(), copyValue.Pointer()) + } + for i := 0; i < originalValue.Len(); i++ { + isCopy(t, originalValue.Index(i).Interface(), copyValue.Index(i).Interface(), append(fieldName, fmt.Sprintf("[%d]", i))...) + } + case reflect.Struct: + for i := 0; i < originalValue.Type().NumField(); i++ { + field := originalValue.Type().Field(i) + isCopy(t, originalValue.Field(i).Interface(), copyValue.Field(i).Interface(), append(fieldName, field.Name)...) + } + default: + if !reflect.DeepEqual(originalValue.Interface(), copyValue.Interface()) { + log(fullFieldName, original, copy) + t.Errorf("Values of original and copy are different: %v != %v", originalValue.Interface(), copyValue.Interface()) + } + } +} + +// createValidPointers creates valid pointer for map, slices or normal pointer if they are nil. +func createValidPointers(s any) { + val := reflect.ValueOf(s).Elem() + for i := range val.NumField() { + field := val.Field(i) + switch field.Kind() { + case reflect.Ptr: + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + case reflect.Slice: + if field.IsNil() { + field.Set(reflect.MakeSlice(field.Type(), 1, 1)) + } + case reflect.Map: + if field.IsNil() { + field.Set(reflect.MakeMap(field.Type())) + } + case reflect.Struct: + createValidPointers(field.Addr().Interface()) + } + } +} + +func TestPacketCopy(t *testing.T) { + packets := []ControlPacket{ + NewControlPacket(Connack).(*ConnackPacket), + NewControlPacket(Connect).(*ConnectPacket), + NewControlPacket(Disconnect).(*DisconnectPacket), + NewControlPacket(Pingreq).(*PingreqPacket), + NewControlPacket(Pingresp).(*PingrespPacket), + NewControlPacket(Puback).(*PubackPacket), + NewControlPacket(Pubcomp).(*PubcompPacket), + NewControlPacket(Publish).(*PublishPacket), + NewControlPacket(Pubrec).(*PubrecPacket), + NewControlPacket(Pubrel).(*PubrelPacket), + NewControlPacket(Suback).(*SubackPacket), + NewControlPacket(Subscribe).(*SubscribePacket), + NewControlPacket(Unsuback).(*UnsubackPacket), + NewControlPacket(Unsubscribe).(*UnsubscribePacket), + } + + for _, packet := range packets { + createValidPointers(packet) + copy := packet.Copy() + + isCopy(t, packet, copy) + } +} diff --git a/packets/pingreq.go b/packets/pingreq.go index 99ee073..e864dbe 100644 --- a/packets/pingreq.go +++ b/packets/pingreq.go @@ -51,3 +51,12 @@ func (pr *PingreqPacket) Unpack(b io.Reader) error { func (pr *PingreqPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the PingreqPacket +func (pr *PingreqPacket) Copy() ControlPacket { + cp := NewControlPacket(Pingreq).(*PingreqPacket) + + *cp = *pr + + return cp +} diff --git a/packets/pingresp.go b/packets/pingresp.go index ac813f2..c6e5368 100644 --- a/packets/pingresp.go +++ b/packets/pingresp.go @@ -51,3 +51,12 @@ func (pr *PingrespPacket) Unpack(b io.Reader) error { func (pr *PingrespPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the PingrespPacket +func (pr *PingrespPacket) Copy() ControlPacket { + cp := NewControlPacket(Pingresp).(*PingrespPacket) + + *cp = *pr + + return cp +} diff --git a/packets/puback.go b/packets/puback.go index 6e8f07b..3b89d1d 100644 --- a/packets/puback.go +++ b/packets/puback.go @@ -59,3 +59,12 @@ func (pa *PubackPacket) Unpack(b io.Reader) error { func (pa *PubackPacket) Details() Details { return Details{Qos: pa.Qos, MessageID: pa.MessageID} } + +// Copy creates a deep copy of the PubackPacket +func (pa *PubackPacket) Copy() ControlPacket { + cp := NewControlPacket(Puback).(*PubackPacket) + + *cp = *pa + + return cp +} diff --git a/packets/pubcomp.go b/packets/pubcomp.go index 07b2715..108e2b2 100644 --- a/packets/pubcomp.go +++ b/packets/pubcomp.go @@ -59,3 +59,12 @@ func (pc *PubcompPacket) Unpack(b io.Reader) error { func (pc *PubcompPacket) Details() Details { return Details{Qos: pc.Qos, MessageID: pc.MessageID} } + +// Copy creates a deep copy of the PubcompPacket +func (pc *PubcompPacket) Copy() ControlPacket { + cp := NewControlPacket(Pubcomp).(*PubcompPacket) + + *cp = *pc + + return cp +} diff --git a/packets/publish.go b/packets/publish.go index 27f9582..994784d 100644 --- a/packets/publish.go +++ b/packets/publish.go @@ -58,7 +58,7 @@ func (p *PublishPacket) Write(w io.Writer) error { // Unpack decodes the details of a ControlPacket after the fixed // header has been read func (p *PublishPacket) Unpack(b io.Reader) error { - var payloadLength = p.FixedHeader.RemainingLength + payloadLength := p.FixedHeader.RemainingLength var err error p.TopicName, err = decodeString(b) if err != nil { @@ -83,20 +83,22 @@ func (p *PublishPacket) Unpack(b io.Reader) error { return err } -// Copy creates a new PublishPacket with the same topic and payload -// but an empty fixed header, useful for when you want to deliver -// a message with different properties such as Qos but the same -// content -func (p *PublishPacket) Copy() *PublishPacket { - newP := NewControlPacket(Publish).(*PublishPacket) - newP.TopicName = p.TopicName - newP.Payload = p.Payload - - return newP -} - // Details returns a Details struct containing the Qos and // MessageID of this ControlPacket func (p *PublishPacket) Details() Details { return Details{Qos: p.Qos, MessageID: p.MessageID} } + +// Copy creates a deep copy of the PublishPacket +func (p *PublishPacket) Copy() ControlPacket { + cp := NewControlPacket(Publish).(*PublishPacket) + + *cp = *p + + if len(p.Payload) > 0 { + cp.Payload = make([]byte, len(p.Payload)) + copy(cp.Payload, p.Payload) + } + + return cp +} diff --git a/packets/pubrec.go b/packets/pubrec.go index 5f97764..c58826e 100644 --- a/packets/pubrec.go +++ b/packets/pubrec.go @@ -59,3 +59,12 @@ func (pr *PubrecPacket) Unpack(b io.Reader) error { func (pr *PubrecPacket) Details() Details { return Details{Qos: pr.Qos, MessageID: pr.MessageID} } + +// Copy creates a deep copy of the PubrecPacket +func (pr *PubrecPacket) Copy() ControlPacket { + cp := NewControlPacket(Pubrec).(*PubrecPacket) + + *cp = *pr + + return cp +} diff --git a/packets/pubrel.go b/packets/pubrel.go index 432a752..bb71139 100644 --- a/packets/pubrel.go +++ b/packets/pubrel.go @@ -59,3 +59,12 @@ func (pr *PubrelPacket) Unpack(b io.Reader) error { func (pr *PubrelPacket) Details() Details { return Details{Qos: pr.Qos, MessageID: pr.MessageID} } + +// Copy creates a deep copy of the PubrelPacket +func (pr *PubrelPacket) Copy() ControlPacket { + cp := NewControlPacket(Pubrel).(*PubrelPacket) + + *cp = *pr + + return cp +} diff --git a/packets/suback.go b/packets/suback.go index 02949a0..960b059 100644 --- a/packets/suback.go +++ b/packets/suback.go @@ -74,3 +74,17 @@ func (sa *SubackPacket) Unpack(b io.Reader) error { func (sa *SubackPacket) Details() Details { return Details{Qos: 0, MessageID: sa.MessageID} } + +// Copy creates a deep copy of the SubackPacket +func (sa *SubackPacket) Copy() ControlPacket { + cp := NewControlPacket(Suback).(*SubackPacket) + + *cp = *sa + + if len(sa.ReturnCodes) > 0 { + cp.ReturnCodes = make([]byte, len(sa.ReturnCodes)) + copy(cp.ReturnCodes, sa.ReturnCodes) + } + + return cp +} diff --git a/packets/subscribe.go b/packets/subscribe.go index 72f9f69..78d80b8 100644 --- a/packets/subscribe.go +++ b/packets/subscribe.go @@ -86,3 +86,22 @@ func (s *SubscribePacket) Unpack(b io.Reader) error { func (s *SubscribePacket) Details() Details { return Details{Qos: 1, MessageID: s.MessageID} } + +// Copy creates a deep copy of the SubscribePacket +func (s *SubscribePacket) Copy() ControlPacket { + cp := NewControlPacket(Subscribe).(*SubscribePacket) + + *cp = *s + + if len(s.Topics) > 0 { + cp.Topics = make([]string, len(s.Topics)) + copy(cp.Topics, s.Topics) + } + + if len(s.Qoss) > 0 { + cp.Qoss = make([]byte, len(s.Qoss)) + copy(cp.Qoss, s.Qoss) + } + + return cp +} diff --git a/packets/unsuback.go b/packets/unsuback.go index 2e1a876..8e2da6a 100644 --- a/packets/unsuback.go +++ b/packets/unsuback.go @@ -59,3 +59,12 @@ func (ua *UnsubackPacket) Unpack(b io.Reader) error { func (ua *UnsubackPacket) Details() Details { return Details{Qos: 0, MessageID: ua.MessageID} } + +// Copy creates a deep copy of the UnsubackPacket +func (ua *UnsubackPacket) Copy() ControlPacket { + cp := NewControlPacket(Unsuback).(*UnsubackPacket) + + *cp = *ua + + return cp +} diff --git a/packets/unsubscribe.go b/packets/unsubscribe.go index df0b483..63ee80d 100644 --- a/packets/unsubscribe.go +++ b/packets/unsubscribe.go @@ -73,3 +73,17 @@ func (u *UnsubscribePacket) Unpack(b io.Reader) error { func (u *UnsubscribePacket) Details() Details { return Details{Qos: 1, MessageID: u.MessageID} } + +// Copy creates a deep copy of the UnsubscribePacket +func (u *UnsubscribePacket) Copy() ControlPacket { + cp := NewControlPacket(Unsubscribe).(*UnsubscribePacket) + + *cp = *u + + if len(u.Topics) > 0 { + cp.Topics = make([]string, len(u.Topics)) + copy(cp.Topics, u.Topics) + } + + return cp +}