Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@
import mlflow.entities
import cloudpickle
import ngrok
import mlflow.exceptions as mlflow_exceptions
else:
plugin_server_module = lazy_load("dagshub.data_engine.voxel_plugin_server.server")
fo = lazy_load("fiftyone")
mlflow = lazy_load("mlflow")
mlflow_exceptions = lazy_load("mlflow.exceptions")
pandas = lazy_load("pandas")
ngrok = lazy_load("ngrok")
cloudpickle = lazy_load("cloudpickle")
Expand Down Expand Up @@ -898,11 +900,20 @@ def _log_to_mlflow(
if run is None:
run = mlflow.start_run()
client = mlflow.MlflowClient()
client.set_tag(run.info.run_id, MLFLOW_DATASOURCE_TAG_NAME, self.source.id)
if self.assigned_dataset is not None:
client.set_tag(run.info.run_id, MLFLOW_DATASET_TAG_NAME, self.assigned_dataset.dataset_id)
client.log_dict(run.info.run_id, self._to_dict(as_of), artifact_name)
log_message(f'Saved the datasource state to MLflow (run "{run.info.run_name}") as "{artifact_name}"')

run_id = run.info.run_id
# Refetch the run from the backend, to prevent double-writing tags
run_info = client.get_run(run_id)

try:
if MLFLOW_DATASOURCE_TAG_NAME not in run_info.data.tags:
client.set_tag(run_id, MLFLOW_DATASOURCE_TAG_NAME, self.source.id)
if self.assigned_dataset is not None and MLFLOW_DATASET_TAG_NAME not in run_info.data.tags:
client.set_tag(run_id, MLFLOW_DATASET_TAG_NAME, self.assigned_dataset.dataset_id)
client.log_dict(run.info.run_id, self._to_dict(as_of), artifact_name)
log_message(f'Saved the datasource state to MLflow (run "{run.info.run_name}") as "{artifact_name}"')
except mlflow_exceptions.RestException as e:
log_message(f"Failed to save the datasource state to MLflow (run {run.info.run_name}): {e}")
return run

def _get_mlflow_artifact_name(self, prefix: str, as_of: datetime.datetime) -> str:
Expand Down