Skip to content

Commit 0e28a7a

Browse files
Cherrypick SCC with indirect effects
1 parent a6b6272 commit 0e28a7a

File tree

3 files changed

+501
-73
lines changed

3 files changed

+501
-73
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 119 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
24+
#include "ir/subtypes.h"
2425
#include "pass.h"
2526
#include "support/strongly_connected_components.h"
2627
#include "wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940

4041
// Directly-called functions from this function.
4142
std::unordered_set<Name> calledFunctions;
43+
44+
// Types that are targets of indirect calls.
45+
std::unordered_set<HeapType> indirectCalledTypes;
4246
};
4347

4448
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -83,11 +87,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8387
if (auto* call = curr->dynCast<Call>()) {
8488
// Note the direct call.
8589
funcInfo.calledFunctions.insert(call->target);
90+
} else if (effects.calls && options.closedWorld) {
91+
HeapType type;
92+
if (auto* callRef = curr->dynCast<CallRef>()) {
93+
type = callRef->target->type.getHeapType();
94+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
95+
type = callIndirect->heapType;
96+
} else {
97+
assert(false && "Unexpected type of call");
98+
}
99+
100+
funcInfo.indirectCalledTypes.insert(type);
86101
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
102+
assert(!options.closedWorld);
91103
funcInfo.effects = UnknownEffects;
92104
} else {
93105
// No call here, but update throwing if we see it. (Only do so,
@@ -107,15 +119,44 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107119
return std::move(analysis.map);
108120
}
109121

110-
using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
122+
using CallGraphNode = std::variant<Function*, HeapType>;
123+
using CallGraph =
124+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111125

112-
CallGraph buildCallGraph(const Module& module,
113-
const std::map<Function*, FuncInfo>& funcInfos) {
126+
CallGraph buildCallGraph(Module& module,
127+
const std::map<Function*, FuncInfo>& funcInfos,
128+
bool closedWorld) {
114129
CallGraph callGraph;
115-
for (const auto& [func, info] : funcInfos) {
116-
for (Name callee : info.calledFunctions) {
117-
callGraph[func].insert(module.getFunction(callee));
130+
131+
if (!closedWorld) {
132+
for (const auto& [caller, callerInfo] : funcInfos) {
133+
for (Name calleeFunction : callerInfo.calledFunctions) {
134+
callGraph[caller].insert(module.getFunction(calleeFunction));
135+
}
136+
}
137+
return callGraph;
138+
}
139+
140+
std::unordered_set<HeapType> allFunctionTypes;
141+
for (const auto& [caller, callerInfo] : funcInfos) {
142+
for (Name calleeFunction : callerInfo.calledFunctions) {
143+
callGraph[caller].insert(module.getFunction(calleeFunction));
144+
}
145+
146+
allFunctionTypes.insert(caller->type.getHeapType());
147+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
148+
callGraph[caller].insert(calleeType);
149+
allFunctionTypes.insert(calleeType);
118150
}
151+
callGraph[caller->type.getHeapType()].insert(caller);
152+
}
153+
154+
SubTypes subtypes(module);
155+
for (HeapType type : allFunctionTypes) {
156+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
157+
callGraph[type].insert(sub);
158+
return true;
159+
});
119160
}
120161

121162
return callGraph;
@@ -148,99 +189,111 @@ void propagateEffects(const Module& module,
148189
std::map<Function*, FuncInfo>& funcInfos,
149190
const CallGraph& callGraph) {
150191
struct CallGraphSCCs
151-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
192+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
152193
const std::map<Function*, FuncInfo>& funcInfos;
153-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
154-
callGraph;
194+
const CallGraph& callGraph;
155195
const Module& module;
156196

157197
CallGraphSCCs(
158-
const std::vector<Function*>& funcs,
198+
const std::vector<CallGraphNode>& nodes,
159199
const std::map<Function*, FuncInfo>& funcInfos,
160-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
161-
callGraph,
200+
const std::unordered_map<CallGraphNode,
201+
std::unordered_set<CallGraphNode>>& callGraph,
162202
const Module& module)
163-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
164-
funcs.begin(), funcs.end()),
203+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
204+
nodes.begin(), nodes.end()),
165205
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
166206

167-
void pushChildren(Function* f) {
168-
auto callees = callGraph.find(f);
207+
void pushChildren(CallGraphNode node) {
208+
auto callees = callGraph.find(node);
169209
if (callees == callGraph.end()) {
170210
return;
171211
}
172-
173-
for (auto* callee : callees->second) {
212+
for (CallGraphNode callee : callees->second) {
174213
push(callee);
175214
}
176215
}
177216
};
178217

179-
std::vector<Function*> allFuncs;
218+
// We only care about Functions that are roots, not types
219+
// A type would be a root if a function exists with that type, but no-one
220+
// indirect calls the type.
221+
std::vector<CallGraphNode> allFuncs;
180222
for (auto& [func, info] : funcInfos) {
181223
allFuncs.push_back(func);
182224
}
225+
183226
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
184227

185228
std::vector<std::optional<EffectAnalyzer>> componentEffects;
186229
// Points to an index in componentEffects
187-
std::unordered_map<Function*, Index> funcComponents;
230+
std::unordered_map<CallGraphNode, Index> funcComponents;
188231

189232
for (auto ccIterator : sccs) {
190233
std::optional<EffectAnalyzer>& ccEffects =
191234
componentEffects.emplace_back(std::in_place, passOptions, module);
235+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
192236

193-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
194-
195-
for (Function* f : ccFuncs) {
196-
funcComponents.emplace(f, componentEffects.size() - 1);
197-
}
198-
199-
std::unordered_set<int> calleeSccs;
200-
for (Function* caller : ccFuncs) {
201-
auto callees = callGraph.find(caller);
202-
if (callees == callGraph.end()) {
203-
continue;
204-
}
205-
for (auto* callee : callees->second) {
206-
calleeSccs.insert(funcComponents.at(callee));
237+
std::vector<Function*> ccFuncs;
238+
for (CallGraphNode node : cc) {
239+
funcComponents.emplace(node, componentEffects.size() - 1);
240+
if (auto** func = std::get_if<Function*>(&node)) {
241+
ccFuncs.push_back(*func);
207242
}
208-
}
209243

210-
// Merge in effects from callees
211-
for (int calleeScc : calleeSccs) {
212-
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
213-
mergeMaybeEffects(ccEffects, calleeComponentEffects);
214-
}
244+
std::unordered_set<int> calleeSccs;
245+
for (CallGraphNode caller : cc) {
246+
auto callees = callGraph.find(caller);
247+
if (callees == callGraph.end()) {
248+
continue;
249+
}
250+
if (callees != callGraph.end()) {
251+
for (CallGraphNode callee : callees->second) {
252+
auto sccIt = funcComponents.find(callee);
253+
if (sccIt != funcComponents.end()) {
254+
calleeSccs.insert(sccIt->second);
255+
}
256+
}
257+
}
258+
}
215259

216-
// Add trap effects for potential cycles.
217-
if (ccFuncs.size() > 1) {
218-
if (ccEffects != UnknownEffects) {
219-
ccEffects->trap = true;
260+
// Merge in effects from callees
261+
for (int calleeScc : calleeSccs) {
262+
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
263+
mergeMaybeEffects(ccEffects, calleeComponentEffects);
220264
}
221-
} else {
222-
auto* func = ccFuncs[0];
223-
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
265+
266+
// Add trap effects for potential cycles.
267+
if (cc.size() > 1) {
224268
if (ccEffects != UnknownEffects) {
225269
ccEffects->trap = true;
226270
}
271+
} else if (ccFuncs.size() == 1) {
272+
// It's possible for a CC to only contain 1 type, but that is not a
273+
// cycle in the call graph.
274+
auto* func = ccFuncs[0];
275+
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
276+
if (ccEffects != UnknownEffects) {
277+
ccEffects->trap = true;
278+
}
279+
}
227280
}
228-
}
229281

230-
// Aggregate effects within this CC
231-
if (ccEffects) {
232-
for (Function* f : ccFuncs) {
233-
const auto& effects = funcInfos.at(f).effects;
234-
mergeMaybeEffects(ccEffects, effects);
282+
// Aggregate effects within this CC
283+
if (ccEffects) {
284+
for (Function* f : ccFuncs) {
285+
const auto& effects = funcInfos.at(f).effects;
286+
mergeMaybeEffects(ccEffects, effects);
287+
}
235288
}
236-
}
237289

238-
// Assign each function's effects to its CC effects.
239-
for (Function* f : ccFuncs) {
240-
if (!ccEffects) {
241-
funcInfos.at(f).effects = UnknownEffects;
242-
} else {
243-
funcInfos.at(f).effects.emplace(*ccEffects);
290+
// Assign each function's effects to its CC effects.
291+
for (Function* f : ccFuncs) {
292+
if (!ccEffects) {
293+
funcInfos.at(f).effects = UnknownEffects;
294+
} else {
295+
funcInfos.at(f).effects.emplace(*ccEffects);
296+
}
244297
}
245298
}
246299
}
@@ -262,7 +315,8 @@ struct GenerateGlobalEffects : public Pass {
262315
std::map<Function*, FuncInfo> funcInfos =
263316
analyzeFuncs(*module, getPassOptions());
264317

265-
auto callGraph = buildCallGraph(*module, funcInfos);
318+
auto callGraph =
319+
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);
266320

267321
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
268322

0 commit comments

Comments
 (0)