Skip to content

Commit 61b5763

Browse files
authored
pyarrow: Small code simplifications (#9594)
# Rationale for this change Makes the code simpler and more readable by relying on new PyO3 and Rust features. No behavior should have changed outside of an error message if `__arrow_c_array__` does not return a tuple # What changes are included in this PR? - use `.call_method0(M)?` instead of `.getattr(M)?.call0()` - Use `.extract()` that allows more advanced features like directly extracting tuple elements - remove temporary variables just before returning - use &raw const and &raw mut pointers instead of casting and addr_of!
1 parent 51bf8a4 commit 61b5763

1 file changed

Lines changed: 48 additions & 107 deletions

File tree

arrow-pyarrow/src/lib.rs

Lines changed: 48 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
6262
use std::convert::{From, TryFrom};
6363
use std::ffi::CStr;
64-
use std::ptr::{addr_of, addr_of_mut};
6564
use std::sync::Arc;
6665

6766
use arrow_array::ffi;
@@ -156,36 +155,27 @@ impl FromPyArrow for DataType {
156155
// method, so prefer it over _export_to_c.
157156
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
158157
if value.hasattr("__arrow_c_schema__")? {
159-
let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
160-
let capsule = capsule.cast::<PyCapsule>()?;
161-
validate_pycapsule(capsule, "arrow_schema")?;
158+
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
159+
validate_pycapsule(&capsule, "arrow_schema")?;
162160

163161
let schema_ptr = capsule
164162
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
165163
.cast::<FFI_ArrowSchema>();
166-
unsafe {
167-
let dtype = DataType::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
168-
return Ok(dtype);
169-
}
164+
return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
170165
}
171166

172167
validate_class(data_type_class(value.py())?, value)?;
173168

174-
let c_schema = FFI_ArrowSchema::empty();
175-
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
176-
value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
177-
let dtype = DataType::try_from(&c_schema).map_err(to_py_err)?;
178-
Ok(dtype)
169+
let mut c_schema = FFI_ArrowSchema::empty();
170+
value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
171+
DataType::try_from(&c_schema).map_err(to_py_err)
179172
}
180173
}
181174

182175
impl ToPyArrow for DataType {
183176
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
184177
let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
185-
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
186-
let dtype =
187-
data_type_class(py)?.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
188-
Ok(dtype)
178+
data_type_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
189179
}
190180
}
191181

@@ -195,36 +185,27 @@ impl FromPyArrow for Field {
195185
// method, so prefer it over _export_to_c.
196186
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
197187
if value.hasattr("__arrow_c_schema__")? {
198-
let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
199-
let capsule = capsule.cast::<PyCapsule>()?;
200-
validate_pycapsule(capsule, "arrow_schema")?;
188+
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
189+
validate_pycapsule(&capsule, "arrow_schema")?;
201190

202191
let schema_ptr = capsule
203192
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
204193
.cast::<FFI_ArrowSchema>();
205-
unsafe {
206-
let field = Field::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
207-
return Ok(field);
208-
}
194+
return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
209195
}
210196

211197
validate_class(field_class(value.py())?, value)?;
212198

213-
let c_schema = FFI_ArrowSchema::empty();
214-
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
215-
value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
216-
let field = Field::try_from(&c_schema).map_err(to_py_err)?;
217-
Ok(field)
199+
let mut c_schema = FFI_ArrowSchema::empty();
200+
value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
201+
Field::try_from(&c_schema).map_err(to_py_err)
218202
}
219203
}
220204

221205
impl ToPyArrow for Field {
222206
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
223207
let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
224-
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
225-
let dtype =
226-
field_class(py)?.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
227-
Ok(dtype)
208+
field_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
228209
}
229210
}
230211

