2121
2222#include " ir/effects.h"
2323#include " ir/module-utils.h"
24+ #include " ir/subtypes.h"
2425#include " pass.h"
2526#include " support/strongly_connected_components.h"
2627#include " wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940
4041 // Directly-called functions from this function.
4142 std::unordered_set<Name> calledFunctions;
43+
44+ // Types that are targets of indirect calls.
45+ std::unordered_set<HeapType> indirectCalledTypes;
4246};
4347
4448std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -83,11 +87,19 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8387 if (auto * call = curr->dynCast <Call>()) {
8488 // Note the direct call.
8589 funcInfo.calledFunctions .insert (call->target );
90+ } else if (effects.calls && options.closedWorld ) {
91+ HeapType type;
92+ if (auto * callRef = curr->dynCast <CallRef>()) {
93+ type = callRef->target ->type .getHeapType ();
94+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
95+ type = callIndirect->heapType ;
96+ } else {
97+ assert (false && " Unexpected type of call" );
98+ }
99+
100+ funcInfo.indirectCalledTypes .insert (type);
86101 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
102+ assert (!options.closedWorld );
91103 funcInfo.effects = UnknownEffects;
92104 } else {
93105 // No call here, but update throwing if we see it. (Only do so,
@@ -107,20 +119,49 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107119 return std::move (analysis.map );
108120}
109121
110- using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
122+ using CallGraphNode = std::variant<Function*, HeapType>;
123+ using CallGraph =
124+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
111125
112- CallGraph buildCallGraph (const Module& module ,
113- const std::map<Function*, FuncInfo>& funcInfos) {
126+ CallGraph buildCallGraph (Module& module ,
127+ const std::map<Function*, FuncInfo>& funcInfos,
128+ bool closedWorld) {
114129 CallGraph callGraph;
115- for (const auto & [func, info] : funcInfos) {
116- if (info.calledFunctions .empty ()) {
117- continue ;
130+
131+ if (!closedWorld) {
132+ for (const auto & [func, info] : funcInfos) {
133+ if (info.calledFunctions .empty ()) {
134+ continue ;
135+ }
136+
137+ auto & callees = callGraph[func];
138+ for (Name calleeFunction : info.calledFunctions ) {
139+ callees.insert (module .getFunction (calleeFunction));
140+ }
118141 }
142+ return callGraph;
143+ }
119144
120- auto & callees = callGraph[func];
121- for (Name callee : info.calledFunctions ) {
122- callees.insert (module .getFunction (callee));
145+ std::unordered_set<HeapType> allFunctionTypes;
146+ for (const auto & [caller, callerInfo] : funcInfos) {
147+ for (Name calleeFunction : callerInfo.calledFunctions ) {
148+ callGraph[caller].insert (module .getFunction (calleeFunction));
123149 }
150+
151+ allFunctionTypes.insert (caller->type .getHeapType ());
152+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
153+ callGraph[caller].insert (calleeType);
154+ allFunctionTypes.insert (calleeType);
155+ }
156+ callGraph[caller->type .getHeapType ()].insert (caller);
157+ }
158+
159+ SubTypes subtypes (module );
160+ for (HeapType type : allFunctionTypes) {
161+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
162+ callGraph[type].insert (sub);
163+ return true ;
164+ });
124165 }
125166
126167 return callGraph;
@@ -153,61 +194,67 @@ void propagateEffects(const Module& module,
153194 std::map<Function*, FuncInfo>& funcInfos,
154195 const CallGraph& callGraph) {
155196 struct CallGraphSCCs
156- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs> {
197+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs> {
157198 const std::map<Function*, FuncInfo>& funcInfos;
158- const std::unordered_map<Function*, std::unordered_set<Function*>>&
159- callGraph;
199+ const CallGraph& callGraph;
160200 const Module& module ;
161201
162202 CallGraphSCCs (
163- const std::vector<Function* >& funcs ,
203+ const std::vector<CallGraphNode >& nodes ,
164204 const std::map<Function*, FuncInfo>& funcInfos,
165- const std::unordered_map<Function*, std::unordered_set<Function*>>&
166- callGraph,
205+ const std::unordered_map<CallGraphNode,
206+ std::unordered_set<CallGraphNode>>& callGraph,
167207 const Module& module )
168- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs>(
169- funcs .begin(), funcs .end()),
208+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs>(
209+ nodes .begin(), nodes .end()),
170210 funcInfos (funcInfos), callGraph(callGraph), module (module ) {}
171211
172- void pushChildren (Function* f ) {
173- auto callees = callGraph.find (f );
212+ void pushChildren (CallGraphNode node ) {
213+ auto callees = callGraph.find (node );
174214 if (callees == callGraph.end ()) {
175215 return ;
176216 }
177-
178- for (auto * callee : callees->second ) {
217+ for (CallGraphNode callee : callees->second ) {
179218 push (callee);
180219 }
181220 }
182221 };
183222
184- std::vector<Function*> allFuncs;
223+ // We only care about Functions that are roots, not types
224+ // A type would be a root if a function exists with that type, but no-one
225+ // indirect calls the type.
226+ std::vector<CallGraphNode> allFuncs;
185227 for (auto & [func, info] : funcInfos) {
186228 allFuncs.push_back (func);
187229 }
230+
188231 CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
189232
190233 std::vector<std::optional<EffectAnalyzer>> componentEffects;
191234 // Points to an index in componentEffects
192- std::unordered_map<Function* , Index> funcComponents;
235+ std::unordered_map<CallGraphNode , Index> funcComponents;
193236
194237 for (auto ccIterator : sccs) {
195238 std::optional<EffectAnalyzer>& ccEffects =
196239 componentEffects.emplace_back (std::in_place, passOptions, module );
240+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
197241
198- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
199-
200- for (Function* f : ccFuncs) {
201- funcComponents.emplace (f, componentEffects.size () - 1 );
242+ std::vector<Function*> ccFuncs;
243+ for (CallGraphNode node : cc) {
244+ funcComponents.emplace (node, componentEffects.size () - 1 );
245+ if (auto ** func = std::get_if<Function*>(&node)) {
246+ ccFuncs.push_back (*func);
247+ }
202248 }
203249
204250 std::unordered_set<int > calleeSccs;
205- for (Function* caller : ccFuncs ) {
251+ for (CallGraphNode caller : cc ) {
206252 auto callees = callGraph.find (caller);
207253 if (callees == callGraph.end ()) {
208254 continue ;
209255 }
210- for (auto * callee : callees->second ) {
256+
257+ for (CallGraphNode callee : callees->second ) {
211258 calleeSccs.insert (funcComponents.at (callee));
212259 }
213260 }
@@ -219,11 +266,13 @@ void propagateEffects(const Module& module,
219266 }
220267
221268 // Add trap effects for potential cycles.
222- if (ccFuncs .size () > 1 ) {
269+ if (cc .size () > 1 ) {
223270 if (ccEffects != UnknownEffects) {
224271 ccEffects->trap = true ;
225272 }
226- } else {
273+ } else if (ccFuncs.size () == 1 ) {
274+ // It's possible for a CC to only contain 1 type, but that is not a
275+ // cycle in the call graph.
227276 auto * func = ccFuncs[0 ];
228277 if (funcInfos.at (func).calledFunctions .contains (func->name )) {
229278 if (ccEffects != UnknownEffects) {
@@ -267,7 +316,8 @@ struct GenerateGlobalEffects : public Pass {
267316 std::map<Function*, FuncInfo> funcInfos =
268317 analyzeFuncs (*module , getPassOptions ());
269318
270- auto callGraph = buildCallGraph (*module , funcInfos);
319+ auto callGraph =
320+ buildCallGraph (*module , funcInfos, getPassOptions ().closedWorld );
271321
272322 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
273323
0 commit comments