diff --git a/platon/__init__.py b/platon/__init__.py index 76f7a015..0c681c30 100644 --- a/platon/__init__.py +++ b/platon/__init__.py @@ -2,3 +2,4 @@ __version__ = "6.2.2" __md5sum__ = "2cc3d6d746c80cc38bd9fdb25c226842" __data_url__ = "https://astro.uchicago.edu/~mz/data_{}.zip".format(__md5sum__) +__gdrive__ = "https://drive.google.com/uc?id=" # Replace with the actual file ID diff --git a/platon/_get_data.py b/platon/_get_data.py index 0c2c4dd9..37506c4d 100644 --- a/platon/_get_data.py +++ b/platon/_get_data.py @@ -1,7 +1,6 @@ from urllib.request import urlopen - from pkg_resources import resource_filename -from platon import __data_url__, __md5sum__ +from platon import __data_url__, __gdrive__, __md5sum__ import sys import zipfile @@ -9,6 +8,18 @@ import hashlib import ssl +# Flag to enable or disable Google Drive as the CDN +GDRIVE_CDN = False # Set to False to use the default __data_url__ + +# Try to import gdown only if GDRIVE_CDN is True +if GDRIVE_CDN: + try: + import gdown + except ImportError: + print("Error: gdown is not installed. run 'pip install gdown'") + sys.exit(1) # Exit the program if gdown is not install + + def get_data_if_needed(): if not os.path.isdir(resource_filename(__name__, "data/")): get_data(resource_filename(__name__, "./")) @@ -17,56 +28,63 @@ def get_data_if_needed(): curr_md5sum = f.read().strip() if __md5sum__ != curr_md5sum: - print("Warning: data files are out of date. To update, remove the PLATON data directory ({}) and PLATON will automatically download the latest data files on the next run.".format(resource_filename(__name__, "data/"))) - + print("Warning: data files are out of date. To update, remove the PLATON data directory ({}) and PLATON will automatically download the latest data files on the next run.".format(resource_filename(__name__, "data/"))) + def get_data(target_dir): - MB_TO_BYTES = 2**20 - filename = "data.zip" - print("Data URL", __data_url__) - - #Bad! Dangerous! But necessary...get a real certificate, Caltech! - ctx = ssl.create_default_context() - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - u = urlopen(__data_url__, context=ctx) - f = open(filename, 'wb') - - #Only for Python 3, because we don't support Python 2 anymore - file_size = int(u.getheader("Content-Length")) - - print("Downloading {}: {:.0f} MB".format( - filename, file_size / MB_TO_BYTES)) - - bytes_downloaded = 0 - block_sz = 8192 - while True: - buffer = u.read(block_sz) - if not buffer: - break - - bytes_downloaded += len(buffer) - f.write(buffer) - percentage = int(100 * bytes_downloaded / file_size) - status = "{:.0f} MB [{}%]".format( - bytes_downloaded / MB_TO_BYTES, percentage) - print(status, end="\r") - - f.close() - - print("\nExtracting...") - zip_ref = zipfile.ZipFile(filename, 'r') - zip_ref.extractall(target_dir) - zip_ref.close() + # Determine which CDN to use + use_gdrive = GDRIVE_CDN and bool(__gdrive__) # Use GDrive if True + url = __gdrive__ if use_gdrive else __data_url__ + filename = "data_gdrive.zip" if use_gdrive else "data.zip" + + # Check if the file already exists + if os.path.exists(filename): + print(f"'{filename}' already exists. Skipping.") + else: + if use_gdrive: + print(f"Downloading from Google Drive: {url}") + gdown.download(url, filename, quiet=False) # Use gdown for Google Drive downloads + else: + MB_TO_BYTES = 2**20 + print("Using URL:", url) + + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + u = urlopen(url, context=ctx) + with open(filename, 'wb') as f: + file_size = int(u.getheader("Content-Length")) + print("Downloading {}: {:.0f} MB".format( + filename, file_size / MB_TO_BYTES)) + + bytes_downloaded = 0 + block_sz = 8192 + while True: + buffer = u.read(block_sz) + if not buffer: + break + + bytes_downloaded += len(buffer) + f.write(buffer) + percentage = int(100 * bytes_downloaded / file_size) + status = "{:.0f} MB [{}%]".format( + bytes_downloaded / MB_TO_BYTES, percentage) + print(status, end="\r") + print("\nDownload finished!") + + print("Extracting...") + with zipfile.ZipFile(filename, 'r') as zip_ref: + zip_ref.extractall(target_dir) print("Extraction finished!") + with open(filename, "rb") as f: curr_md5sum = hashlib.md5(f.read()).hexdigest() if curr_md5sum != __md5sum__: - raise RuntimeError("Downloaded data file is corrupt (wrong md5sum). Please try again.") - + raise RuntimeError("Downloaded data file is corrupt (wrong md5sum). Please try again.") + with open(resource_filename(__name__, "md5sum"), "w") as f: f.write(curr_md5sum) - + os.remove(filename)