1+ from MDAnalysis .core .universe import Universe
2+ from MDAnalysis .lib .log import ProgressBar
3+
14from pathlib import Path
25from abc import ABC , abstractmethod
3- import tempfile
6+ #from io import BytesIO
47
8+ import tempfile
59import requests
6- from MDAnalysis .core .universe import Universe
710
811class FileDownloadPDBError (Exception ):
912 """
@@ -34,7 +37,7 @@ def _file_format_to_topology_string(file_extension):
3437 return valid_topology_string
3538
3639class BaseDownloader (ABC ):
37- """Abstract Base Class for all Downloaders. Not meant to be directly initalized!"""
40+ """Abstract Base Class for all File-Based Downloaders. Not meant to be directly initalized!"""
3841
3942 def __str__ (self ):
4043 return f"Metadata: id={ self .id } , file_format={ self .file_format } , "
@@ -70,6 +73,19 @@ def convert_to_universe(self, **kwargs):
7073 finally :
7174 self ._file .close ()
7275
76+ def _requests_progress_bar (requests_response , file_name , file_writer , return_writer = False ):
77+ """Puts a progress bar while writing a file_like object from the web"""
78+ chunk_size = 1
79+ r = requests_response
80+
81+ with ProgressBar (total = len (r .content ), unit = 'B' , unit_scale = True , desc = file_name ) as pb :
82+ for i in r .iter_content (chunk_size = chunk_size ):
83+ file_writer .write (i )
84+ pb .update (chunk_size )
85+
86+ if return_writer :
87+ return file_writer
88+
7389class PdbDownloader (BaseDownloader ):
7490 """Class to handle download PDBs from the RCSB"""
7591
@@ -84,12 +100,12 @@ def _open_file(self, cache_path):
84100
85101 # create temporary file to save pdb file
86102 if cache_path is None :
87- self ._file = tempfile .NamedTemporaryFile (mode = 'wb' )
103+ self ._file = tempfile .NamedTemporaryFile ('wb' )
88104 self ._download = True
89105
90106 # Create/Parse download cache
91107 else :
92- named_file_path = Path (cache_path ) / f"{ self .id } .{ self .file_format } "
108+ named_file_path = Path (cache_path ) / f"{ self .id } .{ self .file_format } "
93109
94110 # Found Cache, so don't download anything and open existing file
95111 if named_file_path .exists () and named_file_path .is_file ():
@@ -101,7 +117,7 @@ def _open_file(self, cache_path):
101117 self ._download = True
102118
103119
104- def download (self , cache_path = None , timeout = None ):
120+ def download (self , cache_path = None , timeout = None , progress_bar = False ):
105121 """Downloads files from the RCSB"""
106122
107123 # Sets self._file correctly
@@ -111,15 +127,19 @@ def download(self, cache_path=None, timeout=None):
111127 if self ._download :
112128 try :
113129 r = requests .get (f"https://files.rcsb.org/download/{ self .id } .{ self .file_format } " ,
114- timeout = timeout )
130+ timeout = timeout , stream = progress_bar )
131+
115132 r .raise_for_status ()
116- self ._file .write (r .content )
133+ if progress_bar :
134+ _requests_progress_bar (r , f"{ self .id } .{ self .file_format } " , self ._file )
135+ else :
136+ self ._file .write (r .content )
117137 except requests .HTTPError :
118- # This also deletes the undownloaded file since write() hasn't been called yet
119138 raise FileDownloadPDBError
120139 finally :
121140 # Closes File safely if saving to cache
122141 if cache_path is not None :
123142 self ._file .close ()
124143
125144 return self
145+
0 commit comments