Skip to content

Commit d43c40d

Browse files
authored
[NFC] Use Type instead of HeapType for functions (#7971)
This will let us enforce that the type of a `ref.func` is equal to the type of the referenced function, even once we introduce inexact function imports. This PR just stores the Type there.
1 parent 7448451 commit d43c40d

40 files changed

+266
-172
lines changed

src/binaryen-c.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4973,7 +4973,7 @@ static BinaryenFunctionRef addFunctionInternal(BinaryenModuleRef module,
49734973
BinaryenExpressionRef body) {
49744974
auto* ret = new Function;
49754975
ret->setExplicitName(name);
4976-
ret->type = type;
4976+
ret->type = Type(type, NonNullable, Exact);
49774977
for (BinaryenIndex i = 0; i < numVarTypes; i++) {
49784978
ret->vars.push_back(Type(varTypes[i]));
49794979
}
@@ -5097,7 +5097,8 @@ void BinaryenAddFunctionImport(BinaryenModuleRef module,
50975097
func->module = externalModuleName;
50985098
func->base = externalBaseName;
50995099
// TODO: Take a HeapType rather than params and results.
5100-
func->type = Signature(Type(params), Type(results));
5100+
func->type =
5101+
Type(Signature(Type(params), Type(results)), NonNullable, Exact);
51015102
((Module*)module)->addFunction(std::move(func));
51025103
} else {
51035104
// already exists so just set module and base
@@ -5285,7 +5286,8 @@ BinaryenAddActiveElementSegment(BinaryenModuleRef module,
52855286
Fatal() << "invalid function '" << funcNames[i] << "'.";
52865287
}
52875288
segment->data.push_back(
5288-
Builder(*(Module*)module).makeRefFunc(funcNames[i], func->type));
5289+
Builder(*(Module*)module)
5290+
.makeRefFunc(funcNames[i], func->type.getHeapType()));
52895291
}
52905292
return ((Module*)module)->addElementSegment(std::move(segment));
52915293
}
@@ -5302,7 +5304,8 @@ BinaryenAddPassiveElementSegment(BinaryenModuleRef module,
53025304
Fatal() << "invalid function '" << funcNames[i] << "'.";
53035305
}
53045306
segment->data.push_back(
5305-
Builder(*(Module*)module).makeRefFunc(funcNames[i], func->type));
5307+
Builder(*(Module*)module)
5308+
.makeRefFunc(funcNames[i], func->type.getHeapType()));
53065309
}
53075310
return ((Module*)module)->addElementSegment(std::move(segment));
53085311
}
@@ -6017,10 +6020,10 @@ void BinaryenFunctionSetBody(BinaryenFunctionRef func,
60176020
((Function*)func)->body = (Expression*)body;
60186021
}
60196022
BinaryenHeapType BinaryenFunctionGetType(BinaryenFunctionRef func) {
6020-
return ((Function*)func)->type.getID();
6023+
return ((Function*)func)->type.getHeapType().getID();
60216024
}
60226025
void BinaryenFunctionSetType(BinaryenFunctionRef func, BinaryenHeapType type) {
6023-
((Function*)func)->type = HeapType(type);
6026+
((Function*)func)->type = Type(HeapType(type), NonNullable, Exact);
60246027
}
60256028
void BinaryenFunctionOptimize(BinaryenFunctionRef func,
60266029
BinaryenModuleRef module) {

src/ir/module-splitting.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,10 @@ void ModuleSplitter::setupJSPI() {
370370
primary.removeExport(LOAD_SECONDARY_MODULE);
371371
} else {
372372
// Add an imported function to load the secondary module.
373-
auto import = Builder::makeFunction(ModuleSplitting::LOAD_SECONDARY_MODULE,
374-
Signature(Type::none, Type::none),
375-
{});
373+
auto import = Builder::makeFunction(
374+
ModuleSplitting::LOAD_SECONDARY_MODULE,
375+
Type(Signature(Type::none, Type::none), NonNullable, Exact),
376+
{});
376377
import->module = ENV;
377378
import->base = ModuleSplitting::LOAD_SECONDARY_MODULE;
378379
primary.addFunction(std::move(import));
@@ -689,14 +690,15 @@ void ModuleSplitter::indirectCallsToSecondaryFunctions() {
689690
Builder builder(*getModule());
690691
Index secIndex = parent.funcToSecondaryIndex.at(curr->target);
691692
auto* func = parent.secondaries.at(secIndex)->getFunction(curr->target);
692-
auto tableSlot = parent.tableManager.getSlot(curr->target, func->type);
693+
auto tableSlot =
694+
parent.tableManager.getSlot(curr->target, func->type.getHeapType());
693695

694696
replaceCurrent(parent.maybeLoadSecondary(
695697
builder,
696698
builder.makeCallIndirect(tableSlot.tableName,
697699
tableSlot.makeExpr(parent.primary),
698700
curr->operands,
699-
func->type,
701+
func->type.getHeapType(),
700702
curr->isReturn)));
701703
}
702704
};
@@ -786,7 +788,8 @@ void ModuleSplitter::setupTablePatching() {
786788
primary, std::string("placeholder_") + placeholder->base.toString());
787789
placeholder->hasExplicitName = true;
788790
placeholder->type = secondaryFunc->type;
789-
elem = Builder(primary).makeRefFunc(placeholder->name, placeholder->type);
791+
elem = Builder(primary).makeRefFunc(placeholder->name,
792+
placeholder->type.getHeapType());
790793
primary.addFunction(std::move(placeholder));
791794
});
792795

