-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
142 lines (108 loc) · 3.69 KB
/
utils.py
File metadata and controls
142 lines (108 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os, zipfile, shutil, subprocess, shlex, sys # noqa
from urllib.parse import urlparse
import re
import logging
def load_file_from_url(
url: str,
model_dir: str,
file_name: str | None = None,
overwrite: bool = False,
progress: bool = True,
) -> str:
"""Download a file from `url` into `model_dir`,
using the file present if possible.
Returns the path to the downloaded file.
"""
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
# Overwrite
if os.path.exists(cached_file):
if overwrite or os.path.getsize(cached_file) == 0:
remove_files(cached_file)
# Download
if not os.path.exists(cached_file):
logger.info(f'Downloading: "{url}" to {cached_file}\n')
from torch.hub import download_url_to_file
download_url_to_file(url, cached_file, progress=progress)
else:
logger.debug(cached_file)
return cached_file
def friendly_name(file: str):
if file.startswith("http"):
file = urlparse(file).path
file = os.path.basename(file)
model_name, extension = os.path.splitext(file)
return model_name, extension
def download_manager(
url: str,
path: str,
extension: str = "",
overwrite: bool = False,
progress: bool = True,
):
url = url.strip()
name, ext = friendly_name(url)
name += ext if not extension else f".{extension}"
if url.startswith("http"):
filename = load_file_from_url(
url=url,
model_dir=path,
file_name=name,
overwrite=overwrite,
progress=progress,
)
else:
filename = path
return filename
def remove_files(file_list):
if isinstance(file_list, str):
file_list = [file_list]
for file in file_list:
if os.path.exists(file):
os.remove(file)
def remove_directory_contents(directory_path):
"""
Removes all files and subdirectories within a directory.
Parameters:
directory_path (str): Path to the directory whose
contents need to be removed.
"""
if os.path.exists(directory_path):
for filename in os.listdir(directory_path):
file_path = os.path.join(directory_path, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
logger.error(f"Failed to delete {file_path}. Reason: {e}")
logger.info(f"Content in '{directory_path}' removed.")
else:
logger.error(f"Directory '{directory_path}' does not exist.")
# Create directory if not exists
def create_directories(directory_path):
if isinstance(directory_path, str):
directory_path = [directory_path]
for one_dir_path in directory_path:
if not os.path.exists(one_dir_path):
os.makedirs(one_dir_path)
logger.debug(f"Directory '{one_dir_path}' created.")
def setup_logger(name_log):
logger = logging.getLogger(name_log)
logger.setLevel(logging.INFO)
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
_default_handler.flush = sys.stderr.flush
logger.addHandler(_default_handler)
logger.propagate = False
handlers = logger.handlers
for handler in handlers:
formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
handler.setFormatter(formatter)
# logger.handlers
return logger
logger = setup_logger("ss")
logger.setLevel(logging.INFO)