Skip to content

Commit ec9d431

Browse files
Cherrypick SCC with indirect effects
1 parent 9bc667f commit ec9d431

File tree

2 files changed

+441
-67
lines changed

2 files changed

+441
-67
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 118 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
24+
#include "ir/subtypes.h"
2425
#include "pass.h"
2526
#include "support/strongly_connected_components.h"
2627
#include "wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940

4041
// Directly-called functions from this function.
4142
std::unordered_set<Name> calledFunctions;
43+
44+
// Types that are targets of indirect calls.
45+
std::unordered_set<HeapType> indirectCalledTypes;
4246
};
4347

4448
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -83,11 +87,21 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8387
if (auto* call = curr->dynCast<Call>()) {
8488
// Note the direct call.
8589
funcInfo.calledFunctions.insert(call->target);
90+
} else if (effects.calls && options.closedWorld) {
91+
HeapType type;
92+
if (auto* callRef = curr->dynCast<CallRef>()) {
93+
type = callRef->target->type.getHeapType();
94+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
95+
// nullability doesn't matter here
96+
// call_indirect is always inexact
97+
type = callIndirect->heapType;
98+
} else {
99+
assert(false && "Unexpected type of call");
100+
}
101+
102+
funcInfo.indirectCalledTypes.insert(type);
86103
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
104+
assert(!options.closedWorld);
91105
funcInfo.effects = UnknownEffects;
92106
} else {
93107
// No call here, but update throwing if we see it. (Only do so,
@@ -107,14 +121,45 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107121
return std::move(analysis.map);
108122
}
109123

