From 41b396047bbe400b25eb8a064231072bd7325313 Mon Sep 17 00:00:00 2001 From: Corey Zumar Date: Thu, 12 Dec 2019 23:05:38 -0800 Subject: [PATCH 1/2] Tweak URI --- mlflow/server/js/src/utils/ArtifactUtils.js | 2 + mlflow/store/artifact/db_artifact_repo.py | 55 +++++++++++++-------- mlflow/store/tracking/file_store.py | 21 ++++++-- mlflow/tracking/_tracking_service/client.py | 1 + mlflow/utils/rest_utils.py | 1 + 5 files changed, 55 insertions(+), 25 deletions(-) diff --git a/mlflow/server/js/src/utils/ArtifactUtils.js b/mlflow/server/js/src/utils/ArtifactUtils.js index 37d44df86c7da..3c31c4788541d 100644 --- a/mlflow/server/js/src/utils/ArtifactUtils.js +++ b/mlflow/server/js/src/utils/ArtifactUtils.js @@ -41,8 +41,10 @@ export class ArtifactNode { } static findChild(node, path) { + console.log(path); const parts = path.split('/'); let ret = node; + console.log(ret.children); parts.forEach((part) => { if (ret.children && ret.children[part] !== undefined) { ret = ret.children[part]; diff --git a/mlflow/store/artifact/db_artifact_repo.py b/mlflow/store/artifact/db_artifact_repo.py index 40a2037c2fdc3..4ff46e6c92464 100644 --- a/mlflow/store/artifact/db_artifact_repo.py +++ b/mlflow/store/artifact/db_artifact_repo.py @@ -1,5 +1,6 @@ import logging import os +import posixpath from abc import ABCMeta import sqlalchemy import tempfile @@ -40,33 +41,47 @@ def _relative_path_local(base_dir, subdir_path): # The repo_uri is of the form DB_URI/runID/ROOT_PATH_BASE where DB_URI: # +://:@:/?. def extract_db_uri_and_root_path(repo_uri): + print("REPO URI", repo_uri) parsed_uri = urllib.parse.urlparse(repo_uri) + print(parsed_uri) scheme = parsed_uri.scheme scheme_plus_count = scheme.count('+') if scheme_plus_count > 1: error_msg = "Invalid database scheme in the URI: '%s'. %s" % (scheme, _INVALID_DB_URI_MSG) raise MlflowException(error_msg, INVALID_PARAMETER_VALUE) - if parsed_uri.query == "": - if parsed_uri.path == "": - return repo_uri, "" - else: - parsed_path = parsed_uri.path.split(ROOT_PATH_BASE, 1) - if len(parsed_path) == 2: - db_uri = os.path.dirname(os.path.dirname(repo_uri)) - path = os.path.normpath(repo_uri.split(db_uri)[1]) - path = path.split(os.sep, 1)[1] - return db_uri, path - else: - return repo_uri, "" - else: - parsed_query = parsed_uri.query.split("/", 1) - if len(parsed_query) == 2: - path = os.path.normpath(parsed_query[1]) - parsed_uri = parsed_uri._replace(query=parsed_query[0]) - else: - path = "" - return urllib.parse.urlunparse(parsed_uri), path + def get_dbname_and_path(uri_path): + head, tail = posixpath.split(uri_path) + if len(head) == 0 or head == posixpath.sep: + return tail, posixpath.sep + + dbname, path = get_dbname_and_path(head) + return dbname, posixpath.join(path, tail) + + dbname, artifact_path = get_dbname_and_path(parsed_uri.path) + parsed_root_uri = parsed_uri._replace(path=dbname) + return urllib.parse.urlunparse(parsed_root_uri), artifact_path + # + # if parsed_uri.query == "": + # if parsed_uri.path == "": + # return repo_uri, "" + # else: + # parsed_path = parsed_uri.path.split(ROOT_PATH_BASE, 1) + # if len(parsed_path) == 2: + # db_uri = os.path.dirname(os.path.dirname(repo_uri)) + # path = os.path.normpath(repo_uri.split(db_uri)[1]) + # path = path.split(os.sep, 1)[1] + # return db_uri, path + # else: + # return repo_uri, "" + # else: + # parsed_query = parsed_uri.query.split("/", 1) + # if len(parsed_query) == 2: + # path = os.path.normpath(parsed_query[1]) + # parsed_uri = parsed_uri._replace(query=parsed_query[0]) + # else: + # path = "" + # return urllib.parse.urlunparse(parsed_uri), path class DBArtifactRepository(ArtifactRepository): diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 3dc98659dc757..797c449a9c64b 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -167,10 +167,9 @@ def _get_tag_path(self, experiment_id, run_uuid, tag_name): def _get_artifact_dir(self, experiment_id, run_uuid): _validate_run_id(run_uuid) - artifacts_dir = posixpath.join(self.get_experiment(experiment_id).artifact_location, - run_uuid, - FileStore.ARTIFACTS_FOLDER_NAME) - return artifacts_dir + artifact_location = self.get_experiment(experiment_id).artifact_location + return self._get_target_artifact_location( + artifact_location, run_uuid, FileStore.ARTIFACTS_FOLDER_NAME) def _get_active_experiments(self, full_path=False): exp_list = list_subdirs(self.root_directory, full_path) @@ -199,8 +198,20 @@ def list_experiments(self, view_type=ViewType.ACTIVE_ONLY): str(exp_id), str(rnfe), exc_info=True) return experiments + def _get_target_artifact_location(self, artifact_uri, *path): + from six.moves import urllib + parsed_artifact_uri = urllib.parse.urlparse(artifact_uri) + new_path = posixpath.join(parsed_artifact_uri.path, + *path) + parsed_artifact_location = parsed_artifact_uri._replace(path=new_path) + artifact_location = urllib.parse.urlunparse(parsed_artifact_location) + print("ARTIFACT LOCATION", artifact_location) + return artifact_location + def _create_experiment_with_id(self, name, experiment_id, artifact_uri): - artifact_uri = artifact_uri or posixpath.join(self.artifact_root_uri, str(experiment_id)) + artifact_uri = artifact_uri or self._get_target_artifact_location( + self.artifact_root_uri, str(experiment_id)) + # artifact_uri = artifact_uri or posixpath.join(self.artifact_root_uri, str(experiment_id)) self._check_root_dir() meta_dir = mkdir(self.root_directory, str(experiment_id)) experiment = Experiment(experiment_id, name, artifact_uri, LifecycleStage.ACTIVE) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index d30490512d374..7a6dcacdfc7f3 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -238,6 +238,7 @@ def log_artifact(self, run_id, local_path, artifact_path=None): :param artifact_path: If provided, the directory in ``artifact_uri`` to write to. """ run = self.get_run(run_id) + print("RUN ARTIFACT_URI", run.info.artifact_uri) artifact_repo = get_artifact_repository(run.info.artifact_uri) if os.path.isdir(local_path): dir_name = os.path.basename(os.path.normpath(local_path)) diff --git a/mlflow/utils/rest_utils.py b/mlflow/utils/rest_utils.py index fc7d80aaf76a7..1616a1a81cd76 100644 --- a/mlflow/utils/rest_utils.py +++ b/mlflow/utils/rest_utils.py @@ -65,6 +65,7 @@ def request_with_ratelimit_retries(max_rate_limit_interval, **kwargs): cleaned_hostname = strip_suffix(hostname, '/') url = "%s%s" % (cleaned_hostname, endpoint) + print(cleaned_hostname, endpoint) for i in range(retries): response = request_with_ratelimit_retries(max_rate_limit_interval, url=url, headers=headers, verify=verify, **kwargs) From c8f2bb5b219c9ab1ece7162e8c252d4c4010bc01 Mon Sep 17 00:00:00 2001 From: Corey Zumar Date: Fri, 13 Dec 2019 15:53:46 -0800 Subject: [PATCH 2/2] Simplify --- mlflow/server/js/src/utils/ArtifactUtils.js | 2 -- mlflow/store/artifact/db_artifact_repo.py | 23 --------------------- mlflow/store/tracking/file_store.py | 21 +++++-------------- mlflow/tracking/_tracking_service/client.py | 1 - mlflow/utils/rest_utils.py | 1 - 5 files changed, 5 insertions(+), 43 deletions(-) diff --git a/mlflow/server/js/src/utils/ArtifactUtils.js b/mlflow/server/js/src/utils/ArtifactUtils.js index 3c31c4788541d..37d44df86c7da 100644 --- a/mlflow/server/js/src/utils/ArtifactUtils.js +++ b/mlflow/server/js/src/utils/ArtifactUtils.js @@ -41,10 +41,8 @@ export class ArtifactNode { } static findChild(node, path) { - console.log(path); const parts = path.split('/'); let ret = node; - console.log(ret.children); parts.forEach((part) => { if (ret.children && ret.children[part] !== undefined) { ret = ret.children[part]; diff --git a/mlflow/store/artifact/db_artifact_repo.py b/mlflow/store/artifact/db_artifact_repo.py index 4ff46e6c92464..ac511a461a83b 100644 --- a/mlflow/store/artifact/db_artifact_repo.py +++ b/mlflow/store/artifact/db_artifact_repo.py @@ -41,9 +41,7 @@ def _relative_path_local(base_dir, subdir_path): # The repo_uri is of the form DB_URI/runID/ROOT_PATH_BASE where DB_URI: # +://:@:/?. def extract_db_uri_and_root_path(repo_uri): - print("REPO URI", repo_uri) parsed_uri = urllib.parse.urlparse(repo_uri) - print(parsed_uri) scheme = parsed_uri.scheme scheme_plus_count = scheme.count('+') if scheme_plus_count > 1: @@ -61,27 +59,6 @@ def get_dbname_and_path(uri_path): dbname, artifact_path = get_dbname_and_path(parsed_uri.path) parsed_root_uri = parsed_uri._replace(path=dbname) return urllib.parse.urlunparse(parsed_root_uri), artifact_path - # - # if parsed_uri.query == "": - # if parsed_uri.path == "": - # return repo_uri, "" - # else: - # parsed_path = parsed_uri.path.split(ROOT_PATH_BASE, 1) - # if len(parsed_path) == 2: - # db_uri = os.path.dirname(os.path.dirname(repo_uri)) - # path = os.path.normpath(repo_uri.split(db_uri)[1]) - # path = path.split(os.sep, 1)[1] - # return db_uri, path - # else: - # return repo_uri, "" - # else: - # parsed_query = parsed_uri.query.split("/", 1) - # if len(parsed_query) == 2: - # path = os.path.normpath(parsed_query[1]) - # parsed_uri = parsed_uri._replace(query=parsed_query[0]) - # else: - # path = "" - # return urllib.parse.urlunparse(parsed_uri), path class DBArtifactRepository(ArtifactRepository): diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index 797c449a9c64b..3dc98659dc757 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -167,9 +167,10 @@ def _get_tag_path(self, experiment_id, run_uuid, tag_name): def _get_artifact_dir(self, experiment_id, run_uuid): _validate_run_id(run_uuid) - artifact_location = self.get_experiment(experiment_id).artifact_location - return self._get_target_artifact_location( - artifact_location, run_uuid, FileStore.ARTIFACTS_FOLDER_NAME) + artifacts_dir = posixpath.join(self.get_experiment(experiment_id).artifact_location, + run_uuid, + FileStore.ARTIFACTS_FOLDER_NAME) + return artifacts_dir def _get_active_experiments(self, full_path=False): exp_list = list_subdirs(self.root_directory, full_path) @@ -198,20 +199,8 @@ def list_experiments(self, view_type=ViewType.ACTIVE_ONLY): str(exp_id), str(rnfe), exc_info=True) return experiments - def _get_target_artifact_location(self, artifact_uri, *path): - from six.moves import urllib - parsed_artifact_uri = urllib.parse.urlparse(artifact_uri) - new_path = posixpath.join(parsed_artifact_uri.path, - *path) - parsed_artifact_location = parsed_artifact_uri._replace(path=new_path) - artifact_location = urllib.parse.urlunparse(parsed_artifact_location) - print("ARTIFACT LOCATION", artifact_location) - return artifact_location - def _create_experiment_with_id(self, name, experiment_id, artifact_uri): - artifact_uri = artifact_uri or self._get_target_artifact_location( - self.artifact_root_uri, str(experiment_id)) - # artifact_uri = artifact_uri or posixpath.join(self.artifact_root_uri, str(experiment_id)) + artifact_uri = artifact_uri or posixpath.join(self.artifact_root_uri, str(experiment_id)) self._check_root_dir() meta_dir = mkdir(self.root_directory, str(experiment_id)) experiment = Experiment(experiment_id, name, artifact_uri, LifecycleStage.ACTIVE) diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 7a6dcacdfc7f3..d30490512d374 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -238,7 +238,6 @@ def log_artifact(self, run_id, local_path, artifact_path=None): :param artifact_path: If provided, the directory in ``artifact_uri`` to write to. """ run = self.get_run(run_id) - print("RUN ARTIFACT_URI", run.info.artifact_uri) artifact_repo = get_artifact_repository(run.info.artifact_uri) if os.path.isdir(local_path): dir_name = os.path.basename(os.path.normpath(local_path)) diff --git a/mlflow/utils/rest_utils.py b/mlflow/utils/rest_utils.py index 1616a1a81cd76..fc7d80aaf76a7 100644 --- a/mlflow/utils/rest_utils.py +++ b/mlflow/utils/rest_utils.py @@ -65,7 +65,6 @@ def request_with_ratelimit_retries(max_rate_limit_interval, **kwargs): cleaned_hostname = strip_suffix(hostname, '/') url = "%s%s" % (cleaned_hostname, endpoint) - print(cleaned_hostname, endpoint) for i in range(retries): response = request_with_ratelimit_retries(max_rate_limit_interval, url=url, headers=headers, verify=verify, **kwargs)