Skip to content

Commit f19267e

Browse files
committed
fixes #686
1 parent c2cd643 commit f19267e

3 files changed

Lines changed: 219 additions & 14 deletions

File tree

fastcore/_modidx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,9 @@
578578
'fastcore/xtras.py'),
579579
'fastcore.xtras.ReindexCollection.reindex': ('xtras.html#reindexcollection.reindex', 'fastcore/xtras.py'),
580580
'fastcore.xtras.ReindexCollection.shuffle': ('xtras.html#reindexcollection.shuffle', 'fastcore/xtras.py'),
581+
'fastcore.xtras.SaveReturn': ('xtras.html#savereturn', 'fastcore/xtras.py'),
582+
'fastcore.xtras.SaveReturn.__init__': ('xtras.html#savereturn.__init__', 'fastcore/xtras.py'),
583+
'fastcore.xtras.SaveReturn.__iter__': ('xtras.html#savereturn.__iter__', 'fastcore/xtras.py'),
581584
'fastcore.xtras.Unset': ('xtras.html#unset', 'fastcore/xtras.py'),
582585
'fastcore.xtras.Unset.__bool__': ('xtras.html#unset.__bool__', 'fastcore/xtras.py'),
583586
'fastcore.xtras.Unset.__repr__': ('xtras.html#unset.__repr__', 'fastcore/xtras.py'),
@@ -588,13 +591,15 @@
588591
'fastcore.xtras._property_getter': ('xtras.html#_property_getter', 'fastcore/xtras.py'),
589592
'fastcore.xtras._repr_dict': ('xtras.html#_repr_dict', 'fastcore/xtras.py'),
590593
'fastcore.xtras._save_iter': ('xtras.html#_save_iter', 'fastcore/xtras.py'),
594+
'fastcore.xtras._save_iter.__aiter__': ('xtras.html#_save_iter.__aiter__', 'fastcore/xtras.py'),
591595
'fastcore.xtras._save_iter.__init__': ('xtras.html#_save_iter.__init__', 'fastcore/xtras.py'),
592596
'fastcore.xtras._save_iter.__iter__': ('xtras.html#_save_iter.__iter__', 'fastcore/xtras.py'),
593597
'fastcore.xtras._sparkchar': ('xtras.html#_sparkchar', 'fastcore/xtras.py'),
594598
'fastcore.xtras._unpack': ('xtras.html#_unpack', 'fastcore/xtras.py'),
595599
'fastcore.xtras._unwrapped_func': ('xtras.html#_unwrapped_func', 'fastcore/xtras.py'),
596600
'fastcore.xtras._unwrapped_type_dispatch_func': ( 'xtras.html#_unwrapped_type_dispatch_func',
597601
'fastcore/xtras.py'),
602+
'fastcore.xtras.asave_iter': ('xtras.html#asave_iter', 'fastcore/xtras.py'),
598603
'fastcore.xtras.asdict': ('xtras.html#asdict', 'fastcore/xtras.py'),
599604
'fastcore.xtras.autostart': ('xtras.html#autostart', 'fastcore/xtras.py'),
600605
'fastcore.xtras.bunzip': ('xtras.html#bunzip', 'fastcore/xtras.py'),

