Skip to content

Commit 900cfa4

Browse files
committed
fixes #819
1 parent 464a16c commit 900cfa4

2 files changed

Lines changed: 92 additions & 98 deletions

File tree

fastcore/parallel.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,18 @@ def _f(f): return threaded(True, daemon=daemon)(f)()
6161

6262

6363
# %% ../nbs/03a_parallel.ipynb #44d4651b
64-
def _call(lock, pause, n, g, item):
64+
def _call(lock, pause, n, g, item, return_exceptions=False):
6565
l = False
6666
if pause:
6767
try:
6868
l = lock.acquire(timeout=pause*(n+2))
6969
time.sleep(pause)
7070
finally:
7171
if l: lock.release()
72-
return g(item)
72+
try: return g(item)
73+
except Exception as e:
74+
if return_exceptions: return e
75+
raise
7376

7477
# %% ../nbs/03a_parallel.ipynb #63a3920f
7578
def parallelable(param_name, num_workers, f=None):
@@ -90,11 +93,11 @@ def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):
9093
if self.not_parallel: max_workers=1
9194
super().__init__(max_workers, **kwargs)
9295

93-
def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs):
96+
def map(self, f, items, *args, timeout=None, chunksize=1, return_exceptions=False, **kwargs):
9497
if self.not_parallel == False: self.lock = Manager().Lock()
9598
g = partial(f, *args, **kwargs)
9699
if self.not_parallel: return map(g, items)
97-
_g = partial(_call, self.lock, self.pause, self.max_workers, g)
100+
_g = partial(_call, self.lock, self.pause, self.max_workers, g, return_exceptions=return_exceptions)
98101
try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)
99102
except Exception as e: self.on_exc(e)
100103

@@ -109,33 +112,32 @@ def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):
109112
if self.not_parallel: max_workers=1
110113
super().__init__(max_workers, **kwargs)
111114

112-
def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs):
115+
def map(self, f, items, *args, timeout=None, chunksize=1, return_exceptions=False, **kwargs):
113116
if not parallelable('max_workers', self.max_workers, f): self.max_workers = 0
114117
self.not_parallel = self.max_workers==0
115118
if self.not_parallel: self.max_workers=1
116119

117120
if self.not_parallel == False: self.lock = Manager().Lock()
118121
g = partial(f, *args, **kwargs)
119122
if self.not_parallel: return map(g, items)
120-
_g = partial(_call, self.lock, self.pause, self.max_workers, g)
123+
_g = partial(_call, self.lock, self.pause, self.max_workers, g, return_exceptions=return_exceptions)
121124
try: return super().map(_g, items, timeout=timeout, chunksize=chunksize)
122125
except Exception as e: self.on_exc(e)
123126

124127
# %% ../nbs/03a_parallel.ipynb #529e1bb1
125128
def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None, pause=0,
126-
method=None, threadpool=False, timeout=None, chunksize=1, **kwargs):
129+
method=None, threadpool=False, timeout=None, chunksize=1, return_exceptions=False, **kwargs):
127130
"Applies `func` in parallel to `items`, using `n_workers`"
128-
try: from fastprogress import progress_bar
129-
except ImportError: return None
130131
kwpool = {}
131132
if threadpool: pool = ThreadPoolExecutor
132133
else:
133134
if not method and sys.platform == 'darwin': method='fork'
134135
if method: kwpool['mp_context'] = get_context(method)
135136
pool = ProcessPoolExecutor
136137
with pool(n_workers, pause=pause, **kwpool) as ex:
137-
r = ex.map(f,items, *args, timeout=timeout, chunksize=chunksize, **kwargs)
138-
if progress and progress_bar:
138+
r = ex.map(f,items, *args, timeout=timeout, chunksize=chunksize, return_exceptions=return_exceptions, **kwargs)
139+
if progress:
140+
from fastprogress import progress_bar
139141
if total is None: total = len(items)
140142
r = progress_bar(r, total=total, leave=False)
141143
return L(r)
@@ -149,7 +151,7 @@ def _add_one(x, a=1):
149151

150152
# %% ../nbs/03a_parallel.ipynb #87a80e04
151153
async def parallel_async(f, items, *args, n_workers=16, pause=0,
152-
timeout=None, chunksize=1, on_exc=print, cancel_on_error=False, **kwargs):
154+
timeout=None, chunksize=1, on_exc=print, cancel_on_error=False, return_exceptions=False, **kwargs):
153155
"Applies `f` to `items` in parallel using asyncio and a semaphore to limit concurrency."
154156
import asyncio
155157
semaphore = asyncio.Semaphore(n_workers)
@@ -161,7 +163,7 @@ async def limited_task(i, item):
161163
if cancel_on_error:
162164
async with asyncio.TaskGroup() as tg: tasks = [tg.create_task(limited_task(i, item)) for i,item in enumerate(items)]
163165
return [t.result() for t in tasks]
164-
return await asyncio.gather(*[limited_task(i, item) for i,item in enumerate(items)], return_exceptions=True)
166+
return await asyncio.gather(*[limited_task(i, item) for i,item in enumerate(items)], return_exceptions=return_exceptions)
165167

166168
# %% ../nbs/03a_parallel.ipynb #6748aa27
167169
def bg_task(coro):

0 commit comments

Comments
 (0)