Skip to content

Commit 3e08a75

Browse files
authored
Add CSV Decoder::capacity (#3674) (#3677)
* Add CSV Decoder::capacity (#3674) * Add test * Remove unnecessary extern * Add docs
1 parent 5b1821e commit 3e08a75

File tree

1 file changed

+81
-2
lines changed

1 file changed

+81
-2
lines changed

arrow-csv/src/reader/mod.rs

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,15 @@ impl<R: BufRead> BufReader<R> {
438438
loop {
439439
let buf = self.reader.fill_buf()?;
440440
let decoded = self.decoder.decode(buf)?;
441-
if decoded == 0 {
441+
self.reader.consume(decoded);
442+
// Yield if decoded no bytes or the decoder is full
443+
//
444+
// The capacity check avoids looping around and potentially
445+
// blocking reading data in fill_buf that isn't needed
446+
// to flush the next batch
447+
if decoded == 0 || self.decoder.capacity() == 0 {
442448
break;
443449
}
444-
self.reader.consume(decoded);
445450
}
446451

447452
self.decoder.flush()
@@ -574,6 +579,11 @@ impl Decoder {
574579
self.line_number += rows.len();
575580
Ok(Some(batch))
576581
}
582+
583+
/// Returns the number of records that can be read before requiring a call to [`Self::flush`]
584+
pub fn capacity(&self) -> usize {
585+
self.batch_size - self.record_decoder.len()
586+
}
577587
}
578588

579589
/// Parses a slice of [`StringRecords`] into a [RecordBatch]
@@ -2269,4 +2279,73 @@ mod tests {
22692279
"Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
22702280
);
22712281
}
2282+
2283+
struct InstrumentedRead<R> {
2284+
r: R,
2285+
fill_count: usize,
2286+
fill_sizes: Vec<usize>,
2287+
}
2288+
2289+
impl<R> InstrumentedRead<R> {
2290+
fn new(r: R) -> Self {
2291+
Self {
2292+
r,
2293+
fill_count: 0,
2294+
fill_sizes: vec![],
2295+
}
2296+
}
2297+
}
2298+
2299+
impl<R: Seek> Seek for InstrumentedRead<R> {
2300+
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2301+
self.r.seek(pos)
2302+
}
2303+
}
2304+
2305+
impl<R: BufRead> Read for InstrumentedRead<R> {
2306+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2307+
self.r.read(buf)
2308+
}
2309+
}
2310+
2311+
impl<R: BufRead> BufRead for InstrumentedRead<R> {
2312+
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2313+
self.fill_count += 1;
2314+
let buf = self.r.fill_buf()?;
2315+
self.fill_sizes.push(buf.len());
2316+
Ok(buf)
2317+
}
2318+
2319+
fn consume(&mut self, amt: usize) {
2320+
self.r.consume(amt)
2321+
}
2322+
}
2323+
2324+
#[test]
2325+
fn test_io() {
2326+
let schema = Arc::new(Schema::new(vec![
2327+
Field::new("a", DataType::Utf8, false),
2328+
Field::new("b", DataType::Utf8, false),
2329+
]));
2330+
let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2331+
let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2332+
let reader = ReaderBuilder::new()
2333+
.with_schema(schema)
2334+
.with_batch_size(3)
2335+
.build_buffered(&mut read)
2336+
.unwrap();
2337+
2338+
let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2339+
assert_eq!(batches.len(), 2);
2340+
assert_eq!(batches[0].num_rows(), 3);
2341+
assert_eq!(batches[1].num_rows(), 1);
2342+
2343+
// Expect 4 calls to fill_buf
2344+
// 1. Read first 3 rows
2345+
// 2. Read final row
2346+
// 3. Delimit and flush final row
2347+
// 4. Iterator finished
2348+
assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2349+
assert_eq!(read.fill_count, 4);
2350+
}
22722351
}

0 commit comments

Comments
 (0)