From f71354b282becb539c0213f73f6464acff83a3fa Mon Sep 17 00:00:00 2001 From: KBolashev Date: Mon, 24 Mar 2025 11:50:27 +0200 Subject: [PATCH] bug: double writing tags leads to a pk violation --- dagshub/data_engine/model/datasource.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 91a97121..089c92e0 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -70,10 +70,12 @@ import mlflow.entities import cloudpickle import ngrok + import mlflow.exceptions as mlflow_exceptions else: plugin_server_module = lazy_load("dagshub.data_engine.voxel_plugin_server.server") fo = lazy_load("fiftyone") mlflow = lazy_load("mlflow") + mlflow_exceptions = lazy_load("mlflow.exceptions") pandas = lazy_load("pandas") ngrok = lazy_load("ngrok") cloudpickle = lazy_load("cloudpickle") @@ -898,11 +900,20 @@ def _log_to_mlflow( if run is None: run = mlflow.start_run() client = mlflow.MlflowClient() - client.set_tag(run.info.run_id, MLFLOW_DATASOURCE_TAG_NAME, self.source.id) - if self.assigned_dataset is not None: - client.set_tag(run.info.run_id, MLFLOW_DATASET_TAG_NAME, self.assigned_dataset.dataset_id) - client.log_dict(run.info.run_id, self._to_dict(as_of), artifact_name) - log_message(f'Saved the datasource state to MLflow (run "{run.info.run_name}") as "{artifact_name}"') + + run_id = run.info.run_id + # Refetch the run from the backend, to prevent double-writing tags + run_info = client.get_run(run_id) + + try: + if MLFLOW_DATASOURCE_TAG_NAME not in run_info.data.tags: + client.set_tag(run_id, MLFLOW_DATASOURCE_TAG_NAME, self.source.id) + if self.assigned_dataset is not None and MLFLOW_DATASET_TAG_NAME not in run_info.data.tags: + client.set_tag(run_id, MLFLOW_DATASET_TAG_NAME, self.assigned_dataset.dataset_id) + client.log_dict(run.info.run_id, self._to_dict(as_of), artifact_name) + log_message(f'Saved the datasource state to MLflow (run "{run.info.run_name}") as "{artifact_name}"') + except mlflow_exceptions.RestException as e: + log_message(f"Failed to save the datasource state to MLflow (run {run.info.run_name}): {e}") return run def _get_mlflow_artifact_name(self, prefix: str, as_of: datetime.datetime) -> str: