@@ -67,6 +67,39 @@ recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags)
6767 }
6868 return msg_count;
6969}
70+
71+ inline bool is_little_endian ()
72+ {
73+ const uint16_t i = 0x01 ;
74+ return *reinterpret_cast <const uint8_t *>(&i) == 0x01 ;
75+ }
76+
77+ inline void write_network_order (unsigned char *buf, const uint32_t value)
78+ {
79+ if (is_little_endian ()) {
80+ ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits<std::uint8_t >::max ();
81+ *buf++ = (value >> 24 ) & mask;
82+ *buf++ = (value >> 16 ) & mask;
83+ *buf++ = (value >> 8 ) & mask;
84+ *buf++ = value & mask;
85+ } else {
86+ std::memcpy (buf, &value, sizeof (value));
87+ }
88+ }
89+
90+ inline uint32_t read_u32_network_order (const unsigned char *buf)
91+ {
92+ if (is_little_endian ()) {
93+ return (static_cast <uint32_t >(buf[0 ]) << 24 )
94+ + (static_cast <uint32_t >(buf[1 ]) << 16 )
95+ + (static_cast <uint32_t >(buf[2 ]) << 8 )
96+ + static_cast <uint32_t >(buf[3 ]);
97+ } else {
98+ uint32_t value;
99+ std::memcpy (&value, buf, sizeof (value));
100+ return value;
101+ }
102+ }
70103} // namespace detail
71104
72105/* Receive a multipart message.
@@ -189,42 +222,37 @@ message_t encode(const Range &parts)
189222
190223 // First pass check sizes
191224 for (const auto &part : parts) {
192- size_t part_size = part.size ();
225+ const size_t part_size = part.size ();
193226 if (part_size > std::numeric_limits<std::uint32_t >::max ()) {
194227 // Size value must fit into uint32_t.
195228 throw std::range_error (" Invalid size, message part too large" );
196229 }
197- size_t count_size = 5 ;
198- if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
199- count_size = 1 ;
200- }
230+ const size_t count_size =
231+ part_size < std::numeric_limits<std::uint8_t >::max () ? 1 : 5 ;
201232 mmsg_size += part_size + count_size;
202233 }
203234
204235 message_t encoded (mmsg_size);
205236 unsigned char *buf = encoded.data <unsigned char >();
206237 for (const auto &part : parts) {
207- uint32_t part_size = part.size ();
238+ const uint32_t part_size = part.size ();
208239 const unsigned char *part_data =
209240 static_cast <const unsigned char *>(part.data ());
210241
211- // small part
212242 if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
243+ // small part
213244 *buf++ = (unsigned char ) part_size;
214- memcpy (buf, part_data, part_size);
215- buf += part_size;
216- continue ;
245+ } else {
246+ // big part
247+ *buf++ = std::numeric_limits<uint8_t >::max ();
248+ detail::write_network_order (buf, part_size);
249+ buf += sizeof (part_size);
217250 }
218-
219- // big part
220- *buf++ = std::numeric_limits<uint8_t >::max ();
221- *buf++ = (part_size >> 24 ) & std::numeric_limits<std::uint8_t >::max ();
222- *buf++ = (part_size >> 16 ) & std::numeric_limits<std::uint8_t >::max ();
223- *buf++ = (part_size >> 8 ) & std::numeric_limits<std::uint8_t >::max ();
224- *buf++ = part_size & std::numeric_limits<std::uint8_t >::max ();
225- memcpy (buf, part_data, part_size);
251+ std::memcpy (buf, part_data, part_size);
226252 buf += part_size;
227253 }
254+
255+ assert (static_cast <size_t >(buf - encoded.data <unsigned char >()) == mmsg_size);
228256 return encoded;
229257}
230258
@@ -251,22 +279,24 @@ template<class OutputIt> OutputIt decode(const message_t &encoded, OutputIt out)
251279 while (source < limit) {
252280 size_t part_size = *source++;
253281 if (part_size == std::numeric_limits<std::uint8_t >::max ()) {
254- if (source > limit - 4 ) {
282+ if (static_cast < size_t >( limit - source) < sizeof ( uint32_t ) ) {
255283 throw std::out_of_range (
256284 " Malformed encoding, overflow in reading size" );
257285 }
258- part_size = (( uint32_t ) source[ 0 ] << 24 ) + (( uint32_t ) source[ 1 ] << 16 )
259- + (( uint32_t ) source[ 2 ] << 8 ) + ( uint32_t ) source[ 3 ];
260- source += 4 ;
286+ part_size = detail::read_u32_network_order ( source);
287+ // the part size is allowed to be less than 0xFF
288+ source += sizeof ( uint32_t ) ;
261289 }
262290
263- if (source > limit - part_size) {
291+ if (static_cast < size_t >( limit - source) < part_size) {
264292 throw std::out_of_range (" Malformed encoding, overflow in reading part" );
265293 }
266294 *out = message_t (source, part_size);
267295 ++out;
268296 source += part_size;
269297 }
298+
299+ assert (source == limit);
270300 return out;
271301}
272302
0 commit comments