Skip to content

Commit 3d48a54

Browse files
fix bugs with list / tuple iterator "nth" operations (#6086)
* fix bugs with list / tuple iterator "nth" operations * newsfragment * Apply suggestions from code review Co-authored-by: Nathan Goldbaum <nathan.goldbaum@gmail.com> * final tidy ups * fix `next_back` implementation * add coverage for `BorrowedTupleIterator` --------- Co-authored-by: Nathan Goldbaum <nathan.goldbaum@gmail.com>
1 parent 62841bf commit 3d48a54

3 files changed

Lines changed: 332 additions & 196 deletions

File tree

newsfragments/6086.fixed.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Fix possible out of bounds read in `BoundListIterator` and `BoundTupleIterator`'s `nth` and `nth_back` implementations.
2+
- Fix `BoundListIterator` and `BoundTupleIterator` not being exhausted when `nth` or `nth_back` is called with N larger than the remaining count of items.

src/types/list.rs

Lines changed: 101 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -520,22 +520,27 @@ impl<'py> BoundListIterator<'py> {
520520
/// The caller must hold an active critical section on this list.
521521
#[inline]
522522
#[cfg(not(feature = "nightly"))]
523-
unsafe fn nth_locked(
523+
unsafe fn nth_unsynchronized(
524524
index: &mut Index,
525525
length: &mut Length,
526526
list: &Bound<'py, PyList>,
527527
n: usize,
528528
) -> Option<Bound<'py, PyAny>> {
529-
let length = length.0.min(list.len());
530-
let target_index = index.0 + n;
531-
if target_index < length {
532-
// SAFETY: target_index < list.len() guarantees in bounds
533-
let item = unsafe { list.get_item_unchecked(target_index) };
534-
index.0 = target_index + 1;
535-
Some(item)
536-
} else {
537-
None
529+
let current_length = length.0.min(list.len());
530+
if let Some(target_index) = index.0.checked_add(n) {
531+
if target_index < current_length {
532+
// SAFETY: target_index < current_length guarantees in bounds
533+
let item = unsafe { list.get_item_unchecked(target_index) };
534+
// +1 cannot overflow as target_index < current_length
535+
index.0 = target_index + 1;
536+
return Some(item);
537+
}
538538
}
539+
540+
// n overflows the remaining length of the list;
541+
// nth must exhaust all remaining items
542+
index.0 = current_length;
543+
None
539544
}
540545

541546
/// # Safety
@@ -564,22 +569,28 @@ impl<'py> BoundListIterator<'py> {
564569
/// The caller must hold an active critical section on this list.
565570
#[inline]
566571
#[cfg(not(feature = "nightly"))]
567-
unsafe fn nth_back_locked(
572+
unsafe fn nth_back_unsynchronized(
568573
index: &mut Index,
569574
length: &mut Length,
570575
list: &Bound<'py, PyList>,
571576
n: usize,
572577
) -> Option<Bound<'py, PyAny>> {
573-
let length_size = length.0.min(list.len());
574-
if index.0 + n < length_size {
575-
let target_index = length_size - n - 1;
576-
// SAFETY: target_index < list.len() guarantees in bounds
577-
let item = unsafe { list.get_item_unchecked(target_index) };
578-
length.0 = target_index;
579-
Some(item)
580-
} else {
581-
None
578+
let current_length = length.0.min(list.len());
579+
if let Some(index_after_item) = current_length.checked_sub(n) {
580+
if index.0 < index_after_item {
581+
// -1 cannot underflow as index_after_item > index
582+
let target_index = index_after_item - 1;
583+
// SAFETY: target_index is < current_length
584+
let item = unsafe { list.get_item_unchecked(target_index) };
585+
length.0 = target_index;
586+
return Some(item);
587+
}
582588
}
589+
590+
// n overflows the remaining length of the tuple;
591+
// nth must exhaust all remaining items
592+
length.0 = index.0;
593+
None
583594
}
584595

585596
fn with_critical_section<R>(
@@ -611,7 +622,7 @@ impl<'py> Iterator for BoundListIterator<'py> {
611622
fn nth(&mut self, n: usize) -> Option<Self::Item> {
612623
// SAFETY: with_critical_section locks the list
613624
self.with_critical_section(|index, length, list| unsafe {
614-
Self::nth_locked(index, length, list, n)
625+
Self::nth_unsynchronized(index, length, list, n)
615626
})
616627
}
617628

@@ -768,25 +779,17 @@ impl<'py> Iterator for BoundListIterator<'py> {
768779
#[cfg(feature = "nightly")]
769780
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
770781
self.with_critical_section(|index, length, list| {
771-
let max_len = length.0.min(list.len());
772-
let currently_at = index.0;
773-
if currently_at >= max_len {
774-
if n == 0 {
775-
return Ok(());
776-
} else {
777-
return Err(unsafe { NonZero::new_unchecked(n) });
778-
}
782+
let current_length = length.0.min(list.len());
783+
let items_left = current_length.saturating_sub(index.0);
784+
if let Some(overflow) = NonZero::new(n.saturating_sub(items_left)) {
785+
// n overflows the remaining length of the list; advance_by must exhaust all remaining items
786+
index.0 = current_length;
787+
return Err(overflow);
779788
}
780789

781-
let items_left = max_len - currently_at;
782-
if n <= items_left {
783-
index.0 += n;
784-
Ok(())
785-
} else {
786-
index.0 = max_len;
787-
let remainder = n - items_left;
788-
Err(unsafe { NonZero::new_unchecked(remainder) })
789-
}
790+
// cannot overflow as length - index >= n
791+
index.0 += n;
792+
Ok(())
790793
})
791794
}
792795
}
@@ -804,7 +807,7 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
804807
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
805808
// SAFETY: with_critical_section locks the list
806809
self.with_critical_section(|index, length, list| unsafe {
807-
Self::nth_back_locked(index, length, list, n)
810+
Self::nth_back_unsynchronized(index, length, list, n)
808811
})
809812
}
810813

@@ -847,25 +850,17 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
847850
#[cfg(feature = "nightly")]
848851
fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
849852
self.with_critical_section(|index, length, list| {
850-
let max_len = length.0.min(list.len());
851-
let currently_at = index.0;
852-
if currently_at >= max_len {
853-
if n == 0 {
854-
return Ok(());
855-
} else {
856-
return Err(unsafe { NonZero::new_unchecked(n) });
857-
}
853+
let current_length = length.0.min(list.len());
854+
let items_left = current_length.saturating_sub(index.0);
855+
if let Some(overflow) = NonZero::new(n.saturating_sub(items_left)) {
856+
// n overflows the remaining length of the list; advance_back_by must exhaust all remaining items
857+
length.0 = index.0;
858+
return Err(overflow);
858859
}
859860

860-
let items_left = max_len - currently_at;
861-
if n <= items_left {
862-
length.0 = max_len - n;
863-
Ok(())
864-
} else {
865-
length.0 = currently_at;
866-
let remainder = n - items_left;
867-
Err(unsafe { NonZero::new_unchecked(remainder) })
868-
}
861+
// cannot overflow as current_length - index >= n
862+
length.0 = current_length - n;
863+
Ok(())
869864
})
870865
}
871866
}
@@ -1594,6 +1589,19 @@ mod tests {
15941589
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 9);
15951590
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 8);
15961591
assert!(iter.next().is_none());
1592+
1593+
// nth consumes all elements in the list, even on `None` return
1594+
let mut iter = list.iter();
1595+
assert!(iter.nth(100).is_none());
1596+
assert!(iter.next().is_none());
1597+
assert!(iter.next_back().is_none());
1598+
1599+
// nth should not overflow the iterator
1600+
// a naive implementation of nth will overflow if number of advanced
1601+
// elements plus N overflows usize::MAX
1602+
let mut iter = list.iter();
1603+
assert!(iter.next().is_some());
1604+
assert!(iter.nth(usize::MAX).is_none());
15971605
});
15981606
}
15991607

@@ -1651,6 +1659,17 @@ mod tests {
16511659
iter3.nth(1);
16521660
assert_eq!(iter3.nth_back(2).unwrap().extract::<i32>().unwrap(), 3);
16531661
assert!(iter3.nth_back(0).is_none());
1662+
1663+
// nth_back consumes all elements in the list, even on `None` return
1664+
let mut iter4 = list.iter();
1665+
assert!(iter4.nth_back(100).is_none());
1666+
assert!(iter4.next_back().is_none());
1667+
assert!(iter4.next().is_none());
1668+
1669+
// nth_back should not overflow with usize::MAX
1670+
let mut iter5 = list.iter();
1671+
iter5.nth(1); //
1672+
assert!(iter5.nth_back(usize::MAX).is_none());
16541673
});
16551674
}
16561675

