Skip to content

Commit 23ff7eb

Browse files
Implement strict mode for zip() function to enforce equal lengths of iterables (#324)
Co-authored-by: amitrechavia <93474585+amitrechavia@users.noreply.github.com>
1 parent e78c9c3 commit 23ff7eb

File tree

3 files changed

+163
-29
lines changed

3 files changed

+163
-29
lines changed

crates/monty/src/builtins/zip.rs

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
//! Implementation of the zip() builtin function.
22
33
use crate::{
4-
args::ArgValues,
4+
args::{ArgValues, KwargsValues},
55
bytecode::VM,
6-
defer_drop_mut,
7-
exception_private::RunResult,
8-
heap::HeapData,
6+
defer_drop, defer_drop_mut,
7+
exception_private::{ExcType, RunError, RunResult, SimpleException},
8+
heap::{HeapData, HeapGuard},
99
resource::ResourceTracker,
10-
types::{List, MontyIter, allocate_tuple, tuple::TupleVec},
10+
types::{List, MontyIter, PyTrait, allocate_tuple, tuple::TupleVec},
1111
value::Value,
1212
};
1313

1414
/// Implementation of the zip() builtin function.
1515
///
1616
/// Returns a list of tuples, where the i-th tuple contains the i-th element
1717
/// from each of the argument iterables. Stops when the shortest iterable is exhausted.
18+
/// When `strict=True`, raises `ValueError` if any iterable has a different length.
1819
/// Note: In Python this returns an iterator, but we return a list for simplicity.
1920
pub fn builtin_zip(vm: &mut VM<'_, '_, impl ResourceTracker>, args: ArgValues) -> RunResult<Value> {
2021
let (positional, kwargs) = args.into_parts();
2122
defer_drop_mut!(positional, vm);
2223

23-
// TODO: support kwargs (strict)
24-
kwargs.not_supported_yet("zip", vm.heap)?;
24+
let strict = extract_zip_strict(kwargs, vm)?;
2525

2626
if positional.len() == 0 {
2727
// zip() with no arguments returns empty list
@@ -30,48 +30,113 @@ pub fn builtin_zip(vm: &mut VM<'_, '_, impl ResourceTracker>, args: ArgValues) -
3030
}
3131

3232
// Create iterators for each iterable
33-
let mut iterators: Vec<MontyIter> = Vec::with_capacity(positional.len());
33+
let iterators: Vec<MontyIter> = Vec::with_capacity(positional.len());
34+
defer_drop_mut!(iterators, vm);
3435
for iterable in positional {
35-
match MontyIter::new(iterable, vm) {
36-
Ok(iter) => iterators.push(iter),
37-
Err(e) => {
38-
// Clean up already-created iterators
39-
for iter in iterators {
40-
iter.drop_with_heap(vm);
41-
}
42-
return Err(e);
43-
}
44-
}
36+
iterators.push(MontyIter::new(iterable, vm)?);
4537
}
4638

47-
let mut result: Vec<Value> = Vec::new();
39+
let mut result_guard = HeapGuard::new(Vec::new(), vm);
40+
let (result, vm) = result_guard.as_parts_mut();
4841

4942
// Zip until shortest iterator is exhausted
5043
'outer: loop {
51-
let mut tuple_items = TupleVec::with_capacity(iterators.len());
44+
let mut items_guard = HeapGuard::new(TupleVec::with_capacity(iterators.len()), vm);
45+
let (tuple_items, vm) = items_guard.as_parts_mut();
5246

53-
for iter in &mut iterators {
47+
for (i, iter) in iterators.iter_mut().enumerate() {
5448
if let Some(item) = iter.for_next(vm)? {
5549
tuple_items.push(item);
5650
} else {
57-
// This iterator is exhausted - drop partial tuple items and stop
58-
for item in tuple_items {
59-
item.drop_with_heap(vm);
51+
// This iterator is exhausted - stop zipping
52+
53+
if strict {
54+
// In strict mode, if i > 0 then argument i+1 ran out before
55+
// the earlier ones, so it is "shorter."
56+
if i > 0 {
57+
return Err(strict_length_error(i + 1, i, "shorter"));
58+
}
59+
// i == 0: first iterator exhausted — verify every remaining
60+
// iterator is also exhausted. If any still yields a value,
61+
// that argument is "longer" than all preceding exhausted ones.
62+
// j is the 0-based index; iterators 0..j are all exhausted,
63+
// so j gives the count for the error message.
64+
for (j, remaining) in iterators.iter_mut().enumerate().skip(1) {
65+
if let Some(extra) = remaining.for_next(vm)? {
66+
extra.drop_with_heap(vm);
67+
return Err(strict_length_error(j + 1, j, "longer"));
68+
}
69+
}
6070
}
71+
6172
break 'outer;
6273
}
6374
}
6475

6576
// Create tuple from collected items
77+
let (tuple_items, vm) = items_guard.into_parts();
6678
let tuple_val = allocate_tuple(tuple_items, vm.heap)?;
6779
result.push(tuple_val);
6880
}
6981

70-
// Clean up iterators
71-
for iter in iterators {
72-
iter.drop_with_heap(vm);
73-
}
74-
82+
let (result, vm) = result_guard.into_parts();
7583
let heap_id = vm.heap.allocate(HeapData::List(List::new(result)))?;
7684
Ok(Value::Ref(heap_id))
7785
}
86+
87+
/// Extracts the `strict` keyword argument from `zip()`.
88+
///
89+
/// Accepts any truthy/falsy value for `strict`, matching CPython behavior.
90+
/// Raises `TypeError` for unexpected keyword arguments.
91+
fn extract_zip_strict(kwargs: KwargsValues, vm: &mut VM<'_, '_, impl ResourceTracker>) -> RunResult<bool> {
92+
let mut strict = false;
93+
let mut error: Option<RunError> = None;
94+
95+
for (key, value) in kwargs {
96+
defer_drop!(key, vm);
97+
defer_drop!(value, vm);
98+
99+
if error.is_some() {
100+
continue;
101+
}
102+
103+
let Some(keyword_name) = key.as_either_str(vm.heap) else {
104+
error = Some(SimpleException::new_msg(ExcType::TypeError, "keywords must be strings").into());
105+
continue;
106+
};
107+
108+
let key_str = keyword_name.as_str(vm.interns);
109+
match key_str {
110+
"strict" => {
111+
strict = value.py_bool(vm);
112+
}
113+
_ => {
114+
error = Some(ExcType::type_error_unexpected_keyword("zip", key_str));
115+
}
116+
}
117+
}
118+
119+
if let Some(error) = error {
120+
Err(error)
121+
} else {
122+
Ok(strict)
123+
}
124+
}
125+
126+
/// Builds the `ValueError` for `zip(strict=True)` when iterables have different lengths.
127+
///
128+
/// Matches CPython's error format:
129+
/// - `"zip() argument 2 is shorter than argument 1"` (singular)
130+
/// - `"zip() argument 4 is shorter than arguments 1-3"` (plural)
131+
fn strict_length_error(exhausted_arg: usize, num_longer_args: usize, relation: &str) -> RunError {
132+
let others = if num_longer_args == 1 {
133+
"argument 1".to_owned()
134+
} else {
135+
format!("arguments 1-{num_longer_args}")
136+
};
137+
SimpleException::new_msg(
138+
ExcType::ValueError,
139+
format!("zip() argument {exhausted_arg} is {relation} than {others}"),
140+
)
141+
.into()
142+
}

crates/monty/src/heap_traits.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::{
44
vec::{Drain, IntoIter},
55
};
66

7+
use smallvec::SmallVec;
8+
79
use crate::{
810
ResourceTracker,
911
heap::{Heap, HeapId, RecursionToken},
@@ -112,6 +114,17 @@ impl<U: DropWithHeap> DropWithHeap for Vec<U> {
112114
}
113115
}
114116

117+
impl<A: smallvec::Array> DropWithHeap for SmallVec<A>
118+
where
119+
A::Item: DropWithHeap,
120+
{
121+
fn drop_with_heap<H: ContainsHeap>(self, heap: &mut H) {
122+
for value in self {
123+
value.drop_with_heap(heap);
124+
}
125+
}
126+
}
127+
115128
impl<U: DropWithHeap> DropWithHeap for IntoIter<U> {
116129
fn drop_with_heap<H: ContainsHeap>(self, heap: &mut H) {
117130
for value in self {

crates/monty/test_cases/builtin__more_iter_funcs.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,59 @@ def negate(x):
359359
# zip with empty
360360
assert list(zip([1, 2], [])) == [], 'zip with empty second'
361361
assert list(zip([], [1, 2])) == [], 'zip with empty first'
362+
363+
# === zip(strict=True) ===
364+
# Equal length iterables succeed
365+
assert list(zip([1, 2], [3, 4], strict=True)) == [(1, 3), (2, 4)], 'zip strict equal lengths'
366+
assert list(zip([1], [2], [3], strict=True)) == [(1, 2, 3)], 'zip strict three single-element lists'
367+
assert list(zip([], [], strict=True)) == [], 'zip strict empty lists'
368+
assert list(zip(strict=True)) == [], 'zip strict no arguments'
369+
assert list(zip([1, 2, 3], strict=True)) == [(1,), (2,), (3,)], 'zip strict single iterable'
370+
371+
# strict=False behaves like default
372+
assert list(zip([1, 2, 3], [4, 5], strict=False)) == [(1, 4), (2, 5)], 'zip strict=False truncates'
373+
374+
# Falsy values are accepted
375+
assert list(zip([1, 2, 3], [4, 5], strict=0)) == [(1, 4), (2, 5)], 'zip strict=0 is falsy'
376+
377+
# Second argument shorter
378+
try:
379+
list(zip([1, 2, 3], [4, 5], strict=True))
380+
assert False, 'zip strict should raise for shorter arg 2'
381+
except ValueError as e:
382+
assert str(e) == 'zip() argument 2 is shorter than argument 1', 'zip strict shorter error'
383+
384+
# Second argument longer
385+
try:
386+
list(zip([1, 2], [4, 5, 6], strict=True))
387+
assert False, 'zip strict should raise for longer arg 2'
388+
except ValueError as e:
389+
assert str(e) == 'zip() argument 2 is longer than argument 1', 'zip strict longer error'
390+
391+
# Third argument shorter with plural
392+
try:
393+
list(zip([1, 2], [3, 4], [5], strict=True))
394+
assert False, 'zip strict should raise for shorter arg 3'
395+
except ValueError as e:
396+
assert str(e) == 'zip() argument 3 is shorter than arguments 1-2', 'zip strict shorter plural'
397+
398+
# Fourth argument shorter
399+
try:
400+
list(zip([1, 2], [3, 4], [5, 6], [7], strict=True))
401+
assert False, 'zip strict should raise for shorter arg 4'
402+
except ValueError as e:
403+
assert str(e) == 'zip() argument 4 is shorter than arguments 1-3', 'zip strict shorter 4 args'
404+
405+
# Third argument longer than arguments 1-2 (both exhausted)
406+
try:
407+
list(zip([1], [2], [3, 4], strict=True))
408+
assert False, 'zip strict should raise for longer arg 3'
409+
except ValueError as e:
410+
assert str(e) == 'zip() argument 3 is longer than arguments 1-2', 'zip strict longer plural'
411+
412+
# Unexpected keyword argument
413+
try:
414+
list(zip([1], foo=True))
415+
assert False, 'zip unexpected kwarg should raise TypeError'
416+
except TypeError as e:
417+
assert str(e) == "zip() got an unexpected keyword argument 'foo'", 'zip unexpected kwarg error'

0 commit comments

Comments
 (0)