2121
2222#include " ir/effects.h"
2323#include " ir/module-utils.h"
24+ #include " ir/subtypes.h"
2425#include " pass.h"
2526#include " support/strongly_connected_components.h"
2627#include " wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940
4041 // Directly-called functions from this function.
4142 std::unordered_set<Name> calledFunctions;
43+
44+ // Types that are targets of indirect calls.
45+ std::unordered_set<HeapType> indirectCalledTypes;
4246};
4347
4448std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -83,11 +87,21 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8387 if (auto * call = curr->dynCast <Call>()) {
8488 // Note the direct call.
8589 funcInfo.calledFunctions .insert (call->target );
90+ } else if (effects.calls && options.closedWorld ) {
91+ HeapType type;
92+ if (auto * callRef = curr->dynCast <CallRef>()) {
93+ type = callRef->target ->type .getHeapType ();
94+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
95+ // nullability doesn't matter here
96+ // call_indirect is always inexact
97+ type = callIndirect->heapType ;
98+ } else {
99+ assert (false && " Unexpected type of call" );
100+ }
101+
102+ funcInfo.indirectCalledTypes .insert (type);
86103 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
104+ assert (!options.closedWorld );
91105 funcInfo.effects = UnknownEffects;
92106 } else {
93107 // No call here, but update throwing if we see it. (Only do so,
@@ -107,14 +121,52 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107121 return std::move (analysis.map );
108122}
109123
110- std::unordered_map<Function*, std::unordered_set<Function*>>
111- buildCallGraph (const Module& module ,
112- const std::map<Function*, FuncInfo>& funcInfos) {
113- std::unordered_map<Function*, std::unordered_set<Function*>> callGraph;
114- for (const auto & [func, info] : funcInfos) {
115- for (Name callee : info.calledFunctions ) {
116- callGraph[func].insert (module .getFunction (callee));
124+ using CallGraphNode = std::variant<Function*, HeapType>;
125+
126+ // Build a call graph for indirect and direct calls.
127+ // key (callee) -> value (caller)
128+ // Name -> Name : callee is called directly by caller
129+ // Name -> HeapType : callee is a potential target of a virtual call
130+ // with this HeapType
131+ // HeapType -> Name : callee is indirectly called by caller
132+ // HeapType -> HeapType : callee is a subtype of caller If we're
133+ // running in an open world, we only include Name -> Name edges.
134+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
135+ buildCallGraph (Module& module ,
136+ const std::map<Function*, FuncInfo>& funcInfos,
137+ bool closedWorld) {
138+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
139+ callGraph;
140+
141+ if (!closedWorld) {
142+ for (const auto & [caller, callerInfo] : funcInfos) {
143+ for (Name calleeFunction : callerInfo.calledFunctions ) {
144+ callGraph[caller].insert (module .getFunction (calleeFunction));
145+ }
146+ }
147+ return callGraph;
148+ }
149+
150+ std::unordered_set<HeapType> allFunctionTypes;
151+ for (const auto & [caller, callerInfo] : funcInfos) {
152+ for (Name calleeFunction : callerInfo.calledFunctions ) {
153+ callGraph[caller].insert (module .getFunction (calleeFunction));
117154 }
155+
156+ allFunctionTypes.insert (caller->type .getHeapType ());
157+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
158+ callGraph[caller].insert (calleeType);
159+ allFunctionTypes.insert (calleeType);
160+ }
161+ callGraph[caller->type .getHeapType ()].insert (caller);
162+ }
163+
164+ SubTypes subtypes (module );
165+ for (HeapType type : allFunctionTypes) {
166+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
167+ callGraph[type].insert (sub);
168+ return true ;
169+ });
118170 }
119171
120172 return callGraph;
@@ -123,76 +175,89 @@ buildCallGraph(const Module& module,
123175// Propagate effects from callees to callers transitively
124176// e.g. if A -> B -> C (A calls B which calls C)
125177// Then B inherits effects from C and A inherits effects from both B and C.
126- //
127- // Generate SCC for the call graph, then traverse it in reverse topological
128- // order processing each callee before its callers. When traversing:
129- // - Merge all of the effects of functions within the CC
130- // - Also merge the (already computed) effects of each callee CC
131- // - Add trap effects for potentially recursive call chains
132178void propagateEffects (
133- const Module& module ,
179+ Module& module ,
134180 const PassOptions& passOptions,
135181 std::map<Function*, FuncInfo>& funcInfos,
136- const std::unordered_map<Function* , std::unordered_set<Function*>>
182+ const std::unordered_map<CallGraphNode , std::unordered_set<CallGraphNode>>&
137183 callGraph) {
184+
138185 struct CallGraphSCCs
139- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs> {
186+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs> {
140187 const std::map<Function*, FuncInfo>& funcInfos;
141- const std::unordered_map<Function*, std::unordered_set<Function*>>&
142- callGraph;
143188 const Module& module ;
189+ const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
190+ callGraph;
144191
145192 CallGraphSCCs (
146- const std::vector<Function* >& funcs ,
193+ const std::vector<CallGraphNode >& nodes ,
147194 const std::map<Function*, FuncInfo>& funcInfos,
148- const std::unordered_map<Function*, std::unordered_set<Function*>>&
149- callGraph ,
150- const Module& module )
151- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs>(
152- funcs .begin(), funcs .end()),
153- funcInfos (funcInfos), callGraph(callGraph ), module ( module ) {}
154-
155- void pushChildren (Function* f ) {
156- auto callees = callGraph.find (f );
195+ Module& module ,
196+ const std::unordered_map<CallGraphNode ,
197+ std::unordered_set<CallGraphNode>>& callGraph )
198+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs>(
199+ nodes .begin(), nodes .end()),
200+ funcInfos (funcInfos), module ( module ), callGraph(callGraph ) {}
201+
202+ void pushChildren (CallGraphNode node ) {
203+ auto callees = callGraph.find (node );
157204 if (callees == callGraph.end ()) {
158205 return ;
159206 }
160-
161- for (auto * callee : callees->second ) {
207+ for (const auto & callee : callees->second ) {
162208 push (callee);
163209 }
164210 }
165211 };
166212
167- std::vector<Function*> allFuncs;
213+ std::vector<CallGraphNode> allFuncs;
214+ // We only care about Functions that are roots, not types
215+ // A type would be a root if a function exists with that type, but no-one
216+ // indirect calls the type.
168217 for (auto & [func, info] : funcInfos) {
169218 allFuncs.push_back (func);
170219 }
171- CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
172220
173- std::unordered_map<Function*, int > sccMembers;
221+ CallGraphSCCs sccs (allFuncs, funcInfos, module , callGraph);
222+
223+ std::unordered_map<CallGraphNode, int > sccMembers;
174224 std::unordered_map<int , std::optional<EffectAnalyzer>> componentEffects;
175225
176226 int ccIndex = 0 ;
177227 for (auto ccIterator : sccs) {
178228 ccIndex++;
179229 std::optional<EffectAnalyzer>& ccEffects = componentEffects[ccIndex];
180- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
230+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
181231
182232 ccEffects.emplace (passOptions, module );
183233
234+ std::vector<Function*> ccFuncs;
235+ std::vector<HeapType> ccTypes;
236+ for (auto v : cc) {
237+ if (auto ** func = std::get_if<Function*>(&v)) {
238+ ccFuncs.push_back (*func);
239+ } else {
240+ ccTypes.push_back (std::get<HeapType>(v));
241+ }
242+ }
243+
184244 for (Function* f : ccFuncs) {
185245 sccMembers.emplace (f, ccIndex);
186246 }
247+ for (HeapType t : ccTypes) {
248+ sccMembers.emplace (t, ccIndex);
249+ }
187250
188251 std::unordered_set<int > calleeSccs;
189- for (Function* caller : ccFuncs ) {
252+ for (const auto & caller : cc ) {
190253 auto callees = callGraph.find (caller);
191- if (callees == callGraph.end ()) {
192- continue ;
193- }
194- for (auto * callee : callees->second ) {
195- calleeSccs.insert (sccMembers.at (callee));
254+ if (callees != callGraph.end ()) {
255+ for (const auto & callee : callees->second ) {
256+ auto sccIt = sccMembers.find (callee);
257+ if (sccIt != sccMembers.end ()) {
258+ calleeSccs.insert (sccIt->second );
259+ }
260+ }
196261 }
197262 }
198263
@@ -204,17 +269,18 @@ void propagateEffects(
204269 break ;
205270 }
206271
207- else if (ccEffects != UnknownEffects ) {
272+ else if (ccEffects) {
208273 ccEffects->mergeIn (*calleeComponentEffects);
209274 }
210275 }
211276
212277 // Add trap effects for potential cycles.
213- if (ccFuncs .size () > 1 ) {
278+ if (cc .size () > 1 ) {
214279 if (ccEffects != UnknownEffects) {
215280 ccEffects->trap = true ;
216281 }
217- } else {
282+ // A cycle isn't possible for a CC that only contains a type
283+ } else if (ccFuncs.size () == 1 ) {
218284 auto * func = ccFuncs[0 ];
219285 if (funcInfos.at (func).calledFunctions .contains (func->name )) {
220286 if (ccEffects != UnknownEffects) {
@@ -263,7 +329,8 @@ struct GenerateGlobalEffects : public Pass {
263329 std::map<Function*, FuncInfo> funcInfos =
264330 analyzeFuncs (*module , getPassOptions ());
265331
266- auto callGraph = buildCallGraph (*module , funcInfos);
332+ auto callGraph =
333+ buildCallGraph (*module , funcInfos, getPassOptions ().closedWorld );
267334
268335 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
269336
0 commit comments