@@ -438,10 +438,15 @@ impl<R: BufRead> BufReader<R> {
438438 loop {
439439 let buf = self . reader . fill_buf ( ) ?;
440440 let decoded = self . decoder . decode ( buf) ?;
441- if decoded == 0 {
441+ self . reader . consume ( decoded) ;
442+ // Yield if decoded no bytes or the decoder is full
443+ //
444+ // The capacity check avoids looping around and potentially
445+ // blocking reading data in fill_buf that isn't needed
446+ // to flush the next batch
447+ if decoded == 0 || self . decoder . capacity ( ) == 0 {
442448 break ;
443449 }
444- self . reader . consume ( decoded) ;
445450 }
446451
447452 self . decoder . flush ( )
@@ -574,6 +579,11 @@ impl Decoder {
574579 self . line_number += rows. len ( ) ;
575580 Ok ( Some ( batch) )
576581 }
582+
583+ /// Returns the number of records that can be read before requiring a call to [`Self::flush`]
584+ pub fn capacity ( & self ) -> usize {
585+ self . batch_size - self . record_decoder . len ( )
586+ }
577587}
578588
579589/// Parses a slice of [`StringRecords`] into a [RecordBatch]
@@ -2269,4 +2279,73 @@ mod tests {
22692279 "Csv error: Encountered invalid UTF-8 data for line 1 and field 1" ,
22702280 ) ;
22712281 }
2282+
2283+ struct InstrumentedRead < R > {
2284+ r : R ,
2285+ fill_count : usize ,
2286+ fill_sizes : Vec < usize > ,
2287+ }
2288+
2289+ impl < R > InstrumentedRead < R > {
2290+ fn new ( r : R ) -> Self {
2291+ Self {
2292+ r,
2293+ fill_count : 0 ,
2294+ fill_sizes : vec ! [ ] ,
2295+ }
2296+ }
2297+ }
2298+
2299+ impl < R : Seek > Seek for InstrumentedRead < R > {
2300+ fn seek ( & mut self , pos : SeekFrom ) -> std:: io:: Result < u64 > {
2301+ self . r . seek ( pos)
2302+ }
2303+ }
2304+
2305+ impl < R : BufRead > Read for InstrumentedRead < R > {
2306+ fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
2307+ self . r . read ( buf)
2308+ }
2309+ }
2310+
2311+ impl < R : BufRead > BufRead for InstrumentedRead < R > {
2312+ fn fill_buf ( & mut self ) -> std:: io:: Result < & [ u8 ] > {
2313+ self . fill_count += 1 ;
2314+ let buf = self . r . fill_buf ( ) ?;
2315+ self . fill_sizes . push ( buf. len ( ) ) ;
2316+ Ok ( buf)
2317+ }
2318+
2319+ fn consume ( & mut self , amt : usize ) {
2320+ self . r . consume ( amt)
2321+ }
2322+ }
2323+
2324+ #[ test]
2325+ fn test_io ( ) {
2326+ let schema = Arc :: new ( Schema :: new ( vec ! [
2327+ Field :: new( "a" , DataType :: Utf8 , false ) ,
2328+ Field :: new( "b" , DataType :: Utf8 , false ) ,
2329+ ] ) ) ;
2330+ let csv = "foo,bar\n baz,foo\n a,b\n c,d" ;
2331+ let mut read = InstrumentedRead :: new ( Cursor :: new ( csv. as_bytes ( ) ) ) ;
2332+ let reader = ReaderBuilder :: new ( )
2333+ . with_schema ( schema)
2334+ . with_batch_size ( 3 )
2335+ . build_buffered ( & mut read)
2336+ . unwrap ( ) ;
2337+
2338+ let batches = reader. collect :: < Result < Vec < _ > , _ > > ( ) . unwrap ( ) ;
2339+ assert_eq ! ( batches. len( ) , 2 ) ;
2340+ assert_eq ! ( batches[ 0 ] . num_rows( ) , 3 ) ;
2341+ assert_eq ! ( batches[ 1 ] . num_rows( ) , 1 ) ;
2342+
2343+ // Expect 4 calls to fill_buf
2344+ // 1. Read first 3 rows
2345+ // 2. Read final row
2346+ // 3. Delimit and flush final row
2347+ // 4. Iterator finished
2348+ assert_eq ! ( & read. fill_sizes, & [ 23 , 3 , 0 , 0 ] ) ;
2349+ assert_eq ! ( read. fill_count, 4 ) ;
2350+ }
22722351}
0 commit comments