@@ -234,36 +215,27 @@ impl FromPyArrow for Schema {
234215
// method, so prefer it over _export_to_c.
235216
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
236217
if value.hasattr("__arrow_c_schema__")? {
237-
let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
238-
let capsule = capsule.cast::<PyCapsule>()?;
239-
validate_pycapsule(capsule, "arrow_schema")?;
218+
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
219+
validate_pycapsule(&capsule, "arrow_schema")?;
240220

241221
let schema_ptr = capsule
242222
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
243223
.cast::<FFI_ArrowSchema>();
244-
unsafe {
245-
let schema = Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
246-
return Ok(schema);
247-
}
224+
return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
248225
}
249226

250227
validate_class(schema_class(value.py())?, value)?;
251228

252-
let c_schema = FFI_ArrowSchema::empty();
253-
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
254-
value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
255-
let schema = Schema::try_from(&c_schema).map_err(to_py_err)?;
256-
Ok(schema)
229+
let mut c_schema = FFI_ArrowSchema::empty();
230+
value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
231+
Schema::try_from(&c_schema).map_err(to_py_err)
257232
}
258233
}
259234

260235
impl ToPyArrow for Schema {
261236
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
262237
let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
263-
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
264-
let schema =
265-
schema_class(py)?.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
266-
Ok(schema)
238+
schema_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
267239
}
268240
}
269241

