Skip to content

Commit b7ce0b6

Browse files
Cherrypick SCC with indirect effects
1 parent a6b6272 commit b7ce0b6

File tree

3 files changed

+503
-74
lines changed

3 files changed

+503
-74
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 121 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121

2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
24+
#include "ir/subtypes.h"
2425
#include "pass.h"
2526
#include "support/strongly_connected_components.h"
27+
#include "support/utilities.h"
2628
#include "wasm.h"
2729

2830
namespace wasm {
@@ -39,6 +41,9 @@ struct FuncInfo {
3941

4042
// Directly-called functions from this function.
4143
std::unordered_set<Name> calledFunctions;
44+
45+
// Types that are targets of indirect calls.
46+
std::unordered_set<HeapType> indirectCalledTypes;
4247
};
4348

4449
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -83,11 +88,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8388
if (auto* call = curr->dynCast<Call>()) {
8489
// Note the direct call.
8590
funcInfo.calledFunctions.insert(call->target);
91+
} else if (effects.calls && options.closedWorld) {
92+
HeapType type;
93+
if (auto* callRef = curr->dynCast<CallRef>()) {
94+
type = callRef->target->type.getHeapType();
95+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
96+
type = callIndirect->heapType;
97+
} else {
98+
assert(false && "Unexpected type of call");
99+
}
100+
101+
funcInfo.indirectCalledTypes.insert(type);
86102
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
103+
assert(!options.closedWorld);
91104
funcInfo.effects = UnknownEffects;
92105
} else {
93106
// No call here, but update throwing if we see it. (Only do so,
@@ -107,15 +120,44 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107120
return std::move(analysis.map);
108121
}
109122

110-
using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
123+
using CallGraphNode = std::variant<Function*, HeapType>;
124+
using CallGraph =
125+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111126

112-
CallGraph buildCallGraph(const Module& module,
113-
const std::map<Function*, FuncInfo>& funcInfos) {
127+
CallGraph buildCallGraph(Module& module,
128+
const std::map<Function*, FuncInfo>& funcInfos,
129+
bool closedWorld) {
114130
CallGraph callGraph;
115-
for (const auto& [func, info] : funcInfos) {
116-
for (Name callee : info.calledFunctions) {
117-
callGraph[func].insert(module.getFunction(callee));
131+
132+
if (!closedWorld) {
133+
for (const auto& [caller, callerInfo] : funcInfos) {
134+
for (Name calleeFunction : callerInfo.calledFunctions) {
135+
callGraph[caller].insert(module.getFunction(calleeFunction));
136+
}
137+
}
138+
return callGraph;
139+
}
140+
141+
std::unordered_set<HeapType> allFunctionTypes;
142+
for (const auto& [caller, callerInfo] : funcInfos) {
143+
for (Name calleeFunction : callerInfo.calledFunctions) {
144+
callGraph[caller].insert(module.getFunction(calleeFunction));
145+
}
146+
147+
allFunctionTypes.insert(caller->type.getHeapType());
148+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
149+
callGraph[caller].insert(calleeType);
150+
allFunctionTypes.insert(calleeType);
118151
}
152+
callGraph[caller->type.getHeapType()].insert(caller);
153+
}
154+
155+
SubTypes subtypes(module);
156+
for (HeapType type : allFunctionTypes) {
157+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
158+
callGraph[type].insert(sub);
159+
return true;
160+
});
119161
}
120162

121163
return callGraph;
@@ -148,99 +190,111 @@ void propagateEffects(const Module& module,
148190
std::map<Function*, FuncInfo>& funcInfos,
149191
const CallGraph& callGraph) {
150192
struct CallGraphSCCs
151-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
193+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
152194
const std::map<Function*, FuncInfo>& funcInfos;
153-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
154-
callGraph;
195+
const CallGraph& callGraph;
155196
const Module& module;
156197

157198
CallGraphSCCs(
158-
const std::vector<Function*>& funcs,
199+
const std::vector<CallGraphNode>& nodes,
159200
const std::map<Function*, FuncInfo>& funcInfos,
160-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
161-
callGraph,
201+
const std::unordered_map<CallGraphNode,
202+
std::unordered_set<CallGraphNode>>& callGraph,
162203
const Module& module)
163-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
164-
funcs.begin(), funcs.end()),
165-
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
204+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
205+
nodes.begin(), nodes.end()),
206+
funcInfos(funcInfos), module(module), callGraph(callGraph) {}
166207

167-
void pushChildren(Function* f) {
168-
auto callees = callGraph.find(f);
208+
void pushChildren(CallGraphNode node) {
209+
auto callees = callGraph.find(node);
169210
if (callees == callGraph.end()) {
170211
return;
171212
}
172-
173-
for (auto* callee : callees->second) {
213+
for (CallGraphNode callee : callees->second) {
174214
push(callee);
175215
}
176216
}
177217
};
178218

179-
std::vector<Function*> allFuncs;
219+
// We only care about Functions that are roots, not types
220+
// A type would be a root if a function exists with that type, but no-one
221+
// indirect calls the type.
222+
std::vector<CallGraphNode> allFuncs;
180223
for (auto& [func, info] : funcInfos) {
181224
allFuncs.push_back(func);
182225
}
226+
183227
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
184228

185229
std::vector<std::optional<EffectAnalyzer>> componentEffects;
186230
// Points to an index in componentEffects
187-
std::unordered_map<Function*, Index> funcComponents;
231+
std::unordered_map<CallGraphNode, Index> funcComponents;
188232

189233
for (auto ccIterator : sccs) {
190234
std::optional<EffectAnalyzer>& ccEffects =
191235
componentEffects.emplace_back(std::in_place, passOptions, module);
236+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
192237

193-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
194-
195-
for (Function* f : ccFuncs) {
196-
funcComponents.emplace(f, componentEffects.size() - 1);
197-
}
198-
199-
std::unordered_set<int> calleeSccs;
200-
for (Function* caller : ccFuncs) {
201-
auto callees = callGraph.find(caller);
202-
if (callees == callGraph.end()) {
203-
continue;
204-
}
205-
for (auto* callee : callees->second) {
206-
calleeSccs.insert(funcComponents.at(callee));
238+
std::vector<Function*> ccFuncs;
239+
for (CallGraphNode node : cc) {
240+
funcComponents.emplace(node, componentEffects.size() - 1);
241+
if (auto** func = std::get_if<Function*>(&node)) {
242+
ccFuncs.push_back(*func);
207243
}
208-
}
209244

210-
// Merge in effects from callees
211-
for (int calleeScc : calleeSccs) {
212-
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
213-
mergeMaybeEffects(ccEffects, calleeComponentEffects);
214-
}
245+
std::unordered_set<int> calleeSccs;
246+
for (CallGraphNode caller : cc) {
247+
auto callees = callGraph.find(caller);
248+
if (callees == callGraph.end()) {
249+
continue;
250+
}
251+
if (callees != callGraph.end()) {
252+
for (CallGraphNode callee : callees->second) {
253+
auto sccIt = funcComponents.find(callee);
254+
if (sccIt != funcComponents.end()) {
255+
calleeSccs.insert(sccIt->second);
256+
}
257+
}
258+
}
259+
}
215260

216-
// Add trap effects for potential cycles.
217-
if (ccFuncs.size() > 1) {
218-
if (ccEffects != UnknownEffects) {
219-
ccEffects->trap = true;
261+
// Merge in effects from callees
262+
for (int calleeScc : calleeSccs) {
263+
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
264+
mergeMaybeEffects(ccEffects, calleeComponentEffects);
220265
}
221-
} else {
222-
auto* func = ccFuncs[0];
223-
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
266+
267+
// Add trap effects for potential cycles.
268+
if (cc.size() > 1) {
224269
if (ccEffects != UnknownEffects) {
225270
ccEffects->trap = true;
226271
}
272+
} else if (ccFuncs.size() == 1) {
273+
// It's possible for a CC to only contain 1 type, but that is not a
274+
// cycle in the call graph.
275+
auto* func = ccFuncs[0];
276+
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
277+
if (ccEffects != UnknownEffects) {
278+
ccEffects->trap = true;
279+
}
280+
}
227281
}
228-
}
229282

230-
// Aggregate effects within this CC
231-
if (ccEffects) {
232-
for (Function* f : ccFuncs) {
233-
const auto& effects = funcInfos.at(f).effects;
234-
mergeMaybeEffects(ccEffects, effects);
283+
// Aggregate effects within this CC
284+
if (ccEffects) {
285+
for (Function* f : ccFuncs) {
286+
const auto& effects = funcInfos.at(f).effects;
287+
mergeMaybeEffects(ccEffects, effects);
288+
}
235289
}
236-
}
237290

238-
// Assign each function's effects to its CC effects.
239-
for (Function* f : ccFuncs) {
240-
if (!ccEffects) {
241-
funcInfos.at(f).effects = UnknownEffects;
242-
} else {
243-
funcInfos.at(f).effects.emplace(*ccEffects);
291+
// Assign each function's effects to its CC effects.
292+
for (Function* f : ccFuncs) {
293+
if (!ccEffects) {
294+
funcInfos.at(f).effects = UnknownEffects;
295+
} else {
296+
funcInfos.at(f).effects.emplace(*ccEffects);
297+
}
244298
}
245299
}
246300
}
@@ -262,7 +316,8 @@ struct GenerateGlobalEffects : public Pass {
262316
std::map<Function*, FuncInfo> funcInfos =
263317
analyzeFuncs(*module, getPassOptions());
264318

265-
auto callGraph = buildCallGraph(*module, funcInfos);
319+
auto callGraph =
320+
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);
266321

267322
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
268323

0 commit comments

Comments
 (0)