Skip to content

Commit b81719e

Browse files
committed
Improve fmap for pandas
Resolves #734.
1 parent 23d4959 commit b81719e

4 files changed

Lines changed: 14 additions & 11 deletions

File tree

DOCS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3168,6 +3168,8 @@ For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the map
31683168

31693169
For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result.
31703170

3171+
For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s).
3172+
31713173
For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to
31723174
```coconut_python
31733175
async def fmap_over_async_iters(func, async_iter):
@@ -3198,7 +3200,7 @@ _Can't be done without a series of method definitions for each data type. See th
31983200

31993201
**call**(_func_, /, *_args_, \*\*_kwargs_)
32003202

3201-
Coconut's `call` simply implements function application. Thus, `call` is equivalent to
3203+
Coconut's `call` simply implements function application. Thus, `call` is effectively equivalent to
32023204
```coconut
32033205
def call(f, /, *args, **kwargs) = f(*args, **kwargs)
32043206
```

coconut/compiler/templates/header.py_template

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,16 +1486,15 @@ def fmap(func, obj, **kwargs):
14861486
if result is not _coconut.NotImplemented:
14871487
return result
14881488
obj_module = _coconut_get_base_module(obj)
1489+
if obj_module in _coconut.pandas_numpy_modules:
1490+
if obj.ndim <= 1:
1491+
return obj.apply(func)
1492+
return obj.apply(func, axis=obj.ndim-1)
14891493
if obj_module in _coconut.jax_numpy_modules:
14901494
import jax.numpy as jnp
14911495
return jnp.vectorize(func)(obj)
14921496
if obj_module in _coconut.numpy_modules:
1493-
got = _coconut.numpy.vectorize(func)(obj)
1494-
if obj_module in _coconut.pandas_numpy_modules:
1495-
new_obj = obj.copy()
1496-
new_obj[:] = got
1497-
return new_obj
1498-
return got
1497+
return _coconut.numpy.vectorize(func)(obj)
14991498
obj_aiter = _coconut.getattr(obj, "__aiter__", None)
15001499
if obj_aiter is not None and _coconut_amap is not None:
15011500
try:

coconut/root.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
VERSION = "3.0.0"
2727
VERSION_NAME = None
2828
# False for release, int >= 1 for develop
29-
DEVELOP = 39
29+
DEVELOP = 40
3030
ALPHA = True # for pre releases rather than post releases
3131

3232
# -----------------------------------------------------------------------------------------------------------------------

coconut/tests/src/extras.coco

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,9 @@ def test_pandas() -> bool:
472472
assert [d1; d1].keys() |> list == ["nums", "chars"] * 2 # type: ignore
473473
assert [d1;; d1].itertuples() |> list == [(0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c'), (0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c')] # type: ignore
474474
d2 = pd.DataFrame({"a": range(3) |> list, "b": range(1, 4) |> list})
475-
new_d2 = d2 |> fmap$(.+1)
476-
assert new_d2["a"] |> list == range(1, 4) |> list
477-
assert new_d2["b"] |> list == range(2, 5) |> list
475+
d3 = d2 |> fmap$(fmap$(.+1))
476+
assert d3["a"] |> list == range(1, 4) |> list
477+
assert d3["b"] |> list == range(2, 5) |> list
478478
assert multi_enumerate(d1) |> list == [((0, 0), 1), ((1, 0), 2), ((2, 0), 3), ((0, 1), 'a'), ((1, 1), 'b'), ((2, 1), 'c')]
479479
assert not all_equal(d1)
480480
assert not all_equal(d2)
@@ -489,6 +489,8 @@ def test_pandas() -> bool:
489489
3; 'b';;
490490
3; 'c';;
491491
], dtype=object) # type: ignore
492+
d4 = d1 |> fmap$(def r -> r["nums2"] = r["nums"]*2; r)
493+
assert (d4["nums"] * 2 == d4["nums2"]).all()
492494
return True
493495

494496

0 commit comments

Comments
 (0)