@@ -12,7 +12,7 @@ import org.jgrapht.alg.connectivity.GabowStrongConnectivityInspector
1212import org .jgrapht .graph .{DefaultDirectedGraph , DefaultEdge }
1313import org .jgrapht .traverse .TopologicalOrderIterator
1414
15- import scala .collection .mutable .{Set => MSet , ListBuffer }
15+ import scala .collection .mutable .{ListBuffer , Set => MSet }
1616import scala .jdk .CollectionConverters ._
1717
1818/**
@@ -47,11 +47,10 @@ object Functions {
4747
4848 /** Returns the call graph of a given program (also considering specifications as calls).
4949 *
50- * TODO: Memoize invocations of `getFunctionCallgraph`.
50+ * TODO: Memoize invocations of `getFunctionCallgraph`. Note that it's unclear how to derive a useful key from `subs`
5151 */
5252 def getFunctionCallgraph (program : Program , subs : Function => Seq [Exp ] = allSubexpressions)
5353 : DefaultDirectedGraph [Function , DefaultEdge ] = {
54-
5554 val graph = new DefaultDirectedGraph [Function , DefaultEdge ](classOf [DefaultEdge ])
5655
5756 for (f <- program.functions) {
@@ -214,7 +213,7 @@ object Functions {
214213 }
215214
216215 /** Returns all cycles formed by functions that (transitively through certain subexpressions)
217- * recurses via certain expressions.
216+ * recurse via certain expressions.
218217 *
219218 * @param program The program that defines the functions to check for cycles.
220219 * @param via The expression the cycle has to go through.
@@ -233,12 +232,34 @@ object Functions {
233232
234233 program.functions.flatMap(func => {
235234 val graph = getFunctionCallgraph(program, viaSubs(func))
236- val cycleDetector = new CycleDetector (graph)
237- val cycle = cycleDetector.findCyclesContainingVertex(func).asScala
238- if (cycle.isEmpty)
239- None
240- else
241- Some (func -> cycle.toSet)
235+ findCycles(graph, func)
242236 }).toMap[Function , Set [Function ]]
243237 }
238+
239+ /** Returns all cycles formed by functions that (transitively through certain subexpressions)
240+ * recurse via certain expressions. This is an optimized version of `findFunctionCyclesVia` in case
241+ * `via` and `subs` are equivalent.
242+ *
243+ * @param program The program that defines the functions to check for cycles.
244+ * @param via The expression the cycle has to go through.
245+ * @return A map from functions to sets of functions. If a function `f` maps to a set of
246+ * functions `fs`, then `f` (transitively) recurses via, and the
247+ * formed cycles involves the set of functions `fs`.
248+ */
249+ def findFunctionCyclesViaOptimized (program : Program , via : Function => Seq [Exp ])
250+ : Map [Function , Set [Function ]] = {
251+ val graph = getFunctionCallgraph(program, via)
252+ program.functions.flatMap(func => {
253+ findCycles(graph, func)
254+ }).toMap[Function , Set [Function ]]
255+ }
256+
257+ private def findCycles (graph : DefaultDirectedGraph [Function , DefaultEdge ], func : Function ): Option [(Function , Set [Function ])] = {
258+ val cycleDetector = new CycleDetector (graph)
259+ val cycle = cycleDetector.findCyclesContainingVertex(func).asScala
260+ if (cycle.isEmpty)
261+ None
262+ else
263+ Some (func -> cycle.toSet)
264+ }
244265}
0 commit comments