Skip to content

Commit fc87f14

Browse files
committed
Make local initialization work for arbitrary local values
We change the type of the local into a nullable type if we cannot use a placeholder value
1 parent 86beb38 commit fc87f14

File tree

4 files changed

+137
-7
lines changed

4 files changed

+137
-7
lines changed

compiler/lib-wasm/generate.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ module Generate (Target : Target_sig.S) = struct
12871287
| Some loc -> event loc
12881288
| None -> return ())
12891289
in
1290-
let body = post_process_function_body ~param_names ~locals body in
1290+
let locals, body = post_process_function_body ~param_names ~locals body in
12911291
W.Function
12921292
{ name =
12931293
(match name_opt with

compiler/lib-wasm/initialize_locals.ml

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,112 @@ and scan_instructions ctx l =
110110
let ctx = fork_context ctx in
111111
List.iter ~f:(fun i -> scan_instruction ctx i) l
112112

113+
let rec rewrite_expression uninitialized (e : Wasm_ast.expression) =
114+
match e with
115+
| Const _ | GlobalGet _ | Pop _ | RefFunc _ | RefNull _ -> e
116+
| UnOp (op, e') -> UnOp (op, rewrite_expression uninitialized e')
117+
| I32WrapI64 e' -> I32WrapI64 (rewrite_expression uninitialized e')
118+
| I64ExtendI32 (s, e') -> I64ExtendI32 (s, rewrite_expression uninitialized e')
119+
| F32DemoteF64 e' -> F32DemoteF64 (rewrite_expression uninitialized e')
120+
| F64PromoteF32 e' -> F64PromoteF32 (rewrite_expression uninitialized e')
121+
| RefI31 e' -> RefI31 (rewrite_expression uninitialized e')
122+
| I31Get (s, e') -> I31Get (s, rewrite_expression uninitialized e')
123+
| ArrayLen e' -> ArrayLen (rewrite_expression uninitialized e')
124+
| StructGet (s, ty, i, e') -> StructGet (s, ty, i, rewrite_expression uninitialized e')
125+
| RefCast (ty, e') -> RefCast (ty, rewrite_expression uninitialized e')
126+
| RefTest (ty, e') -> RefTest (ty, rewrite_expression uninitialized e')
127+
| Br_on_cast (i, ty, ty', e') ->
128+
Br_on_cast (i, ty, ty', rewrite_expression uninitialized e')
129+
| Br_on_cast_fail (i, ty, ty', e') ->
130+
Br_on_cast_fail (i, ty, ty', rewrite_expression uninitialized e')
131+
| Br_on_null (i, e') -> Br_on_null (i, rewrite_expression uninitialized e')
132+
| BinOp (op, e', e'') ->
133+
BinOp (op, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
134+
| ArrayNew (ty, e', e'') ->
135+
ArrayNew
136+
(ty, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
137+
| ArrayNewData (ty, i, e', e'') ->
138+
ArrayNewData
139+
(ty, i, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
140+
| ArrayGet (s, ty, e', e'') ->
141+
ArrayGet
142+
(s, ty, rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
143+
| RefEq (e', e'') ->
144+
RefEq (rewrite_expression uninitialized e', rewrite_expression uninitialized e'')
145+
| LocalGet i ->
146+
if Code.Var.Hashtbl.mem uninitialized i
147+
then RefCast (Code.Var.Hashtbl.find uninitialized i, e)
148+
else e
149+
| LocalTee (i, e') ->
150+
let e = Wasm_ast.LocalTee (i, rewrite_expression uninitialized e') in
151+
if Code.Var.Hashtbl.mem uninitialized i
152+
then RefCast (Code.Var.Hashtbl.find uninitialized i, e)
153+
else e
154+
| Call_ref (f, e', l) ->
155+
Call_ref
156+
(f, rewrite_expression uninitialized e', rewrite_expressions uninitialized l)
157+
| Call (f, l) -> Call (f, rewrite_expressions uninitialized l)
158+
| ArrayNewFixed (ty, l) -> ArrayNewFixed (ty, rewrite_expressions uninitialized l)
159+
| StructNew (ty, l) -> StructNew (ty, rewrite_expressions uninitialized l)
160+
| BlockExpr (ty, l) -> BlockExpr (ty, rewrite_instructions uninitialized l)
161+
| Seq (l, e') ->
162+
Seq (rewrite_instructions uninitialized l, rewrite_expression uninitialized e')
163+
| IfExpr (ty, cond, e1, e2) ->
164+
IfExpr
165+
( ty
166+
, rewrite_expression uninitialized cond
167+
, rewrite_expression uninitialized e1
168+
, rewrite_expression uninitialized e2 )
169+
| Try (ty, body, catches) -> Try (ty, rewrite_instructions uninitialized body, catches)
170+
| ExternConvertAny e' -> ExternConvertAny (rewrite_expression uninitialized e')
171+
| AnyConvertExtern e' -> AnyConvertExtern (rewrite_expression uninitialized e')
172+
173+
and rewrite_expressions uninitialized l =
174+
List.map ~f:(fun e -> rewrite_expression uninitialized e) l
175+
176+
and rewrite_instruction uninitialized i =
177+
match i with
178+
| Wasm_ast.Drop e -> Wasm_ast.Drop (rewrite_expression uninitialized e)
179+
| GlobalSet (x, e) -> GlobalSet (x, rewrite_expression uninitialized e)
180+
| Br (i, Some e) -> Br (i, Some (rewrite_expression uninitialized e))
181+
| Br_if (i, e) -> Br_if (i, rewrite_expression uninitialized e)
182+
| Br_table (e, l, i) -> Br_table (rewrite_expression uninitialized e, l, i)
183+
| Throw (t, e) -> Throw (t, rewrite_expression uninitialized e)
184+
| Return (Some e) -> Return (Some (rewrite_expression uninitialized e))
185+
| Push e -> Push (rewrite_expression uninitialized e)
186+
| StructSet (ty, i, e, e') ->
187+
StructSet
188+
(ty, i, rewrite_expression uninitialized e, rewrite_expression uninitialized e')
189+
| LocalSet (i, e) -> LocalSet (i, rewrite_expression uninitialized e)
190+
| Loop (ty, l) -> Loop (ty, rewrite_instructions uninitialized l)
191+
| Block (ty, l) -> Block (ty, rewrite_instructions uninitialized l)
192+
| If (ty, e, l, l') ->
193+
If
194+
( ty
195+
, rewrite_expression uninitialized e
196+
, rewrite_instructions uninitialized l
197+
, rewrite_instructions uninitialized l' )
198+
| CallInstr (f, l) -> CallInstr (f, rewrite_expressions uninitialized l)
199+
| Return_call (f, l) -> Return_call (f, rewrite_expressions uninitialized l)
200+
| Br (_, None) | Return None | Rethrow _ | Nop | Unreachable | Event _ -> i
201+
| ArraySet (ty, e, e', e'') ->
202+
ArraySet
203+
( ty
204+
, rewrite_expression uninitialized e
205+
, rewrite_expression uninitialized e'
206+
, rewrite_expression uninitialized e'' )
207+
| Return_call_ref (f, e', l) ->
208+
Return_call_ref
209+
(f, rewrite_expression uninitialized e', rewrite_expressions uninitialized l)
210+
211+
and rewrite_instructions uninitialized l =
212+
List.map ~f:(fun i -> rewrite_instruction uninitialized i) l
213+
214+
let has_default (ty : Wasm_ast.heap_type) =
215+
match ty with
216+
| Any | Eq | I31 -> true
217+
| Func | Extern | Array | Struct | None_ | Type _ -> false
218+
113219
let f ~param_names ~locals instrs =
114220
let ctx =
115221
{ initialized = Code.Var.Set.empty; uninitialized = ref Code.Var.Set.empty }
@@ -122,7 +228,31 @@ let f ~param_names ~locals instrs =
122228
| Ref { nullable = false; _ } -> ())
123229
locals;
124230
scan_instructions ctx instrs;
125-
List.map
126-
~f:(fun i -> Wasm_ast.LocalSet (i, RefI31 (Const (I32 0l))))
127-
(Code.Var.Set.elements !(ctx.uninitialized))
128-
@ instrs
231+
let local_types = Code.Var.Hashtbl.create 16 in
232+
let locals =
233+
List.map
234+
~f:(fun ((var, typ) as local) ->
235+
match typ with
236+
| Ref ({ nullable = false; typ } as ref_typ) ->
237+
if Code.Var.Set.mem var !(ctx.uninitialized) && not (has_default typ)
238+
then (
239+
Code.Var.Hashtbl.add local_types var ref_typ;
240+
var, Wasm_ast.Ref { nullable = true; typ })
241+
else local
242+
| I32 | I64 | F32 | F64 | Ref { nullable = true; _ } -> local)
243+
locals
244+
in
245+
let initializations =
246+
List.filter_map
247+
~f:(fun i ->
248+
if Code.Var.Hashtbl.mem local_types i
249+
then None
250+
else Some (Wasm_ast.LocalSet (i, RefI31 (Const (I32 0l)))))
251+
(Code.Var.Set.elements !(ctx.uninitialized))
252+
in
253+
let instrs =
254+
if Code.Var.Hashtbl.length local_types = 0
255+
then instrs
256+
else rewrite_instructions local_types instrs
257+
in
258+
locals, initializations @ instrs

compiler/lib-wasm/initialize_locals.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ val f :
2020
param_names:Wasm_ast.var list
2121
-> locals:(Wasm_ast.var * Wasm_ast.value_type) list
2222
-> Wasm_ast.instruction list
23-
-> Wasm_ast.instruction list
23+
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list

compiler/lib-wasm/target_sig.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ module type S = sig
278278
param_names:Wasm_ast.var list
279279
-> locals:(Wasm_ast.var * Wasm_ast.value_type) list
280280
-> Wasm_ast.instruction list
281-
-> Wasm_ast.instruction list
281+
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list
282282

283283
val entry_point :
284284
toplevel_fun:Wasm_ast.var

0 commit comments

Comments
 (0)