3535"""
3636import datetime as dt
3737import json
38+ import os
3839import re
3940import time
4041from collections import namedtuple
42+ from concurrent .futures import thread , ThreadPoolExecutor
4143
4244import yaml
4345from bioblend .galaxy .client import ConnectionError
@@ -211,7 +213,8 @@ def test_tools(self,
211213 repositories = None ,
212214 log = None ,
213215 test_user_api_key = None ,
214- test_user = "ephemeris@galaxyproject.org"
216+ test_user = "ephemeris@galaxyproject.org" ,
217+ parallel_tests = 1 ,
215218 ):
216219 """Run tool tests for all tools in each repository in supplied tool list or ``self.installed_repositories()``.
217220 """
@@ -232,34 +235,47 @@ def test_tools(self,
232235 installed_tools .extend (repo_tools )
233236
234237 all_test_results = []
238+ galaxy_interactor = self ._get_interactor (test_user , test_user_api_key )
239+ test_history = galaxy_interactor .new_history ()
235240
236- for tool in installed_tools :
237- log .info ("Testing tool '%s'" , tool ['id' ])
238- results = self ._test_tool (tool , test_user , test_user_api_key )
239- all_test_results .extend (results .tool_test_results )
240- log .info ("%s test passed, %d tests failed for tool '%s'" % (len (results .tests_passed ), len (results .tests_exceptions ), tool ['id' ]))
241- tests_passed .extend (results .tests_passed )
242- test_exceptions .extend (results .test_exceptions )
243-
244- report_obj = {
245- 'version' : '0.1' ,
246- 'tests' : all_test_results ,
247- }
248- with open (test_json , "w" ) as f :
249- json .dump (report_obj , f )
250- if log :
251- log .info ("Passed tool tests ({0}): {1}" .format (
252- len (tests_passed ),
253- [t for t in tests_passed ])
254- )
255- log .info ("Failed tool tests ({0}): {1}" .format (
256- len (test_exceptions ),
257- [t [0 ] for t in test_exceptions ])
258- )
259- log .info ("Total tool test time: {0}" .format (dt .datetime .now () - tool_test_start ))
260-
261- def _test_tool (self , tool , test_user , test_user_api_key ):
262-
241+ with ThreadPoolExecutor (max_workers = parallel_tests ) as executor :
242+ try :
243+ for tool in installed_tools :
244+ self ._test_tool (executor = executor ,
245+ tool = tool ,
246+ galaxy_interactor = galaxy_interactor ,
247+ test_history = test_history ,
248+ log = log ,
249+ tool_test_results = all_test_results ,
250+ tests_passed = tests_passed ,
251+ test_exceptions = test_exceptions ,
252+ )
253+ finally :
254+ # Always write report, even if test was cancelled.
255+ try :
256+ executor .shutdown (wait = True )
257+ except KeyboardInterrupt :
258+ executor ._threads .clear ()
259+ thread ._threads_queues .clear ()
260+ report_obj = {
261+ 'version' : '0.1' ,
262+ 'tests' : sorted (all_test_results , key = lambda el : el ['id' ]),
263+ }
264+ with open (test_json , "w" ) as f :
265+ json .dump (report_obj , f )
266+ log .info ("Report written to '%s'" , os .path .abspath (test_json ))
267+ if log :
268+ log .info ("Passed tool tests ({0}): {1}" .format (
269+ len (tests_passed ),
270+ [t for t in tests_passed ])
271+ )
272+ log .info ("Failed tool tests ({0}): {1}" .format (
273+ len (test_exceptions ),
274+ [t [0 ] for t in test_exceptions ])
275+ )
276+ log .info ("Total tool test time: {0}" .format (dt .datetime .now () - tool_test_start ))
277+
278+ def _get_interactor (self , test_user , test_user_api_key ):
263279 if test_user_api_key is None :
264280 whoami = self .gi .make_get_request (self .gi .url + "/whoami" ).json ()
265281 if whoami is not None :
@@ -273,36 +289,49 @@ def _test_tool(self, tool, test_user, test_user_api_key):
273289 if test_user_api_key is None :
274290 galaxy_interactor_kwds ["test_user" ] = test_user
275291 galaxy_interactor = GalaxyInteractorApi (** galaxy_interactor_kwds )
292+ return galaxy_interactor
293+
294+ def _test_tool (self ,
295+ executor ,
296+ tool ,
297+ galaxy_interactor ,
298+ test_history = None ,
299+ log = None ,
300+ tool_test_results = None ,
301+ tests_passed = None ,
302+ test_exceptions = None ):
303+ if test_history is None :
304+ test_history = galaxy_interactor .new_history ()
276305 tool_id = tool ["id" ]
277306 tool_version = tool ["version" ]
278307 tool_test_dicts = galaxy_interactor .get_tool_tests (tool_id , tool_version = tool_version )
279308 test_indices = list (range (len (tool_test_dicts )))
280- tool_test_results = []
281- tests_passed = []
282- test_exceptions = []
283309
284310 for test_index in test_indices :
285311 test_id = tool_id + "-" + str (test_index )
286312
287- def register (job_data ):
288- tool_test_results .append ({
289- 'id' : test_id ,
290- 'has_data' : True ,
291- 'data' : job_data ,
292- })
293-
294- try :
295- verify_tool (
296- tool_id , galaxy_interactor , test_index = test_index , tool_version = tool_version ,
297- register_job_data = register , quiet = True
298- )
299- tests_passed .append (test_id )
300- except Exception as e :
301- test_exceptions .append ((test_id , e ))
302- Results = namedtuple ("Results" , ["tool_test_results" , "tests_passed" , "test_exceptions" ])
303- return Results (tool_test_results = tool_test_results ,
304- tests_passed = tests_passed ,
305- test_exceptions = test_exceptions )
313+ def run_test (index , test_id ):
314+
315+ def register (job_data ):
316+ tool_test_results .append ({
317+ 'id' : test_id ,
318+ 'has_data' : True ,
319+ 'data' : job_data ,
320+ })
321+
322+ try :
323+ log .info ("Executing test '%s'" , test_id )
324+ verify_tool (
325+ tool_id , galaxy_interactor , test_index = index , tool_version = tool_version ,
326+ register_job_data = register , quiet = True , test_history = test_history ,
327+ )
328+ tests_passed .append (test_id )
329+ log .info ("Test '%s' passed" , test_id )
330+ except Exception as e :
331+ log .warning ("Test '%s' failed" , test_id )
332+ test_exceptions .append ((test_id , e ))
333+
334+ executor .submit (run_test , test_index , test_id )
306335
307336 def install_repository_revision (self , repository , log ):
308337 default_err_msg = ('All repositories that you are attempting to install '
@@ -508,7 +537,9 @@ def main():
508537 repositories = repos ,
509538 log = log ,
510539 test_user_api_key = args .test_user_api_key ,
511- test_user = args .test_user )
540+ test_user = args .test_user ,
541+ parallel_tests = args .parallel_tests ,
542+ )
512543 else :
513544 raise NotImplementedError ("This point in the code should not be reached. Please contact the developers." )
514545
@@ -523,7 +554,9 @@ def main():
523554 repositories = to_be_tested_repositories ,
524555 log = log ,
525556 test_user_api_key = args .test_user_api_key ,
526- test_user = args .test_user )
557+ test_user = args .test_user ,
558+ parallel_tests = args .parallel_tests ,
559+ )
527560
528561
529562if __name__ == "__main__" :
0 commit comments