2121
2222#include " ir/effects.h"
2323#include " ir/module-utils.h"
24+ #include " ir/subtypes.h"
2425#include " pass.h"
2526#include " support/strongly_connected_components.h"
2627#include " wasm.h"
@@ -39,6 +40,9 @@ struct FuncInfo {
3940
4041 // Directly-called functions from this function.
4142 std::unordered_set<Name> calledFunctions;
43+
44+ // Types that are targets of indirect calls.
45+ std::unordered_set<HeapType> indirectCalledTypes;
4246};
4347
4448std::map<Function*, FuncInfo> analyzeFuncs (Module& module ,
@@ -83,11 +87,21 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
8387 if (auto * call = curr->dynCast <Call>()) {
8488 // Note the direct call.
8589 funcInfo.calledFunctions .insert (call->target );
90+ } else if (effects.calls && options.closedWorld ) {
91+ HeapType type;
92+ if (auto * callRef = curr->dynCast <CallRef>()) {
93+ type = callRef->target ->type .getHeapType ();
94+ } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
95+ // nullability doesn't matter here
96+ // call_indirect is always inexact
97+ type = callIndirect->heapType ;
98+ } else {
99+ assert (false && " Unexpected type of call" );
100+ }
101+
102+ funcInfo.indirectCalledTypes .insert (type);
86103 } else if (effects.calls ) {
87- // This is an indirect call of some sort, so we must assume the
88- // worst. To do so, clear the effects, which indicates nothing
89- // is known (so anything is possible).
90- // TODO: We could group effects by function type etc.
104+ assert (!options.closedWorld );
91105 funcInfo.effects = UnknownEffects;
92106 } else {
93107 // No call here, but update throwing if we see it. (Only do so,
@@ -107,14 +121,45 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
107121 return std::move (analysis.map );
108122}
109123
110- std::unordered_map<Function*, std::unordered_set<Function*>>
111- buildCallGraph (const Module& module ,
112- const std::map<Function*, FuncInfo>& funcInfos) {
113- std::unordered_map<Function*, std::unordered_set<Function*>> callGraph;
114- for (const auto & [func, info] : funcInfos) {
115- for (Name callee : info.calledFunctions ) {
116- callGraph[func].insert (module .getFunction (callee));
124+ using CallGraphNode = std::variant<Function*, HeapType>;
125+
126+ // Build a call graph for indirect and direct calls.
127+ // key (callee) -> value (caller)
128+ // Name -> Name : callee is called directly by caller
129+ // Name -> HeapType : callee is a potential target of a virtual call
130+ // with this HeapType
131+ // HeapType -> Name : callee is indirectly called by caller
132+ // HeapType -> HeapType : callee is a subtype of caller If we're
133+ // running in an open world, we only include Name -> Name edges.
134+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
135+ buildCallGraph (Module& module ,
136+ const std::map<Function*, FuncInfo>& funcInfos,
137+ bool closedWorld) {
138+ std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
139+ callGraph;
140+ std::unordered_set<HeapType> allFunctionTypes;
141+ for (const auto & [caller, callerInfo] : funcInfos) {
142+ for (Name calleeFunction : callerInfo.calledFunctions ) {
143+ callGraph[caller].insert (module .getFunction (calleeFunction));
144+ }
145+ if (!closedWorld) {
146+ continue ;
147+ }
148+
149+ allFunctionTypes.insert (caller->type .getHeapType ());
150+ for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
151+ callGraph[caller].insert (calleeType);
152+ allFunctionTypes.insert (calleeType);
117153 }
154+ callGraph[caller->type .getHeapType ()].insert (caller);
155+ }
156+
157+ SubTypes subtypes (module );
158+ for (HeapType type : allFunctionTypes) {
159+ subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
160+ callGraph[type].insert (sub);
161+ return true ;
162+ });
118163 }
119164
120165 return callGraph;
@@ -123,98 +168,108 @@ buildCallGraph(const Module& module,
123168// Propagate effects from callees to callers transitively
124169// e.g. if A -> B -> C (A calls B which calls C)
125170// Then B inherits effects from C and A inherits effects from both B and C.
126- //
127- // Generate SCC for the call graph, then traverse it in reverse topological
128- // order processing each callee before its callers. When traversing:
129- // - Merge all of the effects of functions within the CC
130- // - Also merge the (already computed) effects of each callee CC
131- // - Add trap effects for potentially recursive call chains
132171void propagateEffects (
133- const Module& module ,
172+ Module& module ,
134173 const PassOptions& passOptions,
135174 std::map<Function*, FuncInfo>& funcInfos,
136- const std::unordered_map<Function* , std::unordered_set<Function*>>
175+ const std::unordered_map<CallGraphNode , std::unordered_set<CallGraphNode>>&
137176 callGraph) {
177+
138178 struct CallGraphSCCs
139- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs> {
179+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs> {
140180 const std::map<Function*, FuncInfo>& funcInfos;
141- const std::unordered_map<Function*, std::unordered_set<Function*>>&
142- callGraph;
143181 const Module& module ;
182+ const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
183+ callGraph;
144184
145185 CallGraphSCCs (
146- const std::vector<Function* >& funcs ,
186+ const std::vector<CallGraphNode >& nodes ,
147187 const std::map<Function*, FuncInfo>& funcInfos,
148- const std::unordered_map<Function*, std::unordered_set<Function*>>&
149- callGraph ,
150- const Module& module )
151- : SCCs<std::vector<Function* >::const_iterator, CallGraphSCCs>(
152- funcs .begin(), funcs .end()),
153- funcInfos (funcInfos), callGraph(callGraph ), module ( module ) {}
154-
155- void pushChildren (Function* f ) {
156- auto callees = callGraph.find (f );
188+ Module& module ,
189+ const std::unordered_map<CallGraphNode ,
190+ std::unordered_set<CallGraphNode>>& callGraph )
191+ : SCCs<std::vector<CallGraphNode >::const_iterator, CallGraphSCCs>(
192+ nodes .begin(), nodes .end()),
193+ funcInfos (funcInfos), module ( module ), callGraph(callGraph ) {}
194+
195+ void pushChildren (CallGraphNode node ) {
196+ auto callees = callGraph.find (node );
157197 if (callees == callGraph.end ()) {
158198 return ;
159199 }
160-
161- for (auto * callee : callees->second ) {
200+ for (const auto & callee : callees->second ) {
162201 push (callee);
163202 }
164203 }
165204 };
166205
167- std::vector<Function*> allFuncs;
206+ std::vector<CallGraphNode> funcs;
207+ // We only care about Functions that are roots, not types
168208 for (auto & [func, info] : funcInfos) {
169- allFuncs .push_back (func);
209+ funcs .push_back (func);
170210 }
171- CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
172211
173- std::unordered_map<Function*, int > sccMembers;
212+ CallGraphSCCs sccs (funcs, funcInfos, module , callGraph);
213+
214+ std::unordered_map<CallGraphNode, int > sccMembers;
174215 std::unordered_map<int , std::optional<EffectAnalyzer>> componentEffects;
175216
176217 int ccIndex = 0 ;
177218 for (auto ccIterator : sccs) {
219+ std::vector<CallGraphNode> cc (ccIterator.begin (), ccIterator.end ());
178220 ccIndex++;
179221 std::optional<EffectAnalyzer>& ccEffects = componentEffects[ccIndex];
180- std::vector<Function*> ccFuncs (ccIterator.begin (), ccIterator.end ());
181-
182222 ccEffects.emplace (passOptions, module );
183223
224+ std::vector<Function*> ccFuncs;
225+ std::vector<HeapType> ccTypes;
226+ for (auto v : cc) {
227+ if (auto ** func = std::get_if<Function*>(&v)) {
228+ ccFuncs.push_back (*func);
229+ } else {
230+ ccTypes.push_back (std::get<HeapType>(v));
231+ }
232+ }
233+
184234 for (Function* f : ccFuncs) {
185235 sccMembers.emplace (f, ccIndex);
186236 }
237+ for (HeapType t : ccTypes) {
238+ sccMembers.emplace (t, ccIndex);
239+ }
187240
188241 std::unordered_set<int > calleeSccs;
189- for (Function* caller : ccFuncs ) {
242+ for (const auto & caller : cc ) {
190243 auto callees = callGraph.find (caller);
191- if (callees == callGraph.end ()) {
192- continue ;
193- }
194- for (auto * callee : callees->second ) {
195- calleeSccs.insert (sccMembers.at (callee));
244+ if (callees != callGraph.end ()) {
245+ for (const auto & callee : callees->second ) {
246+ auto sccIt = sccMembers.find (callee);
247+ if (sccIt != sccMembers.end ()) {
248+ calleeSccs.insert (sccIt->second );
249+ }
250+ }
196251 }
197252 }
198253
199- // Merge in effects from callees
200254 for (int calleeScc : calleeSccs) {
201255 const auto & calleeComponentEffects = componentEffects.at (calleeScc);
202256 if (calleeComponentEffects == UnknownEffects) {
203257 ccEffects = UnknownEffects;
204258 break ;
205259 }
206260
207- else if (ccEffects != UnknownEffects ) {
261+ else if (ccEffects) {
208262 ccEffects->mergeIn (*calleeComponentEffects);
209263 }
210264 }
211265
212266 // Add trap effects for potential cycles.
213- if (ccFuncs .size () > 1 ) {
267+ if (cc .size () > 1 ) {
214268 if (ccEffects != UnknownEffects) {
215269 ccEffects->trap = true ;
216270 }
217- } else {
271+ // A cycle isn't possible for a CC that only contains a type
272+ } else if (ccFuncs.size () == 1 ) {
218273 auto * func = ccFuncs[0 ];
219274 if (funcInfos.at (func).calledFunctions .contains (func->name )) {
220275 if (ccEffects != UnknownEffects) {
@@ -223,8 +278,7 @@ void propagateEffects(
223278 }
224279 }
225280
226- // Aggregate effects within this CC
227- if (ccEffects) {
281+ if (ccEffects)
228282 for (Function* f : ccFuncs) {
229283 const auto & effects = funcInfos.at (f).effects ;
230284 if (effects == UnknownEffects) {
@@ -234,9 +288,7 @@ void propagateEffects(
234288
235289 ccEffects->mergeIn (*effects);
236290 }
237- }
238291
239- // Assign each function's effects to its CC effects.
240292 for (Function* f : ccFuncs) {
241293 if (!ccEffects) {
242294 funcInfos.at (f).effects = UnknownEffects;
@@ -247,27 +299,26 @@ void propagateEffects(
247299 }
248300}
249301
250- void copyEffectsToFunctions (const std::map<Function*, FuncInfo> funcInfos) {
251- for (auto & [func, info] : funcInfos) {
252- func->effects .reset ();
253- if (!info.effects ) {
254- continue ;
255- }
256-
257- func->effects = std::make_shared<EffectAnalyzer>(*info.effects );
258- }
259- }
260-
261302struct GenerateGlobalEffects : public Pass {
262303 void run (Module* module ) override {
263304 std::map<Function*, FuncInfo> funcInfos =
264305 analyzeFuncs (*module , getPassOptions ());
265306
266- auto callGraph = buildCallGraph (*module , funcInfos);
307+ auto callGraph =
308+ buildCallGraph (*module , funcInfos, getPassOptions ().closedWorld );
267309
268310 propagateEffects (*module , getPassOptions (), funcInfos, callGraph);
269311
270- copyEffectsToFunctions (funcInfos);
312+ // Generate the final data, starting from a blank slate where nothing is
313+ // known.
314+ for (auto & [func, info] : funcInfos) {
315+ func->effects .reset ();
316+ if (!info.effects ) {
317+ continue ;
318+ }
319+
320+ func->effects = std::make_shared<EffectAnalyzer>(*info.effects );
321+ }
271322 }
272323};
273324
0 commit comments