diff --git a/.idea/dictionaries/project.xml b/.idea/dictionaries/project.xml index 53ae4b6b9..4e318d48a 100644 --- a/.idea/dictionaries/project.xml +++ b/.idea/dictionaries/project.xml @@ -2,6 +2,7 @@ cunyue + expid expname groupname leveldblog diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 9ce193332..bc31c0f6f 100755 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -23,7 +23,7 @@ SwanLabEnv.check() # 导出 OpenApi 接口,必须要等待上述的 import 语句执行完毕以后才能导出,否则会触发循环引用 -from .api import OpenApi +from .api import OpenApi, Api __version__ = get_package_version() @@ -48,6 +48,7 @@ "get_config", "config", "OpenApi", + "Api", "sync_wandb", "sync_mlflow", "sync_tensorboardX", diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index e99aa9c47..617cdffe1 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -1,15 +1,167 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/4/29 9:40 -@File: __init__.py -@IDE: pycharm -@Description: - SwanLab OpenAPI包 """ +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/5 17:58 +@description: SwanLab OpenAPI包 +""" + +from typing import Optional, List, Dict + +from swanlab.core_python import auth, Client +from swanlab.core_python.api.experiment import get_single_experiment, get_project_experiments +from swanlab.error import KeyFileError +from swanlab.log import swanlog +from swanlab.package import HostFormatter, get_key +from .deprecated import OpenApi +from .experiment import Experiment +from .experiments import Experiments +from .project import Project +from .projects import Projects +from .user import User +from .users import Users +from .utils import self_hosted +from .workspace import Workspace +from .workspaces import Workspaces +from ..core_python.api.project import get_project_info +from ..core_python.api.user import get_workspace_info + + +class Api: + def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, web_host: Optional[str] = None): + """ + 初始化 OpenApi 实例,用户需提前登录,或者提供API密钥 + :param api_key: API 密钥,可选 + :param host: API 主机地址,可选 + :param web_host: Web 主机地址,可选 + """ + if host or web_host: + HostFormatter(host, web_host)() + if api_key: + swanlog.debug("Using API key", api_key) + else: + swanlog.debug("Using existing key") + try: + api_key = get_key() + except KeyFileError as e: + swanlog.error("To use SwanLab OpenAPI, please login first.") + raise RuntimeError("Not logged in.") from e + + self._login_info = auth.code_login(api_key, save_key=False) + # 一个OpenApi对应一个client,可创建多个api获取从不同的client获取不同账号下的实验信息 + self._client: Client = Client(self._login_info) + self._web_host = self._login_info.web_host + self._login_user = self._login_info.username + + def user(self, username: str = None) -> User: + """ + 获取用户实例,用于操作用户相关信息 + :param username: 指定用户名,如果为 None,则返回当前登录用户 + :return: User 实例,可对当前/指定用户进行操作 + """ + return User(client=self._client, login_user=self._login_user, username=username) + + @self_hosted("root") + def users(self) -> Users: + """ + 超级管理员获取所有用户 + :return: User 实例,可对当前/指定用户进行操作 + """ + return Users(self._client, login_user=self._login_user) + + def projects( + self, + path: str, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ) -> Projects: + """ + 获取指定工作空间(组织)下的所有项目信息 + :param path: 工作空间(组织)名称 'username' + :param sort: 排序方式,可选 + :param search: 搜索关键词,可选 + :param detail: 是否返回详细信息,可选 + :return: Projects 实例,可遍历获取项目信息 + """ + return Projects( + self._client, + web_host=self._web_host, + path=path, + sort=sort, + search=search, + detail=detail, + ) + + def project( + self, + path: str, + ) -> Project: + """ + 获取指定工作空间(组织)下的指定项目信息 + :param path: 项目路径 'username/project' + :return: Project 实例,单个项目的信息 + """ + data = get_project_info(self._client, path=path) + return Project(self._client, web_host=self._web_host, data=data) + + def runs(self, path: str, filters: Dict[str, object] = None) -> Experiments: + """ + 获取指定项目下的所有实验信息 + :param path: 项目路径,格式为 'username/project' + :return: Experiments 实例,可遍历获取实验信息 + :param filters: 筛选实验的条件,可选 + """ + return Experiments(self._client, path=path, login_info=self._login_info, filters=filters) + + def run( + self, + path: str, + ) -> Experiment: + """ + 获取指定实验的信息 + :param path: 实验路径,格式为 'username/project/run_id' + :return: Experiment 实例,包含实验信息 + """ + # TODO: 待后端完善后替换成专用的接口 + if len(path.split('/')) != 3: + raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project/run_id'") + _data = get_single_experiment(self._client, path=path) + proj_path = path.rsplit('/', 1)[0] + data = get_project_experiments( + self._client, path=proj_path, filters={'name': _data['name'], 'created_at': _data['createdAt']} + ) + return Experiment( + self._client, + data=data[0], + path=proj_path, + web_host=self._web_host, + login_user=self._login_user, + line_count=1, + ) + + def workspaces( + self, + username: str = None, + ): + """ + 获取当前登录用户的工作空间迭代器 + 当username为其他用户时,可以作为visitor访问其工作空间 + """ + if username is None: + username = self._login_user + return Workspaces(self._client, username=username) + + def workspace( + self, + username: str = None, + ): + """ + 获取当前登录用户的工作空间 + """ + if username is None: + username = self._login_user + data = get_workspace_info(self._client, path=username) + return Workspace(self._client, data=data) -from swanlab.api.main import OpenApi -__all__ = [ - "OpenApi" -] +__all__ = ["Api", "OpenApi"] diff --git a/swanlab/api/deprecated/__init__.py b/swanlab/api/deprecated/__init__.py new file mode 100644 index 000000000..81840912f --- /dev/null +++ b/swanlab/api/deprecated/__init__.py @@ -0,0 +1,10 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2025/12/30 20:54 +@description: 旧版OpenApi,即将遗弃 +""" + +from .main import OpenApi + +__all__ = ["OpenApi"] \ No newline at end of file diff --git a/swanlab/api/base.py b/swanlab/api/deprecated/base.py similarity index 99% rename from swanlab/api/base.py rename to swanlab/api/deprecated/base.py index 2fd90e0a4..ab6a25896 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/deprecated/base.py @@ -13,9 +13,9 @@ import requests -from swanlab.api.types import ApiResponse from swanlab.core_python import auth, create_session from swanlab.log.log import SwanLog +from .types import ApiResponse _logger: Optional[SwanLog] = None @@ -64,7 +64,7 @@ def __init__(self, login_info: auth.LoginInfo): self.__login_info: auth.LoginInfo = login_info self.__session: requests.Session = self.__init_session() self.service: OpenApiService = OpenApiService(self) - + @property def session(self) -> requests.Session: """ @@ -131,7 +131,7 @@ def get_project_info(self, username: str, projname: str) -> ApiResponse[dict]: 获取项目详情 """ return self.http.get(f"/project/{username}/{projname}", params={}) - + @staticmethod def fetch_paginated_api( api_func: Callable[..., ApiResponse], # 分页 API 请求函数 diff --git a/swanlab/api/experiment.py b/swanlab/api/deprecated/experiment.py similarity index 98% rename from swanlab/api/experiment.py rename to swanlab/api/deprecated/experiment.py index 74db3d995..ebf0ad635 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/deprecated/experiment.py @@ -9,8 +9,8 @@ """ from typing import List -from swanlab.api.base import ApiBase, ApiHTTP -from swanlab.api.types import ApiResponse, Experiment, Pagination +from .base import ApiBase, ApiHTTP +from .types import ApiResponse, Experiment, Pagination try: from pandas import DataFrame diff --git a/swanlab/api/group.py b/swanlab/api/deprecated/group.py similarity index 88% rename from swanlab/api/group.py rename to swanlab/api/deprecated/group.py index fd48ce087..ab6d73b87 100644 --- a/swanlab/api/group.py +++ b/swanlab/api/deprecated/group.py @@ -8,8 +8,8 @@ 组织相关的开放API """ -from swanlab.api.base import ApiBase, ApiHTTP -from swanlab.api.types import ApiResponse +from .base import ApiBase, ApiHTTP +from .types import ApiResponse class GroupAPI(ApiBase): diff --git a/swanlab/api/main.py b/swanlab/api/deprecated/main.py similarity index 95% rename from swanlab/api/main.py rename to swanlab/api/deprecated/main.py index 1029e4ba9..711a47764 100644 --- a/swanlab/api/main.py +++ b/swanlab/api/deprecated/main.py @@ -9,15 +9,15 @@ """ from typing import Dict, List, Union -from swanlab.api.base import ApiHTTP, get_logger -from swanlab.api.experiment import ExperimentAPI -from swanlab.api.group import GroupAPI -from swanlab.api.project import ProjectAPI -from swanlab.api.types import ApiResponse, Experiment, Project from swanlab.core_python import auth from swanlab.error import KeyFileError from swanlab.log.log import SwanLog from swanlab.package import get_key +from .base import ApiHTTP, get_logger +from .experiment import ExperimentAPI +from .group import GroupAPI +from .project import ProjectAPI +from .types import ApiResponse, Experiment, Project try: from pandas import DataFrame @@ -28,6 +28,7 @@ class OpenApi: def __init__(self, api_key: str = "", log_level: str = "info"): self.__logger: SwanLog = get_logger(log_level) + self.__logger.warning("OpenApi will be soon deprecated in swanlab 0.8.0. Please use swanlab.Api() instead.") if api_key: self.__logger.debug("Using API key", api_key) @@ -210,7 +211,7 @@ def get_summary( return self.experiment.get_summary( exp_id=exp_id, pro_id=project_cuid, - root_exp_id=exp.data.get("rootExpId", ""), + root_exp_id=exp.data.get("rootExpId", ""), root_pro_id=exp.data.get("rootProId", "") ) diff --git a/swanlab/api/project.py b/swanlab/api/deprecated/project.py similarity index 95% rename from swanlab/api/project.py rename to swanlab/api/deprecated/project.py index b41eba4c0..f20e5f612 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/deprecated/project.py @@ -8,8 +8,8 @@ 项目相关的开放API """ -from swanlab.api.base import ApiBase, ApiHTTP -from swanlab.api.types import ApiResponse, Pagination, Project +from .base import ApiBase, ApiHTTP +from .types import ApiResponse, Pagination, Project class ProjectAPI(ApiBase): diff --git a/swanlab/api/types.py b/swanlab/api/deprecated/types.py similarity index 100% rename from swanlab/api/types.py rename to swanlab/api/deprecated/types.py diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py new file mode 100644 index 000000000..91423b911 --- /dev/null +++ b/swanlab/api/experiment/__init__.py @@ -0,0 +1,222 @@ +""" +@author: Zhou QiYang +@file: experiment.py +@time: 2026/1/11 16:36 +@description: OpenApi 的单个实验对象 +""" + +from typing import List, Dict, Any + +from swanlab.api.user import User +from swanlab.api.utils import Label, get_properties +from swanlab.core_python.api.type import RunType +from swanlab.core_python.client import Client +from swanlab.log import swanlog + + + +class Experiment: + def __init__( + self, client: Client, *, data: RunType, path: str, web_host: str, login_user: str, line_count: int + ) -> None: + self._client = client + self._data = data + self._path = path + self._web_host = web_host + self._login_user = login_user + self._line_count = line_count + + @property + def name(self) -> str: + """ + Experiment name. + """ + return self._data['name'] + + @property + def id(self) -> str: + """ + Experiment CUID. + """ + return self._data['cuid'] + + @property + def url(self) -> str: + """ + Full URL to access the experiment. + """ + return f"{self._web_host}/@{self._path}/runs/{self.id}/chart" + + @property + def created_at(self) -> str: + """ + Experiment creation timestamp + """ + return self._data['createdAt'] + + @property + def description(self) -> str: + """ + Experiment description. + """ + return self._data['description'] + + @property + def labels(self) -> List[Label]: + """ + List of Label attached to this experiment. + """ + return [Label(label['name']) for label in self._data['labels']] + + @property + def state(self) -> str: + """ + Experiment state. + """ + return self._data['state'] + + @property + def group(self) -> str: + """ + Experiment group. + """ + return self._data['cluster'] + + @property + def job(self) -> str: + """ + Experiment job type. + """ + return self._data['job'] + + @property + def user(self) -> User: + """ + Experiment user. + """ + return User(client=self._client, login_user=self._login_user, username=self._data['user']['username']) + + @property + def metric_keys(self) -> List[str]: + """ + List of metric keys. + """ + return list(self.summary.keys()) + + @property + def history_line_count(self) -> int: + """ + The number of historical experiments in this project. + """ + return self._line_count + + @property + def root_exp_id(self) -> str: + """ + Root experiment cuid. If the experiment is a root experiment, it will be None. + """ + return self._data['rootExpId'] + + @property + def root_pro_id(self) -> str: + """ + Root project cuid. If the experiment is a root experiment, it will be None. + """ + return self._data['rootProId'] + + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) + + def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None, pandas: bool = True) -> Any: + """ + Get metric data from the experiment. + + Args: + keys: List of metric keys to fetch. Required. A single string is also accepted. + x_axis: Metric to use as x-axis. Defaults to 'step'. + sample: Number of rows to return from the start. + pandas: Reserved parameter (always returns DataFrame). + + Returns: + pandas.DataFrame with metric data. + + Example: + >>> exp.metrics(keys=['loss', 'accuracy'], sample=20, x_axis='t/accuracy') + t/accuracy loss + step + 0 0.310770 0.525776 + 1 0.642817 0.479186 + ... + """ + try: + import pandas as pd + except ImportError: + raise TypeError("pandas is required for metrics(). Install with 'pip install pandas'.") + + # Normalize keys: must be a non-empty list of strings + if keys is None: + swanlog.warning('keys cannot be None') + return pd.DataFrame() + if not isinstance(keys, list): + swanlog.warning('keys must be a list') + return pd.DataFrame() + if not keys: + swanlog.warning('keys cannot be empty') + return pd.DataFrame() + if not all(isinstance(k, str) for k in keys): + swanlog.warning('keys must be a list of strings') + return pd.DataFrame() + + # Determine if x_axis needs to be included + use_x_axis = x_axis is not None and x_axis != "step" + if use_x_axis: + keys.append(x_axis) + + # Fetch and process each metric CSV + dfs = [] + prefix = "" + for idx, key in enumerate(keys): + resp = self._client.get(f"/experiment/{self.id}/column/csv", params={"key": key}) + url = resp[0].get("url", "") + df = pd.read_csv(url, index_col=0) + + # Extract prefix from first column (e.g., "t0707-02:17-loss_step" → "t0707-02:17-") + if idx == 0: + first_col = df.columns[0] + suffix = f"{key}_" + prefix = first_col.split(suffix)[0] if suffix in first_col else "" + + # Strip "_step" suffix from column names (Python 3.8 compatible) + def strip_suffix(col, suffix="_step"): + return col[:-len(suffix)] if col.endswith(suffix) else col + + # Apply prefix removal and suffix stripping + df.columns = [ + strip_suffix(col[len(prefix):]) if prefix and col.startswith(prefix) else strip_suffix(col) + for col in df.columns + ] + dfs.append(df) + + # Merge all DataFrames + result_df = dfs[0].join(dfs[1:], how='outer') if len(dfs) > 1 else dfs[0] + result_df = result_df.sort_index() + + # Handle x_axis: drop timestamp columns, reorder, filter nulls + if use_x_axis: + result_df = result_df.drop(columns=[c for c in result_df.columns if c.endswith("_timestamp")], errors='ignore') + if x_axis not in result_df.columns: + raise ValueError(f"x_axis '{x_axis}' not found in result DataFrame") + cols = [x_axis] + [c for c in result_df.columns if c != x_axis] + result_df = result_df[cols].dropna(subset=[x_axis]) + + # Apply sample limit + if sample is not None: + result_df = result_df.head(sample) + + return result_df + + +__all__ = ['Experiment'] diff --git a/swanlab/api/experiments/__init__.py b/swanlab/api/experiments/__init__.py new file mode 100644 index 000000000..ea04d2653 --- /dev/null +++ b/swanlab/api/experiments/__init__.py @@ -0,0 +1,68 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2025/12/30 15:08 +@description: OpenApi 的实验对象迭代器 +""" + +from typing import List, Dict, Iterator + +from swanlab.api.experiment import Experiment +from swanlab.core_python.api.experiment import get_project_experiments +from swanlab.core_python.api.type import RunType +from swanlab.core_python.auth.providers.api_key import LoginInfo +from swanlab.core_python.client import Client + + +def flatten_runs(runs: Dict) -> List: + """ + 展开分组后的实验数据,返回一个包含所有实验的列表 + """ + flat_runs = [] + for group in runs.values(): + if isinstance(group, Dict): + flat_runs.extend(flatten_runs(group)) + else: + flat_runs.extend(group) + return flat_runs + + +class Experiments: + """ + Container for a collection of Experiment objects. + You can iterate over the experiments by for-in loop. + """ + + def __init__(self, client: Client, *, path: str, login_info: LoginInfo, filters: Dict[str, object] = None) -> None: + if len(path.split('/')) != 2: + raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project'") + self._client = client + self._path = path + self._web_host = login_info.web_host + self._login_user = login_info.username + self._filters = filters + + def __iter__(self) -> Iterator[Experiment]: + # TODO: 完善filter的功能(正则、条件判断) + resp = get_project_experiments(self._client, path=self._path, filters=self._filters) + runs: List[RunType] = [] + if isinstance(resp, List): + runs = resp + # 分组时需展平实验数据 + elif isinstance(resp, Dict): + runs = flatten_runs(resp) + line_count = len(runs) + yield from iter( + Experiment( + self._client, + data=run, + path=self._path, + web_host=self._web_host, + login_user=self._login_user, + line_count=line_count, + ) + for run in runs + ) + + +__all__ = ["Experiments"] diff --git a/swanlab/api/project/__init__.py b/swanlab/api/project/__init__.py new file mode 100644 index 000000000..66985d2a0 --- /dev/null +++ b/swanlab/api/project/__init__.py @@ -0,0 +1,104 @@ +""" +@author: Zhou QiYang +@file: project.py +@time: 2026/1/5 17:58 +@description: OpenApi 中的项目对象 +""" + +from functools import cached_property +from typing import List, Dict + +from swanlab.api.utils import Label, get_properties +from swanlab.api.workspace import Workspace +from swanlab.core_python.api.type import ProjectType +from swanlab.core_python.api.user import get_workspace_info +from swanlab.core_python.client import Client + + +class Project: + """ + Representing a single project with some of its properties. + """ + + def __init__(self, client: Client, *, web_host: str, data: ProjectType) -> None: + self._client = client + self._web_host = web_host + self._data = data + + @property + def name(self) -> str: + """ + Project name. + """ + return self._data['name'] + + @property + def path(self) -> str: + """ + Project path in the format 'username/project-name'. + """ + return self._data['path'] + + @property + def url(self) -> str: + """ + Full URL to access the project. + """ + return f"{self._web_host}/@{self._data['path']}" + + @property + def description(self) -> str: + """ + Project description. + """ + return self._data['description'] + + @property + def visibility(self) -> str: + """ + Project visibility, either 'PUBLIC' or 'PRIVATE'. + """ + return self._data['visibility'] + + @property + def created_at(self) -> str: + """ + Project creation timestamp + """ + return self._data['createdAt'] + + @property + def updated_at(self) -> str: + """ + Project last update timestamp + """ + return self._data['updatedAt'] + + @cached_property + def workspace(self) -> Workspace: + """ + Project workspace object. + """ + data = get_workspace_info(self._client, path=self._data["group"]["username"]) + return Workspace(self._client, data=data) + + @property + def labels(self) -> List[Label]: + """ + List of Label attached to this project. + """ + return [Label(label['name']) for label in self._data['projectLabels']] + + @property + def count(self) -> Dict[str, int]: + """ + Project statistics dictionary containing: + experiments, contributors, children, collaborators, runningExps. + """ + return self._data['_count'] + + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py new file mode 100644 index 000000000..eb3ac297b --- /dev/null +++ b/swanlab/api/projects/__init__.py @@ -0,0 +1,60 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/11 16:31 +@description: OpenApi 中的项目对象迭代器 +""" + +from typing import List, Optional, Iterator + +from swanlab.api.project import Project +from swanlab.core_python.api.project import get_workspace_projects +from swanlab.core_python.api.type import ProjResponseType +from swanlab.core_python.client import Client + + +class Projects: + """ + Container for a collection of Project objects. + You can iterate over the projects by for-in loop. + """ + + def __init__( + self, + client: Client, + *, + web_host: str, + path: str, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ) -> None: + self._client = client + self._web_host = web_host + self._path = path + self._sort = sort + self._search = search + self._detail = detail + + def __iter__(self) -> Iterator[Project]: + # 按用户遍历情况获取项目信息 + cur_page = 0 + while True: + cur_page += 1 + resp: ProjResponseType = get_workspace_projects( + self._client, + path=self._path, + page=cur_page, + size=20, + sort=self._sort, + search=self._search, + detail=self._detail, + ) + for p in resp['list']: + yield Project(self._client, web_host=self._web_host, data=p) + + if cur_page >= resp['pages']: + break + + +__all__ = ["Projects"] diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py new file mode 100644 index 000000000..073212925 --- /dev/null +++ b/swanlab/api/user/__init__.py @@ -0,0 +1,130 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/11 16:40 +@description: OpenApi 中的用户对象 +""" + +import re +from functools import cached_property +from typing import List, Optional + +from swanlab.api.utils import self_hosted, get_properties +from swanlab.core_python.api.type import ApiKeyType +from swanlab.core_python.api.user import ( + get_user_groups, + get_api_keys, + create_api_key, + get_latest_api_key, + delete_api_key, + create_user, +) +from swanlab.core_python.client import Client + + +def check_create_info(username: str, password: str) -> bool: + # 用户名为大小写字母、数字及-、_组成 + # 密码必须包含数字+英文且至少8位 + if not re.match(r'^[a-zA-Z0-9_-]+$', username): + raise ValueError("Username must be alphanumeric and can contain - and _") + if not re.match(r'^(?=.*[0-9])(?=.*[a-zA-Z]).{8,}$', password): + raise ValueError("Password must contain at least 8 characters and include numbers and letters") + else: + return True + + +class User: + def __init__(self, client: Client, login_user: str = None, username: str = None) -> None: + if login_user is None and username is None: + raise ValueError("login_user or username are required") + + self._client = client + self._api_keys: List[ApiKeyType] = [] + self._login_user = login_user + self._cur_username = username or self._login_user + + @property + def username(self) -> str: + """ + User name. (if username is not None, return username, otherwise return login_user) + """ + return self._cur_username + + @property + def is_self(self) -> bool: + """ + Check if the user is the current login user. + """ + return self._cur_username == self._login_user + + @cached_property + def teams(self) -> List[str]: + """ + List of teams the user belongs to. + """ + resp = get_user_groups(self._client, username=self._cur_username) + return [r['name'] for r in resp] + + # TODO: 管理员可以对指定用户的api_key进行操作 + @cached_property + def api_keys(self) -> List[str]: + """ + List of api keys the user has. + """ + if not self.is_self: + raise ValueError("Getting api keys of other users has not been supported yet.") + else: + self._api_keys = get_api_keys(self._client) + return [r['key'] for r in self._api_keys] + + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) + + def _refresh_api_keys(self): + """ + Refresh the list of api keys. + """ + del self.api_keys + self._api_keys = get_api_keys(self._client) + + def generate_api_key(self, description: str = None) -> Optional[str]: + """ + Generate a new api key. + """ + if not self.is_self: + raise ValueError("Generating api key of other users has not been supported yet.") + else: + create_api_key(self._client, name=description) + api_key = get_latest_api_key(self._client) + return api_key['key'] if api_key else None + + def delete_api_key(self, api_key: str) -> bool: + """ + Delete an api key. + """ + if not self.is_self: + raise ValueError("Deleting api key of other users has not been supported yet.") + else: + self._refresh_api_keys() + for key in self._api_keys: + if key['key'] == api_key: + delete_api_key(self._client, key_id=key['id']) + return True + return False + + @self_hosted("root") + def create(self, username: str, password: str) -> Optional[bool]: + """ + Create a new user. (Only root user can create other user) + """ + if not self.is_self: + raise ValueError(f"{self._cur_username} is not allowed to create other user.") + check_create_info(username, password) + create_user(self._client, username=username, password=password) + return True + + +__all__ = ["User"] diff --git a/swanlab/api/users/__init__.py b/swanlab/api/users/__init__.py new file mode 100644 index 000000000..b205347b8 --- /dev/null +++ b/swanlab/api/users/__init__.py @@ -0,0 +1,38 @@ +""" +@author: cunyue +@file: __init__.py +@time: 2026/2/3 13:30 +@description: OpenApi 中的用户对象迭代器 +""" + +from typing import Iterator + +from swanlab.api.user import User +from swanlab.core_python import Client +from swanlab.core_python.api.user import get_users + + +class Users: + """ + Container for a collection of User objects. + You can iterate over the users by for-in loop. + """ + + def __init__(self, client: Client, *, login_user: str) -> None: + self._client = client + self._login_user = login_user + + def __iter__(self) -> Iterator[User]: + cur_page = 0 + while True: + cur_page += 1 + resp = get_users( + self._client, + page=cur_page, + size=20, + ) + for u in resp['list']: + yield User(self._client, login_user=self._login_user, username=u['username']) + + if cur_page >= resp['pages']: + break diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py new file mode 100644 index 000000000..196d564b7 --- /dev/null +++ b/swanlab/api/utils.py @@ -0,0 +1,82 @@ +""" +@author: Zhou QiYang +@file: utils.py +@time: 2026/1/11 23:44 +@description: OpenApi 中的基础对象与通用工具 +""" + +from dataclasses import dataclass +from functools import wraps +from typing import Dict + +from swanlab.core_python.api.type import IdentityType +from swanlab.core_python.api.user import get_self_hosted_init +from swanlab.core_python.client import Client +from swanlab.error import ApiError + + +@dataclass +class Label: + """ + Project label object + you can get the label name by str(label) + """ + + name: str + + def __str__(self) -> str: + return self.name + + +def self_hosted(identity: IdentityType = "user"): + """ + 用于需要在私有化环境下使用的接口的装饰器。 + :param identity: 用户身份,默认为 "user",如果为 "root",则会额外验证是否为根用户。 + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + client = getattr(self, '_client', None) + if not isinstance(client, Client): + raise AttributeError("There is no SwanLab client instance.") + + # 1. 尝试获取私有化服务信息 + try: + self_hosted_info = get_self_hosted_init(client) + except ApiError: + raise ValueError("You haven't launched a swanlab self-hosted instance. This usages are not available.") + + if not self_hosted_info.get("enabled", False): + raise ValueError("SwanLab self-hosted instance hasn't been ready yet.") + if self_hosted_info.get("expired", True): + raise ValueError("SwanLab self-hosted instance has expired.") + + # 2. 检测用户权限(商业版root用户功能) + if identity == 'root': + if not self_hosted_info.get('root', False): + raise ValueError("You don't have permission to perform this action. Please login as a root user") + if not getattr(self, 'is_self', True): + raise ValueError('This root-only action can only be performed by the logged-in root user.') + + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def get_properties(obj: object) -> Dict[str, object]: + """递归获取实例中所有property""" + result = dict() + for name in dir(obj): + if name.startswith("_"): + continue + if isinstance(getattr(type(obj), name, None), property): + value = getattr(obj, name, None) + result[name] = value if type(value).__module__ == 'builtins' else get_properties(value) + + return result + + +__all__ = ['Label', 'self_hosted', 'get_properties'] diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace/__init__.py new file mode 100644 index 000000000..8cc15295a --- /dev/null +++ b/swanlab/api/workspace/__init__.py @@ -0,0 +1,69 @@ +""" +@author: Zhou Qiyang +@file: __init__.py +@time: 2026/1/27 20:43 +@description: 工作空间 +""" + +from typing import Literal, Dict + +from swanlab.api.utils import get_properties +from swanlab.core_python import Client +from swanlab.core_python.api.type import WorkspaceType, RoleType + + +class Workspace: + def __init__(self, client: Client, *, data: WorkspaceType): + self._client = client + self._data = data + + @property + def name(self) -> str: + """ + Workspace display name. + """ + return self._data['name'] + + @property + def username(self) -> str: + """ + Workspace name. + """ + return self._data['username'] + + @property + def workspace_type(self) -> Literal['TEAM', 'PERSON']: + """ + Workspace type. + """ + return self._data['type'] + + @property + def profile(self) -> Dict[str, str]: + """ + Workspace profile. + """ + return self._data.get('profile', dict()) + + @property + def comment(self) -> str: + """ + Workspace comment. + """ + return self._data['comment'] + + @property + def role(self) -> RoleType: + """ + Current login user's role in the workspace (only display when type=TEAM). + """ + return self._data['role'] + + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) + + +__all__ = ['Workspace'] diff --git a/swanlab/api/workspaces/__init__.py b/swanlab/api/workspaces/__init__.py new file mode 100644 index 000000000..dbe7c108f --- /dev/null +++ b/swanlab/api/workspaces/__init__.py @@ -0,0 +1,33 @@ +""" +@author: Zhou Qiyang +@file: __init__.py +@time: 2026/1/27 18:13 +@description: 工作空间迭代器 +""" + +from typing import Iterator + +from swanlab.api.workspace import Workspace +from swanlab.core_python import Client +from swanlab.core_python.api.user import get_user_groups, get_workspace_info + + +class Workspaces: + def __init__(self, client: Client, *, username: str) -> None: + self._client = client + self._username = username + + def get_all_workspaces(self, username: str = None): + """Get all workspaces of specific user (defaults to current user)""" + cur_username = username if username else self._username + resp = get_user_groups(self._client, username=cur_username) + groups = [r['username'] for r in resp] + return [cur_username] + groups + + def __iter__(self) -> Iterator[Workspace]: + for space in self.get_all_workspaces(): + data = get_workspace_info(self._client, path=space) + yield Workspace(self._client, data=data) + + +__all__ = ['Workspaces'] diff --git a/swanlab/core_python/api/experiment.py b/swanlab/core_python/api/experiment.py deleted file mode 100644 index 95d92db0c..000000000 --- a/swanlab/core_python/api/experiment.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -@author: cunyue -@file: experiment.py -@time: 2025/12/11 18:37 -@description: 定义实验相关的后端API接口 -""" - -from typing import Literal, TYPE_CHECKING - -if TYPE_CHECKING: - from swanlab.core_python.client import Client - - -def send_experiment_heartbeat( - client: "Client", - *, - cuid: str, - flag_id: str, -): - """ - 发送实验心跳,保持实验处于活跃状态 - :param client: 已登录的客户端实例 - :param cuid: 实验唯一标识符 - :param flag_id: 实验标记ID - """ - client.post(f"/house/experiments/{cuid}/heartbeat", {"flagId": flag_id}) - - -def update_experiment_state( - client: "Client", - *, - username: str, - projname: str, - cuid: str, - state: Literal['FINISHED', 'CRASHED', 'ABORTED'], - finished_at: str = None, -): - """ - 更新实验状态,注意此接口会将客户端标记为 pending 状态,表示实验已结束 - :param client: 已登录的客户端实例 - :param username: 实验所属用户名 - :param projname: 实验所属项目名称 - :param cuid: 实验唯一标识符 - :param state: 实验状态 - :param finished_at: 实验结束时间,格式为 ISO 8601,如果不提供则使用当前时间 - """ - put_data = { - "state": state, - "finishedAt": finished_at, - "from": "sdk", - } - put_data = {k: v for k, v in put_data.items() if v is not None} # 移除值为None的键 - client.put(f"/project/{username}/{projname}/runs/{cuid}/state", put_data) - client.pending = True diff --git a/swanlab/core_python/api/experiment/__init__.py b/swanlab/core_python/api/experiment/__init__.py new file mode 100644 index 000000000..662d7d258 --- /dev/null +++ b/swanlab/core_python/api/experiment/__init__.py @@ -0,0 +1,120 @@ +""" +@author: cunyue +@file: experiment.py +@time: 2025/12/11 18:37 +@description: 定义实验相关的后端API接口 +""" + +from typing import Literal, Dict, TYPE_CHECKING, List, Union + +from swanlab.core_python.api.type import RunType +from .utils import to_camel_case, parse_column_type + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + + +def send_experiment_heartbeat( + client: "Client", + *, + cuid: str, + flag_id: str, +): + """ + 发送实验心跳,保持实验处于活跃状态 + :param client: 已登录的客户端实例 + :param cuid: 实验唯一标识符 + :param flag_id: 实验标记ID + """ + client.post(f"/house/experiments/{cuid}/heartbeat", {"flagId": flag_id}) + + +def update_experiment_state( + client: "Client", + *, + username: str, + projname: str, + cuid: str, + state: Literal['FINISHED', 'CRASHED', 'ABORTED'], + finished_at: str = None, +): + """ + 更新实验状态,注意此接口会将客户端标记为 pending 状态,表示实验已结束 + :param client: 已登录的客户端实例 + :param username: 实验所属用户名 + :param projname: 实验所属项目名称 + :param cuid: 实验唯一标识符 + :param state: 实验状态 + :param finished_at: 实验结束时间,格式为 ISO 8601,如果不提供则使用当前时间 + """ + put_data = { + "state": state, + "finishedAt": finished_at, + "from": "sdk", + } + put_data = {k: v for k, v in put_data.items() if v is not None} # 移除值为None的键 + client.put(f"/project/{username}/{projname}/runs/{cuid}/state", put_data) + client.pending = True + + +def get_project_experiments( + client: "Client", + *, + path: str, + filters: Dict[str, object] = None, +) -> Union[List[RunType], Dict[str, List[RunType]]]: + """ + 获取指定项目下的所有实验信息 + 若有实验分组,则返回一个字典,使用时需递归展平实验数据 + :param client: 已登录的客户端实例 + :param path: 项目路径 username/project + :param filters: 筛选实验的条件,可选 + """ + parsed_filters = ( + [ + { + "key": to_camel_case(key) if parse_column_type(key) == 'STABLE' else key.split('.', 1)[-1], + "active": True, + "value": [value], + "op": 'EQ', + "type": parse_column_type(key), + } + for key, value in filters.items() + ] + if filters + else [] + ) + res = client.post(f"/project/{path}/runs/shows", data={'filters': parsed_filters}) + return res[0] + + +def get_single_experiment(client: "Client", *, path: str) -> RunType: + """ + 获取指定项目下的所有实验信息 + 若有实验分组,则返回一个字典,使用时需递归展平实验数据 + :param client: 已登录的客户端实例 + :param path: 实验路径 username/project/expid + """ + proj_path, expid = path.rsplit('/', 1) + res = client.get(f"/project/{proj_path}/runs/{expid}") + return res[0] + + +def get_experiment_metrics(client: "Client", *, expid: str, key: str) -> Dict[str, str]: + """ + 获取指定字段的指标数据,返回csv网址 + :param client: 已登录的客户端实例 + :param expid: 实验cuid + :param key: 指定字段列表 + """ + res = client.get(f"/experiment/{expid}/column/csv", params={'key': key}) + return res[0] + + +__all__ = [ + "send_experiment_heartbeat", + "update_experiment_state", + "get_project_experiments", + "get_single_experiment", + "get_experiment_metrics", +] diff --git a/swanlab/core_python/api/experiment/utils.py b/swanlab/core_python/api/experiment/utils.py new file mode 100644 index 000000000..a35088f97 --- /dev/null +++ b/swanlab/core_python/api/experiment/utils.py @@ -0,0 +1,24 @@ +""" +@author: Zhou QiYang +@file: utils.py +@time: 2026/1/10 22:09 +@description: 实验相关的后端API接口中的工具函数 +""" + +from swanlab.core_python.api.type import ColumnType + + +# 从前缀中获取指标类型 +def parse_column_type(column: str) -> ColumnType: + column_type = column.split('.', 1)[0] + if column_type == 'summary': + return 'SCALAR' + elif column_type == 'config': + return 'CONFIG' + else: + return 'STABLE' + + +# 将下划线命名转化为驼峰命名 +def to_camel_case(name: str) -> str: + return ''.join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split('_'))]) diff --git a/swanlab/core_python/api/project/__init__.py b/swanlab/core_python/api/project/__init__.py new file mode 100644 index 000000000..09c4a0f47 --- /dev/null +++ b/swanlab/core_python/api/project/__init__.py @@ -0,0 +1,57 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2025/12/19 23:49 +@description: 定义项目相关的后端API接口 +""" + +from typing import Optional, List, TYPE_CHECKING + +from swanlab.core_python.api.type import ProjResponseType, ProjectType + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + + +def get_workspace_projects( + client: "Client", + *, + path: str, + page: int = 1, + size: int = 20, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, +) -> ProjResponseType: + """ + 获取指定页数和条件下的项目信息 + :param client: 已登录的客户端实例 + :param path: 工作空间名称 + :param page: 页码 + :param size: 每页项目数量 + :param sort: 排序规则, 可选 + :param search: 搜索的项目名称关键字, 可选 + :param detail: 是否包含项目下实验的相关信息, 可选, 默认为true + """ + params = { + 'page': page, + 'size': size, + 'sort': sort, + 'search': search, + 'detail': detail, + } + res = client.get(f"/project/{path}", params=dict(params)) + return res[0] + + +def get_project_info(client: "Client", *, path: str) -> ProjectType: + """ + 获取指定路径的项目信息 + :param client: 已登录的客户端实例 + :param path: 项目路径 'username/project' + """ + res = client.get(f"/project/{path}") + return res[0] + + +__all__ = ["get_workspace_projects", "get_project_info"] diff --git a/swanlab/core_python/api/type/__init__.py b/swanlab/core_python/api/type/__init__.py new file mode 100644 index 000000000..4bed69f2a --- /dev/null +++ b/swanlab/core_python/api/type/__init__.py @@ -0,0 +1,24 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/11 23:36 +@description: 后端API接口相关类型 +""" + +from .experiment import RunType, ColumnType +from .project import ProjectType, ProjResponseType +from .user import GroupType, IdentityType, ApiKeyType, SelfHostedInfoType +from .workspace import WorkspaceType, RoleType + +__all__ = [ + "RunType", + "ColumnType", + "ProjectType", + "ProjResponseType", + "GroupType", + "IdentityType", + "ApiKeyType", + "SelfHostedInfoType", + "WorkspaceType", + "RoleType", +] diff --git a/swanlab/core_python/api/type/experiment.py b/swanlab/core_python/api/type/experiment.py new file mode 100644 index 000000000..56d17455e --- /dev/null +++ b/swanlab/core_python/api/type/experiment.py @@ -0,0 +1,33 @@ +""" +@author: Zhou QiYang +@file: type.py +@time: 2026/1/10 22:12 +@description: 实验相关后端API接口类型 +""" + +from typing import TypedDict, List, Dict, Optional, Literal + +ColumnType = Literal['STABLE', 'SCALAR', 'CONFIG'] # 列类型 +StateType = Literal['FINISHED', 'CRASHED', 'ABORTED', 'RUNNING'] # 实验状态 + + +# ------------------------------------- 通用类型 ------------------------------------- +class UserType(TypedDict): + username: str # 用户名 + name: str # 用户显示名称 + + +class RunType(TypedDict): + cuid: str # 实验CUID, 唯一标识符 + name: str # 实验名称 + createdAt: str # 创建时间, e.g., '2024-11-23T12:28:04.286Z' + description: str # 实验描述 + labels: List[Dict[str, str]] # 实验标签列表 + profile: Dict[str, Dict[str, object]] # 实验配置和摘要信息,包含 'config' 和 'scalar' + state: StateType # 实验状态 + cluster: str # 实验组 + job: str # 任务类型 + runtime: str # 运行时间 + user: UserType # 实验所属用户 + rootExpId: Optional[str] # 祖宗实验对应的实验 cuid,如果为克隆实验则必传 + rootProId: Optional[str] # 祖宗实验对应的项目 cuid,如果为克隆实验则必传 diff --git a/swanlab/core_python/api/type/project.py b/swanlab/core_python/api/type/project.py new file mode 100644 index 000000000..edc66392d --- /dev/null +++ b/swanlab/core_python/api/type/project.py @@ -0,0 +1,36 @@ +""" +@author: Zhou QiYang +@file: projects.py +@time: 2026/1/10 21:48 +@description: 项目相关后端API接口类型 +""" + +from typing import TypedDict, List, Dict + + +# ------------------------------------- 通用类型 ------------------------------------- +class ProjectLabelType(TypedDict): + name: str # 项目标签名称 + + +# 项目信息 +class ProjectType(TypedDict): + cuid: str # 项目CUID, 唯一标识符 + name: str # 项目名 + path: str # 项目路径 + url: str # 项目URL + description: str # 项目描述 + visibility: str # 可见性, 'PUBLIC' 或 'PRIVATE' + createdAt: str # e.g., '2024-11-23T12:28:04.286Z' + updatedAt: str # e.g., '2024-11-23T12:28:04.286Z' + group: Dict[str, str] # 包含项目所属工作空间名称 (workspace) + projectLabels: List[ProjectLabelType] # 项目标签 + _count: Dict[str, int] # 项目的统计信息 + + +# ------------------------------------- 后端返回信息 ------------------------------------- +class ProjResponseType(TypedDict): + list: List[ProjectType] # 项目列表 + size: int # 每页项目数量 + pages: int # 总页数 + total: int # 总项目数量 diff --git a/swanlab/core_python/api/type/user.py b/swanlab/core_python/api/type/user.py new file mode 100644 index 000000000..f56d55d83 --- /dev/null +++ b/swanlab/core_python/api/type/user.py @@ -0,0 +1,35 @@ +""" +@author: Zhou QiYang +@file: user.py +@time: 2026/1/10 21:46 +@description: 用户相关后端API接口类型 +""" + +from typing import Literal, TypedDict + + +# ------------------------------------- 通用类型 ------------------------------------- +IdentityType = Literal['user', 'root'] + + +# 在项目信息和用户信息的返回结果中,该类型的字段含义不同,注意区分 +class GroupType(TypedDict): + name: str # 组织名称 (用于user.teams) + username: str + + +# ------------------------------------- 后端返回信息 ------------------------------------- +class ApiKeyType(TypedDict): + id: int + name: str + createdAt: str + key: str + + +# 私有化部署信息 +class SelfHostedInfoType(TypedDict): + enabled: bool # 是否成功部署 + expired: bool # licence是否过期 + root: bool # 是否为根用户 + plan: Literal["free", "commercial"] # 私有化版本(免费、商业) + seats: int # 余剩席位 diff --git a/swanlab/core_python/api/type/workspace.py b/swanlab/core_python/api/type/workspace.py new file mode 100644 index 000000000..a689b1d28 --- /dev/null +++ b/swanlab/core_python/api/type/workspace.py @@ -0,0 +1,20 @@ +""" +@author: Zhou Qiyang +@file: workspace +@time: 2026/1/27 20:46 +@description: 工作空间相关类型 +""" + +from typing import TypedDict, Literal, Dict + +RoleType = Literal['VISITOR', 'VIEWER', 'MEMBER', 'OWNER'] + + +class WorkspaceType(TypedDict): + name: str + username: str + profile: Dict[str, str] + type: Literal['TEAM', 'PERSON'] # 返回的信息类型 + comment: str # 组织或者个人的描述 + # 组织信息特有 + role: RoleType # 组织成员的角色 diff --git a/swanlab/core_python/api/user/__init__.py b/swanlab/core_python/api/user/__init__.py new file mode 100644 index 000000000..1aac2a8d4 --- /dev/null +++ b/swanlab/core_python/api/user/__init__.py @@ -0,0 +1,83 @@ +""" +@author: Zhou QiYang +@file: __init__.py.py +@time: 2026/1/10 21:44 +@description: 定义用户相关的后端API接口 +""" + +from typing import TYPE_CHECKING, List + +from swanlab.core_python.api.type import GroupType, ApiKeyType, WorkspaceType +from .self_hosted import get_self_hosted_init, create_user, get_users + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + + +def create_api_key(client: "Client", *, name: str = None) -> None: + """ + 创建一个api_key,完成后返回成功信息 + :param client: 已登录的客户端实例 + :param name: api_key 的名称 + """ + client.post(f"/user/key", data={'name': name} if name else None) + + +def delete_api_key(client: "Client", *, key_id: int) -> None: + """ + 删除指定id的api_key + :param client: 已登录的客户端实例 + :param key_id: api_key的id + """ + client.delete(f"/user/key/{key_id}") + + +def get_user_groups(client: "Client", *, username: str) -> List[GroupType]: + """ + 获取用户加入的组织 + :param client: 已登录的客户端实例 + :param username: 用户名称 + """ + res = client.get(f"/user/{username}/groups") + return res[0] + + +def get_workspace_info(client: "Client", *, path: str) -> WorkspaceType: + """ + 获取指定工作空间的信息 + :param client: 已登录的客户端实例 + :param path: 工作空间名称 + """ + res = client.get(f"/group/{path}") + return res[0] + + +def get_api_keys(client: "Client") -> List[ApiKeyType]: + """ + 获取当前全部的api_key + :param client: 已登录的客户端实例 + """ + res = client.get(f"/user/key") + return res[0] + + +def get_latest_api_key(client: "Client") -> ApiKeyType: + """ + 获取最新的api_key + :param client: 已登录的客户端实例 + """ + res = client.get(f"/user/key/latest") + return res[0] + + +__all__ = [ + "create_api_key", + "delete_api_key", + "get_user_groups", + "get_workspace_info", + "get_api_keys", + "get_latest_api_key", + "get_self_hosted_init", + "create_user", + "get_users", +] diff --git a/swanlab/core_python/api/user/self_hosted.py b/swanlab/core_python/api/user/self_hosted.py new file mode 100644 index 000000000..fee2da394 --- /dev/null +++ b/swanlab/core_python/api/user/self_hosted.py @@ -0,0 +1,45 @@ +""" +@author: Zhou QiYang +@file: self_hosted.py +@time: 2026/1/5 17:42 +@description: 私有化相关API接口 +""" + +from typing import TYPE_CHECKING + +from swanlab.core_python.api.type import SelfHostedInfoType + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + + +def get_self_hosted_init(client: "Client") -> SelfHostedInfoType: + """ + 获取私有化部署信息 + :param client: 已登录的客户端实例 + """ + res = client.get(f"/self_hosted/info") + return res[0] + + +def create_user(client: "Client", *, username: str, password: str) -> None: + """ + 添加用户(私有化管理员限定) + :param client: 已登录的客户端实例 + :param username: 用户名 + :param password: 用户密码 + """ + data = {"users": [{"username": username, "password": password}]} + client.post("/self_hosted/users", data=data) + + +def get_users(client: "Client", *, page: int = 1, size: int = 20): + """ + 分页获取用户(管理员限定) + :param client: 已登录的客户端实例 + :param page: 页码 + :param size: 每页大小 + """ + params = {"page": page, "size": size} + res = client.get("/self_hosted/users", params=params) + return res[0] diff --git a/swanlab/core_python/client/__init__.py b/swanlab/core_python/client/__init__.py index a8b3a5f0d..b90d1153d 100644 --- a/swanlab/core_python/client/__init__.py +++ b/swanlab/core_python/client/__init__.py @@ -201,6 +201,15 @@ def patch(self, url: str, data: dict = None): resp = self.__session.patch(url, json=data) return decode_response(resp), resp + def delete(self, url: str, retries: Optional[int] = None): + """ + delete请求 + """ + url = self.__login_info.api_host + url + self.__before_request() + resp = self.__session.delete(url, retries=retries) + return decode_response(resp), resp + # ---------------------------------- 训练相关接口 ---------------------------------- def mount_project(self, name: str, username: str = None, public: bool = None): diff --git a/test/unit/api/test_experiment.py b/test/unit/api/test_experiment.py index 81214084a..2f57b9e83 100644 --- a/test/unit/api/test_experiment.py +++ b/test/unit/api/test_experiment.py @@ -1,101 +1,21 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/7 09:47 -@File: test_experiment.py -@IDE: pycharm -@Description: - 测试开放API的实验相关接口 -""" - -import pandas as pd -import pytest - -import tutils as T -from swanlab import OpenApi -from swanlab.api.types import ApiResponse, Experiment - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_experiment(): - """ - 获取一个实验的详细信息 - """ - api = OpenApi() - exp_cuid = "test_cuid" - res = api.get_experiment(project="test_project", exp_id=exp_cuid) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, Experiment) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_list_experiments(): - """ - 获取一个项目下的实验列表 - """ - api = OpenApi() - res = api.list_experiments(project="SwanLab") - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, list) - for item in res.data: - assert isinstance(item, Experiment) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_summary(): - """ - 获取一个实验的Summary信息 - """ - api = OpenApi() - exp_cuid = "test_cuid" - res = api.get_summary(project="test_project", exp_id=exp_cuid) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, dict) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_delete_experiment(): - """ - 删除一个实验 - """ - api = OpenApi() - exp_cuid = "test_cuid" - res = api.delete_experiment(project="test_project", exp_id=exp_cuid) - assert isinstance(res, ApiResponse) - assert res.code in [204, 404] - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_metrics(): - """ - 获取实验的指标数据 - """ - api = OpenApi() - exp_cuid = "test_cuid" - keys = ["accuracy", "loss"] - res = api.get_metrics(exp_id=exp_cuid, keys=keys) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, pd.DataFrame) - assert not res.data.empty - assert all(key in res.data.columns for key in keys) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_metrics_duplicate_keys(): - """ - 获取实验的指标数据 - """ - api = OpenApi() - exp_cuid = "test_cuid" - keys = ["accuracy", "loss", "loss"] - res = api.get_metrics(exp_id=exp_cuid, keys=keys) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, pd.DataFrame) - assert not res.data.empty - assert len(res.data.columns.to_list()) == 2 - assert all(key in res.data.columns for key in ["accuracy", "loss"]) +from unittest.mock import patch, MagicMock + +from swanlab.api.experiments import Experiments +from swanlab.core_python import Client +from tutils.setup import mock_login_info +from utils import create_nested_exps + + +def test_folded_exps(): + """测试嵌套的实验数据能够正确展平""" + mock_exps = Experiments( + MagicMock(spec=Client), + path='test_user/test-project', + login_info=mock_login_info(), + ) + + nested_data = create_nested_exps(groups=2, num_per_group=2) + with patch('swanlab.api.experiments.get_project_experiments') as mock_get_exps: + mock_get_exps.return_value = nested_data + experiments = list(mock_exps) + assert len(experiments) == 4 diff --git a/test/unit/api/test_group.py b/test/unit/api/test_group.py deleted file mode 100644 index 12269cb8e..000000000 --- a/test/unit/api/test_group.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/4/30 21:16 -@File: test_group.py -@IDE: pycharm -@Description: - 测试开放API的组织相关接口 -""" - -import pytest - -import tutils as T -from swanlab import OpenApi -from swanlab.api.types import ApiResponse - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_workspaces(): - """ - 获取用户的所有工作空间 - """ - api = OpenApi() - r = api.list_workspaces() - - assert isinstance(r, ApiResponse) - assert isinstance(r.data, list) diff --git a/test/unit/api/test_metrics.py b/test/unit/api/test_metrics.py new file mode 100644 index 000000000..6fcf07fa8 --- /dev/null +++ b/test/unit/api/test_metrics.py @@ -0,0 +1,128 @@ +""" +@author: Zhou QiYang +@file: test_metrics.py +@time: 2026/1/11 16:36 +@description: 测试 Experiment.metrics() 方法,使用 MagicMock 和 monkeypatch 模拟网络请求 +""" + +from unittest.mock import patch, MagicMock + +import pytest +import pandas as pd + +from swanlab.api.experiment import Experiment +from swanlab.core_python.client import Client +from swanlab.package import get_host_web, get_host_api +from utils import create_csv_data, create_run_type_data + + +@pytest.fixture +def experiment(): + """创建 Experiment 实例""" + data = create_run_type_data() + return Experiment( + MagicMock(spec=Client), + data=data, + path='test_user/test_project/test_exp', + web_host=get_host_web(), + login_user='test_user', + line_count=100, + ) + + +@pytest.fixture +def metrics_data(): + return [ + (list(range(10)), 'loss', [0.5 - i * 0.05 for i in range(10)]), + (list(range(10)), 'accuracy', [0.5 + i * 0.05 for i in range(10)]), + ] + + +class MockSetup: + """ + 模拟网络请求 + 直接 mock client.get 返回 URL,然后 mock pd.read_csv 返回 DataFrame + """ + + def __init__(self, metrics_data, experiment): + self.metrics_data = metrics_data + self.experiment = experiment + # Create a lookup dict for metrics data + self._metric_lookup = {m[1]: m for m in metrics_data} + + def _create_df(self, key): + """Helper to create DataFrame from metric data""" + if key not in self._metric_lookup: + return pd.DataFrame() + step_values, metric_name, metric_values = self._metric_lookup[key] + return pd.DataFrame({ + 'step': step_values, + f'{metric_name}_step': metric_values + }).set_index('step') + + def __enter__(self): + # Mock client.get to return CSV URL + self.mock_get = patch.object(self.experiment._client, 'get').start() + self.mock_get.side_effect = lambda path, params: [{'url': f'{get_host_api()}/{params["key"]}'}] + + # Mock pd.read_csv to return DataFrame directly + self.mock_read_csv = patch('pandas.read_csv').start() + self.mock_read_csv.side_effect = lambda url, index_col: self._create_df(url.split('/')[-1]) + return self + + def __exit__(self, *args): + patch.stopall() + + +def test_metrics_basic(experiment, metrics_data): + """测试使用指定 keys 获取历史数据""" + with MockSetup(metrics_data, experiment): + result = experiment.metrics(keys=['loss', 'accuracy']) + + assert len(result) == 10 + assert 'loss' in result.columns + assert 'accuracy' in result.columns + + +def test_metrics_with_x_axis(experiment, metrics_data): + """测试使用 x_axis 参数""" + with MockSetup(metrics_data, experiment): + result = experiment.metrics(keys=['loss'], x_axis='accuracy') + + # x_axis 列应该在第一列 + assert result.columns[0] == 'accuracy' + + +def test_metrics_with_sample(experiment, metrics_data): + """测试使用 sample 参数限制返回行数""" + with MockSetup(metrics_data, experiment): + result = experiment.metrics(keys=['loss'], sample=5) + + # 只返回前 5 行 + assert len(result) == 5 + + +def test_metrics_dict_mode(experiment, metrics_data): + """测试 pandas=False 时返回 DataFrame(当前实现只支持 DataFrame)""" + with MockSetup(metrics_data, experiment): + result = experiment.metrics(keys=['loss'], pandas=False) + + # 当前实现始终返回 DataFrame + assert isinstance(result, pd.DataFrame) + + +def test_full_metrics(experiment, metrics_data): + """测试 keys=None 时返回空 DataFrame(当前实现不支持)""" + with MockSetup(metrics_data, experiment): + result = experiment.metrics() + + # keys=None 时返回空 DataFrame + assert len(result) == 0 + + +@pytest.mark.parametrize("keys", ('invalid_keys', ['loss', 123, 'accuracy'])) +def test_metrics_invalid_keys(experiment, metrics_data, keys): + """测试 keys 参数类型错误的情况,返回空 DataFrame""" + with MockSetup(metrics_data, experiment): + result = experiment.metrics(keys=keys) + assert len(result) == 0 diff --git a/test/unit/api/test_project.py b/test/unit/api/test_project.py index 39289bace..3c07fa991 100644 --- a/test/unit/api/test_project.py +++ b/test/unit/api/test_project.py @@ -1,40 +1,27 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/8 16:14 -@File: test_project.py -@IDE: VSCode -@Description: - 测试开放API的项目相关接口 -""" +from unittest.mock import patch, MagicMock -import pytest +from swanlab.api.projects import Projects +from swanlab.core_python import Client +from swanlab.package import get_host_web +from utils import create_project_data -import tutils as T -from swanlab import OpenApi -from swanlab.api.types import ApiResponse, Project, Pagination +def test_projects(): + """测试能否分页获取所有项目""" + with patch('swanlab.api.projects.get_workspace_projects') as mock_get_projects: + total = 80 + page_size = 20 -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_list_projects(): - """ - 测试列出一个 workspace 下的所有项目 - """ - api = OpenApi() - resp = api.list_projects(detail=False) - assert isinstance(resp, ApiResponse) - if resp.code == 200: - assert isinstance(resp.data, list) - for item in resp.data: - assert isinstance(item, Project) + def side_effect(*args, **kwargs): + return create_project_data(page=kwargs.get("page", 1), total=total) -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_delete_project(): - """ - 测试删除一个项目 - """ - api = OpenApi() - project_name = "test_project" - resp = api.delete_project(project=project_name) - assert isinstance(resp, ApiResponse) - assert resp.code in [204, 404] + mock_get_projects.side_effect = side_effect + + mock_projects = Projects( + MagicMock(spec=Client), + web_host=get_host_web(), + path='test_user', + ) + projects = list(mock_projects) + assert len(projects) == total + assert mock_get_projects.call_count == (total + page_size - 1) // page_size diff --git a/test/unit/api/test_types.py b/test/unit/api/test_types.py deleted file mode 100644 index 53cd00cda..000000000 --- a/test/unit/api/test_types.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/14 22:25 -@File: test_types.py -@IDE: pycharm -@Description: - 测试开放API模型定义 -""" - -from swanlab.api.types import Experiment - - -def test_dict_style_access(): - """ - 测试字典风格访问对象字段 - """ - exp = Experiment.model_validate({ - "cuid": "test_cuid", - "name": "test_experiment", - "description": "test_description", - "state": "FINISHED", - "show": True, - "createdAt": "2025-05-14T22:25:00Z", - "finishedAt": "2025-05-14T22:30:00Z", - "user": { - "username": "test_user", - "name": "Test User" - }, - "profile": { - "requirements": "test_requirements", - } - }) - assert exp["cuid"] == "test_cuid" - \ No newline at end of file diff --git a/test/unit/api/test_user.py b/test/unit/api/test_user.py new file mode 100644 index 000000000..24231610f --- /dev/null +++ b/test/unit/api/test_user.py @@ -0,0 +1,109 @@ +from unittest.mock import patch, MagicMock + +import pytest + +from swanlab.api.user import User +from swanlab.api.users import Users +from swanlab.core_python import Client +from swanlab.error import ApiError +from utils import create_user_data + + +def create_user(username=None): + """创建用户对象的辅助函数""" + return User(MagicMock(spec=Client), login_user="test_user", username=username) + + +class SelfHosted: + """私有化部署的上下文管理器""" + + def __init__(self, start=False, **kwargs): + self.start = start # 是否启动私有化部署 + self.enabled = kwargs.get('enabled', True) + self.expired = kwargs.get('expired', False) + self.root = kwargs.get('root', False) + self.plan = kwargs.get('plan', 'free') + self.seats = kwargs.get('seats', 99) + + def __enter__(self): + self.mock_get_metrics = patch('swanlab.api.utils.get_self_hosted_init').start() + if not self.start: + self.mock_get_metrics.side_effect = ApiError() + else: + self.mock_get_metrics.return_value = { + 'enabled': self.enabled, # 是否成功部署 + 'expired': self.expired, # licence是否过期 + 'root': self.root, # 是否为根用户 + 'plan': self.plan, # 私有化版本(免费、商业) + 'seats': self.seats, # 余剩席位 + } + return self + + def __exit__(self, *args): + patch.stopall() + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, # 私有化未启动 + {'start': True, 'enabled': False}, # 启动,但未启用 + {'start': True, 'expired': True}, # licence过期 + {'start': True}, # 启动,正常,但不是root + {'username': 'test_other_user'}, + ], +) +def test_create_permission(kwargs): + """测试尝试创建用户是否会被拦截""" + user = create_user(kwargs.get('username', None)) + with SelfHosted(**kwargs): + with pytest.raises(ValueError): + user.create(username='test_user', password='123456aa') + + +@pytest.mark.parametrize( + ("username", "password"), + [ + ('user@name', 'password123'), # 无效的用户名 + ('test_user', 'short'), # 无效密码(密码长度小于8) + ('test_user', '12345678'), # 无效密码(全是数字) + ('test_user', 'ABCDEFGH'), # 有效密码(全是字母) + ], +) +def test_check_create_info(username, password): + """测试无效的用户名或密码""" + root_user = create_user() + with SelfHosted(start=True, root=True): + with pytest.raises(ValueError): + root_user.create(username, password) + + +def test_other_user(): + """测试是否对未开发的功能进行拦截""" + other_user = create_user(username="other_user") + with SelfHosted(start=True, root=True): + with pytest.raises(ValueError): + assert other_user.generate_api_key() is None + with pytest.raises(ValueError): + assert other_user.delete_api_key(api_key='test_api_key') == False + + +def test_users(): + """测试能否分页获取所有用户""" + with patch('swanlab.api.users.get_users') as mock_get_users: + total = 80 + page_size = 20 + + def side_effect(*args, **kwargs): + return create_user_data(page=kwargs.get("page", 1), total=total) + + mock_get_users.side_effect = side_effect + client = MagicMock(spec=Client) + users = Users(client, login_user="test_user") + + user_list = list(users) + assert len(user_list) == total + for i, user in enumerate(user_list): + assert user.username == f'user_{i}' + + assert mock_get_users.call_count == (total + page_size - 1) // page_size diff --git a/test/unit/api/utils.py b/test/unit/api/utils.py new file mode 100644 index 000000000..7f06bf693 --- /dev/null +++ b/test/unit/api/utils.py @@ -0,0 +1,125 @@ +import csv +from io import StringIO +from typing import Dict, List + +import nanoid + +from swanlab.core_python.api.type import RunType, ProjResponseType, ProjectType +from swanlab.package import get_host_web + + +def create_run_type_data(cuid=None) -> RunType: + """ + 创建 RunType 类型的测试数据 + """ + return { + 'cuid': cuid if cuid is not None else nanoid.generate('0123456789', 10), + 'name': '', + 'createdAt': '', + 'description': '', + 'labels': [], + 'profile': {'config': {}, 'scalar': {'loss': [], 'accuracy': []}}, + 'state': 'FINISHED', + 'cluster': '', + 'job': '', + 'runtime': '', + 'user': {'username': '', 'name': ''}, + 'rootExpId': None, + 'rootProId': None, + } + + +def create_nested_exps(groups: int = 2, num_per_group: int = 2) -> Dict: + """ + 创建嵌套的实验数据(用于模拟 get_project_experiments 的返回值) + + :param groups: 分组数量 + :param num_per_group: 每组(页)中的实验数量 + :return: 分组情况下 Dict 格式的数据 + """ + result = {} + for i in range(groups): + group_key = f'group_{i}' + runs = [] + for j in range(num_per_group): + run = create_run_type_data(cuid=f'exp_{i}_{j}') + runs.append(run) + result[group_key] = runs + return result + + +def create_project_data(page: int = 1, total: int = 20) -> ProjResponseType: + """ + 创建分页项目数据(用于模拟 get_workspace_projects 的返回值) + + :param page: 当前页数 + :param total: 项目总数 + :return: ProjResponseType 格式的数据 + """ + page_size = 20 + pages = (total + page_size - 1) // page_size + project_list: List[ProjectType] = [] + + for j in range(page_size): + project: ProjectType = { + 'cuid': f'proj_{page}_{j}', + 'name': f'project_{page}_{j}', + 'path': f'test_user/project_{page}_{j}', + 'url': f'{get_host_web()}/test_user/project_{page}_{j}', + 'description': '', + 'visibility': 'PRIVATE', + 'createdAt': '', + 'updatedAt': '', + 'group': {'workspace': 'test_user'}, + 'projectLabels': [], + '_count': {}, + } + project_list.append(project) + + return { + 'list': project_list, + 'size': page_size, + 'pages': pages, + 'total': total, + } + + +def create_user_data(page: int = 1, total: int = 20) -> Dict: + """ + 创建分页用户数据(用于模拟 get_users 的返回值) + + :param page: 当前页数 + :param total: 用户总数 + :return: 包含 list, pages, total 等字段的字典 + """ + page_size = 20 + pages = (total + page_size - 1) // page_size + user_list = [] + + for j in range(page_size): + user_list.append({ + 'username': f'user_{ (page - 1) * page_size + j }' + }) + + return { + 'list': user_list, + 'pages': pages, + 'total': total, + } + + +def create_csv_data(step_values, metric_name, metric_values): + """ + 创建 CSV 格式的数据 + :param step_values: step 列的值列表 + :param metric_name: 指标名称 + :param metric_values: 指标值列表 + :return: CSV 格式的字节数据 + """ + output = StringIO() + writer = csv.writer(output) + writer.writerow(['step', metric_name]) + for step, value in zip(step_values, metric_values): + writer.writerow([step, value]) + csv_text = output.getvalue() + return csv_text.encode('utf-8')