1+ # SPDX-License-Identifier: GPL-2.0-or-later
2+ # This file is part of Scapy
3+ # See https://scapy.net/ for more information
4+ # Copyright (C) 2024 Lucas Drufva <lucas.drufva@gmail.com>
5+
6+ # scapy.contrib.description = WebSocket
7+ # scapy.contrib.status = loads
8+
9+ # Based on rfc6455
10+
11+ import struct
12+ import base64
13+ import zlib
14+ from hashlib import sha1
15+ from scapy .fields import (BitFieldLenField , Field , BitField , BitEnumField , ConditionalField , XNBytesField )
16+ from scapy .layers .http import HTTPRequest , HTTPResponse , HTTP
17+ from scapy .layers .inet import TCP
18+ from scapy .packet import Packet
19+ from scapy .error import Scapy_Exception
20+ import logging
21+
22+
23+ class PayloadLenField (BitFieldLenField ):
24+
25+ def __init__ (self , name , default , length_of , size = 0 , tot_size = 0 , end_tot_size = 0 ):
26+ # Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField)
27+ super ().__init__ (name , default , size , length_of = length_of , tot_size = tot_size , end_tot_size = end_tot_size )
28+
29+ def getfield (self , pkt , s ):
30+ s , _ = s
31+ # Get the 7-bit field (first byte)
32+ length_byte = s [0 ] & 0x7F
33+ s = s [1 :]
34+
35+ if length_byte <= 125 :
36+ # 7-bit length
37+ return s , length_byte
38+ elif length_byte == 126 :
39+ # 16-bit length
40+ length = struct .unpack ("!H" , s [:2 ])[0 ] # Read 2 bytes
41+ s = s [2 :]
42+ return s , length
43+ elif length_byte == 127 :
44+ # 64-bit length
45+ length = struct .unpack ("!Q" , s [:8 ])[0 ] # Read 8 bytes
46+ s = s [8 :]
47+ return s , length
48+
49+ def addfield (self , pkt , s , val ):
50+ p_field , p_val = pkt .getfield_and_val (self .length_of )
51+ val = p_field .i2len (pkt , p_val )
52+
53+ if val <= 125 :
54+ self .size = 7
55+ return super ().addfield (pkt , s , val )
56+ elif val <= 0xFFFF :
57+ self .size = 7 + 16
58+ s , _ , masked = s
59+ return s + struct .pack ("!BH" , 126 | masked , val )
60+ elif val <= 0xFFFFFFFFFFFFFFFF :
61+ self .size = 7 + 64
62+ s , _ , masked = s
63+ return s + struct .pack ("!BQ" , 127 | masked , val )
64+ else :
65+ raise Scapy_Exception ("%s: Payload length too large" %
66+ self .__class__ .__name__ )
67+
68+
69+
70+ class PayloadField (Field ):
71+ """
72+ Field for handling raw byte payloads with dynamic size.
73+ The length of the payload is described by a preceding PayloadLenField.
74+ """
75+ __slots__ = ["lengthFrom" ]
76+
77+ def __init__ (self , name , lengthFrom ):
78+ """
79+ :param name: Field name
80+ :param lengthFrom: Field name that provides the length of the payload
81+ """
82+ super (PayloadField , self ).__init__ (name , None )
83+ self .lengthFrom = lengthFrom
84+
85+ def getfield (self , pkt , s ):
86+ # Fetch the length from the field that specifies the length
87+ length = getattr (pkt , self .lengthFrom )
88+ payloadData = s [:length ]
89+
90+ if pkt .mask :
91+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
92+ data_int = int .from_bytes (payloadData , 'big' )
93+ mask_repeated = key * (len (payloadData ) // 4 ) + key [: len (payloadData ) % 4 ]
94+ mask_int = int .from_bytes (mask_repeated , 'big' )
95+ payloadData = (data_int ^ mask_int ).to_bytes (len (payloadData ), 'big' )
96+
97+ if ("permessage-deflate" in pkt .extensions ):
98+ try :
99+ payloadData = pkt .decoder [0 ](payloadData + b"\x00 \x00 \xff \xff " )
100+ except Exception :
101+ logging .debug ("Failed to decompress payload" , payloadData )
102+
103+ return s [length :], payloadData
104+
105+ def addfield (self , pkt , s , val ):
106+ if pkt .mask :
107+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
108+ data_int = int .from_bytes (val , 'big' )
109+ mask_repeated = key * (len (val ) // 4 ) + key [: len (val ) % 4 ]
110+ mask_int = int .from_bytes (mask_repeated , 'big' )
111+ val = (data_int ^ mask_int ).to_bytes (len (val ), 'big' )
112+
113+ return s + bytes (val )
114+
115+ def i2len (self , pkt , val ):
116+ # Length of the payload in bytes
117+ return len (val )
118+
119+ class WebSocket (Packet ):
120+ __slots__ = ["extensions" , "decoder" ]
121+
122+ name = "WebSocket"
123+ fields_desc = [
124+ BitField ("fin" , 0 , 1 ),
125+ BitField ("rsv" , 0 , 3 ),
126+ BitEnumField ("opcode" , 0 , 4 ,
127+ {
128+ 0x0 : "none" ,
129+ 0x1 : "text" ,
130+ 0x2 : "binary" ,
131+ 0x8 : "close" ,
132+ 0x9 : "ping" ,
133+ 0xA : "pong" ,
134+ }),
135+ BitField ("mask" , 0 , 1 ),
136+ PayloadLenField ("payloadLen" , 0 , length_of = "wsPayload" , size = 1 ),
137+ ConditionalField (XNBytesField ("maskingKey" , 0 , sz = 4 ), lambda pkt : pkt .mask == 1 ),
138+ PayloadField ("wsPayload" , lengthFrom = "payloadLen" )
139+ ]
140+
141+ def __init__ (self , pkt = None , extensions = [], decoder = None , * args , ** fields ):
142+ self .extensions = extensions
143+ self .decoder = decoder
144+ super ().__init__ (_pkt = pkt , * args , ** fields )
145+
146+ def extract_padding (self , s ):
147+ return '' , s
148+
149+ @classmethod
150+ def tcp_reassemble (cls , data , metadata , session ):
151+ # data = the reassembled data from the same request/flow
152+ # metadata = empty dictionary, that can be used to store data
153+ # during TCP reassembly
154+ # session = a dictionary proper to the bidirectional TCP session,
155+ # that can be used to store anything
156+ # [...]
157+ # If the packet is available, return it. Otherwise don't.
158+ # Whenever you return a packet, the buffer will be discarded.
159+
160+
161+ HANDSHAKE_STATE_CLIENT_OPEN = 0
162+ HANDSHAKE_STATE_SERVER_OPEN = 1
163+ HANDSHAKE_STATE_OPEN = 2
164+
165+ if "handshake-state" not in session :
166+ session ["handshake-state" ] = HANDSHAKE_STATE_CLIENT_OPEN
167+
168+ if "extensions" not in session :
169+ session ["extensions" ] = {}
170+
171+
172+ if session ["handshake-state" ] == HANDSHAKE_STATE_CLIENT_OPEN :
173+ http_data = HTTP (data )
174+ if HTTPRequest in http_data :
175+ http_data = http_data [HTTPRequest ]
176+ else :
177+ return http_data
178+
179+ if http_data .Method != b"GET" :
180+ return HTTP ()/ http_data
181+
182+ if not http_data .Upgrade or http_data .Upgrade .lower () != b"websocket" :
183+ return HTTP ()/ http_data
184+
185+ if not http_data .Unknown_Headers or b"Sec-WebSocket-Key" not in http_data .Unknown_Headers :
186+ return HTTP ()/ http_data
187+
188+ session ["handshake-key" ] = http_data .Unknown_Headers [b"Sec-WebSocket-Key" ]
189+
190+ if "original" in metadata :
191+ session ["server-port" ] = metadata ["original" ][TCP ].dport
192+
193+ session ["handshake-state" ] = HANDSHAKE_STATE_SERVER_OPEN
194+
195+ return http_data
196+
197+ elif session ["handshake-state" ] == HANDSHAKE_STATE_SERVER_OPEN :
198+ http_data = HTTP (data )
199+ if HTTPResponse in http_data :
200+ http_data = http_data [HTTPResponse ]
201+ else :
202+ return http_data
203+
204+ if not http_data .Upgrade .lower () == b"websocket" :
205+ return HTTP ()/ http_data
206+
207+ # Verify key-accept handshake:
208+ correct_accept = base64 .b64encode (sha1 (session ["handshake-key" ] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" .encode ()).digest ())
209+ if not http_data .Unknown_Headers or b"Sec-WebSocket-Accept" not in http_data .Unknown_Headers or http_data .Unknown_Headers [b"Sec-WebSocket-Accept" ] != correct_accept :
210+ # TODO: handle or Logg wrong accept key
211+ pass
212+
213+ if b"Sec-WebSocket-Extensions" in http_data .Unknown_Headers :
214+ session ["extensions" ] = {}
215+ for extension in http_data .Unknown_Headers [b"Sec-WebSocket-Extensions" ].decode ().strip ().split (";" ):
216+ key_value_pair = extension .split ("=" , 1 ) + [None ]
217+ session ["extensions" ][key_value_pair [0 ].strip ()] = key_value_pair [1 ]
218+
219+ if "permessage-deflate" in session ["extensions" ]:
220+ def create_decompressor (window_bits ):
221+ decoder = zlib .decompressobj (wbits = - window_bits )
222+ def decomp (data ):
223+ nonlocal decoder
224+ return decoder .decompress (data , 0 )
225+
226+ def reset ():
227+ nonlocal decoder
228+ nonlocal window_bits
229+ decoder = zlib .decompressobj (wbits = - window_bits )
230+
231+ return (decomp , reset )
232+
233+ # Default values
234+ client_wb = 12
235+ server_wb = 15
236+
237+ # Check for new values in extensions header
238+ if "client_max_window_bits" in session ["extensions" ]:
239+ client_wb = int (session ["extensions" ]["client_max_window_bits" ])
240+
241+ if "server_max_window_bits" in session ["extensions" ]:
242+ server_wb = int (session ["extensions" ]["server_max_window_bits" ])
243+
244+
245+ session ["server-decoder" ] = create_decompressor (client_wb )
246+ session ["client-decoder" ] = create_decompressor (server_wb )
247+
248+
249+ session ["handshake-state" ] = HANDSHAKE_STATE_OPEN
250+
251+ return HTTP ()/ http_data
252+
253+
254+ # Handshake is done:
255+ if "original" not in metadata :
256+ return
257+
258+ if "permessage-deflate" in session ["extensions" ]:
259+ is_server = True if metadata ["original" ][TCP ].sport == session ["server-port" ] else False
260+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ], decoder = session ["server-decoder" ] if is_server else session ["client-decoder" ])
261+ return ws
262+ else :
263+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ])
264+ return ws
0 commit comments