@@ -827,7 +830,8 @@ void ModuleSplitter::setupTablePatching() {
827830
// primarySeg->data[i] is a placeholder, so use the secondary
828831
// function.
829832
auto* func = replacement->second;
830-
auto* ref = Builder(secondary).makeRefFunc(func->name, func->type);
833+
auto* ref = Builder(secondary).makeRefFunc(func->name,
834+
func->type.getHeapType());
831835
secondaryElems.push_back(ref);
832836
++replacement;
833837
} else if (auto* get = primarySeg->data[i]->dynCast<RefFunc>()) {
@@ -869,7 +873,7 @@ void ModuleSplitter::setupTablePatching() {
869873
}
870874
auto* func = curr->second;
871875
currData.push_back(
872-
Builder(secondary).makeRefFunc(func->name, func->type));
876+
Builder(secondary).makeRefFunc(func->name, func->type.getHeapType()));
873877
}
874878
if (currData.size()) {
875879
finishSegment();

src/ir/module-utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,14 +657,14 @@ std::vector<HeapType> getPublicHeapTypes(Module& wasm) {
657657
// We can ignore call.without.effects, which is implemented as an import but
658658
// functionally is a call within the module.
659659
if (!Intrinsics(wasm).isCallWithoutEffects(func)) {
660-
notePublic(func->type);
660+
notePublic(func->type.getHeapType());
661661
}
662662
});
663663
for (auto& ex : wasm.exports) {
664664
switch (ex->kind) {
665665
case ExternalKind::Function: {
666666
auto* func = wasm.getFunction(*ex->getInternalName());
667-
notePublic(func->type);
667+
notePublic(func->type.getHeapType());
668668
continue;
669669
}
670670
case ExternalKind::Table: {

src/ir/possible-contents.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -651,12 +651,13 @@ struct InfoCollector
651651
// actually have a RefFunc.
652652
auto* func = getModule()->getFunction(curr->func);
653653
for (Index i = 0; i < func->getParams().size(); i++) {
654-
info.links.push_back(
655-
{SignatureParamLocation{func->type, i}, ParamLocation{func, i}});
654+
info.links.push_back({SignatureParamLocation{func->type.getHeapType(), i},
655+
ParamLocation{func, i}});
656656
}
657657
for (Index i = 0; i < func->getResults().size(); i++) {
658658
info.links.push_back(
659-
{ResultLocation{func, i}, SignatureResultLocation{func->type, i}});
659+
{ResultLocation{func, i},
660+
SignatureResultLocation{func->type.getHeapType(), i}});
660661
}
661662

662663
if (!options.closedWorld) {
@@ -1759,9 +1760,9 @@ void TNHOracle::infer() {
17591760
continue;
17601761
}
17611762
while (1) {
1762-
typeFunctions[type].push_back(func.get());
1763-
if (auto super = type.getDeclaredSuperType()) {
1764-
type = *super;
1763+
typeFunctions[type.getHeapType()].push_back(func.get());
1764+
if (auto super = type.getHeapType().getDeclaredSuperType()) {
1765+
type = type.with(*super);
17651766
} else {
17661767
break;
17671768
}
@@ -1859,8 +1860,9 @@ void TNHOracle::infer() {
18591860
// as other opts will make this call direct later, after which a
18601861
// lot of other optimizations become possible anyhow.
18611862
auto target = possibleTargets[0]->name;
1862-
info.inferences[call->target] = PossibleContents::literal(
1863-
Literal::makeFunc(target, wasm.getFunction(target)->type));
1863+
info.inferences[call->target] =
1864+
PossibleContents::literal(Literal::makeFunc(
1865+
target, wasm.getFunction(target)->type.getHeapType()));
18641866
continue;
18651867
}
18661868

src/ir/table-utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ inline Index append(Table& table, Name name, Module& wasm) {
9292

9393
auto* func = wasm.getFunctionOrNull(name);
9494
assert(func != nullptr && "Cannot append non-existing function to a table.");
95-
segment->data.push_back(Builder(wasm).makeRefFunc(name, func->type));
95+
segment->data.push_back(
96+
Builder(wasm).makeRefFunc(name, func->type.getHeapType()));
9697
table.initial++;
9798
return tableIndex;
9899
}

src/parser/contexts.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,7 +1418,7 @@ struct ParseModuleTypesCtx : TypeParserCtx<ParseModuleTypesCtx>,
14181418
if (!type.type.isSignature()) {
14191419
return in.err(pos, "expected signature type");
14201420
}
1421-
f->type = type.type;
1421+
f->type = f->type.with(type.type);
14221422
// If we are provided with too many names (more than the function has), we
14231423
// will error on that later when we check the signature matches the type.
14241424
// For now, avoid asserting in setLocalName.
@@ -1601,7 +1601,7 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx>, AnnotationParserCtx {
16011601
elems.push_back(expr);
16021602
}
16031603
void appendFuncElem(std::vector<Expression*>& elems, Name func) {
1604-
auto type = wasm.getFunction(func)->type;
1604+
auto type = wasm.getFunction(func)->type.getHeapType();
16051605
elems.push_back(builder.makeRefFunc(func, type));
16061606
}
16071607

src/passes/Directize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
153153
return CallUtils::Trap{};
154154
}
155155
auto* func = getModule()->getFunction(name);
156-
if (!HeapType::isSubType(func->type, original->heapType)) {
156+
if (!HeapType::isSubType(func->type.getHeapType(), original->heapType)) {
157157
return CallUtils::Trap{};
158158
}
159159
return CallUtils::Known{name};

src/passes/FuncCastEmulation.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ struct FuncCastEmulation : public Pass {
178178
}
179179
auto* thunk = iter->second;
180180
ref->func = thunk->name;
181-
ref->finalize(thunk->type);
181+
ref->finalize(thunk->type.getHeapType());
182182
}
183183
}
184184

@@ -209,11 +209,11 @@ struct FuncCastEmulation : public Pass {
209209
for (Index i = 0; i < numParams; i++) {
210210
thunkParams.push_back(Type::i64);
211211
}
212-
auto thunkFunc =
213-
builder.makeFunction(thunk,
214-
Signature(Type(thunkParams), Type::i64),
215-
{}, // no vars
216-
toABI(call, module));
212+
auto thunkFunc = builder.makeFunction(
213+
thunk,
214+
Type(Signature(Type(thunkParams), Type::i64), NonNullable, Exact),
215+
{}, // no vars
216+
toABI(call, module));
217217
thunkFunc->hasExplicitName = true;
218218
return module->addFunction(std::move(thunkFunc));
219219
}

src/passes/GenerateDynCalls.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct GenerateDynCalls : public WalkerPass<PostWalker<GenerateDynCalls>> {
6161
std::vector<Name> tableSegmentData;
6262
ElementUtils::iterElementSegmentFunctionNames(
6363
it->get(), [&](Name name, Index) {
64-
generateDynCallThunk(wasm->getFunction(name)->type);
64+
generateDynCallThunk(wasm->getFunction(name)->type.getHeapType());
6565
});
6666
}
6767
}
@@ -70,7 +70,7 @@ struct GenerateDynCalls : public WalkerPass<PostWalker<GenerateDynCalls>> {
7070
// Generate dynCalls for invokes
7171
if (func->imported() && func->module == ENV &&
7272
func->base.startsWith("invoke_")) {
73-
Signature sig = func->type.getSignature();
73+
Signature sig = func->type.getHeapType().getSignature();
7474
// The first parameter is a pointer to the original function that's called
7575
// by the invoke, so skip it
7676
std::vector<Type> newParams(sig.params.begin() + 1, sig.params.end());
@@ -155,7 +155,11 @@ void GenerateDynCalls::generateDynCallThunk(HeapType funcType) {
155155
params.push_back(param);
156156
}
157157
auto f = builder.makeFunction(
158-
name, std::move(namedParams), Signature(Type(params), sig.results), {});
158+
name,
159+
std::move(namedParams),
160+
Type(Signature(Type(params), sig.results), NonNullable, Exact),
161+
{},
162+
nullptr);
159163
f->hasExplicitName = true;
160164
Expression* fptr = builder.makeLocalGet(0, table->addressType);
161165
std::vector<Expression*> args;

src/passes/JSPI.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ struct JSPI : public Pass {
9797
if (wasmSplit) {
9898
// Make an import for the load secondary module function so a JSPI wrapper
9999
// version will be created.
100-
auto import =
101-
Builder::makeFunction(ModuleSplitting::LOAD_SECONDARY_MODULE,
102-
Signature(Type::none, Type::none),
103-
{});
100+
auto import = Builder::makeFunction(
101+
ModuleSplitting::LOAD_SECONDARY_MODULE,
102+
Type(Signature(Type::none, Type::none), NonNullable, Exact),
103+
{});
104104
import->module = ENV;
105105
import->base = ModuleSplitting::LOAD_SECONDARY_MODULE;
106106
module->addFunction(std::move(import));
@@ -152,7 +152,8 @@ struct JSPI : public Pass {
152152
continue;
153153
}
154154
auto* replacementRef = builder.makeRefFunc(
155-
iter->second, module->getFunction(iter->second)->type);
155+
iter->second,
156+
module->getFunction(iter->second)->type.getHeapType());
156157
segment->data[i] = replacementRef;
157158
}
158159
}
@@ -213,12 +214,12 @@ struct JSPI : public Pass {
213214
block->list.push_back(builder.makeConst(0));
214215
}
215216
block->finalize();
216-
auto wrapperFunc =
217-
Builder::makeFunction(wrapperName,
218-
std::move(namedWrapperParams),
219-
Signature(Type(wrapperParams), resultsType),
220-
{},
221-
block);
217+
auto wrapperFunc = Builder::makeFunction(
218+
wrapperName,
219+
std::move(namedWrapperParams),
220+
Type(Signature(Type(wrapperParams), resultsType), NonNullable, Exact),
221+
{},
222+
block);
222223
return module->addFunction(std::move(wrapperFunc))->name;
223224
}
224225

@@ -276,7 +277,8 @@ struct JSPI : public Pass {
276277
block->finalize();
277278
call->type = im->getResults();
278279
stub->body = block;
279-
wrapperIm->type = Signature(Type(params), call->type);
280+
wrapperIm->type =
281+
Type(Signature(Type(params), call->type), NonNullable, Exact);
280282

281283
if (wasmSplit && im->name == ModuleSplitting::LOAD_SECONDARY_MODULE) {
282284
// In non-debug builds the name of the JSPI wrapper function for loading

0 commit comments

Comments
 (0)