diff --git a/qiskit/circuit/classical/expr/__init__.py b/qiskit/circuit/classical/expr/__init__.py index a9c152ef229d..24fdb924b051 100644 --- a/qiskit/circuit/classical/expr/__init__.py +++ b/qiskit/circuit/classical/expr/__init__.py @@ -134,6 +134,10 @@ .. autofunction:: greater_equal .. autofunction:: shift_left .. autofunction:: shift_right +.. autofunction:: add +.. autofunction:: sub +.. autofunction:: mul +.. autofunction:: div You can index into unsigned integers and bit-likes using another unsigned integer of any width. This includes in storing operations, if the target of the index is writeable. @@ -214,6 +218,10 @@ "greater", "greater_equal", "index", + "add", + "sub", + "mul", + "div", "lift_legacy_condition", ] @@ -238,5 +246,9 @@ shift_left, shift_right, index, + add, + sub, + mul, + div, lift_legacy_condition, ) diff --git a/qiskit/circuit/classical/expr/constructors.py b/qiskit/circuit/classical/expr/constructors.py index 72744f55d940..9cc8da16e250 100644 --- a/qiskit/circuit/classical/expr/constructors.py +++ b/qiskit/circuit/classical/expr/constructors.py @@ -36,6 +36,10 @@ "shift_left", "shift_right", "index", + "add", + "sub", + "mul", + "div", "lift_legacy_condition", ] @@ -564,3 +568,209 @@ def index(target: typing.Any, index: typing.Any, /) -> Expr: if target.type.kind is not types.Uint or index.type.kind is not types.Uint: raise TypeError(f"invalid types for indexing: '{target.type}' and '{index.type}'") return Index(target, index, types.Bool()) + + +def _binary_sum(op: Binary.Op, left: typing.Any, right: typing.Any) -> Expr: + left, right = _lift_binary_operands(left, right) + if left.type.kind is right.type.kind and left.type.kind in { + types.Uint, + types.Float, + types.Duration, + }: + type = types.greater(left.type, right.type) + return Binary( + op, + _coerce_lossless(left, type), + _coerce_lossless(right, type), + type, + ) + raise TypeError(f"invalid types for '{op}': '{left.type}' and '{right.type}'") + + +def add(left: typing.Any, right: typing.Any, /) -> Expr: + """Create an addition expression node from the given values, resolving any implicit casts and + lifting the values into :class:`Value` nodes if required. + + Examples: + Addition of two floating point numbers:: + + >>> from qiskit.circuit.classical import expr + >>> expr.add(5.0, 2.0) + Binary(\ +Binary.Op.ADD, \ +Value(5.0, Float()), \ +Value(2.0, Float()), \ +Float()) + + Addition of two durations:: + + >>> from qiskit.circuit import Duration + >>> from qiskit.circuit.classical import expr + >>> expr.add(Duration.dt(1000), Duration.dt(1000)) + Binary(\ +Binary.Op.ADD, \ +Value(Duration.dt(1000), Duration()), \ +Value(Duration.dt(1000), Duration()), \ +Duration()) + """ + return _binary_sum(Binary.Op.ADD, left, right) + + +def sub(left: typing.Any, right: typing.Any, /) -> Expr: + """Create a subtraction expression node from the given values, resolving any implicit casts and + lifting the values into :class:`Value` nodes if required. + + Examples: + Subtraction of two floating point numbers:: + + >>> from qiskit.circuit.classical import expr + >>> expr.sub(5.0, 2.0) + Binary(\ +Binary.Op.SUB, \ +Value(5.0, Float()), \ +Value(2.0, Float()), \ +Float()) + + Subtraction of two durations:: + + >>> from qiskit.circuit import Duration + >>> from qiskit.circuit.classical import expr + >>> expr.add(Duration.dt(1000), Duration.dt(1000)) + Binary(\ +Binary.Op.SUB, \ +Value(Duration.dt(1000), Duration()), \ +Value(Duration.dt(1000), Duration()), \ +Duration()) + """ + return _binary_sum(Binary.Op.SUB, left, right) + + +def mul(left: typing.Any, right: typing.Any) -> Expr: + """Create a multiplication expression node from the given values, resolving any implicit casts and + lifting the values into :class:`Value` nodes if required. + + This can be used to multiply numeric operands of the same type kind, or to multiply a duration + operand by a numeric operand. + + Examples: + Multiplication of two floating point numbers:: + + >>> from qiskit.circuit.classical import expr + >>> expr.mul(5.0, 2.0) + Binary(\ +Binary.Op.MUL, \ +Value(5.0, Float()), \ +Value(2.0, Float()), \ +Float()) + + Multiplication of a duration by a float:: + + >>> from qiskit.circuit import Duration + >>> from qiskit.circuit.classical import expr + >>> expr.mul(Duration.dt(1000), 0.5) + Binary(\ +Binary.Op.MUL, \ +Value(Duration.dt(1000), Duration()), \ +Value(0.5, Float()), \ +Duration()) + """ + left, right = _lift_binary_operands(left, right) + type: types.Type + if left.type.kind is right.type.kind is types.Duration: + raise TypeError("cannot multiply two durations") + if left.type.kind is right.type.kind and left.type.kind in {types.Uint, types.Float}: + type = types.greater(left.type, right.type) + left = _coerce_lossless(left, type) + right = _coerce_lossless(right, type) + elif left.type.kind is types.Duration and right.type.kind in {types.Uint, types.Float}: + if not right.const: + raise ValueError( + f"multiplying operands '{left}' and '{right}' would result in a non-const '{left.type}'" + ) + type = left.type + elif right.type.kind is types.Duration and left.type.kind in {types.Uint, types.Float}: + if not left.const: + raise ValueError( + f"multiplying operands '{left}' and '{right}' would result in a non-const '{right.type}'" + ) + type = right.type + else: + raise TypeError(f"invalid types for '{Binary.Op.MUL}': '{left.type}' and '{right.type}'") + return Binary( + Binary.Op.MUL, + left, + right, + type, + ) + + +def div(left: typing.Any, right: typing.Any) -> Expr: + """Create a division expression node from the given values, resolving any implicit casts and + lifting the values into :class:`Value` nodes if required. + + This can be used to divide numeric operands of the same type kind, to divide a + :class`~.types.Duration` operand by a numeric operand, or to divide two + :class`~.types.Duration` operands which yields an expression of type + :class:`~.types.Float`. + + Examples: + Division of two floating point numbers:: + + >>> from qiskit.circuit.classical import expr + >>> expr.div(5.0, 2.0) + Binary(\ +Binary.Op.DIV, \ +Value(5.0, Float()), \ +Value(2.0, Float()), \ +Float()) + + Division of two durations:: + + >>> from qiskit.circuit import Duration + >>> from qiskit.circuit.classical import expr + >>> expr.div(Duration.dt(10000), Duration.dt(1000)) + Binary(\ +Binary.Op.DIV, \ +Value(Duration.dt(10000), Duration()), \ +Value(Duration.dt(1000), Duration()), \ +Float()) + + + Division of a duration by a float:: + + >>> from qiskit.circuit import Duration + >>> from qiskit.circuit.classical import expr + >>> expr.div(Duration.dt(10000), 12.0) + Binary(\ +Binary.Op.DIV, \ +Value(Duration.dt(10000), Duration()), \ +Value(12.0, types.Float()), \ +Duration()) + """ + left, right = _lift_binary_operands(left, right) + type: types.Type + if left.type.kind is right.type.kind and left.type.kind in { + types.Duration, + types.Uint, + types.Float, + }: + if left.type.kind is types.Duration: + type = types.Float() + elif types.order(left.type, right.type) is not types.Ordering.NONE: + type = types.greater(left.type, right.type) + left = _coerce_lossless(left, type) + right = _coerce_lossless(right, type) + elif left.type.kind is types.Duration and right.type.kind in {types.Uint, types.Float}: + if not right.const: + raise ValueError( + f"division of '{left}' and '{right}' would result in a non-const '{left.type}'" + ) + type = left.type + else: + raise TypeError(f"invalid types for '{Binary.Op.DIV}': '{left.type}' and '{right.type}'") + return Binary( + Binary.Op.DIV, + left, + right, + type, + ) diff --git a/qiskit/circuit/classical/expr/expr.py b/qiskit/circuit/classical/expr/expr.py index dbb0f7e3b075..d2f8a20fb091 100644 --- a/qiskit/circuit/classical/expr/expr.py +++ b/qiskit/circuit/classical/expr/expr.py @@ -314,6 +314,15 @@ class Op(enum.Enum): container types (e.g. unsigned integers) as the left operand, and any integer type as the right-hand operand. In all cases, the output bit width is the same as the input, and zeros fill in the "exposed" spaces. + + The binary arithmetic operators :data:`ADD`, :data:`SUB:, :data:`MUL`, and :data:`DIV` + can be applied to two floats or two unsigned integers, which should be made to be of + the same width during construction via a cast. + The :data:`ADD`, :data:`SUB`, and :data:`DIV` operators can be applied on two durations + yielding another duration, or a float in the case of :data:`DIV`. The :data:`MUL` operator + can also be applied to a duration and a numeric type, yielding another duration. Finally, + the :data:`DIV` operator can be used to divide a duration by a numeric type, yielding a + duration. """ # If adding opcodes, remember to add helper constructor functions in `constructors.py` @@ -345,6 +354,14 @@ class Op(enum.Enum): """Zero-padding bitshift to the left. ``lhs << rhs``.""" SHIFT_RIGHT = 13 """Zero-padding bitshift to the right. ``lhs >> rhs``.""" + ADD = 14 + """Addition. ``lhs + rhs``.""" + SUB = 15 + """Subtraction. ``lhs - rhs``.""" + MUL = 16 + """Multiplication. ``lhs * rhs``.""" + DIV = 17 + """Division. ``lhs / rhs``.""" def __str__(self): return f"Binary.{super().__str__()}" diff --git a/qiskit/qasm3/ast.py b/qiskit/qasm3/ast.py index 7e0ca0ce4723..ee4b282bb560 100644 --- a/qiskit/qasm3/ast.py +++ b/qiskit/qasm3/ast.py @@ -323,6 +323,10 @@ class Op(enum.Enum): NOT_EQUAL = "!=" SHIFT_LEFT = "<<" SHIFT_RIGHT = ">>" + ADD = "+" + SUB = "-" + MUL = "*" + DIV = "/" def __init__(self, op: Op, left: Expression, right: Expression): self.op = op diff --git a/qiskit/qasm3/printer.py b/qiskit/qasm3/printer.py index 36795ed35e96..0020982e38a1 100644 --- a/qiskit/qasm3/printer.py +++ b/qiskit/qasm3/printer.py @@ -39,8 +39,12 @@ ast.Unary.Op.LOGIC_NOT: _BindingPower(right=22), ast.Unary.Op.BIT_NOT: _BindingPower(right=22), # - # Multiplication/division/modulo: (19, 20) - # Addition/subtraction: (17, 18) + # Modulo: (19, 20) + ast.Binary.Op.MUL: _BindingPower(19, 20), + ast.Binary.Op.DIV: _BindingPower(19, 20), + # + ast.Binary.Op.ADD: _BindingPower(17, 18), + ast.Binary.Op.SUB: _BindingPower(17, 18), # ast.Binary.Op.SHIFT_LEFT: _BindingPower(15, 16), ast.Binary.Op.SHIFT_RIGHT: _BindingPower(15, 16), diff --git a/releasenotes/notes/math-expr-a71515664473fdc4.yaml b/releasenotes/notes/math-expr-a71515664473fdc4.yaml new file mode 100644 index 000000000000..263bebf52a15 --- /dev/null +++ b/releasenotes/notes/math-expr-a71515664473fdc4.yaml @@ -0,0 +1,27 @@ +--- +features_circuits: + - | + The classical realtime-expressions module :mod:`qiskit.circuit.classical` can now represent + arithmetic operations :func:`~.expr.add`, :func:`~.expr.sub`, :func:`~.expr.mul`, + and :func:`~.expr.div` on numeric and timing operands. + + For example:: + + from qiskit.circuit import QuantumCircuit, ClassicalRegister, Duration + from qiskit.circuit.classical import expr + + # Subtract two integers + cr = ClassicalRegister(4, "cr") + qc = QuantumCircuit(cr) + with qc.if_test(expr.equal(expr.sub(cr, 2), 3)): + pass + + # Multiply a Duration by a Float + with qc.if_test(expr.less(expr.mul(Duration.dt(200), 2.0), Duration.ns(500))): + pass + + # Divide a Duration by a Duration to get a Float + with qc.if_test(expr.greater(expr.div(Duration.dt(200), Duration.dt(400)), 0.5)): + pass + + For additional examples, see the module-level documentation linked above. diff --git a/test/python/circuit/classical/test_expr_constructors.py b/test/python/circuit/classical/test_expr_constructors.py index 169b2d176a47..37123c663722 100644 --- a/test/python/circuit/classical/test_expr_constructors.py +++ b/test/python/circuit/classical/test_expr_constructors.py @@ -184,18 +184,29 @@ def test_unary_logical_forbidden(self, function): (expr.less_equal, ClassicalRegister(3), 5), (expr.greater, 4, ClassicalRegister(3)), (expr.greater_equal, ClassicalRegister(3), 5), + (expr.add, ClassicalRegister(3), 6), + (expr.sub, ClassicalRegister(3), 5), + (expr.mul, 4, ClassicalRegister(3)), + (expr.div, ClassicalRegister(3), 5), (expr.equal, 8.0, 255.0), (expr.not_equal, 8.0, 255.0), (expr.less, 3.0, 6.0), (expr.less_equal, 3.0, 5.0), (expr.greater, 4.0, 3.0), (expr.greater_equal, 3.0, 5.0), + (expr.add, 3.0, 6.0), + (expr.sub, 3.0, 5.0), + (expr.mul, 4.0, 3.0), + (expr.div, 3.0, 5.0), (expr.equal, Duration.dt(1000), Duration.dt(1000)), (expr.not_equal, Duration.dt(1000), Duration.dt(1000)), (expr.less, Duration.dt(1000), Duration.dt(1000)), (expr.less_equal, Duration.dt(1000), Duration.dt(1000)), (expr.greater, Duration.dt(1000), Duration.dt(1000)), (expr.greater_equal, Duration.dt(1000), Duration.dt(1000)), + (expr.add, Duration.dt(1000), Duration.dt(1000)), + (expr.sub, Duration.dt(1000), Duration.dt(1000)), + (expr.div, Duration.dt(1000), Duration.dt(1000)), ) @ddt.unpack def test_binary_functions_lift_scalars(self, function, left, right): @@ -421,7 +432,7 @@ def test_binary_equal_explicit(self, function, opcode): types.Bool(), ), ) - self.assertFalse(function(clbit, True).const) + self.assertTrue(function(expr.lift(7.0), 7.0).const) self.assertEqual( function(expr.lift(Duration.ms(1000)), Duration.s(1)), @@ -610,3 +621,311 @@ def test_shift_forbidden(self, function): function(Duration.dt(1000), 1) with self.assertRaisesRegex(TypeError, "invalid types"): function(Duration.dt(1000), Duration.dt(1000)) + + @ddt.data( + (expr.add, expr.Binary.Op.ADD), + (expr.sub, expr.Binary.Op.SUB), + ) + @ddt.unpack + def test_binary_sum_explicit(self, function, opcode): + cr = ClassicalRegister(8, "c") + + self.assertEqual( + function(cr, 200), + expr.Binary( + opcode, expr.Var(cr, types.Uint(8)), expr.Value(200, types.Uint(8)), types.Uint(8) + ), + ) + self.assertFalse(function(cr, 200).const) + + self.assertEqual( + function(12, cr), + expr.Binary( + opcode, + expr.Value(12, types.Uint(8)), + expr.Var(cr, types.Uint(8)), + types.Uint(8), + ), + ) + self.assertFalse(function(12, cr).const) + + self.assertEqual( + function(12.5, 2.0), + expr.Binary( + opcode, + expr.Value(12.5, types.Float()), + expr.Value(2.0, types.Float()), + types.Float(), + ), + ) + self.assertTrue(function(12.5, 2.0).const) + + self.assertEqual( + function( + expr.lift(Duration.ms(1000), types.Duration()), + expr.lift(Duration.s(1)), + ), + expr.Binary( + opcode, + expr.Value(Duration.ms(1000), types.Duration()), + expr.Value(Duration.s(1), types.Duration()), + types.Duration(), + ), + ) + self.assertTrue( + function( + expr.lift(Duration.ms(1000), types.Duration()), + expr.lift(Duration.s(1)), + ).const + ) + + @ddt.data(expr.add, expr.sub) + def test_binary_sum_forbidden(self, function): + with self.assertRaisesRegex(TypeError, "invalid types"): + function(Clbit(), ClassicalRegister(3, "c")) + with self.assertRaisesRegex(TypeError, "invalid types"): + function(ClassicalRegister(3, "c"), False) + with self.assertRaisesRegex(TypeError, "invalid types"): + function(Clbit(), Clbit()) + with self.assertRaisesRegex(TypeError, "invalid types"): + function(0xFFFF, 2.0) + with self.assertRaisesRegex(TypeError, "invalid types"): + function(255.0, 1) + with self.assertRaisesRegex(TypeError, "invalid types"): + function(Duration.dt(1000), 1) + with self.assertRaisesRegex(TypeError, "invalid types"): + function(Duration.dt(1000), 1.0) + with self.assertRaisesRegex(TypeError, "invalid types"): + function(Duration.dt(1000), expr.lift(1.0)) + + def test_mul_explicit(self): + cr = ClassicalRegister(8, "c") + + self.assertEqual( + expr.mul(cr, 200), + expr.Binary( + expr.Binary.Op.MUL, + expr.Var(cr, types.Uint(8)), + expr.Value(200, types.Uint(8)), + types.Uint(8), + ), + ) + self.assertFalse(expr.mul(cr, 200).const) + + self.assertEqual( + expr.mul(12, cr), + expr.Binary( + expr.Binary.Op.MUL, + expr.Value(12, types.Uint(8)), + expr.Var(cr, types.Uint(8)), + types.Uint(8), + ), + ) + self.assertFalse(expr.mul(12, cr).const) + + self.assertEqual( + expr.mul(expr.lift(12), cr), + expr.Binary( + expr.Binary.Op.MUL, + # Explicit cast required to get from Uint(4) to Uint(8) + expr.Cast(expr.Value(12, types.Uint(4)), types.Uint(8), implicit=False), + expr.Var(cr, types.Uint(8)), + types.Uint(8), + ), + ) + self.assertFalse(expr.mul(12, cr).const) + + self.assertEqual( + expr.mul(expr.lift(12, types.Uint(8)), expr.lift(12)), + expr.Binary( + expr.Binary.Op.MUL, + expr.Value(12, types.Uint(8)), + expr.Cast( + expr.Value(12, types.Uint(4)), + types.Uint(8), + implicit=False, + ), + types.Uint(8), + ), + ) + self.assertTrue(expr.mul(expr.lift(12, types.Uint(8)), expr.lift(12)).const) + + self.assertEqual( + expr.mul(expr.lift(12.0, types.Float()), expr.lift(12.0)), + expr.Binary( + expr.Binary.Op.MUL, + expr.Value(12.0, types.Float()), + expr.Value(12.0, types.Float()), + types.Float(), + ), + ) + self.assertTrue(expr.mul(expr.lift(12.0, types.Float()), expr.lift(12.0)).const) + + self.assertEqual( + expr.mul(Duration.ms(1000), 2.0), + expr.Binary( + expr.Binary.Op.MUL, + expr.Value(Duration.ms(1000), types.Duration()), + expr.Value(2.0, types.Float()), + types.Duration(), + ), + ) + self.assertTrue(expr.mul(Duration.ms(1000), 2.0).const) + + self.assertEqual( + expr.mul(2.0, Duration.ms(1000)), + expr.Binary( + expr.Binary.Op.MUL, + expr.Value(2.0, types.Float()), + expr.Value(Duration.ms(1000), types.Duration()), + types.Duration(), + ), + ) + self.assertTrue(expr.mul(2.0, Duration.ms(1000)).const) + + self.assertEqual( + expr.mul(2, Duration.ms(1000)), + expr.Binary( + expr.Binary.Op.MUL, + expr.Value(2, types.Uint(2)), + expr.Value(Duration.ms(1000), types.Duration()), + types.Duration(), + ), + ) + self.assertTrue(expr.mul(2, Duration.ms(1000)).const) + + def test_mul_forbidden(self): + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.mul(Clbit(), ClassicalRegister(3, "c")) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.mul(ClassicalRegister(3, "c"), False) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.mul(Clbit(), Clbit()) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.mul(0xFFFF, 2.0) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.mul(255.0, 1) + with self.assertRaisesRegex(TypeError, "cannot multiply two durations"): + expr.mul(Duration.dt(1000), Duration.dt(1000)) + + # Multiply timing expressions by non-const floats: + non_const_float = expr.Var.new("a", types.Float()) + with self.assertRaisesRegex(ValueError, "would result in a non-const"): + expr.mul(Duration.dt(1000), non_const_float) + with self.assertRaisesRegex(ValueError, "would result in a non-const"): + expr.mul(non_const_float, Duration.dt(1000)) + + def test_div_explicit(self): + cr = ClassicalRegister(8, "c") + + self.assertEqual( + expr.div(cr, 200), + expr.Binary( + expr.Binary.Op.DIV, + expr.Var(cr, types.Uint(8)), + expr.Value(200, types.Uint(8)), + types.Uint(8), + ), + ) + self.assertFalse(expr.div(cr, 200).const) + + self.assertEqual( + expr.div(12, cr), + expr.Binary( + expr.Binary.Op.DIV, + expr.Value(12, types.Uint(8)), + expr.Var(cr, types.Uint(8)), + types.Uint(8), + ), + ) + self.assertFalse(expr.div(12, cr).const) + + self.assertEqual( + expr.div(expr.lift(12), cr), + expr.Binary( + expr.Binary.Op.DIV, + # Explicit cast required to get from Uint(4) to Uint(8) + expr.Cast(expr.Value(12, types.Uint(4)), types.Uint(8), implicit=False), + expr.Var(cr, types.Uint(8)), + types.Uint(8), + ), + ) + self.assertFalse(expr.div(expr.lift(12), cr).const) + + self.assertEqual( + expr.div(expr.lift(12, types.Uint(8)), expr.lift(12)), + expr.Binary( + expr.Binary.Op.DIV, + expr.Value(12, types.Uint(8)), + expr.Cast( + expr.Value(12, types.Uint(4)), + types.Uint(8), + implicit=False, + ), + types.Uint(8), + ), + ) + self.assertTrue(expr.div(expr.lift(12, types.Uint(8)), expr.lift(12)).const) + + self.assertEqual( + expr.div(expr.lift(12.0, types.Float()), expr.lift(12.0)), + expr.Binary( + expr.Binary.Op.DIV, + expr.Value(12.0, types.Float()), + expr.Value(12.0, types.Float()), + types.Float(), + ), + ) + self.assertTrue(expr.div(expr.lift(12.0, types.Float()), expr.lift(12.0)).const) + + self.assertEqual( + expr.div(Duration.ms(1000), 2.0), + expr.Binary( + expr.Binary.Op.DIV, + expr.Value(Duration.ms(1000), types.Duration()), + expr.Value(2.0, types.Float()), + types.Duration(), + ), + ) + self.assertTrue(expr.div(Duration.ms(1000), 2.0).const) + + self.assertEqual( + expr.div(Duration.ms(1000), 2), + expr.Binary( + expr.Binary.Op.DIV, + expr.Value(Duration.ms(1000), types.Duration()), + expr.Value(2, types.Uint(2)), + types.Duration(), + ), + ) + self.assertTrue(expr.div(Duration.ms(1000), 2).const) + + self.assertEqual( + expr.div(Duration.ms(1000), Duration.ms(1000)), + expr.Binary( + expr.Binary.Op.DIV, + expr.Value(Duration.ms(1000), types.Duration()), + expr.Value(Duration.ms(1000), types.Duration()), + types.Float(), + ), + ) + self.assertTrue(expr.div(Duration.ms(1000), Duration.ms(1000)).const) + + def test_div_forbidden(self): + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.div(Clbit(), ClassicalRegister(3, "c")) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.div(ClassicalRegister(3, "c"), False) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.div(Clbit(), Clbit()) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.div(0xFFFF, 2.0) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.div(255.0, 1) + with self.assertRaisesRegex(TypeError, "invalid types"): + expr.div(255.0, Duration.dt(1000)) + + # Divide timing expressions by non-const floats: + non_const_float = expr.Var.new("a", types.Float()) + with self.assertRaisesRegex(ValueError, "would result in a non-const"): + expr.div(Duration.dt(1000), non_const_float) diff --git a/test/python/compiler/test_transpiler.py b/test/python/compiler/test_transpiler.py index 61cf3c60a684..eac19218580b 100644 --- a/test/python/compiler/test_transpiler.py +++ b/test/python/compiler/test_transpiler.py @@ -2032,6 +2032,22 @@ def _control_flow_expr_circuit(self): ) ): base.cx(0, 1) + with base.if_test( + expr.logic_and( + expr.logic_and( + expr.equal(expr.mul(Duration.dt(1), 2.0), expr.div(Duration.ns(2), 2.0)), + expr.equal( + expr.add(Duration.us(3), Duration.us(4)), + expr.sub(Duration.ms(5), Duration.ms(6)), + ), + ), + expr.logic_and( + expr.equal(expr.mul(expr.lift(1.0), 2.0), expr.div(4.0, 2.0)), + expr.equal(expr.add(3.0, 4.0), expr.sub(10.5, expr.lift(4.3, types.Float()))), + ), + ) + ): + base.cx(0, 1) return base def _standalone_var_circuit(self): diff --git a/test/python/qasm3/test_export.py b/test/python/qasm3/test_export.py index c95cdee4b55c..cadffbf96c9c 100644 --- a/test/python/qasm3/test_export.py +++ b/test/python/qasm3/test_export.py @@ -1560,6 +1560,10 @@ def test_expr_associativity_left(self): ) qc.if_test(expr.logic_and(expr.logic_and(cr1[0], cr1[1]), cr1[2]), body.copy(), [], []) qc.if_test(expr.logic_or(expr.logic_or(cr1[0], cr1[1]), cr1[2]), body.copy(), [], []) + qc.if_test(expr.equal(expr.add(expr.add(cr1, cr2), cr3), 7), body.copy(), [], []) + qc.if_test(expr.equal(expr.sub(expr.sub(cr1, cr2), cr3), 7), body.copy(), [], []) + qc.if_test(expr.equal(expr.mul(expr.mul(cr1, cr2), cr3), 7), body.copy(), [], []) + qc.if_test(expr.equal(expr.div(expr.div(cr1, cr2), cr3), 7), body.copy(), [], []) # Note that bitwise operations except shift have lower priority than `==` so there's extra # parentheses. All these operators are left-associative in OQ3. @@ -1585,6 +1589,14 @@ def test_expr_associativity_left(self): } if (cr1[0] || cr1[1] || cr1[2]) { } +if (cr1 + cr2 + cr3 == 7) { +} +if (cr1 - cr2 - cr3 == 7) { +} +if (cr1 * cr2 * cr3 == 7) { +} +if (cr1 / cr2 / cr3 == 7) { +} """ self.assertEqual(dumps(qc), expected) @@ -1611,6 +1623,10 @@ def test_expr_associativity_right(self): ) qc.if_test(expr.logic_and(cr1[0], expr.logic_and(cr1[1], cr1[2])), body.copy(), [], []) qc.if_test(expr.logic_or(cr1[0], expr.logic_or(cr1[1], cr1[2])), body.copy(), [], []) + qc.if_test(expr.equal(expr.add(cr1, expr.add(cr2, cr3)), 7), body.copy(), [], []) + qc.if_test(expr.equal(expr.sub(cr1, expr.sub(cr2, cr3)), 7), body.copy(), [], []) + qc.if_test(expr.equal(expr.mul(cr1, expr.mul(cr2, cr3)), 7), body.copy(), [], []) + qc.if_test(expr.equal(expr.div(cr1, expr.div(cr2, cr3)), 7), body.copy(), [], []) # Note that bitwise operations have lower priority than `==` so there's extra parentheses. # All these operators are left-associative in OQ3, so we need parentheses for them to be @@ -1638,6 +1654,14 @@ def test_expr_associativity_right(self): } if (cr1[0] || (cr1[1] || cr1[2])) { } +if (cr1 + (cr2 + cr3) == 7) { +} +if (cr1 - (cr2 - cr3) == 7) { +} +if (cr1 * (cr2 * cr3) == 7) { +} +if (cr1 / (cr2 / cr3) == 7) { +} """ self.assertEqual(dumps(qc), expected) @@ -1713,11 +1737,17 @@ def test_expr_precedence(self): ), ) + arithmetic = expr.equal( + expr.add(expr.mul(cr, expr.sub(cr, cr)), expr.div(expr.add(cr, cr), cr)), + expr.sub(expr.div(expr.mul(cr, cr), expr.add(cr, cr)), expr.mul(cr, expr.add(cr, cr))), + ) + qc = QuantumCircuit(cr) qc.if_test(inside_out, body.copy(), [], []) qc.if_test(outside_in, body.copy(), [], []) qc.if_test(logics, body.copy(), [], []) qc.if_test(bitshifts, body.copy(), [], []) + qc.if_test(arithmetic, body.copy(), [], []) expected = """\ OPENQASM 3.0; @@ -1733,6 +1763,8 @@ def test_expr_precedence(self): } if (((cr ^ cr) & cr) << (cr | cr) == (cr >> 3 ^ cr << 4 | cr << 1)) { } +if (cr * (cr - cr) + (cr + cr) / cr == cr * cr / (cr + cr) - cr * (cr + cr)) { +} """ self.assertEqual(dumps(qc), expected) diff --git a/test/qpy_compat/test_qpy.py b/test/qpy_compat/test_qpy.py index 3ff797edb7e2..7a551fbcf8eb 100755 --- a/test/qpy_compat/test_qpy.py +++ b/test/qpy_compat/test_qpy.py @@ -866,7 +866,25 @@ def generate_v14_expr(): ): pass - return [float_expr, duration_expr] + math_expr = QuantumCircuit(name="math_expr") + with math_expr.if_test( + expr.logic_and( + expr.logic_and( + expr.equal(expr.mul(Duration.dt(1), 2.0), expr.div(Duration.ns(2), 2.0)), + expr.equal( + expr.add(Duration.us(3), Duration.us(4)), + expr.sub(Duration.ms(5), Duration.ms(6)), + ), + ), + expr.logic_and( + expr.equal(expr.mul(1.0, 2.0), expr.div(4.0, 2.0)), + expr.equal(expr.add(3.0, 4.0), expr.sub(10.5, 4.3)), + ), + ) + ): + pass + + return [float_expr, duration_expr, math_expr] def generate_circuits(version_parts, current_version, load_context=False):