Skip to content

Commit 8a7d18d

Browse files
authored
Merge pull request #428 from viperproject/meilers_smt_types
Backend support for SMTLib types (particularly bitvectors and floats)
2 parents 6450b61 + 075fd4b commit 8a7d18d

9 files changed

Lines changed: 308 additions & 5 deletions

File tree

src/main/scala/viper/silver/ast/Expression.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,17 @@ object DomainFuncApp {
385385
DomainFuncApp(func.name,args,typVarMap)(pos, info, func.typ.substitute(typVarMap), func.domainName, errT)
386386
}
387387

388+
// --- References to backend (i.e., SMTLIB or Boogie 'builtin') functions
389+
390+
case class BackendFuncApp(backendFunc: BackendFunc, args: Seq[Exp])
391+
(val pos: Position = NoPosition, val info: Info = NoInfo, val errT: ErrorTrafo = NoTrafos)
392+
extends AbstractDomainFuncApp {
393+
override lazy val check : Seq[ConsistencyError] = args.flatMap(Consistency.checkPure)
394+
override def func = (p: Program) => backendFunc
395+
def funcname = backendFunc.name
396+
override def typ = backendFunc.typ
397+
}
398+
388399
// --- Field and predicate accesses
389400

390401
/** A common trait for expressions accessing a location. */

src/main/scala/viper/silver/ast/Program.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,13 @@ case object NotOp extends UnOp with BoolDomainFunc {
770770
lazy val fixity = Prefix
771771
}
772772

773+
774+
case class BackendFunc(name: String, smtName: String, override val typ: Type, override val formalArgs: Seq[LocalVarDecl])
775+
extends Node with AbstractDomainFunc with BuiltinDomainFunc
776+
773777
/**
774778
* The Extension Member trait provides the way to expand the Ast to include new Top Level declarations
775779
*/
776780
trait ExtensionMember extends Member{
777781
def extensionSubnodes: Seq[Node]
778-
}
782+
}

src/main/scala/viper/silver/ast/Type.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ case class TypeVar(name: String) extends Type {
183183
//def !=(other: TypeVar) = name != other
184184
}
185185

186+
case class BackendType(boogieName: String, smtName: String) extends AtomicType
187+
186188
trait ExtensionType extends Type{
187189
def getAstType: Type = ???
188190
}

src/main/scala/viper/silver/ast/pretty/PrettyPrinter.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,8 @@ object FastPrettyPrinter extends FastPrettyPrinterBase with BracketPrettyPrinter
624624
case dt@DomainType(domainName, typVarsMap) =>
625625
val typArgs = dt.typeParameters map (t => show(typVarsMap.getOrElse(t, t)))
626626
text(domainName) <> (if (typArgs.isEmpty) nil else brackets(ssep(typArgs, char (',') <> space)))
627+
case BackendType(boogieName, _) if boogieName != null => boogieName
628+
case BackendType(_, smtName) => smtName
627629
}
628630
}
629631

