Skip to content

Commit 7673277

Browse files
committed
WIP
1 parent 6274843 commit 7673277

5 files changed

Lines changed: 289 additions & 67 deletions

File tree

qiskit/circuit/classical/expr/constructors.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,19 @@ def lift(value: typing.Any, /, type: types.Type | None = None) -> Expr:
124124
from qiskit.circuit import Clbit, ClassicalRegister # pylint: disable=cyclic-import
125125

126126
inferred: types.Type
127-
if value is True or value is False or isinstance(value, Clbit):
127+
if value is True or value is False:
128+
inferred = types.Bool(const=True)
129+
constructor = Value
130+
elif isinstance(value, Clbit):
128131
inferred = types.Bool()
129-
constructor = Value if value is True or value is False else Var
132+
constructor = Var
130133
elif isinstance(value, ClassicalRegister):
131134
inferred = types.Uint(width=value.size)
132135
constructor = Var
133136
elif isinstance(value, int):
134137
if value < 0:
135138
raise ValueError("cannot represent a negative value")
136-
inferred = types.Uint(width=value.bit_length() or 1)
139+
inferred = types.Uint(width=value.bit_length() or 1, const=True)
137140
constructor = Value
138141
else:
139142
raise TypeError(f"failed to infer a type for '{value}'")
@@ -198,41 +201,73 @@ def logic_not(operand: typing.Any, /) -> Expr:
198201
Cast(Var(ClassicalRegister(3, 'c'), Uint(3)), Bool(), implicit=True), \
199202
Bool())
200203
"""
201-
operand = _coerce_lossless(lift(operand), types.Bool())
204+
var_or_value = lift(operand)
205+
operand = _coerce_lossless(var_or_value, types.Bool(const=var_or_value.type.const))
202206
return Unary(Unary.Op.LOGIC_NOT, operand, operand.type)
203207

204208

205209
def _lift_binary_operands(left: typing.Any, right: typing.Any) -> tuple[Expr, Expr]:
206210
"""Lift two binary operands simultaneously, inferring the widths of integer literals in either
207-
position to match the other operand."""
208-
left_int = isinstance(left, int) and not isinstance(left, bool)
209-
right_int = isinstance(right, int) and not isinstance(right, bool)
211+
position to match the other operand.
212+
213+
Const-ness is handled as follows:
214+
* If neither operand is an expression, both are lifted to share the same const-ness.
215+
Both will be const, if possible. Else, neither will be.
216+
* If only one operand is an expression, the other is lifted with the same const-ness, if possible.
217+
Otherwise, the returned operands will have different const-ness, and thus require a cast node.
218+
* If both operands are expressions, they are returned as-is and may require a cast node.
219+
"""
220+
left_bool = isinstance(left, bool)
221+
left_int = isinstance(left, int) and not left_bool
222+
right_bool = isinstance(right, bool)
223+
right_int = isinstance(right, int) and not right_bool
210224
if not (left_int or right_int):
211-
left = lift(left)
212-
right = lift(right)
225+
if left_bool == right_bool:
226+
# If they're both bool, lifting them will produce const Bool.
227+
# If neither are bool, they're a mix of bits/registers (which are always
228+
# non-const) and Expr, which we can't modify the const-ness of without
229+
# a cast node.
230+
left = lift(left)
231+
right = lift(right)
232+
elif not right_bool:
233+
# Left is a bool
234+
right = lift(right)
235+
# TODO: if right.type isn't Bool, there's a type mismatch so we _should_
236+
# raise here. But, _binary_bitwise will error for us with a better msg.
237+
left = lift(left, right.type if right.type.kind is types.Bool else None)
238+
elif not left_bool:
239+
# Right is a bool.
240+
left = lift(left)
241+
# TODO: if left.type isn't Bool, there's a type mismatch so we _should_
242+
# raise here. But, _binary_bitwise will error for us with a better msg.
243+
right = lift(right, left.type if left.type.kind is types.Bool else None)
213244
elif not right_int:
245+
# Left is an int.
214246
right = lift(right)
215247
if right.type.kind is types.Uint:
216248
if left.bit_length() > right.type.width:
217249
raise TypeError(
218250
f"integer literal '{left}' is wider than the other operand '{right}'"
219251
)
252+
# Left will share const-ness of right.
220253
left = Value(left, right.type)
221254
else:
222255
left = lift(left)
223256
elif not left_int:
257+
# Right is an int.
224258
left = lift(left)
225259
if left.type.kind is types.Uint:
226260
if right.bit_length() > left.type.width:
227261
raise TypeError(
228262
f"integer literal '{right}' is wider than the other operand '{left}'"
229263
)
264+
# Right will share const-ness of left.
230265
right = Value(right, left.type)
231266
else:
232267
right = lift(right)
233268
else:
234269
# Both are `int`, so we take our best case to make things work.
235-
uint = types.Uint(max(left.bit_length(), right.bit_length(), 1))
270+
uint = types.Uint(max(left.bit_length(), right.bit_length(), 1), const=True)
236271
left = Value(left, uint)
237272
right = Value(right, uint)
238273
return left, right
@@ -242,14 +277,14 @@ def _binary_bitwise(op: Binary.Op, left: typing.Any, right: typing.Any) -> Expr:
242277
left, right = _lift_binary_operands(left, right)
243278
type: types.Type
244279
if left.type.kind is right.type.kind is types.Bool:
245-
type = types.Bool()
280+
type = types.Bool(const=(left.type.const and right.type.const))
246281
elif left.type.kind is types.Uint and right.type.kind is types.Uint:
247282
if left.type != right.type:
248283
raise TypeError(
249284
"binary bitwise operations are defined between unsigned integers of the same width,"
250285
f" but got {left.type.width} and {right.type.width}."
251286
)
252-
type = left.type
287+
type = types.Uint(width=left.type.width, const=(left.type.const and right.type.const))
253288
else:
254289
raise TypeError(f"invalid types for '{op}': '{left.type}' and '{right.type}'")
255290
return Binary(op, left, right, type)
@@ -313,10 +348,10 @@ def bit_xor(left: typing.Any, right: typing.Any, /) -> Expr:
313348