@@ -273,21 +245,11 @@ impl FromPyArrow for ArrayData {
273245
// method, so prefer it over _export_to_c.
274246
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
275247
if value.hasattr("__arrow_c_array__")? {
276-
let tuple = value.getattr("__arrow_c_array__")?.call0()?;
277-
278-
if !tuple.is_instance_of::<PyTuple>() {
279-
return Err(PyTypeError::new_err(
280-
"Expected __arrow_c_array__ to return a tuple.",
281-
));
282-
}
283-
284-
let schema_capsule = tuple.get_item(0)?;
285-
let schema_capsule = schema_capsule.cast::<PyCapsule>()?;
286-
let array_capsule = tuple.get_item(1)?;
287-
let array_capsule = array_capsule.cast::<PyCapsule>()?;
248+
let (schema_capsule, array_capsule) =
249+
value.call_method0("__arrow_c_array__")?.extract()?;
288250

289-
validate_pycapsule(schema_capsule, "arrow_schema")?;
290-
validate_pycapsule(array_capsule, "arrow_array")?;
251+
validate_pycapsule(&schema_capsule, "arrow_schema")?;
252+
validate_pycapsule(&array_capsule, "arrow_array")?;
291253

292254
let schema_ptr = schema_capsule
293255
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
@@ -315,8 +277,8 @@ impl FromPyArrow for ArrayData {
315277
value.call_method1(
316278
"_export_to_c",
317279
(
318-
addr_of_mut!(array) as Py_uintptr_t,
319-
addr_of_mut!(schema) as Py_uintptr_t,
280+
&raw mut array as Py_uintptr_t,
281+
&raw mut schema as Py_uintptr_t,
320282
),
321283
)?;
322284

@@ -328,15 +290,13 @@ impl ToPyArrow for ArrayData {
328290
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
329291
let array = FFI_ArrowArray::new(self);
330292
let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
331-
332-
let array = array_class(py)?.call_method1(
293+
array_class(py)?.call_method1(
333294
"_import_from_c",
334295
(
335-
addr_of!(array) as Py_uintptr_t,
336-
addr_of!(schema) as Py_uintptr_t,
296+
&raw const array as Py_uintptr_t,
297+
&raw const schema as Py_uintptr_t,
337298
),
338-
)?;
339-
Ok(array)
299+
)
340300
}
341301
}
342302

@@ -364,21 +324,11 @@ impl FromPyArrow for RecordBatch {
364324
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
365325

366326
if value.hasattr("__arrow_c_array__")? {
367-
let tuple = value.getattr("__arrow_c_array__")?.call0()?;
327+
let (schema_capsule, array_capsule) =
328+
value.call_method0("__arrow_c_array__")?.extract()?;
368329

369-
if !tuple.is_instance_of::<PyTuple>() {
370-
return Err(PyTypeError::new_err(
371-
"Expected __arrow_c_array__ to return a tuple.",
372-
));
373-
}
374-
375-
let schema_capsule = tuple.get_item(0)?;
376-
let schema_capsule = schema_capsule.cast::<PyCapsule>()?;
377-
let array_capsule = tuple.get_item(1)?;
378-
let array_capsule = array_capsule.cast::<PyCapsule>()?;
379-
380-
validate_pycapsule(schema_capsule, "arrow_schema")?;
381-
validate_pycapsule(array_capsule, "arrow_array")?;
330+
validate_pycapsule(&schema_capsule, "arrow_schema")?;
331+
validate_pycapsule(&array_capsule, "arrow_array")?;
382332

383333
let schema_ptr = schema_capsule
384334
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
@@ -455,9 +405,9 @@ impl FromPyArrow for ArrowArrayStreamReader {
455405
// method, so prefer it over _export_to_c.
456406
// See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
457407
if value.hasattr("__arrow_c_stream__")? {
458-
let capsule = value.getattr("__arrow_c_stream__")?.call0()?;
459-
let capsule = capsule.cast::<PyCapsule>()?;
460-
validate_pycapsule(capsule, "arrow_array_stream")?;
408+
let capsule = value.call_method0("__arrow_c_stream__")?.extract()?;
409+
410+
validate_pycapsule(&capsule, "arrow_array_stream")?;
461411

462412
let stream = unsafe {
463413
FFI_ArrowArrayStream::from_raw(
@@ -476,20 +426,17 @@ impl FromPyArrow for ArrowArrayStreamReader {
476426

477427
validate_class(record_batch_reader_class(value.py())?, value)?;
478428

479-
// prepare a pointer to receive the stream struct
429+
// prepare the stream struct to receive the content
480430
let mut stream = FFI_ArrowArrayStream::empty();
481-
let stream_ptr = &mut stream as *mut FFI_ArrowArrayStream;
482431

483432
// make the conversion through PyArrow's private API
484433
// this changes the pointer's memory and is thus unsafe.
485434
// In particular, `_export_to_c` can go out of bounds
486-
let args = PyTuple::new(value.py(), [stream_ptr as Py_uintptr_t])?;
435+
let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?;
487436
value.call_method1("_export_to_c", args)?;
488437

489-
let stream_reader = ArrowArrayStreamReader::try_new(stream)
490-
.map_err(|err| PyValueError::new_err(err.to_string()))?;
491-
492-
Ok(stream_reader)
438+
ArrowArrayStreamReader::try_new(stream)
439+
.map_err(|err| PyValueError::new_err(err.to_string()))
493440
}
494441
}
495442

@@ -498,13 +445,9 @@ impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
498445
// We can't implement `ToPyArrow` for `T: RecordBatchReader + Send` because
499446
// there is already a blanket implementation for `T: ToPyArrow`.
500447
fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
501-
let mut stream = FFI_ArrowArrayStream::new(self);
502-
503-
let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream;
504-
let reader = record_batch_reader_class(py)?
505-
.call_method1("_import_from_c", (stream_ptr as Py_uintptr_t,))?;
506-
507-
Ok(reader)
448+
let stream = FFI_ArrowArrayStream::new(self);
449+
record_batch_reader_class(py)?
450+
.call_method1("_import_from_c", (&raw const stream as Py_uintptr_t,))
508451
}
509452
}
510453

@@ -588,7 +531,7 @@ impl FromPyArrow for Table {
588531
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
589532
let reader: Box<dyn RecordBatchReader> =
590533
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
591-
Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
534+
Self::try_from(reader).map_err(|err| PyValueError::new_err(err.to_string()))
592535
}
593536
}
594537

@@ -601,9 +544,7 @@ impl IntoPyArrow for Table {
601544
let kwargs = PyDict::new(py);
602545
kwargs.set_item("schema", py_schema)?;
603546

604-
let reader = table_class(py)?.call_method("from_batches", (py_batches,), Some(&kwargs))?;
605-
606-
Ok(reader)
547+
table_class(py)?.call_method("from_batches", (py_batches,), Some(&kwargs))
607548
}
608549
}
609550

@@ -664,7 +605,7 @@ impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType<T> {
664605

665606
type Error = PyErr;
666607

667-
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, PyErr> {
608+
fn into_pyobject(self, py: Python<'py>) -> PyResult<Self::Output> {
668609
self.0.into_pyarrow(py)
669610
}
670611
}

0 commit comments

Comments
 (0)