Skip to content

Commit 83f4b3b

Browse files
Cherrypick SCC with indirect effects
1 parent 0a2f0dd commit 83f4b3b

File tree

3 files changed

+468
-44
lines changed

3 files changed

+468
-44
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
24+
#include "ir/subtypes.h"
2425
#include "pass.h"
2526
#include "support/strongly_connected_components.h"
2627
#include "wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940

4041
// Directly-called functions from this function.
4142
std::unordered_set<Name> calledFunctions;
43+
44+
// Types that are targets of indirect calls.
45+
std::unordered_set<HeapType> indirectCalledTypes;
4246
};
4347

4448
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -83,11 +87,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8387
if (auto* call = curr->dynCast<Call>()) {
8488
// Note the direct call.
8589
funcInfo.calledFunctions.insert(call->target);
90+
} else if (effects.calls && options.closedWorld) {
91+
HeapType type;
92+
if (auto* callRef = curr->dynCast<CallRef>()) {
93+
type = callRef->target->type.getHeapType();
94+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
95+
type = callIndirect->heapType;
96+
} else {
97+
assert(false && "Unexpected type of call");
98+
}
99+
100+
funcInfo.indirectCalledTypes.insert(type);
86101
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
102+
assert(!options.closedWorld);
91103
funcInfo.effects = UnknownEffects;
92104
} else {
93105
// No call here, but update throwing if we see it. (Only do so,
@@ -107,20 +119,49 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107119
return std::move(analysis.map);
108120
}
109121

110-
using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
122+
using CallGraphNode = std::variant<Function*, HeapType>;
123+
using CallGraph =
124+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111125

112-
CallGraph buildCallGraph(const Module& module,
113-
const std::map<Function*, FuncInfo>& funcInfos) {
126+
CallGraph buildCallGraph(Module& module,
127+
const std::map<Function*, FuncInfo>& funcInfos,
128+
bool closedWorld) {
114129
CallGraph callGraph;
115-
for (const auto& [func, info] : funcInfos) {
116-
if (info.calledFunctions.empty()) {
117-
continue;
130+
131+
if (!closedWorld) {
132+
for (const auto& [func, info] : funcInfos) {
133+
if (info.calledFunctions.empty()) {
134+
continue;
135+
}
136+
137+
auto& callees = callGraph[func];
138+
for (Name calleeFunction : info.calledFunctions) {
139+
callees.insert(module.getFunction(calleeFunction));
140+
}
118141
}
142+
return callGraph;
143+
}
119144

120-
auto& callees = callGraph[func];
121-
for (Name callee : info.calledFunctions) {
122-
callees.insert(module.getFunction(callee));
145+
std::unordered_set<HeapType> allFunctionTypes;
146+
for (const auto& [caller, callerInfo] : funcInfos) {
147+
for (Name calleeFunction : callerInfo.calledFunctions) {
148+
callGraph[caller].insert(module.getFunction(calleeFunction));
123149
}
150+
151+
allFunctionTypes.insert(caller->type.getHeapType());
152+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
153+
callGraph[caller].insert(calleeType);
154+
allFunctionTypes.insert(calleeType);
155+
}
156+
callGraph[caller->type.getHeapType()].insert(caller);
157+
}
158+
159+
SubTypes subtypes(module);
160+
for (HeapType type : allFunctionTypes) {
161+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
162+
callGraph[type].insert(sub);
163+
return true;
164+
});
124165
}
125166

126167
return callGraph;
@@ -153,61 +194,67 @@ void propagateEffects(const Module& module,
153194
std::map<Function*, FuncInfo>& funcInfos,
154195
const CallGraph& callGraph) {
155196
struct CallGraphSCCs
156-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
197+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
157198
const std::map<Function*, FuncInfo>& funcInfos;
158-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
159-
callGraph;
199+
const CallGraph& callGraph;
160200
const Module& module;
161201

162202
CallGraphSCCs(
163-
const std::vector<Function*>& funcs,
203+
const std::vector<CallGraphNode>& nodes,
164204
const std::map<Function*, FuncInfo>& funcInfos,
165-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
166-
callGraph,
205+
const std::unordered_map<CallGraphNode,
206+
std::unordered_set<CallGraphNode>>& callGraph,
167207
const Module& module)
168-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
169-
funcs.begin(), funcs.end()),
208+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
209+
nodes.begin(), nodes.end()),
170210
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
171211

172-
void pushChildren(Function* f) {
173-
auto callees = callGraph.find(f);
212+
void pushChildren(CallGraphNode node) {
213+
auto callees = callGraph.find(node);
174214
if (callees == callGraph.end()) {
175215
return;
176216
}
177-
178-
for (auto* callee : callees->second) {
217+
for (CallGraphNode callee : callees->second) {
179218
push(callee);
180219
}
181220
}
182221
};
183222

184-
std::vector<Function*> allFuncs;
223+
// We only care about Functions that are roots, not types
224+
// A type would be a root if a function exists with that type, but no-one
225+
// indirect calls the type.
226+
std::vector<CallGraphNode> allFuncs;
185227
for (auto& [func, info] : funcInfos) {
186228
allFuncs.push_back(func);
187229
}
230+
188231
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
189232

190233
std::vector<std::optional<EffectAnalyzer>> componentEffects;
191234
// Points to an index in componentEffects
192-
std::unordered_map<Function*, Index> funcComponents;
235+
std::unordered_map<CallGraphNode, Index> funcComponents;
193236

194237
for (auto ccIterator : sccs) {
195238
std::optional<EffectAnalyzer>& ccEffects =
196239
componentEffects.emplace_back(std::in_place, passOptions, module);
240+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
197241

198-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
199-
200-
for (Function* f : ccFuncs) {
201-
funcComponents.emplace(f, componentEffects.size() - 1);
242+
std::vector<Function*> ccFuncs;
243+
for (CallGraphNode node : cc) {
244+
funcComponents.emplace(node, componentEffects.size() - 1);
245+
if (auto** func = std::get_if<Function*>(&node)) {
246+
ccFuncs.push_back(*func);
247+
}
202248
}
203249

204250
std::unordered_set<int> calleeSccs;
205-
for (Function* caller : ccFuncs) {
251+
for (CallGraphNode caller : cc) {
206252
auto callees = callGraph.find(caller);
207253
if (callees == callGraph.end()) {
208254
continue;
209255
}
210-
for (auto* callee : callees->second) {
256+
257+
for (CallGraphNode callee : callees->second) {
211258
calleeSccs.insert(funcComponents.at(callee));
212259
}
213260
}
@@ -219,11 +266,13 @@ void propagateEffects(const Module& module,
219266
}
220267

221268
// Add trap effects for potential cycles.
222-
if (ccFuncs.size() > 1) {
269+
if (cc.size() > 1) {
223270
if (ccEffects != UnknownEffects) {
224271
ccEffects->trap = true;
225272
}
226-
} else {
273+
} else if (ccFuncs.size() == 1) {
274+
// It's possible for a CC to only contain 1 type, but that is not a
275+
// cycle in the call graph.
227276
auto* func = ccFuncs[0];
228277
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
229278
if (ccEffects != UnknownEffects) {
@@ -267,7 +316,8 @@ struct GenerateGlobalEffects : public Pass {
267316
std::map<Function*, FuncInfo> funcInfos =
268317
analyzeFuncs(*module, getPassOptions());
269318

270-
auto callGraph = buildCallGraph(*module, funcInfos);
319+
auto callGraph =
320+
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);
271321

272322
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
273323

0 commit comments

Comments
 (0)