Skip to content

Commit 0418356

Browse files
Cherrypick SCC with indirect effects
1 parent 9bc667f commit 0418356

File tree

4 files changed

+492
-51
lines changed

4 files changed

+492
-51
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 106 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121

2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
24+
#include "ir/subtypes.h"
2425
#include "pass.h"
2526
#include "support/strongly_connected_components.h"
27+
#include "support/utilities.h"
2628
#include "wasm.h"
2729

2830
namespace wasm {
@@ -39,6 +41,9 @@ struct FuncInfo {
3941

4042
// Directly-called functions from this function.
4143
std::unordered_set<Name> calledFunctions;
44+
45+
// Types that are targets of indirect calls.
46+
std::unordered_set<HeapType> indirectCalledTypes;
4247
};
4348

4449
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -83,11 +88,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8388
if (auto* call = curr->dynCast<Call>()) {
8489
// Note the direct call.
8590
funcInfo.calledFunctions.insert(call->target);
91+
} else if (effects.calls && options.closedWorld) {
92+
HeapType type;
93+
if (auto* callRef = curr->dynCast<CallRef>()) {
94+
type = callRef->target->type.getHeapType();
95+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
96+
type = callIndirect->heapType;
97+
} else {
98+
assert(false && "Unexpected type of call");
99+
}
100+
101+
funcInfo.indirectCalledTypes.insert(type);
86102
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
103+
assert(!options.closedWorld);
91104
funcInfo.effects = UnknownEffects;
92105
} else {
93106
// No call here, but update throwing if we see it. (Only do so,
@@ -107,14 +120,52 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107120
return std::move(analysis.map);
108121
}
109122

110-
std::unordered_map<Function*, std::unordered_set<Function*>>
111-
buildCallGraph(const Module& module,
112-
const std::map<Function*, FuncInfo>& funcInfos) {
113-
std::unordered_map<Function*, std::unordered_set<Function*>> callGraph;
114-
for (const auto& [func, info] : funcInfos) {
115-
for (Name callee : info.calledFunctions) {
116-
callGraph[func].insert(module.getFunction(callee));
123+
using CallGraphNode = std::variant<Function*, HeapType>;
124+
125+
// Build a call graph for indirect and direct calls.
126+
// key (callee) -> value (caller)
127+
// Name -> Name : callee is called directly by caller
128+
// Name -> HeapType : callee is a potential target of a virtual call
129+
// with this HeapType
130+
// HeapType -> Name : callee is indirectly called by caller
131+
// HeapType -> HeapType : callee is a subtype of caller If we're
132+
// running in an open world, we only include Name -> Name edges.
133+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
134+
buildCallGraph(Module& module,
135+
const std::map<Function*, FuncInfo>& funcInfos,
136+
bool closedWorld) {
137+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
138+
callGraph;
139+
140+
if (!closedWorld) {
141+
for (const auto& [caller, callerInfo] : funcInfos) {
142+
for (Name calleeFunction : callerInfo.calledFunctions) {
143+
callGraph[caller].insert(module.getFunction(calleeFunction));
144+
}
145+
}
146+
return callGraph;
147+
}
148+
149+
std::unordered_set<HeapType> allFunctionTypes;
150+
for (const auto& [caller, callerInfo] : funcInfos) {
151+
for (Name calleeFunction : callerInfo.calledFunctions) {
152+
callGraph[caller].insert(module.getFunction(calleeFunction));
153+
}
154+
155+
allFunctionTypes.insert(caller->type.getHeapType());
156+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
157+
callGraph[caller].insert(calleeType);
158+
allFunctionTypes.insert(calleeType);
117159
}
160+
callGraph[caller->type.getHeapType()].insert(caller);
161+
}
162+
163+
SubTypes subtypes(module);
164+
for (HeapType type : allFunctionTypes) {
165+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
166+
callGraph[type].insert(sub);
167+
return true;
168+
});
118169
}
119170

120171
return callGraph;
@@ -130,69 +181,79 @@ buildCallGraph(const Module& module,
130181
// - Also merge the (already computed) effects of each callee CC
131182
// - Add trap effects for potentially recursive call chains
132183
void propagateEffects(
133-
const Module& module,
184+
Module& module,
134185
const PassOptions& passOptions,
135186
std::map<Function*, FuncInfo>& funcInfos,
136-
const std::unordered_map<Function*, std::unordered_set<Function*>>
187+
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
137188
callGraph) {
189+
138190
struct CallGraphSCCs
139-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
191+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
140192
const std::map<Function*, FuncInfo>& funcInfos;
141-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
142-
callGraph;
143193
const Module& module;
194+
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
195+
callGraph;
144196

145197
CallGraphSCCs(
146-
const std::vector<Function*>& funcs,
198+
const std::vector<CallGraphNode>& nodes,
147199
const std::map<Function*, FuncInfo>& funcInfos,
148-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
149-
callGraph,
150-
const Module& module)
151-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
152-
funcs.begin(), funcs.end()),
153-
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
154-
155-
void pushChildren(Function* f) {
156-
auto callees = callGraph.find(f);
200+
Module& module,
201+
const std::unordered_map<CallGraphNode,
202+
std::unordered_set<CallGraphNode>>& callGraph)
203+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
204+
nodes.begin(), nodes.end()),
205+
funcInfos(funcInfos), module(module), callGraph(callGraph) {}
206+
207+
void pushChildren(CallGraphNode node) {
208+
auto callees = callGraph.find(node);
157209
if (callees == callGraph.end()) {
158210
return;
159211
}
160-
161-
for (auto* callee : callees->second) {
212+
for (CallGraphNode callee : callees->second) {
162213
push(callee);
163214
}
164215
}
165216
};
166217

167-
std::vector<Function*> allFuncs;
218+
// We only care about Functions that are roots, not types
219+
// A type would be a root if a function exists with that type, but no-one
220+
// indirect calls the type.
221+
std::vector<CallGraphNode> allFuncs;
168222
for (auto& [func, info] : funcInfos) {
169223
allFuncs.push_back(func);
170224
}
171-
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
172225

173-
std::unordered_map<Function*, int> sccMembers;
226+
CallGraphSCCs sccs(allFuncs, funcInfos, module, callGraph);
227+
228+
std::unordered_map<CallGraphNode, int> sccMembers;
174229
std::unordered_map<int, std::optional<EffectAnalyzer>> componentEffects;
175230

176231
int ccIndex = 0;
177232
for (auto ccIterator : sccs) {
178233
ccIndex++;
179234
std::optional<EffectAnalyzer>& ccEffects = componentEffects[ccIndex];
180-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
235+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
181236

182237
ccEffects.emplace(passOptions, module);
183238

184-
for (Function* f : ccFuncs) {
185-
sccMembers.emplace(f, ccIndex);
239+
std::vector<Function*> ccFuncs;
240+
for (CallGraphNode node : cc) {
241+
sccMembers.emplace(node, ccIndex);
242+
if (auto** func = std::get_if<Function*>(&node)) {
243+
ccFuncs.push_back(*func);
244+
}
186245
}
187246

188247
std::unordered_set<int> calleeSccs;
189-
for (Function* caller : ccFuncs) {
248+
for (CallGraphNode caller : cc) {
190249
auto callees = callGraph.find(caller);
191-
if (callees == callGraph.end()) {
192-
continue;
193-
}
194-
for (auto* callee : callees->second) {
195-
calleeSccs.insert(sccMembers.at(callee));
250+
if (callees != callGraph.end()) {
251+
for (const auto& callee : callees->second) {
252+
auto sccIt = sccMembers.find(callee);
253+
if (sccIt != sccMembers.end()) {
254+
calleeSccs.insert(sccIt->second);
255+
}
256+
}
196257
}
197258
}
198259

@@ -210,11 +271,12 @@ void propagateEffects(
210271
}
211272

212273
// Add trap effects for potential cycles.
213-
if (ccFuncs.size() > 1) {
274+
if (cc.size() > 1) {
214275
if (ccEffects != UnknownEffects) {
215276
ccEffects->trap = true;
216277
}
217-
} else {
278+
// A cycle isn't possible for a CC that only contains a type
279+
} else if (ccFuncs.size() == 1) {
218280
auto* func = ccFuncs[0];
219281
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
220282
if (ccEffects != UnknownEffects) {
@@ -263,7 +325,8 @@ struct GenerateGlobalEffects : public Pass {
263325
std::map<Function*, FuncInfo> funcInfos =
264326
analyzeFuncs(*module, getPassOptions());
265327

266-
auto callGraph = buildCallGraph(*module, funcInfos);
328+
auto callGraph =
329+
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);
267330

268331
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
269332

src/support/utilities.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class Fatal {
9494
#define WASM_UNREACHABLE(msg) wasm::handle_unreachable()
9595
#endif
9696

97+
template<class... Ts> struct overloaded : Ts... {
98+
using Ts::operator()...;
99+
};
100+
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
97101
} // namespace wasm
98102

99103
#endif // wasm_support_utilities_h

0 commit comments

Comments
 (0)