@@ -1677,6 +1696,19 @@ mod tests {
16771696
let mut iter4 = list.iter();
16781697
assert_eq!(iter4.advance_by(0), Ok(()));
16791698
assert_eq!(iter4.next().unwrap().extract::<i32>().unwrap(), 1);
1699+
1700+
// advance_by should not overflow with usize::MAX
1701+
// - first advanc will overflow by MAX - len, and will exhaust the iterator
1702+
// - second advance will overflow by MAX
1703+
let mut iter5 = list.iter();
1704+
assert_eq!(
1705+
iter5.advance_by(usize::MAX),
1706+
Err(NonZero::new(usize::MAX - list.len()).unwrap())
1707+
);
1708+
assert_eq!(
1709+
iter5.advance_by(usize::MAX),
1710+
Err(NonZero::new(usize::MAX).unwrap())
1711+
);
16801712
})
16811713
}
16821714

@@ -1703,6 +1735,19 @@ mod tests {
17031735
let mut iter4 = list.iter();
17041736
assert_eq!(iter4.advance_back_by(0), Ok(()));
17051737
assert_eq!(iter4.next_back().unwrap().extract::<i32>().unwrap(), 5);
1738+
1739+
// advance_back_by should not overflow with usize::MAX
1740+
// - first advance will overflow by MAX - len, and will exhaust the iterator
1741+
// - second advance will overflow by MAX
1742+
let mut iter5 = list.iter();
1743+
assert_eq!(
1744+
iter5.advance_back_by(usize::MAX),
1745+
Err(NonZero::new(usize::MAX - list.len()).unwrap())
1746+
);
1747+
assert_eq!(
1748+
iter5.advance_back_by(usize::MAX),
1749+
Err(NonZero::new(usize::MAX).unwrap())
1750+
);
17061751
})
17071752
}
17081753

0 commit comments

Comments
 (0)