Skip to content

Commit 78b2951

Browse files
committed
Address review comments.
1 parent 7fea216 commit 78b2951

5 files changed

Lines changed: 84 additions & 44 deletions

File tree

crates/circuit/src/duration.rs

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
// that they have been altered from the originals.
1212

1313
use pyo3::prelude::*;
14+
use pyo3::IntoPyObjectExt;
1415

15-
#[pyclass(eq, module = "qiskit._accelerate.circuit")]
16-
#[derive(PartialEq, Clone, Copy, Debug)]
17-
#[allow(non_camel_case_types)]
1816
/// A length of time used to express circuit timing.
1917
///
2018
/// It defines a group of classes which are all subclasses of itself (functionally, an
@@ -30,23 +28,54 @@ use pyo3::prelude::*;
3028
/// case _:
3129
/// raise ValueError("expected dt or seconds")
3230
///
33-
/// And in Python 3.9, you can use ``isinstance`` to determine which variant
31+
/// And in Python 3.9, you can use :meth:`Duration.unit` to determine which variant
3432
/// is populated::
3533
///
36-
/// if isinstance(duration, Duration.dt):
37-
/// return duration[0]
38-
/// elif isinstance(duration, Duration.s):
39-
/// return duration[0] / 5e-7
34+
/// if duration.unit() == "dt":
35+
/// return duration.value()
36+
/// elif duration.unit() == "s":
37+
/// return duration.value() / 5e-7
4038
/// else:
4139
/// raise ValueError("expected dt or seconds")
40+
#[pyclass(eq, module = "qiskit._accelerate.circuit")]
41+
#[derive(PartialEq, Clone, Copy, Debug)]
42+
#[allow(non_camel_case_types)]
4243
pub enum Duration {
43-
dt(u64),
44+
dt(i64),
4445
ns(f64),
4546
us(f64),
4647
ms(f64),
4748
s(f64),
4849
}
4950

