2121
2222#include " ir/effects.h"
2323#include " ir/module-utils.h"
24+ #include " ir/subtypes.h"
2425#include " pass.h"
2526#include " support/strongly_connected_components.h"
27+ #include " support/utilities.h"
2628#include " wasm.h"
2729
2830namespace wasm {
@@ -39,6 +41,9 @@ struct FuncInfo {
3941
4042 // Directly-called functions from this function.
4143 std::unordered_set<Name> calledFunctions;
44+
45+ // Types that are targets of indirect calls.
46+ std::unordered_set<HeapType> indirectCalledTypes;
4247};
4348
4449std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -83,11 +88,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8388 if (auto * call = curr->dynCast <Call>()) {
8489 // Note the direct call.
8590 funcInfo.calledFunctions .insert (call->target );
91+ } else if (effects.calls && options.closedWorld ) {
92+ HeapType type;
93+ if (auto * callRef = curr->dynCast <CallRef>()) {
94+ type = callRef->target ->type .getHeapType ();
95+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
96+ type = callIndirect->heapType ;
97+ } else {
98+ assert (false && " Unexpected type of call" );
99+ }
100+
101+ funcInfo.indirectCalledTypes .insert (type);
86102 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
103+ assert (!options.closedWorld );
91104 funcInfo.effects = UnknownEffects;
92105 } else {
93106 // No call here, but update throwing if we see it. (Only do so,
@@ -107,14 +120,52 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107120 return std::move (analysis.map );
108121}
109122
110- std::unordered_map<Function*, std::unordered_set<Function*>>
111- buildCallGraph (const Module& module ,
112- const std::map<Function*, FuncInfo>& funcInfos) {
113- std::unordered_map<Function*, std::unordered_set<Function*>> callGraph;
114- for (const auto & [func, info] : funcInfos) {
115- for (Name callee : info.calledFunctions ) {
116- callGraph[func].insert (module .getFunction (callee));
123+ using CallGraphNode = std::variant<Function*, HeapType>;
124+
125+ // Build a call graph for indirect and direct calls.
126+ // key (callee) -> value (caller)
127+ // Name -> Name : callee is called directly by caller
128+ // Name -> HeapType : callee is a potential target of a virtual call
129+ // with this HeapType
130+ // HeapType -> Name : callee is indirectly called by caller
131+ // HeapType -> HeapType : callee is a subtype of caller If we're
132+ // running in an open world, we only include Name -> Name edges.
133+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
134+ buildCallGraph (Module& module ,
135+ const std::map<Function*, FuncInfo>& funcInfos,
136+ bool closedWorld) {
137+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
138+ callGraph;
139+
140+ if (!closedWorld) {
141+ for (const auto & [caller, callerInfo] : funcInfos) {
142+ for (Name calleeFunction : callerInfo.calledFunctions ) {
143+ callGraph[caller].insert (module .getFunction (calleeFunction));
144+ }
145+ }
146+ return callGraph;
147+ }
148+
149+ std::unordered_set<HeapType> allFunctionTypes;
150+ for (const auto & [caller, callerInfo] : funcInfos) {
151+ for (Name calleeFunction : callerInfo.calledFunctions ) {
152+ callGraph[caller].insert (module .getFunction (calleeFunction));
153+ }
154+
155+ allFunctionTypes.insert (caller->type .getHeapType ());
156+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
157+ callGraph[caller].insert (calleeType);
158+ allFunctionTypes.insert (calleeType);
117159 }
160+ callGraph[caller->type .getHeapType ()].insert (caller);
161+ }
162+
163+ SubTypes subtypes (module );
164+ for (HeapType type : allFunctionTypes) {
165+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
166+ callGraph[type].insert (sub);
167+ return true ;
168+ });
118169 }
119170
120171 return callGraph;
@@ -130,69 +181,79 @@ buildCallGraph(const Module& module,
130181// - Also merge the (already computed) effects of each callee CC
131182// - Add trap effects for potentially recursive call chains
132183void propagateEffects (
133- const Module& module ,
184+ Module& module ,
134185 const PassOptions& passOptions,
135186 std::map<Function*, FuncInfo>& funcInfos,
136- const std::unordered_map<Function* , std::unordered_set<Function*>>
187+ const std::unordered_map<CallGraphNode , std::unordered_set<CallGraphNode>>&
137188 callGraph) {
189+
138190 struct CallGraphSCCs
139- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs> {
191+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs> {
140192 const std::map<Function*, FuncInfo>& funcInfos;
141- const std::unordered_map<Function*, std::unordered_set<Function*>>&
142- callGraph;
143193 const Module& module ;
194+ const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
195+ callGraph;
144196
145197 CallGraphSCCs (
146- const std::vector<Function* >& funcs ,
198+ const std::vector<CallGraphNode >& nodes ,
147199 const std::map<Function*, FuncInfo>& funcInfos,
148- const std::unordered_map<Function*, std::unordered_set<Function*>>&
149- callGraph ,
150- const Module& module )
151- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs>(
152- funcs .begin(), funcs .end()),
153- funcInfos (funcInfos), callGraph(callGraph ), module ( module ) {}
154-
155- void pushChildren (Function* f ) {
156- auto callees = callGraph.find (f );
200+ Module& module ,
201+ const std::unordered_map<CallGraphNode ,
202+ std::unordered_set<CallGraphNode>>& callGraph )
203+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs>(
204+ nodes .begin(), nodes .end()),
205+ funcInfos (funcInfos), module ( module ), callGraph(callGraph ) {}
206+
207+ void pushChildren (CallGraphNode node ) {
208+ auto callees = callGraph.find (node );
157209 if (callees == callGraph.end ()) {
158210 return ;
159211 }
160-
161- for (auto * callee : callees->second ) {
212+ for (CallGraphNode callee : callees->second ) {
162213 push (callee);
163214 }
164215 }
165216 };
166217
167- std::vector<Function*> allFuncs;
218+ // We only care about Functions that are roots, not types
219+ // A type would be a root if a function exists with that type, but no-one
220+ // indirect calls the type.
221+ std::vector<CallGraphNode> allFuncs;
168222 for (auto & [func, info] : funcInfos) {
169223 allFuncs.push_back (func);
170224 }
171- CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
172225
173- std::unordered_map<Function*, int > sccMembers;
226+ CallGraphSCCs sccs (allFuncs, funcInfos, module , callGraph);
227+
228+ std::unordered_map<CallGraphNode, int > sccMembers;
174229 std::unordered_map<int , std::optional<EffectAnalyzer>> componentEffects;
175230
176231 int ccIndex = 0 ;
177232 for (auto ccIterator : sccs) {
178233 ccIndex++;
179234 std::optional<EffectAnalyzer>& ccEffects = componentEffects[ccIndex];
180- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
235+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
181236
182237 ccEffects.emplace (passOptions, module );
183238
184- for (Function* f : ccFuncs) {
185- sccMembers.emplace (f, ccIndex);
239+ std::vector<Function*> ccFuncs;
240+ for (CallGraphNode node : cc) {
241+ sccMembers.emplace (node, ccIndex);
242+ if (auto ** func = std::get_if<Function*>(&node)) {
243+ ccFuncs.push_back (*func);
244+ }
186245 }
187246
188247 std::unordered_set<int > calleeSccs;
189- for (Function* caller : ccFuncs ) {
248+ for (CallGraphNode caller : cc ) {
190249 auto callees = callGraph.find (caller);
191- if (callees == callGraph.end ()) {
192- continue ;
193- }
194- for (auto * callee : callees->second ) {
195- calleeSccs.insert (sccMembers.at (callee));
250+ if (callees != callGraph.end ()) {
251+ for (const auto & callee : callees->second ) {
252+ auto sccIt = sccMembers.find (callee);
253+ if (sccIt != sccMembers.end ()) {
254+ calleeSccs.insert (sccIt->second );
255+ }
256+ }
196257 }
197258 }
198259
@@ -210,11 +271,12 @@ void propagateEffects(
210271 }
211272
212273 // Add trap effects for potential cycles.
213- if (ccFuncs .size () > 1 ) {
274+ if (cc .size () > 1 ) {
214275 if (ccEffects != UnknownEffects) {
215276 ccEffects->trap = true ;
216277 }
217- } else {
278+ // A cycle isn't possible for a CC that only contains a type
279+ } else if (ccFuncs.size () == 1 ) {
218280 auto * func = ccFuncs[0 ];
219281 if (funcInfos.at (func).calledFunctions .contains (func->name )) {
220282 if (ccEffects != UnknownEffects) {
@@ -263,7 +325,8 @@ struct GenerateGlobalEffects : public Pass {
263325 std::map<Function*, FuncInfo> funcInfos =
264326 analyzeFuncs (*module , getPassOptions ());
265327
266- auto callGraph = buildCallGraph (*module , funcInfos);
328+ auto callGraph =
329+ buildCallGraph (*module , funcInfos, getPassOptions ().closedWorld );
267330
268331 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
269332
0 commit comments