Skip to content

Commit 30d65ee

Browse files
committed
multi: decode zero-length onion message payloads
Since the onion message payload can be zero-length, we need to decode it correctly. This commit adds a boolean flag to the HopPayload Decode that tells whether the payload is an onion message payload or not. If it is, the payload is decoded as a tlv payload also if the first byte is 0x00. sphinx_test: Add zero-length payload om test
1 parent 4f2dbed commit 30d65ee

File tree

4 files changed

+161
-43
lines changed

4 files changed

+161
-43
lines changed

error.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package sphinx
22

3-
import "fmt"
3+
import (
4+
"errors"
5+
"fmt"
6+
)
47

58
var (
69
// ErrReplayedPacket is an error returned when a packet is rejected
@@ -24,4 +27,7 @@ var (
2427
// ErrLogEntryNotFound is an error returned when a packet lookup in a replay
2528
// log fails because it is missing.
2629
ErrLogEntryNotFound = fmt.Errorf("sphinx packet is not in log")
30+
31+
// ErrIOReadFull is returned when an io read full operation fails.
32+
ErrIOReadFull = errors.New("io read full error")
2733
)

payload.go

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -87,48 +87,62 @@ func (hp *HopPayload) Encode(w io.Writer) error {
8787
}
8888

8989
// Decode unpacks an encoded HopPayload from the passed reader into the target
90-
// HopPayload.
91-
func (hp *HopPayload) Decode(r io.Reader) error {
90+
// HopPayload. tlvGuaranteed should be set to true if the caller only wishes to
91+
// accept TLV encoded payloads. By doing so, zero-lengt tlv payloads are
92+
// supported. If set to false, then the function will inspect the first byte to
93+
// determine the type of payload.
94+
func DecodeHopPayload(r io.Reader, tlvGuaranteed bool) (*HopPayload, error) {
9295
bufReader := bufio.NewReader(r)
9396

94-
// In order to properly parse the payload, we'll need to check the
95-
// first byte. We'll use a bufio reader to peek at it without consuming
96-
// it from the buffer.
97+
var payloadSize uint16
98+
99+
hopPayload := &HopPayload{}
100+
101+
// If we are not sure if this is a TLV or legacy payload, then we need
102+
// to inspect the first byte to determine the type of payload. The first
103+
// byte is either a realm (legacy) or the beginning of a var-int
104+
// encoding the length of the payload (TLV). We'll use a bufio reader to
105+
// peek at it without consuming it from the buffer.
97106
peekByte, err := bufReader.Peek(1)
98107
if err != nil {
99-
return err
108+
return nil, fmt.Errorf("peek first payload byte: %w", err)
100109
}
101110

102-
var (
103-
legacyPayload = isLegacyPayloadByte(peekByte[0])
104-
payloadSize uint16
105-
)
111+
switch {
112+
case tlvGuaranteed:
113+
// If we're instructed to only accept TLV payloads, then we set
114+
// the type accordingly. This allows us to support zero-length
115+
// TLV payloads.
116+
117+
hopPayload.Type = PayloadTLV
106118

107-
if legacyPayload {
108-
payloadSize = legacyPayloadSize()
109-
hp.Type = PayloadLegacy
110-
} else {
111119
payloadSize, err = tlvPayloadSize(bufReader)
112120
if err != nil {
113-
return err
121+
return nil, err
114122
}
115123

116-
hp.Type = PayloadTLV
117-
}
124+
case isLegacyPayloadByte(peekByte[0]):
125+
// If the first byte indicates that this is a legacy payload,
126+
// then we set the type accordingly.
127+
hopPayload.Type = PayloadLegacy
128+
payloadSize = legacyPayloadSize()
118129

119-
// Now that we know the payload size, we'll create a new buffer to
120-
// read it out in full.
121-
//
122-
// TODO(roasbeef): can avoid all these copies
123-
hp.Payload = make([]byte, payloadSize)
124-
if _, err := io.ReadFull(bufReader, hp.Payload[:]); err != nil {
125-
return err
130+
default:
131+
// Otherwise, we set the type to TLV.
132+
hopPayload.Type = PayloadTLV
133+
134+
payloadSize, err = tlvPayloadSize(bufReader)
135+
if err != nil {
136+
return nil, err
137+
}
126138
}
127-
if _, err := io.ReadFull(bufReader, hp.HMAC[:]); err != nil {
128-
return err
139+
140+
err = readPayloadAndHMAC(hopPayload, bufReader, payloadSize)
141+
if err != nil {
142+
return nil, err
129143
}
130144

131-
return nil
145+
return hopPayload, nil
132146
}
133147

134148
// HopData attempts to extract a set of forwarding instructions from the target
@@ -146,6 +160,26 @@ func (hp *HopPayload) HopData() (*HopData, error) {
146160
return nil, nil
147161
}
148162

163+
// readPayloadAndHMAC reads the payload and HMAC from the reader into the
164+
// HopPayload.
165+
func readPayloadAndHMAC(hp *HopPayload, r io.Reader, payloadSize uint16) error {
166+
// Now that we know the payload size, we'll create a new buffer to read
167+
// it out in full.
168+
hp.Payload = make([]byte, payloadSize)
169+
170+
_, err := io.ReadFull(r, hp.Payload)
171+
if err != nil {
172+
return fmt.Errorf("%w : %w", ErrIOReadFull, err)
173+
}
174+
175+
_, err = io.ReadFull(r, hp.HMAC[:])
176+
if err != nil {
177+
return fmt.Errorf("%w : %w", ErrIOReadFull, err)
178+
}
179+
180+
return nil
181+
}
182+
149183
// tlvPayloadSize uses the passed reader to extract the payload length encoded
150184
// as a var-int.
151185
func tlvPayloadSize(r io.Reader) (uint16, error) {
@@ -314,8 +348,12 @@ func legacyNumBytes() int {
314348
return LegacyHopDataSize
315349
}
316350

317-
// isLegacyPayload returns true if the given byte is equal to the 0x00 byte
318-
// which indicates that the payload should be decoded as a legacy payload.
351+
// isLegacyPayloadByte determines if the first byte of a hop payload indicates
352+
// that it is a legacy payload. The first byte of a legacy payload will always
353+
// be 0x00, as this is the realm. For TLV payloads, the first byte is a
354+
// var-int encoding the length of the payload. A TLV stream can be empty, in
355+
// which case its length is 0, which is also encoded as a 0x00 byte. This
356+
// creates an ambiguity between a legacy payload and an empty TLV payload.
319357
func isLegacyPayloadByte(b byte) bool {
320358
return b == 0x00
321359
}

sphinx.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ func (r *Router) Stop() {
510510
// processOnionCfg is a set of config values that can be used to modify how an
511511
// onion is processed.
512512
type processOnionCfg struct {
513-
blindingPoint *btcec.PublicKey
513+
blindingPoint *btcec.PublicKey
514+
tlvPayloadOnly bool
514515
}
515516

516517
// ProcessOnionOpt defines the signature of a function option that can be used
@@ -525,6 +526,14 @@ func WithBlindingPoint(point *btcec.PublicKey) ProcessOnionOpt {
525526
}
526527
}
527528

529+
// WithTLVPayloadOnly is a functional option that signals that the onion packet
530+
// being processed is an onion_message_packet.
531+
func WithTLVPayloadOnly() ProcessOnionOpt {
532+
return func(cfg *processOnionCfg) {
533+
cfg.tlvPayloadOnly = true
534+
}
535+
}
536+
528537
// ProcessOnionPacket processes an incoming onion packet which has been forward
529538
// to the target Sphinx router. If the encoded ephemeral key isn't on the
530539
// target Elliptic Curve, then the packet is rejected. Similarly, if the
@@ -560,7 +569,9 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte,
560569
// Continue to optimistically process this packet, deferring replay
561570
// protection until the end to reduce the penalty of multiple IO
562571
// operations.
563-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
572+
packet, err := processOnionPacket(
573+
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
574+
)
564575
if err != nil {
565576
return nil, err
566577
}
@@ -594,7 +605,9 @@ func (r *Router) ReconstructOnionPacket(onionPkt *OnionPacket, assocData []byte,
594605
return nil, err
595606
}
596607

597-
return processOnionPacket(onionPkt, &sharedSecret, assocData)
608+
return processOnionPacket(
609+
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
610+
)
598611
}
599612

600613
// DecryptBlindedHopData uses the router's private key to decrypt data encrypted
@@ -625,7 +638,8 @@ func (r *Router) OnionPublicKey() *btcec.PublicKey {
625638
// packet. This function returns the next inner onion packet layer, along with
626639
// the hop data extracted from the outer onion packet.
627640
func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
628-
assocData []byte) (*OnionPacket, *HopPayload, error) {
641+
assocData []byte, tlvPayloadOnly bool) (*OnionPacket, *HopPayload,
642+
error) {
629643

630644
dhKey := onionPkt.EphemeralKey
631645
routeInfo := onionPkt.RoutingInfo
@@ -649,8 +663,8 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
649663
zeroBytes := bytes.Repeat([]byte{0}, MaxPayloadSize)
650664
headerWithPadding := append(routeInfo[:], zeroBytes...)
651665

652-
var hopInfo [numStreamBytes]byte
653-
xor(hopInfo[:], headerWithPadding, streamBytes)
666+
hopInfo := make([]byte, numStreamBytes)
667+
xor(hopInfo, headerWithPadding, streamBytes)
654668

655669
// Randomize the DH group element for the next hop using the
656670
// deterministic blinding factor.
@@ -660,8 +674,10 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
660674
// With the MAC checked, and the payload decrypted, we can now parse
661675
// out the payload so we can derive the specified forwarding
662676
// instructions.
663-
var hopPayload HopPayload
664-
if err := hopPayload.Decode(bytes.NewReader(hopInfo[:])); err != nil {
677+
hopPayload, err := DecodeHopPayload(
678+
bytes.NewReader(hopInfo), tlvPayloadOnly,
679+
)
680+
if err != nil {
665681
return nil, nil, err
666682
}
667683

@@ -676,14 +692,14 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
676692
HeaderMAC: hopPayload.HMAC,
677693
}
678694

679-
return innerPkt, &hopPayload, nil
695+
return innerPkt, hopPayload, nil
680696
}
681697

682698
// processOnionPacket performs the primary key derivation and handling of onion
683699
// packets. The processed packets returned from this method should only be used
684700
// if the packet was not flagged as a replayed packet.
685701
func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
686-
assocData []byte) (*ProcessedPacket, error) {
702+
assocData []byte, tlvPayloadOnly bool) (*ProcessedPacket, error) {
687703

688704
// First, we'll unwrap an initial layer of the onion packet. Typically,
689705
// we'll only have a single layer to unwrap, However, if the sender has
@@ -693,7 +709,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
693709
// they can properly check the HMAC and unwrap a layer for their
694710
// handoff hop.
695711
innerPkt, outerHopPayload, err := unwrapPacket(
696-
onionPkt, sharedSecret, assocData,
712+
onionPkt, sharedSecret, assocData, tlvPayloadOnly,
697713
)
698714
if err != nil {
699715
return nil, err
@@ -703,7 +719,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
703719
// However if the uncovered 'nextMac' is all zeroes, then this
704720
// indicates that we're the final hop in the route.
705721
var action ProcessCode = MoreHops
706-
if bytes.Compare(zeroHMAC[:], outerHopPayload.HMAC[:]) == 0 {
722+
if bytes.Equal(zeroHMAC[:], outerHopPayload.HMAC[:]) {
707723
action = ExitNode
708724
}
709725

@@ -794,7 +810,9 @@ func (t *Tx) ProcessOnionPacket(seqNum uint16, onionPkt *OnionPacket,
794810
// Continue to optimistically process this packet, deferring replay
795811
// protection until the end to reduce the penalty of multiple IO
796812
// operations.
797-
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
813+
packet, err := processOnionPacket(
814+
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
815+
)
798816
if err != nil {
799817
return err
800818
}

sphinx_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,62 @@ func TestTLVPayloadMessagePacket(t *testing.T) {
299299
hex.EncodeToString(finalPacket), hex.EncodeToString(b.Bytes()))
300300
}
301301

302+
// TestProcessOnionMessageZeroLengthPayload tests that we can properly process
303+
// an onion message that has a zero-length payload.
304+
func TestProcessOnionMessageZeroLengthPayload(t *testing.T) {
305+
t.Parallel()
306+
307+
// First, create a router that will be the destination of the onion
308+
// message.
309+
privKey, err := btcec.NewPrivateKey()
310+
require.NoError(t, err)
311+
312+
router := NewRouter(&PrivKeyECDH{privKey}, NewMemoryReplayLog())
313+
err = router.Start()
314+
require.NoError(t, err)
315+
316+
defer router.Stop()
317+
318+
// Next, create a session key for the onion packet.
319+
sessionKey, err := btcec.NewPrivateKey()
320+
require.NoError(t, err)
321+
322+
// We'll create a simple one-hop path.
323+
path := &PaymentPath{
324+
{
325+
NodePub: *privKey.PubKey(),
326+
},
327+
}
328+
329+
// The hop payload will be an empty TLV payload.
330+
payload, err := NewTLVHopPayload(nil)
331+
require.NoError(t, err)
332+
require.Empty(t, payload.Payload)
333+
path[0].HopPayload = payload
334+
335+
// Now, create the onion packet.
336+
onionPacket, err := NewOnionPacket(
337+
path, sessionKey, nil, DeterministicPacketFiller,
338+
)
339+
require.NoError(t, err)
340+
341+
// We'll now process the packet, making sure to indicate that this is
342+
// an onion message.
343+
processedPacket, err := router.ProcessOnionPacket(
344+
onionPacket, nil, 0, WithTLVPayloadOnly(),
345+
)
346+
require.NoError(t, err)
347+
348+
// The packet should be decoded as an exit node.
349+
require.EqualValues(t, ExitNode, processedPacket.Action)
350+
351+
// The payload should be of type TLV.
352+
require.Equal(t, PayloadTLV, processedPacket.Payload.Type)
353+
354+
// And the payload should be empty.
355+
require.Empty(t, processedPacket.Payload.Payload)
356+
}
357+
302358
func TestSphinxCorrectness(t *testing.T) {
303359
nodes, _, hopDatas, fwdMsg, err := newTestRoute(testLegacyRouteNumHops)
304360
if err != nil {

0 commit comments

Comments
 (0)