diff --git a/mlflow/store/artifact/db_artifact_repo.py b/mlflow/store/artifact/db_artifact_repo.py index 40a2037c2fdc3..ac511a461a83b 100644 --- a/mlflow/store/artifact/db_artifact_repo.py +++ b/mlflow/store/artifact/db_artifact_repo.py @@ -1,5 +1,6 @@ import logging import os +import posixpath from abc import ABCMeta import sqlalchemy import tempfile @@ -47,26 +48,17 @@ def extract_db_uri_and_root_path(repo_uri): error_msg = "Invalid database scheme in the URI: '%s'. %s" % (scheme, _INVALID_DB_URI_MSG) raise MlflowException(error_msg, INVALID_PARAMETER_VALUE) - if parsed_uri.query == "": - if parsed_uri.path == "": - return repo_uri, "" - else: - parsed_path = parsed_uri.path.split(ROOT_PATH_BASE, 1) - if len(parsed_path) == 2: - db_uri = os.path.dirname(os.path.dirname(repo_uri)) - path = os.path.normpath(repo_uri.split(db_uri)[1]) - path = path.split(os.sep, 1)[1] - return db_uri, path - else: - return repo_uri, "" - else: - parsed_query = parsed_uri.query.split("/", 1) - if len(parsed_query) == 2: - path = os.path.normpath(parsed_query[1]) - parsed_uri = parsed_uri._replace(query=parsed_query[0]) - else: - path = "" - return urllib.parse.urlunparse(parsed_uri), path + def get_dbname_and_path(uri_path): + head, tail = posixpath.split(uri_path) + if len(head) == 0 or head == posixpath.sep: + return tail, posixpath.sep + + dbname, path = get_dbname_and_path(head) + return dbname, posixpath.join(path, tail) + + dbname, artifact_path = get_dbname_and_path(parsed_uri.path) + parsed_root_uri = parsed_uri._replace(path=dbname) + return urllib.parse.urlunparse(parsed_root_uri), artifact_path class DBArtifactRepository(ArtifactRepository):