Skip to content

Commit a3c6417

Browse files
committed
chore: multithreading for I/O requests
1 parent bff9884 commit a3c6417

6 files changed

Lines changed: 81 additions & 25 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ gtfs/__version__.py
33

44
# Generated / downloaded files
55
*.zip
6-
*.p
6+
*.pkl
77

88
# virtualenv
99
.venv/

feeds/AlbanyNy.pkl

75 Bytes
Binary file not shown.

feeds/Berlin.pkl

73 Bytes
Binary file not shown.

gtfs/__main__.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
"""Command line interface for fetching GTFS."""
33
import os
4+
import threading
45
from typing import Optional
56

67
import typer
@@ -37,6 +38,18 @@ def check_bbox(bbox: str) -> Optional[Bbox]:
3738
return Bbox(min_x, min_y, max_x, max_y)
3839

3940

41+
def check_sources(sources: str) -> Optional[str]:
42+
"""Check if the sources are valid."""
43+
if sources is None:
44+
return None
45+
sources = sources.split(",")
46+
for source in sources:
47+
if not any(src.__name__.lower() == source.lower() for src in feed_sources):
48+
raise typer.BadParameter(f"{source} is not a valid feed source!")
49+
50+
return ",".join(sources)
51+
52+
4053
@app.command()
4154
def list_feeds(
4255
bbox: Annotated[
@@ -86,10 +99,8 @@ def list_feeds(
8699
["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1
87100
)
88101

89-
90102
filtered_srcs = ""
91103

92-
93104
for src in feed_sources:
94105
feed_bbox: Bbox = src.bbox
95106
if bbox is not None and predicate == "contains":
@@ -101,7 +112,6 @@ def list_feeds(
101112
):
102113
continue
103114

104-
105115
filtered_srcs += src.__name__ + ", "
106116

107117
if pretty is True:
@@ -116,23 +126,22 @@ def list_feeds(
116126

117127
print(src.url)
118128

119-
120129
if pretty is True:
121130
print("\n" + pretty_output.get_string())
122131

123132
if typer.confirm("Do you want to fetch feeds from these sources?"):
124133
fetch_feeds(sources=filtered_srcs[:-1])
125134

126135

127-
128136
@app.command()
129137
def fetch_feeds(
130138
sources: Annotated[
131139
Optional[str],
132140
typer.Option(
133141
"--sources",
134142
"-src",
135-
help="pass value as a string separated by commas like this: src1,src2,src3",
143+
help="pass value as a string separated by commas like this: Berlin,AlbanyNy,...",
144+
callback=check_sources,
136145
),
137146
] = None,
138147
search: Annotated[
@@ -148,19 +157,27 @@ def fetch_feeds(
148157
typer.Option(
149158
"--output-dir",
150159
"-o",
151-
help="the directory where the downloaded feeds will be saved",
160+
help="the directory where the downloaded feeds will be saved, default is feeds",
152161
),
153162
] = "feeds",
163+
concurrency: Annotated[
164+
Optional[int],
165+
typer.Option(
166+
"--concurrency",
167+
"-c",
168+
help="the number of concurrent downloads, default is 4",
169+
),
170+
] = 4,
154171
) -> None:
155172
"""Fetch feeds from sources.
156173
157174
:param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available.
158175
:param search: Search for feeds based on a string.
159176
:param output_dir: The directory where the downloaded feeds will be saved; default is feeds.
177+
:param concurrency: The number of concurrent downloads; default is 4.
160178
"""
161179
# statuses = {} # collect the statuses for all the files
162180

163-
164181
if not sources:
165182
if not search:
166183
# fetch all feeds
@@ -173,28 +190,47 @@ def fetch_feeds(
173190
if search.lower() in src.__name__.lower() or search.lower() in src.url.lower()
174191
]
175192
else:
176-
# fetch feeds based on sources
177-
sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()]
193+
if search:
194+
raise typer.BadParameter("Please pass either sources or search, not both at the same time!")
195+
else:
196+
sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()]
178197

179198
output_dir_path = os.path.join(os.getcwd(), output_dir)
180199
if not os.path.exists(output_dir_path):
181200
os.makedirs(output_dir_path)
182201

183202
LOG.info(f"Going to fetch feeds from sources: {sources}")
184203

185-
for src in sources:
186-
LOG.debug(f"Going to start fetch for {src}...")
187-
try:
188-
if issubclass(src, FeedSource):
189-
inst = src()
190-
inst.ddir = output_dir_path
191-
inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl")
192-
inst.fetch()
193-
# statuses.update(inst.status)
194-
else:
195-
LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.")
196-
except AttributeError:
197-
LOG.error(f"Skipping feed {src}, which could not be found.")
204+
threads = []
205+
206+
def thread_worker():
207+
while True:
208+
try:
209+
src = sources.pop(0)
210+
except IndexError:
211+
break
212+
213+
LOG.debug(f"Going to start fetch for {src}...")
214+
try:
215+
if issubclass(src, FeedSource):
216+
inst = src()
217+
inst.ddir = output_dir_path
218+
inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl")
219+
inst.fetch()
220+
# statuses.update(inst.status)
221+
else:
222+
LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.")
223+
except AttributeError:
224+
LOG.error(f"Skipping feed {src}, which could not be found.")
225+
226+
for _ in range(concurrency):
227+
thread = threading.Thread(target=thread_worker)
228+
thread.start()
229+
threads.append(thread)
230+
231+
# Wait for all threads to complete
232+
for thread in threads:
233+
thread.join()
198234

199235
# ptable = ColorTable(
200236
# [

gtfs/feed_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> boo
144144
LOG.debug("No last-modified header set")
145145
posted_date = datetime.utcnow().strftime(TIMECHECK_FMT)
146146
self.set_posted_date(feed_file, posted_date)
147-
LOG.info("Download completed successfully.")
147+
LOG.info(f"Download completed successfully for {feed_file}.")
148148
return True
149149
else:
150150
self.set_error(feed_file, "Download failed")

tests/test_cli.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,23 @@ def test_contains_predicate(self, runner):
4949
def test_pretty(self, runner):
5050
result = runner.invoke(app, ["list-feeds", "-pt"], input="N\n")
5151
assert result.exit_code == 0
52+
53+
54+
class TestFetchFeedsCommand:
55+
def test_help(self, runner):
56+
result = runner.invoke(app, ["fetch-feeds", "--help"])
57+
assert result.exit_code == 0
58+
assert "Fetch feeds from sources." in result.stdout
59+
60+
def test_bad_args(self, runner):
61+
result = runner.invoke(app, ["fetch-feeds", "-src", "berlin", "-s", "cdta"])
62+
assert result.exit_code == 2
63+
assert "Please pass either sources or search" in result.stdout
64+
65+
def test_fetch_with_sources(self, runner):
66+
result = runner.invoke(app, ["fetch-feeds", "-src", "berlin"])
67+
assert result.exit_code == 0
68+
69+
def test_fetch_with_search(self, runner):
70+
result = runner.invoke(app, ["fetch-feeds", "-s", "cdta"])
71+
assert result.exit_code == 0

0 commit comments

Comments
 (0)