Skip to content

Commit dbace5a

Browse files
committed
Added Progress bar to PdbDownloader().download()
1 parent 736cca0 commit dbace5a

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

package/MDAnalysis/web/downloaders.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from MDAnalysis.core.universe import Universe
2+
from MDAnalysis.lib.log import ProgressBar
3+
14
from pathlib import Path
25
from abc import ABC, abstractmethod
3-
import tempfile
6+
#from io import BytesIO
47

8+
import tempfile
59
import requests
6-
from MDAnalysis.core.universe import Universe
710

811
class FileDownloadPDBError(Exception):
912
"""
@@ -34,7 +37,7 @@ def _file_format_to_topology_string(file_extension):
3437
return valid_topology_string
3538

3639
class BaseDownloader(ABC):
37-
"""Abstract Base Class for all Downloaders. Not meant to be directly initalized!"""
40+
"""Abstract Base Class for all File-Based Downloaders. Not meant to be directly initalized!"""
3841

3942
def __str__(self):
4043
return f"Metadata: id={self.id}, file_format={self.file_format}, "
@@ -70,6 +73,19 @@ def convert_to_universe(self, **kwargs):
7073
finally:
7174
self._file.close()
7275

76+
def _requests_progress_bar(requests_response, file_name, file_writer, return_writer=False):
77+
"""Puts a progress bar while writing a file_like object from the web"""
78+
chunk_size = 1
79+
r = requests_response
80+
81+
with ProgressBar(total=len(r.content), unit='B', unit_scale=True, desc=file_name) as pb:
82+
for i in r.iter_content(chunk_size=chunk_size):
83+
file_writer.write(i)
84+
pb.update(chunk_size)
85+
86+
if return_writer:
87+
return file_writer
88+
7389
class PdbDownloader(BaseDownloader):
7490
"""Class to handle download PDBs from the RCSB"""
7591

@@ -84,12 +100,12 @@ def _open_file(self, cache_path):
84100

85101
# create temporary file to save pdb file
86102
if cache_path is None:
87-
self._file = tempfile.NamedTemporaryFile(mode='wb')
103+
self._file = tempfile.NamedTemporaryFile('wb')
88104
self._download = True
89105

90106
# Create/Parse download cache
91107
else:
92-
named_file_path = Path(cache_path) / f"{self.id}.{self.file_format}"
108+
named_file_path = Path(cache_path) / f"{self.id}.{self.file_format}"
93109

94110
# Found Cache, so don't download anything and open existing file
95111
if named_file_path.exists() and named_file_path.is_file():
@@ -101,7 +117,7 @@ def _open_file(self, cache_path):
101117
self._download = True
102118

103119

104-
def download(self, cache_path=None, timeout=None):
120+
def download(self, cache_path=None, timeout=None, progress_bar=False):
105121
"""Downloads files from the RCSB"""
106122

107123
# Sets self._file correctly
@@ -111,15 +127,19 @@ def download(self, cache_path=None, timeout=None):
111127
if self._download:
112128
try:
113129
r = requests.get(f"https://files.rcsb.org/download/{self.id}.{self.file_format}",
114-
timeout=timeout)
130+
timeout=timeout, stream=progress_bar)
131+
115132
r.raise_for_status()
116-
self._file.write(r.content)
133+
if progress_bar:
134+
_requests_progress_bar(r, f"{self.id}.{self.file_format}", self._file)
135+
else:
136+
self._file.write(r.content)
117137
except requests.HTTPError:
118-
# This also deletes the undownloaded file since write() hasn't been called yet
119138
raise FileDownloadPDBError
120139
finally:
121140
# Closes File safely if saving to cache
122141
if cache_path is not None:
123142
self._file.close()
124143

125144
return self
145+

0 commit comments

Comments
 (0)