110-
std::unordered_map<Function*, std::unordered_set<Function*>>
111-
buildCallGraph(const Module& module,
112-
const std::map<Function*, FuncInfo>& funcInfos) {
113-
std::unordered_map<Function*, std::unordered_set<Function*>> callGraph;
114-
for (const auto& [func, info] : funcInfos) {
115-
for (Name callee : info.calledFunctions) {
116-
callGraph[func].insert(module.getFunction(callee));
124+
using CallGraphNode = std::variant<Function*, HeapType>;
125+
126+
// Build a call graph for indirect and direct calls.
127+
// key (callee) -> value (caller)
128+
// Name -> Name : callee is called directly by caller
129+
// Name -> HeapType : callee is a potential target of a virtual call
130+
// with this HeapType
131+
// HeapType -> Name : callee is indirectly called by caller
132+
// HeapType -> HeapType : callee is a subtype of caller If we're
133+
// running in an open world, we only include Name -> Name edges.
134+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
135+
buildCallGraph(Module& module,
136+
const std::map<Function*, FuncInfo>& funcInfos,
137+
bool closedWorld) {
138+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
139+
callGraph;
140+
std::unordered_set<HeapType> allFunctionTypes;
141+
for (const auto& [caller, callerInfo] : funcInfos) {
142+
for (Name calleeFunction : callerInfo.calledFunctions) {
143+
callGraph[caller].insert(module.getFunction(calleeFunction));
144+
}
145+
if (!closedWorld) {
146+
continue;
147+
}
148+
149+
allFunctionTypes.insert(caller->type.getHeapType());
150+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
151+
callGraph[caller].insert(calleeType);
152+
allFunctionTypes.insert(calleeType);
117153
}
154+
callGraph[caller->type.getHeapType()].insert(caller);
155+
}
156+
157+
SubTypes subtypes(module);
158+
for (HeapType type : allFunctionTypes) {
159+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
160+
callGraph[type].insert(sub);
161+
return true;
162+
});
118163
}
119164

120165
return callGraph;
@@ -123,98 +168,108 @@ buildCallGraph(const Module& module,
123168
// Propagate effects from callees to callers transitively
124169
// e.g. if A -> B -> C (A calls B which calls C)
125170
// Then B inherits effects from C and A inherits effects from both B and C.
126-
//
127-
// Generate SCC for the call graph, then traverse it in reverse topological
128-
// order processing each callee before its callers. When traversing:
129-
// - Merge all of the effects of functions within the CC
130-
// - Also merge the (already computed) effects of each callee CC
131-
// - Add trap effects for potentially recursive call chains
132171
void propagateEffects(
133-
const Module& module,
172+
Module& module,
134173
const PassOptions& passOptions,
135174
std::map<Function*, FuncInfo>& funcInfos,
136-
const std::unordered_map<Function*, std::unordered_set<Function*>>
175+
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
137176
callGraph) {
177+
138178
struct CallGraphSCCs
139-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
179+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
140180
const std::map<Function*, FuncInfo>& funcInfos;
141-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
142-
callGraph;
143181
const Module& module;
182+
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
183+
callGraph;
144184

145185
CallGraphSCCs(
146-
const std::vector<Function*>& funcs,
186+
const std::vector<CallGraphNode>& nodes,
147187
const std::map<Function*, FuncInfo>& funcInfos,
148-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
149-
callGraph,
150-
const Module& module)
151-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
152-
funcs.begin(), funcs.end()),
153-
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
154-
155-
void pushChildren(Function* f) {
156-
auto callees = callGraph.find(f);
188+
Module& module,
189+
const std::unordered_map<CallGraphNode,
190+
std::unordered_set<CallGraphNode>>& callGraph)
191+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
192+
nodes.begin(), nodes.end()),
193+
funcInfos(funcInfos), module(module), callGraph(callGraph) {}
194+
195+
void pushChildren(CallGraphNode node) {
196+
auto callees = callGraph.find(node);
157197
if (callees == callGraph.end()) {
158198
return;
159199
}
160-
161-
for (auto* callee : callees->second) {
200+
for (const auto& callee : callees->second) {
162201
push(callee);
163202
}
164203
}
165204
};
166205

167-
std::vector<Function*> allFuncs;
206+
std::vector<CallGraphNode> funcs;
207+
// We only care about Functions that are roots, not types
168208
for (auto& [func, info] : funcInfos) {
169-
allFuncs.push_back(func);
209+
funcs.push_back(func);
170210
}
171-
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
172211

173-
std::unordered_map<Function*, int> sccMembers;
212+
CallGraphSCCs sccs(funcs, funcInfos, module, callGraph);
213+
214+
std::unordered_map<CallGraphNode, int> sccMembers;
174215
std::unordered_map<int, std::optional<EffectAnalyzer>> componentEffects;
175216

176217
int ccIndex = 0;
177218
for (auto ccIterator : sccs) {
219+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
178220
ccIndex++;
179221
std::optional<EffectAnalyzer>& ccEffects = componentEffects[ccIndex];
180-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
181-
182222
ccEffects.emplace(passOptions, module);
183223

224+
std::vector<Function*> ccFuncs;
225+
std::vector<HeapType> ccTypes;
226+
for (auto v : cc) {
227+
if (auto** func = std::get_if<Function*>(&v)) {
228+
ccFuncs.push_back(*func);
229+
} else {
230+
ccTypes.push_back(std::get<HeapType>(v));
231+
}
232+
}
233+
184234
for (Function* f : ccFuncs) {
185235
sccMembers.emplace(f, ccIndex);
186236
}
237+
for (HeapType t : ccTypes) {
238+
sccMembers.emplace(t, ccIndex);
239+
}
187240

188241
std::unordered_set<int> calleeSccs;
189-
for (Function* caller : ccFuncs) {
242+
for (const auto& caller : cc) {
190243
auto callees = callGraph.find(caller);
191-
if (callees == callGraph.end()) {
192-
continue;
193-
}
194-
for (auto* callee : callees->second) {
195-
calleeSccs.insert(sccMembers.at(callee));
244+
if (callees != callGraph.end()) {
245+
for (const auto& callee : callees->second) {
246+
auto sccIt = sccMembers.find(callee);
247+
if (sccIt != sccMembers.end()) {
248+
calleeSccs.insert(sccIt->second);
249+
}
250+
}
196251
}
197252
}
198253

199-
// Merge in effects from callees
200254
for (int calleeScc : calleeSccs) {
201255
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
202256
if (calleeComponentEffects == UnknownEffects) {
203257
ccEffects = UnknownEffects;
204258
break;
205259
}
206260

207-
else if (ccEffects != UnknownEffects) {
261+
else if (ccEffects) {
208262
ccEffects->mergeIn(*calleeComponentEffects);
209263
}
210264
}
211265

212266
// Add trap effects for potential cycles.
213-
if (ccFuncs.size() > 1) {
267+
if (cc.size() > 1) {
214268
if (ccEffects != UnknownEffects) {
215269
ccEffects->trap = true;
216270
}
217-
} else {
271+
// A cycle isn't possible for a CC that only contains a type
272+
} else if (ccFuncs.size() == 1) {
218273
auto* func = ccFuncs[0];
219274
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
220275
if (ccEffects != UnknownEffects) {
@@ -223,8 +278,7 @@ void propagateEffects(
223278
}
224279
}
225280

226-
// Aggregate effects within this CC
227-
if (ccEffects) {
281+
if (ccEffects)
228282
for (Function* f : ccFuncs) {
229283
const auto& effects = funcInfos.at(f).effects;
230284
if (effects == UnknownEffects) {
@@ -234,9 +288,7 @@ void propagateEffects(
234288

235289
ccEffects->mergeIn(*effects);
236290
}
237-
}
238291

239-
// Assign each function's effects to its CC effects.
240292
for (Function* f : ccFuncs) {
241293
if (!ccEffects) {
242294
funcInfos.at(f).effects = UnknownEffects;
@@ -247,27 +299,26 @@ void propagateEffects(
247299
}
248300
}
249301

250-
void copyEffectsToFunctions(const std::map<Function*, FuncInfo> funcInfos) {
251-
for (auto& [func, info] : funcInfos) {
252-
func->effects.reset();
253-
if (!info.effects) {
254-
continue;
255-
}
256-
257-
func->effects = std::make_shared<EffectAnalyzer>(*info.effects);
258-
}
259-
}
260-
261302
struct GenerateGlobalEffects : public Pass {
262303
void run(Module* module) override {
263304
std::map<Function*, FuncInfo> funcInfos =
264305
analyzeFuncs(*module, getPassOptions());
265306

266-
auto callGraph = buildCallGraph(*module, funcInfos);
307+
auto callGraph =
308+
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);
267309

268310
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
269311

270-
copyEffectsToFunctions(funcInfos);
312+
// Generate the final data, starting from a blank slate where nothing is
313+
// known.
314+
for (auto& [func, info] : funcInfos) {
315+
func->effects.reset();
316+
if (!info.effects) {
317+
continue;
318+
}
319+
320+
func->effects = std::make_shared<EffectAnalyzer>(*info.effects);
321+
}
271322
}
272323
};
273324

0 commit comments

Comments
 (0)