|
6 | 6 | "fmt" |
7 | 7 | "io" |
8 | 8 | "math" |
| 9 | + "math/bits" |
9 | 10 | ) |
10 | 11 |
|
11 | 12 | //go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go |
@@ -69,7 +70,7 @@ type header struct { |
69 | 70 | payloadLength int64 |
70 | 71 |
|
71 | 72 | masked bool |
72 | | - maskKey [4]byte |
| 73 | + maskKey uint32 |
73 | 74 | } |
74 | 75 |
|
75 | 76 | func makeWriteHeaderBuf() []byte { |
@@ -119,7 +120,7 @@ func writeHeader(b []byte, h header) []byte { |
119 | 120 | if h.masked { |
120 | 121 | b[1] |= 1 << 7 |
121 | 122 | b = b[:len(b)+4] |
122 | | - copy(b[len(b)-4:], h.maskKey[:]) |
| 123 | + binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey) |
123 | 124 | } |
124 | 125 |
|
125 | 126 | return b |
@@ -192,7 +193,7 @@ func readHeader(b []byte, r io.Reader) (header, error) { |
192 | 193 | } |
193 | 194 |
|
194 | 195 | if h.masked { |
195 | | - copy(h.maskKey[:], b) |
| 196 | + h.maskKey = binary.LittleEndian.Uint32(b) |
196 | 197 | } |
197 | 198 |
|
198 | 199 | return h, nil |
@@ -321,122 +322,122 @@ func (ce CloseError) bytes() ([]byte, error) { |
321 | 322 | return buf, nil |
322 | 323 | } |
323 | 324 |
|
324 | | -// xor applies the WebSocket masking algorithm to p |
325 | | -// with the given key where the first 3 bits of pos |
326 | | -// are the starting position in the key. |
| 325 | +// fastXOR applies the WebSocket masking algorithm to p |
| 326 | +// with the given key. |
327 | 327 | // See https://tools.ietf.org/html/rfc6455#section-5.3 |
328 | 328 | // |
329 | | -// The returned value is the position of the next byte |
330 | | -// to be used for masking in the key. This is so that |
331 | | -// unmasking can be performed without the entire frame. |
332 | | -func fastXOR(key [4]byte, keyPos int, b []byte) int { |
333 | | - // If the payload is greater than or equal to 16 bytes, then it's worth |
334 | | - // masking 8 bytes at a time. |
335 | | - // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 |
336 | | - if len(b) >= 16 { |
337 | | - // We first create a key that is 8 bytes long |
338 | | - // and is aligned on the position correctly. |
339 | | - var alignedKey [8]byte |
340 | | - for i := range alignedKey { |
341 | | - alignedKey[i] = key[(i+keyPos)&3] |
342 | | - } |
343 | | - k := binary.LittleEndian.Uint64(alignedKey[:]) |
| 329 | +// The returned value is the correctly rotated key to |
| 330 | +// to continue to mask/unmask the message. |
| 331 | +// |
| 332 | +// It is optimized for LittleEndian and expects the key |
| 333 | +// to be in little endian. |
| 334 | +func fastXOR(key uint32, b []byte) uint32 { |
| 335 | + if len(b) >= 8 { |
| 336 | + key64 := uint64(key)<<32 | uint64(key) |
344 | 337 |
|
345 | 338 | // At some point in the future we can clean these unrolled loops up. |
346 | 339 | // See https://github.com/golang/go/issues/31586#issuecomment-487436401 |
347 | 340 |
|
348 | 341 | // Then we xor until b is less than 128 bytes. |
349 | 342 | for len(b) >= 128 { |
350 | 343 | v := binary.LittleEndian.Uint64(b) |
351 | | - binary.LittleEndian.PutUint64(b, v^k) |
| 344 | + binary.LittleEndian.PutUint64(b, v^key64) |
352 | 345 | v = binary.LittleEndian.Uint64(b[8:]) |
353 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
| 346 | + binary.LittleEndian.PutUint64(b[8:], v^key64) |
354 | 347 | v = binary.LittleEndian.Uint64(b[16:]) |
355 | | - binary.LittleEndian.PutUint64(b[16:], v^k) |
| 348 | + binary.LittleEndian.PutUint64(b[16:], v^key64) |
356 | 349 | v = binary.LittleEndian.Uint64(b[24:]) |
357 | | - binary.LittleEndian.PutUint64(b[24:], v^k) |
| 350 | + binary.LittleEndian.PutUint64(b[24:], v^key64) |
358 | 351 | v = binary.LittleEndian.Uint64(b[32:]) |
359 | | - binary.LittleEndian.PutUint64(b[32:], v^k) |
| 352 | + binary.LittleEndian.PutUint64(b[32:], v^key64) |
360 | 353 | v = binary.LittleEndian.Uint64(b[40:]) |
361 | | - binary.LittleEndian.PutUint64(b[40:], v^k) |
| 354 | + binary.LittleEndian.PutUint64(b[40:], v^key64) |
362 | 355 | v = binary.LittleEndian.Uint64(b[48:]) |
363 | | - binary.LittleEndian.PutUint64(b[48:], v^k) |
| 356 | + binary.LittleEndian.PutUint64(b[48:], v^key64) |
364 | 357 | v = binary.LittleEndian.Uint64(b[56:]) |
365 | | - binary.LittleEndian.PutUint64(b[56:], v^k) |
| 358 | + binary.LittleEndian.PutUint64(b[56:], v^key64) |
366 | 359 | v = binary.LittleEndian.Uint64(b[64:]) |
367 | | - binary.LittleEndian.PutUint64(b[64:], v^k) |
| 360 | + binary.LittleEndian.PutUint64(b[64:], v^key64) |
368 | 361 | v = binary.LittleEndian.Uint64(b[72:]) |
369 | | - binary.LittleEndian.PutUint64(b[72:], v^k) |
| 362 | + binary.LittleEndian.PutUint64(b[72:], v^key64) |
370 | 363 | v = binary.LittleEndian.Uint64(b[80:]) |
371 | | - binary.LittleEndian.PutUint64(b[80:], v^k) |
| 364 | + binary.LittleEndian.PutUint64(b[80:], v^key64) |
372 | 365 | v = binary.LittleEndian.Uint64(b[88:]) |
373 | | - binary.LittleEndian.PutUint64(b[88:], v^k) |
| 366 | + binary.LittleEndian.PutUint64(b[88:], v^key64) |
374 | 367 | v = binary.LittleEndian.Uint64(b[96:]) |
375 | | - binary.LittleEndian.PutUint64(b[96:], v^k) |
| 368 | + binary.LittleEndian.PutUint64(b[96:], v^key64) |
376 | 369 | v = binary.LittleEndian.Uint64(b[104:]) |
377 | | - binary.LittleEndian.PutUint64(b[104:], v^k) |
| 370 | + binary.LittleEndian.PutUint64(b[104:], v^key64) |
378 | 371 | v = binary.LittleEndian.Uint64(b[112:]) |
379 | | - binary.LittleEndian.PutUint64(b[112:], v^k) |
| 372 | + binary.LittleEndian.PutUint64(b[112:], v^key64) |
380 | 373 | v = binary.LittleEndian.Uint64(b[120:]) |
381 | | - binary.LittleEndian.PutUint64(b[120:], v^k) |
| 374 | + binary.LittleEndian.PutUint64(b[120:], v^key64) |
382 | 375 | b = b[128:] |
383 | 376 | } |
384 | 377 |
|
385 | 378 | // Then we xor until b is less than 64 bytes. |
386 | 379 | for len(b) >= 64 { |
387 | 380 | v := binary.LittleEndian.Uint64(b) |
388 | | - binary.LittleEndian.PutUint64(b, v^k) |
| 381 | + binary.LittleEndian.PutUint64(b, v^key64) |
389 | 382 | v = binary.LittleEndian.Uint64(b[8:]) |
390 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
| 383 | + binary.LittleEndian.PutUint64(b[8:], v^key64) |
391 | 384 | v = binary.LittleEndian.Uint64(b[16:]) |
392 | | - binary.LittleEndian.PutUint64(b[16:], v^k) |
| 385 | + binary.LittleEndian.PutUint64(b[16:], v^key64) |
393 | 386 | v = binary.LittleEndian.Uint64(b[24:]) |
394 | | - binary.LittleEndian.PutUint64(b[24:], v^k) |
| 387 | + binary.LittleEndian.PutUint64(b[24:], v^key64) |
395 | 388 | v = binary.LittleEndian.Uint64(b[32:]) |
396 | | - binary.LittleEndian.PutUint64(b[32:], v^k) |
| 389 | + binary.LittleEndian.PutUint64(b[32:], v^key64) |
397 | 390 | v = binary.LittleEndian.Uint64(b[40:]) |
398 | | - binary.LittleEndian.PutUint64(b[40:], v^k) |
| 391 | + binary.LittleEndian.PutUint64(b[40:], v^key64) |
399 | 392 | v = binary.LittleEndian.Uint64(b[48:]) |
400 | | - binary.LittleEndian.PutUint64(b[48:], v^k) |
| 393 | + binary.LittleEndian.PutUint64(b[48:], v^key64) |
401 | 394 | v = binary.LittleEndian.Uint64(b[56:]) |
402 | | - binary.LittleEndian.PutUint64(b[56:], v^k) |
| 395 | + binary.LittleEndian.PutUint64(b[56:], v^key64) |
403 | 396 | b = b[64:] |
404 | 397 | } |
405 | 398 |
|
406 | 399 | // Then we xor until b is less than 32 bytes. |
407 | 400 | for len(b) >= 32 { |
408 | 401 | v := binary.LittleEndian.Uint64(b) |
409 | | - binary.LittleEndian.PutUint64(b, v^k) |
| 402 | + binary.LittleEndian.PutUint64(b, v^key64) |
410 | 403 | v = binary.LittleEndian.Uint64(b[8:]) |
411 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
| 404 | + binary.LittleEndian.PutUint64(b[8:], v^key64) |
412 | 405 | v = binary.LittleEndian.Uint64(b[16:]) |
413 | | - binary.LittleEndian.PutUint64(b[16:], v^k) |
| 406 | + binary.LittleEndian.PutUint64(b[16:], v^key64) |
414 | 407 | v = binary.LittleEndian.Uint64(b[24:]) |
415 | | - binary.LittleEndian.PutUint64(b[24:], v^k) |
| 408 | + binary.LittleEndian.PutUint64(b[24:], v^key64) |
416 | 409 | b = b[32:] |
417 | 410 | } |
418 | 411 |
|
419 | 412 | // Then we xor until b is less than 16 bytes. |
420 | 413 | for len(b) >= 16 { |
421 | 414 | v := binary.LittleEndian.Uint64(b) |
422 | | - binary.LittleEndian.PutUint64(b, v^k) |
| 415 | + binary.LittleEndian.PutUint64(b, v^key64) |
423 | 416 | v = binary.LittleEndian.Uint64(b[8:]) |
424 | | - binary.LittleEndian.PutUint64(b[8:], v^k) |
| 417 | + binary.LittleEndian.PutUint64(b[8:], v^key64) |
425 | 418 | b = b[16:] |
426 | 419 | } |
427 | 420 |
|
428 | 421 | // Then we xor until b is less than 8 bytes. |
429 | 422 | for len(b) >= 8 { |
430 | 423 | v := binary.LittleEndian.Uint64(b) |
431 | | - binary.LittleEndian.PutUint64(b, v^k) |
| 424 | + binary.LittleEndian.PutUint64(b, v^key64) |
432 | 425 | b = b[8:] |
433 | 426 | } |
434 | 427 | } |
435 | 428 |
|
| 429 | + // Then we xor until b is less than 4 bytes. |
| 430 | + for len(b) >= 4 { |
| 431 | + v := binary.LittleEndian.Uint32(b) |
| 432 | + binary.LittleEndian.PutUint32(b, v^key) |
| 433 | + b = b[4:] |
| 434 | + } |
| 435 | + |
436 | 436 | // xor remaining bytes. |
437 | 437 | for i := range b { |
438 | | - b[i] ^= key[keyPos&3] |
439 | | - keyPos++ |
| 438 | + b[i] ^= byte(key) |
| 439 | + key = bits.RotateLeft32(key, -8) |
440 | 440 | } |
441 | | - return keyPos & 3 |
| 441 | + |
| 442 | + return key |
442 | 443 | } |
0 commit comments