Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/main/scala/viper/silver/ast/utility/Simplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

package viper.silver.ast.utility

import viper.silver.ast.utility.Statements.EmptyStmt
import viper.silver.ast._
import viper.silver.ast.utility.rewriter._

Expand All @@ -15,14 +16,15 @@ import viper.silver.ast.utility.rewriter._
object Simplifier {

/**
* Simplify `expression`, in particular by making use of literals. For
* Simplify `n`, in particular by making use of literals. For
* example, `!true` is replaced by `false`. Division and modulo with divisor
* 0 are not treated. Note that an expression with non-terminating evaluation due to endless recursion
* might be transformed to terminating expression.
*/
def simplify(expression: Exp): Exp = {
def simplify[N <: Node](n: N): N = {
/* Always simplify children first, then treat parent. */
StrategyBuilder.Slim[Node]({
// expression simplifications
case root @ Not(BoolLit(literal)) =>
BoolLit(!literal)(root.pos, root.info)
case Not(Not(single)) => single
Expand Down Expand Up @@ -112,7 +114,16 @@ object Simplifier {
IntLit(left / right)(root.pos, root.info)
case root @ Mod(IntLit(left), IntLit(right)) if right != bigIntZero =>
IntLit((right.abs + (left % right)) % right.abs)(root.pos, root.info)
}, Traverse.BottomUp) execute[Exp](expression)

// statement simplifications
case Seqn(EmptyStmt, _) => EmptyStmt // remove empty Seqn (including unnecessary scopedDecls)
case s@Seqn(ss, scopedDecls) if ss.contains(EmptyStmt) => // remove empty statements
val newSS = ss.filterNot(_ == EmptyStmt)
Seqn(newSS, scopedDecls)(s.pos, s.info, s.errT)
case If(_, EmptyStmt, EmptyStmt) => EmptyStmt // remove empty If clause
case If(TrueLit(), thn, _) => thn // remove trivial If conditions
case If(FalseLit(), _, els) => els // remove trivial If conditions
}, Traverse.BottomUp) execute n
}

private val bigIntZero = BigInt(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,25 @@

package viper.silver.plugin.standard.termination.transformation

import viper.silver.ast.{Exp, Stmt, _}
import viper.silver.ast._
import viper.silver.ast.utility.Statements.EmptyStmt
import viper.silver.ast.utility.ViperStrategy
import viper.silver.ast.utility.rewriter.{ContextCustom, RepeatedStrategy, Strategy, Traverse}
import viper.silver.ast.utility.{Expressions, ViperStrategy}
import viper.silver.ast.utility.rewriter.{ContextCustom, Strategy, Traverse}
import viper.silver.verifier.ConsistencyError

/**
* A basic interface which helps to rewrite an expression (e.g. a function body) into a stmt (e.g. for a method body).
* Some default transformations are already implemented.
*/
trait ExpTransformer extends ErrorReporter {
trait ExpTransformer extends ProgramManager with ErrorReporter {

/**
* Transforms an expression into a statement.
*
* @return a statement representing the expression.
*/
def transformExp: PartialFunction[(Exp, ExpressionContext), Stmt] = {
case (CondExp(cond, thn, els), c) =>
def transformExp(e: Exp, c: ExpressionContext): Stmt = e match {
case CondExp(cond, thn, els) =>
val condStmt = transformExp(cond, c)
val thnStmt = transformExp(thn, c)
val elsStmt = transformExp(els, c)
Expand All @@ -32,23 +33,31 @@ trait ExpTransformer extends ErrorReporter {

val stmts = Seq(condStmt, ifStmt)
Seqn(stmts, Nil)()
case (Unfolding(acc, unfBody), c) =>
case Unfolding(acc, unfBody) =>
val permCheck = transformExp(acc.perm, c)
val unfoldBody = transformExp(unfBody, c)
// only unfold and fold if body contains something
val (unfold, fold) = (Unfold(acc)(), Fold(acc)())

val stmts = Seq(permCheck, unfold, unfoldBody, fold)
Seqn(stmts, Nil)()
case (inex: InhaleExhaleExp, c) =>
case Applying(wand, body) =>
// note that this case is untested -- it's not possible to write a function with an `applying` expression
val nonDetVarDecl = LocalVarDecl(uniqueName("b"), Bool)(e.pos, e.info, e.errT)
val bodyStmt = transformExp(body, c)
val killBranchStmt = Inhale(FalseLit()(e.pos, e.info, e.errT))(e.pos, e.info, e.errT)
val thnStmt = Seqn(Seq(Apply(wand)(e.pos, e.info, e.errT), bodyStmt, killBranchStmt), Nil)()
val ifStmt = If(nonDetVarDecl.localVar, thnStmt, EmptyStmt)(e.pos, e.info, e.errT)
Seqn(Seq(ifStmt), Seq(nonDetVarDecl))(e.pos, e.info, e.errT)
case inex: InhaleExhaleExp =>
val inhaleStmt = transformExp(inex.in, c)
val exhaleStmt = transformExp(inex.ex, c)

c.conditionInEx match {
case Some(conditionVar) => If(conditionVar.localVar, Seqn(Seq(inhaleStmt), Nil)(), Seqn(Seq(exhaleStmt), Nil)())()
case None => Seqn(Seq(inhaleStmt, exhaleStmt), Nil)()
}
case (letExp: Let, c) =>
case letExp: Let =>
val expressionStmt = transformExp(letExp.exp, c)
val localVarDecl = letExp.variable

Expand All @@ -58,7 +67,7 @@ trait ExpTransformer extends ErrorReporter {

Seqn(Seq(expressionStmt, inhaleEq, bodyStmt), Seq(localVarDecl))()

case (b: BinExp, c) =>
case b: BinExp =>
val left = transformExp(b.left, c)
val right = transformExp(b.right, c)

Expand All @@ -75,80 +84,62 @@ trait ExpTransformer extends ErrorReporter {
Seqn(Seq(right), Nil)()
}
Seqn(Seq(left, rightSCE), Nil)()
case (sq: SeqExp, c) => sq match {
case ExplicitSeq(elems) =>
Seqn(elems.map(transformExp(_, c)), Nil)(sq.pos)
case RangeSeq(low, high) =>
Seqn(Seq(transformExp(low, c),
transformExp(high, c)), Nil)(sq.pos)
case SeqAppend(left, right) =>
Seqn(Seq(transformExp(left, c),
transformExp(right, c)), Nil)(sq.pos)
case SeqIndex(s, idx) =>
Seqn(Seq(transformExp(s, c),
transformExp(idx, c)), Nil)(sq.pos)
case SeqTake(s, n) =>
Seqn(Seq(transformExp(s, c),
transformExp(n, c)), Nil)(sq.pos)
case SeqDrop(s, n) =>
Seqn(Seq(transformExp(s, c),
transformExp(n, c)), Nil)(sq.pos)
case SeqContains(elem, s) =>
Seqn(Seq(transformExp(elem, c),
transformExp(s, c)), Nil)(sq.pos)
case SeqUpdate(s, idx, elem) =>
Seqn(Seq(transformExp(s, c),
transformExp(idx, c),
transformExp(elem, c)), Nil)(sq.pos)
case SeqLength(s) =>
Seqn(Seq(transformExp(s, c)), Nil)(sq.pos)
case EmptySeq(_) => EmptyStmt
case unsupportedExp => transformUnknownExp(unsupportedExp, c)
EmptyStmt
}
case (mp: MapExp, c) => mp match {
case EmptyMap(_, _) => EmptyStmt
case ExplicitMap(elems) => Seqn(elems.map(transformExp(_, c)), Nil)(mp.pos)
case Maplet(key, value) => Seqn(Seq(transformExp(key, c), transformExp(value, c)), Nil)(mp.pos)
case MapCardinality(base) => Seqn(Seq(transformExp(base, c)), Nil)(mp.pos)
case MapContains(key, base) => Seqn(Seq(transformExp(key, c), transformExp(base, c)), Nil)(mp.pos)
case MapLookup(base, key) => Seqn(Seq(transformExp(base, c), transformExp(key, c)), Nil)(mp.pos)
case MapUpdate(base, key, value) => Seqn(Seq(transformExp(base, c), transformExp(key, c), transformExp(value, c)), Nil)(mp.pos)
}
case (st: ExplicitSet, c) =>
Seqn(st.elems.map(transformExp(_, c)), Nil)(st.pos)
case (mst: ExplicitMultiset, c) =>
Seqn(mst.elems.map(transformExp(_, c)), Nil)(mst.pos)
case (md: MapDomain, c) => Seqn(Seq(transformExp(md.base, c)), Nil)(md.pos)
case (mr: MapRange, c) => Seqn(Seq(transformExp(mr.base, c)), Nil)(mr.pos)
case (u: UnExp, c) => transformExp(u.exp, c)
case (_: Literal, _) => EmptyStmt
case (_: AbstractLocalVar, _) => EmptyStmt
case (_: AbstractConcretePerm, _) => EmptyStmt
case (_: LocationAccess, _) => EmptyStmt

case (ap: AccessPredicate, c) =>
val check = transformExp(ap.perm, c)

val inhale = Inhale(ap)(ap.pos)
case _: Literal => EmptyStmt
case _: AbstractLocalVar => EmptyStmt
case _: AbstractConcretePerm => EmptyStmt
case _: WildcardPerm => EmptyStmt
case _: EpsilonPerm => EmptyStmt
case _: CurrentPerm => EmptyStmt
case _: LocationAccess => EmptyStmt

case ap: AccessPredicate =>
val check = transformExp(ap.perm, c)
val inhale = Inhale(ap)(ap.pos)
Seqn(Seq(check, inhale), Nil)()
case (fa: FuncLikeApp, c) =>
case fa: Forall =>
// we turn the quantified variables into local variables with arbitrary value and show that the expression holds
// for arbitrary values, which is similar to a forall introduction
val (localDeclMapping, transformedExp) = substituteWithFreshVars(fa.variables, fa.exp)
val expressionStmt = transformExp(transformedExp, c)
Seqn(Seq(expressionStmt), localDeclMapping.map(_._2))(fa.pos, fa.info, fa.errT)
case fp: ForPerm =>
// let's pick arbitrary values for the quantified variables and check the body given that the current heap has
// sufficient permissions
val (localDeclMapping, transformedExp) = substituteWithFreshVars(fp.variables, fp.exp)
val transformedRes = applySubstitution(localDeclMapping, fp.resource)
val expressionStmt = transformExp(transformedExp, c)
val killBranchStmt = Inhale(FalseLit()(e.pos, e.info, e.errT))(e.pos, e.info, e.errT)
val thnStmt = Seqn(Seq(expressionStmt, killBranchStmt), Nil)(e.pos, e.info, e.errT)
val ifCond = GtCmp(CurrentPerm(transformedRes)(e.pos, e.info, e.errT), NoPerm()(e.pos, e.info, e.errT))(e.pos, e.info, e.errT)
val ifStmt = If(ifCond, thnStmt, EmptyStmt)(e.pos, e.info, e.errT)
Seqn(Seq(ifStmt), localDeclMapping.map(_._2))(e.pos, e.info, e.errT)
case ex: Exists =>
// we perform existential elimination by retrieving witnesses for the quantified variables
val (localDeclMapping, transformedExp) = substituteWithFreshVars(ex.variables, ex.exp)
// we can't use an assume statement at this point because the `assume`s have already been rewritten
// furthermore, Viper only allows pure existentially quantified expressions
val inhaleWitnesses = Inhale(transformedExp)(ex.pos, ex.info, ex.errT)
val expressionStmt = transformExp(transformedExp, c)
Seqn(Seq(inhaleWitnesses, expressionStmt), localDeclMapping.map(_._2))(ex.pos, ex.info, ex.errT)
case fa: FuncLikeApp =>
val argStmts = fa.args.map(transformExp(_, c))
Seqn(argStmts, Nil)()
case (unknownExp, c) =>
transformUnknownExp(unknownExp, c)
case e: ExtensionExp => reportUnsupportedExp(e)

case _ =>
val sub = e.subExps.map(transformExp(_, c))
Seqn(sub, Nil)()
}

/**
* Expression transformer if no default is defined.
* Calls transformExp on all subExps of e.
* To change or extend the default transformer for unknown expressions
* override this method (and possibly combine it with super.transformUnknownExp).
*/
def transformUnknownExp(e: Exp, c: ExpressionContext): Stmt = {
val sub = e.subExps.map(transformExp(_, c))
Seqn(sub, Nil)()
* Issues a consistency error for unsupported expressions.
*
* @param unsupportedExp to be reported.
*/
def reportUnsupportedExp(unsupportedExp: Exp): Stmt = {
reportError(ConsistencyError("Unsupported expression detected: " + unsupportedExp + ", " + unsupportedExp.getClass, unsupportedExp.pos))
EmptyStmt
}

/**
Expand Down Expand Up @@ -182,37 +173,27 @@ trait ExpTransformer extends ErrorReporter {
case Forall(_, _, exp) => Seq(exp)
}

/**
* Turns `vars` into new local variable declarations with a unique name and replaces the corresponding local variable uses in `exp`.
* Returns a mapping of old variable declarations to new ones and the transformed expression
*/
protected def substituteWithFreshVars[E <: Exp](vars: Seq[LocalVarDecl], exp: E): (Seq[(LocalVarDecl, LocalVarDecl)], E) = {
val declMapping = vars.map(oldDecl =>
oldDecl -> LocalVarDecl(uniqueName(oldDecl.name), oldDecl.typ)(oldDecl.pos, oldDecl.info, oldDecl.errT))
val transformedExp = applySubstitution(declMapping, exp)
(declMapping, transformedExp)
}

/**
* The simplifyStmts Strategy can be used to simplify statements
* by e.g. removing or combining nested Seqn or If clauses.
* This is in particular useful for the expression to statement transformer
* because this often creates nested and empty Seqn and If clauses.
*
* This is a repeating strategy because the removal of unneccessary unfold/fold statements requires EmptyStmts
* to be removed beforehand. Therefore, it should only be used with an maximum number of iterations (2).
*
*/
val simplifyStmts: RepeatedStrategy[Node] = ViperStrategy.Slim({
case Seqn(EmptyStmt, _) => // remove empty Seqn (including unnecessary scopedDecls)
EmptyStmt
case s@Seqn(Seq(Seqn(ss, scopedDecls2)), scopedDecls1) => // combine nested Seqn
Seqn(ss, scopedDecls1 ++ scopedDecls2)(s.pos, s.info, s.errT)
case s@Seqn(ss, scopedDecls) if ss.contains(EmptyStmt) => // remove empty statements
val newSS = ss.filterNot(_ == EmptyStmt)
Seqn(newSS, scopedDecls)(s.pos, s.info, s.errT)
case Seqn(Seq(Unfold(acc1), Fold(acc2)), _) if acc1 == acc2 => // remove unfold/fold with nothing in between
EmptyStmt
case If(_, EmptyStmt, EmptyStmt) => // remove empty If clause
EmptyStmt
case i@If(c, EmptyStmt, els) => // change If with only els to If with only thn
If(Not(c)(c.pos), els, EmptyStmt)(i.pos, i.info, i.errT)
case i@If(c1, Seqn(Seq(If(c2, thn, EmptyStmt)), Nil), EmptyStmt) => // combine nested if clauses
If(And(c1, c2)(), thn, EmptyStmt)(i.pos, i.info, i.errT)
}, Traverse.BottomUp)
.repeat
* Replaces uses of local variables in `exp` based on `mapping`.
* `mapping` maps local variable declarations to new declarations and this transformation replaces the corresponding
* local variable uses.
*/
protected def applySubstitution[E <: Exp](mapping: Seq[(LocalVarDecl, LocalVarDecl)], exp: E): E = {
Expressions.instantiateVariables(exp, mapping.map(_._1.localVar), mapping.map(_._2.localVar))
}
}

trait ExpressionContext {
val conditionInEx: Option[LocalVarDecl]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ package viper.silver.plugin.standard.termination.transformation
import org.jgrapht.graph.{DefaultDirectedGraph, DefaultEdge}
import viper.silver.ast.utility.Statements.EmptyStmt
import viper.silver.ast.utility.rewriter.Traverse
import viper.silver.ast.utility.ViperStrategy
import viper.silver.ast.utility.{Simplifier, ViperStrategy}
import viper.silver.ast.{And, Bool, ErrTrafo, Exp, FalseLit, FuncApp, Function, LocalVarDecl, Method, Node, NodeTrafo, Old, Result, Seqn, Stmt}
import viper.silver.plugin.standard.termination.{DecreasesSpecification, FunctionTerminationError}
import viper.silver.verifier.ConsistencyError
import viper.silver.verifier.errors.AssertFailed

trait FunctionCheck extends ProgramManager with DecreasesCheck with ExpTransformer with NestedPredicates with ErrorReporter {
Expand Down Expand Up @@ -54,7 +53,7 @@ trait FunctionCheck extends ProgramManager with DecreasesCheck with ExpTransform
val context = FContext(f)

val proofMethodBody: Stmt = {
val stmt: Stmt = simplifyStmts.execute(transformExp(f.body.get, context))
val stmt: Stmt = Simplifier.simplify(transformExp(f.body.get, context))
if (requireNestedInfo) {
addNestedPredicateInformation.execute(stmt)
} else {
Expand Down Expand Up @@ -89,7 +88,7 @@ trait FunctionCheck extends ProgramManager with DecreasesCheck with ExpTransform
.reduce((e, p) => And(e, p)())

val proofMethodBody: Stmt = {
val stmt: Stmt = simplifyStmts.execute(transformExp(posts, context))
val stmt: Stmt = Simplifier.simplify(transformExp(posts, context))
if (requireNestedInfo) {
addNestedPredicateInformation.execute(stmt)
} else {
Expand All @@ -114,7 +113,7 @@ trait FunctionCheck extends ProgramManager with DecreasesCheck with ExpTransform
// concatenate all pres
val pres = f.pres.reduce((e, p) => And(e, p)())

val proofMethodBody: Stmt = simplifyStmts.execute(transformExp(pres, context))
val proofMethodBody: Stmt = Simplifier.simplify(transformExp(pres, context))

if (proofMethodBody != EmptyStmt) {
val proofMethod = Method(proofMethodName, f.formalArgs, Nil, Nil, Nil,
Expand All @@ -137,8 +136,8 @@ trait FunctionCheck extends ProgramManager with DecreasesCheck with ExpTransform
*
* @return a statement representing the expression
*/
override val transformExp: PartialFunction[(Exp, ExpressionContext), Stmt] = {
case (functionCall: FuncApp, context: FunctionContext) =>
override def transformExp(e: Exp, c: ExpressionContext): Stmt = (e, c) match {
case (functionCall: FuncApp, context: FunctionContext @unchecked) =>
val stmts = collection.mutable.ArrayBuffer[Stmt]()

// check the arguments
Expand Down Expand Up @@ -219,21 +218,7 @@ trait FunctionCheck extends ProgramManager with DecreasesCheck with ExpTransform
// should not happen
}
Seqn(stmts.toSeq, Nil)()
case default => super.transformExp(default)
}

override def transformUnknownExp(e: Exp, c: ExpressionContext): Stmt = {
reportUnsupportedExp(e)
EmptyStmt
}

/**
* Issues a consistency error for unsupported expressions.
*
* @param unsupportedExp to be reported.
*/
def reportUnsupportedExp(unsupportedExp: Exp): Unit = {
reportError(ConsistencyError("Unsupported expression detected: " + unsupportedExp + ", " + unsupportedExp.getClass, unsupportedExp.pos))
case _ => super.transformExp(e, c)
}

// context creator
Expand Down
Loading