Skip to content

Commit b4e9245

Browse files
Cherrypick SCC with indirect effects
1 parent a6b6272 commit b4e9245

File tree

4 files changed

+514
-77
lines changed

4 files changed

+514
-77
lines changed

src/passes/GlobalEffects.cpp

Lines changed: 128 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121

2222
#include "ir/effects.h"
2323
#include "ir/module-utils.h"
24+
#include "ir/subtypes.h"
2425
#include "pass.h"
2526
#include "support/strongly_connected_components.h"
27+
#include "support/utilities.h"
2628
#include "wasm.h"
2729

2830
namespace wasm {
@@ -39,6 +41,9 @@ struct FuncInfo {
3941

4042
// Directly-called functions from this function.
4143
std::unordered_set<Name> calledFunctions;
44+
45+
// Types that are targets of indirect calls.
46+
std::unordered_set<HeapType> indirectCalledTypes;
4247
};
4348

4449
std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
@@ -83,11 +88,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8388
if (auto* call = curr->dynCast<Call>()) {
8489
// Note the direct call.
8590
funcInfo.calledFunctions.insert(call->target);
91+
} else if (effects.calls && options.closedWorld) {
92+
HeapType type;
93+
if (auto* callRef = curr->dynCast<CallRef>()) {
94+
type = callRef->target->type.getHeapType();
95+
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
96+
type = callIndirect->heapType;
97+
} else {
98+
assert(false && "Unexpected type of call");
99+
}
100+
101+
funcInfo.indirectCalledTypes.insert(type);
86102
} else if (effects.calls) {
87-
// This is an indirect call of some sort, so we must assume the
88-
// worst. To do so, clear the effects, which indicates nothing
89-
// is known (so anything is possible).
90-
// TODO: We could group effects by function type etc.
103+
assert(!options.closedWorld);
91104
funcInfo.effects = UnknownEffects;
92105
} else {
93106
// No call here, but update throwing if we see it. (Only do so,
@@ -107,15 +120,44 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107120
return std::move(analysis.map);
108121
}
109122

110-
using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
111-
112-
CallGraph buildCallGraph(const Module& module,
113-
const std::map<Function*, FuncInfo>& funcInfos) {
123+
using CallGraphNode = std::variant<Function*, HeapType>;
124+
using CallGraph =
125+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
126+
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
127+
buildCallGraph(Module& module,
128+
const std::map<Function*, FuncInfo>& funcInfos,
129+
bool closedWorld) {
114130
CallGraph callGraph;
115-
for (const auto& [func, info] : funcInfos) {
116-
for (Name callee : info.calledFunctions) {
117-
callGraph[func].insert(module.getFunction(callee));
131+
132+
if (!closedWorld) {
133+
for (const auto& [caller, callerInfo] : funcInfos) {
134+
for (Name calleeFunction : callerInfo.calledFunctions) {
135+
callGraph[caller].insert(module.getFunction(calleeFunction));
136+
}
137+
}
138+
return callGraph;
139+
}
140+
141+
std::unordered_set<HeapType> allFunctionTypes;
142+
for (const auto& [caller, callerInfo] : funcInfos) {
143+
for (Name calleeFunction : callerInfo.calledFunctions) {
144+
callGraph[caller].insert(module.getFunction(calleeFunction));
145+
}
146+
147+
allFunctionTypes.insert(caller->type.getHeapType());
148+
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
149+
callGraph[caller].insert(calleeType);
150+
allFunctionTypes.insert(calleeType);
118151
}
152+
callGraph[caller->type.getHeapType()].insert(caller);
153+
}
154+
155+
SubTypes subtypes(module);
156+
for (HeapType type : allFunctionTypes) {
157+
subtypes.iterSubTypes(type, [&callGraph, type](HeapType sub, auto _) {
158+
callGraph[type].insert(sub);
159+
return true;
160+
});
119161
}
120162

121163
return callGraph;
@@ -148,99 +190,115 @@ void propagateEffects(const Module& module,
148190
std::map<Function*, FuncInfo>& funcInfos,
149191
const CallGraph& callGraph) {
150192
struct CallGraphSCCs
151-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
193+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
152194
const std::map<Function*, FuncInfo>& funcInfos;
153-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
154-
callGraph;
155195
const Module& module;
196+
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
197+
callGraph;
156198

157199
CallGraphSCCs(
158-
const std::vector<Function*>& funcs,
200+
const std::vector<CallGraphNode>& nodes,
159201
const std::map<Function*, FuncInfo>& funcInfos,
160-
const std::unordered_map<Function*, std::unordered_set<Function*>>&
161-
callGraph,
162-
const Module& module)
163-
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
164-
funcs.begin(), funcs.end()),
165-
funcInfos(funcInfos), callGraph(callGraph), module(module) {}
166-
167-
void pushChildren(Function* f) {
168-
auto callees = callGraph.find(f);
202+
const Module& module,
203+
const std::unordered_map<CallGraphNode,
204+
std::unordered_set<CallGraphNode>>& callGraph)
205+
: SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
206+
nodes.begin(), nodes.end()),
207+
funcInfos(funcInfos), module(module), callGraph(callGraph) {}
208+
209+
void pushChildren(CallGraphNode node) {
210+
auto callees = callGraph.find(node);
169211
if (callees == callGraph.end()) {
170212
return;
171213
}
172-
173-
for (auto* callee : callees->second) {
214+
for (CallGraphNode callee : callees->second) {
174215
push(callee);
175216
}
176217
}
177218
};
178219

179-
std::vector<Function*> allFuncs;
220+
// We only care about Functions that are roots, not types
221+
// A type would be a root if a function exists with that type, but no-one
222+
// indirect calls the type.
223+
std::vector<CallGraphNode> allFuncs;
180224
for (auto& [func, info] : funcInfos) {
181225
allFuncs.push_back(func);
182226
}
183-
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
227+
228+
CallGraphSCCs sccs(allFuncs, funcInfos, module, callGraph);
184229

185230
std::vector<std::optional<EffectAnalyzer>> componentEffects;
186231
// Points to an index in componentEffects
187-
std::unordered_map<Function*, Index> funcComponents;
232+
std::unordered_map<CallGraphNode, Index> funcComponents;
188233

189234
for (auto ccIterator : sccs) {
190235
std::optional<EffectAnalyzer>& ccEffects =
191236
componentEffects.emplace_back(std::in_place, passOptions, module);
237+
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());
192238

193-
std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());
194-
195-
for (Function* f : ccFuncs) {
196-
funcComponents.emplace(f, componentEffects.size() - 1);
197-
}
239+
ccEffects.emplace(passOptions, module);
198240

199-
std::unordered_set<int> calleeSccs;
200-
for (Function* caller : ccFuncs) {
201-
auto callees = callGraph.find(caller);
202-
if (callees == callGraph.end()) {
203-
continue;
204-
}
205-
for (auto* callee : callees->second) {
206-
calleeSccs.insert(funcComponents.at(callee));
241+
std::vector<Function*> ccFuncs;
242+
for (CallGraphNode node : cc) {
243+
funcComponents.emplace(node, componentEffects.size() - 1);
244+
if (auto** func = std::get_if<Function*>(&node)) {
245+
ccFuncs.push_back(*func);
207246
}
208-
}
209247

210-
// Merge in effects from callees
211-
for (int calleeScc : calleeSccs) {
212-
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
213-
mergeMaybeEffects(ccEffects, calleeComponentEffects);
214-
}
248+
std::unordered_set<int> calleeSccs;
249+
for (CallGraphNode caller : cc) {
250+
auto callees = callGraph.find(caller);
251+
if (callees == callGraph.end()) {
252+
continue;
253+
}
254+
for (auto callee : callees->second) {
255+
if (callees != callGraph.end()) {
256+
for (const auto& callee : callees->second) {
257+
auto sccIt = funcComponents.find(callee);
258+
if (sccIt != funcComponents.end()) {
259+
calleeSccs.insert(sccIt->second);
260+
}
261+
}
262+
}
263+
}
264+
}
215265

216-
// Add trap effects for potential cycles.
217-
if (ccFuncs.size() > 1) {
218-
if (ccEffects != UnknownEffects) {
219-
ccEffects->trap = true;
266+
// Merge in effects from callees
267+
for (int calleeScc : calleeSccs) {
268+
const auto& calleeComponentEffects = componentEffects.at(calleeScc);
269+
mergeMaybeEffects(ccEffects, calleeComponentEffects);
220270
}
221-
} else {
222-
auto* func = ccFuncs[0];
223-
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
271+
272+
// Add trap effects for potential cycles.
273+
if (cc.size() > 1) {
224274
if (ccEffects != UnknownEffects) {
225275
ccEffects->trap = true;
226276
}
277+
// A cycle isn't possible for a CC that only contains a type
278+
} else if (ccFuncs.size() == 1) {
279+
auto* func = ccFuncs[0];
280+
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
281+
if (ccEffects != UnknownEffects) {
282+
ccEffects->trap = true;
283+
}
284+
}
227285
}
228-
}
229286

230-
// Aggregate effects within this CC
231-
if (ccEffects) {
232-
for (Function* f : ccFuncs) {
233-
const auto& effects = funcInfos.at(f).effects;
234-
mergeMaybeEffects(ccEffects, effects);
287+
// Aggregate effects within this CC
288+
if (ccEffects) {
289+
for (Function* f : ccFuncs) {
290+
const auto& effects = funcInfos.at(f).effects;
291+
mergeMaybeEffects(ccEffects, effects);
292+
}
235293
}
236-
}
237294

238-
// Assign each function's effects to its CC effects.
239-
for (Function* f : ccFuncs) {
240-
if (!ccEffects) {
241-
funcInfos.at(f).effects = UnknownEffects;
242-
} else {
243-
funcInfos.at(f).effects.emplace(*ccEffects);
295+
// Assign each function's effects to its CC effects.
296+
for (Function* f : ccFuncs) {
297+
if (!ccEffects) {
298+
funcInfos.at(f).effects = UnknownEffects;
299+
} else {
300+
funcInfos.at(f).effects.emplace(*ccEffects);
301+
}
244302
}
245303
}
246304
}
@@ -262,7 +320,8 @@ struct GenerateGlobalEffects : public Pass {
262320
std::map<Function*, FuncInfo> funcInfos =
263321
analyzeFuncs(*module, getPassOptions());
264322

265-
auto callGraph = buildCallGraph(*module, funcInfos);
323+
auto callGraph =
324+
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);
266325

267326
propagateEffects(*module, getPassOptions(), funcInfos, callGraph);
268327

src/support/utilities.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class Fatal {
9494
#define WASM_UNREACHABLE(msg) wasm::handle_unreachable()
9595
#endif
9696

97+
template<class... Ts> struct overloaded : Ts... {
98+
using Ts::operator()...;
99+
};
100+
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
97101
} // namespace wasm
98102

99103
#endif // wasm_support_utilities_h

0 commit comments

Comments
 (0)