11#!/usr/bin/env python
22"""Command line interface for fetching GTFS."""
33import os
4+ import threading
45from typing import Optional
56
67import typer
@@ -37,6 +38,18 @@ def check_bbox(bbox: str) -> Optional[Bbox]:
3738 return Bbox (min_x , min_y , max_x , max_y )
3839
3940
41+ def check_sources (sources : str ) -> Optional [str ]:
42+ """Check if the sources are valid."""
43+ if sources is None :
44+ return None
45+ sources = sources .split ("," )
46+ for source in sources :
47+ if not any (src .__name__ .lower () == source .lower () for src in feed_sources ):
48+ raise typer .BadParameter (f"{ source } is not a valid feed source!" )
49+
50+ return "," .join (sources )
51+
52+
4053@app .command ()
4154def list_feeds (
4255 bbox : Annotated [
@@ -86,10 +99,8 @@ def list_feeds(
8699 ["Feed Source" , "Transit URL" , "Bounding Box" ], theme = Themes .OCEAN , hrules = 1
87100 )
88101
89-
90102 filtered_srcs = ""
91103
92-
93104 for src in feed_sources :
94105 feed_bbox : Bbox = src .bbox
95106 if bbox is not None and predicate == "contains" :
@@ -101,7 +112,6 @@ def list_feeds(
101112 ):
102113 continue
103114
104-
105115 filtered_srcs += src .__name__ + ", "
106116
107117 if pretty is True :
@@ -116,23 +126,22 @@ def list_feeds(
116126
117127 print (src .url )
118128
119-
120129 if pretty is True :
121130 print ("\n " + pretty_output .get_string ())
122131
123132 if typer .confirm ("Do you want to fetch feeds from these sources?" ):
124133 fetch_feeds (sources = filtered_srcs [:- 1 ])
125134
126135
127-
128136@app .command ()
129137def fetch_feeds (
130138 sources : Annotated [
131139 Optional [str ],
132140 typer .Option (
133141 "--sources" ,
134142 "-src" ,
135- help = "pass value as a string separated by commas like this: src1,src2,src3" ,
143+ help = "pass value as a string separated by commas like this: Berlin,AlbanyNy,..." ,
144+ callback = check_sources ,
136145 ),
137146 ] = None ,
138147 search : Annotated [
@@ -148,19 +157,27 @@ def fetch_feeds(
148157 typer .Option (
149158 "--output-dir" ,
150159 "-o" ,
151- help = "the directory where the downloaded feeds will be saved" ,
160+ help = "the directory where the downloaded feeds will be saved, default is feeds " ,
152161 ),
153162 ] = "feeds" ,
163+ concurrency : Annotated [
164+ Optional [int ],
165+ typer .Option (
166+ "--concurrency" ,
167+ "-c" ,
168+ help = "the number of concurrent downloads, default is 4" ,
169+ ),
170+ ] = 4 ,
154171) -> None :
155172 """Fetch feeds from sources.
156173
157174 :param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available.
158175 :param search: Search for feeds based on a string.
159176 :param output_dir: The directory where the downloaded feeds will be saved; default is feeds.
177+ :param concurrency: The number of concurrent downloads; default is 4.
160178 """
161179 # statuses = {} # collect the statuses for all the files
162180
163-
164181 if not sources :
165182 if not search :
166183 # fetch all feeds
@@ -173,28 +190,47 @@ def fetch_feeds(
173190 if search .lower () in src .__name__ .lower () or search .lower () in src .url .lower ()
174191 ]
175192 else :
176- # fetch feeds based on sources
177- sources = [src for src in feed_sources if src .__name__ .lower () in sources .lower ()]
193+ if search :
194+ raise typer .BadParameter ("Please pass either sources or search, not both at the same time!" )
195+ else :
196+ sources = [src for src in feed_sources if src .__name__ .lower () in sources .lower ()]
178197
179198 output_dir_path = os .path .join (os .getcwd (), output_dir )
180199 if not os .path .exists (output_dir_path ):
181200 os .makedirs (output_dir_path )
182201
183202 LOG .info (f"Going to fetch feeds from sources: { sources } " )
184203
185- for src in sources :
186- LOG .debug (f"Going to start fetch for { src } ..." )
187- try :
188- if issubclass (src , FeedSource ):
189- inst = src ()
190- inst .ddir = output_dir_path
191- inst .status_file = os .path .join (inst .ddir , src .__name__ + ".pkl" )
192- inst .fetch ()
193- # statuses.update(inst.status)
194- else :
195- LOG .warning (f"Skipping class { src .__name__ } , which does not subclass FeedSource." )
196- except AttributeError :
197- LOG .error (f"Skipping feed { src } , which could not be found." )
204+ threads = []
205+
206+ def thread_worker ():
207+ while True :
208+ try :
209+ src = sources .pop (0 )
210+ except IndexError :
211+ break
212+
213+ LOG .debug (f"Going to start fetch for { src } ..." )
214+ try :
215+ if issubclass (src , FeedSource ):
216+ inst = src ()
217+ inst .ddir = output_dir_path
218+ inst .status_file = os .path .join (inst .ddir , src .__name__ + ".pkl" )
219+ inst .fetch ()
220+ # statuses.update(inst.status)
221+ else :
222+ LOG .warning (f"Skipping class { src .__name__ } , which does not subclass FeedSource." )
223+ except AttributeError :
224+ LOG .error (f"Skipping feed { src } , which could not be found." )
225+
226+ for _ in range (concurrency ):
227+ thread = threading .Thread (target = thread_worker )
228+ thread .start ()
229+ threads .append (thread )
230+
231+ # Wait for all threads to complete
232+ for thread in threads :
233+ thread .join ()
198234
199235 # ptable = ColorTable(
200236 # [
0 commit comments