fastcore/xtras.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
__all__ = ['spark_chars', 'UNSET', 'walk', 'globtastic', 'maybe_open', 'mkdir', 'image_size', 'bunzip', 'loads', 'loads_multi',
1010
'dumps', 'untar_dir', 'repo_details', 'run', 'open_file', 'save_pickle', 'load_pickle', 'parse_env',
1111
'expand_wildcards', 'dict2obj', 'obj2dict', 'repr_dict', 'is_listy', 'mapped', 'IterLen',
12-
'ReindexCollection', 'trim_wraps', 'save_iter', 'exec_eval', 'get_source_link', 'truncstr', 'sparkline',
13-
'modify_exception', 'round_multiple', 'set_num_threads', 'join_path_file', 'autostart', 'EventTimer',
14-
'stringfmt_names', 'PartialFormatter', 'partial_format', 'utc2local', 'local2utc', 'trace', 'modified_env',
15-
'ContextManagers', 'shufflish', 'console_help', 'hl_md', 'type2str', 'dataclass_src', 'Unset', 'nullable_dc',
16-
'make_nullable', 'flexiclass', 'asdict', 'is_typeddict', 'is_namedtuple', 'CachedIter', 'CachedAwaitable',
17-
'reawaitable', 'flexicache', 'time_policy', 'mtime_policy', 'timed_cache']
12+
'ReindexCollection', 'SaveReturn', 'trim_wraps', 'save_iter', 'asave_iter', 'exec_eval', 'get_source_link',
13+
'truncstr', 'sparkline', 'modify_exception', 'round_multiple', 'set_num_threads', 'join_path_file',
14+
'autostart', 'EventTimer', 'stringfmt_names', 'PartialFormatter', 'partial_format', 'utc2local', 'local2utc',
15+
'trace', 'modified_env', 'ContextManagers', 'shufflish', 'console_help', 'hl_md', 'type2str',
16+
'dataclass_src', 'Unset', 'nullable_dc', 'make_nullable', 'flexiclass', 'asdict', 'is_typeddict',
17+
'is_namedtuple', 'CachedIter', 'CachedAwaitable', 'reawaitable', 'flexicache', 'time_policy', 'mtime_policy',
18+
'timed_cache']
1819

1920
# %% ../nbs/03_xtras.ipynb
2021
from .imports import *
@@ -409,6 +410,14 @@ def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s
409410
shuffle="Randomly shuffle indices",
410411
cache_clear="Clear LRU cache")
411412

413+
# %% ../nbs/03_xtras.ipynb
414+
class SaveReturn:
415+
"Wrap an iterator such that the generator function's return value is stored in `.value`"
416+
def __init__(self, its): self.its = its
417+
def __iter__(self):
418+
self.value = yield from self.its
419+
return self.value
420+
412421
# %% ../nbs/03_xtras.ipynb
413422
def trim_wraps(f, n=1):
414423
"Like wraps, but removes the first n parameters from the signature"
@@ -425,8 +434,17 @@ def _(g):
425434
class _save_iter:
426435
def __init__(self, g, *args, **kw): self.g,self.args,self.kw = g,args,kw
427436
def __iter__(self): yield from self.g(self, *self.args, **self.kw)
437+
def __aiter__(self): return self.g(self, *self.args, **self.kw)
428438

429439
def save_iter(g):
440+
"Decorator that allows a generator function to store values in the returned iterator object"
441+
@trim_wraps(g)
442+
def _(*args, **kwargs): return _save_iter(g, *args, **kwargs)
443+
return _
444+
445+
# %% ../nbs/03_xtras.ipynb
446+
def asave_iter(g):
447+
"Like `save_iter`, but for async iterators"
430448
@trim_wraps(g)
431449
def _(*args, **kwargs): return _save_iter(g, *args, **kwargs)
432450
return _

nbs/03_xtras.ipynb