314349

315350
def _binary_logical(op: Binary.Op, left: typing.Any, right: typing.Any) -> Expr:
316-
bool_ = types.Bool()
317-
left = _coerce_lossless(lift(left), bool_)
318-
right = _coerce_lossless(lift(right), bool_)
319-
return Binary(op, left, right, bool_)
351+
left, right = _lift_binary_operands(left, right)
352+
left = _coerce_lossless(left, types.Bool(const=left.type.const))
353+
right = _coerce_lossless(right, types.Bool(const=right.type.const))
354+
return Binary(op, left, right, types.Bool(const=(left.type.const and right.type.const)))
320355

321356

322357
def logic_and(left: typing.Any, right: typing.Any, /) -> Expr:
@@ -354,7 +389,7 @@ def _equal_like(op: Binary.Op, left: typing.Any, right: typing.Any) -> Expr:
354389
if left.type.kind is not right.type.kind:
355390
raise TypeError(f"invalid types for '{op}': '{left.type}' and '{right.type}'")
356391
type = types.greater(left.type, right.type)
357-
return Binary(op, _coerce_lossless(left, type), _coerce_lossless(right, type), types.Bool())
392+
return Binary(op, _coerce_lossless(left, type), _coerce_lossless(right, type), types.Bool(const=type.const))
358393

359394

360395
def equal(left: typing.Any, right: typing.Any, /) -> Expr:
@@ -398,7 +433,7 @@ def _binary_relation(op: Binary.Op, left: typing.Any, right: typing.Any) -> Expr
398433
if left.type.kind is not right.type.kind or left.type.kind is types.Bool:
399434
raise TypeError(f"invalid types for '{op}': '{left.type}' and '{right.type}'")
400435
type = types.greater(left.type, right.type)
401-
return Binary(op, _coerce_lossless(left, type), _coerce_lossless(right, type), types.Bool())
436+
return Binary(op, _coerce_lossless(left, type), _coerce_lossless(right, type), types.Bool(const=type.const))
402437

403438

