2121
2222#include " ir/effects.h"
2323#include " ir/module-utils.h"
24+ #include " ir/subtypes.h"
2425#include " pass.h"
2526#include " support/strongly_connected_components.h"
27+ #include " support/utilities.h"
2628#include " wasm.h"
2729
2830namespace wasm {
@@ -39,6 +41,9 @@ struct FuncInfo {
3941
4042 // Directly-called functions from this function.
4143 std::unordered_set<Name> calledFunctions;
44+
45+ // Types that are targets of indirect calls.
46+ std::unordered_set<HeapType> indirectCalledTypes;
4247};
4348
4449std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -83,11 +88,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8388 if (auto * call = curr->dynCast <Call>()) {
8489 // Note the direct call.
8590 funcInfo.calledFunctions .insert (call->target );
91+ } else if (effects.calls && options.closedWorld ) {
92+ HeapType type;
93+ if (auto * callRef = curr->dynCast <CallRef>()) {
94+ type = callRef->target ->type .getHeapType ();
95+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
96+ type = callIndirect->heapType ;
97+ } else {
98+ assert (false && " Unexpected type of call" );
99+ }
100+
101+ funcInfo.indirectCalledTypes .insert (type);
86102 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
103+ assert (!options.closedWorld );
91104 funcInfo.effects = UnknownEffects;
92105 } else {
93106 // No call here, but update throwing if we see it. (Only do so,
@@ -107,15 +120,44 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107120 return std::move (analysis.map );
108121}
109122
110- using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
123+ using CallGraphNode = std::variant<Function*, HeapType>;
124+ using CallGraph =
125+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111126
112- CallGraph buildCallGraph (const Module& module ,
113- const std::map<Function*, FuncInfo>& funcInfos) {
127+ CallGraph buildCallGraph (Module& module ,
128+ const std::map<Function*, FuncInfo>& funcInfos,
129+ bool closedWorld) {
114130 CallGraph callGraph;
115- for (const auto & [func, info] : funcInfos) {
116- for (Name callee : info.calledFunctions ) {
117- callGraph[func].insert (module .getFunction (callee));
131+
132+ if (!closedWorld) {
133+ for (const auto & [caller, callerInfo] : funcInfos) {
134+ for (Name calleeFunction : callerInfo.calledFunctions ) {
135+ callGraph[caller].insert (module .getFunction (calleeFunction));
136+ }
137+ }
138+ return callGraph;
139+ }
140+
141+ std::unordered_set<HeapType> allFunctionTypes;
142+ for (const auto & [caller, callerInfo] : funcInfos) {
143+ for (Name calleeFunction : callerInfo.calledFunctions ) {
144+ callGraph[caller].insert (module .getFunction (calleeFunction));
145+ }
146+
147+ allFunctionTypes.insert (caller->type .getHeapType ());
148+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
149+ callGraph[caller].insert (calleeType);
150+ allFunctionTypes.insert (calleeType);
118151 }
152+ callGraph[caller->type .getHeapType ()].insert (caller);
153+ }
154+
155+ SubTypes subtypes (module );
156+ for (HeapType type : allFunctionTypes) {
157+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
158+ callGraph[type].insert (sub);
159+ return true ;
160+ });
119161 }
120162
121163 return callGraph;
@@ -148,99 +190,111 @@ void propagateEffects(const Module& module,
148190 std::map<Function*, FuncInfo>& funcInfos,
149191 const CallGraph& callGraph) {
150192 struct CallGraphSCCs
151- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs> {
193+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs> {
152194 const std::map<Function*, FuncInfo>& funcInfos;
153- const std::unordered_map<Function*, std::unordered_set<Function*>>&
154- callGraph;
195+ const CallGraph& callGraph;
155196 const Module& module ;
156197
157198 CallGraphSCCs (
158- const std::vector<Function* >& funcs ,
199+ const std::vector<CallGraphNode >& nodes ,
159200 const std::map<Function*, FuncInfo>& funcInfos,
160- const std::unordered_map<Function*, std::unordered_set<Function*>>&
161- callGraph,
201+ const std::unordered_map<CallGraphNode,
202+ std::unordered_set<CallGraphNode>>& callGraph,
162203 const Module& module )
163- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs>(
164- funcs .begin(), funcs .end()),
165- funcInfos (funcInfos), callGraph(callGraph ), module ( module ) {}
204+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs>(
205+ nodes .begin(), nodes .end()),
206+ funcInfos (funcInfos), module ( module ), callGraph(callGraph ) {}
166207
167- void pushChildren (Function* f ) {
168- auto callees = callGraph.find (f );
208+ void pushChildren (CallGraphNode node ) {
209+ auto callees = callGraph.find (node );
169210 if (callees == callGraph.end ()) {
170211 return ;
171212 }
172-
173- for (auto * callee : callees->second ) {
213+ for (CallGraphNode callee : callees->second ) {
174214 push (callee);
175215 }
176216 }
177217 };
178218
179- std::vector<Function*> allFuncs;
219+ // We only care about Functions that are roots, not types
220+ // A type would be a root if a function exists with that type, but no-one
221+ // indirect calls the type.
222+ std::vector<CallGraphNode> allFuncs;
180223 for (auto & [func, info] : funcInfos) {
181224 allFuncs.push_back (func);
182225 }
226+
183227 CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
184228
185229 std::vector<std::optional<EffectAnalyzer>> componentEffects;
186230 // Points to an index in componentEffects
187- std::unordered_map<Function* , Index> funcComponents;
231+ std::unordered_map<CallGraphNode , Index> funcComponents;
188232
189233 for (auto ccIterator : sccs) {
190234 std::optional<EffectAnalyzer>& ccEffects =
191235 componentEffects.emplace_back (std::in_place, passOptions, module );
236+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
192237
193- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
194-
195- for (Function* f : ccFuncs) {
196- funcComponents.emplace (f, componentEffects.size () - 1 );
197- }
198-
199- std::unordered_set<int > calleeSccs;
200- for (Function* caller : ccFuncs) {
201- auto callees = callGraph.find (caller);
202- if (callees == callGraph.end ()) {
203- continue ;
204- }
205- for (auto * callee : callees->second ) {
206- calleeSccs.insert (funcComponents.at (callee));
238+ std::vector<Function*> ccFuncs;
239+ for (CallGraphNode node : cc) {
240+ funcComponents.emplace (node, componentEffects.size () - 1 );
241+ if (auto ** func = std::get_if<Function*>(&node)) {
242+ ccFuncs.push_back (*func);
207243 }
208- }
209244
210- // Merge in effects from callees
211- for (int calleeScc : calleeSccs) {
212- const auto & calleeComponentEffects = componentEffects.at (calleeScc);
213- mergeMaybeEffects (ccEffects, calleeComponentEffects);
214- }
245+ std::unordered_set<int > calleeSccs;
246+ for (CallGraphNode caller : cc) {
247+ auto callees = callGraph.find (caller);
248+ if (callees == callGraph.end ()) {
249+ continue ;
250+ }
251+ if (callees != callGraph.end ()) {
252+ for (CallGraphNode callee : callees->second ) {
253+ auto sccIt = funcComponents.find (callee);
254+ if (sccIt != funcComponents.end ()) {
255+ calleeSccs.insert (sccIt->second );
256+ }
257+ }
258+ }
259+ }
215260
216- // Add trap effects for potential cycles.
217- if (ccFuncs. size () > 1 ) {
218- if (ccEffects != UnknownEffects) {
219- ccEffects-> trap = true ;
261+ // Merge in effects from callees
262+ for ( int calleeScc : calleeSccs ) {
263+ const auto & calleeComponentEffects = componentEffects. at (calleeScc);
264+ mergeMaybeEffects ( ccEffects, calleeComponentEffects) ;
220265 }
221- } else {
222- auto * func = ccFuncs[ 0 ];
223- if (funcInfos. at (func). calledFunctions . contains (func-> name ) ) {
266+
267+ // Add trap effects for potential cycles.
268+ if (cc. size () > 1 ) {
224269 if (ccEffects != UnknownEffects) {
225270 ccEffects->trap = true ;
226271 }
272+ } else if (ccFuncs.size () == 1 ) {
273+ // It's possible for a CC to only contain 1 type, but that is not a
274+ // cycle in the call graph.
275+ auto * func = ccFuncs[0 ];
276+ if (funcInfos.at (func).calledFunctions .contains (func->name )) {
277+ if (ccEffects != UnknownEffects) {
278+ ccEffects->trap = true ;
279+ }
280+ }
227281 }
228- }
229282
230- // Aggregate effects within this CC
231- if (ccEffects) {
232- for (Function* f : ccFuncs) {
233- const auto & effects = funcInfos.at (f).effects ;
234- mergeMaybeEffects (ccEffects, effects);
283+ // Aggregate effects within this CC
284+ if (ccEffects) {
285+ for (Function* f : ccFuncs) {
286+ const auto & effects = funcInfos.at (f).effects ;
287+ mergeMaybeEffects (ccEffects, effects);
288+ }
235289 }
236- }
237290
238- // Assign each function's effects to its CC effects.
239- for (Function* f : ccFuncs) {
240- if (!ccEffects) {
241- funcInfos.at (f).effects = UnknownEffects;
242- } else {
243- funcInfos.at (f).effects .emplace (*ccEffects);
291+ // Assign each function's effects to its CC effects.
292+ for (Function* f : ccFuncs) {
293+ if (!ccEffects) {
294+ funcInfos.at (f).effects = UnknownEffects;
295+ } else {
296+ funcInfos.at (f).effects .emplace (*ccEffects);
297+ }
244298 }
245299 }
246300 }
@@ -262,7 +316,8 @@ struct GenerateGlobalEffects : public Pass {
262316 std::map<Function*, FuncInfo> funcInfos =
263317 analyzeFuncs (*module , getPassOptions ());
264318
265- auto callGraph = buildCallGraph (*module , funcInfos);
319+ auto callGraph =
320+ buildCallGraph (*module , funcInfos, getPassOptions ().closedWorld );
266321
267322 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
268323
0 commit comments