Lines changed: 190 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,7 +1732,7 @@
17321732
{
17331733
"data": {
17341734
"text/plain": [
1735-
"['c', 'h', 'b', 'a', 'f', 'g', 'd', 'e']"
1735+
"['g', 'c', 'f', 'd', 'a', 'b', 'e', 'h']"
17361736
]
17371737
},
17381738
"execution_count": null,
@@ -1782,7 +1782,99 @@
17821782
"cell_type": "markdown",
17831783
"metadata": {},
17841784
"source": [
1785-
"## Other Helpers"
1785+
"## `SaveReturn` and `save_iter` Variants"
1786+
]
1787+
},
1788+
{
1789+
"cell_type": "markdown",
1790+
"metadata": {},
1791+
"source": [
1792+
"These utilities solve a common problem in Python: how to extract additional information from generator functions beyond just the yielded values.\n",
1793+
"\n",
1794+
"In Python, generator functions can `yield` values and also `return` a final value, but the return value is normally lost when you iterate over the generator:"
1795+
]
1796+
},
1797+
{
1798+
"cell_type": "code",
1799+
"execution_count": null,
1800+
"metadata": {},
1801+
"outputs": [],
1802+
"source": [
1803+
"def example_generator():\n",
1804+
" total = 0\n",
1805+
" for i in range(3):\n",
1806+
" total += i\n",
1807+
" yield i\n",
1808+
" return total # This gets lost!\n",
1809+
"\n",
1810+
"# The return value (3) is lost\n",
1811+
"values = list(example_generator()) # [0, 1, 2]"
1812+
]
1813+
},
1814+
{
1815+
"cell_type": "code",
1816+
"execution_count": null,
1817+
"metadata": {},
1818+
"outputs": [],
1819+
"source": [
1820+
"#| exports\n",
1821+
"class SaveReturn:\n",
1822+
" \"Wrap an iterator such that the generator function's return value is stored in `.value`\"\n",
1823+
" def __init__(self, its): self.its = its\n",
1824+
" def __iter__(self):\n",
1825+
" self.value = yield from self.its\n",
1826+
" return self.value"
1827+
]
1828+
},
1829+
{
1830+
"cell_type": "markdown",
1831+
"metadata": {},
1832+
"source": [
1833+
"`SaveReturn` is the simplest approach to solving this problem - it wraps any existing (non-async) generator and captures its return value. This works because `yield from` (used internally in `SaveReturn`) returns the value from the `return` of the generator function."
1834+
]
1835+
},
1836+
{
1837+
"cell_type": "code",
1838+
"execution_count": null,
1839+
"metadata": {},
1840+
"outputs": [
1841+
{
1842+
"name": "stdout",
1843+
"output_type": "stream",
1844+
"text": [
1845+
"Values: [0, 1, 2, 3, 4]\n"
1846+
]
1847+
},
1848+
{
1849+
"data": {
1850+
"text/plain": [
1851+
"10"
1852+
]
1853+
},
1854+
"execution_count": null,
1855+
"metadata": {},
1856+
"output_type": "execute_result"
1857+
}
1858+
],
1859+
"source": [
1860+
"def sum_range(n):\n",
1861+
" total = 0\n",
1862+
" for i in range(n):\n",
1863+
" total += i\n",
1864+
" yield i\n",
1865+
" return total # This value is returned by yield from\n",
1866+
"\n",
1867+
"sr = SaveReturn(sum_range(5))\n",
1868+
"values = list(sr) # This will consume the generator and get the return value\n",
1869+
"print(f\"Values: {values}\")\n",
1870+
"sr.value"
1871+
]
1872+
},
1873+
{
1874+
"cell_type": "markdown",
1875+
"metadata": {},
1876+
"source": [
1877+
"In order to provide an accurate signature for `save_iter`, we need a version of `wraps` that removes leading parameters:"
17861878
]
17871879
},
17881880
{
@@ -1846,8 +1938,91 @@
18461938
"class _save_iter:\n",
18471939
" def __init__(self, g, *args, **kw): self.g,self.args,self.kw = g,args,kw\n",
18481940
" def __iter__(self): yield from self.g(self, *self.args, **self.kw)\n",
1941+
" def __aiter__(self): return self.g(self, *self.args, **self.kw)\n",
18491942
"\n",
18501943
"def save_iter(g):\n",
1944+
" \"Decorator that allows a generator function to store values in the returned iterator object\"\n",
1945+
" @trim_wraps(g)\n",
1946+
" def _(*args, **kwargs): return _save_iter(g, *args, **kwargs)\n",
1947+
" return _"
1948+
]
1949+
},
1950+
{
1951+
"cell_type": "markdown",
1952+
"metadata": {},
1953+
"source": [
1954+
"`save_iter` modifies generator functions to store state in the iterator object itself. The generator receives an object as its first parameter, which it can use to store attributes. You can store values during iteration, not just at the end,\n",
1955+
"and you can store multiple attributes if needed."
1956+
]
1957+
},
1958+
{
1959+
"cell_type": "code",
1960+
"execution_count": null,
1961+
"metadata": {},
1962+
"outputs": [],
1963+
"source": [
1964+
"@save_iter\n",
1965+
"def sum_range(o, n): # Note: 'o' parameter added\n",
1966+
" total = 0\n",
1967+
" for i in range(n):\n",
1968+
" total += i\n",
1969+
" yield i\n",
1970+
" o.value = total # Store directly on the iterator object"
1971+
]
1972+
},
1973+
{
1974+
"cell_type": "markdown",
1975+
"metadata": {},
1976+
"source": [
1977+
"Because iternally `save_iter` uses `trim_wraps`, the signature of `sum_range` correctly shows that you should *not* pass `o` to it; it's injected by the decorating function."
1978+
]
1979+
},
1980+
{
1981+
"cell_type": "code",
1982+
"execution_count": null,
1983+
"metadata": {},
1984+
"outputs": [
1985+
{
1986+
"name": "stdout",
1987+
"output_type": "stream",
1988+
"text": [
1989+
"(n)\n"
1990+
]
1991+
}
1992+
],
1993+
"source": [
1994+
"print(sum_range.__signature__)"
1995+
]
1996+
},
1997+
{
1998+
"cell_type": "code",
1999+
"execution_count": null,
2000+
"metadata": {},
2001+
"outputs": [
2002+
{
2003+
"name": "stdout",
2004+
"output_type": "stream",
2005+
"text": [
2006+
"Values: [0, 1, 2, 3, 4]\n",
2007+
"Sum stored: 10\n"
2008+
]
2009+
}
2010+
],
2011+
"source": [
2012+
"sr = sum_range(5)\n",
2013+
"print(f\"Values: {list(sr)}\")\n",
2014+
"print(f\"Sum stored: {sr.value}\")"
2015+
]
2016+
},
2017+
{
2018+
"cell_type": "code",
2019+
"execution_count": null,
2020+
"metadata": {},
2021+
"outputs": [],
2022+
"source": [
2023+
"#| export\n",
2024+
"def asave_iter(g):\n",
2025+
" \"Like `save_iter`, but for async iterators\"\n",
18512026
" @trim_wraps(g)\n",
18522027
" def _(*args, **kwargs): return _save_iter(g, *args, **kwargs)\n",
18532028
" return _"
@@ -1857,7 +2032,7 @@
18572032
"cell_type": "markdown",
18582033
"metadata": {},
18592034
"source": [
1860-
"`save_iter` is a decorator that allows a generator function to store values in the returned iterator object. The generator receives an object as its first parameter, which it can use to store attributes."
2035+
"`asave_iter` provides the same functionality as `save_iter`, but for async generator functions. `yield from` and `return` can not be used with async generator functions, so `SaveReturn` can't be used here."
18612036
]
18622037
},
18632038
{
@@ -1875,17 +2050,24 @@
18752050
}
18762051
],
18772052
"source": [
1878-
"@save_iter\n",
1879-
"def sum_range(self, n):\n",
2053+
"@asave_iter\n",
2054+
"async def asum_range(self, n):\n",
18802055
" total = 0\n",
18812056
" for i in range(n):\n",
18822057
" total += i\n",
18832058
" yield i\n",
18842059
" self.value = total\n",
18852060
"\n",
1886-
"sr = sum_range(5)\n",
1887-
"print(f\"Values: {list(sr)}\")\n",
1888-
"print(f\"Sum stored: {sr.value}\") # Sum stored: 10"
2061+
"asr = asum_range(5)\n",
2062+
"print(f\"Values: {[o async for o in asr]}\")\n",
2063+
"print(f\"Sum stored: {asr.value}\")"
2064+
]
2065+
},
2066+
{
2067+
"cell_type": "markdown",
2068+
"metadata": {},
2069+
"source": [
2070+
"## Other Helpers"
18892071
]
18902072
},
18912073
{

0 commit comments

Comments
 (0)