1+ # SPDX-License-Identifier: GPL-2.0-or-later
2+ # This file is part of Scapy
3+ # See https://scapy.net/ for more information
4+ # Copyright (C) 2024 Lucas Drufva <lucas.drufva@gmail.com>
5+
6+ # scapy.contrib.description = WebSocket
7+ # scapy.contrib.status = loads
8+
9+ # Based on rfc6455
10+
11+ import struct
12+ import base64
13+ import zlib
14+ from hashlib import sha1
15+ from scapy .fields import (BitFieldLenField , Field , BitField , BitEnumField , ConditionalField , XNBytesField )
16+ from scapy .layers .http import HTTPRequest , HTTPResponse
17+ from scapy .layers .inet import TCP
18+ from scapy .packet import Packet
19+ from scapy .error import Scapy_Exception
20+ import logging
21+
22+
23+ class PayloadLenField (BitFieldLenField ):
24+
25+ def __init__ (self , name , default , length_of , size = 0 , tot_size = 0 , end_tot_size = 0 ):
26+ # Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField)
27+ super ().__init__ (name , default , size , length_of = length_of , tot_size = tot_size , end_tot_size = end_tot_size )
28+
29+ def getfield (self , pkt , s ):
30+ s , _ = s
31+ # Get the 7-bit field (first byte)
32+ length_byte = s [0 ] & 0x7F
33+ s = s [1 :]
34+
35+ if length_byte <= 125 :
36+ # 7-bit length
37+ return s , length_byte
38+ elif length_byte == 126 :
39+ # 16-bit length
40+ length = struct .unpack ("!H" , s [:2 ])[0 ] # Read 2 bytes
41+ s = s [2 :]
42+ return s , length
43+ elif length_byte == 127 :
44+ # 64-bit length
45+ length = struct .unpack ("!Q" , s [:8 ])[0 ] # Read 8 bytes
46+ s = s [8 :]
47+ return s , length
48+
49+ def addfield (self , pkt , s , val ):
50+ p_field , p_val = pkt .getfield_and_val (self .length_of )
51+ val = p_field .i2len (pkt , p_val )
52+
53+ if val <= 125 :
54+ self .size = 7
55+ return super ().addfield (pkt , s , val )
56+ elif val <= 0xFFFF :
57+ self .size = 7 + 16
58+ s , _ , masked = s
59+ return s + struct .pack ("!BH" , 126 | masked , val )
60+ elif val <= 0xFFFFFFFFFFFFFFFF :
61+ self .size = 7 + 64
62+ s , _ , masked = s
63+ return s + struct .pack ("!BQ" , 127 | masked , val )
64+ else :
65+ raise Scapy_Exception ("%s: Payload length too large" %
66+ self .__class__ .__name__ )
67+
68+
69+
70+ class PayloadField (Field ):
71+ """
72+ Field for handling raw byte payloads with dynamic size.
73+ The length of the payload is described by a preceding PayloadLenField.
74+ """
75+ __slots__ = ["lengthFrom" ]
76+
77+ def __init__ (self , name , lengthFrom ):
78+ """
79+ :param name: Field name
80+ :param lengthFrom: Field name that provides the length of the payload
81+ """
82+ super (PayloadField , self ).__init__ (name , None )
83+ self .lengthFrom = lengthFrom
84+
85+ def getfield (self , pkt , s ):
86+ # Fetch the length from the field that specifies the length
87+ length = getattr (pkt , self .lengthFrom )
88+ payloadData = s [:length ]
89+
90+ if pkt .mask :
91+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
92+ data_int = int .from_bytes (payloadData , 'big' )
93+ mask_repeated = key * (len (payloadData ) // 4 ) + key [: len (payloadData ) % 4 ]
94+ mask_int = int .from_bytes (mask_repeated , 'big' )
95+ payloadData = (data_int ^ mask_int ).to_bytes (len (payloadData ), 'big' )
96+
97+ if ("permessage-deflate" in pkt .extensions ):
98+ try :
99+ payloadData = pkt .decoder [0 ](payloadData + b"\x00 \x00 \xff \xff " )
100+ except Exception :
101+ logging .debug ("Failed to decompress payload" , payloadData )
102+
103+ return s [length :], payloadData
104+
105+ def addfield (self , pkt , s , val ):
106+ if pkt .mask :
107+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
108+ data_int = int .from_bytes (val , 'big' )
109+ mask_repeated = key * (len (val ) // 4 ) + key [: len (val ) % 4 ]
110+ mask_int = int .from_bytes (mask_repeated , 'big' )
111+ val = (data_int ^ mask_int ).to_bytes (len (val ), 'big' )
112+
113+ return s + bytes (val )
114+
115+ def i2len (self , pkt , val ):
116+ # Length of the payload in bytes
117+ return len (val )
118+
119+ class WebSocket (Packet ):
120+ __slots__ = ["extensions" , "decoder" ]
121+
122+ name = "WebSocket"
123+ fields_desc = [
124+ BitField ("fin" , 0 , 1 ),
125+ BitField ("rsv" , 0 , 3 ),
126+ BitEnumField ("opcode" , 0 , 4 ,
127+ {
128+ 0x0 : "none" ,
129+ 0x1 : "text" ,
130+ 0x2 : "binary" ,
131+ 0x8 : "close" ,
132+ 0x9 : "ping" ,
133+ 0xA : "pong" ,
134+ }),
135+ BitField ("mask" , 0 , 1 ),
136+ PayloadLenField ("payloadLen" , 0 , length_of = "wsPayload" , size = 1 ),
137+ ConditionalField (XNBytesField ("maskingKey" , 0 , sz = 4 ), lambda pkt : pkt .mask == 1 ),
138+ PayloadField ("wsPayload" , lengthFrom = "payloadLen" )
139+ ]
140+
141+ def __init__ (self , pkt = None , extensions = [], decoder = None , * args , ** fields ):
142+ self .extensions = extensions
143+ self .decoder = decoder
144+ super ().__init__ (_pkt = pkt , * args , ** fields )
145+
146+ def extract_padding (self , s ):
147+ return '' , s
148+
149+ @classmethod
150+ def tcp_reassemble (cls , data , metadata , session ):
151+ # data = the reassembled data from the same request/flow
152+ # metadata = empty dictionary, that can be used to store data
153+ # during TCP reassembly
154+ # session = a dictionary proper to the bidirectional TCP session,
155+ # that can be used to store anything
156+ # [...]
157+ # If the packet is available, return it. Otherwise don't.
158+ # Whenever you return a packet, the buffer will be discarded.
159+
160+
161+ HANDSHAKE_STATE_CLIENT_OPEN = 0
162+ HANDSHAKE_STATE_SERVER_OPEN = 1
163+ HANDSHAKE_STATE_OPEN = 2
164+
165+ if "handshake-state" not in session :
166+ session ["handshake-state" ] = HANDSHAKE_STATE_CLIENT_OPEN
167+
168+ if "extensions" not in session :
169+ session ["extensions" ] = {}
170+
171+
172+ if session ["handshake-state" ] == HANDSHAKE_STATE_CLIENT_OPEN :
173+ ht = HTTPRequest (data )
174+
175+ if ht .Method != b"GET" :
176+ return None
177+
178+ if not ht .Upgrade or ht .Upgrade .lower () != b"websocket" :
179+ return None
180+
181+ if b"Sec-WebSocket-Key" not in ht .Unknown_Headers :
182+ return None
183+
184+
185+ session ["handshake-key" ] = ht .Unknown_Headers [b"Sec-WebSocket-Key" ]
186+
187+ if "original" in metadata :
188+ session ["server-port" ] = metadata ["original" ][TCP ].dport
189+ else :
190+ print ("No original packet" )
191+
192+ session ["handshake-state" ] = HANDSHAKE_STATE_SERVER_OPEN
193+
194+ return ht
195+
196+ elif session ["handshake-state" ] == HANDSHAKE_STATE_SERVER_OPEN :
197+ ht = HTTPResponse (data )
198+
199+ if not ht .Upgrade .lower () == b"websocket" :
200+ return None
201+
202+ # Verify key-accept handshake:
203+ correct_accept = base64 .b64encode (sha1 (session ["handshake-key" ] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" .encode ()).digest ())
204+ if ht .Unknown_Headers [b"Sec-WebSocket-Accept" ] != correct_accept :
205+ #TODO handle or Logg wrong accept key
206+ pass
207+
208+ if b"Sec-WebSocket-Extensions" in ht .Unknown_Headers :
209+ session ["extensions" ] = {}
210+ for extension in ht .Unknown_Headers [b"Sec-WebSocket-Extensions" ].decode ().strip ().split (";" ):
211+ key_value_pair = extension .split ("=" , 1 ) + [None ]
212+ session ["extensions" ][key_value_pair [0 ].strip ()] = key_value_pair [1 ]
213+
214+ if "permessage-deflate" in session ["extensions" ]:
215+ def create_decompressor (window_bits ):
216+ decoder = zlib .decompressobj (wbits = - window_bits )
217+ def decomp (data ):
218+ nonlocal decoder
219+ return decoder .decompress (data , 0 )
220+
221+ def reset ():
222+ nonlocal decoder
223+ nonlocal window_bits
224+ decoder = zlib .decompressobj (wbits = - window_bits )
225+
226+ return (decomp , reset )
227+
228+ # Default values
229+ client_wb = 12
230+ server_wb = 15
231+
232+ # Check for new values in extensions header
233+ if "client_max_window_bits" in session ["extensions" ]:
234+ client_wb = int (session ["extensions" ]["client_max_window_bits" ])
235+
236+ if "server_max_window_bits" in session ["extensions" ]:
237+ server_wb = int (session ["extensions" ]["server_max_window_bits" ])
238+
239+
240+ session ["server-decoder" ] = create_decompressor (client_wb )
241+ session ["client-decoder" ] = create_decompressor (server_wb )
242+
243+
244+ session ["handshake-state" ] = HANDSHAKE_STATE_OPEN
245+
246+ return ht
247+
248+
249+ # Handshake is done:
250+ if "original" not in metadata :
251+ return
252+
253+ if "permessage-deflate" in session ["extensions" ]:
254+ is_server = True if metadata ["original" ][TCP ].sport == session ["server-port" ] else False
255+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ], decoder = session ["server-decoder" ] if is_server else session ["client-decoder" ])
256+ return ws
257+ else :
258+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ])
259+ return ws
0 commit comments