2121
2222#include " ir/effects.h"
2323#include " ir/module-utils.h"
24+ #include " ir/subtypes.h"
2425#include " pass.h"
2526#include " support/strongly_connected_components.h"
2627#include " wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940
4041 // Directly-called functions from this function.
4142 std::unordered_set<Name> calledFunctions;
43+
44+ // Types that are targets of indirect calls.
45+ std::unordered_set<HeapType> indirectCalledTypes;
4246};
4347
4448std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -83,11 +87,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8387 if (auto * call = curr->dynCast <Call>()) {
8488 // Note the direct call.
8589 funcInfo.calledFunctions .insert (call->target );
90+ } else if (effects.calls && options.closedWorld ) {
91+ HeapType type;
92+ if (auto * callRef = curr->dynCast <CallRef>()) {
93+ type = callRef->target ->type .getHeapType ();
94+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
95+ type = callIndirect->heapType ;
96+ } else {
97+ assert (false && " Unexpected type of call" );
98+ }
99+
100+ funcInfo.indirectCalledTypes .insert (type);
86101 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
102+ assert (!options.closedWorld );
91103 funcInfo.effects = UnknownEffects;
92104 } else {
93105 // No call here, but update throwing if we see it. (Only do so,
@@ -107,15 +119,44 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107119 return std::move (analysis.map );
108120}
109121
110- using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
122+ using CallGraphNode = std::variant<Function*, HeapType>;
123+ using CallGraph =
124+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111125
112- CallGraph buildCallGraph (const Module& module ,
113- const std::map<Function*, FuncInfo>& funcInfos) {
126+ CallGraph buildCallGraph (Module& module ,
127+ const std::map<Function*, FuncInfo>& funcInfos,
128+ bool closedWorld) {
114129 CallGraph callGraph;
115- for (const auto & [func, info] : funcInfos) {
116- for (Name callee : info.calledFunctions ) {
117- callGraph[func].insert (module .getFunction (callee));
130+
131+ if (!closedWorld) {
132+ for (const auto & [caller, callerInfo] : funcInfos) {
133+ for (Name calleeFunction : callerInfo.calledFunctions ) {
134+ callGraph[caller].insert (module .getFunction (calleeFunction));
135+ }
136+ }
137+ return callGraph;
138+ }
139+
140+ std::unordered_set<HeapType> allFunctionTypes;
141+ for (const auto & [caller, callerInfo] : funcInfos) {
142+ for (Name calleeFunction : callerInfo.calledFunctions ) {
143+ callGraph[caller].insert (module .getFunction (calleeFunction));
144+ }
145+
146+ allFunctionTypes.insert (caller->type .getHeapType ());
147+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
148+ callGraph[caller].insert (calleeType);
149+ allFunctionTypes.insert (calleeType);
118150 }
151+ callGraph[caller->type .getHeapType ()].insert (caller);
152+ }
153+
154+ SubTypes subtypes (module );
155+ for (HeapType type : allFunctionTypes) {
156+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
157+ callGraph[type].insert (sub);
158+ return true ;
159+ });
119160 }
120161
121162 return callGraph;
@@ -148,99 +189,111 @@ void propagateEffects(const Module& module,
148189 std::map<Function*, FuncInfo>& funcInfos,
149190 const CallGraph& callGraph) {
150191 struct CallGraphSCCs
151- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs> {
192+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs> {
152193 const std::map<Function*, FuncInfo>& funcInfos;
153- const std::unordered_map<Function*, std::unordered_set<Function*>>&
154- callGraph;
194+ const CallGraph& callGraph;
155195 const Module& module ;
156196
157197 CallGraphSCCs (
158- const std::vector<Function* >& funcs ,
198+ const std::vector<CallGraphNode >& nodes ,
159199 const std::map<Function*, FuncInfo>& funcInfos,
160- const std::unordered_map<Function*, std::unordered_set<Function*>>&
161- callGraph,
200+ const std::unordered_map<CallGraphNode,
201+ std::unordered_set<CallGraphNode>>& callGraph,
162202 const Module& module )
163- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs>(
164- funcs .begin(), funcs .end()),
203+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs>(
204+ nodes .begin(), nodes .end()),
165205 funcInfos (funcInfos), callGraph(callGraph), module (module ) {}
166206
167- void pushChildren (Function* f ) {
168- auto callees = callGraph.find (f );
207+ void pushChildren (CallGraphNode node ) {
208+ auto callees = callGraph.find (node );
169209 if (callees == callGraph.end ()) {
170210 return ;
171211 }
172-
173- for (auto * callee : callees->second ) {
212+ for (CallGraphNode callee : callees->second ) {
174213 push (callee);
175214 }
176215 }
177216 };
178217
179- std::vector<Function*> allFuncs;
218+ // We only care about Functions that are roots, not types
219+ // A type would be a root if a function exists with that type, but no-one
220+ // indirect calls the type.
221+ std::vector<CallGraphNode> allFuncs;
180222 for (auto & [func, info] : funcInfos) {
181223 allFuncs.push_back (func);
182224 }
225+
183226 CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
184227
185228 std::vector<std::optional<EffectAnalyzer>> componentEffects;
186229 // Points to an index in componentEffects
187- std::unordered_map<Function* , Index> funcComponents;
230+ std::unordered_map<CallGraphNode , Index> funcComponents;
188231
189232 for (auto ccIterator : sccs) {
190233 std::optional<EffectAnalyzer>& ccEffects =
191234 componentEffects.emplace_back (std::in_place, passOptions, module );
235+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
192236
193- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
194-
195- for (Function* f : ccFuncs) {
196- funcComponents.emplace (f, componentEffects.size () - 1 );
197- }
198-
199- std::unordered_set<int > calleeSccs;
200- for (Function* caller : ccFuncs) {
201- auto callees = callGraph.find (caller);
202- if (callees == callGraph.end ()) {
203- continue ;
204- }
205- for (auto * callee : callees->second ) {
206- calleeSccs.insert (funcComponents.at (callee));
237+ std::vector<Function*> ccFuncs;
238+ for (CallGraphNode node : cc) {
239+ funcComponents.emplace (node, componentEffects.size () - 1 );
240+ if (auto ** func = std::get_if<Function*>(&node)) {
241+ ccFuncs.push_back (*func);
207242 }
208- }
209243
210- // Merge in effects from callees
211- for (int calleeScc : calleeSccs) {
212- const auto & calleeComponentEffects = componentEffects.at (calleeScc);
213- mergeMaybeEffects (ccEffects, calleeComponentEffects);
214- }
244+ std::unordered_set<int > calleeSccs;
245+ for (CallGraphNode caller : cc) {
246+ auto callees = callGraph.find (caller);
247+ if (callees == callGraph.end ()) {
248+ continue ;
249+ }
250+ if (callees != callGraph.end ()) {
251+ for (CallGraphNode callee : callees->second ) {
252+ auto sccIt = funcComponents.find (callee);
253+ if (sccIt != funcComponents.end ()) {
254+ calleeSccs.insert (sccIt->second );
255+ }
256+ }
257+ }
258+ }
215259
216- // Add trap effects for potential cycles.
217- if (ccFuncs. size () > 1 ) {
218- if (ccEffects != UnknownEffects) {
219- ccEffects-> trap = true ;
260+ // Merge in effects from callees
261+ for ( int calleeScc : calleeSccs ) {
262+ const auto & calleeComponentEffects = componentEffects. at (calleeScc);
263+ mergeMaybeEffects ( ccEffects, calleeComponentEffects) ;
220264 }
221- } else {
222- auto * func = ccFuncs[ 0 ];
223- if (funcInfos. at (func). calledFunctions . contains (func-> name ) ) {
265+
266+ // Add trap effects for potential cycles.
267+ if (cc. size () > 1 ) {
224268 if (ccEffects != UnknownEffects) {
225269 ccEffects->trap = true ;
226270 }
271+ } else if (ccFuncs.size () == 1 ) {
272+ // It's possible for a CC to only contain 1 type, but that is not a
273+ // cycle in the call graph.
274+ auto * func = ccFuncs[0 ];
275+ if (funcInfos.at (func).calledFunctions .contains (func->name )) {
276+ if (ccEffects != UnknownEffects) {
277+ ccEffects->trap = true ;
278+ }
279+ }
227280 }
228- }
229281
230- // Aggregate effects within this CC
231- if (ccEffects) {
232- for (Function* f : ccFuncs) {
233- const auto & effects = funcInfos.at (f).effects ;
234- mergeMaybeEffects (ccEffects, effects);
282+ // Aggregate effects within this CC
283+ if (ccEffects) {
284+ for (Function* f : ccFuncs) {
285+ const auto & effects = funcInfos.at (f).effects ;
286+ mergeMaybeEffects (ccEffects, effects);
287+ }
235288 }
236- }
237289
238- // Assign each function's effects to its CC effects.
239- for (Function* f : ccFuncs) {
240- if (!ccEffects) {
241- funcInfos.at (f).effects = UnknownEffects;
242- } else {
243- funcInfos.at (f).effects .emplace (*ccEffects);
290+ // Assign each function's effects to its CC effects.
291+ for (Function* f : ccFuncs) {
292+ if (!ccEffects) {
293+ funcInfos.at (f).effects = UnknownEffects;
294+ } else {
295+ funcInfos.at (f).effects .emplace (*ccEffects);
296+ }
244297 }
245298 }
246299 }
@@ -262,7 +315,8 @@ struct GenerateGlobalEffects : public Pass {
262315 std::map<Function*, FuncInfo> funcInfos =
263316 analyzeFuncs (*module , getPassOptions ());
264317
265- auto callGraph = buildCallGraph (*module , funcInfos);
318+ auto callGraph =
319+ buildCallGraph (*module , funcInfos, getPassOptions ().closedWorld );
266320
267321 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
268322
0 commit comments