@@ -67,6 +67,39 @@ recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags)
6767 }
6868 return msg_count;
6969}
70+
71+ inline bool is_little_endian ()
72+ {
73+ const uint16_t i = 0x01 ;
74+ return *reinterpret_cast <const uint8_t *>(&i) == 0x01 ;
75+ }
76+
77+ inline void write_network_order (unsigned char *buf, const uint32_t value)
78+ {
79+ if (is_little_endian ()) {
80+ ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits<std::uint8_t >::max ();
81+ *buf++ = (value >> 24 ) & mask;
82+ *buf++ = (value >> 16 ) & mask;
83+ *buf++ = (value >> 8 ) & mask;
84+ *buf++ = value & mask;
85+ } else {
86+ std::memcpy (buf, &value, sizeof (value));
87+ }
88+ }
89+
90+ inline uint32_t read_u32_network_order (const unsigned char *buf)
91+ {
92+ if (is_little_endian ()) {
93+ return (static_cast <uint32_t >(buf[0 ]) << 24 )
94+ + (static_cast <uint32_t >(buf[1 ]) << 16 )
95+ + (static_cast <uint32_t >(buf[2 ]) << 8 )
96+ + static_cast <uint32_t >(buf[3 ]);
97+ } else {
98+ uint32_t value;
99+ std::memcpy (&value, buf, sizeof (value));
100+ return value;
101+ }
102+ }
70103} // namespace detail
71104
72105/* Receive a multipart message.
@@ -190,42 +223,37 @@ message_t encode(const Range &parts)
190223
191224 // First pass check sizes
192225 for (const auto &part : parts) {
193- size_t part_size = part.size ();
226+ const size_t part_size = part.size ();
194227 if (part_size > std::numeric_limits<std::uint32_t >::max ()) {
195228 // Size value must fit into uint32_t.
196229 throw std::range_error (" Invalid size, message part too large" );
197230 }
198- size_t count_size = 5 ;
199- if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
200- count_size = 1 ;
201- }
231+ const size_t count_size =
232+ part_size < std::numeric_limits<std::uint8_t >::max () ? 1 : 5 ;
202233 mmsg_size += part_size + count_size;
203234 }
204235
205236 message_t encoded (mmsg_size);
206237 unsigned char *buf = encoded.data <unsigned char >();
207238 for (const auto &part : parts) {
208- uint32_t part_size = part.size ();
239+ const uint32_t part_size = part.size ();
209240 const unsigned char *part_data =
210241 static_cast <const unsigned char *>(part.data ());
211242
212- // small part
213243 if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
244+ // small part
214245 *buf++ = (unsigned char ) part_size;
215- memcpy (buf, part_data, part_size);
216- buf += part_size;
217- continue ;
246+ } else {
247+ // big part
248+ *buf++ = std::numeric_limits<uint8_t >::max ();
249+ detail::write_network_order (buf, part_size);
250+ buf += sizeof (part_size);
218251 }
219-
220- // big part
221- *buf++ = std::numeric_limits<uint8_t >::max ();
222- *buf++ = (part_size >> 24 ) & std::numeric_limits<std::uint8_t >::max ();
223- *buf++ = (part_size >> 16 ) & std::numeric_limits<std::uint8_t >::max ();
224- *buf++ = (part_size >> 8 ) & std::numeric_limits<std::uint8_t >::max ();
225- *buf++ = part_size & std::numeric_limits<std::uint8_t >::max ();
226- memcpy (buf, part_data, part_size);
252+ std::memcpy (buf, part_data, part_size);
227253 buf += part_size;
228254 }
255+
256+ assert (static_cast <size_t >(buf - encoded.data <unsigned char >()) == mmsg_size);
229257 return encoded;
230258}
231259
@@ -252,22 +280,24 @@ template<class OutputIt> OutputIt decode(const message_t &encoded, OutputIt out)
252280 while (source < limit) {
253281 size_t part_size = *source++;
254282 if (part_size == std::numeric_limits<std::uint8_t >::max ()) {
255- if (source > limit - 4 ) {
283+ if (static_cast < size_t >( limit - source) < sizeof ( uint32_t ) ) {
256284 throw std::out_of_range (
257285 " Malformed encoding, overflow in reading size" );
258286 }
259- part_size = (( uint32_t ) source[ 0 ] << 24 ) + (( uint32_t ) source[ 1 ] << 16 )
260- + (( uint32_t ) source[ 2 ] << 8 ) + ( uint32_t ) source[ 3 ];
261- source += 4 ;
287+ part_size = detail::read_u32_network_order ( source);
288+ // the part size is allowed to be less than 0xFF
289+ source += sizeof ( uint32_t ) ;
262290 }
263291
264- if (source > limit - part_size) {
292+ if (static_cast < size_t >( limit - source) < part_size) {
265293 throw std::out_of_range (" Malformed encoding, overflow in reading part" );
266294 }
267295 *out = message_t (source, part_size);
268296 ++out;
269297 source += part_size;
270298 }
299+
300+ assert (source == limit);
271301 return out;
272302}
273303
0 commit comments