51+
#[pymethods]
52+
impl Duration {
53+
/// The corresponding ``unit`` of the duration.
54+
fn unit(&self) -> &'static str {
55+
match self {
56+
Duration::dt(_) => "dt",
57+
Duration::us(_) => "us",
58+
Duration::ns(_) => "ns",
59+
Duration::ms(_) => "ms",
60+
Duration::s(_) => "s",
61+
}
62+
}
63+
64+
/// The ``value`` of the duration.
65+
///
66+
/// This will be a Python ``int`` if the :meth:`~Duration.unit` is ``"dt"``,
67+
/// else a ``float``.
68+
#[pyo3(name = "value")]
69+
fn py_value(&self, py: Python) -> PyResult<PyObject> {
70+
match self {
71+
Duration::dt(v) => v.into_py_any(py),
72+
Duration::us(v) | Duration::ns(v) | Duration::ms(v) | Duration::s(v) => {
73+
v.into_py_any(py)
74+
}
75+
}
76+
}
77+
}
78+
5079
impl Duration {
5180
fn __repr__(&self) -> String {
5281
match self {

qiskit/qasm3/exporter.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,16 +1283,17 @@ def visit_value(self, node, /):
12831283
if node.type.kind is types.Float:
12841284
return ast.FloatLiteral(node.value)
12851285
if node.type.kind is types.Duration:
1286-
if isinstance(node.value, Duration.dt):
1287-
return ast.DurationLiteral(node.value[0], ast.DurationUnit.SAMPLE)
1288-
if isinstance(node.value, Duration.ns):
1289-
return ast.DurationLiteral(node.value[0], ast.DurationUnit.NANOSECOND)
1290-
if isinstance(node.value, Duration.us):
1291-
return ast.DurationLiteral(node.value[0], ast.DurationUnit.MICROSECOND)
1292-
if isinstance(node.value, Duration.ms):
1293-
return ast.DurationLiteral(node.value[0], ast.DurationUnit.MILLISECOND)
1294-
if isinstance(node.value, Duration.s):
1295-
return ast.DurationLiteral(node.value[0], ast.DurationUnit.SECOND)
1286+
unit = node.value.unit()
1287+
if unit == "dt":
1288+
return ast.DurationLiteral(node.value.value(), ast.DurationUnit.SAMPLE)
1289+
if unit == "ns":
1290+
return ast.DurationLiteral(node.value.value(), ast.DurationUnit.NANOSECOND)
1291+
if unit == "us":
1292+
return ast.DurationLiteral(node.value.value(), ast.DurationUnit.MICROSECOND)
1293+
if unit == "ms":
1294+
return ast.DurationLiteral(node.value.value(), ast.DurationUnit.MILLISECOND)
1295+
if unit == "s":
1296+
return ast.DurationLiteral(node.value.value(), ast.DurationUnit.SECOND)
12961297
raise RuntimeError(f"unhandled Value type '{node}'")
12971298

12981299
def visit_cast(self, node, /):

qiskit/qpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def open(*args):
427427
addition to the existing encodings for int and bool. The new value type encodings are below:
428428
429429
=========================== ========= ============================================================
430-
Python type / class Type code Payload
430+
Python type Type code Payload
431431
=========================== ========= ============================================================
432432
``float`` ``f`` One ``double value``.
433433

qiskit/qpy/binary_io/value.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -402,21 +402,30 @@ def _write_expr_type(file_obj, type_: types.Type, version: int):
402402

403403

404404
def _write_duration(file_obj, duration: Duration):
405-
if isinstance(duration, Duration.dt):
405+
unit = duration.unit()
406+
if unit == "dt":
406407
file_obj.write(type_keys.CircuitDuration.DT)
407-
file_obj.write(struct.pack(formats.DURATION_DT_PACK, *formats.DURATION_DT(duration[0])))
408-
elif isinstance(duration, Duration.ns):
408+
file_obj.write(
409+
struct.pack(formats.DURATION_DT_PACK, *formats.DURATION_DT(duration.value()))
410+
)
411+
elif unit == "ns":
409412
file_obj.write(type_keys.CircuitDuration.NS)
410-
file_obj.write(struct.pack(formats.DURATION_NS_PACK, *formats.DURATION_NS(duration[0])))
411-
elif isinstance(duration, Duration.us):
413+
file_obj.write(
414+
struct.pack(formats.DURATION_NS_PACK, *formats.DURATION_NS(duration.value()))
415+
)
416+
elif unit == "us":
412417
file_obj.write(type_keys.CircuitDuration.US)
413-
file_obj.write(struct.pack(formats.DURATION_US_PACK, *formats.DURATION_US(duration[0])))
414-
elif isinstance(duration, Duration.ms):
418+
file_obj.write(
419+
struct.pack(formats.DURATION_US_PACK, *formats.DURATION_US(duration.value()))
420+
)
421+
elif unit == "ms":
415422
file_obj.write(type_keys.CircuitDuration.MS)
416-
file_obj.write(struct.pack(formats.DURATION_MS_PACK, *formats.DURATION_MS(duration[0])))
417-
elif isinstance(duration, Duration.s):
423+
file_obj.write(
424+
struct.pack(formats.DURATION_MS_PACK, *formats.DURATION_MS(duration.value()))
425+
)
426+
elif unit == "s":
418427
file_obj.write(type_keys.CircuitDuration.S)
419-
file_obj.write(struct.pack(formats.DURATION_S_PACK, *formats.DURATION_S(duration[0])))
428+
file_obj.write(struct.pack(formats.DURATION_S_PACK, *formats.DURATION_S(duration.value())))
420429
else:
421430
raise exceptions.QpyError(f"unhandled Duration object '{duration};")
422431

qiskit/qpy/type_keys.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -367,20 +367,21 @@ class CircuitDuration(TypeKeyBase):
367367

368368
@classmethod
369369
def assign(cls, obj):
370-
if isinstance(obj, Duration.dt):
371-
return cls.DT
372-
elif isinstance(obj, Duration.ns):
373-
return cls.NS
374-
elif isinstance(obj, Duration.us):
375-
return cls.US
376-
elif isinstance(obj, Duration.ms):
377-
return cls.MS
378-
elif isinstance(obj, Duration.s):
379-
return cls.S
380-
else:
381-
raise exceptions.QpyError(
382-
f"Object type '{type(obj)}' is not supported in {cls.__name__} namespace."
383-
)
370+
if isinstance(obj, Duration):
371+
unit = obj.unit()
372+
if unit == "dt":
373+
return cls.DT
374+
if unit == "ns":
375+
return cls.NS
376+
if unit == "us":
377+
return cls.US
378+
if unit == "ms":
379+
return cls.MS
380+
if unit == "s":
381+
return cls.S
382+
raise exceptions.QpyError(
383+
f"Object type '{type(obj)}' is not supported in {cls.__name__} namespace."
384+
)
384385

385386
@classmethod
386387
def retrieve(cls, type_key):

0 commit comments

Comments
 (0)