Skip to content

Commit 2ee72d4

Browse files
committed
fixup! multi: decode zero-length onion message payloads
1 parent 12391a0 commit 2ee72d4

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

payload.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,22 @@ func (hp *HopPayload) Decode(r io.Reader) error {
101101

102102
var payloadSize uint16
103103

104-
// If the HopPayload isn't guaranteed to be a TLV payload, we check the
105-
// first byte to see if it is a legacy payload.
106-
if hp.Type != PayloadTLV && isLegacyPayloadByte(peekByte[0]) {
104+
// If the HopPayload might be a legacy payload (indicated by the Type
105+
// being equal to the zero-value PayloadLegacy), we check the first
106+
// byte to see if it is a legacy payload.
107+
if hp.Type == PayloadLegacy && isLegacyPayloadByte(peekByte[0]) {
107108
payloadSize = legacyPayloadSize()
108109
} else {
109-
// If the first byte doesn't indicate a legacy payload, then it
110-
// *must* be a TLV payload.
110+
// Either this is already known to be a TLV payload, or the
111+
// first byte indicates that this is a TLV payload (i.e., the
112+
// first byte is not 0x00).
111113
payloadSize, err = tlvPayloadSize(bufReader)
112114
if err != nil {
113115
return err
114116
}
115117

118+
// We still need to set the payload type in case it was the
119+
// zero-value, PayloadLegacy, AND the first byte was not 0x00.
116120
hp.Type = PayloadTLV
117121
}
118122

@@ -140,11 +144,15 @@ func readPayloadAndHMAC(hp *HopPayload, r io.Reader, payloadSize uint16) error {
140144
// Now that we know the payload size, we'll create a new buffer to read
141145
// it out in full.
142146
hp.Payload = make([]byte, payloadSize)
143-
if _, err := io.ReadFull(r, hp.Payload[:]); err != nil {
144-
return err
147+
148+
_, err := io.ReadFull(r, hp.Payload)
149+
if err != nil {
150+
return fmt.Errorf("%w : %w", ErrIOReadFull, err)
145151
}
146-
if _, err := io.ReadFull(r, hp.HMAC[:]); err != nil {
147-
return err
152+
153+
_, err = io.ReadFull(r, hp.HMAC[:])
154+
if err != nil {
155+
return fmt.Errorf("%w : %w", ErrIOReadFull, err)
148156
}
149157

150158
return nil

sphinx.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ func WithBlindingPoint(point *btcec.PublicKey) ProcessOnionOpt {
579579
}
580580

581581
// WithTLVPayloadOnly is a functional option that signals that the onion packet
582-
// being processed is from onion message.
582+
// being processed is an onion_message_packet.
583583
func WithTLVPayloadOnly() ProcessOnionOpt {
584584
return func(cfg *processOnionCfg) {
585585
cfg.tlvPayloadOnly = true

sphinx_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ func TestTLVPayloadMessagePacket(t *testing.T) {
430430
hex.EncodeToString(finalPacket), hex.EncodeToString(b.Bytes()))
431431
}
432432

433-
// TestProcessOnionMessageZeroLengthPayload tests that we can properly process an
434-
// onion message that has a zero-length payload.
433+
// TestProcessOnionMessageZeroLengthPayload tests that we can properly process
434+
// an onion message that has a zero-length payload.
435435
func TestProcessOnionMessageZeroLengthPayload(t *testing.T) {
436436
t.Parallel()
437437

@@ -443,6 +443,7 @@ func TestProcessOnionMessageZeroLengthPayload(t *testing.T) {
443443
router := NewRouter(&PrivKeyECDH{privKey}, NewMemoryReplayLog())
444444
err = router.Start()
445445
require.NoError(t, err)
446+
446447
defer router.Stop()
447448

448449
// Next, create a session key for the onion packet.
@@ -459,6 +460,7 @@ func TestProcessOnionMessageZeroLengthPayload(t *testing.T) {
459460
// The hop payload will be an empty TLV payload.
460461
payload, err := NewTLVHopPayload(nil)
461462
require.NoError(t, err)
463+
require.Empty(t, payload.Payload)
462464
path[0].HopPayload = payload
463465

464466
// Now, create the onion packet.

0 commit comments

Comments
 (0)