Add numpy ndarray support for LLM-generated Python#248
Add numpy ndarray support for LLM-generated Python#248trevorprater wants to merge 12 commits intopydantic:mainfrom
Conversation
bf87cce to
1dffa8a
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
0b57170 to
6bfc6ce
Compare
|
Thank you for this PR. As you note, it's absolutely massive. I can understand the use case however I'm also reluctant to add such a huge feature set to the codebase, it's also not a complete feature set of numpy / pandas so I would imagine that merging this will set precedent for:
I worry this will quickly become unsustainable to maintain. I wonder if there's a few different approaches to explore here:
|
|
I certainly would like to help here but my rust knowledge is next to nothing, I done my own sandbox project but it use pyodide which is different approach to monty here https://github.com/auto-medica-labs/vivarium |
|
Hi thanks so much for this @trevorprater I agree with one of @davidhewitt's suggestions: we should add Rationale:
LMK what you think? |
|
Thanks @samuelcolvin, that makes sense. I agree – numpy is a much cleaner fit for Monty’s scope. The pandas surface area is large and would inevitably pull toward needing a real query engine for completeness, which isn’t worth the binary bloat. I’ll split the PR to numpy-only. Plan is:
@davidhewitt – apologies for the slow response, things have been hectic on my end. I really appreciate the thoughtful review and the suggestions around a plugin/extension mechanism (great idea). |
|
Great, no worries and no rush. FYI, I recently merged #265 which was a huge refactoring that'll likely hit this PR hard (but makes it MUCH easier to avoid copying data, satisfy borrow checker etc). Happy to take pings if you need any comments about how to get this PR up-to-date with main for the numpy bindings. |
Implements a built-in numpy module with ndarray type for running LLM-generated numeric Python code in the Monty sandbox. Module functions: array, zeros, ones, arange, linspace, sum, mean, min, max, abs, sqrt, log, exp, round, clip, where, maximum, minimum, sort, unique, concatenate, cumsum, dot, ceil, floor, log10, std. NdArray methods: sum, mean, min, max, std, flatten, tolist, copy, sort, argsort, argmin, argmax, all, any, cumsum, reshape, round, clip, dot, astype. Element-wise binary ops (+, -, *, /, //, %, **) and comparisons (==, !=, >, <, >=, <=) between arrays and scalars. All tests verified against real CPython + numpy.
Parity test with ~200 assertions covering every numpy function, method, attribute, binary op, comparison, and edge case — verified against both real CPython+numpy and Monty. Bugs found and fixed: - argmax() returned last max on ties instead of first (numpy returns first) - .T attribute not resolved (single-char "T" uses ASCII interning, not StaticStrings) - Float repr wrote "1.0" instead of numpy's "1." - len() on 2D arrays returned total elements instead of shape[0]
6bfc6ce to
529777b
Compare
… functions - Pre-check allocation size in np.zeros, np.ones, np.arange, np.linspace before allocating Vec to prevent memory exhaustion from user-controlled sizes - Accept plain list arguments in call_elementwise (np.abs, np.sqrt, etc.) matching real NumPy behavior - Improve reshape error message to include actual size and requested shape - Add tests for elementwise functions on plain lists
- Fix `~` on int arrays to use bitwise NOT (e.g. ~1 = -2) instead of logical NOT - Fix `~` on float arrays to raise TypeError, matching NumPy - Fix `~` on bool arrays to correctly flip True/False with Bool dtype - Track bool vs int dtype in array creation so np.array([True, False]) gets dtype='bool' instead of 'int64' - Validate that np.where x/y array lengths match condition length, raising ValueError on mismatch instead of silently producing inconsistent arrays
NaN/Inf correctness: - Fix min()/max() to propagate NaN (was silently ignoring NaN values) - Fix sort/argsort/unique to put NaN values last, matching NumPy - Fix float repr to use lowercase 'nan'/'inf' instead of Rust's 'NaN'/'inf' Empty array handling: - Fix empty np.array([]) to default to float64 dtype (was int64) 2D array correctness: - Fix tolist() to produce nested lists for 2D+ arrays (was flattening) Dtype promotion: - Track scalar_is_float through binary operations so int_arr * 1.0 correctly promotes to float64 (was staying int64 because 1.0.fract() == 0) - Pass is_float flag from Value::Float through value_to_f64() and all scalar operation dispatch Resource safety: - Add check_array_alloc_size() to np.concatenate (was unchecked) - Use checked_mul in reshape to prevent usize overflow from user input Tests: - Add ~60 new assertions for NaN/Inf, empty arrays, dtype, and 2D ops
…overage
Add comprehensive numpy ndarray support covering all operations commonly
generated by LLMs. Implementation spans 8 categories:
Phase 1 - Math: np.sin, np.cos, np.tan, np.log2, np.power, np.diff
Phase 2 - Creation: np.full, np.eye, np.copy, np.empty, np.zeros((m,n)),
np.ones((m,n)), np.zeros_like, np.ones_like
Phase 3 - Testing: np.isnan, np.isinf, np.isfinite, np.array_equal,
np.count_nonzero, np.all, np.any (module-level)
Phase 4 - Aggregation: .prod(), np.prod, .var(), np.var, np.median,
np.argmin, np.argmax (module-level)
Phase 5 - Manipulation: np.reshape, np.transpose (module-level),
np.append, np.vstack, np.hstack, np.stack, .ravel()
Phase 6 - Indexing: np.nonzero, np.argwhere, fancy indexing with
integer arrays, slice indexing (arr[1:3], arr[::2], arr[::-1])
Phase 7 - Utilities: np.tile, np.repeat, np.split,
.astype("int32"/"float32"/"int"/"float")
Phase 8 - Validation: 560 assertions verified against NumPy 2.x,
edge cases for empty arrays, NaN/Inf, single elements
All 918 integration tests pass, clippy clean, ref-count-panic clean.
- Validate negative arguments for np.linspace (num), np.tile (reps), np.repeat (repeats), and np.split (sections) — previously these would silently wrap to huge usize values via cast_sign_loss - Fix np.split error message for sections=0 to match NumPy's wording - Use defer_drop! for reps_val in tile/repeat for cleaner ref counting - Update module-level docstring to list all ~40 supported functions - Add doc comments about np.stack/np.hstack 1D-only limitation - Add doc comment about np.array_equal NaN behavior - Add comment about ref-count leak acceptability in call_split (resource exhaustion is terminal per project convention) - Replace temporal "New functions" section comment with descriptive label
Phase 1: Constants (np.pi, np.e, np.inf, np.nan, np.newaxis) and dtype type objects (np.float64, np.int64, np.bool_, np.float32, np.int32). Phase 3: Inverse trig (arcsin, arccos, arctan, arctan2), hyperbolic (sinh, cosh, tanh, arcsinh, arccosh, arctanh), and remaining element-wise math (sign, square, cbrt, reciprocal, log1p, exp2, expm1, deg2rad, rad2deg, degrees, radians, hypot, nan_to_num, fmin, fmax, fmod, rint, fabs, positive, negative). Phase 4: NaN-aware aggregations (nansum, nanmean, nanmin, nanmax, nanstd, nanvar, nanprod, nanmedian, nanargmin, nanargmax, nancumsum, nancumprod) and statistics (ptp, cumprod, percentile, quantile, average). Phase 5: Logical functions (logical_and, logical_or, logical_not, logical_xor, allclose, isclose, isin). Phase 6: Array manipulation (flip, fliplr, flipud, roll, expand_dims, squeeze, ravel, delete, insert, diag, diagonal, trace, flatnonzero, asarray, column_stack, array_split, full_like, empty_like). Phase 7: Sorting/searching/set ops (argsort, searchsorted, extract, intersect1d, union1d, setdiff1d, setxor1d, bincount, digitize). Phase 8: Linear algebra (outer, cross). Phase 9-10: Creation functions (logspace, geomspace, tri, tril, triu, identity, meshgrid, gradient, convolve, correlate, interp, select). Also fixes ref-counting bugs in functions using into_pos_only + .next() pattern, and fixes np.sign to return 0.0 for zero (matching NumPy). Test assertions: 560 → 735 (+175 new).
… ops, matmul Implements the most critical missing operators for ndarray parity: - Bitwise operators (&, |, ^) on bool and int arrays - __setitem__ with int index, bool mask, and slice assignment - __iter__ for iterating over array elements in for loops - __contains__ for 'val in arr' membership testing - In-place operators (+=, -=, *=, /=) for scalar and array operands - @ (matmul) operator and np.matmul function for dot/matrix products Updates the matmul test expectation from NotImplementedError to TypeError since @ is now implemented for ndarray but not for plain int/float. Test assertions: 735 → 779 (+44 new).
… etc.) and attributes (nbytes, itemsize) Adds missing ndarray methods and attributes: - .item() - extract single element as scalar - .cumprod() - cumulative product - .squeeze() - remove size-1 dimensions - .take(indices) - take elements at given indices - .diagonal() - extract diagonal of 2D array - .trace() - sum of diagonal elements - .fill(value) - fill array in-place - .compress(condition) - select elements by bool mask - .swapaxes(a, b) - swap two axes - .nbytes - total bytes (size * 8) - .itemsize - bytes per element (8) Test assertions: 779 → 803 (+24 new).
…+ test assertions - Fix ndarray.sort() to mutate in-place and return None (was creating new array) - Fix slice setitem to handle array RHS values, not just scalars - Add comprehensive test assertions (total now 1002) covering: - In-place sort, setitem with slices/masks/arrays - ndarray methods: item, cumprod, squeeze, take, diagonal, trace, fill, compress, swapaxes - ndarray attributes: nbytes, itemsize - Additional edge cases for existing functions
…nzero() methods - Add FloorDiv, Mod, Pow variants to NdArrayInplaceOp for true in-place mutation - Wire InplaceFloorDiv/InplaceMod/InplacePow to new inplace methods - Add .flat attribute returning flattened 1D copy of the array - Add ndarray.repeat(n) method for element-wise repetition - Add ndarray.nonzero() method returning tuple of index arrays - Total assertions: 1014
Summary
Adds numpy ndarray support to the Monty interpreter, enabling LLM-generated Python code that uses numpy to run in the sandbox.
crates/monty/src/types/ndarray.rs): Core ndarray implementation with shape, dtype, and all standard methods (sum, mean, min, max, std, reshape, flatten, tolist, argmin, argmax, cumsum, round, clip, dot, astype, copy, sort, argsort, all, any) and attributes (shape, dtype, size, ndim, T)crates/monty/src/modules/numpy.rs): 24+ functions including array creation (array, zeros, ones, arange, linspace), aggregation (sum, mean, min, max, std), element-wise math (abs, sqrt, log, exp, ceil, floor, log10, round, clip), and utilities (where, maximum, minimum, sort, unique, concatenate, cumsum, dot)binary.rs): +, -, *, /, //, %, ** for array+array, array+scalar, scalar+arraycompare.rs): ==, !=, >, <, >=, <= returning boolean ndarraysAll behavior verified against CPython+numpy via ~200 assertion parity test.
Test plan
make test-cases)numpy__parity.py— ~200 assertions verified against real CPython+numpynumpy__methods.py— ndarray methods and element-wise list testsnumpy__arithmetic.py— element-wise binary opsnumpy__array_creation.py— array, zeros, ones, arange, linspacenumpy__comparison.py— element-wise comparisonsnumpy__math_functions.py— np.abs, sqrt, log, exp, etc.np.zeros(10**9)from allocatingmake lint-rspasses clean