@@ -110,6 +110,112 @@ and scan_instructions ctx l =
110110 let ctx = fork_context ctx in
111111 List. iter ~f: (fun i -> scan_instruction ctx i) l
112112
113+ let rec rewrite_expression uninitialized (e : Wasm_ast.expression ) =
114+ match e with
115+ | Const _ | GlobalGet _ | Pop _ | RefFunc _ | RefNull _ -> e
116+ | UnOp (op , e' ) -> UnOp (op, rewrite_expression uninitialized e')
117+ | I32WrapI64 e' -> I32WrapI64 (rewrite_expression uninitialized e')
118+ | I64ExtendI32 (s , e' ) -> I64ExtendI32 (s, rewrite_expression uninitialized e')
119+ | F32DemoteF64 e' -> F32DemoteF64 (rewrite_expression uninitialized e')
120+ | F64PromoteF32 e' -> F64PromoteF32 (rewrite_expression uninitialized e')
121+ | RefI31 e' -> RefI31 (rewrite_expression uninitialized e')
122+ | I31Get (s , e' ) -> I31Get (s, rewrite_expression uninitialized e')
123+ | ArrayLen e' -> ArrayLen (rewrite_expression uninitialized e')
124+ | StructGet (s , ty , i , e' ) -> StructGet (s, ty, i, rewrite_expression uninitialized e')
125+ | RefCast (ty , e' ) -> RefCast (ty, rewrite_expression uninitialized e')
126+ | RefTest (ty , e' ) -> RefTest (ty, rewrite_expression uninitialized e')
127+ | Br_on_cast (i , ty , ty' , e' ) ->
128+ Br_on_cast (i, ty, ty', rewrite_expression uninitialized e')
129+ | Br_on_cast_fail (i , ty , ty' , e' ) ->
130+ Br_on_cast_fail (i, ty, ty', rewrite_expression uninitialized e')
131+ | Br_on_null (i , e' ) -> Br_on_null (i, rewrite_expression uninitialized e')
132+ | BinOp (op , e' , e'' ) ->
133+ BinOp (op, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
134+ | ArrayNew (ty , e' , e'' ) ->
135+ ArrayNew
136+ (ty, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
137+ | ArrayNewData (ty , i , e' , e'' ) ->
138+ ArrayNewData
139+ (ty, i, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
140+ | ArrayGet (s , ty , e' , e'' ) ->
141+ ArrayGet
142+ (s, ty, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
143+ | RefEq (e' , e'' ) ->
144+ RefEq (rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
145+ | LocalGet i ->
146+ if Code.Var.Hashtbl. mem uninitialized i
147+ then RefCast (Code.Var.Hashtbl. find uninitialized i, e)
148+ else e
149+ | LocalTee (i , e' ) ->
150+ let e = Wasm_ast. LocalTee (i, rewrite_expression uninitialized e') in
151+ if Code.Var.Hashtbl. mem uninitialized i
152+ then RefCast (Code.Var.Hashtbl. find uninitialized i, e)
153+ else e
154+ | Call_ref (f , e' , l ) ->
155+ Call_ref
156+ (f, rewrite_expression uninitialized e', rewrite_expressions uninitialized l)
157+ | Call (f , l ) -> Call (f, rewrite_expressions uninitialized l)
158+ | ArrayNewFixed (ty , l ) -> ArrayNewFixed (ty, rewrite_expressions uninitialized l)
159+ | StructNew (ty , l ) -> StructNew (ty, rewrite_expressions uninitialized l)
160+ | BlockExpr (ty , l ) -> BlockExpr (ty, rewrite_instructions uninitialized l)
161+ | Seq (l , e' ) ->
162+ Seq (rewrite_instructions uninitialized l, rewrite_expression uninitialized e')
163+ | IfExpr (ty , cond , e1 , e2 ) ->
164+ IfExpr
165+ ( ty
166+ , rewrite_expression uninitialized cond
167+ , rewrite_expression uninitialized e1
168+ , rewrite_expression uninitialized e2 )
169+ | Try (ty , body , catches ) -> Try (ty, rewrite_instructions uninitialized body, catches)
170+ | ExternConvertAny e' -> ExternConvertAny (rewrite_expression uninitialized e')
171+ | AnyConvertExtern e' -> AnyConvertExtern (rewrite_expression uninitialized e')
172+
173+ and rewrite_expressions uninitialized l =
174+ List. map ~f: (fun e -> rewrite_expression uninitialized e) l
175+
176+ and rewrite_instruction uninitialized i =
177+ match i with
178+ | Wasm_ast. Drop e -> Wasm_ast. Drop (rewrite_expression uninitialized e)
179+ | GlobalSet (x , e ) -> GlobalSet (x, rewrite_expression uninitialized e)
180+ | Br (i , Some e ) -> Br (i, Some (rewrite_expression uninitialized e))
181+ | Br_if (i , e ) -> Br_if (i, rewrite_expression uninitialized e)
182+ | Br_table (e , l , i ) -> Br_table (rewrite_expression uninitialized e, l, i)
183+ | Throw (t , e ) -> Throw (t, rewrite_expression uninitialized e)
184+ | Return (Some e ) -> Return (Some (rewrite_expression uninitialized e))
185+ | Push e -> Push (rewrite_expression uninitialized e)
186+ | StructSet (ty , i , e , e' ) ->
187+ StructSet
188+ (ty, i, rewrite_expression uninitialized e, rewrite_expression uninitialized e')
189+ | LocalSet (i , e ) -> LocalSet (i, rewrite_expression uninitialized e)
190+ | Loop (ty , l ) -> Loop (ty, rewrite_instructions uninitialized l)
191+ | Block (ty , l ) -> Block (ty, rewrite_instructions uninitialized l)
192+ | If (ty , e , l , l' ) ->
193+ If
194+ ( ty
195+ , rewrite_expression uninitialized e
196+ , rewrite_instructions uninitialized l
197+ , rewrite_instructions uninitialized l' )
198+ | CallInstr (f , l ) -> CallInstr (f, rewrite_expressions uninitialized l)
199+ | Return_call (f , l ) -> Return_call (f, rewrite_expressions uninitialized l)
200+ | Br (_ , None ) | Return None | Rethrow _ | Nop | Unreachable | Event _ -> i
201+ | ArraySet (ty , e , e' , e'' ) ->
202+ ArraySet
203+ ( ty
204+ , rewrite_expression uninitialized e
205+ , rewrite_expression uninitialized e'
206+ , rewrite_expression uninitialized e'' )
207+ | Return_call_ref (f , e' , l ) ->
208+ Return_call_ref
209+ (f, rewrite_expression uninitialized e', rewrite_expressions uninitialized l)
210+
211+ and rewrite_instructions uninitialized l =
212+ List. map ~f: (fun i -> rewrite_instruction uninitialized i) l
213+
214+ let has_default (ty : Wasm_ast.heap_type ) =
215+ match ty with
216+ | Any | Eq | I31 -> true
217+ | Func | Extern | Array | Struct | None_ | Type _ -> false
218+
113219let f ~param_names ~locals instrs =
114220 let ctx =
115221 { initialized = Code.Var.Set. empty; uninitialized = ref Code.Var.Set. empty }
@@ -122,7 +228,31 @@ let f ~param_names ~locals instrs =
122228 | Ref { nullable = false ; _ } -> () )
123229 locals;
124230 scan_instructions ctx instrs;
125- List. map
126- ~f: (fun i -> Wasm_ast. LocalSet (i, RefI31 (Const (I32 0l ))))
127- (Code.Var.Set. elements ! (ctx.uninitialized))
128- @ instrs
231+ let local_types = Code.Var.Hashtbl. create 16 in
232+ let locals =
233+ List. map
234+ ~f: (fun ((var , typ ) as local ) ->
235+ match typ with
236+ | Ref ({ nullable = false ; typ } as ref_typ ) ->
237+ if Code.Var.Set. mem var ! (ctx.uninitialized) && not (has_default typ)
238+ then (
239+ Code.Var.Hashtbl. add local_types var ref_typ;
240+ var, Wasm_ast. Ref { nullable = true ; typ })
241+ else local
242+ | I32 | I64 | F32 | F64 | Ref { nullable = true ; _ } -> local)
243+ locals
244+ in
245+ let initializations =
246+ List. filter_map
247+ ~f: (fun i ->
248+ if Code.Var.Hashtbl. mem local_types i
249+ then None
250+ else Some (Wasm_ast. LocalSet (i, RefI31 (Const (I32 0l )))))
251+ (Code.Var.Set. elements ! (ctx.uninitialized))
252+ in
253+ let instrs =
254+ if Code.Var.Hashtbl. length local_types = 0
255+ then instrs
256+ else rewrite_instructions local_types instrs
257+ in
258+ locals, initializations @ instrs
0 commit comments