404439
def less(left: typing.Any, right: typing.Any, /) -> Expr:
@@ -485,7 +520,7 @@ def _shift_like(
485520
right = lift(right)
486521
if left.type.kind != types.Uint or right.type.kind != types.Uint:
487522
raise TypeError(f"invalid types for '{op}': '{left.type}' and '{right.type}'")
488-
return Binary(op, left, right, left.type)
523+
return Binary(op, left, right, types.Uint(width=left.type.width, const=(left.type.const and right.type.const)))
489524

490525

491526
def shift_left(left: typing.Any, right: typing.Any, /, type: types.Type | None = None) -> Expr:
@@ -553,4 +588,4 @@ def index(target: typing.Any, index: typing.Any, /) -> Expr:
553588
target, index = lift(target), lift(index)
554589
if target.type.kind is not types.Uint or index.type.kind is not types.Uint:
555590
raise TypeError(f"invalid types for indexing: '{target.type}' and '{index.type}'")
556-
return Index(target, index, types.Bool())
591+
return Index(target, index, types.Bool(const=target.type.const))

qiskit/circuit/classical/expr/visitors.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ def is_lvalue(node: expr.Expr, /) -> bool:
276276
>>> expr.is_lvalue(expr.lift(2))
277277
False
278278
279-
:class:`~.expr.Var` nodes are always l-values, because they always have some associated
280-
memory location::
279+
:class:`~.expr.Var` nodes are l-values (unless their resolution type is `const`!), because
280+
they have some associated memory location::
281281
282282
>>> from qiskit.circuit.classical import types
283283
>>> from qiskit.circuit import Clbit
@@ -297,4 +297,8 @@ def is_lvalue(node: expr.Expr, /) -> bool:
297297
>>> expr.is_lvalue(expr.bit_and(a, b))
298298
False
299299
"""
300+
if node.type.const:
301+
# If the expression's resolution type is const, then this can never be
302+
# an l-value (even if the expression is a Var).
303+
return False
300304
return node.accept(_IS_LVALUE)

qiskit/circuit/classical/types/ordering.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import enum
2828

29-
from .types import Type, Bool, Uint
29+
from .types import Type, Bool, Uint, Duration, Stretch
3030

3131

3232
# While the type system is simple, it's overkill to represent the complete partial ordering graph of
@@ -55,7 +55,7 @@ def __repr__(self):
5555
return str(self)
5656

5757

58-
def _order_bool_bool(_a: Bool, _b: Bool, /) -> Ordering:
58+
def _order_identical(_a: Type, _b: Type, /) -> Ordering:
5959
return Ordering.EQUAL
6060

6161

@@ -68,8 +68,10 @@ def _order_uint_uint(left: Uint, right: Uint, /) -> Ordering:
6868

6969

7070
_ORDERERS = {
71-
(Bool, Bool): _order_bool_bool,
71+
(Bool, Bool): _order_identical,
7272
(Uint, Uint): _order_uint_uint,
73+
(Duration, Duration): _order_identical,
74+
(Stretch, Stretch): _order_identical,
7375
}
7476

7577

@@ -90,7 +92,13 @@ def order(left: Type, right: Type, /) -> Ordering:
9092
"""
9193
if (orderer := _ORDERERS.get((left.kind, right.kind))) is None:
9294
return Ordering.NONE
93-
return orderer(left, right)
95+
order_ = orderer(left, right)
96+
if order_ is Ordering.EQUAL:
97+
if left.const is True and right.const is False:
98+
return Ordering.LESS
99+
if right.const is True and left.const is False:
100+
return Ordering.GREATER
101+
return order_
94102

95103

96104
def is_subtype(left: Type, right: Type, /, strict: bool = False) -> bool:
@@ -213,13 +221,20 @@ def cast_kind(from_: Type, to_: Type, /) -> CastKind:
213221
>>> from qiskit.circuit.classical import types
214222
>>> types.cast_kind(types.Bool(), types.Bool())
215223
<CastKind.EQUAL: 1>
216-
>>> types.cast_kind(types.Uint(8), types.Bool())
224+
>>> types.cast_kind(types.Uint(8, const=True), types.Bool())
217225
<CastKind.IMPLICIT: 2>
218226
>>> types.cast_kind(types.Bool(), types.Uint(8))
219227
<CastKind.LOSSLESS: 3>
220228
>>> types.cast_kind(types.Uint(16), types.Uint(8))
221229
<CastKind.DANGEROUS: 4>
222230
"""
231+
if to_.const is True and from_.const is False:
232+
# we can't cast to a const type
233+
return CastKind.NONE
223234
if (coercer := _ALLOWED_CASTS.get((from_.kind, to_.kind))) is None:
224235
return CastKind.NONE
225-
return coercer(from_, to_)
236+
cast_kind_ = coercer(from_, to_)
237+
if cast_kind_ is CastKind.EQUAL and to_.const != from_.const:
238+
# we need an implicit cast to drop const
239+
return CastKind.IMPLICIT
240+
return cast_kind_

qiskit/circuit/classical/types/types.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
"Type",
2424
"Bool",
2525
"Uint",
26+
"Duration",
27+
"Stretch"
2628
]
2729

2830
import typing
@@ -81,37 +83,79 @@ def __setstate__(self, state):
8183

8284

8385
@typing.final
84-
class Bool(Type, metaclass=_Singleton):
86+
class Bool(Type):
8587
"""The Boolean type. This has exactly two values: ``True`` and ``False``."""
8688

87-
__slots__ = ()
89+
__slots__ = ("const",)
90+
91+
def __init__(self, *, const: bool = False):
92+
super(Type, self).__setattr__("const", const)
8893

8994
def __repr__(self):
90-
return "Bool()"
95+
return f"Bool(const={self.const})"
9196

9297
def __hash__(self):
93-
return hash(self.__class__)
98+
return hash((self.__class__, self.const))
9499

95100
def __eq__(self, other):
96-
return isinstance(other, Bool)
101+
return isinstance(other, Bool) and self.const == other.const
97102

98103

99104
@typing.final
100105
class Uint(Type):
101106
"""An unsigned integer of fixed bit width."""
102107

103-
__slots__ = ("width",)
108+
__slots__ = ("const", "width",)
104109

105-
def __init__(self, width: int):
110+
def __init__(self, width: int, *, const: bool = False):
106111
if isinstance(width, int) and width <= 0:
107112
raise ValueError("uint width must be greater than zero")
113+
super(Type, self).__setattr__("const", const)
108114
super(Type, self).__setattr__("width", width)
109115

110116
def __repr__(self):
111-
return f"Uint({self.width})"
117+
return f"Uint({self.width}, const={self.const})"
118+
119+
def __hash__(self):
120+
return hash((self.__class__, self.const, self.width))
121+
122+
def __eq__(self, other):
123+
return isinstance(other, Uint) and self.const == other.const and self.width == other.width
124+
125+
126+
@typing.final
127+
class Duration(Type):
128+
"""A length of time, possibly negative."""
129+
130+
__slots__ = ("const",)
131+
132+
def __init__(self, *, const: bool = False):
133+
super(Type, self).__setattr__("const", const)
134+
135+
def __repr__(self):
136+
return f"Duration(const={self.const})"
137+
138+
def __hash__(self):
139+
return hash((self.__class__, self.const))
140+
141+
def __eq__(self, other):
142+
return isinstance(other, Duration) and self.const == other.const
143+
144+
145+
@typing.final
146+
class Stretch(Type):
147+
"""A special type that denotes some not-yet-known non-negative duration."""
148+
149+
__slots__ = ("const",)
150+
151+
def __init__(self, *, const: bool = False):
152+
super(Type, self).__setattr__("const", const)
153+
154+
def __repr__(self):
155+
return f"Stretch(const={self.const})"
112156

113157
def __hash__(self):
114-
return hash((self.__class__, self.width))
158+
return hash((self.__class__, self.const))
115159

116160
def __eq__(self, other):
117-
return isinstance(other, Uint) and self.width == other.width
161+
return isinstance(other, Stretch) and self.const == other.const

0 commit comments

Comments
 (0)