@@ -520,22 +520,27 @@ impl<'py> BoundListIterator<'py> {
520520 /// The caller must hold an active critical section on this list.
521521 #[ inline]
522522 #[ cfg( not( feature = "nightly" ) ) ]
523- unsafe fn nth_locked (
523+ unsafe fn nth_unsynchronized (
524524 index : & mut Index ,
525525 length : & mut Length ,
526526 list : & Bound < ' py , PyList > ,
527527 n : usize ,
528528 ) -> Option < Bound < ' py , PyAny > > {
529- let length = length. 0 . min ( list. len ( ) ) ;
530- let target_index = index. 0 + n ;
531- if target_index < length {
532- // SAFETY: target_index < list.len() guarantees in bounds
533- let item = unsafe { list. get_item_unchecked ( target_index) } ;
534- index . 0 = target_index + 1 ;
535- Some ( item )
536- } else {
537- None
529+ let current_length = length. 0 . min ( list. len ( ) ) ;
530+ if let Some ( target_index) = index. 0 . checked_add ( n ) {
531+ if target_index < current_length {
532+ // SAFETY: target_index < current_length guarantees in bounds
533+ let item = unsafe { list. get_item_unchecked ( target_index) } ;
534+ // +1 cannot overflow as target_index < current_length
535+ index . 0 = target_index + 1 ;
536+ return Some ( item ) ;
537+ }
538538 }
539+
540+ // n overflows the remaining length of the list;
541+ // nth must exhaust all remaining items
542+ index. 0 = current_length;
543+ None
539544 }
540545
541546 /// # Safety
@@ -564,22 +569,28 @@ impl<'py> BoundListIterator<'py> {
564569 /// The caller must hold an active critical section on this list.
565570 #[ inline]
566571 #[ cfg( not( feature = "nightly" ) ) ]
567- unsafe fn nth_back_locked (
572+ unsafe fn nth_back_unsynchronized (
568573 index : & mut Index ,
569574 length : & mut Length ,
570575 list : & Bound < ' py , PyList > ,
571576 n : usize ,
572577 ) -> Option < Bound < ' py , PyAny > > {
573- let length_size = length. 0 . min ( list. len ( ) ) ;
574- if index. 0 + n < length_size {
575- let target_index = length_size - n - 1 ;
576- // SAFETY: target_index < list.len() guarantees in bounds
577- let item = unsafe { list. get_item_unchecked ( target_index) } ;
578- length. 0 = target_index;
579- Some ( item)
580- } else {
581- None
578+ let current_length = length. 0 . min ( list. len ( ) ) ;
579+ if let Some ( index_after_item) = current_length. checked_sub ( n) {
580+ if index. 0 < index_after_item {
581+ // -1 cannot underflow as index_after_item > index
582+ let target_index = index_after_item - 1 ;
583+ // SAFETY: target_index is < current_length
584+ let item = unsafe { list. get_item_unchecked ( target_index) } ;
585+ length. 0 = target_index;
586+ return Some ( item) ;
587+ }
582588 }
589+
590+ // n overflows the remaining length of the tuple;
591+ // nth must exhaust all remaining items
592+ length. 0 = index. 0 ;
593+ None
583594 }
584595
585596 fn with_critical_section < R > (
@@ -611,7 +622,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
611622 fn nth ( & mut self , n : usize ) -> Option < Self :: Item > {
612623 // SAFETY: with_critical_section locks the list
613624 self . with_critical_section ( |index, length, list| unsafe {
614- Self :: nth_locked ( index, length, list, n)
625+ Self :: nth_unsynchronized ( index, length, list, n)
615626 } )
616627 }
617628
@@ -768,25 +779,17 @@ impl<'py> Iterator for BoundListIterator<'py> {
768779 #[ cfg( feature = "nightly" ) ]
769780 fn advance_by ( & mut self , n : usize ) -> Result < ( ) , NonZero < usize > > {
770781 self . with_critical_section ( |index, length, list| {
771- let max_len = length. 0 . min ( list. len ( ) ) ;
772- let currently_at = index. 0 ;
773- if currently_at >= max_len {
774- if n == 0 {
775- return Ok ( ( ) ) ;
776- } else {
777- return Err ( unsafe { NonZero :: new_unchecked ( n) } ) ;
778- }
782+ let current_length = length. 0 . min ( list. len ( ) ) ;
783+ let items_left = current_length. saturating_sub ( index. 0 ) ;
784+ if let Some ( overflow) = NonZero :: new ( n. saturating_sub ( items_left) ) {
785+ // n overflows the remaining length of the list; advance_by must exhaust all remaining items
786+ index. 0 = current_length;
787+ return Err ( overflow) ;
779788 }
780789
781- let items_left = max_len - currently_at;
782- if n <= items_left {
783- index. 0 += n;
784- Ok ( ( ) )
785- } else {
786- index. 0 = max_len;
787- let remainder = n - items_left;
788- Err ( unsafe { NonZero :: new_unchecked ( remainder) } )
789- }
790+ // cannot overflow as length - index >= n
791+ index. 0 += n;
792+ Ok ( ( ) )
790793 } )
791794 }
792795}
@@ -804,7 +807,7 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
804807 fn nth_back ( & mut self , n : usize ) -> Option < Self :: Item > {
805808 // SAFETY: with_critical_section locks the list
806809 self . with_critical_section ( |index, length, list| unsafe {
807- Self :: nth_back_locked ( index, length, list, n)
810+ Self :: nth_back_unsynchronized ( index, length, list, n)
808811 } )
809812 }
810813
@@ -847,25 +850,17 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
847850 #[ cfg( feature = "nightly" ) ]
848851 fn advance_back_by ( & mut self , n : usize ) -> Result < ( ) , NonZero < usize > > {
849852 self . with_critical_section ( |index, length, list| {
850- let max_len = length. 0 . min ( list. len ( ) ) ;
851- let currently_at = index. 0 ;
852- if currently_at >= max_len {
853- if n == 0 {
854- return Ok ( ( ) ) ;
855- } else {
856- return Err ( unsafe { NonZero :: new_unchecked ( n) } ) ;
857- }
853+ let current_length = length. 0 . min ( list. len ( ) ) ;
854+ let items_left = current_length. saturating_sub ( index. 0 ) ;
855+ if let Some ( overflow) = NonZero :: new ( n. saturating_sub ( items_left) ) {
856+ // n overflows the remaining length of the list; advance_back_by must exhaust all remaining items
857+ length. 0 = index. 0 ;
858+ return Err ( overflow) ;
858859 }
859860
860- let items_left = max_len - currently_at;
861- if n <= items_left {
862- length. 0 = max_len - n;
863- Ok ( ( ) )
864- } else {
865- length. 0 = currently_at;
866- let remainder = n - items_left;
867- Err ( unsafe { NonZero :: new_unchecked ( remainder) } )
868- }
861+ // cannot overflow as current_length - index >= n
862+ length. 0 = current_length - n;
863+ Ok ( ( ) )
869864 } )
870865 }
871866}
@@ -1594,6 +1589,19 @@ mod tests {
15941589 assert_eq ! ( iter. nth_back( 1 ) . unwrap( ) . extract:: <i32 >( ) . unwrap( ) , 9 ) ;
15951590 assert_eq ! ( iter. nth( 2 ) . unwrap( ) . extract:: <i32 >( ) . unwrap( ) , 8 ) ;
15961591 assert ! ( iter. next( ) . is_none( ) ) ;
1592+
1593+ // nth consumes all elements in the list, even on `None` return
1594+ let mut iter = list. iter ( ) ;
1595+ assert ! ( iter. nth( 100 ) . is_none( ) ) ;
1596+ assert ! ( iter. next( ) . is_none( ) ) ;
1597+ assert ! ( iter. next_back( ) . is_none( ) ) ;
1598+
1599+ // nth should not overflow the iterator
1600+ // a naive implementation of nth will overflow if number of advanced
1601+ // elements plus N overflows usize::MAX
1602+ let mut iter = list. iter ( ) ;
1603+ assert ! ( iter. next( ) . is_some( ) ) ;
1604+ assert ! ( iter. nth( usize :: MAX ) . is_none( ) ) ;
15971605 } ) ;
15981606 }
15991607
@@ -1651,6 +1659,17 @@ mod tests {
16511659 iter3. nth ( 1 ) ;
16521660 assert_eq ! ( iter3. nth_back( 2 ) . unwrap( ) . extract:: <i32 >( ) . unwrap( ) , 3 ) ;
16531661 assert ! ( iter3. nth_back( 0 ) . is_none( ) ) ;
1662+
1663+ // nth_back consumes all elements in the list, even on `None` return
1664+ let mut iter4 = list. iter ( ) ;
1665+ assert ! ( iter4. nth_back( 100 ) . is_none( ) ) ;
1666+ assert ! ( iter4. next_back( ) . is_none( ) ) ;
1667+ assert ! ( iter4. next( ) . is_none( ) ) ;
1668+
1669+ // nth_back should not overflow with usize::MAX
1670+ let mut iter5 = list. iter ( ) ;
1671+ iter5. nth ( 1 ) ; //
1672+ assert ! ( iter5. nth_back( usize :: MAX ) . is_none( ) ) ;
16541673 } ) ;
16551674 }
16561675
@@ -1677,6 +1696,19 @@ mod tests {
16771696 let mut iter4 = list. iter ( ) ;
16781697 assert_eq ! ( iter4. advance_by( 0 ) , Ok ( ( ) ) ) ;
16791698 assert_eq ! ( iter4. next( ) . unwrap( ) . extract:: <i32 >( ) . unwrap( ) , 1 ) ;
1699+
1700+ // advance_by should not overflow with usize::MAX
1701+ // - first advanc will overflow by MAX - len, and will exhaust the iterator
1702+ // - second advance will overflow by MAX
1703+ let mut iter5 = list. iter ( ) ;
1704+ assert_eq ! (
1705+ iter5. advance_by( usize :: MAX ) ,
1706+ Err ( NonZero :: new( usize :: MAX - list. len( ) ) . unwrap( ) )
1707+ ) ;
1708+ assert_eq ! (
1709+ iter5. advance_by( usize :: MAX ) ,
1710+ Err ( NonZero :: new( usize :: MAX ) . unwrap( ) )
1711+ ) ;
16801712 } )
16811713 }
16821714
@@ -1703,6 +1735,19 @@ mod tests {
17031735 let mut iter4 = list. iter ( ) ;
17041736 assert_eq ! ( iter4. advance_back_by( 0 ) , Ok ( ( ) ) ) ;
17051737 assert_eq ! ( iter4. next_back( ) . unwrap( ) . extract:: <i32 >( ) . unwrap( ) , 5 ) ;
1738+
1739+ // advance_back_by should not overflow with usize::MAX
1740+ // - first advance will overflow by MAX - len, and will exhaust the iterator
1741+ // - second advance will overflow by MAX
1742+ let mut iter5 = list. iter ( ) ;
1743+ assert_eq ! (
1744+ iter5. advance_back_by( usize :: MAX ) ,
1745+ Err ( NonZero :: new( usize :: MAX - list. len( ) ) . unwrap( ) )
1746+ ) ;
1747+ assert_eq ! (
1748+ iter5. advance_back_by( usize :: MAX ) ,
1749+ Err ( NonZero :: new( usize :: MAX ) . unwrap( ) )
1750+ ) ;
17061751 } )
17071752 }
17081753
0 commit comments