1+ # SPDX-License-Identifier: GPL-2.0-or-later
2+ # This file is part of Scapy
3+ # See https://scapy.net/ for more information
4+ # Copyright (C) 2024 Lucas Drufva <lucas.drufva@gmail.com>
5+
6+ # scapy.contrib.description = WebSocket
7+ # scapy.contrib.status = loads
8+
9+ # Based on rfc6455
10+
11+ import struct
12+ import base64
13+ import zlib
14+ from hashlib import sha1
15+ from scapy .fields import (BitFieldLenField , Field , BitField , BitEnumField , ConditionalField , XIntField , FieldLenField , XNBytesField )
16+ from scapy .layers .http import HTTPRequest , HTTPResponse
17+ from scapy .layers .inet import TCP
18+ from scapy .packet import Packet
19+ from scapy .error import Scapy_Exception
20+
21+
22+ class PayloadLenField (BitFieldLenField ):
23+
24+ def __init__ (self , name , default , length_of , size = 0 , tot_size = 0 , end_tot_size = 0 ):
25+ # Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField)
26+ super ().__init__ (name , default , size , length_of = length_of , tot_size = tot_size , end_tot_size = end_tot_size )
27+
28+ def getfield (self , pkt , s ):
29+ s , _ = s
30+ # Get the 7-bit field (first byte)
31+ length_byte = s [0 ] & 0x7F
32+ s = s [1 :]
33+
34+ if length_byte <= 125 :
35+ # 7-bit length
36+ return s , length_byte
37+ elif length_byte == 126 :
38+ # 16-bit length
39+ length = struct .unpack ("!H" , s [:2 ])[0 ] # Read 2 bytes
40+ s = s [2 :]
41+ return s , length
42+ elif length_byte == 127 :
43+ # 64-bit length
44+ length = struct .unpack ("!Q" , s [:8 ])[0 ] # Read 8 bytes
45+ s = s [8 :]
46+ return s , length
47+
48+ def addfield (self , pkt , s , val ):
49+ p_field , p_val = pkt .getfield_and_val (self .length_of )
50+ val = p_field .i2len (pkt , p_val )
51+
52+ if val <= 125 :
53+ self .size = 7
54+ return super ().addfield (pkt , s , val )
55+ elif val <= 0xFFFF :
56+ self .size = 7 + 16
57+ s , _ , masked = s
58+ return s + struct .pack ("!BH" , 126 | masked , val )
59+ elif val <= 0xFFFFFFFFFFFFFFFF :
60+ self .size = 7 + 64
61+ s , _ , masked = s
62+ return s + struct .pack ("!BQ" , 127 | masked , val )
63+ else :
64+ raise Scapy_Exception ("%s: Payload length too large" %
65+ self .__class__ .__name__ )
66+
67+
68+
69+ class PayloadField (Field ):
70+ """
71+ Field for handling raw byte payloads with dynamic size.
72+ The length of the payload is described by a preceding PayloadLenField.
73+ """
74+ __slots__ = ["lengthFrom" ]
75+
76+ def __init__ (self , name , lengthFrom ):
77+ """
78+ :param name: Field name
79+ :param lengthFrom: Field name that provides the length of the payload
80+ """
81+ super (PayloadField , self ).__init__ (name , None )
82+ self .lengthFrom = lengthFrom
83+
84+ def getfield (self , pkt , s ):
85+ # Fetch the length from the field that specifies the length
86+ length = getattr (pkt , self .lengthFrom )
87+ payloadData = s [:length ]
88+
89+ if pkt .mask :
90+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
91+ data_int = int .from_bytes (payloadData , 'big' )
92+ mask_repeated = key * (len (payloadData ) // 4 ) + key [: len (payloadData ) % 4 ]
93+ mask_int = int .from_bytes (mask_repeated , 'big' )
94+ payloadData = (data_int ^ mask_int ).to_bytes (len (payloadData ), 'big' )
95+
96+ if ("permessage-deflate" in pkt .extensions ):
97+ try :
98+ payloadData = pkt .decoder [0 ](payloadData + b"\x00 \x00 \xff \xff " )
99+ except Exception :
100+ # Failed to decompress payload
101+ pass
102+
103+ return s [length :], payloadData
104+
105+ def addfield (self , pkt , s , val ):
106+ # Ensure val is bytes and append the data to the packet
107+ return s + bytes (val )
108+
109+ def i2len (self , pkt , val ):
110+ # Length of the payload in bytes
111+ return len (val )
112+
113+ class WebSocket (Packet ):
114+ __slots__ = ["extensions" , "decoder" ]
115+
116+ name = "WebSocket"
117+ fields_desc = [
118+ BitField ("fin" , 0 , 1 ),
119+ BitField ("rsv" , 0 , 3 ),
120+ BitEnumField ("opcode" , 0 , 4 ,
121+ {
122+ 0x0 : "none" ,
123+ 0x1 : "text" ,
124+ 0x2 : "binary" ,
125+ 0x8 : "close" ,
126+ 0x9 : "ping" ,
127+ 0xA : "pong" ,
128+ }),
129+ BitField ("mask" , 0 , 1 ),
130+ PayloadLenField ("payloadLen" , 0 , length_of = "wsPayload" , size = 1 ),
131+ ConditionalField (XNBytesField ("maskingKey" , 0 , sz = 4 ), lambda pkt : pkt .mask == 1 ),
132+ PayloadField ("wsPayload" , lengthFrom = "payloadLen" )
133+ ]
134+
135+ def __init__ (self , pkt = None , extensions = [], decoder = None , * args , ** fields ):
136+ self .extensions = extensions
137+ self .decoder = decoder
138+ super ().__init__ (_pkt = pkt , * args , ** fields )
139+
140+ def extract_padding (self , s ):
141+ return '' , s
142+
143+ @classmethod
144+ def tcp_reassemble (cls , data , metadata , session ):
145+ # data = the reassembled data from the same request/flow
146+ # metadata = empty dictionary, that can be used to store data
147+ # during TCP reassembly
148+ # session = a dictionary proper to the bidirectional TCP session,
149+ # that can be used to store anything
150+ # [...]
151+ # If the packet is available, return it. Otherwise don't.
152+ # Whenever you return a packet, the buffer will be discarded.
153+
154+
155+ HANDSHAKE_STATE_CLIENT_OPEN = 0
156+ HANDSHAKE_STATE_SERVER_OPEN = 1
157+ HANDSHAKE_STATE_OPEN = 2
158+
159+ if "handshake-state" not in session :
160+ session ["handshake-state" ] = HANDSHAKE_STATE_CLIENT_OPEN
161+
162+ if "extensions" not in session :
163+ session ["extensions" ] = {}
164+
165+
166+ if session ["handshake-state" ] == HANDSHAKE_STATE_CLIENT_OPEN :
167+ ht = HTTPRequest (data )
168+
169+ if ht .Method != b"GET" :
170+ return None
171+
172+ if not ht .Upgrade or ht .Upgrade .lower () != b"websocket" :
173+ return None
174+
175+ if b"Sec-WebSocket-Key" not in ht .Unknown_Headers :
176+ return None
177+
178+
179+ session ["handshake-key" ] = ht .Unknown_Headers [b"Sec-WebSocket-Key" ]
180+
181+ if "original" in metadata :
182+ session ["server-port" ] = metadata ["original" ][TCP ].dport
183+ else :
184+ print ("No original packet" )
185+
186+ session ["handshake-state" ] = HANDSHAKE_STATE_SERVER_OPEN
187+
188+ return ht
189+
190+ elif session ["handshake-state" ] == HANDSHAKE_STATE_SERVER_OPEN :
191+ ht = HTTPResponse (data )
192+
193+ if not ht .Upgrade .lower () == b"websocket" :
194+ return None
195+
196+ # Verify key-accept handshake:
197+ correct_accept = base64 .b64encode (sha1 (session ["handshake-key" ] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" .encode ()).digest ())
198+ if ht .Unknown_Headers [b"Sec-WebSocket-Accept" ] != correct_accept :
199+ #TODO handle or Logg wrong accept key
200+ pass
201+
202+ if b"Sec-WebSocket-Extensions" in ht .Unknown_Headers :
203+ session ["extensions" ] = {}
204+ for extension in ht .Unknown_Headers [b"Sec-WebSocket-Extensions" ].decode ().strip ().split (";" ):
205+ key_value_pair = extension .split ("=" , 1 ) + [None ]
206+ session ["extensions" ][key_value_pair [0 ].strip ()] = key_value_pair [1 ]
207+
208+ if "permessage-deflate" in session ["extensions" ]:
209+ def create_decompressor (window_bits ):
210+ decoder = zlib .decompressobj (wbits = - window_bits )
211+ def decomp (data ):
212+ nonlocal decoder
213+ return decoder .decompress (data , 0 )
214+
215+ def reset ():
216+ nonlocal decoder
217+ nonlocal window_bits
218+ decoder = zlib .decompressobj (wbits = - window_bits )
219+
220+ return (decomp , reset )
221+
222+ # Default values
223+ client_wb = 12
224+ server_wb = 15
225+
226+ # Check for new values in extensions header
227+ if "client_max_window_bits" in session ["extensions" ]:
228+ client_wb = int (session ["extensions" ]["client_max_window_bits" ])
229+
230+ if "server_max_window_bits" in session ["extensions" ]:
231+ server_wb = int (session ["extensions" ]["server_max_window_bits" ])
232+
233+
234+ session ["server-decoder" ] = create_decompressor (client_wb )
235+ session ["client-decoder" ] = create_decompressor (server_wb )
236+
237+
238+ session ["handshake-state" ] = HANDSHAKE_STATE_OPEN
239+
240+ return ht
241+
242+
243+ # Handshake is done:
244+ if "original" not in metadata :
245+ return
246+
247+ if "permessage-deflate" in session ["extensions" ]:
248+ is_server = True if metadata ["original" ][TCP ].sport == session ["server-port" ] else False
249+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ], decoder = session ["server-decoder" ] if is_server else session ["client-decoder" ])
250+ return ws
251+ else :
252+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ])
253+ return ws
0 commit comments