@@ -768,13 +770,14 @@ object FastPrettyPrinter extends FastPrettyPrinterBase with BracketPrettyPrinter
768770
text("acc") <> parens(show(loc) <> "," <+> show(perm))
769771
case FuncApp(funcname, args) =>
770772
text(funcname) <> parens(ssep(args map show, char (',') <> space))
771-
case DomainFuncApp(funcname, args, tvMap) =>
773+
case dfa@DomainFuncApp(funcname, args, tvMap) =>
772774
if (tvMap.nonEmpty)
773775
// Type may be underconstrained, so to be safe we explicitly print out the type.
774-
text(funcname) <> parens(ssep(args map show, char (',') <> space))
776+
parens(text(funcname) <> parens(ssep(args map show, char (',') <> space)) <> char(':') <+> show(dfa.typ))
775777
else
776778
text(funcname) <> parens(ssep(args map show, char (',') <> space))
777-
779+
case BackendFuncApp(func, args) =>
780+
text(func.name) <> parens(ssep(args map show, char(',') <> space))
778781
case EmptySeq(elemTyp) =>
779782
text("Seq[") <> showType(elemTyp) <> "]()"
780783
case ExplicitSeq(elems) =>
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package viper.silver.ast.utility
2+
3+
import viper.silver.ast.{Bool, Int, LocalVarDecl, BackendFunc, BackendType}
4+
5+
/**
6+
* A factory for fixed-size bitvectors that offers convenient access to bitvector types, as well as
7+
* function definitions for unary and binary functions on bitvectors, as well as conversions from and
8+
* to integers.
9+
*/
10+
case class BVFactory(size: Int) {
11+
lazy val typ = BackendType(s"bv${size}", s"(_ BitVec ${size})")
12+
13+
def xor(name: String) = BackendFunc(name, "bvxor", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
14+
def and(name: String) = BackendFunc(name, "bvand", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
15+
def or(name: String) = BackendFunc(name, "bvor", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
16+
def add(name: String) = BackendFunc(name, "bvadd", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
17+
def mul(name: String) = BackendFunc(name, "bvmul", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
18+
def shl(name: String) = BackendFunc(name, "bvshl", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
19+
def shr(name: String) = BackendFunc(name, "bvshr", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
20+
21+
def not(name: String) = BackendFunc(name, "bvnot", typ, Seq(LocalVarDecl("x", typ)()))
22+
def neg(name: String) = BackendFunc(name, "bvneg", typ, Seq(LocalVarDecl("x", typ)()))
23+
24+
def from_int(name: String) = BackendFunc(name, s"(_ int2bv ${size})", typ, Seq(LocalVarDecl("x", Int)()))
25+
def to_int(name: String) = BackendFunc(name, s"(_ bv2int ${size})", Int, Seq(LocalVarDecl("x", typ)()))
26+
def from_nat(name: String) = BackendFunc(name, s"(_ nat2bv ${size})", typ, Seq(LocalVarDecl("x", Int)()))
27+
def to_nat(name: String) = BackendFunc(name, s"(_ bv2nat ${size})", Int, Seq(LocalVarDecl("x", typ)()))
28+
}
29+
30+
/**
31+
* Rounding modes for floating point operations.
32+
*/
33+
object RoundingMode extends Enumeration {
34+
type RoundingMode = Value
35+
val RNE, RNA, RTP, RTN, RTZ = Value
36+
}
37+
import RoundingMode._
38+
39+
/**
40+
* A factory for IEEE floating point numbers with "exp" bits for the exponent, "mant" bits for the significant,
41+
* including the hidden bit, and a given rounding mode for all operations that use one.
42+
* Offers access to types, unary and binary operations, comparisons, and conversions from and to
43+
* bitvectors of size exp + mant.
44+
*/
45+
case class FloatFactory(mant: Int, exp: Int, roundingMode: RoundingMode) {
46+
47+
lazy val typ = BackendType(s"float${mant}e${exp}", s"(_ FloatingPoint ${exp} ${mant})")
48+
49+
def neg(name: String) = BackendFunc(name, "fp.neg", typ, Seq(LocalVarDecl("x", typ)()))
50+
def abs(name: String) = BackendFunc(name, "fp.abs", typ, Seq(LocalVarDecl("x", typ)()))
51+
52+
def add(name: String) = BackendFunc(name, s"fp.add ${roundingMode}", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
53+
def sub(name: String) = BackendFunc(name, s"fp.sub ${roundingMode}", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
54+
def mul(name: String) = BackendFunc(name, s"fp.mul ${roundingMode}", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
55+
def div(name: String) = BackendFunc(name, s"fp.div ${roundingMode}", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
56+
def min(name: String) = BackendFunc(name, s"fp.min ${roundingMode}", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
57+
def max(name: String) = BackendFunc(name, s"fp.max ${roundingMode}", typ, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
58+
59+
def eq(name: String) = BackendFunc(name, s"fp.eq", Bool, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
60+
def leq(name: String) = BackendFunc(name, s"fp.leq", Bool, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
61+
def geq(name: String) = BackendFunc(name, s"fp.geq", Bool, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
62+
def lt(name: String) = BackendFunc(name, s"fp.lt", Bool, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
63+
def gt(name: String) = BackendFunc(name, s"fp.gt", Bool, Seq(LocalVarDecl("x", typ)(), LocalVarDecl("y", typ)()))
64+
65+
def isZero(name: String) = BackendFunc(name, s"fp.isZero", Bool, Seq(LocalVarDecl("x", typ)()))
66+
def isInfinite(name: String) = BackendFunc(name, s"fp.isInfinite", Bool, Seq(LocalVarDecl("x", typ)()))
67+
def isNaN(name: String) = BackendFunc(name, s"fp.isNaN", Bool, Seq(LocalVarDecl("x", typ)()))
68+
def isNegative(name: String) = BackendFunc(name, s"fp.isNegative", Bool, Seq(LocalVarDecl("x", typ)()))
69+
def isPositive(name: String) = BackendFunc(name, s"fp.isPositive", Bool, Seq(LocalVarDecl("x", typ)()))
70+
71+
def from_bv(name: String) = BackendFunc(name, s"(_ to_fp ${exp} ${mant}) ", typ, Seq(LocalVarDecl("x", BVFactory(mant+exp).typ)()))
72+
def to_bv(name: String) = BackendFunc(name, s"(_ fp.to_sbv ${exp+mant}) ${roundingMode} ", BVFactory(mant+exp).typ, Seq(LocalVarDecl("x", typ)()))
73+
}

src/main/scala/viper/silver/ast/utility/Expressions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object Expressions {
3333
| _: PermExp
3434
| _: FuncApp
3535
| _: DomainFuncApp
36+
| _: BackendFuncApp
3637
| _: LocationAccess
3738
| _: AbstractLocalVar
3839
| _: SeqExp

src/main/scala/viper/silver/ast/utility/Nodes.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ object Nodes {
106106
case FuncApp(_, args) => args
107107
case DomainFuncApp(_, args, m) =>
108108
args ++ m.keys ++ m.values
109+
case BackendFuncApp(_, args) => args
109110

110111
case EmptySeq(elemTyp) => Seq(elemTyp)
111112
case ExplicitSeq(elems) => elems

src/main/scala/viper/silver/ast/utility/Types.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ object Types {
6969
* @return The list of transitive type constituents of `typ`.
7070
*/
7171
def typeConstituents(typ: Type): List[Type] = typ match {
72-
case Int | Bool | Perm | Ref | InternalType | _: TypeVar | Wand => Nil
72+
case Int | Bool | Perm | Ref | InternalType | _: TypeVar | Wand | _: BackendType => Nil
7373
case dt: DomainType => dt.typeParameters.map(_.substitute(dt.typVarsMap)).toList
7474
case SeqType(elementType) => elementType :: typeConstituents(elementType)
7575
case SetType(elementType) => elementType :: typeConstituents(elementType)
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
package viper.silver.testing
2+
3+
import org.scalatest.{BeforeAndAfterAllConfigMap, ConfigMap, FunSuite, Matchers}
4+
import viper.silver.ast.{AnySetContains, Assert, EqCmp, Exp, Field, FieldAccess, FieldAccessPredicate, FullPerm, Function, Inhale, IntLit, LocalVarAssign, LocalVarDecl, Method, Program, Ref, Result, BackendFuncApp, Seqn, SetType, Stmt}
5+
import viper.silver.ast.utility.{BVFactory, FloatFactory, RoundingMode}
6+
import viper.silver.verifier.{Failure, Success, Verifier}
7+
import viper.silver.verifier.errors.{AssertFailed, PostconditionViolated}
8+
9+
trait BackendTypeTest extends FunSuite with Matchers with BeforeAndAfterAllConfigMap {
10+
11+
def generateTypeCombinationTest(success: Boolean) : (Program, Assert) = {
12+
val t = if (success) BVFactory(23).typ else FloatFactory(23, 11, RoundingMode.RNE).typ
13+
val p1_decl = LocalVarDecl("three", t)()
14+
val p1_ref = p1_decl.localVar
15+
val p2_decl = LocalVarDecl("lol", SetType(t))()
16+
val p2_ref= p2_decl.localVar
17+
val element_in_param = AnySetContains(p1_ref, p2_ref)()
18+
19+
val assume = Inhale(element_in_param)()
20+
val assert = Assert(element_in_param)()
21+
val body = if (success) Seq(assume, assert) else Seq(assert)
22+
(wrapInProgram(body, Seq(p1_decl, p2_decl), Seq()), assert)
23+
}
24+
25+
def generateFieldTypeTest(success: Boolean) : (Program, Assert) = {
26+
val t = if (!success) BVFactory(23).typ else FloatFactory(23, 11, RoundingMode.RNE).typ
27+
val field = Field("f", t)()
28+
val p1_decl = LocalVarDecl("three", Ref)()
29+
val p1_ref = p1_decl.localVar
30+
val p2_decl = LocalVarDecl("lol", SetType(t))()
31+
val p2_ref= p2_decl.localVar
32+
val fieldAcc = FieldAccess(p1_ref, field)()
33+
val perm = FieldAccessPredicate(fieldAcc, FullPerm()())()
34+
val element_in_param = AnySetContains(fieldAcc, p2_ref)()
35+
36+
val getPerm = Inhale(perm)()
37+
val assume = Inhale(element_in_param)()
38+
val assert = Assert(element_in_param)()
39+
val body = if (success) Seq(getPerm, assume, assert) else Seq(getPerm, assert)
40+
(wrapInProgram(body, Seq(p1_decl, p2_decl), Seq(), fields = Seq(field)), assert)
41+
}
42+
43+
def generateFloatOpTest(success: Boolean) : (Program, Assert) = {
44+
val rne = RoundingMode.RNE
45+
val fp = FloatFactory(24, 8, rne)
46+
val first = 1081081856 // 3.75
47+
val second = 1103888384 // 25.5
48+
val result = 1105854464 // 29.25
49+
val bv32 = BVFactory(32)
50+
val from_int = bv32.from_int("toBV32")
51+
val to_fp = fp.from_bv("tofp")
52+
val fp_eq = fp.eq("fp_eq")
53+
val fp_add = fp.add("fp_add")
54+
55+
val first_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(first)()))()))()
56+
val second_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(second)()))()))()
57+
val result_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(result)()))()))()
58+
59+
val zero_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(0)()))()))()
60+
61+
val addition = BackendFuncApp(fp_add, Seq(first_float, second_float))()
62+
val result_addition = BackendFuncApp(fp_add, Seq(result_float, if (success) zero_float else first_float))()
63+
64+
val equality = BackendFuncApp(fp_eq, Seq(addition, result_addition))()
65+
val assert = Assert(equality)()
66+
(wrapInProgram(Seq(assert), Seq(), Seq()), assert)
67+
}
68+
69+
def generateFloatOpFunctionTest(success: Boolean) : (Program, Function, Exp) = {
70+
val rne = RoundingMode.RNE
71+
val fp = FloatFactory(24, 8, rne)
72+
val first = 1081081856 // 3.75
73+
val second = 1103888384 // 25.5
74+
val result = 1105854464 // 29.25
75+
val bv32 = BVFactory(32)
76+
val from_int = bv32.from_int("toBV32")
77+
val to_fp = fp.from_bv("tofp")
78+
val fp_eq = fp.eq("fp_eq")
79+
val fp_add = fp.add("fp_add")
80+
81+
val first_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(first)()))()))()
82+
val second_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(second)()))()))()
83+
val result_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(result)()))()))()
84+
85+
val zero_float = BackendFuncApp(to_fp, Seq(BackendFuncApp(from_int, Seq(IntLit(0)()))()))()
86+
87+
val addition = BackendFuncApp(fp_add, Seq(first_float, second_float))()
88+
val result_addition = BackendFuncApp(fp_add, Seq(result_float, if (success) zero_float else first_float))()
89+
90+
val equality = BackendFuncApp(fp_eq, Seq(Result(fp.typ)(), result_addition))()
91+
92+
val fun = Function("test", Seq(), fp.typ, Seq(), Seq(equality), Some(addition))()
93+
val program = Program(Seq(), Seq(), Seq(fun), Seq(), Seq(), Seq())()
94+
(program, fun, equality)
95+
}
96+
97+
def generateBvOpTest(success: Boolean) : (Program, Assert) = {
98+
val bv23 = BVFactory(23)
99+
val from_int = bv23.from_int("toBV23")
100+
val two_lit = IntLit(2)()
101+
val three_lit = IntLit(3)()
102+
val one_lit = IntLit(1) ()
103+
val two = BackendFuncApp(from_int, Seq(two_lit))()
104+
val three = BackendFuncApp(from_int, Seq(three_lit))()
105+
val one = BackendFuncApp(from_int, Seq(one_lit))()
106+
val result_decl = LocalVarDecl("three", bv23.typ)()
107+
val result_ref = result_decl.localVar
108+
val assign = LocalVarAssign(result_ref, if (success) three else one)()
109+
val xor = bv23.xor("xorBV23")
110+
val xor_app = BackendFuncApp(xor, Seq(one, two))()
111+
val equality1 = EqCmp(result_ref, xor_app)()
112+
val assertion1 = Assert(equality1)()
113+
(wrapInProgram(Seq(assign, assertion1), Seq(), Seq(result_decl)), assertion1)
114+
}
115+
116+
def wrapInProgram(stmts: Seq[Stmt], params: Seq[LocalVarDecl], vars: Seq[LocalVarDecl], fields: Seq[Field] = Seq()): Program = {
117+
val block = Seqn(stmts, vars)()
118+
val method = Method("test", params, Seq(), Seq(), Seq(), Some(block))()
119+
Program(Seq(), fields, Seq(), Seq(), Seq(method), Seq())()
120+
}
121+
122+
val verifier : Verifier
123+
124+
override def beforeAll(configMap: ConfigMap) {
125+
verifier.parseCommandLine(Seq("dummy.vpr"))
126+
verifier.start()
127+
}
128+
129+
override def afterAll(configMap: ConfigMap) {
130+
verifier.stop()
131+
}
132+
133+
test("typeCombinationSuccess") {
134+
val (prog, assertNode) = generateTypeCombinationTest(true)
135+
val res = verifier.verify(prog)
136+
assert(res == Success)
137+
}
138+
139+
test("typeCombinationFail") {
140+
val (prog, assertNode) = generateTypeCombinationTest(false)
141+
val res = verifier.verify(prog)
142+
assert(res match {
143+
case Failure(Seq(AssertFailed(a, _, _))) if a == assertNode => true
144+
case _ => false
145+
})
146+
}
147+
148+
test("fieldTypeSuccess") {
149+
val (prog, assertNode) = generateFieldTypeTest(true)
150+
val res = verifier.verify(prog)
151+
assert(res == Success)
152+
}
153+
154+
test("fieldTypeFail") {
155+
val (prog, assertNode) = generateFieldTypeTest(false)
156+
val res = verifier.verify(prog)
157+
assert(res match {
158+
case Failure(Seq(AssertFailed(a, _, _))) if a == assertNode => true
159+
case _ => false
160+
})
161+
}
162+
163+
test("bvOpSuccess") {
164+
val (prog, assertNode) = generateBvOpTest(true)
165+
val res = verifier.verify(prog)
166+
assert(res == Success)
167+
}
168+
169+
test("bvOpFail") {
170+
val (prog, assertNode) = generateBvOpTest(false)
171+
val res = verifier.verify(prog)
172+
assert(res match {
173+
case Failure(Seq(AssertFailed(a, _, _))) if a == assertNode => true
174+
case _ => false
175+
})
176+
}
177+
178+
test("floatOpSuccess") {
179+
val (prog, assertNode) = generateFloatOpTest(true)
180+
val res = verifier.verify(prog)
181+
assert(res == Success)
182+
}
183+
184+
test("floatOpFail") {
185+
val (prog, assertNode) = generateFloatOpTest(false)
186+
val res = verifier.verify(prog)
187+
assert(res match {
188+
case Failure(Seq(AssertFailed(a, _, _))) if a == assertNode => true
189+
case _ => false
190+
})
191+
}
192+
193+
test("floatOpFunctionSuccess") {
194+
val (prog, fun, exp) = generateFloatOpFunctionTest(true)
195+
val res = verifier.verify(prog)
196+
assert(res == Success)
197+
}
198+
199+
test("floatOpFunctionFail") {
200+
val (prog, fun, exp) = generateFloatOpFunctionTest(false)
201+
val res = verifier.verify(prog)
202+
assert(res match {
203+
case Failure(Seq(PostconditionViolated(e, f, _, _))) if e == exp && fun == f => true
204+
case _ => false
205+
})
206+
}
207+
208+
}

0 commit comments

Comments
 (0)