From dca98b1992cf59b39674e0e370862739ae03037d Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:45:13 +0800 Subject: [PATCH 01/21] feat(init): api login (#1380) --- swanlab/api/base.py | 175 --------------------- swanlab/api/experiment.py | 210 ------------------------- swanlab/api/group.py | 32 ---- swanlab/api/main.py | 258 +++---------------------------- swanlab/api/project.py | 71 --------- swanlab/api/types.py | 57 ------- test/api/project.py | 12 ++ test/unit/api/test_experiment.py | 101 ------------ test/unit/api/test_group.py | 27 ---- test/unit/api/test_project.py | 40 ----- test/unit/api/test_types.py | 35 ----- 11 files changed, 32 insertions(+), 986 deletions(-) delete mode 100644 swanlab/api/base.py delete mode 100644 swanlab/api/experiment.py delete mode 100644 swanlab/api/group.py delete mode 100644 swanlab/api/project.py delete mode 100644 swanlab/api/types.py create mode 100644 test/api/project.py delete mode 100644 test/unit/api/test_experiment.py delete mode 100644 test/unit/api/test_group.py delete mode 100644 test/unit/api/test_project.py delete mode 100644 test/unit/api/test_types.py diff --git a/swanlab/api/base.py b/swanlab/api/base.py deleted file mode 100644 index 2fd90e0a4..000000000 --- a/swanlab/api/base.py +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/1 17:36 -@File: base.py -@IDE: pycharm -@Description: - SwanLab OpenAPI API基类 -""" -import json -from datetime import datetime, timezone -from typing import Any, Callable, List, Optional, Union - -import requests - -from swanlab.api.types import ApiResponse -from swanlab.core_python import auth, create_session -from swanlab.log.log import SwanLog - -_logger: Optional[SwanLog] = None - -def get_logger(log_level: str = "info") -> SwanLog: - global _logger - if _logger is None: - _logger = SwanLog("swanlab.openapi", log_level) - else: - _logger.level = log_level - return _logger - -def handle_response(resp: requests.Response) -> ApiResponse: - try: - data = resp.json() if resp.content else {} - except (json.decoder.JSONDecodeError, requests.JSONDecodeError): - return ApiResponse[str]( - code=resp.status_code, - errmsg="sdk decode json error", - data=resp.text - ) - - if not isinstance(data, (dict, list)): - return ApiResponse[Any]( - code=resp.status_code, - errmsg="sdk decode dict error", - data=data - ) - - code = resp.status_code - if 200 <= code < 300: - message = "" - else: - message = f"api error: {resp.reason}. Trace id: {resp.headers.get('traceid')}" - return ApiResponse( - code=code, - errmsg=message, - data=data - ) - - -class ApiHTTP: - REFRESH_TIME = 60 * 60 * 24 * 7 # 7天 - - def __init__(self, login_info: auth.LoginInfo): - self.__logger = get_logger() - self.__login_info: auth.LoginInfo = login_info - self.__session: requests.Session = self.__init_session() - self.service: OpenApiService = OpenApiService(self) - - @property - def session(self) -> requests.Session: - """ - 获取当前的requests.Session对象 - """ - return self.__session - - @property - def username(self) -> str: - """ - 当前登录的用户名 - """ - return self.__login_info.username or "" - - @property - def base_url(self): - return self.__login_info.api_host - - @property - def sid_expired_at(self): - """ - 获取sid的过期时间 - """ - return datetime.strptime(self.__login_info.expired_at or "", "%Y-%m-%dT%H:%M:%S.%fZ") - - def __init_session(self) -> requests.Session: - session = create_session() - session.cookies.update({"sid": self.__login_info.sid or ""}) - return session - - def __before_request(self): - if (self.sid_expired_at - datetime.now(timezone.utc).replace(tzinfo=None)).total_seconds() < self.REFRESH_TIME: - self.__logger.debug("Refreshing sid...") - self.__login_info = auth.login_by_key(self.__login_info.api_key or "", save=False) - self.__session.headers["cookie"] = f"sid={self.__login_info.sid}" - - def get(self, url: str, params: dict) -> ApiResponse: - self.__before_request() - resp = self.__session.get(self.base_url + url, params=params) - return handle_response(resp) - - def post(self, url: str, data: Union[dict, list], params: dict) -> ApiResponse: - self.__before_request() - resp = self.__session.post(self.base_url + url, json=data, params=params) - return handle_response(resp) - - def delete(self, url: str, params: dict) -> ApiResponse: - self.__before_request() - resp = self.__session.delete(self.base_url + url, params=params) - return handle_response(resp) - -class OpenApiService: - def __init__(self, http: ApiHTTP): - self.http: ApiHTTP = http - - def get_exp_info(self, username: str, project: str, exp_id: str) -> ApiResponse[dict]: - """ - 获取实验信息 - """ - return self.http.get(f"/project/{username}/{project}/runs/{exp_id}", params={}) - - def get_project_info(self, username: str, projname: str) -> ApiResponse[dict]: - """ - 获取项目详情 - """ - return self.http.get(f"/project/{username}/{projname}", params={}) - - @staticmethod - def fetch_paginated_api( - api_func: Callable[..., ApiResponse], # 分页 API 请求函数 - page_field: str = "page", - size_field: str = "size", - page_size: int = 10, - *args, **kwargs - ) -> ApiResponse[List]: - """ - 通用分页全量拉取函数 - - Args: - api_func (Callable): 分页 API 请求函数,应返回 ApiResponse[Pagination] - page_field (str): 页码参数名,默认为 "page" - size_field (str): 每页大小参数名,默认为 "size" - page_size (int): 每页条数,默认为 10 - *args: 传递给 api_func 的位置参数 - **kwargs: 传递给 api_func 的关键字参数 - - Returns: - ApiResponse[list]: 返回所有分页数据组成的 ApiResponse - """ - page = 1 - objs = [] - while True: - kwargs.update({page_field: page, size_field: page_size}) - resp: ApiResponse = api_func(*args, **kwargs) - if resp.errmsg: - break - - objs.extend(resp.data.list) - - if len(objs) >= resp.data.total: - break - page += 1 - - return ApiResponse[List](code=resp.code, errmsg=resp.errmsg, data=objs) - -class ApiBase: - def __init__(self, http: ApiHTTP): - self.http: ApiHTTP = http diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py deleted file mode 100644 index 74db3d995..000000000 --- a/swanlab/api/experiment.py +++ /dev/null @@ -1,210 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/1 17:30 -@File: experiment.py -@IDE: pycharm -@Description: - 实验相关的开放API -""" -from typing import List - -from swanlab.api.base import ApiBase, ApiHTTP -from swanlab.api.types import ApiResponse, Experiment, Pagination - -try: - from pandas import DataFrame -except ImportError: - DataFrame = None - - -class ExperimentAPI(ApiBase): - def __init__(self, http: ApiHTTP): - super().__init__(http) - - @classmethod - def parse(cls, body: dict) -> Experiment: - return Experiment.model_validate({ - "cuid": body.get("cuid") or "", - "name": body.get("name") or "", - "description": body.get("description") or "", - "state": body.get("state") or "", - "show": bool(body.get("show")), - "createdAt": body.get("createdAt") or "", - "finishedAt": body.get("finishedAt") or "", - "user": { - "username": (body.get("user") or {}).get("username") or "", - "name": (body.get("user") or {}).get("name") or "", - }, - "profile": body.get("profile") or {} - }) - - def get_experiment( - self, - username: str, - projname: str, - exp_id: str - ) -> ApiResponse[Experiment]: - """ - 获取实验信息 - - Args: - username (str): 工作空间名 - projname (str): 项目名 - exp_id (str): 实验CUID - """ - resp: ApiResponse = self.http.service.get_exp_info(username=username, project=projname, exp_id=exp_id) - if resp.errmsg: - return resp - resp.data = ExperimentAPI.parse(resp.data) - return resp - - def delete_experiment( - self, - username: str, - projname: str, - exp_id: str - ): - """ - 删除实验 - - Args: - username (str): 工作空间名 - projname (str): 项目名 - exp_id (str): 实验CUID - """ - return self.http.delete(f"/project/{username}/{projname}/runs/{exp_id}", params={}) - - def list_experiments( - self, - username: str, - projname: str, - page: int = 1, - size: int = 10 - ) -> ApiResponse[Pagination[Experiment]]: - """ - 分页获取项目下的实验列表 - 该接口返回的实验profile只包含config(用户自定义的配置) - - Args: - username (str): 工作空间名 - projname (str): 项目名 - page (int): 页码, 默认为1 - size (int): 每页大小, 默认为10 - """ - resp = self.http.get(f"/project/{username}/{projname}/runs", params={"page": page, "size": size}) - - if resp.errmsg: - return resp - exps = resp.data - resp.data = Pagination[Experiment].model_validate({ - "total": exps.get("total", 0), - "list": [ExperimentAPI.parse(e) for e in exps.get("list", [])] - }) - return resp - - def get_summary( - self, - exp_id: str, - pro_id: str, - root_exp_id: str, - root_pro_id: str - ) -> ApiResponse[dict]: - """ - 获取实验的summary信息 - 从House获取, 需要考虑克隆实验 - - Args: - exp_id (str): 实验CUID - pro_id (str): 项目CUID - root_exp_id (str): 根实验CUID - root_pro_id (str): 根项目CUID - """ - data = { - "experimentId": exp_id, - "projectId": pro_id, - } - if root_exp_id and root_pro_id: - data["rootExpId"] = root_exp_id - data["rootProId"] = root_pro_id - - resp = self.http.post("/house/metrics/summaries", data=[data], params={}) - if resp.errmsg: - return resp - - resp.data = list(resp.data.values())[0] - resp.data = { - k: { - "step": v.get("step"), - "value": v.get("value"), - "min": { - "step": v.get("min").get("index"), - "value": v.get("min").get("data"), - }, - "max": { - "step": v.get("max").get("index"), - "value": v.get("max").get("data"), - } - } - for k, v in resp.data.items() - } - return resp - - def get_metrics( - self, - exp_id: str, - keys: List[str], - ) -> ApiResponse[DataFrame]: - """ - 获取实验的指标数据, 可选择若干由用户自定义的列 - - Args: - exp_id (str): 实验CUID - keys (list[str]): 指标key列表 - - Returns: - ApiResponse[DataFrame]: - """ - try: - import pandas as pd - except ImportError: - raise ImportError("OpenApi.get_metrics requires pandas module. Install with 'pip install pandas'.") - - # 去重 keys - keys = list(set(keys)) - dfs = [] - prefix = "" - for idx, key in enumerate(keys): - resp = self.http.get(f"/experiment/{exp_id}/column/csv", params={"key": key}) - if resp.errmsg: - continue - - url:str = resp.data.get("url", "") - df = pd.read_csv(url, index_col=0) - - if idx == 0: - # 从第一列名提取 prefix,例如 "t0707-02:17-loss_step" 中提取 "t0707-02:17-" - first_col = df.columns[0] - suffix = f"{key}_" - if suffix in first_col: - prefix = first_col.split(suffix)[0] # 结果为 "t0707-02:17-" - else: - prefix = "" - - if prefix: - df.columns = [ - col[len(prefix):].removesuffix("_step") if col.startswith(prefix) else col.removesuffix("_step") - for col in df.columns - ] - else: - df.columns = [col.removesuffix("_step") for col in df.columns] - - dfs.append(df) - - if not dfs: - return ApiResponse[DataFrame](code=404, errmsg="No data found", data=pd.DataFrame()) - - # 按列合并,使用 inner join 保证对齐 index - result_df = pd.concat(dfs, axis=1, join="inner") - - return ApiResponse[DataFrame](code=200, errmsg="", data=result_df) diff --git a/swanlab/api/group.py b/swanlab/api/group.py deleted file mode 100644 index fd48ce087..000000000 --- a/swanlab/api/group.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/4/30 12:11 -@File: group.py -@IDE: pycharm -@Description: - 组织相关的开放API -""" - -from swanlab.api.base import ApiBase, ApiHTTP -from swanlab.api.types import ApiResponse - - -class GroupAPI(ApiBase): - def __init__(self, http: ApiHTTP): - super().__init__(http) - - def list_workspaces(self) -> ApiResponse[list]: - resp = self.http.get("/group/", params={}) - if resp.errmsg: - return resp - groups = resp.data.get("list", []) - resp.data = [ - { - "name": item["name"], - "username": item["username"], - "role": item["role"] - } - for item in groups - ] - return resp diff --git a/swanlab/api/main.py b/swanlab/api/main.py index 1029e4ba9..5971a8103 100644 --- a/swanlab/api/main.py +++ b/swanlab/api/main.py @@ -1,23 +1,16 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/1 17:00 -@File: main.py -@IDE: pycharm -@Description: - SwanLab OpenAPI模块 """ -from typing import Dict, List, Union +@author: Zhou Qiyang +@file: main.py +@time: 2025/12/17 11:39 +@description: OpenApi 模块 +""" + +from typing import Optional -from swanlab.api.base import ApiHTTP, get_logger -from swanlab.api.experiment import ExperimentAPI -from swanlab.api.group import GroupAPI -from swanlab.api.project import ProjectAPI -from swanlab.api.types import ApiResponse, Experiment, Project -from swanlab.core_python import auth +from swanlab.core_python import auth, Client from swanlab.error import KeyFileError -from swanlab.log.log import SwanLog -from swanlab.package import get_key +from swanlab.log import swanlog +from swanlab.package import get_key, HostFormatter try: from pandas import DataFrame @@ -26,230 +19,19 @@ class OpenApi: - def __init__(self, api_key: str = "", log_level: str = "info"): - self.__logger: SwanLog = get_logger(log_level) - + def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, web_host: Optional[str] = None): + if host or web_host: + HostFormatter(host, web_host)() if api_key: - self.__logger.debug("Using API key", api_key) - self.__key = api_key - self.login_info = auth.code_login(self.__key, False) + swanlog.debug("Using API key", api_key) else: - self.__logger.debug("Using existing key") + swanlog.debug("Using existing key") try: - self.__key = get_key() + api_key = get_key() except KeyFileError as e: - self.__logger.error("To use SwanLab OpenAPI, please login first.") + swanlog.error("To use SwanLab OpenAPI, please login first.") raise RuntimeError("Not logged in.") from e - self.login_info = auth.code_login(self.__key, False) - - self.username = self.login_info.username - self.http: ApiHTTP = ApiHTTP(self.login_info) - self.service = self.http.service - - self.group = GroupAPI(self.http) - self.experiment = ExperimentAPI(self.http) - self.project = ProjectAPI(self.http) - - def list_workspaces(self) -> ApiResponse[List[Dict]]: - """ - 获取当前用户的所有工作空间(Group) - - Returns: - ApiResponse[List]: - - code (int): HTTP 状态码 - - errmsg (str): 错误信息, 仅在请求有错误时非空 - - data (List[Dict]): 一个列表, 其中每个元素是一个字典, 包含相应工作空间的基础信息: - - name (str): 工作空间名称 - - username (str): 工作空间名(用于组织相关的 URL) - - role (str): 用户在该工作空间中的角色,如 'OWNER' 或 'MEMBER' - """ - return self.group.list_workspaces() - - def get_experiment( - self, - project: str, - exp_id: str, - username: str = "" - ) -> ApiResponse[Experiment]: - """ - 获取实验信息 - - Args: - project (str): 项目名 - exp_id (str): 实验CUID - username (str): 工作空间名, 默认为用户个人空间 - - Returns: - ApiResponse[Experiment]: - - code (int): HTTP 状态码 - - errmsg (str): 错误信息, 仅在请求有错误时非空 - - data (Dict): 实验信息的字典, 包含实验信息 - """ - return self.experiment.get_experiment( - username=username if username else self.http.username, projname=project, exp_id=exp_id - ) - - def delete_experiment( - self, - project: str, - exp_id: str, - username: str = "" - ) -> ApiResponse[None]: - """ - 删除实验 - - Args: - project (str): 项目名 - exp_id (str): 实验CUID - username (str): 工作空间名, 默认为用户个人空间 - - Returns: - ApiResponse[None]: - - code (int): HTTP 状态码 - - errmsg (str): 错误信息, 仅在请求有错误时非空 - - data (None): 无数据返回 - """ - return self.experiment.delete_experiment( - username=username if username else self.http.username, projname=project, exp_id=exp_id - ) - - def list_experiments( - self, - project: str, - username: str = "" - ) -> ApiResponse[List[Experiment]]: - """ - 获取一个项目下的所有实验 - - Args: - project (str): 项目名 - username (str): 工作空间名, 默认为用户个人空间 - - Returns: - ApiResponse[Experiment]: - - code (int): HTTP 状态码 - - errmsg (str): 错误信息, 仅在请求有错误时非空 - - data (List[Experiment]): 实验列表, 每个元素包含一个实验的信息 - - 此实验的 profile 只包含 config (实验自定义配置) - """ - return self.service.fetch_paginated_api( - api_func=self.experiment.list_experiments, - projname=project, - username=username if username else self.http.username - ) - - def delete_project( - self, - project: str, - username: str = "" - ) -> ApiResponse[None]: - """ - 删除一个项目 - - Args: - project (str): 项目名 - username (str): 工作空间名, 默认为用户个人空间 - - Returns: - ApiResponse[None]: - - code (int): HTTP 状态码 - - errmsg (str): 错误信息, 仅在请求有错误时非空 - - data (None): 无数据返回 - """ - return self.project.delete_project( - username=username if username else self.http.username, project=project - ) - - def list_projects( - self, - username: str = "", - detail: bool = True - ) -> ApiResponse[List[Project]]: - """ - 获取一个工作空间下的所有项目 - - Args: - username (str): 工作空间名, 默认为用户个人空间 - detail (bool): 是否包含详细统计信息,默认为 True - - Returns: - ApiResponse[List[Project]]: - - code (int): HTTP 状态码 - - errmsg (str): 错误信息, 仅在请求有错误时非空 - - data (List[Project]): 项目列表, 每个元素包含一个项目的信息 - """ - return self.service.fetch_paginated_api( - api_func=self.project.list_projects, - username=username if username else self.http.username, - detail=detail - ) - - def get_summary( - self, - project: str, - exp_id: str, - username: str = "" - ) -> ApiResponse[Dict]: - """ - 获取实验的概要信息 - - Args: - project (str): 项目名 - exp_id (str): 实验CUID - username (str): 工作空间名, 默认为用户个人空间 - - Returns: - ApiResponse[Dict]: - - code (int): HTTP 状态码 - - errmsg (str): 错误信息, 仅在请求有错误时非空 - - data (Dict): 实验的概要信息字典, 包含用户训练各指标的最大最小值, 及其对应步数 - """ - username = username if username else self.http.username - project_cuid = self.service.get_project_info(username=username, projname=project).data.get("cuid", "") - exp = self.service.get_exp_info(username=username, project=project, exp_id=exp_id) - return self.experiment.get_summary( - exp_id=exp_id, - pro_id=project_cuid, - root_exp_id=exp.data.get("rootExpId", ""), - root_pro_id=exp.data.get("rootProId", "") - ) - - def get_metrics( - self, - exp_id: str, - keys: Union[str, List[str]], - ) -> ApiResponse[DataFrame]: - """ - 获取实验的指标数据 - - Args: - exp_id (str): 实验CUID - keys (str | List[str]): 指标key, 单个字符串或字符串列表 - - Returns: - ApiResponse[DataFrame]: 包含指标数据的响应, 指标数据以 DataFrame 格式返回 - 在DataFrame中, 每个key对应两个列, 分别为key和key_timestamp, 表示指标值和时间戳 - """ - if isinstance(keys, str): - keys = [keys] - return self.experiment.get_metrics(exp_id, keys) - - def get_exp_summary(self, *args, **kwargs) -> ApiResponse[Dict]: - """ - 获取实验的概要信息 - @deprecated, 请使用 `get_experiment_summary` - - Returns: - ApiResponse[Dict]: 包含实验概要信息的响应 - """ - return self.get_summary(*args, **kwargs) - - def list_project_exps(self, *args, **kwargs) -> ApiResponse[List[Experiment]]: - """ - 获取一个项目下的所有实验 - @deprecated, 请使用 `list_experiments` - Returns: - ApiResponse[List[Experiment]]: 包含实验列表的响应 - """ - return self.list_experiments(*args, **kwargs) + login_info = auth.code_login(api_key, save_key=False) + # 一个OpenApi对应一个client,可创建多个api获取从不同的client获取不同账号下的实验信息 + self._client: Client = Client(login_info) diff --git a/swanlab/api/project.py b/swanlab/api/project.py deleted file mode 100644 index b41eba4c0..000000000 --- a/swanlab/api/project.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/1 17:30 -@File: project.py -@IDE: pycharm -@Description: - 项目相关的开放API -""" - -from swanlab.api.base import ApiBase, ApiHTTP -from swanlab.api.types import ApiResponse, Pagination, Project - - -class ProjectAPI(ApiBase): - def __init__(self, http: ApiHTTP): - super().__init__(http) - - @classmethod - def parse(cls, body: dict, detail=True) -> Project: - project_parser = { - "cuid": body.get("cuid") or "", - "name": body.get("name") or "", - "description": body.get("description") or "", - "visibility": body.get("visibility") or "", - "createdAt": body.get("createdAt") or "", - "updatedAt": body.get("updatedAt") or "", - "group": { - "type": (body.get("group") or {}).get("type") or "", - "username": (body.get("group") or {}).get("username") or "", - "name": (body.get("group") or {}).get("name") or "", - }, - } - if detail: - project_parser["count"] = body.get("_count") or {} - return Project.model_validate(project_parser) - - def delete_project( - self, username: str, project: str - ) -> ApiResponse[None]: - """ - 删除一个项目 - - Args: - username (str): 工作空间名 - project (str): 项目名 - """ - return self.http.delete(f"/project/{username}/{project}", params={}) - - def list_projects( - self, username: str, detail = True, page: int = 1, size: int = 10 - ) -> ApiResponse[Pagination[Project]]: - """ - 列出一个 workspace 下的所有项目 - - Args: - username (str): 工作空间名, 默认为用户个人空间 - detail (bool): 是否返回详细统计信息,默认为 True - page (int): 页码,默认为 1 - size (int): 每页数量,默认为 10 - """ - resp = self.http.get(f"/project/{username}", params={"detail": detail, "page": page, "size": size}) - if resp.errmsg: - return resp - - resp.data = Pagination[Project].model_validate({ - "total": resp.data.get("total", 0), - "list": [ProjectAPI.parse(project, detail=detail) for project in resp.data.get("list", [])] - }) - - return resp diff --git a/swanlab/api/types.py b/swanlab/api/types.py deleted file mode 100644 index a3806e177..000000000 --- a/swanlab/api/types.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/8 14:36 -@File: types.py -@IDE: pycharm -@Description: - OpenAPI 相关数据结构 -""" - -from typing import Dict, Generic, List, TypeVar - -from pydantic import BaseModel as PydanticBaseModel -from pydantic import ConfigDict - - -class BaseModel(PydanticBaseModel): - def __getitem__(self, key): - return getattr(self, key) - - -class Experiment(BaseModel): - cuid: str # 实验CUID, 唯一标识符 - name: str # 实验名 - description: str = "" # 实验描述 - state: str # 实验状态, 'FINISHED' 或 'RUNNING' - show: bool # 显示状态 - createdAt: str # e.g., '2024-11-23T12:28:04.286Z' - finishedAt: str = "" # e.g., '2024-11-23T12:28:04.286Z' - user: Dict[str, str] # 实验创建者, 包含 'username' 与 'name' - profile: Dict # 实验相关配置 - - -class Project(BaseModel): - cuid: str # 项目CUID, 唯一标识符 - name: str # 项目名 - description: str = "" # 项目描述 - visibility: str # 可见性, 'PUBLIC' 或 'PRIVATE' - createdAt: str # e.g., '2024-11-23T12:28:04.286Z' - updatedAt: str # e.g., '2024-11-23T12:28:04.286Z' - group: Dict[str, str] # 工作空间信息, 包含 'type', 'username', 'name' - count: Dict[str, int] = {} # 项目的统计信息 - - -D = TypeVar("D") - - -class ApiResponse(BaseModel, Generic[D]): - code: int # HTTP状态码 - errmsg: str # API错误消息, 只有请求错误时非空 - data: D # 返回数据 - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class Pagination(BaseModel, Generic[D]): - total: int # 总数 - list: List[D] # 列表数据,泛型 diff --git a/test/api/project.py b/test/api/project.py new file mode 100644 index 000000000..bdc994332 --- /dev/null +++ b/test/api/project.py @@ -0,0 +1,12 @@ +""" +@author: Zhou Qiyang +@file: project.py +@time: 2025/12/17 10:45 +@description: 用于测试api登录功能 +""" +import swanlab + + +# swanlab.login() +api = swanlab.OpenApi() +# print(api) diff --git a/test/unit/api/test_experiment.py b/test/unit/api/test_experiment.py deleted file mode 100644 index 81214084a..000000000 --- a/test/unit/api/test_experiment.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/7 09:47 -@File: test_experiment.py -@IDE: pycharm -@Description: - 测试开放API的实验相关接口 -""" - -import pandas as pd -import pytest - -import tutils as T -from swanlab import OpenApi -from swanlab.api.types import ApiResponse, Experiment - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_experiment(): - """ - 获取一个实验的详细信息 - """ - api = OpenApi() - exp_cuid = "test_cuid" - res = api.get_experiment(project="test_project", exp_id=exp_cuid) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, Experiment) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_list_experiments(): - """ - 获取一个项目下的实验列表 - """ - api = OpenApi() - res = api.list_experiments(project="SwanLab") - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, list) - for item in res.data: - assert isinstance(item, Experiment) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_summary(): - """ - 获取一个实验的Summary信息 - """ - api = OpenApi() - exp_cuid = "test_cuid" - res = api.get_summary(project="test_project", exp_id=exp_cuid) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, dict) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_delete_experiment(): - """ - 删除一个实验 - """ - api = OpenApi() - exp_cuid = "test_cuid" - res = api.delete_experiment(project="test_project", exp_id=exp_cuid) - assert isinstance(res, ApiResponse) - assert res.code in [204, 404] - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_metrics(): - """ - 获取实验的指标数据 - """ - api = OpenApi() - exp_cuid = "test_cuid" - keys = ["accuracy", "loss"] - res = api.get_metrics(exp_id=exp_cuid, keys=keys) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, pd.DataFrame) - assert not res.data.empty - assert all(key in res.data.columns for key in keys) - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_metrics_duplicate_keys(): - """ - 获取实验的指标数据 - """ - api = OpenApi() - exp_cuid = "test_cuid" - keys = ["accuracy", "loss", "loss"] - res = api.get_metrics(exp_id=exp_cuid, keys=keys) - assert isinstance(res, ApiResponse) - if res.code == 200: - assert isinstance(res.data, pd.DataFrame) - assert not res.data.empty - assert len(res.data.columns.to_list()) == 2 - assert all(key in res.data.columns for key in ["accuracy", "loss"]) diff --git a/test/unit/api/test_group.py b/test/unit/api/test_group.py deleted file mode 100644 index 12269cb8e..000000000 --- a/test/unit/api/test_group.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/4/30 21:16 -@File: test_group.py -@IDE: pycharm -@Description: - 测试开放API的组织相关接口 -""" - -import pytest - -import tutils as T -from swanlab import OpenApi -from swanlab.api.types import ApiResponse - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_get_workspaces(): - """ - 获取用户的所有工作空间 - """ - api = OpenApi() - r = api.list_workspaces() - - assert isinstance(r, ApiResponse) - assert isinstance(r.data, list) diff --git a/test/unit/api/test_project.py b/test/unit/api/test_project.py deleted file mode 100644 index 39289bace..000000000 --- a/test/unit/api/test_project.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/8 16:14 -@File: test_project.py -@IDE: VSCode -@Description: - 测试开放API的项目相关接口 -""" - -import pytest - -import tutils as T -from swanlab import OpenApi -from swanlab.api.types import ApiResponse, Project, Pagination - - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_list_projects(): - """ - 测试列出一个 workspace 下的所有项目 - """ - api = OpenApi() - resp = api.list_projects(detail=False) - assert isinstance(resp, ApiResponse) - if resp.code == 200: - assert isinstance(resp.data, list) - for item in resp.data: - assert isinstance(item, Project) - -@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") -def test_delete_project(): - """ - 测试删除一个项目 - """ - api = OpenApi() - project_name = "test_project" - resp = api.delete_project(project=project_name) - assert isinstance(resp, ApiResponse) - assert resp.code in [204, 404] diff --git a/test/unit/api/test_types.py b/test/unit/api/test_types.py deleted file mode 100644 index 53cd00cda..000000000 --- a/test/unit/api/test_types.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/5/14 22:25 -@File: test_types.py -@IDE: pycharm -@Description: - 测试开放API模型定义 -""" - -from swanlab.api.types import Experiment - - -def test_dict_style_access(): - """ - 测试字典风格访问对象字段 - """ - exp = Experiment.model_validate({ - "cuid": "test_cuid", - "name": "test_experiment", - "description": "test_description", - "state": "FINISHED", - "show": True, - "createdAt": "2025-05-14T22:25:00Z", - "finishedAt": "2025-05-14T22:30:00Z", - "user": { - "username": "test_user", - "name": "Test User" - }, - "profile": { - "requirements": "test_requirements", - } - }) - assert exp["cuid"] == "test_cuid" - \ No newline at end of file From 974ee4f4efad702f727b5d6dd75c1d7dc30c7cf2 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:51:10 +0800 Subject: [PATCH 02/21] feat: get all projects through OpenApi (#1382) * feat: get all project of a group * update unit test using magicMock * update unit test & fix bug * opt import and rename some file - move api from folder core_python to api * fix wrong comment * remove pending status when getting entity projects * opt Project class - add EN comment for Project properties - add __str__ method for project label - adapt snake_case naming for project properties - change some project property name: projectLabels -> label - change api.projects() param name: entity -> workspace - delete some project properties (group, cuid) - prase url inside Project class * dynamically request project data according to the traversal of the project, rather than requesting all project information at once * opt import * iteratively request project data inside the Projects class - move api form api package to core_python package * fix bug & add type * move api's type.py to core_python --- swanlab/api/__init__.py | 6 +- swanlab/api/{main.py => api.py} | 20 +++- swanlab/api/model.py | 156 +++++++++++++++++++++++++++++ swanlab/core_python/api/project.py | 43 ++++++++ swanlab/core_python/api/type.py | 47 +++++++++ test/api/project.py | 106 +++++++++++++++++++- 6 files changed, 370 insertions(+), 8 deletions(-) rename swanlab/api/{main.py => api.py} (68%) create mode 100644 swanlab/api/model.py create mode 100644 swanlab/core_python/api/project.py create mode 100644 swanlab/core_python/api/type.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index e99aa9c47..f7c75dc19 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -8,8 +8,6 @@ SwanLab OpenAPI包 """ -from swanlab.api.main import OpenApi +from .api import OpenApi -__all__ = [ - "OpenApi" -] +__all__ = ["OpenApi"] diff --git a/swanlab/api/main.py b/swanlab/api/api.py similarity index 68% rename from swanlab/api/main.py rename to swanlab/api/api.py index 5971a8103..08c2283b6 100644 --- a/swanlab/api/main.py +++ b/swanlab/api/api.py @@ -5,12 +5,13 @@ @description: OpenApi 模块 """ -from typing import Optional +from typing import Optional, List from swanlab.core_python import auth, Client from swanlab.error import KeyFileError from swanlab.log import swanlog from swanlab.package import get_key, HostFormatter +from .model import Projects try: from pandas import DataFrame @@ -35,3 +36,20 @@ def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, we login_info = auth.code_login(api_key, save_key=False) # 一个OpenApi对应一个client,可创建多个api获取从不同的client获取不同账号下的实验信息 self._client: Client = Client(login_info) + self._web_host = login_info.web_host + + def projects( + self, + workspace: str, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ) -> Projects: + return Projects( + client=self._client, + web_host=self._web_host, + workspace=workspace, + sort=sort, + search=search, + detail=detail, + ) diff --git a/swanlab/api/model.py b/swanlab/api/model.py new file mode 100644 index 000000000..ccd628e19 --- /dev/null +++ b/swanlab/api/model.py @@ -0,0 +1,156 @@ +""" +@author: Zhou Qiyang +@file: model.py +@time: 2025/12/18 20:10 +@description: OpenApi查询结果将以对象返回,并且对后端的返回字段进行一些筛选 +""" + +from typing import List, Dict, Optional + +from swanlab.core_python import Client +from swanlab.core_python.api.project import get_entity_projects +from swanlab.core_python.api.type import ProjectType, ProjectLabelType, ProjResponseType + + +class Label: + """ + Project label object + you can get the label name by str(label) + """ + + def __init__(self, data: ProjectLabelType): + self._data = data + + @property + def name(self): + """ + Label name. + """ + return self._data['name'] + + def __str__(self): + return str(self.name) + + +class Project: + """ + Representing a single project with some of its properties. + """ + + def __init__(self, data: ProjectType, web_host: str): + self._data = data + self._web_host = web_host + + @property + def name(self) -> str: + """ + Project name. + """ + return self._data['name'] + + @property + def path(self) -> str: + """ + Project path in the format 'username/project-name'. + """ + return self._data['path'] + + @property + def url(self) -> str: + """ + Full URL to access the project. + """ + return f"{self._web_host}/@{self._data['path']}" + + @property + def description(self) -> str: + """ + Project description. + """ + return self._data['description'] + + @property + def visibility(self) -> str: + """ + Project visibility, either 'PUBLIC' or 'PRIVATE'. + """ + return self._data['visibility'] + + @property + def created_at(self) -> str: + """ + Project creation timestamp + """ + return self._data['createdAt'] + + @property + def updated_at(self) -> str: + """ + Project last update timestamp + """ + return self._data['updatedAt'] + + @property + def workspace(self) -> str: + """ + Project workspace name. + """ + return self._data["group"]["username"] + + @property + def labels(self) -> List[Label]: + """ + List of Label attached to this project. + """ + return [Label(label) for label in self._data['projectLabels']] + + @property + def count(self) -> Dict[str, int]: + """ + Project statistics dictionary containing: + experiments, contributors, children, collaborators, runningExps. + """ + return self._data['_count'] + + +class Projects: + """ + Container for a collection of Project objects. + You can iterate over the projects by for-in loop. + """ + + def __init__( + self, + client: Client, + web_host: str, + workspace: str, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ): + self._client = client + self._web_host = web_host + self._workspace = workspace + self._sort = sort + self._search = search + self._detail = detail + + def __iter__(self): + # 按用户遍历情况获取项目信息 + cur_page = 0 + page_size = 20 + while True: + cur_page += 1 + projects_info: ProjResponseType = get_entity_projects( + self._client, + workspace=self._workspace, + page=cur_page, + size=page_size, + sort=self._sort, + search=self._search, + detail=self._detail, + ) + if cur_page * page_size >= projects_info['total']: + break + + yield from iter(Project(project, self._web_host) for project in projects_info['list']) diff --git a/swanlab/core_python/api/project.py b/swanlab/core_python/api/project.py new file mode 100644 index 000000000..0b5f7393a --- /dev/null +++ b/swanlab/core_python/api/project.py @@ -0,0 +1,43 @@ +""" +@author: Zhou QiYang +@file: project.py +@time: 2025/12/19 23:49 +@description: 定义项目相关的后端API接口 +""" + +from typing import Optional, List + +from swanlab.core_python.api.type import ProjParamType, ProjResponseType +from .. import Client + + +def get_entity_projects( + client: Client, + *, + workspace: str, + page: int = 1, + size: int = 20, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, +): + """ + 获取指定页数和条件下的项目信息 + :param client: 已登录的客户端实例 + :param workspace: 工作空间名称 + :param page: 页码 + :param size: 每页项目数量 + :param sort: 排序规则, 可选 + :param search: 搜索的项目名称关键字, 可选 + :param detail: 是否包含项目下实验的相关信息, 可选, 默认为true + """ + params: ProjParamType = { + 'page': page, + 'size': size, + 'sort': sort, + 'search': search, + 'detail': detail, + } + res = client.get(f"/project/{workspace}", params=dict(params)) + projects_info: ProjResponseType = res[0] + return projects_info diff --git a/swanlab/core_python/api/type.py b/swanlab/core_python/api/type.py new file mode 100644 index 000000000..7f13ec6ff --- /dev/null +++ b/swanlab/core_python/api/type.py @@ -0,0 +1,47 @@ +""" +@author: Zhou Qiyang +@file: types.py +@time: 2025/12/17 16:35 +@description: OpenApi 用到的类型文件 +""" + +from typing import TypedDict, Optional, List, Dict + + +# 发送到后端查询项目信息的字段 +class ProjParamType(TypedDict): + page: int # 页码 + size: int # 每页项目数量 + sort: Optional[List[str]] # 排序方式(包含多个条件的列表) + search: Optional[str] # 搜索关键词 + detail: Optional[bool] # 是否返回详细信息(_count) + + +class GroupType(TypedDict): + username: str # 工作空间名称 (workspace) + + +class ProjectLabelType(TypedDict): + name: str # 项目标签名称 + + +class ProjectType(TypedDict): + cuid: str # 项目CUID, 唯一标识符 + name: str # 项目名 + path: str # 项目路径 + url: str # 项目URL + description: str # 项目描述 + visibility: str # 可见性, 'PUBLIC' 或 'PRIVATE' + createdAt: str # e.g., '2024-11-23T12:28:04.286Z' + updatedAt: str # e.g., '2024-11-23T12:28:04.286Z' + group: GroupType # 项目所属工作空间名称 (workspace) + projectLabels: List[ProjectLabelType] # 项目标签 + _count: Dict[str, int] # 项目的统计信息 + + +# 后端返回的项目信息 +class ProjResponseType(TypedDict): + list: List[ProjectType] # 项目列表 + size: int # 每页项目数量 + pages: int # 总页数 + total: int # 总项目数量 diff --git a/test/api/project.py b/test/api/project.py index bdc994332..25a600bfa 100644 --- a/test/api/project.py +++ b/test/api/project.py @@ -4,9 +4,109 @@ @time: 2025/12/17 10:45 @description: 用于测试api登录功能 """ + +from unittest.mock import patch + import swanlab -# swanlab.login() -api = swanlab.OpenApi() -# print(api) +# 测试数据 +def make_fake_projects(start, count): + return [ + { + "cuid": f"c{n}", + "name": f"proj-{n}", + "path": f"user/proj-{n}", + "url": f"https://dev001.swanlab.cn/@user/proj-{n}", + "description": f"desc-{n}", + "visibility": "PUBLIC", + "createdAt": "2025-01-01T00:00:00Z", + "updatedAt": "2025-01-01T00:00:00Z", + "projectLabels": [{"name": "Nvidia"}], + "group": {"username": "user", "status": "ENABLED", "type": "TEAM"}, + "_count": {}, + } + for n in range(start, start + count) + ] + + +performance_test_projects = [ + [ + { + "total": 40, + "pages": 1, + "size": 20, + "list": make_fake_projects(0, 20), + }, + ], + [ + { + "total": 40, + "pages": 2, + "size": 20, + "list": make_fake_projects(20, 20), + }, + ], +] + +params_test_projects = [ + { + "total": 20, + "pages": 1, + "size": 20, + "list": make_fake_projects(0, 20), + }, +] + + +# 性能测试:是否按照当前遍历的项目动态获取 +def test_api_projects_performance(): + # patch: client.get 返回 fake_projects_raw + with patch("swanlab.core_python.client.Client.get", side_effect=performance_test_projects) as mock_get: + api = swanlab.OpenApi() + + result = api.projects(workspace="bainiantest", detail=True) + + # 断言请求调用次数 + assert mock_get.call_count == 0 + for project in result: + if project.name == "proj-19": + assert mock_get.call_count == 1 + if project.name == "proj-20": + assert mock_get.call_count == 2 + + +# 功能测试:获取到的项目的属性是否齐全且正确 +def test_api_projects_params(): + with patch("swanlab.core_python.client.Client.get", return_value=params_test_projects): + api = swanlab.OpenApi() + + result = api.projects(workspace="bainiantest", detail=True) + + # 1. 字符串类型的字段 + raw_list = params_test_projects[0]["list"] + fields = { + "name": "name", + "path": "path", + "url": "url", + "description": "description", + "visibility": "visibility", + "created_at": "createdAt", + "updated_at": "updatedAt", + } + for field in fields: + assert [getattr(p, field) for p in result] == [r[fields[field]] for r in raw_list] + + # 1.1 workspace + assert [p.workspace for p in result] == [r["group"]["username"] for r in raw_list] + + # 2. labels + assert [[l.__str__() for l in p.labels] for p in result] == [ + [l["name"] for l in r["projectLabels"]] for r in raw_list + ] + assert [[l.name for l in p.labels] for p in result] == [ + [l["name"] for l in r["projectLabels"]] for r in raw_list + ] + + # 3. count + assert [p.count for p in result] == [r["_count"] for r in raw_list] From e84f8199ccba4fd122f5faa9aa83b621c77f483e Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Fri, 9 Jan 2026 10:38:19 +0800 Subject: [PATCH 03/21] Feat/api runs user (#1403) * feat: get runs metadata - add api in core_python to get all exps in a project - add new type and class to parse exps * update unit test & fix bugs * opt code style & add comment * raise value error if user's path is invaded * fix bug of Incorrectly raised ValueError * get single exp through OpenApi * resolve conflicts: filtering exp (basic implement) - get full single exp info through filter func # Conflicts: # swanlab/core_python/api/experiment.py * opt __dict__ for project and experiment * resolve conflict: update run.history() * feat: update run.scan_history() - add example codes * resolve conflict: fix bugs when importing pandas * import pandas only when used * fix: pandas warning * feat: run and user features of OpenApi - get metric data in batch - recover old OpenApi version for smooth transition - create & delete api_key through Api.user * opt Client importing in core_python's api * opt logic of run.history() - get latest api_keys when user delete api_key * rename the deprecated openApi folder * refactored run.history() & add return type - add return type for the backend interfaces - simplified get_experiment_metrics() and handle the csv inside HistoryPool * removed redundant types * update unit test using mock * feat: get user's team * feat: create user if user is the root user in self-hosted swanlab - refactor the models in OpenApi * feat: update unit test for create user and get user teams * delete test code of OpenApi * use ThreadPoolExecutor in HistoryPool * Refactor Api class and cleanup api module structure Moved the Api class implementation from swanlab/api/api.py to swanlab/api/__init__.py and deleted the redundant api.py file. Updated imports and references accordingly. Minor docstring and comment improvements, and fixed a message in thread.py to reference 'Api' instead of 'OpenApi'. * Fix circular import by updating Api export Combined the import of OpenApi and Api from .api to prevent circular import issues and simplify the export process. --------- Co-authored-by: ZeYi Lin <944270057@qq.com> Co-authored-by: Kang Li --- .idea/dictionaries/project.xml | 1 + swanlab/__init__.py | 3 +- swanlab/api/__init__.py | 126 +++++++++- swanlab/api/api.py | 55 ----- swanlab/api/deprecated/__init__.py | 10 + swanlab/api/deprecated/base.py | 175 ++++++++++++++ swanlab/api/deprecated/experiment.py | 210 +++++++++++++++++ swanlab/api/deprecated/group.py | 31 +++ swanlab/api/deprecated/main.py | 256 +++++++++++++++++++++ swanlab/api/deprecated/project.py | 71 ++++++ swanlab/api/deprecated/types.py | 57 +++++ swanlab/api/model/__init__.py | 12 + swanlab/api/model/base.py | 60 +++++ swanlab/api/model/experiment.py | 242 +++++++++++++++++++ swanlab/api/{model.py => model/project.py} | 50 ++-- swanlab/api/model/user.py | 83 +++++++ swanlab/api/thread.py | 97 ++++++++ swanlab/api/utils.py | 24 ++ swanlab/core_python/api/experiment.py | 59 ++++- swanlab/core_python/api/project.py | 19 +- swanlab/core_python/api/self_hosted.py | 34 +++ swanlab/core_python/api/type.py | 60 ++++- swanlab/core_python/api/user.py | 65 ++++++ swanlab/core_python/api/utils.py | 24 ++ swanlab/core_python/client/__init__.py | 9 + test/api/project.py | 112 --------- 26 files changed, 1720 insertions(+), 225 deletions(-) delete mode 100644 swanlab/api/api.py create mode 100644 swanlab/api/deprecated/__init__.py create mode 100644 swanlab/api/deprecated/base.py create mode 100644 swanlab/api/deprecated/experiment.py create mode 100644 swanlab/api/deprecated/group.py create mode 100644 swanlab/api/deprecated/main.py create mode 100644 swanlab/api/deprecated/project.py create mode 100644 swanlab/api/deprecated/types.py create mode 100644 swanlab/api/model/__init__.py create mode 100644 swanlab/api/model/base.py create mode 100644 swanlab/api/model/experiment.py rename swanlab/api/{model.py => model/project.py} (75%) create mode 100644 swanlab/api/model/user.py create mode 100644 swanlab/api/thread.py create mode 100644 swanlab/api/utils.py create mode 100644 swanlab/core_python/api/self_hosted.py create mode 100644 swanlab/core_python/api/user.py create mode 100644 swanlab/core_python/api/utils.py delete mode 100644 test/api/project.py diff --git a/.idea/dictionaries/project.xml b/.idea/dictionaries/project.xml index 53ae4b6b9..4e318d48a 100644 --- a/.idea/dictionaries/project.xml +++ b/.idea/dictionaries/project.xml @@ -2,6 +2,7 @@ cunyue + expid expname groupname leveldblog diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 9ce193332..bc31c0f6f 100755 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -23,7 +23,7 @@ SwanLabEnv.check() # 导出 OpenApi 接口,必须要等待上述的 import 语句执行完毕以后才能导出,否则会触发循环引用 -from .api import OpenApi +from .api import OpenApi, Api __version__ = get_package_version() @@ -48,6 +48,7 @@ "get_config", "config", "OpenApi", + "Api", "sync_wandb", "sync_mlflow", "sync_tensorboardX", diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index f7c75dc19..fd67c18f2 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -7,7 +7,129 @@ @Description: SwanLab OpenAPI包 """ +from typing import Optional, Union, List, Dict -from .api import OpenApi +from .deprecated import OpenApi +from .model import ApiUser, SuperUser, Projects, Experiments, Experiment +from .model import ApiUser, SuperUser, Projects, Experiments, Experiment +from ..core_python import auth, Client +from ..core_python.api.experiment import get_single_experiment, get_project_experiments +from ..core_python.api.self_hosted import get_self_hosted_init +from ..error import KeyFileError, ApiError +from ..log import swanlog +from ..package import HostFormatter, get_key -__all__ = ["OpenApi"] +__all__ = ["Api", "OpenApi"] + + +class Api: + def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, web_host: Optional[str] = None): + """ + 初始化 OpenApi 实例,用户需提前登录,或者提供API密钥 + :param api_key: API 密钥,可选 + :param host: API 主机地址,可选 + :param web_host: Web 主机地址,可选 + """ + if host or web_host: + HostFormatter(host, web_host)() + if api_key: + swanlog.debug("Using API key", api_key) + else: + swanlog.debug("Using existing key") + try: + api_key = get_key() + except KeyFileError as e: + swanlog.error("To use SwanLab OpenAPI, please login first.") + raise RuntimeError("Not logged in.") from e + + self._login_info = auth.code_login(api_key, save_key=False) + # 一个OpenApi对应一个client,可创建多个api获取从不同的client获取不同账号下的实验信息 + self._client: Client = Client(self._login_info) + self._web_host = self._login_info.web_host + + def user(self, username: str = None) -> Optional[Union[ApiUser, SuperUser]]: + # 尝试获取私有化服务信息,如果不是私有化服务,则会报错退出,因为指定user功能仅供私有化用户使用 + try: + self_hosted_info = get_self_hosted_init(self._client) + except ApiError as e: + if username is not None: + swanlog.error( + "You haven't launched a swanlab self-hosted instance. Please check your login status using 'swanlab verify'." + ) + raise e + else: + return ApiUser(self._client, self._login_info) + + if not self_hosted_info["enabled"]: + raise RuntimeError("SwanLab self-hosted instance hasn't been ready yet.") + if self_hosted_info["expired"]: + raise RuntimeError("SwanLab self-hosted instance has expired. Please refresh your licence.") + + # 免费版仅能获取当前api_key登录的用户 + if self_hosted_info["plan"] == 'free': + if username != self._login_info.username: + swanlog.warning("Your self-hosted plan is 'free', You will be access to your own account.") + return ApiUser(self._client, self._login_info) + # 商业版的根用户可以获取到任何一个用户 + elif self_hosted_info["plan"] == 'commercial': + if self_hosted_info['root']: + return SuperUser(self._client, self._login_info, self_hosted=self_hosted_info) + elif username != self._login_info.username: + swanlog.warning("Your are not the root user, You will be access to your own account.") + return ApiUser(self._client, self._login_info) + # 为教育版预留功能 + else: + swanlog.warning("The self-hosted plan hasn't been supported yet.") + return None + + def projects( + self, + workspace: str, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ) -> Projects: + """ + 获取指定工作空间(组织)下的所有项目信息 + :param workspace: 工作空间(组织)名称 + :param sort: 排序方式,可选 + :param search: 搜索关键词,可选 + :param detail: 是否返回详细信息,可选 + :return: Projects 实例,可遍历获取项目信息 + """ + return Projects( + client=self._client, + web_host=self._web_host, + workspace=workspace, + sort=sort, + search=search, + detail=detail, + ) + + def runs(self, path: str, filters: Dict[str, object] = None) -> Experiments: + """ + 获取指定项目下的所有实验信息 + :param path: 项目路径,格式为 'username/project' + :return: Experiments 实例,可遍历获取实验信息 + :param filters: 筛选实验的条件,可选 + """ + return Experiments(client=self._client, path=path, web_host=self._web_host, filters=filters) + + def run( + self, + path: str, + ) -> Experiment: + """ + 获取指定实验的信息 + :param path: 实验路径,格式为 'username/project/run_id' + :return: Experiment 实例,包含实验信息 + """ + # TODO: 待后端完善后替换成专用的接口 + if len(path.split('/')) != 3: + raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project/run_id'") + _data = get_single_experiment(self._client, path=path) + proj_path = path.rsplit('/', 1)[0] + data = get_project_experiments( + self._client, path=proj_path, filters={'name': _data['name'], 'created_at': _data['createdAt']} + ) + return Experiment(data=data[0], client=self._client, path=proj_path, web_host=self._web_host, line_count=1) diff --git a/swanlab/api/api.py b/swanlab/api/api.py deleted file mode 100644 index 08c2283b6..000000000 --- a/swanlab/api/api.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -@author: Zhou Qiyang -@file: main.py -@time: 2025/12/17 11:39 -@description: OpenApi 模块 -""" - -from typing import Optional, List - -from swanlab.core_python import auth, Client -from swanlab.error import KeyFileError -from swanlab.log import swanlog -from swanlab.package import get_key, HostFormatter -from .model import Projects - -try: - from pandas import DataFrame -except ImportError: - DataFrame = None - - -class OpenApi: - def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, web_host: Optional[str] = None): - if host or web_host: - HostFormatter(host, web_host)() - if api_key: - swanlog.debug("Using API key", api_key) - else: - swanlog.debug("Using existing key") - try: - api_key = get_key() - except KeyFileError as e: - swanlog.error("To use SwanLab OpenAPI, please login first.") - raise RuntimeError("Not logged in.") from e - - login_info = auth.code_login(api_key, save_key=False) - # 一个OpenApi对应一个client,可创建多个api获取从不同的client获取不同账号下的实验信息 - self._client: Client = Client(login_info) - self._web_host = login_info.web_host - - def projects( - self, - workspace: str, - sort: Optional[List[str]] = None, - search: Optional[str] = None, - detail: Optional[bool] = True, - ) -> Projects: - return Projects( - client=self._client, - web_host=self._web_host, - workspace=workspace, - sort=sort, - search=search, - detail=detail, - ) diff --git a/swanlab/api/deprecated/__init__.py b/swanlab/api/deprecated/__init__.py new file mode 100644 index 000000000..81840912f --- /dev/null +++ b/swanlab/api/deprecated/__init__.py @@ -0,0 +1,10 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2025/12/30 20:54 +@description: 旧版OpenApi,即将遗弃 +""" + +from .main import OpenApi + +__all__ = ["OpenApi"] \ No newline at end of file diff --git a/swanlab/api/deprecated/base.py b/swanlab/api/deprecated/base.py new file mode 100644 index 000000000..ab6a25896 --- /dev/null +++ b/swanlab/api/deprecated/base.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2025/5/1 17:36 +@File: base.py +@IDE: pycharm +@Description: + SwanLab OpenAPI API基类 +""" +import json +from datetime import datetime, timezone +from typing import Any, Callable, List, Optional, Union + +import requests + +from swanlab.core_python import auth, create_session +from swanlab.log.log import SwanLog +from .types import ApiResponse + +_logger: Optional[SwanLog] = None + +def get_logger(log_level: str = "info") -> SwanLog: + global _logger + if _logger is None: + _logger = SwanLog("swanlab.openapi", log_level) + else: + _logger.level = log_level + return _logger + +def handle_response(resp: requests.Response) -> ApiResponse: + try: + data = resp.json() if resp.content else {} + except (json.decoder.JSONDecodeError, requests.JSONDecodeError): + return ApiResponse[str]( + code=resp.status_code, + errmsg="sdk decode json error", + data=resp.text + ) + + if not isinstance(data, (dict, list)): + return ApiResponse[Any]( + code=resp.status_code, + errmsg="sdk decode dict error", + data=data + ) + + code = resp.status_code + if 200 <= code < 300: + message = "" + else: + message = f"api error: {resp.reason}. Trace id: {resp.headers.get('traceid')}" + return ApiResponse( + code=code, + errmsg=message, + data=data + ) + + +class ApiHTTP: + REFRESH_TIME = 60 * 60 * 24 * 7 # 7天 + + def __init__(self, login_info: auth.LoginInfo): + self.__logger = get_logger() + self.__login_info: auth.LoginInfo = login_info + self.__session: requests.Session = self.__init_session() + self.service: OpenApiService = OpenApiService(self) + + @property + def session(self) -> requests.Session: + """ + 获取当前的requests.Session对象 + """ + return self.__session + + @property + def username(self) -> str: + """ + 当前登录的用户名 + """ + return self.__login_info.username or "" + + @property + def base_url(self): + return self.__login_info.api_host + + @property + def sid_expired_at(self): + """ + 获取sid的过期时间 + """ + return datetime.strptime(self.__login_info.expired_at or "", "%Y-%m-%dT%H:%M:%S.%fZ") + + def __init_session(self) -> requests.Session: + session = create_session() + session.cookies.update({"sid": self.__login_info.sid or ""}) + return session + + def __before_request(self): + if (self.sid_expired_at - datetime.now(timezone.utc).replace(tzinfo=None)).total_seconds() < self.REFRESH_TIME: + self.__logger.debug("Refreshing sid...") + self.__login_info = auth.login_by_key(self.__login_info.api_key or "", save=False) + self.__session.headers["cookie"] = f"sid={self.__login_info.sid}" + + def get(self, url: str, params: dict) -> ApiResponse: + self.__before_request() + resp = self.__session.get(self.base_url + url, params=params) + return handle_response(resp) + + def post(self, url: str, data: Union[dict, list], params: dict) -> ApiResponse: + self.__before_request() + resp = self.__session.post(self.base_url + url, json=data, params=params) + return handle_response(resp) + + def delete(self, url: str, params: dict) -> ApiResponse: + self.__before_request() + resp = self.__session.delete(self.base_url + url, params=params) + return handle_response(resp) + +class OpenApiService: + def __init__(self, http: ApiHTTP): + self.http: ApiHTTP = http + + def get_exp_info(self, username: str, project: str, exp_id: str) -> ApiResponse[dict]: + """ + 获取实验信息 + """ + return self.http.get(f"/project/{username}/{project}/runs/{exp_id}", params={}) + + def get_project_info(self, username: str, projname: str) -> ApiResponse[dict]: + """ + 获取项目详情 + """ + return self.http.get(f"/project/{username}/{projname}", params={}) + + @staticmethod + def fetch_paginated_api( + api_func: Callable[..., ApiResponse], # 分页 API 请求函数 + page_field: str = "page", + size_field: str = "size", + page_size: int = 10, + *args, **kwargs + ) -> ApiResponse[List]: + """ + 通用分页全量拉取函数 + + Args: + api_func (Callable): 分页 API 请求函数,应返回 ApiResponse[Pagination] + page_field (str): 页码参数名,默认为 "page" + size_field (str): 每页大小参数名,默认为 "size" + page_size (int): 每页条数,默认为 10 + *args: 传递给 api_func 的位置参数 + **kwargs: 传递给 api_func 的关键字参数 + + Returns: + ApiResponse[list]: 返回所有分页数据组成的 ApiResponse + """ + page = 1 + objs = [] + while True: + kwargs.update({page_field: page, size_field: page_size}) + resp: ApiResponse = api_func(*args, **kwargs) + if resp.errmsg: + break + + objs.extend(resp.data.list) + + if len(objs) >= resp.data.total: + break + page += 1 + + return ApiResponse[List](code=resp.code, errmsg=resp.errmsg, data=objs) + +class ApiBase: + def __init__(self, http: ApiHTTP): + self.http: ApiHTTP = http diff --git a/swanlab/api/deprecated/experiment.py b/swanlab/api/deprecated/experiment.py new file mode 100644 index 000000000..ebf0ad635 --- /dev/null +++ b/swanlab/api/deprecated/experiment.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2025/5/1 17:30 +@File: experiment.py +@IDE: pycharm +@Description: + 实验相关的开放API +""" +from typing import List + +from .base import ApiBase, ApiHTTP +from .types import ApiResponse, Experiment, Pagination + +try: + from pandas import DataFrame +except ImportError: + DataFrame = None + + +class ExperimentAPI(ApiBase): + def __init__(self, http: ApiHTTP): + super().__init__(http) + + @classmethod + def parse(cls, body: dict) -> Experiment: + return Experiment.model_validate({ + "cuid": body.get("cuid") or "", + "name": body.get("name") or "", + "description": body.get("description") or "", + "state": body.get("state") or "", + "show": bool(body.get("show")), + "createdAt": body.get("createdAt") or "", + "finishedAt": body.get("finishedAt") or "", + "user": { + "username": (body.get("user") or {}).get("username") or "", + "name": (body.get("user") or {}).get("name") or "", + }, + "profile": body.get("profile") or {} + }) + + def get_experiment( + self, + username: str, + projname: str, + exp_id: str + ) -> ApiResponse[Experiment]: + """ + 获取实验信息 + + Args: + username (str): 工作空间名 + projname (str): 项目名 + exp_id (str): 实验CUID + """ + resp: ApiResponse = self.http.service.get_exp_info(username=username, project=projname, exp_id=exp_id) + if resp.errmsg: + return resp + resp.data = ExperimentAPI.parse(resp.data) + return resp + + def delete_experiment( + self, + username: str, + projname: str, + exp_id: str + ): + """ + 删除实验 + + Args: + username (str): 工作空间名 + projname (str): 项目名 + exp_id (str): 实验CUID + """ + return self.http.delete(f"/project/{username}/{projname}/runs/{exp_id}", params={}) + + def list_experiments( + self, + username: str, + projname: str, + page: int = 1, + size: int = 10 + ) -> ApiResponse[Pagination[Experiment]]: + """ + 分页获取项目下的实验列表 + 该接口返回的实验profile只包含config(用户自定义的配置) + + Args: + username (str): 工作空间名 + projname (str): 项目名 + page (int): 页码, 默认为1 + size (int): 每页大小, 默认为10 + """ + resp = self.http.get(f"/project/{username}/{projname}/runs", params={"page": page, "size": size}) + + if resp.errmsg: + return resp + exps = resp.data + resp.data = Pagination[Experiment].model_validate({ + "total": exps.get("total", 0), + "list": [ExperimentAPI.parse(e) for e in exps.get("list", [])] + }) + return resp + + def get_summary( + self, + exp_id: str, + pro_id: str, + root_exp_id: str, + root_pro_id: str + ) -> ApiResponse[dict]: + """ + 获取实验的summary信息 + 从House获取, 需要考虑克隆实验 + + Args: + exp_id (str): 实验CUID + pro_id (str): 项目CUID + root_exp_id (str): 根实验CUID + root_pro_id (str): 根项目CUID + """ + data = { + "experimentId": exp_id, + "projectId": pro_id, + } + if root_exp_id and root_pro_id: + data["rootExpId"] = root_exp_id + data["rootProId"] = root_pro_id + + resp = self.http.post("/house/metrics/summaries", data=[data], params={}) + if resp.errmsg: + return resp + + resp.data = list(resp.data.values())[0] + resp.data = { + k: { + "step": v.get("step"), + "value": v.get("value"), + "min": { + "step": v.get("min").get("index"), + "value": v.get("min").get("data"), + }, + "max": { + "step": v.get("max").get("index"), + "value": v.get("max").get("data"), + } + } + for k, v in resp.data.items() + } + return resp + + def get_metrics( + self, + exp_id: str, + keys: List[str], + ) -> ApiResponse[DataFrame]: + """ + 获取实验的指标数据, 可选择若干由用户自定义的列 + + Args: + exp_id (str): 实验CUID + keys (list[str]): 指标key列表 + + Returns: + ApiResponse[DataFrame]: + """ + try: + import pandas as pd + except ImportError: + raise ImportError("OpenApi.get_metrics requires pandas module. Install with 'pip install pandas'.") + + # 去重 keys + keys = list(set(keys)) + dfs = [] + prefix = "" + for idx, key in enumerate(keys): + resp = self.http.get(f"/experiment/{exp_id}/column/csv", params={"key": key}) + if resp.errmsg: + continue + + url:str = resp.data.get("url", "") + df = pd.read_csv(url, index_col=0) + + if idx == 0: + # 从第一列名提取 prefix,例如 "t0707-02:17-loss_step" 中提取 "t0707-02:17-" + first_col = df.columns[0] + suffix = f"{key}_" + if suffix in first_col: + prefix = first_col.split(suffix)[0] # 结果为 "t0707-02:17-" + else: + prefix = "" + + if prefix: + df.columns = [ + col[len(prefix):].removesuffix("_step") if col.startswith(prefix) else col.removesuffix("_step") + for col in df.columns + ] + else: + df.columns = [col.removesuffix("_step") for col in df.columns] + + dfs.append(df) + + if not dfs: + return ApiResponse[DataFrame](code=404, errmsg="No data found", data=pd.DataFrame()) + + # 按列合并,使用 inner join 保证对齐 index + result_df = pd.concat(dfs, axis=1, join="inner") + + return ApiResponse[DataFrame](code=200, errmsg="", data=result_df) diff --git a/swanlab/api/deprecated/group.py b/swanlab/api/deprecated/group.py new file mode 100644 index 000000000..71ac04b9d --- /dev/null +++ b/swanlab/api/deprecated/group.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2025/4/30 12:11 +@File: group.py +@IDE: pycharm +@Description: + 组织相关的开放API +""" + +from .base import ApiBase, ApiHTTP +from .types import ApiResponse + + +class GroupAPI(ApiBase): + def __init__(self, http: ApiHTTP): + super().__init__(http) + + def list_workspaces(self) -> ApiResponse[list]: + resp = self.http.get("/group/", params={}) + if resp.errmsg: + return resp + groups = resp.data.get("list", []) + resp.data = [ + { + "name": item["name"], + "username": item["username"], + "role": item["role"] + } + for item in groups + ] diff --git a/swanlab/api/deprecated/main.py b/swanlab/api/deprecated/main.py new file mode 100644 index 000000000..711a47764 --- /dev/null +++ b/swanlab/api/deprecated/main.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2025/5/1 17:00 +@File: main.py +@IDE: pycharm +@Description: + SwanLab OpenAPI模块 +""" +from typing import Dict, List, Union + +from swanlab.core_python import auth +from swanlab.error import KeyFileError +from swanlab.log.log import SwanLog +from swanlab.package import get_key +from .base import ApiHTTP, get_logger +from .experiment import ExperimentAPI +from .group import GroupAPI +from .project import ProjectAPI +from .types import ApiResponse, Experiment, Project + +try: + from pandas import DataFrame +except ImportError: + DataFrame = None + + +class OpenApi: + def __init__(self, api_key: str = "", log_level: str = "info"): + self.__logger: SwanLog = get_logger(log_level) + self.__logger.warning("OpenApi will be soon deprecated in swanlab 0.8.0. Please use swanlab.Api() instead.") + + if api_key: + self.__logger.debug("Using API key", api_key) + self.__key = api_key + self.login_info = auth.code_login(self.__key, False) + else: + self.__logger.debug("Using existing key") + try: + self.__key = get_key() + except KeyFileError as e: + self.__logger.error("To use SwanLab OpenAPI, please login first.") + raise RuntimeError("Not logged in.") from e + self.login_info = auth.code_login(self.__key, False) + + self.username = self.login_info.username + self.http: ApiHTTP = ApiHTTP(self.login_info) + self.service = self.http.service + + self.group = GroupAPI(self.http) + self.experiment = ExperimentAPI(self.http) + self.project = ProjectAPI(self.http) + + def list_workspaces(self) -> ApiResponse[List[Dict]]: + """ + 获取当前用户的所有工作空间(Group) + + Returns: + ApiResponse[List]: + - code (int): HTTP 状态码 + - errmsg (str): 错误信息, 仅在请求有错误时非空 + - data (List[Dict]): 一个列表, 其中每个元素是一个字典, 包含相应工作空间的基础信息: + - name (str): 工作空间名称 + - username (str): 工作空间名(用于组织相关的 URL) + - role (str): 用户在该工作空间中的角色,如 'OWNER' 或 'MEMBER' + """ + return self.group.list_workspaces() + + def get_experiment( + self, + project: str, + exp_id: str, + username: str = "" + ) -> ApiResponse[Experiment]: + """ + 获取实验信息 + + Args: + project (str): 项目名 + exp_id (str): 实验CUID + username (str): 工作空间名, 默认为用户个人空间 + + Returns: + ApiResponse[Experiment]: + - code (int): HTTP 状态码 + - errmsg (str): 错误信息, 仅在请求有错误时非空 + - data (Dict): 实验信息的字典, 包含实验信息 + """ + return self.experiment.get_experiment( + username=username if username else self.http.username, projname=project, exp_id=exp_id + ) + + def delete_experiment( + self, + project: str, + exp_id: str, + username: str = "" + ) -> ApiResponse[None]: + """ + 删除实验 + + Args: + project (str): 项目名 + exp_id (str): 实验CUID + username (str): 工作空间名, 默认为用户个人空间 + + Returns: + ApiResponse[None]: + - code (int): HTTP 状态码 + - errmsg (str): 错误信息, 仅在请求有错误时非空 + - data (None): 无数据返回 + """ + return self.experiment.delete_experiment( + username=username if username else self.http.username, projname=project, exp_id=exp_id + ) + + def list_experiments( + self, + project: str, + username: str = "" + ) -> ApiResponse[List[Experiment]]: + """ + 获取一个项目下的所有实验 + + Args: + project (str): 项目名 + username (str): 工作空间名, 默认为用户个人空间 + + Returns: + ApiResponse[Experiment]: + - code (int): HTTP 状态码 + - errmsg (str): 错误信息, 仅在请求有错误时非空 + - data (List[Experiment]): 实验列表, 每个元素包含一个实验的信息 + - 此实验的 profile 只包含 config (实验自定义配置) + """ + return self.service.fetch_paginated_api( + api_func=self.experiment.list_experiments, + projname=project, + username=username if username else self.http.username + ) + + def delete_project( + self, + project: str, + username: str = "" + ) -> ApiResponse[None]: + """ + 删除一个项目 + + Args: + project (str): 项目名 + username (str): 工作空间名, 默认为用户个人空间 + + Returns: + ApiResponse[None]: + - code (int): HTTP 状态码 + - errmsg (str): 错误信息, 仅在请求有错误时非空 + - data (None): 无数据返回 + """ + return self.project.delete_project( + username=username if username else self.http.username, project=project + ) + + def list_projects( + self, + username: str = "", + detail: bool = True + ) -> ApiResponse[List[Project]]: + """ + 获取一个工作空间下的所有项目 + + Args: + username (str): 工作空间名, 默认为用户个人空间 + detail (bool): 是否包含详细统计信息,默认为 True + + Returns: + ApiResponse[List[Project]]: + - code (int): HTTP 状态码 + - errmsg (str): 错误信息, 仅在请求有错误时非空 + - data (List[Project]): 项目列表, 每个元素包含一个项目的信息 + """ + return self.service.fetch_paginated_api( + api_func=self.project.list_projects, + username=username if username else self.http.username, + detail=detail + ) + + def get_summary( + self, + project: str, + exp_id: str, + username: str = "" + ) -> ApiResponse[Dict]: + """ + 获取实验的概要信息 + + Args: + project (str): 项目名 + exp_id (str): 实验CUID + username (str): 工作空间名, 默认为用户个人空间 + + Returns: + ApiResponse[Dict]: + - code (int): HTTP 状态码 + - errmsg (str): 错误信息, 仅在请求有错误时非空 + - data (Dict): 实验的概要信息字典, 包含用户训练各指标的最大最小值, 及其对应步数 + """ + username = username if username else self.http.username + project_cuid = self.service.get_project_info(username=username, projname=project).data.get("cuid", "") + exp = self.service.get_exp_info(username=username, project=project, exp_id=exp_id) + return self.experiment.get_summary( + exp_id=exp_id, + pro_id=project_cuid, + root_exp_id=exp.data.get("rootExpId", ""), + root_pro_id=exp.data.get("rootProId", "") + ) + + def get_metrics( + self, + exp_id: str, + keys: Union[str, List[str]], + ) -> ApiResponse[DataFrame]: + """ + 获取实验的指标数据 + + Args: + exp_id (str): 实验CUID + keys (str | List[str]): 指标key, 单个字符串或字符串列表 + + Returns: + ApiResponse[DataFrame]: 包含指标数据的响应, 指标数据以 DataFrame 格式返回 + 在DataFrame中, 每个key对应两个列, 分别为key和key_timestamp, 表示指标值和时间戳 + """ + if isinstance(keys, str): + keys = [keys] + return self.experiment.get_metrics(exp_id, keys) + + def get_exp_summary(self, *args, **kwargs) -> ApiResponse[Dict]: + """ + 获取实验的概要信息 + @deprecated, 请使用 `get_experiment_summary` + + Returns: + ApiResponse[Dict]: 包含实验概要信息的响应 + """ + return self.get_summary(*args, **kwargs) + + def list_project_exps(self, *args, **kwargs) -> ApiResponse[List[Experiment]]: + """ + 获取一个项目下的所有实验 + @deprecated, 请使用 `list_experiments` + + Returns: + ApiResponse[List[Experiment]]: 包含实验列表的响应 + """ + return self.list_experiments(*args, **kwargs) diff --git a/swanlab/api/deprecated/project.py b/swanlab/api/deprecated/project.py new file mode 100644 index 000000000..f20e5f612 --- /dev/null +++ b/swanlab/api/deprecated/project.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2025/5/1 17:30 +@File: project.py +@IDE: pycharm +@Description: + 项目相关的开放API +""" + +from .base import ApiBase, ApiHTTP +from .types import ApiResponse, Pagination, Project + + +class ProjectAPI(ApiBase): + def __init__(self, http: ApiHTTP): + super().__init__(http) + + @classmethod + def parse(cls, body: dict, detail=True) -> Project: + project_parser = { + "cuid": body.get("cuid") or "", + "name": body.get("name") or "", + "description": body.get("description") or "", + "visibility": body.get("visibility") or "", + "createdAt": body.get("createdAt") or "", + "updatedAt": body.get("updatedAt") or "", + "group": { + "type": (body.get("group") or {}).get("type") or "", + "username": (body.get("group") or {}).get("username") or "", + "name": (body.get("group") or {}).get("name") or "", + }, + } + if detail: + project_parser["count"] = body.get("_count") or {} + return Project.model_validate(project_parser) + + def delete_project( + self, username: str, project: str + ) -> ApiResponse[None]: + """ + 删除一个项目 + + Args: + username (str): 工作空间名 + project (str): 项目名 + """ + return self.http.delete(f"/project/{username}/{project}", params={}) + + def list_projects( + self, username: str, detail = True, page: int = 1, size: int = 10 + ) -> ApiResponse[Pagination[Project]]: + """ + 列出一个 workspace 下的所有项目 + + Args: + username (str): 工作空间名, 默认为用户个人空间 + detail (bool): 是否返回详细统计信息,默认为 True + page (int): 页码,默认为 1 + size (int): 每页数量,默认为 10 + """ + resp = self.http.get(f"/project/{username}", params={"detail": detail, "page": page, "size": size}) + if resp.errmsg: + return resp + + resp.data = Pagination[Project].model_validate({ + "total": resp.data.get("total", 0), + "list": [ProjectAPI.parse(project, detail=detail) for project in resp.data.get("list", [])] + }) + + return resp diff --git a/swanlab/api/deprecated/types.py b/swanlab/api/deprecated/types.py new file mode 100644 index 000000000..a3806e177 --- /dev/null +++ b/swanlab/api/deprecated/types.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2025/5/8 14:36 +@File: types.py +@IDE: pycharm +@Description: + OpenAPI 相关数据结构 +""" + +from typing import Dict, Generic, List, TypeVar + +from pydantic import BaseModel as PydanticBaseModel +from pydantic import ConfigDict + + +class BaseModel(PydanticBaseModel): + def __getitem__(self, key): + return getattr(self, key) + + +class Experiment(BaseModel): + cuid: str # 实验CUID, 唯一标识符 + name: str # 实验名 + description: str = "" # 实验描述 + state: str # 实验状态, 'FINISHED' 或 'RUNNING' + show: bool # 显示状态 + createdAt: str # e.g., '2024-11-23T12:28:04.286Z' + finishedAt: str = "" # e.g., '2024-11-23T12:28:04.286Z' + user: Dict[str, str] # 实验创建者, 包含 'username' 与 'name' + profile: Dict # 实验相关配置 + + +class Project(BaseModel): + cuid: str # 项目CUID, 唯一标识符 + name: str # 项目名 + description: str = "" # 项目描述 + visibility: str # 可见性, 'PUBLIC' 或 'PRIVATE' + createdAt: str # e.g., '2024-11-23T12:28:04.286Z' + updatedAt: str # e.g., '2024-11-23T12:28:04.286Z' + group: Dict[str, str] # 工作空间信息, 包含 'type', 'username', 'name' + count: Dict[str, int] = {} # 项目的统计信息 + + +D = TypeVar("D") + + +class ApiResponse(BaseModel, Generic[D]): + code: int # HTTP状态码 + errmsg: str # API错误消息, 只有请求错误时非空 + data: D # 返回数据 + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class Pagination(BaseModel, Generic[D]): + total: int # 总数 + list: List[D] # 列表数据,泛型 diff --git a/swanlab/api/model/__init__.py b/swanlab/api/model/__init__.py new file mode 100644 index 000000000..e838c7eb7 --- /dev/null +++ b/swanlab/api/model/__init__.py @@ -0,0 +1,12 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/5 17:59 +@description: OpenApi 中包含的对象 +""" + +from .experiment import Experiment, Experiments +from .project import Projects +from .user import ApiUser, SuperUser + +__all__ = ['Experiment', 'Experiments', 'Projects', 'ApiUser', 'SuperUser'] diff --git a/swanlab/api/model/base.py b/swanlab/api/model/base.py new file mode 100644 index 000000000..5d2efc8d8 --- /dev/null +++ b/swanlab/api/model/base.py @@ -0,0 +1,60 @@ +""" +@author: Zhou Qiyang +@file: model.py +@time: 2025/12/18 20:10 +@description: OpenApi 中的基础对象 +""" + +from typing import Dict + +from swanlab.core_python.api.type import ProjectLabelType, UserType + + +class ApiBase: + @property + def __dict__(self) -> Dict[str, object]: + """ + Return a dictionary containing all @property fields. + """ + result = {} + cls = type(self) + for attr_name in dir(cls): + if attr_name.startswith('_'): + continue + attr = getattr(cls, attr_name, None) + if isinstance(attr, property): + result[attr_name] = self.__getattribute__(attr_name) + return result + + +class Label(ApiBase): + """ + Project label object + you can get the label name by str(label) + """ + + def __init__(self, data: ProjectLabelType) -> None: + self._data = data + + @property + def name(self) -> str: + """ + Label name. + """ + return self._data['name'] + + def __str__(self) -> str: + return str(self.name) + + +class User(ApiBase): + def __init__(self, data: UserType) -> None: + self._data = data + + @property + def name(self) -> str: + return self._data['name'] + + @property + def username(self) -> str: + return self._data['username'] diff --git a/swanlab/api/model/experiment.py b/swanlab/api/model/experiment.py new file mode 100644 index 000000000..355a72332 --- /dev/null +++ b/swanlab/api/model/experiment.py @@ -0,0 +1,242 @@ +""" +@author: Zhou QiYang +@file: experiment.py +@time: 2026/1/5 17:58 +@description: OpenApi 中的实验对象 +""" + +from typing import TYPE_CHECKING, List, Dict, Any, Iterator + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + +from swanlab.log import swanlog +from swanlab.core_python.api.experiment import get_project_experiments +from swanlab.core_python.api.type import RunType +from swanlab.api.thread import HistoryPool +from swanlab.api.utils import flatten_runs + +from .base import ApiBase, Label, User + + +class Experiment(ApiBase): + def __init__(self, data: RunType, client: "Client", path: str, web_host: str, line_count: int) -> None: + self._data = data + self._client = client + self._path = path + self._web_host = web_host + self._line_count = line_count + + @property + def name(self) -> str: + """ + Experiment name. + """ + return self._data['name'] + + @property + def id(self) -> str: + """ + Experiment CUID. + """ + return self._data['cuid'] + + @property + def url(self) -> str: + """ + Full URL to access the experiment. + """ + return f"{self._web_host}/@{self._path}/runs/{self.id}/chart" + + @property + def created_at(self) -> str: + """ + Experiment creation timestamp + """ + return self._data['createdAt'] + + @property + def description(self) -> str: + """ + Experiment description. + """ + return self._data['description'] + + @property + def labels(self) -> List[Label]: + """ + List of Label attached to this experiment. + """ + return [Label(label) for label in self._data['labels']] + + @property + def config(self) -> Dict[str, object]: + """ + Experiment configuration. Can be used as filter in the format of 'config.' + """ + return self._data['profile']['config'] + + @property + def summary(self) -> Dict[str, object]: + """ + Experiment metrics data. Can be used as filter in the format of 'summary.' + """ + return self._data['profile']['scalar'] + + @property + def state(self) -> str: + """ + Experiment state. + """ + return self._data['state'] + + @property + def group(self) -> str: + """ + Experiment group. + """ + return self._data['cluster'] + + @property + def job(self) -> str: + """ + Experiment job type. + """ + return self._data['job'] + + @property + def user(self) -> User: + """ + Experiment user. + """ + return User(self._data['user']) + + @property + def metric_keys(self) -> List[str]: + """ + List of metric keys. + """ + summary_keys = self.summary.keys() + return list(summary_keys) + + @property + def history_line_count(self) -> int: + """ + The number of historical experiments in this project. + """ + return self._line_count + + @property + def root_exp_id(self) -> str: + """ + Root experiment cuid. If the experiment is a root experiment, it will be None. + """ + return self._data['rootExpId'] + + @property + def root_pro_id(self) -> str: + """ + Root project cuid. If the experiment is a root experiment, it will be None. + """ + return self._data['rootProId'] + + def __full_history(self) -> Any: + """ + Get all metric keys' data of the experiment with timestamp. + """ + try: + import pandas as pd + except ImportError: + raise TypeError( + "OpenApi requires pandas to implement the run.history(). Please install with 'pip install pandas'." + ) + + df = pd.DataFrame() + if len(self.metric_keys) >= 1: + pool = HistoryPool(self._client, self.id, keys=self.metric_keys) + df = pool.execute() + + return df + + def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None, pandas: bool = True) -> Any: + """ + Get specific metric data of the experiment. + :param keys: List of metric keys to obtain. If None, all metrics keys will be used. + :param x_axis: The metric to be used as x-axis. If None, '_step' will be used as the x-axis. + :param sample: Number of rows to select from the beginning. + :param pandas: Whether to return a pandas DataFrame. If False, returns dict format: {key: [values], ...} + :return: Metric data. + + Example: + ```python + api = swanlab.OpenApi() + exp = api.run(path="username/project/expid") # You can get expid from api.runs() + print(exp.history(keys=['loss'], sample=20, x_axis='t/accuracy')) + + Returns: + t/accuracy loss + 0 0.310770 0.525776 + 1 0.642817 0.479186 + 2 0.646031 0.362428 + 3 0.608820 0.230555 + ... + 19 0.791999 0.180106 + ``` + """ + try: + import pandas as pd + except ImportError: + raise TypeError( + "OpenApi requires pandas to implement the run.history(). Please install with 'pip install pandas'." + ) + + if keys is not None and not isinstance(keys, list): + swanlog.warning('keys must be specified as a list') + return pd.DataFrame() + elif keys is not None and len(keys) and not all(isinstance(k, str) for k in keys): + swanlog.warning('keys must be a list of string') + return pd.DataFrame() + + if keys is None and x_axis is None: + # x轴与keys都未指定时,获取所有指标数据 + df = self.__full_history() + else: + # 使用线程池并发获取所有的key的指标数据 + pool = HistoryPool(self._client, self.id, keys=keys, x_axis=x_axis) + df = pool.execute() + + # 截取前sample行 + if sample is not None: + df = df.head(sample) + + return df if pandas else df.to_dict(orient='records') + + + + +class Experiments(ApiBase): + """ + Container for a collection of Experiment objects. + You can iterate over the experiments by for-in loop. + """ + + def __init__(self, client: "Client", path: str, web_host: str, filters: Dict[str, object] = None) -> None: + if len(path.split('/')) != 2: + raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project'") + self._client = client + self._path = path + self._web_host = web_host + self._filters = filters + + def __iter__(self) -> Iterator[Experiment]: + # TODO: 完善filter的功能(正则、条件判断) + resp = get_project_experiments(self._client, path=self._path, filters=self._filters) + runs: List[RunType] = [] + if isinstance(resp, List): + runs = resp + # 分组时需展平实验数据 + elif isinstance(resp, Dict): + runs = flatten_runs(resp) + line_count = len(runs) + yield from iter(Experiment(run, self._client, self._path, self._web_host, line_count) for run in runs) + diff --git a/swanlab/api/model.py b/swanlab/api/model/project.py similarity index 75% rename from swanlab/api/model.py rename to swanlab/api/model/project.py index ccd628e19..438a2b4a9 100644 --- a/swanlab/api/model.py +++ b/swanlab/api/model/project.py @@ -1,43 +1,27 @@ """ -@author: Zhou Qiyang -@file: model.py -@time: 2025/12/18 20:10 -@description: OpenApi查询结果将以对象返回,并且对后端的返回字段进行一些筛选 +@author: Zhou QiYang +@file: project.py +@time: 2026/1/5 17:58 +@description: OpenApi 中的项目对象 """ -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Iterator, TYPE_CHECKING -from swanlab.core_python import Client -from swanlab.core_python.api.project import get_entity_projects -from swanlab.core_python.api.type import ProjectType, ProjectLabelType, ProjResponseType +if TYPE_CHECKING: + from swanlab.core_python.client import Client +from swanlab.core_python.api.project import get_workspace_projects +from swanlab.core_python.api.type import ProjectType, ProjResponseType -class Label: - """ - Project label object - you can get the label name by str(label) - """ - - def __init__(self, data: ProjectLabelType): - self._data = data - - @property - def name(self): - """ - Label name. - """ - return self._data['name'] - - def __str__(self): - return str(self.name) +from .base import ApiBase, Label -class Project: +class Project(ApiBase): """ Representing a single project with some of its properties. """ - def __init__(self, data: ProjectType, web_host: str): + def __init__(self, data: ProjectType, web_host: str) -> None: self._data = data self._web_host = web_host @@ -113,7 +97,7 @@ def count(self) -> Dict[str, int]: return self._data['_count'] -class Projects: +class Projects(ApiBase): """ Container for a collection of Project objects. You can iterate over the projects by for-in loop. @@ -121,13 +105,13 @@ class Projects: def __init__( self, - client: Client, + client: "Client", web_host: str, workspace: str, sort: Optional[List[str]] = None, search: Optional[str] = None, detail: Optional[bool] = True, - ): + ) -> None: self._client = client self._web_host = web_host self._workspace = workspace @@ -135,13 +119,13 @@ def __init__( self._search = search self._detail = detail - def __iter__(self): + def __iter__(self) -> Iterator[Project]: # 按用户遍历情况获取项目信息 cur_page = 0 page_size = 20 while True: cur_page += 1 - projects_info: ProjResponseType = get_entity_projects( + projects_info: ProjResponseType = get_workspace_projects( self._client, workspace=self._workspace, page=cur_page, diff --git a/swanlab/api/model/user.py b/swanlab/api/model/user.py new file mode 100644 index 000000000..b184e30bc --- /dev/null +++ b/swanlab/api/model/user.py @@ -0,0 +1,83 @@ +""" +@author: Zhou QiYang +@file: user.py +@time: 2026/1/5 17:58 +@description: OpenApi 中的用户对象 +""" + +import re +from typing import TYPE_CHECKING, List, Optional + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + +from swanlab.api.utils import STATUS_CREATED, STATUS_OK +from swanlab.core_python.api.self_hosted import create_user +from swanlab.core_python.api.type import ApiKeyType, SelfHostedInfoType +from swanlab.core_python.api.user import ( + get_user_groups, + get_api_keys, + create_api_key, + get_latest_api_key, + delete_api_key, +) +from swanlab.core_python.auth.providers.api_key import LoginInfo + +from .base import ApiBase + + +class ApiUser(ApiBase): + def __init__(self, client: "Client", login_info: LoginInfo) -> None: + super().__init__() + self._client = client + self._login_info = login_info + self._api_keys: List[ApiKeyType] = [] + + @property + def username(self) -> str: + return self._login_info.username + + @property + def teams(self) -> List[str]: + resp = get_user_groups(self._client, username=self.username) + return [r['name'] for r in resp] + + @property + def api_keys(self) -> List[str]: + self._api_keys = get_api_keys(self._client) + return [r['key'] for r in self._api_keys] + + def generate_api_key(self, description: str = None) -> Optional[str]: + api_key: Optional[ApiKeyType] = None + res = create_api_key(self._client, name=description) + if res == STATUS_CREATED: + api_key = get_latest_api_key(self._client) + return api_key['key'] if api_key else None + + def delete_api_key(self, api_key: str) -> bool: + self._api_keys = get_api_keys(self._client) + for key in self._api_keys: + if key['key'] == api_key: + res = delete_api_key(self._client, key_id=key['id']) + if res == STATUS_OK: + return True + return False + + +class SuperUser(ApiUser): + def __init__(self, client: "Client", login_info: LoginInfo, self_hosted: SelfHostedInfoType) -> None: + super().__init__(client, login_info) + self._self_hosted_info = self_hosted + + def create(self, username: str, password: str) -> bool: + # 用户名为大小写字母、数字及-、_组成 + # 密码必须包含数字+英文且至少8位 + if not re.match(r'^[a-zA-Z0-9_-]+$', username): + raise ValueError("Username must be alphanumeric and can contain - and _") + if not re.match(r'^(?=.*[0-9])(?=.*[a-zA-Z]).{8,}$', password): + raise ValueError("Password must contain at least 8 characters and include numbers and letters") + resp = create_user(self._client, username=username, password=password) + if resp == STATUS_CREATED: + return True + else: + raise False diff --git a/swanlab/api/thread.py b/swanlab/api/thread.py new file mode 100644 index 000000000..eaedd7d75 --- /dev/null +++ b/swanlab/api/thread.py @@ -0,0 +1,97 @@ +""" +@author: Zhou QiYang +@file: thread.py +@time: 2025/12/30 15:08 +@description: 用于api并发请求的封装类 +""" + +from concurrent.futures import ThreadPoolExecutor +from io import BytesIO +from typing import List, Any, TYPE_CHECKING + +import requests + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + +from swanlab.core_python.api.experiment import get_experiment_metrics +from swanlab.log import swanlog + + +class HistoryPool: + + def __init__(self, client: "Client", expid: str, *, keys: List[str], x_axis: str = None, num_threads: int = 10): + try: + import pandas as pd + except ImportError: + raise TypeError("Api requires pandas to init the HistoryPool. Please install with 'pip install pandas'.") + + self._client = client + self._expid = expid + self._keys = keys + self._x_axis = x_axis + if self._x_axis is not None: + self._keys = [self._x_axis] + [k for k in self._keys if k != self._x_axis] + self._num_threads = num_threads + + # 使用 _results 字典收集每个 key 的 DataFrame,最后统一按顺序合并到 _history + self._executor = ThreadPoolExecutor(max_workers=self._num_threads) + self._futures = [] + self._results = dict() + self._history = pd.DataFrame() + + def _task(self, key: str): + """ + 处理单个key,获取对应csv + """ + import pandas as pd + + try: + csv_df = pd.DataFrame() + resp = get_experiment_metrics(self._client, expid=self._expid, key=key) + # 从返回网址中解析csv内容 + with requests.get(resp['url']) as response: + csv_df = pd.read_csv(BytesIO(response.content)) + return key, csv_df + except Exception as e: + swanlog.warning(f'Error processing key {key} in experiment {self._expid}: {e}') + return key, pd.DataFrame() + + def execute(self) -> Any: + if not self._keys: + return self._history + + # 将所有key提交到线程池 + for key in self._keys: + future = self._executor.submit(self._task, key) + self._futures.append((key, future)) + + # 等待所有任务完成并收集结果 + for key, future in self._futures: + try: + result_key, csv_df = future.result() + self._results[result_key] = csv_df + except Exception as e: + swanlog.warning(f'Error getting result for key {key} in experiment {self._expid}: {e}') + self._executor.shutdown(wait=True) + + # 按照 keys 的顺序统一合并 + for key in self._keys: + if key not in self._results: + continue + key_df = self._results[key] + step_col, value_col = key_df.columns[:2] # step 列, 指标值列 + + # 将 step 设为索引,其后基于索引自动对齐 + if self._history.empty: + self._history = key_df.set_index(step_col) + else: + self._history[value_col] = key_df.set_index(step_col)[value_col] + + # 若指定x轴,重置索引 + if self._x_axis is not None: + self._history = self._history.reset_index().iloc[:, 1:] + self._history = self._history.set_index(self._history.columns[0]) + else: + self._history.rename(columns={'step': '_step'}, inplace=True) + return self._history diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py new file mode 100644 index 000000000..5114184e1 --- /dev/null +++ b/swanlab/api/utils.py @@ -0,0 +1,24 @@ +""" +@author: Zhou QiYang +@file: utils.py +@time: 2026/1/4 18:03 +@description: OpenApi 使用的常量和工具函数 +""" + +from typing import Dict, List + +STATUS_OK = "OK" +STATUS_CREATED = "Created" + + +def flatten_runs(runs: Dict) -> List: + """ + 展开分组后的实验数据,返回一个包含所有实验的列表 + """ + flat_runs = [] + for group in runs.values(): + if isinstance(group, Dict): + flat_runs.extend(flatten_runs(group)) + else: + flat_runs.extend(group) + return flat_runs diff --git a/swanlab/core_python/api/experiment.py b/swanlab/core_python/api/experiment.py index 95d92db0c..8e7299ffc 100644 --- a/swanlab/core_python/api/experiment.py +++ b/swanlab/core_python/api/experiment.py @@ -5,11 +5,14 @@ @description: 定义实验相关的后端API接口 """ -from typing import Literal, TYPE_CHECKING +from typing import Literal, Dict, TYPE_CHECKING, List, Union if TYPE_CHECKING: from swanlab.core_python.client import Client +from .type import RunType +from .utils import to_camel_case, parse_column_type + def send_experiment_heartbeat( client: "Client", @@ -52,3 +55,57 @@ def update_experiment_state( put_data = {k: v for k, v in put_data.items() if v is not None} # 移除值为None的键 client.put(f"/project/{username}/{projname}/runs/{cuid}/state", put_data) client.pending = True + + +def get_project_experiments( + client: "Client", + *, + path: str, + filters: Dict[str, object] = None, +) -> Union[List[RunType], Dict[str, List[RunType]]]: + """ + 获取指定项目下的所有实验信息 + 若有实验分组,则返回一个字典,使用时需递归展平实验数据 + :param client: 已登录的客户端实例 + :param path: 项目路径 username/project + :param filters: 筛选实验的条件,可选 + """ + parsed_filters = ( + [ + { + "key": to_camel_case(key) if parse_column_type(key) == 'STABLE' else key.split('.', 1)[-1], + "active": True, + "value": [value], + "op": 'EQ', + "type": parse_column_type(key), + } + for key, value in filters.items() + ] + if filters + else [] + ) + res = client.post(f"/project/{path}/runs/shows", data={'filters': parsed_filters}) + return res[0] + + +def get_single_experiment(client: "Client", *, path: str) -> RunType: + """ + 获取指定项目下的所有实验信息 + 若有实验分组,则返回一个字典,使用时需递归展平实验数据 + :param client: 已登录的客户端实例 + :param path: 实验路径 username/project/expid + """ + proj_path, expid = path.rsplit('/', 1) + res = client.get(f"/project/{proj_path}/runs/{expid}") + return res[0] + + +def get_experiment_metrics(client: "Client", *, expid: str, key: str) -> Dict[str, str]: + """ + 获取指定字段的指标数据,返回csv网址 + :param client: 已登录的客户端实例 + :param expid: 实验cuid + :param key: 指定字段列表 + """ + res = client.get(f"/experiment/{expid}/column/csv", params={'key': key}) + return res[0] diff --git a/swanlab/core_python/api/project.py b/swanlab/core_python/api/project.py index 0b5f7393a..ade2e2e6c 100644 --- a/swanlab/core_python/api/project.py +++ b/swanlab/core_python/api/project.py @@ -5,14 +5,16 @@ @description: 定义项目相关的后端API接口 """ -from typing import Optional, List +from typing import Optional, List, TYPE_CHECKING -from swanlab.core_python.api.type import ProjParamType, ProjResponseType -from .. import Client +if TYPE_CHECKING: + from swanlab.core_python.client import Client +from .type import ProjResponseType -def get_entity_projects( - client: Client, + +def get_workspace_projects( + client: "Client", *, workspace: str, page: int = 1, @@ -20,7 +22,7 @@ def get_entity_projects( sort: Optional[List[str]] = None, search: Optional[str] = None, detail: Optional[bool] = True, -): +) -> ProjResponseType: """ 获取指定页数和条件下的项目信息 :param client: 已登录的客户端实例 @@ -31,7 +33,7 @@ def get_entity_projects( :param search: 搜索的项目名称关键字, 可选 :param detail: 是否包含项目下实验的相关信息, 可选, 默认为true """ - params: ProjParamType = { + params = { 'page': page, 'size': size, 'sort': sort, @@ -39,5 +41,4 @@ def get_entity_projects( 'detail': detail, } res = client.get(f"/project/{workspace}", params=dict(params)) - projects_info: ProjResponseType = res[0] - return projects_info + return res[0] diff --git a/swanlab/core_python/api/self_hosted.py b/swanlab/core_python/api/self_hosted.py new file mode 100644 index 000000000..3a252cab9 --- /dev/null +++ b/swanlab/core_python/api/self_hosted.py @@ -0,0 +1,34 @@ +""" +@author: Zhou QiYang +@file: self_hosted.py +@time: 2026/1/5 17:42 +@description: 私有化相关API接口 +""" + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + +from .type import SelfHostedInfoType + + +def get_self_hosted_init(client: "Client") -> SelfHostedInfoType: + """ + 获取私有化部署信息 + :param client: 已登录的客户端实例 + """ + res = client.get(f"/self_hosted/info") + return res[0] + + +def create_user(client: "Client", *, username: str, password: str) -> str: + """ + 根用户添加用户 + :param client: 已登录的客户端实例 + :param username: 用户名 + :param password: 用户密码 + """ + data = {"users": [{"username": username, "password": password}]} + res = client.post("/self_hosted/users", data=data) + return res[0] diff --git a/swanlab/core_python/api/type.py b/swanlab/core_python/api/type.py index 7f13ec6ff..ec9becbe4 100644 --- a/swanlab/core_python/api/type.py +++ b/swanlab/core_python/api/type.py @@ -5,26 +5,30 @@ @description: OpenApi 用到的类型文件 """ -from typing import TypedDict, Optional, List, Dict - - -# 发送到后端查询项目信息的字段 -class ProjParamType(TypedDict): - page: int # 页码 - size: int # 每页项目数量 - sort: Optional[List[str]] # 排序方式(包含多个条件的列表) - search: Optional[str] # 搜索关键词 - detail: Optional[bool] # 是否返回详细信息(_count) +from typing import TypedDict, Optional, List, Dict, Literal +# ------------------------------------- 通用类型 ------------------------------------- +# 在项目信息和用户信息的返回结果中,该类型的字段含义不同,注意区分 class GroupType(TypedDict): - username: str # 工作空间名称 (workspace) + name: str # 组织名称 (用于user.teams) + username: str # 工作空间名称 (用于project.workspace) class ProjectLabelType(TypedDict): name: str # 项目标签名称 +class UserType(TypedDict): + username: str # 用户名 + name: str # 用户显示名称 + + +StateType = Literal['FINISHED', 'CRASHED', 'ABORTED', 'RUNNING'] # 实验状态 +ColumnType = Literal['STABLE', 'SCALAR', 'CONFIG'] # 列类型 + + +# 项目信息 class ProjectType(TypedDict): cuid: str # 项目CUID, 唯一标识符 name: str # 项目名 @@ -39,9 +43,41 @@ class ProjectType(TypedDict): _count: Dict[str, int] # 项目的统计信息 -# 后端返回的项目信息 +class RunType(TypedDict): + cuid: str # 实验CUID, 唯一标识符 + name: str # 实验名称 + createdAt: str # 创建时间, e.g., '2024-11-23T12:28:04.286Z' + description: str # 实验描述 + labels: List[ProjectLabelType] # 实验标签列表 + profile: Dict[str, Dict[str, object]] # 实验配置和摘要信息,包含 'config' 和 'scalar' + state: StateType # 实验状态 + cluster: str # 实验组 + job: str # 任务类型 + runtime: str # 运行时间 + user: UserType # 实验所属用户 + rootExpId: Optional[str] # 祖宗实验对应的实验 cuid,如果为克隆实验则必传 + rootProId: Optional[str] # 祖宗实验对应的项目 cuid,如果为克隆实验则必传 + + +# ------------------------------------- 后端返回信息 ------------------------------------- class ProjResponseType(TypedDict): list: List[ProjectType] # 项目列表 size: int # 每页项目数量 pages: int # 总页数 total: int # 总项目数量 + + +class ApiKeyType(TypedDict): + id: int + name: str + createdAt: str + key: str + + +# 私有化部署信息 +class SelfHostedInfoType(TypedDict): + enabled: bool # 是否成功部署 + expired: bool # licence是否过期 + root: bool # 是否为根用户 + plan: Literal["free", "commercial"] # 私有化版本(免费、商业) + seats: int # 余剩席位 diff --git a/swanlab/core_python/api/user.py b/swanlab/core_python/api/user.py new file mode 100644 index 000000000..8e0b6e097 --- /dev/null +++ b/swanlab/core_python/api/user.py @@ -0,0 +1,65 @@ +""" +@author: Zhou QiYang +@file: user.py +@time: 2026/1/2 21:01 +@description: 定义用户相关的后端API接口 +""" + +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from swanlab.core_python.client import Client + +from swanlab.core_python.api.type import ApiKeyType, GroupType + + +def create_api_key(client: "Client", *, name: str = None) -> str: + """ + 创建一个api_key,完成后返回成功信息 + :param client: 已登录的客户端实例 + :param name: api_key 的名称 + """ + if name is not None: + data = {'name': name} + res = client.post(f"/user/key", data=data) + else: + res = client.post(f"/user/key") + return res[0] + + +def delete_api_key(client: "Client", *, key_id: int) -> str: + """ + 删除指定id的api_key + :param client: 已登录的客户端实例 + :param key_id: api_key的id + """ + res = client.delete(f"/user/key/{key_id}") + return res[0] + + +def get_user_groups(client: "Client", *, username: str) -> List[GroupType]: + """ + 获取当前全部的api_key + :param client: 已登录的客户端实例 + :param username: 用户名称 + """ + res = client.get(f"/user/{username}/groups") + return res[0] + + +def get_api_keys(client: "Client") -> List[ApiKeyType]: + """ + 获取当前全部的api_key + :param client: 已登录的客户端实例 + """ + res = client.get(f"/user/key") + return res[0] + + +def get_latest_api_key(client: "Client") -> ApiKeyType: + """ + 获取最新的api_key + :param client: 已登录的客户端实例 + """ + res = client.get(f"/user/key/latest") + return res[0] diff --git a/swanlab/core_python/api/utils.py b/swanlab/core_python/api/utils.py new file mode 100644 index 000000000..3ff966c5a --- /dev/null +++ b/swanlab/core_python/api/utils.py @@ -0,0 +1,24 @@ +""" +@author: Zhou QiYang +@file: utils.py +@time: 2025/12/27 18:53 +@description: 与后端交互时所需的工具函数 +""" + +from .type import ColumnType + + +# 从前缀中获取指标类型 +def parse_column_type(column: str) -> ColumnType: + column_type = column.split('.', 1)[0] + if column_type == 'summary': + return 'SCALAR' + elif column_type == 'config': + return 'CONFIG' + else: + return 'STABLE' + + +# 将下划线命名转化为驼峰命名 +def to_camel_case(name: str) -> str: + return ''.join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split('_'))]) diff --git a/swanlab/core_python/client/__init__.py b/swanlab/core_python/client/__init__.py index a8b3a5f0d..9adac78e3 100644 --- a/swanlab/core_python/client/__init__.py +++ b/swanlab/core_python/client/__init__.py @@ -192,6 +192,15 @@ def get(self, url: str, params: dict = None): resp = self.__session.get(url, params=params) return decode_response(resp), resp + def delete(self, url: str): + """ + delete请求 + """ + url = self.__login_info.api_host + url + self.__before_request() + resp = self.__session.delete(url) + return decode_response(resp), resp + def patch(self, url: str, data: dict = None): """ patch请求 diff --git a/test/api/project.py b/test/api/project.py deleted file mode 100644 index 25a600bfa..000000000 --- a/test/api/project.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -@author: Zhou Qiyang -@file: project.py -@time: 2025/12/17 10:45 -@description: 用于测试api登录功能 -""" - -from unittest.mock import patch - -import swanlab - - -# 测试数据 -def make_fake_projects(start, count): - return [ - { - "cuid": f"c{n}", - "name": f"proj-{n}", - "path": f"user/proj-{n}", - "url": f"https://dev001.swanlab.cn/@user/proj-{n}", - "description": f"desc-{n}", - "visibility": "PUBLIC", - "createdAt": "2025-01-01T00:00:00Z", - "updatedAt": "2025-01-01T00:00:00Z", - "projectLabels": [{"name": "Nvidia"}], - "group": {"username": "user", "status": "ENABLED", "type": "TEAM"}, - "_count": {}, - } - for n in range(start, start + count) - ] - - -performance_test_projects = [ - [ - { - "total": 40, - "pages": 1, - "size": 20, - "list": make_fake_projects(0, 20), - }, - ], - [ - { - "total": 40, - "pages": 2, - "size": 20, - "list": make_fake_projects(20, 20), - }, - ], -] - -params_test_projects = [ - { - "total": 20, - "pages": 1, - "size": 20, - "list": make_fake_projects(0, 20), - }, -] - - -# 性能测试:是否按照当前遍历的项目动态获取 -def test_api_projects_performance(): - # patch: client.get 返回 fake_projects_raw - with patch("swanlab.core_python.client.Client.get", side_effect=performance_test_projects) as mock_get: - api = swanlab.OpenApi() - - result = api.projects(workspace="bainiantest", detail=True) - - # 断言请求调用次数 - assert mock_get.call_count == 0 - for project in result: - if project.name == "proj-19": - assert mock_get.call_count == 1 - if project.name == "proj-20": - assert mock_get.call_count == 2 - - -# 功能测试:获取到的项目的属性是否齐全且正确 -def test_api_projects_params(): - with patch("swanlab.core_python.client.Client.get", return_value=params_test_projects): - api = swanlab.OpenApi() - - result = api.projects(workspace="bainiantest", detail=True) - - # 1. 字符串类型的字段 - raw_list = params_test_projects[0]["list"] - fields = { - "name": "name", - "path": "path", - "url": "url", - "description": "description", - "visibility": "visibility", - "created_at": "createdAt", - "updated_at": "updatedAt", - } - for field in fields: - assert [getattr(p, field) for p in result] == [r[fields[field]] for r in raw_list] - - # 1.1 workspace - assert [p.workspace for p in result] == [r["group"]["username"] for r in raw_list] - - # 2. labels - assert [[l.__str__() for l in p.labels] for p in result] == [ - [l["name"] for l in r["projectLabels"]] for r in raw_list - ] - assert [[l.name for l in p.labels] for p in result] == [ - [l["name"] for l in r["projectLabels"]] for r in raw_list - ] - - # 3. count - assert [p.count for p in result] == [r["_count"] for r in raw_list] From 6480312d97d5fac4a921def5fa192274b5f09471 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Tue, 13 Jan 2026 11:46:47 +0800 Subject: [PATCH 04/21] refactor open api (#1411) * Refactor the OpenAPI codes, managing the code with a modular approach * combine 3 kinds of user into one - use @cached_property to cache some property of user * Place the API's custom type into a separate module - place class ApiBase into a single model * get the user identity at the init of the OpenApi - only the root user get by api.user() can create new account * refactor module name & structure in api and core_python * Modified parameter passing convention & add client property to ApiBase - pass param through kwargs except client in OpenApi and core_python * update code commit - opt str() of Label class --- swanlab/api/__init__.py | 105 ++++++++-------- .../experiment.py => experiment/__init__.py} | 60 +++------- swanlab/api/{ => experiment}/thread.py | 10 +- swanlab/api/experiments/__init__.py | 69 +++++++++++ swanlab/api/model/__init__.py | 12 -- swanlab/api/model/base.py | 60 ---------- swanlab/api/model/user.py | 83 ------------- .../{model/project.py => project/__init__.py} | 59 +-------- swanlab/api/projects/__init__.py | 61 ++++++++++ swanlab/api/user/__init__.py | 113 ++++++++++++++++++ swanlab/api/utils.py | 53 +++++--- .../{experiment.py => experiment/__init__.py} | 15 ++- .../core_python/api/{ => experiment}/utils.py | 6 +- .../api/{project.py => project/__init__.py} | 9 +- swanlab/core_python/api/type.py | 83 ------------- swanlab/core_python/api/type/__init__.py | 21 ++++ swanlab/core_python/api/type/experiment.py | 33 +++++ swanlab/core_python/api/type/project.py | 36 ++++++ swanlab/core_python/api/type/user.py | 35 ++++++ .../api/{user.py => user/__init__.py} | 20 +++- .../core_python/api/{ => user}/self_hosted.py | 4 +- 21 files changed, 524 insertions(+), 423 deletions(-) rename swanlab/api/{model/experiment.py => experiment/__init__.py} (76%) rename swanlab/api/{ => experiment}/thread.py (91%) create mode 100644 swanlab/api/experiments/__init__.py delete mode 100644 swanlab/api/model/__init__.py delete mode 100644 swanlab/api/model/base.py delete mode 100644 swanlab/api/model/user.py rename swanlab/api/{model/project.py => project/__init__.py} (51%) create mode 100644 swanlab/api/projects/__init__.py create mode 100644 swanlab/api/user/__init__.py rename swanlab/core_python/api/{experiment.py => experiment/__init__.py} (93%) rename swanlab/core_python/api/{ => experiment}/utils.py (78%) rename swanlab/core_python/api/{project.py => project/__init__.py} (90%) delete mode 100644 swanlab/core_python/api/type.py create mode 100644 swanlab/core_python/api/type/__init__.py create mode 100644 swanlab/core_python/api/type/experiment.py create mode 100644 swanlab/core_python/api/type/project.py create mode 100644 swanlab/core_python/api/type/user.py rename swanlab/core_python/api/{user.py => user/__init__.py} (81%) rename swanlab/core_python/api/{ => user}/self_hosted.py (93%) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index fd67c18f2..0abd51a5b 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -1,25 +1,24 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2025/4/29 9:40 -@File: __init__.py -@IDE: pycharm -@Description: - SwanLab OpenAPI包 """ -from typing import Optional, Union, List, Dict +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/5 17:58 +@description: SwanLab OpenAPI包 +""" -from .deprecated import OpenApi -from .model import ApiUser, SuperUser, Projects, Experiments, Experiment -from .model import ApiUser, SuperUser, Projects, Experiments, Experiment -from ..core_python import auth, Client -from ..core_python.api.experiment import get_single_experiment, get_project_experiments -from ..core_python.api.self_hosted import get_self_hosted_init -from ..error import KeyFileError, ApiError -from ..log import swanlog -from ..package import HostFormatter, get_key +from typing import Optional, List, Dict -__all__ = ["Api", "OpenApi"] +from swanlab.core_python import auth, Client +from swanlab.core_python.api.experiment import get_single_experiment, get_project_experiments +from swanlab.core_python.api.type import IdentityType +from swanlab.core_python.api.user import get_self_hosted_init +from swanlab.error import KeyFileError, ApiError +from swanlab.log import swanlog +from swanlab.package import HostFormatter, get_key +from .deprecated import OpenApi +from .experiment import Experiment +from .experiments import Experiments +from .projects import Projects +from .user import User class Api: @@ -47,40 +46,32 @@ def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, we self._client: Client = Client(self._login_info) self._web_host = self._login_info.web_host - def user(self, username: str = None) -> Optional[Union[ApiUser, SuperUser]]: # 尝试获取私有化服务信息,如果不是私有化服务,则会报错退出,因为指定user功能仅供私有化用户使用 try: - self_hosted_info = get_self_hosted_init(self._client) - except ApiError as e: - if username is not None: - swanlog.error( - "You haven't launched a swanlab self-hosted instance. Please check your login status using 'swanlab verify'." - ) - raise e - else: - return ApiUser(self._client, self._login_info) + self._self_hosted_info = get_self_hosted_init(self._client) + except ApiError: + swanlog.warning("You haven't launched a swanlab self-hosted instance. Some usages are not available.") + self._self_hosted_info = None - if not self_hosted_info["enabled"]: - raise RuntimeError("SwanLab self-hosted instance hasn't been ready yet.") - if self_hosted_info["expired"]: - raise RuntimeError("SwanLab self-hosted instance has expired. Please refresh your licence.") + self._identity: IdentityType = 'user' + if self._self_hosted_info is not None and self._self_hosted_info["plan"] == 'commercial': + self._identity = 'root' if self._self_hosted_info['root'] else 'user' - # 免费版仅能获取当前api_key登录的用户 - if self_hosted_info["plan"] == 'free': - if username != self._login_info.username: - swanlog.warning("Your self-hosted plan is 'free', You will be access to your own account.") - return ApiUser(self._client, self._login_info) - # 商业版的根用户可以获取到任何一个用户 - elif self_hosted_info["plan"] == 'commercial': - if self_hosted_info['root']: - return SuperUser(self._client, self._login_info, self_hosted=self_hosted_info) - elif username != self._login_info.username: - swanlog.warning("Your are not the root user, You will be access to your own account.") - return ApiUser(self._client, self._login_info) - # 为教育版预留功能 - else: - swanlog.warning("The self-hosted plan hasn't been supported yet.") - return None + if self._self_hosted_info is not None: + if not self._self_hosted_info["enabled"]: + swanlog.warning("SwanLab self-hosted instance hasn't been ready yet.") + if self._self_hosted_info["expired"]: + swanlog.warning("SwanLab self-hosted instance has expired.") + + def user(self, username: str = None) -> User: + """ + 获取用户实例,用于操作用户相关信息 + :param username: 指定用户名,如果为 None,则返回当前登录用户 + :return: User 实例,可对当前/指定用户进行操作 + """ + return User( + client=self._client, login_user=self._login_info.username, username=username, identity=self._identity + ) def projects( self, @@ -98,7 +89,7 @@ def projects( :return: Projects 实例,可遍历获取项目信息 """ return Projects( - client=self._client, + self._client, web_host=self._web_host, workspace=workspace, sort=sort, @@ -113,7 +104,7 @@ def runs(self, path: str, filters: Dict[str, object] = None) -> Experiments: :return: Experiments 实例,可遍历获取实验信息 :param filters: 筛选实验的条件,可选 """ - return Experiments(client=self._client, path=path, web_host=self._web_host, filters=filters) + return Experiments(self._client, path=path, login_info=self._login_info, filters=filters) def run( self, @@ -132,4 +123,14 @@ def run( data = get_project_experiments( self._client, path=proj_path, filters={'name': _data['name'], 'created_at': _data['createdAt']} ) - return Experiment(data=data[0], client=self._client, path=proj_path, web_host=self._web_host, line_count=1) + return Experiment( + self._client, + data=data[0], + path=proj_path, + web_host=self._web_host, + login_user=self._login_info.username, + line_count=1, + ) + + +__all__ = ["Api", "OpenApi"] diff --git a/swanlab/api/model/experiment.py b/swanlab/api/experiment/__init__.py similarity index 76% rename from swanlab/api/model/experiment.py rename to swanlab/api/experiment/__init__.py index 355a72332..0f5873ac2 100644 --- a/swanlab/api/model/experiment.py +++ b/swanlab/api/experiment/__init__.py @@ -1,30 +1,29 @@ """ @author: Zhou QiYang @file: experiment.py -@time: 2026/1/5 17:58 -@description: OpenApi 中的实验对象 +@time: 2026/1/11 16:36 +@description: OpenApi 的单个实验对象 """ -from typing import TYPE_CHECKING, List, Dict, Any, Iterator +from typing import List, Dict, Any -if TYPE_CHECKING: - from swanlab.core_python.client import Client - -from swanlab.log import swanlog -from swanlab.core_python.api.experiment import get_project_experiments +from swanlab.api.user import User +from swanlab.api.utils import ApiBase, Label from swanlab.core_python.api.type import RunType -from swanlab.api.thread import HistoryPool -from swanlab.api.utils import flatten_runs - -from .base import ApiBase, Label, User +from swanlab.core_python.client import Client +from swanlab.log import swanlog +from .thread import HistoryPool class Experiment(ApiBase): - def __init__(self, data: RunType, client: "Client", path: str, web_host: str, line_count: int) -> None: + def __init__( + self, client: Client, *, data: RunType, path: str, web_host: str, login_user: str, line_count: int + ) -> None: + super().__init__(client) self._data = data - self._client = client self._path = path self._web_host = web_host + self._login_user = login_user self._line_count = line_count @property @@ -67,7 +66,7 @@ def labels(self) -> List[Label]: """ List of Label attached to this experiment. """ - return [Label(label) for label in self._data['labels']] + return [Label(label['name']) for label in self._data['labels']] @property def config(self) -> Dict[str, object]: @@ -109,7 +108,7 @@ def user(self) -> User: """ Experiment user. """ - return User(self._data['user']) + return User(client=self._client, login_user=self._login_user, username=self._data['user']['username']) @property def metric_keys(self) -> List[str]: @@ -212,31 +211,4 @@ def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None return df if pandas else df.to_dict(orient='records') - - -class Experiments(ApiBase): - """ - Container for a collection of Experiment objects. - You can iterate over the experiments by for-in loop. - """ - - def __init__(self, client: "Client", path: str, web_host: str, filters: Dict[str, object] = None) -> None: - if len(path.split('/')) != 2: - raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project'") - self._client = client - self._path = path - self._web_host = web_host - self._filters = filters - - def __iter__(self) -> Iterator[Experiment]: - # TODO: 完善filter的功能(正则、条件判断) - resp = get_project_experiments(self._client, path=self._path, filters=self._filters) - runs: List[RunType] = [] - if isinstance(resp, List): - runs = resp - # 分组时需展平实验数据 - elif isinstance(resp, Dict): - runs = flatten_runs(resp) - line_count = len(runs) - yield from iter(Experiment(run, self._client, self._path, self._web_host, line_count) for run in runs) - +__all__ = ['Experiment'] diff --git a/swanlab/api/thread.py b/swanlab/api/experiment/thread.py similarity index 91% rename from swanlab/api/thread.py rename to swanlab/api/experiment/thread.py index eaedd7d75..50e12fc9f 100644 --- a/swanlab/api/thread.py +++ b/swanlab/api/experiment/thread.py @@ -2,25 +2,23 @@ @author: Zhou QiYang @file: thread.py @time: 2025/12/30 15:08 -@description: 用于api并发请求的封装类 +@description: 用于并发请求实验指标数据的封装类 """ from concurrent.futures import ThreadPoolExecutor from io import BytesIO -from typing import List, Any, TYPE_CHECKING +from typing import List, Any import requests -if TYPE_CHECKING: - from swanlab.core_python.client import Client - from swanlab.core_python.api.experiment import get_experiment_metrics +from swanlab.core_python.client import Client from swanlab.log import swanlog class HistoryPool: - def __init__(self, client: "Client", expid: str, *, keys: List[str], x_axis: str = None, num_threads: int = 10): + def __init__(self, client: Client, expid: str, *, keys: List[str], x_axis: str = None, num_threads: int = 10): try: import pandas as pd except ImportError: diff --git a/swanlab/api/experiments/__init__.py b/swanlab/api/experiments/__init__.py new file mode 100644 index 000000000..383d2503a --- /dev/null +++ b/swanlab/api/experiments/__init__.py @@ -0,0 +1,69 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2025/12/30 15:08 +@description: OpenApi 的实验对象迭代器 +""" + +from typing import List, Dict, Iterator + +from swanlab.api.experiment import Experiment +from swanlab.api.utils import ApiBase +from swanlab.core_python.api.experiment import get_project_experiments +from swanlab.core_python.api.type import RunType +from swanlab.core_python.auth.providers.api_key import LoginInfo +from swanlab.core_python.client import Client + + +def flatten_runs(runs: Dict) -> List: + """ + 展开分组后的实验数据,返回一个包含所有实验的列表 + """ + flat_runs = [] + for group in runs.values(): + if isinstance(group, Dict): + flat_runs.extend(flatten_runs(group)) + else: + flat_runs.extend(group) + return flat_runs + + +class Experiments(ApiBase): + """ + Container for a collection of Experiment objects. + You can iterate over the experiments by for-in loop. + """ + + def __init__(self, client: Client, *, path: str, login_info: LoginInfo, filters: Dict[str, object] = None) -> None: + if len(path.split('/')) != 2: + raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project'") + super().__init__(client) + self._path = path + self._web_host = login_info.web_host + self._login_user = login_info.username + self._filters = filters + + def __iter__(self) -> Iterator[Experiment]: + # TODO: 完善filter的功能(正则、条件判断) + resp = get_project_experiments(self._client, path=self._path, filters=self._filters) + runs: List[RunType] = [] + if isinstance(resp, List): + runs = resp + # 分组时需展平实验数据 + elif isinstance(resp, Dict): + runs = flatten_runs(resp) + line_count = len(runs) + yield from iter( + Experiment( + self._client, + data=run, + path=self._path, + web_host=self._web_host, + login_user=self._login_user, + line_count=line_count, + ) + for run in runs + ) + + +__all__ = ["Experiments"] diff --git a/swanlab/api/model/__init__.py b/swanlab/api/model/__init__.py deleted file mode 100644 index e838c7eb7..000000000 --- a/swanlab/api/model/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -@author: Zhou QiYang -@file: __init__.py -@time: 2026/1/5 17:59 -@description: OpenApi 中包含的对象 -""" - -from .experiment import Experiment, Experiments -from .project import Projects -from .user import ApiUser, SuperUser - -__all__ = ['Experiment', 'Experiments', 'Projects', 'ApiUser', 'SuperUser'] diff --git a/swanlab/api/model/base.py b/swanlab/api/model/base.py deleted file mode 100644 index 5d2efc8d8..000000000 --- a/swanlab/api/model/base.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -@author: Zhou Qiyang -@file: model.py -@time: 2025/12/18 20:10 -@description: OpenApi 中的基础对象 -""" - -from typing import Dict - -from swanlab.core_python.api.type import ProjectLabelType, UserType - - -class ApiBase: - @property - def __dict__(self) -> Dict[str, object]: - """ - Return a dictionary containing all @property fields. - """ - result = {} - cls = type(self) - for attr_name in dir(cls): - if attr_name.startswith('_'): - continue - attr = getattr(cls, attr_name, None) - if isinstance(attr, property): - result[attr_name] = self.__getattribute__(attr_name) - return result - - -class Label(ApiBase): - """ - Project label object - you can get the label name by str(label) - """ - - def __init__(self, data: ProjectLabelType) -> None: - self._data = data - - @property - def name(self) -> str: - """ - Label name. - """ - return self._data['name'] - - def __str__(self) -> str: - return str(self.name) - - -class User(ApiBase): - def __init__(self, data: UserType) -> None: - self._data = data - - @property - def name(self) -> str: - return self._data['name'] - - @property - def username(self) -> str: - return self._data['username'] diff --git a/swanlab/api/model/user.py b/swanlab/api/model/user.py deleted file mode 100644 index b184e30bc..000000000 --- a/swanlab/api/model/user.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -@author: Zhou QiYang -@file: user.py -@time: 2026/1/5 17:58 -@description: OpenApi 中的用户对象 -""" - -import re -from typing import TYPE_CHECKING, List, Optional - -if TYPE_CHECKING: - from swanlab.core_python.client import Client - -from swanlab.api.utils import STATUS_CREATED, STATUS_OK -from swanlab.core_python.api.self_hosted import create_user -from swanlab.core_python.api.type import ApiKeyType, SelfHostedInfoType -from swanlab.core_python.api.user import ( - get_user_groups, - get_api_keys, - create_api_key, - get_latest_api_key, - delete_api_key, -) -from swanlab.core_python.auth.providers.api_key import LoginInfo - -from .base import ApiBase - - -class ApiUser(ApiBase): - def __init__(self, client: "Client", login_info: LoginInfo) -> None: - super().__init__() - self._client = client - self._login_info = login_info - self._api_keys: List[ApiKeyType] = [] - - @property - def username(self) -> str: - return self._login_info.username - - @property - def teams(self) -> List[str]: - resp = get_user_groups(self._client, username=self.username) - return [r['name'] for r in resp] - - @property - def api_keys(self) -> List[str]: - self._api_keys = get_api_keys(self._client) - return [r['key'] for r in self._api_keys] - - def generate_api_key(self, description: str = None) -> Optional[str]: - api_key: Optional[ApiKeyType] = None - res = create_api_key(self._client, name=description) - if res == STATUS_CREATED: - api_key = get_latest_api_key(self._client) - return api_key['key'] if api_key else None - - def delete_api_key(self, api_key: str) -> bool: - self._api_keys = get_api_keys(self._client) - for key in self._api_keys: - if key['key'] == api_key: - res = delete_api_key(self._client, key_id=key['id']) - if res == STATUS_OK: - return True - return False - - -class SuperUser(ApiUser): - def __init__(self, client: "Client", login_info: LoginInfo, self_hosted: SelfHostedInfoType) -> None: - super().__init__(client, login_info) - self._self_hosted_info = self_hosted - - def create(self, username: str, password: str) -> bool: - # 用户名为大小写字母、数字及-、_组成 - # 密码必须包含数字+英文且至少8位 - if not re.match(r'^[a-zA-Z0-9_-]+$', username): - raise ValueError("Username must be alphanumeric and can contain - and _") - if not re.match(r'^(?=.*[0-9])(?=.*[a-zA-Z]).{8,}$', password): - raise ValueError("Password must contain at least 8 characters and include numbers and letters") - resp = create_user(self._client, username=username, password=password) - if resp == STATUS_CREATED: - return True - else: - raise False diff --git a/swanlab/api/model/project.py b/swanlab/api/project/__init__.py similarity index 51% rename from swanlab/api/model/project.py rename to swanlab/api/project/__init__.py index 438a2b4a9..926345c27 100644 --- a/swanlab/api/model/project.py +++ b/swanlab/api/project/__init__.py @@ -5,15 +5,10 @@ @description: OpenApi 中的项目对象 """ -from typing import List, Dict, Optional, Iterator, TYPE_CHECKING +from typing import List, Dict -if TYPE_CHECKING: - from swanlab.core_python.client import Client - -from swanlab.core_python.api.project import get_workspace_projects -from swanlab.core_python.api.type import ProjectType, ProjResponseType - -from .base import ApiBase, Label +from swanlab.api.utils import ApiBase, Label +from swanlab.core_python.api.type import ProjectType class Project(ApiBase): @@ -21,7 +16,8 @@ class Project(ApiBase): Representing a single project with some of its properties. """ - def __init__(self, data: ProjectType, web_host: str) -> None: + def __init__(self, *, data: ProjectType, web_host: str) -> None: + super().__init__() self._data = data self._web_host = web_host @@ -86,7 +82,7 @@ def labels(self) -> List[Label]: """ List of Label attached to this project. """ - return [Label(label) for label in self._data['projectLabels']] + return [Label(label['name']) for label in self._data['projectLabels']] @property def count(self) -> Dict[str, int]: @@ -95,46 +91,3 @@ def count(self) -> Dict[str, int]: experiments, contributors, children, collaborators, runningExps. """ return self._data['_count'] - - -class Projects(ApiBase): - """ - Container for a collection of Project objects. - You can iterate over the projects by for-in loop. - """ - - def __init__( - self, - client: "Client", - web_host: str, - workspace: str, - sort: Optional[List[str]] = None, - search: Optional[str] = None, - detail: Optional[bool] = True, - ) -> None: - self._client = client - self._web_host = web_host - self._workspace = workspace - self._sort = sort - self._search = search - self._detail = detail - - def __iter__(self) -> Iterator[Project]: - # 按用户遍历情况获取项目信息 - cur_page = 0 - page_size = 20 - while True: - cur_page += 1 - projects_info: ProjResponseType = get_workspace_projects( - self._client, - workspace=self._workspace, - page=cur_page, - size=page_size, - sort=self._sort, - search=self._search, - detail=self._detail, - ) - if cur_page * page_size >= projects_info['total']: - break - - yield from iter(Project(project, self._web_host) for project in projects_info['list']) diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py new file mode 100644 index 000000000..ead54c51b --- /dev/null +++ b/swanlab/api/projects/__init__.py @@ -0,0 +1,61 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/11 16:31 +@description: OpenApi 中的项目对象迭代器 +""" + +from typing import List, Optional, Iterator + +from swanlab.api.project import Project +from swanlab.api.utils import ApiBase +from swanlab.core_python.api.project import get_workspace_projects +from swanlab.core_python.api.type import ProjResponseType +from swanlab.core_python.client import Client + + +class Projects(ApiBase): + """ + Container for a collection of Project objects. + You can iterate over the projects by for-in loop. + """ + + def __init__( + self, + client: Client, + *, + web_host: str, + workspace: str, + sort: Optional[List[str]] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ) -> None: + super().__init__(client) + self._web_host = web_host + self._workspace = workspace + self._sort = sort + self._search = search + self._detail = detail + + def __iter__(self) -> Iterator[Project]: + # 按用户遍历情况获取项目信息 + cur_page = 0 + page_size = 20 + while True: + cur_page += 1 + projects_info: ProjResponseType = get_workspace_projects( + self._client, + workspace=self._workspace, + page=cur_page, + size=page_size, + sort=self._sort, + search=self._search, + detail=self._detail, + ) + if cur_page * page_size >= projects_info['total']: + break + + yield from iter(Project(data=p, web_host=self._web_host) for p in projects_info['list']) + + +__all__ = ["Projects"] diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py new file mode 100644 index 000000000..a7419108e --- /dev/null +++ b/swanlab/api/user/__init__.py @@ -0,0 +1,113 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/11 16:40 +@description: OpenApi 中的用户对象 +""" + +import re +from functools import cached_property +from typing import List, Optional + +from swanlab.api.utils import ApiBase +from swanlab.core_python.api.type import ApiKeyType +from swanlab.core_python.api.type.user import IdentityType +from swanlab.core_python.api.user import ( + get_user_groups, + get_api_keys, + create_api_key, + get_latest_api_key, + delete_api_key, +) +from swanlab.core_python.api.user.self_hosted import create_user +from swanlab.core_python.client import Client +from swanlab.log import swanlog + +STATUS_OK = "OK" +STATUS_CREATED = "Created" + + +def check_create_info(username: str, password: str) -> bool: + # 用户名为大小写字母、数字及-、_组成 + # 密码必须包含数字+英文且至少8位 + if not re.match(r'^[a-zA-Z0-9_-]+$', username): + raise ValueError("Username must be alphanumeric and can contain - and _") + if not re.match(r'^(?=.*[0-9])(?=.*[a-zA-Z]).{8,}$', password): + raise ValueError("Password must contain at least 8 characters and include numbers and letters") + else: + return True + + +class User(ApiBase): + def __init__( + self, client: Client, login_user: str = None, username: str = None, identity: IdentityType = 'user' + ) -> None: + if login_user is None and username is None: + raise ValueError("login_user or username are required") + + super().__init__(client) + self._identity = identity + self._api_keys: List[ApiKeyType] = [] + self._cur_username = username or login_user + self._is_other_user = username is not None and username != login_user + + @property + def username(self) -> str: + return self._cur_username + + @cached_property + def teams(self) -> List[str]: + resp = get_user_groups(self._client, username=self._cur_username) + return [r['name'] for r in resp] + + # TODO: 管理员可以对指定用户的api_key进行操作 + @cached_property + def api_keys(self) -> List[str]: + if self._is_other_user: + swanlog.warning("Getting api keys of other users has not been supported yet.") + return [] + else: + self._api_keys = get_api_keys(self._client) + return [r['key'] for r in self._api_keys] + + def _refresh_api_keys(self): + del self.api_keys + self._api_keys = get_api_keys(self._client) + + def generate_api_key(self, description: str = None) -> Optional[str]: + if self._is_other_user: + swanlog.warning("Generating api key of other users has not been supported yet.") + return None + else: + api_key: Optional[ApiKeyType] = None + res = create_api_key(self._client, name=description) + if res == STATUS_CREATED: + api_key = get_latest_api_key(self._client) + return api_key['key'] if api_key else None + + def delete_api_key(self, api_key: str) -> bool: + if self._is_other_user: + swanlog.warning("Deleting api key of other users has not been supported yet.") + return False + else: + self._refresh_api_keys() + for key in self._api_keys: + if key['key'] == api_key: + res = delete_api_key(self._client, key_id=key['id']) + if res == STATUS_OK: + return True + return False + + def create(self, username: str, password: str) -> bool: + if self._identity != "root" or self._is_other_user: + swanlog.warning(f"{self._cur_username} is not allowed to create other user.") + return False + check_create_info(username, password) + resp = create_user(self._client, username=username, password=password) + if resp == STATUS_CREATED: + return True + else: + raise False + + +__all__ = ["User"] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 5114184e1..3eb87337b 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -1,24 +1,47 @@ """ @author: Zhou QiYang -@file: utils.py -@time: 2026/1/4 18:03 -@description: OpenApi 使用的常量和工具函数 +@file: __init__.py +@time: 2026/1/11 23:44 +@description: OpenApi 中的基础对象 """ -from typing import Dict, List +from dataclasses import dataclass +from typing import Dict -STATUS_OK = "OK" -STATUS_CREATED = "Created" +from swanlab.core_python import Client -def flatten_runs(runs: Dict) -> List: +@dataclass +class Label: """ - 展开分组后的实验数据,返回一个包含所有实验的列表 + Project label object + you can get the label name by str(label) """ - flat_runs = [] - for group in runs.values(): - if isinstance(group, Dict): - flat_runs.extend(flatten_runs(group)) - else: - flat_runs.extend(group) - return flat_runs + + name: str + + def __str__(self) -> str: + return self.name + + +class ApiBase: + def __init__(self, client: Client = None): + self._client = client + + @property + def __dict__(self) -> Dict[str, object]: + """ + Return a dictionary containing all @property fields. + """ + result = {} + cls = type(self) + for attr_name in dir(cls): + if attr_name.startswith('_'): + continue + attr = getattr(cls, attr_name, None) + if isinstance(attr, property): + result[attr_name] = self.__getattribute__(attr_name) + return result + + +__all__ = ['ApiBase', 'Label'] diff --git a/swanlab/core_python/api/experiment.py b/swanlab/core_python/api/experiment/__init__.py similarity index 93% rename from swanlab/core_python/api/experiment.py rename to swanlab/core_python/api/experiment/__init__.py index 8e7299ffc..662d7d258 100644 --- a/swanlab/core_python/api/experiment.py +++ b/swanlab/core_python/api/experiment/__init__.py @@ -7,12 +7,12 @@ from typing import Literal, Dict, TYPE_CHECKING, List, Union +from swanlab.core_python.api.type import RunType +from .utils import to_camel_case, parse_column_type + if TYPE_CHECKING: from swanlab.core_python.client import Client -from .type import RunType -from .utils import to_camel_case, parse_column_type - def send_experiment_heartbeat( client: "Client", @@ -109,3 +109,12 @@ def get_experiment_metrics(client: "Client", *, expid: str, key: str) -> Dict[st """ res = client.get(f"/experiment/{expid}/column/csv", params={'key': key}) return res[0] + + +__all__ = [ + "send_experiment_heartbeat", + "update_experiment_state", + "get_project_experiments", + "get_single_experiment", + "get_experiment_metrics", +] diff --git a/swanlab/core_python/api/utils.py b/swanlab/core_python/api/experiment/utils.py similarity index 78% rename from swanlab/core_python/api/utils.py rename to swanlab/core_python/api/experiment/utils.py index 3ff966c5a..a35088f97 100644 --- a/swanlab/core_python/api/utils.py +++ b/swanlab/core_python/api/experiment/utils.py @@ -1,11 +1,11 @@ """ @author: Zhou QiYang @file: utils.py -@time: 2025/12/27 18:53 -@description: 与后端交互时所需的工具函数 +@time: 2026/1/10 22:09 +@description: 实验相关的后端API接口中的工具函数 """ -from .type import ColumnType +from swanlab.core_python.api.type import ColumnType # 从前缀中获取指标类型 diff --git a/swanlab/core_python/api/project.py b/swanlab/core_python/api/project/__init__.py similarity index 90% rename from swanlab/core_python/api/project.py rename to swanlab/core_python/api/project/__init__.py index ade2e2e6c..8e807b658 100644 --- a/swanlab/core_python/api/project.py +++ b/swanlab/core_python/api/project/__init__.py @@ -1,17 +1,17 @@ """ @author: Zhou QiYang -@file: project.py +@file: __init__.py @time: 2025/12/19 23:49 @description: 定义项目相关的后端API接口 """ from typing import Optional, List, TYPE_CHECKING +from swanlab.core_python.api.type import ProjResponseType + if TYPE_CHECKING: from swanlab.core_python.client import Client -from .type import ProjResponseType - def get_workspace_projects( client: "Client", @@ -42,3 +42,6 @@ def get_workspace_projects( } res = client.get(f"/project/{workspace}", params=dict(params)) return res[0] + + +__all__ = ["get_workspace_projects"] diff --git a/swanlab/core_python/api/type.py b/swanlab/core_python/api/type.py deleted file mode 100644 index ec9becbe4..000000000 --- a/swanlab/core_python/api/type.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -@author: Zhou Qiyang -@file: types.py -@time: 2025/12/17 16:35 -@description: OpenApi 用到的类型文件 -""" - -from typing import TypedDict, Optional, List, Dict, Literal - - -# ------------------------------------- 通用类型 ------------------------------------- -# 在项目信息和用户信息的返回结果中,该类型的字段含义不同,注意区分 -class GroupType(TypedDict): - name: str # 组织名称 (用于user.teams) - username: str # 工作空间名称 (用于project.workspace) - - -class ProjectLabelType(TypedDict): - name: str # 项目标签名称 - - -class UserType(TypedDict): - username: str # 用户名 - name: str # 用户显示名称 - - -StateType = Literal['FINISHED', 'CRASHED', 'ABORTED', 'RUNNING'] # 实验状态 -ColumnType = Literal['STABLE', 'SCALAR', 'CONFIG'] # 列类型 - - -# 项目信息 -class ProjectType(TypedDict): - cuid: str # 项目CUID, 唯一标识符 - name: str # 项目名 - path: str # 项目路径 - url: str # 项目URL - description: str # 项目描述 - visibility: str # 可见性, 'PUBLIC' 或 'PRIVATE' - createdAt: str # e.g., '2024-11-23T12:28:04.286Z' - updatedAt: str # e.g., '2024-11-23T12:28:04.286Z' - group: GroupType # 项目所属工作空间名称 (workspace) - projectLabels: List[ProjectLabelType] # 项目标签 - _count: Dict[str, int] # 项目的统计信息 - - -class RunType(TypedDict): - cuid: str # 实验CUID, 唯一标识符 - name: str # 实验名称 - createdAt: str # 创建时间, e.g., '2024-11-23T12:28:04.286Z' - description: str # 实验描述 - labels: List[ProjectLabelType] # 实验标签列表 - profile: Dict[str, Dict[str, object]] # 实验配置和摘要信息,包含 'config' 和 'scalar' - state: StateType # 实验状态 - cluster: str # 实验组 - job: str # 任务类型 - runtime: str # 运行时间 - user: UserType # 实验所属用户 - rootExpId: Optional[str] # 祖宗实验对应的实验 cuid,如果为克隆实验则必传 - rootProId: Optional[str] # 祖宗实验对应的项目 cuid,如果为克隆实验则必传 - - -# ------------------------------------- 后端返回信息 ------------------------------------- -class ProjResponseType(TypedDict): - list: List[ProjectType] # 项目列表 - size: int # 每页项目数量 - pages: int # 总页数 - total: int # 总项目数量 - - -class ApiKeyType(TypedDict): - id: int - name: str - createdAt: str - key: str - - -# 私有化部署信息 -class SelfHostedInfoType(TypedDict): - enabled: bool # 是否成功部署 - expired: bool # licence是否过期 - root: bool # 是否为根用户 - plan: Literal["free", "commercial"] # 私有化版本(免费、商业) - seats: int # 余剩席位 diff --git a/swanlab/core_python/api/type/__init__.py b/swanlab/core_python/api/type/__init__.py new file mode 100644 index 000000000..b20bcd6e3 --- /dev/null +++ b/swanlab/core_python/api/type/__init__.py @@ -0,0 +1,21 @@ +""" +@author: Zhou QiYang +@file: __init__.py +@time: 2026/1/11 23:36 +@description: 后端API接口相关类型 +""" + +from .experiment import RunType, ColumnType +from .project import ProjectType, ProjResponseType +from .user import GroupType, IdentityType, ApiKeyType, SelfHostedInfoType + +__all__ = [ + "RunType", + "ColumnType", + "ProjectType", + "ProjResponseType", + "GroupType", + "IdentityType", + "ApiKeyType", + "SelfHostedInfoType", +] diff --git a/swanlab/core_python/api/type/experiment.py b/swanlab/core_python/api/type/experiment.py new file mode 100644 index 000000000..56d17455e --- /dev/null +++ b/swanlab/core_python/api/type/experiment.py @@ -0,0 +1,33 @@ +""" +@author: Zhou QiYang +@file: type.py +@time: 2026/1/10 22:12 +@description: 实验相关后端API接口类型 +""" + +from typing import TypedDict, List, Dict, Optional, Literal + +ColumnType = Literal['STABLE', 'SCALAR', 'CONFIG'] # 列类型 +StateType = Literal['FINISHED', 'CRASHED', 'ABORTED', 'RUNNING'] # 实验状态 + + +# ------------------------------------- 通用类型 ------------------------------------- +class UserType(TypedDict): + username: str # 用户名 + name: str # 用户显示名称 + + +class RunType(TypedDict): + cuid: str # 实验CUID, 唯一标识符 + name: str # 实验名称 + createdAt: str # 创建时间, e.g., '2024-11-23T12:28:04.286Z' + description: str # 实验描述 + labels: List[Dict[str, str]] # 实验标签列表 + profile: Dict[str, Dict[str, object]] # 实验配置和摘要信息,包含 'config' 和 'scalar' + state: StateType # 实验状态 + cluster: str # 实验组 + job: str # 任务类型 + runtime: str # 运行时间 + user: UserType # 实验所属用户 + rootExpId: Optional[str] # 祖宗实验对应的实验 cuid,如果为克隆实验则必传 + rootProId: Optional[str] # 祖宗实验对应的项目 cuid,如果为克隆实验则必传 diff --git a/swanlab/core_python/api/type/project.py b/swanlab/core_python/api/type/project.py new file mode 100644 index 000000000..edc66392d --- /dev/null +++ b/swanlab/core_python/api/type/project.py @@ -0,0 +1,36 @@ +""" +@author: Zhou QiYang +@file: projects.py +@time: 2026/1/10 21:48 +@description: 项目相关后端API接口类型 +""" + +from typing import TypedDict, List, Dict + + +# ------------------------------------- 通用类型 ------------------------------------- +class ProjectLabelType(TypedDict): + name: str # 项目标签名称 + + +# 项目信息 +class ProjectType(TypedDict): + cuid: str # 项目CUID, 唯一标识符 + name: str # 项目名 + path: str # 项目路径 + url: str # 项目URL + description: str # 项目描述 + visibility: str # 可见性, 'PUBLIC' 或 'PRIVATE' + createdAt: str # e.g., '2024-11-23T12:28:04.286Z' + updatedAt: str # e.g., '2024-11-23T12:28:04.286Z' + group: Dict[str, str] # 包含项目所属工作空间名称 (workspace) + projectLabels: List[ProjectLabelType] # 项目标签 + _count: Dict[str, int] # 项目的统计信息 + + +# ------------------------------------- 后端返回信息 ------------------------------------- +class ProjResponseType(TypedDict): + list: List[ProjectType] # 项目列表 + size: int # 每页项目数量 + pages: int # 总页数 + total: int # 总项目数量 diff --git a/swanlab/core_python/api/type/user.py b/swanlab/core_python/api/type/user.py new file mode 100644 index 000000000..f56d55d83 --- /dev/null +++ b/swanlab/core_python/api/type/user.py @@ -0,0 +1,35 @@ +""" +@author: Zhou QiYang +@file: user.py +@time: 2026/1/10 21:46 +@description: 用户相关后端API接口类型 +""" + +from typing import Literal, TypedDict + + +# ------------------------------------- 通用类型 ------------------------------------- +IdentityType = Literal['user', 'root'] + + +# 在项目信息和用户信息的返回结果中,该类型的字段含义不同,注意区分 +class GroupType(TypedDict): + name: str # 组织名称 (用于user.teams) + username: str + + +# ------------------------------------- 后端返回信息 ------------------------------------- +class ApiKeyType(TypedDict): + id: int + name: str + createdAt: str + key: str + + +# 私有化部署信息 +class SelfHostedInfoType(TypedDict): + enabled: bool # 是否成功部署 + expired: bool # licence是否过期 + root: bool # 是否为根用户 + plan: Literal["free", "commercial"] # 私有化版本(免费、商业) + seats: int # 余剩席位 diff --git a/swanlab/core_python/api/user.py b/swanlab/core_python/api/user/__init__.py similarity index 81% rename from swanlab/core_python/api/user.py rename to swanlab/core_python/api/user/__init__.py index 8e0b6e097..53e6257f3 100644 --- a/swanlab/core_python/api/user.py +++ b/swanlab/core_python/api/user/__init__.py @@ -1,17 +1,18 @@ """ @author: Zhou QiYang -@file: user.py -@time: 2026/1/2 21:01 +@file: __init__.py.py +@time: 2026/1/10 21:44 @description: 定义用户相关的后端API接口 """ from typing import TYPE_CHECKING, List +from swanlab.core_python.api.type import GroupType, ApiKeyType +from .self_hosted import get_self_hosted_init, create_user + if TYPE_CHECKING: from swanlab.core_python.client import Client -from swanlab.core_python.api.type import ApiKeyType, GroupType - def create_api_key(client: "Client", *, name: str = None) -> str: """ @@ -63,3 +64,14 @@ def get_latest_api_key(client: "Client") -> ApiKeyType: """ res = client.get(f"/user/key/latest") return res[0] + + +__all__ = [ + "create_api_key", + "delete_api_key", + "get_user_groups", + "get_api_keys", + "get_latest_api_key", + "get_self_hosted_init", + "create_user", +] diff --git a/swanlab/core_python/api/self_hosted.py b/swanlab/core_python/api/user/self_hosted.py similarity index 93% rename from swanlab/core_python/api/self_hosted.py rename to swanlab/core_python/api/user/self_hosted.py index 3a252cab9..55ad620c0 100644 --- a/swanlab/core_python/api/self_hosted.py +++ b/swanlab/core_python/api/user/self_hosted.py @@ -7,11 +7,11 @@ from typing import TYPE_CHECKING +from swanlab.core_python.api.type import SelfHostedInfoType + if TYPE_CHECKING: from swanlab.core_python.client import Client -from .type import SelfHostedInfoType - def get_self_hosted_init(client: "Client") -> SelfHostedInfoType: """ From 126e29c5506d90b2c93ffb5c96a22788a7d588d7 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Mon, 19 Jan 2026 13:06:40 +0800 Subject: [PATCH 05/21] feat: open api unit tests (#1420) * add unit test - add unit test and fix bugs of getting all projects through OpenApi - add unit tests for OpenApi runs and run.history() * add unit test for api.user() * accept suggestions from gemini --- swanlab/api/projects/__init__.py | 8 ++- test/unit/api/test_experiment.py | 21 ++++++ test/unit/api/test_history.py | 117 +++++++++++++++++++++++++++++++ test/unit/api/test_project.py | 27 +++++++ test/unit/api/test_user.py | 41 +++++++++++ test/unit/api/utils.py | 100 ++++++++++++++++++++++++++ 6 files changed, 311 insertions(+), 3 deletions(-) create mode 100644 test/unit/api/test_experiment.py create mode 100644 test/unit/api/test_history.py create mode 100644 test/unit/api/test_project.py create mode 100644 test/unit/api/test_user.py create mode 100644 test/unit/api/utils.py diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py index ead54c51b..6a8ba1cfc 100644 --- a/swanlab/api/projects/__init__.py +++ b/swanlab/api/projects/__init__.py @@ -41,9 +41,10 @@ def __iter__(self) -> Iterator[Project]: # 按用户遍历情况获取项目信息 cur_page = 0 page_size = 20 + projects = [] while True: cur_page += 1 - projects_info: ProjResponseType = get_workspace_projects( + resp: ProjResponseType = get_workspace_projects( self._client, workspace=self._workspace, page=cur_page, @@ -52,10 +53,11 @@ def __iter__(self) -> Iterator[Project]: search=self._search, detail=self._detail, ) - if cur_page * page_size >= projects_info['total']: + projects.extend(resp['list']) + if cur_page * page_size >= resp['total']: break - yield from iter(Project(data=p, web_host=self._web_host) for p in projects_info['list']) + yield from iter(Project(data=p, web_host=self._web_host) for p in projects) __all__ = ["Projects"] diff --git a/test/unit/api/test_experiment.py b/test/unit/api/test_experiment.py new file mode 100644 index 000000000..2f57b9e83 --- /dev/null +++ b/test/unit/api/test_experiment.py @@ -0,0 +1,21 @@ +from unittest.mock import patch, MagicMock + +from swanlab.api.experiments import Experiments +from swanlab.core_python import Client +from tutils.setup import mock_login_info +from utils import create_nested_exps + + +def test_folded_exps(): + """测试嵌套的实验数据能够正确展平""" + mock_exps = Experiments( + MagicMock(spec=Client), + path='test_user/test-project', + login_info=mock_login_info(), + ) + + nested_data = create_nested_exps(groups=2, num_per_group=2) + with patch('swanlab.api.experiments.get_project_experiments') as mock_get_exps: + mock_get_exps.return_value = nested_data + experiments = list(mock_exps) + assert len(experiments) == 4 diff --git a/test/unit/api/test_history.py b/test/unit/api/test_history.py new file mode 100644 index 000000000..0efb3d92a --- /dev/null +++ b/test/unit/api/test_history.py @@ -0,0 +1,117 @@ +""" +@author: Zhou QiYang +@file: test_history.py +@time: 2026/1/11 16:36 +@description: 测试 Experiment.history() 方法,使用 MagicMock 和 monkeypatch 模拟网络请求 +""" + +from unittest.mock import patch, MagicMock + +import pytest +import requests_mock + +from swanlab.api.experiment import Experiment +from swanlab.core_python.client import Client +from swanlab.package import get_host_web, get_host_api +from utils import create_csv_data, create_run_type_data + + +@pytest.fixture +def experiment(): + """创建 Experiment 实例""" + data = create_run_type_data() + return Experiment( + MagicMock(spec=Client), + data=data, + path='test_user/test_project/test_exp', + web_host=get_host_web(), + login_user='test_user', + line_count=100, + ) + + +@pytest.fixture +def metrics_data(): + return [ + (list(range(10)), 'loss', [0.5 - i * 0.05 for i in range(10)]), + (list(range(10)), 'accuracy', [0.5 + i * 0.05 for i in range(10)]), + ] + + +class MockSetup: + """ + 模拟网络请求 + 分别模拟获取 csv 网址和文件内容 + """ + + def __init__(self, metrics_data): + self.metrics_data = metrics_data + + def __enter__(self): + self.mock_get_metrics = patch('swanlab.api.experiment.thread.get_experiment_metrics').start() + self.mock_get_metrics.side_effect = lambda client, expid, key: {'url': f'{get_host_api()}/{key}'} + + self.m = requests_mock.Mocker() + self.m.start() + for metric in self.metrics_data: + self.m.get(f'{get_host_api()}/{metric[1]}', content=create_csv_data(*metric)) + return self + + def __exit__(self, *args): + patch.stopall() + self.m.stop() + + +def test_history_basic(experiment, metrics_data): + """测试使用指定 keys 获取历史数据""" + with MockSetup(metrics_data): + result = experiment.history(keys=['loss', 'accuracy']) + + assert len(result) == 10 + assert 'loss' in result.columns + assert 'accuracy' in result.columns + + +def test_history_with_x_axis(experiment, metrics_data): + """测试使用 x_axis 参数""" + with MockSetup(metrics_data): + result = experiment.history(keys=['loss'], x_axis='accuracy') + + # x_axis 应该作为索引 + assert result.index.name == 'accuracy' + + +def test_history_with_sample(experiment, metrics_data): + """测试使用 sample 参数限制返回行数""" + with MockSetup(metrics_data): + result = experiment.history(keys=['loss'], sample=5) + + # 只返回前 5 行 + assert len(result) == 5 + + +def test_history_dict_mode(experiment, metrics_data): + """测试 pandas=False 时返回 dict 格式""" + with MockSetup(metrics_data): + result = experiment.history(keys=['loss'], pandas=False) + + # 应该返回字典列表 + assert all(isinstance(item, dict) for item in result) + + +def test_full_history(experiment, metrics_data): + """测试 keys 和 x_axis 都为 None 时调用 __full_history""" + with MockSetup(metrics_data): + result = experiment.history() + + assert len(result) == 10 + assert 'loss' in result.columns + assert 'accuracy' in result.columns + + +@pytest.mark.parametrize("keys", ('invalid_keys', ['loss', 123, 'accuracy'])) +def test_history_invalid_keys(experiment, metrics_data, keys): + """测试 keys 参数类型错误的情况,返回空 DataFrame""" + with MockSetup(metrics_data): + result = experiment.history(keys=keys) + assert len(result) == 0 diff --git a/test/unit/api/test_project.py b/test/unit/api/test_project.py new file mode 100644 index 000000000..d41d26732 --- /dev/null +++ b/test/unit/api/test_project.py @@ -0,0 +1,27 @@ +from unittest.mock import patch, MagicMock + +from swanlab.api.projects import Projects +from swanlab.core_python import Client +from swanlab.package import get_host_web +from utils import create_project_data + + +def test_projects(): + """测试能否分页获取所有项目""" + with patch('swanlab.api.projects.get_workspace_projects') as mock_get_projects: + total_pages = 4 + page_size = 20 + + def side_effect(*args, **kwargs): + return create_project_data(size=page_size, pages=kwargs.get("page", 1), total=page_size * total_pages) + + mock_get_projects.side_effect = side_effect + + mock_projects = Projects( + MagicMock(spec=Client), + web_host=get_host_web(), + workspace='test_user', + ) + projects = list(mock_projects) + assert len(projects) == page_size * total_pages + assert mock_get_projects.call_count == total_pages diff --git a/test/unit/api/test_user.py b/test/unit/api/test_user.py new file mode 100644 index 000000000..0f05dfd59 --- /dev/null +++ b/test/unit/api/test_user.py @@ -0,0 +1,41 @@ +from unittest.mock import MagicMock + +import pytest + +from swanlab.api.user import User +from swanlab.core_python import Client +from swanlab.core_python.api.type import IdentityType + + +def create_user(username=None, identity: IdentityType = "user"): + """创建用户对象的辅助函数""" + return User(MagicMock(spec=Client), login_user="test_user", username=username, identity=identity) + + +def test_create_permission(): + """测试普通用户尝试创建用户是否会被拦截""" + user = create_user(identity="user") + assert user.create(username='test_user', password='123456aa') == False + + +@pytest.mark.parametrize( + ("username", "password"), + [ + ('user@name', 'password123'), + ('test_user', 'short'), + ('test_user', '12345678'), + ('test_user', 'ABCDEFGH'), + ], +) +def test_check_create_info(username, password): + """测试无效的用户名或密码""" + root_user = create_user(identity="root") + with pytest.raises(ValueError): + root_user.create(username, password) + + +def test_other_user(): + """测试是否对未开发的功能进行拦截""" + other_user = create_user(identity="root", username="other_user") + assert other_user.generate_api_key() is None + assert other_user.delete_api_key(api_key='test_api_key') == False diff --git a/test/unit/api/utils.py b/test/unit/api/utils.py new file mode 100644 index 000000000..b11ba1dee --- /dev/null +++ b/test/unit/api/utils.py @@ -0,0 +1,100 @@ +import csv +from io import StringIO +from typing import Dict, List + +import nanoid + +from swanlab.core_python.api.type import RunType, ProjResponseType, ProjectType +from swanlab.package import get_host_web + + +def create_run_type_data(cuid=None) -> RunType: + """ + 创建 RunType 类型的测试数据 + """ + return { + 'cuid': cuid if cuid is not None else nanoid.generate('0123456789', 10), + 'name': '', + 'createdAt': '', + 'description': '', + 'labels': [], + 'profile': {'config': {}, 'scalar': {'loss': [], 'accuracy': []}}, + 'state': 'FINISHED', + 'cluster': '', + 'job': '', + 'runtime': '', + 'user': {'username': '', 'name': ''}, + 'rootExpId': None, + 'rootProId': None, + } + + +def create_nested_exps(groups: int = 2, num_per_group: int = 2) -> Dict: + """ + 创建嵌套的实验数据(用于模拟 get_project_experiments 的返回值) + + :param groups: 分组数量 + :param num_per_group: 每组(页)中的实验数量 + :return: 分组情况下 Dict 格式的数据 + """ + result = {} + for i in range(groups): + group_key = f'group_{i}' + runs = [] + for j in range(num_per_group): + run = create_run_type_data(cuid=f'exp_{i}_{j}') + runs.append(run) + result[group_key] = runs + return result + + +def create_project_data(size: int = 20, pages: int = 1, total: int = 20) -> ProjResponseType: + """ + 创建单页项目数据(用于模拟 get_workspace_projects 的返回值,仅生成一页) + + :param size: 项目数 + :param pages: 当前页数 + :param total: 项目总数 + :return: ProjResponseType 格式的数据 + """ + project_list: List[ProjectType] = [] + + for j in range(size): + project: ProjectType = { + 'cuid': f'proj_{pages}_{j}', + 'name': f'project_{pages}_{j}', + 'path': f'test_user/project_{pages}_{j}', + 'url': f'{get_host_web()}/test_user/project_{pages}_{j}', + 'description': '', + 'visibility': 'PRIVATE', + 'createdAt': '', + 'updatedAt': '', + 'group': {'workspace': 'test_user'}, + 'projectLabels': [], + '_count': {}, + } + project_list.append(project) + + return { + 'list': project_list, + 'size': size, + 'pages': pages, + 'total': total, + } + + +def create_csv_data(step_values, metric_name, metric_values): + """ + 创建 CSV 格式的数据 + :param step_values: step 列的值列表 + :param metric_name: 指标名称 + :param metric_values: 指标值列表 + :return: CSV 格式的字节数据 + """ + output = StringIO() + writer = csv.writer(output) + writer.writerow(['step', metric_name]) + for step, value in zip(step_values, metric_values): + writer.writerow([step, value]) + csv_text = output.getvalue() + return csv_text.encode('utf-8') From 5325e28adaba7c64d775606a5fe78557cb794983 Mon Sep 17 00:00:00 2001 From: Bainianzzz <3036349123@qq.com> Date: Tue, 20 Jan 2026 12:16:38 +0800 Subject: [PATCH 06/21] fix bugs in projects --- swanlab/api/projects/__init__.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py index 6a8ba1cfc..952ddd09f 100644 --- a/swanlab/api/projects/__init__.py +++ b/swanlab/api/projects/__init__.py @@ -41,7 +41,6 @@ def __iter__(self) -> Iterator[Project]: # 按用户遍历情况获取项目信息 cur_page = 0 page_size = 20 - projects = [] while True: cur_page += 1 resp: ProjResponseType = get_workspace_projects( @@ -53,11 +52,13 @@ def __iter__(self) -> Iterator[Project]: search=self._search, detail=self._detail, ) - projects.extend(resp['list']) - if cur_page * page_size >= resp['total']: + # 立即 yield 当前页的数据 + for p in resp['list']: + yield Project(data=p, web_host=self._web_host) + + # 检查是否已获取所有数据:当前页数据少于 page_size 或已达到总数 + if len(resp['list']) < page_size or cur_page * page_size >= resp['total']: break - yield from iter(Project(data=p, web_host=self._web_host) for p in projects) - __all__ = ["Projects"] From 943f2279ba5dad326e7b7de3681fe229889b5896 Mon Sep 17 00:00:00 2001 From: Bainianzzz <3036349123@qq.com> Date: Tue, 20 Jan 2026 12:20:15 +0800 Subject: [PATCH 07/21] revert changes --- swanlab/api/projects/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py index 952ddd09f..6a8ba1cfc 100644 --- a/swanlab/api/projects/__init__.py +++ b/swanlab/api/projects/__init__.py @@ -41,6 +41,7 @@ def __iter__(self) -> Iterator[Project]: # 按用户遍历情况获取项目信息 cur_page = 0 page_size = 20 + projects = [] while True: cur_page += 1 resp: ProjResponseType = get_workspace_projects( @@ -52,13 +53,11 @@ def __iter__(self) -> Iterator[Project]: search=self._search, detail=self._detail, ) - # 立即 yield 当前页的数据 - for p in resp['list']: - yield Project(data=p, web_host=self._web_host) - - # 检查是否已获取所有数据:当前页数据少于 page_size 或已达到总数 - if len(resp['list']) < page_size or cur_page * page_size >= resp['total']: + projects.extend(resp['list']) + if cur_page * page_size >= resp['total']: break + yield from iter(Project(data=p, web_host=self._web_host) for p in projects) + __all__ = ["Projects"] From c820aab5456678c297accf035fb47fe26e7cffe3 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Wed, 21 Jan 2026 11:08:09 +0800 Subject: [PATCH 08/21] Opt open api code (#1424) * delete class ApiBase * accept suggestions - place yield inside the loop of the projects - check if creating & deleting user or api_keys inside the api func - add property is_self in user * opt user api func * fix bugs in projects * accept gemini suggestions - fix type error in self_hosted.py and user's api func - add a constant for project page size * opt over-encapsulated code - discard constant PAGE_SIZE - let adapter handle the exception when creating & deleting user info - opt unit test for projects * delete unused utils --- swanlab/api/experiment/__init__.py | 6 +-- swanlab/api/experiments/__init__.py | 5 +- swanlab/api/project/__init__.py | 5 +- swanlab/api/projects/__init__.py | 17 +++--- swanlab/api/user/__init__.py | 60 +++++++++++++-------- swanlab/api/utils.py | 25 +-------- swanlab/core_python/api/user/__init__.py | 14 ++--- swanlab/core_python/api/user/self_hosted.py | 5 +- test/unit/api/test_project.py | 8 +-- test/unit/api/utils.py | 21 ++++---- 10 files changed, 75 insertions(+), 91 deletions(-) diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index 0f5873ac2..e2a70978b 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -8,18 +8,18 @@ from typing import List, Dict, Any from swanlab.api.user import User -from swanlab.api.utils import ApiBase, Label +from swanlab.api.utils import Label from swanlab.core_python.api.type import RunType from swanlab.core_python.client import Client from swanlab.log import swanlog from .thread import HistoryPool -class Experiment(ApiBase): +class Experiment: def __init__( self, client: Client, *, data: RunType, path: str, web_host: str, login_user: str, line_count: int ) -> None: - super().__init__(client) + self._client = client self._data = data self._path = path self._web_host = web_host diff --git a/swanlab/api/experiments/__init__.py b/swanlab/api/experiments/__init__.py index 383d2503a..ea04d2653 100644 --- a/swanlab/api/experiments/__init__.py +++ b/swanlab/api/experiments/__init__.py @@ -8,7 +8,6 @@ from typing import List, Dict, Iterator from swanlab.api.experiment import Experiment -from swanlab.api.utils import ApiBase from swanlab.core_python.api.experiment import get_project_experiments from swanlab.core_python.api.type import RunType from swanlab.core_python.auth.providers.api_key import LoginInfo @@ -28,7 +27,7 @@ def flatten_runs(runs: Dict) -> List: return flat_runs -class Experiments(ApiBase): +class Experiments: """ Container for a collection of Experiment objects. You can iterate over the experiments by for-in loop. @@ -37,7 +36,7 @@ class Experiments(ApiBase): def __init__(self, client: Client, *, path: str, login_info: LoginInfo, filters: Dict[str, object] = None) -> None: if len(path.split('/')) != 2: raise ValueError(f"User's {path} is invaded. Correct path should be like 'username/project'") - super().__init__(client) + self._client = client self._path = path self._web_host = login_info.web_host self._login_user = login_info.username diff --git a/swanlab/api/project/__init__.py b/swanlab/api/project/__init__.py index 926345c27..91f3f295f 100644 --- a/swanlab/api/project/__init__.py +++ b/swanlab/api/project/__init__.py @@ -7,17 +7,16 @@ from typing import List, Dict -from swanlab.api.utils import ApiBase, Label +from swanlab.api.utils import Label from swanlab.core_python.api.type import ProjectType -class Project(ApiBase): +class Project: """ Representing a single project with some of its properties. """ def __init__(self, *, data: ProjectType, web_host: str) -> None: - super().__init__() self._data = data self._web_host = web_host diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py index 6a8ba1cfc..a6afcdd20 100644 --- a/swanlab/api/projects/__init__.py +++ b/swanlab/api/projects/__init__.py @@ -8,13 +8,12 @@ from typing import List, Optional, Iterator from swanlab.api.project import Project -from swanlab.api.utils import ApiBase from swanlab.core_python.api.project import get_workspace_projects from swanlab.core_python.api.type import ProjResponseType from swanlab.core_python.client import Client -class Projects(ApiBase): +class Projects: """ Container for a collection of Project objects. You can iterate over the projects by for-in loop. @@ -30,7 +29,7 @@ def __init__( search: Optional[str] = None, detail: Optional[bool] = True, ) -> None: - super().__init__(client) + self._client = client self._web_host = web_host self._workspace = workspace self._sort = sort @@ -40,24 +39,22 @@ def __init__( def __iter__(self) -> Iterator[Project]: # 按用户遍历情况获取项目信息 cur_page = 0 - page_size = 20 - projects = [] while True: cur_page += 1 resp: ProjResponseType = get_workspace_projects( self._client, workspace=self._workspace, page=cur_page, - size=page_size, + size=20, sort=self._sort, search=self._search, detail=self._detail, ) - projects.extend(resp['list']) - if cur_page * page_size >= resp['total']: - break + for p in resp['list']: + yield Project(data=p, web_host=self._web_host) - yield from iter(Project(data=p, web_host=self._web_host) for p in projects) + if cur_page >= resp['pages']: + break __all__ = ["Projects"] diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py index a7419108e..0c1cb80be 100644 --- a/swanlab/api/user/__init__.py +++ b/swanlab/api/user/__init__.py @@ -9,7 +9,6 @@ from functools import cached_property from typing import List, Optional -from swanlab.api.utils import ApiBase from swanlab.core_python.api.type import ApiKeyType from swanlab.core_python.api.type.user import IdentityType from swanlab.core_python.api.user import ( @@ -23,9 +22,6 @@ from swanlab.core_python.client import Client from swanlab.log import swanlog -STATUS_OK = "OK" -STATUS_CREATED = "Created" - def check_create_info(username: str, password: str) -> bool: # 用户名为大小写字母、数字及-、_组成 @@ -38,32 +34,48 @@ def check_create_info(username: str, password: str) -> bool: return True -class User(ApiBase): +class User: def __init__( self, client: Client, login_user: str = None, username: str = None, identity: IdentityType = 'user' ) -> None: if login_user is None and username is None: raise ValueError("login_user or username are required") - super().__init__(client) + self._client = client self._identity = identity self._api_keys: List[ApiKeyType] = [] - self._cur_username = username or login_user - self._is_other_user = username is not None and username != login_user + self._login_user = login_user + self._cur_username = username or self._login_user @property def username(self) -> str: + """ + User name. (if username is not None, return username, otherwise return login_user) + """ return self._cur_username + @property + def is_self(self) -> bool: + """ + Check if the user is the current login user. + """ + return self._cur_username == self._login_user + @cached_property def teams(self) -> List[str]: + """ + List of teams the user belongs to. + """ resp = get_user_groups(self._client, username=self._cur_username) return [r['name'] for r in resp] # TODO: 管理员可以对指定用户的api_key进行操作 @cached_property def api_keys(self) -> List[str]: - if self._is_other_user: + """ + List of api keys the user has. + """ + if not self.is_self: swanlog.warning("Getting api keys of other users has not been supported yet.") return [] else: @@ -71,43 +83,49 @@ def api_keys(self) -> List[str]: return [r['key'] for r in self._api_keys] def _refresh_api_keys(self): + """ + Refresh the list of api keys. + """ del self.api_keys self._api_keys = get_api_keys(self._client) def generate_api_key(self, description: str = None) -> Optional[str]: - if self._is_other_user: + """ + Generate a new api key. + """ + if not self.is_self: swanlog.warning("Generating api key of other users has not been supported yet.") return None else: api_key: Optional[ApiKeyType] = None res = create_api_key(self._client, name=description) - if res == STATUS_CREATED: + if res: api_key = get_latest_api_key(self._client) return api_key['key'] if api_key else None def delete_api_key(self, api_key: str) -> bool: - if self._is_other_user: + """ + Delete an api key. + """ + if not self.is_self: swanlog.warning("Deleting api key of other users has not been supported yet.") return False else: self._refresh_api_keys() for key in self._api_keys: if key['key'] == api_key: - res = delete_api_key(self._client, key_id=key['id']) - if res == STATUS_OK: - return True + return delete_api_key(self._client, key_id=key['id']) return False def create(self, username: str, password: str) -> bool: - if self._identity != "root" or self._is_other_user: + """ + Create a new user. (Only root user can create other user) + """ + if self._identity != "root" or not self.is_self: swanlog.warning(f"{self._cur_username} is not allowed to create other user.") return False check_create_info(username, password) - resp = create_user(self._client, username=username, password=password) - if resp == STATUS_CREATED: - return True - else: - raise False + return create_user(self._client, username=username, password=password) __all__ = ["User"] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 3eb87337b..637cab0b5 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -6,9 +6,6 @@ """ from dataclasses import dataclass -from typing import Dict - -from swanlab.core_python import Client @dataclass @@ -24,24 +21,4 @@ def __str__(self) -> str: return self.name -class ApiBase: - def __init__(self, client: Client = None): - self._client = client - - @property - def __dict__(self) -> Dict[str, object]: - """ - Return a dictionary containing all @property fields. - """ - result = {} - cls = type(self) - for attr_name in dir(cls): - if attr_name.startswith('_'): - continue - attr = getattr(cls, attr_name, None) - if isinstance(attr, property): - result[attr_name] = self.__getattribute__(attr_name) - return result - - -__all__ = ['ApiBase', 'Label'] +__all__ = ['Label'] diff --git a/swanlab/core_python/api/user/__init__.py b/swanlab/core_python/api/user/__init__.py index 53e6257f3..1603b6415 100644 --- a/swanlab/core_python/api/user/__init__.py +++ b/swanlab/core_python/api/user/__init__.py @@ -14,28 +14,22 @@ from swanlab.core_python.client import Client -def create_api_key(client: "Client", *, name: str = None) -> str: +def create_api_key(client: "Client", *, name: str = None) -> None: """ 创建一个api_key,完成后返回成功信息 :param client: 已登录的客户端实例 :param name: api_key 的名称 """ - if name is not None: - data = {'name': name} - res = client.post(f"/user/key", data=data) - else: - res = client.post(f"/user/key") - return res[0] + client.post(f"/user/key", data={'name': name} if name else None) -def delete_api_key(client: "Client", *, key_id: int) -> str: +def delete_api_key(client: "Client", *, key_id: int) -> None: """ 删除指定id的api_key :param client: 已登录的客户端实例 :param key_id: api_key的id """ - res = client.delete(f"/user/key/{key_id}") - return res[0] + client.delete(f"/user/key/{key_id}") def get_user_groups(client: "Client", *, username: str) -> List[GroupType]: diff --git a/swanlab/core_python/api/user/self_hosted.py b/swanlab/core_python/api/user/self_hosted.py index 55ad620c0..e6ad58b41 100644 --- a/swanlab/core_python/api/user/self_hosted.py +++ b/swanlab/core_python/api/user/self_hosted.py @@ -22,7 +22,7 @@ def get_self_hosted_init(client: "Client") -> SelfHostedInfoType: return res[0] -def create_user(client: "Client", *, username: str, password: str) -> str: +def create_user(client: "Client", *, username: str, password: str) -> None: """ 根用户添加用户 :param client: 已登录的客户端实例 @@ -30,5 +30,4 @@ def create_user(client: "Client", *, username: str, password: str) -> str: :param password: 用户密码 """ data = {"users": [{"username": username, "password": password}]} - res = client.post("/self_hosted/users", data=data) - return res[0] + client.post("/self_hosted/users", data=data) diff --git a/test/unit/api/test_project.py b/test/unit/api/test_project.py index d41d26732..c91ef97ee 100644 --- a/test/unit/api/test_project.py +++ b/test/unit/api/test_project.py @@ -9,11 +9,11 @@ def test_projects(): """测试能否分页获取所有项目""" with patch('swanlab.api.projects.get_workspace_projects') as mock_get_projects: - total_pages = 4 + total = 80 page_size = 20 def side_effect(*args, **kwargs): - return create_project_data(size=page_size, pages=kwargs.get("page", 1), total=page_size * total_pages) + return create_project_data(page=kwargs.get("page", 1), total=total) mock_get_projects.side_effect = side_effect @@ -23,5 +23,5 @@ def side_effect(*args, **kwargs): workspace='test_user', ) projects = list(mock_projects) - assert len(projects) == page_size * total_pages - assert mock_get_projects.call_count == total_pages + assert len(projects) == total + assert mock_get_projects.call_count == (total + page_size - 1) // page_size diff --git a/test/unit/api/utils.py b/test/unit/api/utils.py index b11ba1dee..63a4b61dc 100644 --- a/test/unit/api/utils.py +++ b/test/unit/api/utils.py @@ -48,23 +48,24 @@ def create_nested_exps(groups: int = 2, num_per_group: int = 2) -> Dict: return result -def create_project_data(size: int = 20, pages: int = 1, total: int = 20) -> ProjResponseType: +def create_project_data(page: int = 1, total: int = 20) -> ProjResponseType: """ - 创建单页项目数据(用于模拟 get_workspace_projects 的返回值,仅生成一页) + 创建分页项目数据(用于模拟 get_workspace_projects 的返回值) - :param size: 项目数 - :param pages: 当前页数 + :param page: 当前页数 :param total: 项目总数 :return: ProjResponseType 格式的数据 """ + page_size = 20 + pages = (total + page_size - 1) // page_size project_list: List[ProjectType] = [] - for j in range(size): + for j in range(page_size): project: ProjectType = { - 'cuid': f'proj_{pages}_{j}', - 'name': f'project_{pages}_{j}', - 'path': f'test_user/project_{pages}_{j}', - 'url': f'{get_host_web()}/test_user/project_{pages}_{j}', + 'cuid': f'proj_{page}_{j}', + 'name': f'project_{page}_{j}', + 'path': f'test_user/project_{page}_{j}', + 'url': f'{get_host_web()}/test_user/project_{page}_{j}', 'description': '', 'visibility': 'PRIVATE', 'createdAt': '', @@ -77,7 +78,7 @@ def create_project_data(size: int = 20, pages: int = 1, total: int = 20) -> Proj return { 'list': project_list, - 'size': size, + 'size': page_size, 'pages': pages, 'total': total, } From fc07300b627af24b84170f632bd6f29c3294159f Mon Sep 17 00:00:00 2001 From: Bainianzzz <3036349123@qq.com> Date: Wed, 21 Jan 2026 11:34:14 +0800 Subject: [PATCH 09/21] resolve conflict --- swanlab/core_python/client/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/swanlab/core_python/client/__init__.py b/swanlab/core_python/client/__init__.py index 9adac78e3..b90d1153d 100644 --- a/swanlab/core_python/client/__init__.py +++ b/swanlab/core_python/client/__init__.py @@ -192,22 +192,22 @@ def get(self, url: str, params: dict = None): resp = self.__session.get(url, params=params) return decode_response(resp), resp - def delete(self, url: str): + def patch(self, url: str, data: dict = None): """ - delete请求 + patch请求 """ url = self.__login_info.api_host + url self.__before_request() - resp = self.__session.delete(url) + resp = self.__session.patch(url, json=data) return decode_response(resp), resp - def patch(self, url: str, data: dict = None): + def delete(self, url: str, retries: Optional[int] = None): """ - patch请求 + delete请求 """ url = self.__login_info.api_host + url self.__before_request() - resp = self.__session.patch(url, json=data) + resp = self.__session.delete(url, retries=retries) return decode_response(resp), resp # ---------------------------------- 训练相关接口 ---------------------------------- From c43a8a4544c4cdfe45215a2e75b2f8c6db3e8d63 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Sun, 25 Jan 2026 13:01:04 +0800 Subject: [PATCH 10/21] check self hosted info using a wrapper instead of checking when init (#1425) * use a wrapper to check self_hosted info for some functions - update unit tests for user (add self_hosted context) * fix bug in OpenApi.list_workspaces() * accept gemini suggestions * raise value error when not being self_hosted * raise value error when self_hosted is not available - raise error when accessing unsupported user functions --- swanlab/api/__init__.py | 25 +--------- swanlab/api/deprecated/group.py | 1 + swanlab/api/user/__init__.py | 37 ++++++--------- swanlab/api/utils.py | 47 +++++++++++++++++-- test/unit/api/test_user.py | 81 +++++++++++++++++++++++++-------- 5 files changed, 125 insertions(+), 66 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 0abd51a5b..0f6f2ff51 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -9,9 +9,7 @@ from swanlab.core_python import auth, Client from swanlab.core_python.api.experiment import get_single_experiment, get_project_experiments -from swanlab.core_python.api.type import IdentityType -from swanlab.core_python.api.user import get_self_hosted_init -from swanlab.error import KeyFileError, ApiError +from swanlab.error import KeyFileError from swanlab.log import swanlog from swanlab.package import HostFormatter, get_key from .deprecated import OpenApi @@ -46,32 +44,13 @@ def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, we self._client: Client = Client(self._login_info) self._web_host = self._login_info.web_host - # 尝试获取私有化服务信息,如果不是私有化服务,则会报错退出,因为指定user功能仅供私有化用户使用 - try: - self._self_hosted_info = get_self_hosted_init(self._client) - except ApiError: - swanlog.warning("You haven't launched a swanlab self-hosted instance. Some usages are not available.") - self._self_hosted_info = None - - self._identity: IdentityType = 'user' - if self._self_hosted_info is not None and self._self_hosted_info["plan"] == 'commercial': - self._identity = 'root' if self._self_hosted_info['root'] else 'user' - - if self._self_hosted_info is not None: - if not self._self_hosted_info["enabled"]: - swanlog.warning("SwanLab self-hosted instance hasn't been ready yet.") - if self._self_hosted_info["expired"]: - swanlog.warning("SwanLab self-hosted instance has expired.") - def user(self, username: str = None) -> User: """ 获取用户实例,用于操作用户相关信息 :param username: 指定用户名,如果为 None,则返回当前登录用户 :return: User 实例,可对当前/指定用户进行操作 """ - return User( - client=self._client, login_user=self._login_info.username, username=username, identity=self._identity - ) + return User(client=self._client, login_user=self._login_info.username, username=username) def projects( self, diff --git a/swanlab/api/deprecated/group.py b/swanlab/api/deprecated/group.py index 71ac04b9d..ab6d73b87 100644 --- a/swanlab/api/deprecated/group.py +++ b/swanlab/api/deprecated/group.py @@ -29,3 +29,4 @@ def list_workspaces(self) -> ApiResponse[list]: } for item in groups ] + return resp diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py index 0c1cb80be..4b907c8eb 100644 --- a/swanlab/api/user/__init__.py +++ b/swanlab/api/user/__init__.py @@ -9,8 +9,8 @@ from functools import cached_property from typing import List, Optional +from swanlab.api.utils import self_hosted from swanlab.core_python.api.type import ApiKeyType -from swanlab.core_python.api.type.user import IdentityType from swanlab.core_python.api.user import ( get_user_groups, get_api_keys, @@ -20,7 +20,6 @@ ) from swanlab.core_python.api.user.self_hosted import create_user from swanlab.core_python.client import Client -from swanlab.log import swanlog def check_create_info(username: str, password: str) -> bool: @@ -35,14 +34,11 @@ def check_create_info(username: str, password: str) -> bool: class User: - def __init__( - self, client: Client, login_user: str = None, username: str = None, identity: IdentityType = 'user' - ) -> None: + def __init__(self, client: Client, login_user: str = None, username: str = None) -> None: if login_user is None and username is None: raise ValueError("login_user or username are required") self._client = client - self._identity = identity self._api_keys: List[ApiKeyType] = [] self._login_user = login_user self._cur_username = username or self._login_user @@ -76,8 +72,7 @@ def api_keys(self) -> List[str]: List of api keys the user has. """ if not self.is_self: - swanlog.warning("Getting api keys of other users has not been supported yet.") - return [] + raise ValueError("Getting api keys of other users has not been supported yet.") else: self._api_keys = get_api_keys(self._client) return [r['key'] for r in self._api_keys] @@ -94,13 +89,10 @@ def generate_api_key(self, description: str = None) -> Optional[str]: Generate a new api key. """ if not self.is_self: - swanlog.warning("Generating api key of other users has not been supported yet.") - return None + raise ValueError("Generating api key of other users has not been supported yet.") else: - api_key: Optional[ApiKeyType] = None - res = create_api_key(self._client, name=description) - if res: - api_key = get_latest_api_key(self._client) + create_api_key(self._client, name=description) + api_key = get_latest_api_key(self._client) return api_key['key'] if api_key else None def delete_api_key(self, api_key: str) -> bool: @@ -108,24 +100,25 @@ def delete_api_key(self, api_key: str) -> bool: Delete an api key. """ if not self.is_self: - swanlog.warning("Deleting api key of other users has not been supported yet.") - return False + raise ValueError("Deleting api key of other users has not been supported yet.") else: self._refresh_api_keys() for key in self._api_keys: if key['key'] == api_key: - return delete_api_key(self._client, key_id=key['id']) + delete_api_key(self._client, key_id=key['id']) + return True return False - def create(self, username: str, password: str) -> bool: + @self_hosted("root") + def create(self, username: str, password: str) -> Optional[bool]: """ Create a new user. (Only root user can create other user) """ - if self._identity != "root" or not self.is_self: - swanlog.warning(f"{self._cur_username} is not allowed to create other user.") - return False + if not self.is_self: + raise ValueError(f"{self._cur_username} is not allowed to create other user.") check_create_info(username, password) - return create_user(self._client, username=username, password=password) + create_user(self._client, username=username, password=password) + return True __all__ = ["User"] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 637cab0b5..6b630ea61 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -1,11 +1,17 @@ """ @author: Zhou QiYang -@file: __init__.py +@file: utils.py @time: 2026/1/11 23:44 -@description: OpenApi 中的基础对象 +@description: OpenApi 中的基础对象与通用工具 """ from dataclasses import dataclass +from functools import wraps + +from swanlab.core_python.api.type import IdentityType +from swanlab.core_python.api.user import get_self_hosted_init +from swanlab.core_python.client import Client +from swanlab.error import ApiError @dataclass @@ -21,4 +27,39 @@ def __str__(self) -> str: return self.name -__all__ = ['Label'] +def self_hosted(identity: IdentityType = "user"): + """ + 用于需要在私有化环境下使用的接口的装饰器。 + :param identity: 用户身份,默认为 "user",如果为 "root",则会额外验证是否为根用户。 + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + client = getattr(self, '_client', None) + if not isinstance(client, Client): + raise AttributeError("There is no SwanLab client instance.") + + # 1. 尝试获取私有化服务信息 + try: + self_hosted_info = get_self_hosted_init(client) + except ApiError: + raise ValueError("You haven't launched a swanlab self-hosted instance. This usages are not available.") + + if not self_hosted_info.get("enabled", False): + raise ValueError("SwanLab self-hosted instance hasn't been ready yet.") + if self_hosted_info.get("expired", True): + raise ValueError("SwanLab self-hosted instance has expired.") + + # 2. 检测用户权限(商业版root用户功能) + if identity == "root" and not self_hosted_info.get("root", False): + raise ValueError("You don't have permission to perform this action. Please login as a root user") + + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +__all__ = ['Label', 'self_hosted'] diff --git a/test/unit/api/test_user.py b/test/unit/api/test_user.py index 0f05dfd59..fb06ab682 100644 --- a/test/unit/api/test_user.py +++ b/test/unit/api/test_user.py @@ -1,41 +1,86 @@ -from unittest.mock import MagicMock +from unittest.mock import patch, MagicMock import pytest from swanlab.api.user import User from swanlab.core_python import Client -from swanlab.core_python.api.type import IdentityType +from swanlab.error import ApiError -def create_user(username=None, identity: IdentityType = "user"): +def create_user(username=None): """创建用户对象的辅助函数""" - return User(MagicMock(spec=Client), login_user="test_user", username=username, identity=identity) + return User(MagicMock(spec=Client), login_user="test_user", username=username) -def test_create_permission(): - """测试普通用户尝试创建用户是否会被拦截""" - user = create_user(identity="user") - assert user.create(username='test_user', password='123456aa') == False +class SelfHosted: + """私有化部署的上下文管理器""" + + def __init__(self, start=False, **kwargs): + self.start = start # 是否启动私有化部署 + self.enabled = kwargs.get('enabled', True) + self.expired = kwargs.get('expired', False) + self.root = kwargs.get('root', False) + self.plan = kwargs.get('plan', 'free') + self.seats = kwargs.get('seats', 99) + + def __enter__(self): + self.mock_get_metrics = patch('swanlab.api.utils.get_self_hosted_init').start() + if not self.start: + self.mock_get_metrics.side_effect = ApiError() + else: + self.mock_get_metrics.return_value = { + 'enabled': self.enabled, # 是否成功部署 + 'expired': self.expired, # licence是否过期 + 'root': self.root, # 是否为根用户 + 'plan': self.plan, # 私有化版本(免费、商业) + 'seats': self.seats, # 余剩席位 + } + return self + + def __exit__(self, *args): + patch.stopall() + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, # 私有化未启动 + {'start': True, 'enabled': False}, # 启动,但未启用 + {'start': True, 'expired': True}, # licence过期 + {'start': True}, # 启动,正常,但不是root + {'username': 'test_other_user'}, + ], +) +def test_create_permission(kwargs): + """测试尝试创建用户是否会被拦截""" + user = create_user(kwargs.get('username', None)) + with SelfHosted(**kwargs): + with pytest.raises(ValueError): + user.create(username='test_user', password='123456aa') @pytest.mark.parametrize( ("username", "password"), [ - ('user@name', 'password123'), - ('test_user', 'short'), - ('test_user', '12345678'), - ('test_user', 'ABCDEFGH'), + ('user@name', 'password123'), # 无效的用户名 + ('test_user', 'short'), # 无效密码(密码长度小于8) + ('test_user', '12345678'), # 无效密码(全是数字) + ('test_user', 'ABCDEFGH'), # 有效密码(全是字母) ], ) def test_check_create_info(username, password): """测试无效的用户名或密码""" - root_user = create_user(identity="root") - with pytest.raises(ValueError): - root_user.create(username, password) + root_user = create_user() + with SelfHosted(start=True, root=True): + with pytest.raises(ValueError): + root_user.create(username, password) def test_other_user(): """测试是否对未开发的功能进行拦截""" - other_user = create_user(identity="root", username="other_user") - assert other_user.generate_api_key() is None - assert other_user.delete_api_key(api_key='test_api_key') == False + other_user = create_user(username="other_user") + with SelfHosted(start=True, root=True): + with pytest.raises(ValueError): + assert other_user.generate_api_key() is None + with pytest.raises(ValueError): + assert other_user.delete_api_key(api_key='test_api_key') == False From bfc03334bf81d5240d95714e0672efcdeedfce88 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Tue, 27 Jan 2026 22:26:13 +0800 Subject: [PATCH 11/21] feat: workspace and json() (#1428) --- swanlab/api/__init__.py | 30 +++++++++- swanlab/api/experiment/__init__.py | 8 ++- swanlab/api/project/__init__.py | 8 ++- swanlab/api/user/__init__.py | 8 ++- swanlab/api/utils.py | 16 ++++- swanlab/api/workspace/__init__.py | 72 +++++++++++++++++++++++ swanlab/api/workspaces/__init__.py | 30 ++++++++++ swanlab/core_python/api/type/__init__.py | 3 + swanlab/core_python/api/type/workspace.py | 20 +++++++ swanlab/core_python/api/user/__init__.py | 15 ++++- 10 files changed, 202 insertions(+), 8 deletions(-) create mode 100644 swanlab/api/workspace/__init__.py create mode 100644 swanlab/api/workspaces/__init__.py create mode 100644 swanlab/core_python/api/type/workspace.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 0f6f2ff51..02ce044b3 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -17,6 +17,8 @@ from .experiments import Experiments from .projects import Projects from .user import User +from .workspace import Workspace +from .workspaces import Workspaces class Api: @@ -43,6 +45,7 @@ def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None, we # 一个OpenApi对应一个client,可创建多个api获取从不同的client获取不同账号下的实验信息 self._client: Client = Client(self._login_info) self._web_host = self._login_info.web_host + self._login_user = self._login_info.username def user(self, username: str = None) -> User: """ @@ -50,7 +53,7 @@ def user(self, username: str = None) -> User: :param username: 指定用户名,如果为 None,则返回当前登录用户 :return: User 实例,可对当前/指定用户进行操作 """ - return User(client=self._client, login_user=self._login_info.username, username=username) + return User(client=self._client, login_user=self._login_user, username=username) def projects( self, @@ -107,9 +110,32 @@ def run( data=data[0], path=proj_path, web_host=self._web_host, - login_user=self._login_info.username, + login_user=self._login_user, line_count=1, ) + def workspaces( + self, + username: str = None, + ): + """ + 获取当前登录用户的工作空间迭代器 + 当username为其他用户时,可以作为visitor访问其工作空间 + """ + if username is None: + username = self._login_user + return Workspaces(self._client, username=username) + + def workspace( + self, + username: str = None, + ): + """ + 获取当前登录用户的工作空间 + """ + if username is None: + username = self._login_user + return Workspace(client=self._client, workspace=username) + __all__ = ["Api", "OpenApi"] diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index e2a70978b..6c415ab7c 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -8,7 +8,7 @@ from typing import List, Dict, Any from swanlab.api.user import User -from swanlab.api.utils import Label +from swanlab.api.utils import Label, get_properties from swanlab.core_python.api.type import RunType from swanlab.core_python.client import Client from swanlab.log import swanlog @@ -139,6 +139,12 @@ def root_pro_id(self) -> str: """ return self._data['rootProId'] + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) + def __full_history(self) -> Any: """ Get all metric keys' data of the experiment with timestamp. diff --git a/swanlab/api/project/__init__.py b/swanlab/api/project/__init__.py index 91f3f295f..4e4cf62d1 100644 --- a/swanlab/api/project/__init__.py +++ b/swanlab/api/project/__init__.py @@ -7,7 +7,7 @@ from typing import List, Dict -from swanlab.api.utils import Label +from swanlab.api.utils import Label, get_properties from swanlab.core_python.api.type import ProjectType @@ -90,3 +90,9 @@ def count(self) -> Dict[str, int]: experiments, contributors, children, collaborators, runningExps. """ return self._data['_count'] + + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py index 4b907c8eb..587473b28 100644 --- a/swanlab/api/user/__init__.py +++ b/swanlab/api/user/__init__.py @@ -9,7 +9,7 @@ from functools import cached_property from typing import List, Optional -from swanlab.api.utils import self_hosted +from swanlab.api.utils import self_hosted, get_properties from swanlab.core_python.api.type import ApiKeyType from swanlab.core_python.api.user import ( get_user_groups, @@ -77,6 +77,12 @@ def api_keys(self) -> List[str]: self._api_keys = get_api_keys(self._client) return [r['key'] for r in self._api_keys] + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) + def _refresh_api_keys(self): """ Refresh the list of api keys. diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 6b630ea61..d9b536fdb 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from functools import wraps +from typing import Dict from swanlab.core_python.api.type import IdentityType from swanlab.core_python.api.user import get_self_hosted_init @@ -62,4 +63,17 @@ def wrapper(self, *args, **kwargs): return decorator -__all__ = ['Label', 'self_hosted'] +def get_properties(obj: object) -> Dict[str, object]: + """递归获取实例中所有property""" + result = dict() + for name in dir(obj): + if name.startswith("_"): + continue + if isinstance(getattr(type(obj), name, None), property): + value = getattr(obj, name, None) + result[name] = value if type(value).__module__ == 'builtins' else get_properties(value) + + return result + + +__all__ = ['Label', 'self_hosted', 'get_properties'] diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace/__init__.py new file mode 100644 index 000000000..f47e975c2 --- /dev/null +++ b/swanlab/api/workspace/__init__.py @@ -0,0 +1,72 @@ +""" +@author: Zhou Qiyang +@file: __init__.py +@time: 2026/1/27 20:43 +@description: 工作空间 +""" + +from typing import Literal, Dict + +from swanlab.api.utils import get_properties +from swanlab.core_python import Client +from swanlab.core_python.api.type import WorkspaceType, RoleType +from swanlab.core_python.api.user import get_workspace_info + + +class Workspace: + def __init__(self, *, data: WorkspaceType = None, client: Client = None, workspace: str = None) -> None: + self._client = client + + if data is None: + if workspace is None or client is None: + raise ValueError('workspace or client cannot both None') + data = get_workspace_info(self._client, workspace=workspace) + self._data = data + + @property + def name(self) -> str: + """ + Workspace display name. + """ + return self._data['name'] + + @property + def username(self) -> str: + """ + Workspace name. + """ + return self._data['username'] + + @property + def workspace_type(self) -> Literal['TEAM', 'PERSON']: + """ + Workspace type. + """ + return self._data['type'] + + @property + def profile(self) -> Dict[str, str]: + """ + Workspace profile. + """ + return self._data['profile'] + + @property + def comment(self) -> str: + """ + Workspace comment. + """ + return self._data['comment'] + + @property + def role(self) -> RoleType: + """ + Current login user's role in the workspace (only display when type=TEAM). + """ + return self._data['role'] + + def json(self): + """ + JSON-serializable dict of all @property values. + """ + return get_properties(self) diff --git a/swanlab/api/workspaces/__init__.py b/swanlab/api/workspaces/__init__.py new file mode 100644 index 000000000..6aaa2537a --- /dev/null +++ b/swanlab/api/workspaces/__init__.py @@ -0,0 +1,30 @@ +""" +@author: Zhou Qiyang +@file: __init__.py +@time: 2026/1/27 18:13 +@description: 工作空间迭代器 +""" + +from typing import Iterator + +from swanlab.api.workspace import Workspace +from swanlab.core_python import Client +from swanlab.core_python.api.user import get_user_groups, get_workspace_info + + +class Workspaces: + def __init__(self, client: Client, *, username: str) -> None: + self._client = client + self._username = username + + def get_all_workspaces(self, username: str = None): + """Get all workspaces of specific user (defaults to current user)""" + cur_username = username if username else self._username + resp = get_user_groups(self._client, username=cur_username) + groups = [r['username'] for r in resp] + return [cur_username] + groups + + def __iter__(self) -> Iterator[Workspace]: + for space in self.get_all_workspaces(): + data = get_workspace_info(self._client, workspace=space) + yield Workspace(data=data) diff --git a/swanlab/core_python/api/type/__init__.py b/swanlab/core_python/api/type/__init__.py index b20bcd6e3..4bed69f2a 100644 --- a/swanlab/core_python/api/type/__init__.py +++ b/swanlab/core_python/api/type/__init__.py @@ -8,6 +8,7 @@ from .experiment import RunType, ColumnType from .project import ProjectType, ProjResponseType from .user import GroupType, IdentityType, ApiKeyType, SelfHostedInfoType +from .workspace import WorkspaceType, RoleType __all__ = [ "RunType", @@ -18,4 +19,6 @@ "IdentityType", "ApiKeyType", "SelfHostedInfoType", + "WorkspaceType", + "RoleType", ] diff --git a/swanlab/core_python/api/type/workspace.py b/swanlab/core_python/api/type/workspace.py new file mode 100644 index 000000000..a689b1d28 --- /dev/null +++ b/swanlab/core_python/api/type/workspace.py @@ -0,0 +1,20 @@ +""" +@author: Zhou Qiyang +@file: workspace +@time: 2026/1/27 20:46 +@description: 工作空间相关类型 +""" + +from typing import TypedDict, Literal, Dict + +RoleType = Literal['VISITOR', 'VIEWER', 'MEMBER', 'OWNER'] + + +class WorkspaceType(TypedDict): + name: str + username: str + profile: Dict[str, str] + type: Literal['TEAM', 'PERSON'] # 返回的信息类型 + comment: str # 组织或者个人的描述 + # 组织信息特有 + role: RoleType # 组织成员的角色 diff --git a/swanlab/core_python/api/user/__init__.py b/swanlab/core_python/api/user/__init__.py index 1603b6415..7ccaf18d4 100644 --- a/swanlab/core_python/api/user/__init__.py +++ b/swanlab/core_python/api/user/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, List -from swanlab.core_python.api.type import GroupType, ApiKeyType +from swanlab.core_python.api.type import GroupType, ApiKeyType, WorkspaceType from .self_hosted import get_self_hosted_init, create_user if TYPE_CHECKING: @@ -34,7 +34,7 @@ def delete_api_key(client: "Client", *, key_id: int) -> None: def get_user_groups(client: "Client", *, username: str) -> List[GroupType]: """ - 获取当前全部的api_key + 获取用户加入的组织 :param client: 已登录的客户端实例 :param username: 用户名称 """ @@ -42,6 +42,16 @@ def get_user_groups(client: "Client", *, username: str) -> List[GroupType]: return res[0] +def get_workspace_info(client: "Client", *, workspace: str) -> WorkspaceType: + """ + 获取指定工作空间的信息 + :param client: 已登录的客户端实例 + :param workspace: 工作空间名称 + """ + res = client.get(f"/group/{workspace}") + return res[0] + + def get_api_keys(client: "Client") -> List[ApiKeyType]: """ 获取当前全部的api_key @@ -64,6 +74,7 @@ def get_latest_api_key(client: "Client") -> ApiKeyType: "create_api_key", "delete_api_key", "get_user_groups", + "get_workspace_info", "get_api_keys", "get_latest_api_key", "get_self_hosted_init", From 620330dadf74be25ab8d50cabf2a15306b3afae2 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Fri, 30 Jan 2026 21:35:31 +0800 Subject: [PATCH 12/21] feat: replace the workspace field in the project object with a workspace object (#1430) * Replace the workspace field in the project object with a workspace object * cached the workspace object --- swanlab/api/project/__init__.py | 14 +++++++++----- swanlab/api/projects/__init__.py | 2 +- swanlab/api/workspace/__init__.py | 3 +++ swanlab/api/workspaces/__init__.py | 3 +++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/swanlab/api/project/__init__.py b/swanlab/api/project/__init__.py index 4e4cf62d1..12c24c4a7 100644 --- a/swanlab/api/project/__init__.py +++ b/swanlab/api/project/__init__.py @@ -5,10 +5,13 @@ @description: OpenApi 中的项目对象 """ +from functools import cached_property from typing import List, Dict from swanlab.api.utils import Label, get_properties +from swanlab.api.workspace import Workspace from swanlab.core_python.api.type import ProjectType +from swanlab.core_python.client import Client class Project: @@ -16,7 +19,8 @@ class Project: Representing a single project with some of its properties. """ - def __init__(self, *, data: ProjectType, web_host: str) -> None: + def __init__(self, client: Client, *, data: ProjectType, web_host: str) -> None: + self._client = client self._data = data self._web_host = web_host @@ -69,12 +73,12 @@ def updated_at(self) -> str: """ return self._data['updatedAt'] - @property - def workspace(self) -> str: + @cached_property + def workspace(self) -> Workspace: """ - Project workspace name. + Project workspace object. """ - return self._data["group"]["username"] + return Workspace(client=self._client, workspace=self._data["group"]["username"]) @property def labels(self) -> List[Label]: diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py index a6afcdd20..bba2b5523 100644 --- a/swanlab/api/projects/__init__.py +++ b/swanlab/api/projects/__init__.py @@ -51,7 +51,7 @@ def __iter__(self) -> Iterator[Project]: detail=self._detail, ) for p in resp['list']: - yield Project(data=p, web_host=self._web_host) + yield Project(self._client, data=p, web_host=self._web_host) if cur_page >= resp['pages']: break diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace/__init__.py index f47e975c2..7a6ce77fb 100644 --- a/swanlab/api/workspace/__init__.py +++ b/swanlab/api/workspace/__init__.py @@ -70,3 +70,6 @@ def json(self): JSON-serializable dict of all @property values. """ return get_properties(self) + + +__all__ = ['Workspace'] diff --git a/swanlab/api/workspaces/__init__.py b/swanlab/api/workspaces/__init__.py index 6aaa2537a..6f28127b3 100644 --- a/swanlab/api/workspaces/__init__.py +++ b/swanlab/api/workspaces/__init__.py @@ -28,3 +28,6 @@ def __iter__(self) -> Iterator[Workspace]: for space in self.get_all_workspaces(): data = get_workspace_info(self._client, workspace=space) yield Workspace(data=data) + + +__all__ = ['Workspaces'] From 27c3baf54868d9b3cce4f6da9268fd0b95c3b180 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Mon, 2 Feb 2026 12:26:38 +0800 Subject: [PATCH 13/21] fix: fix bugs when getting info when 'profile' param is None (#1440) * fix bugs when getting exp info when profile is None - use getattr() to get profile and related info * accept gemini suggestions --- swanlab/api/experiment/__init__.py | 7 +++---- swanlab/api/workspace/__init__.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index 6c415ab7c..e5f34f92e 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -73,14 +73,14 @@ def config(self) -> Dict[str, object]: """ Experiment configuration. Can be used as filter in the format of 'config.' """ - return self._data['profile']['config'] + return self._data.get('profile', dict()).get('config', dict()) @property def summary(self) -> Dict[str, object]: """ Experiment metrics data. Can be used as filter in the format of 'summary.' """ - return self._data['profile']['scalar'] + return self._data.get('profile', dict()).get('scalar', dict()) @property def state(self) -> str: @@ -115,8 +115,7 @@ def metric_keys(self) -> List[str]: """ List of metric keys. """ - summary_keys = self.summary.keys() - return list(summary_keys) + return list(self.summary.keys()) @property def history_line_count(self) -> int: diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace/__init__.py index 7a6ce77fb..371178af6 100644 --- a/swanlab/api/workspace/__init__.py +++ b/swanlab/api/workspace/__init__.py @@ -49,7 +49,7 @@ def profile(self) -> Dict[str, str]: """ Workspace profile. """ - return self._data['profile'] + return self._data.get('profile', dict()) @property def comment(self) -> str: From bdfca37d60a145d8b076e374cda2d60888841163 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:42:43 +0800 Subject: [PATCH 14/21] feat: get all user when login as root user (#1438) * get all user when login as root user - opt @self_hosted decorator identity checking - add api func get_users() * accept gemini suggestions * move get all users function from user to api - opt identity checking logic * Refactor user listing to use Users iterator class Replaced the users() method in Api to return a new Users iterator class instead of yielding User objects directly. Added swanlab/api/users/__init__.py to encapsulate user iteration logic, improving code organization and separation of concerns. * Add user pagination test and utility function Added a test for paginated user retrieval in test_user.py and implemented create_user_data in utils.py to simulate paginated user data. Also clarified the docstring in create_user for self-hosted admin restriction. --------- Co-authored-by: Kang Li --- swanlab/api/__init__.py | 10 ++++++ swanlab/api/user/__init__.py | 2 +- swanlab/api/users/__init__.py | 38 +++++++++++++++++++++ swanlab/api/utils.py | 7 ++-- swanlab/core_python/api/user/__init__.py | 3 +- swanlab/core_python/api/user/self_hosted.py | 14 +++++++- test/unit/api/test_user.py | 23 +++++++++++++ test/unit/api/utils.py | 24 +++++++++++++ 8 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 swanlab/api/users/__init__.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 02ce044b3..24fd935cc 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -17,6 +17,8 @@ from .experiments import Experiments from .projects import Projects from .user import User +from .users import Users +from .utils import self_hosted from .workspace import Workspace from .workspaces import Workspaces @@ -55,6 +57,14 @@ def user(self, username: str = None) -> User: """ return User(client=self._client, login_user=self._login_user, username=username) + @self_hosted("root") + def users(self) -> Users: + """ + 超级管理员获取所有用户 + :return: User 实例,可对当前/指定用户进行操作 + """ + return Users(self._client, login_user=self._login_user) + def projects( self, workspace: str, diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py index 587473b28..073212925 100644 --- a/swanlab/api/user/__init__.py +++ b/swanlab/api/user/__init__.py @@ -17,8 +17,8 @@ create_api_key, get_latest_api_key, delete_api_key, + create_user, ) -from swanlab.core_python.api.user.self_hosted import create_user from swanlab.core_python.client import Client diff --git a/swanlab/api/users/__init__.py b/swanlab/api/users/__init__.py new file mode 100644 index 000000000..b205347b8 --- /dev/null +++ b/swanlab/api/users/__init__.py @@ -0,0 +1,38 @@ +""" +@author: cunyue +@file: __init__.py +@time: 2026/2/3 13:30 +@description: OpenApi 中的用户对象迭代器 +""" + +from typing import Iterator + +from swanlab.api.user import User +from swanlab.core_python import Client +from swanlab.core_python.api.user import get_users + + +class Users: + """ + Container for a collection of User objects. + You can iterate over the users by for-in loop. + """ + + def __init__(self, client: Client, *, login_user: str) -> None: + self._client = client + self._login_user = login_user + + def __iter__(self) -> Iterator[User]: + cur_page = 0 + while True: + cur_page += 1 + resp = get_users( + self._client, + page=cur_page, + size=20, + ) + for u in resp['list']: + yield User(self._client, login_user=self._login_user, username=u['username']) + + if cur_page >= resp['pages']: + break diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index d9b536fdb..196d564b7 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -53,8 +53,11 @@ def wrapper(self, *args, **kwargs): raise ValueError("SwanLab self-hosted instance has expired.") # 2. 检测用户权限(商业版root用户功能) - if identity == "root" and not self_hosted_info.get("root", False): - raise ValueError("You don't have permission to perform this action. Please login as a root user") + if identity == 'root': + if not self_hosted_info.get('root', False): + raise ValueError("You don't have permission to perform this action. Please login as a root user") + if not getattr(self, 'is_self', True): + raise ValueError('This root-only action can only be performed by the logged-in root user.') return func(self, *args, **kwargs) diff --git a/swanlab/core_python/api/user/__init__.py b/swanlab/core_python/api/user/__init__.py index 7ccaf18d4..9756620c6 100644 --- a/swanlab/core_python/api/user/__init__.py +++ b/swanlab/core_python/api/user/__init__.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, List from swanlab.core_python.api.type import GroupType, ApiKeyType, WorkspaceType -from .self_hosted import get_self_hosted_init, create_user +from .self_hosted import get_self_hosted_init, create_user, get_users if TYPE_CHECKING: from swanlab.core_python.client import Client @@ -79,4 +79,5 @@ def get_latest_api_key(client: "Client") -> ApiKeyType: "get_latest_api_key", "get_self_hosted_init", "create_user", + "get_users", ] diff --git a/swanlab/core_python/api/user/self_hosted.py b/swanlab/core_python/api/user/self_hosted.py index e6ad58b41..fee2da394 100644 --- a/swanlab/core_python/api/user/self_hosted.py +++ b/swanlab/core_python/api/user/self_hosted.py @@ -24,10 +24,22 @@ def get_self_hosted_init(client: "Client") -> SelfHostedInfoType: def create_user(client: "Client", *, username: str, password: str) -> None: """ - 根用户添加用户 + 添加用户(私有化管理员限定) :param client: 已登录的客户端实例 :param username: 用户名 :param password: 用户密码 """ data = {"users": [{"username": username, "password": password}]} client.post("/self_hosted/users", data=data) + + +def get_users(client: "Client", *, page: int = 1, size: int = 20): + """ + 分页获取用户(管理员限定) + :param client: 已登录的客户端实例 + :param page: 页码 + :param size: 每页大小 + """ + params = {"page": page, "size": size} + res = client.get("/self_hosted/users", params=params) + return res[0] diff --git a/test/unit/api/test_user.py b/test/unit/api/test_user.py index fb06ab682..24231610f 100644 --- a/test/unit/api/test_user.py +++ b/test/unit/api/test_user.py @@ -3,8 +3,10 @@ import pytest from swanlab.api.user import User +from swanlab.api.users import Users from swanlab.core_python import Client from swanlab.error import ApiError +from utils import create_user_data def create_user(username=None): @@ -84,3 +86,24 @@ def test_other_user(): assert other_user.generate_api_key() is None with pytest.raises(ValueError): assert other_user.delete_api_key(api_key='test_api_key') == False + + +def test_users(): + """测试能否分页获取所有用户""" + with patch('swanlab.api.users.get_users') as mock_get_users: + total = 80 + page_size = 20 + + def side_effect(*args, **kwargs): + return create_user_data(page=kwargs.get("page", 1), total=total) + + mock_get_users.side_effect = side_effect + client = MagicMock(spec=Client) + users = Users(client, login_user="test_user") + + user_list = list(users) + assert len(user_list) == total + for i, user in enumerate(user_list): + assert user.username == f'user_{i}' + + assert mock_get_users.call_count == (total + page_size - 1) // page_size diff --git a/test/unit/api/utils.py b/test/unit/api/utils.py index 63a4b61dc..7f06bf693 100644 --- a/test/unit/api/utils.py +++ b/test/unit/api/utils.py @@ -84,6 +84,30 @@ def create_project_data(page: int = 1, total: int = 20) -> ProjResponseType: } +def create_user_data(page: int = 1, total: int = 20) -> Dict: + """ + 创建分页用户数据(用于模拟 get_users 的返回值) + + :param page: 当前页数 + :param total: 用户总数 + :return: 包含 list, pages, total 等字段的字典 + """ + page_size = 20 + pages = (total + page_size - 1) // page_size + user_list = [] + + for j in range(page_size): + user_list.append({ + 'username': f'user_{ (page - 1) * page_size + j }' + }) + + return { + 'list': user_list, + 'pages': pages, + 'total': total, + } + + def create_csv_data(step_values, metric_name, metric_values): """ 创建 CSV 格式的数据 From 3c117eda4e3ae0fff54817c4f4109c78d208a6b8 Mon Sep 17 00:00:00 2001 From: Bainianzzz <95992848+Bainianzzz@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:18:09 +0800 Subject: [PATCH 15/21] feat: get single project through openApi (#1439) * get single project through openApi - add api func get_project_info * accept gemini suggestions * Refactor workspace and project param names to 'path' Replaces 'workspace' parameters with 'path' across API, core_python, and test modules for consistency. Updates related function calls, class initializations, and docstrings to reflect the new naming convention. * Refactor Workspace initialization and usage Updated Workspace class to require client and data parameters, removing path-based initialization logic. Refactored related API, Project, and Workspaces classes to fetch workspace data before instantiating Workspace, ensuring consistent and explicit data handling. --------- Co-authored-by: Kang Li --- swanlab/api/__init__.py | 24 +++++++++++++++++---- swanlab/api/project/__init__.py | 8 ++++--- swanlab/api/projects/__init__.py | 8 +++---- swanlab/api/workspace/__init__.py | 8 +------ swanlab/api/workspaces/__init__.py | 4 ++-- swanlab/core_python/api/project/__init__.py | 20 ++++++++++++----- swanlab/core_python/api/user/__init__.py | 6 +++--- test/unit/api/test_project.py | 2 +- 8 files changed, 51 insertions(+), 29 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 24fd935cc..617cdffe1 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -15,12 +15,15 @@ from .deprecated import OpenApi from .experiment import Experiment from .experiments import Experiments +from .project import Project from .projects import Projects from .user import User from .users import Users from .utils import self_hosted from .workspace import Workspace from .workspaces import Workspaces +from ..core_python.api.project import get_project_info +from ..core_python.api.user import get_workspace_info class Api: @@ -67,14 +70,14 @@ def users(self) -> Users: def projects( self, - workspace: str, + path: str, sort: Optional[List[str]] = None, search: Optional[str] = None, detail: Optional[bool] = True, ) -> Projects: """ 获取指定工作空间(组织)下的所有项目信息 - :param workspace: 工作空间(组织)名称 + :param path: 工作空间(组织)名称 'username' :param sort: 排序方式,可选 :param search: 搜索关键词,可选 :param detail: 是否返回详细信息,可选 @@ -83,12 +86,24 @@ def projects( return Projects( self._client, web_host=self._web_host, - workspace=workspace, + path=path, sort=sort, search=search, detail=detail, ) + def project( + self, + path: str, + ) -> Project: + """ + 获取指定工作空间(组织)下的指定项目信息 + :param path: 项目路径 'username/project' + :return: Project 实例,单个项目的信息 + """ + data = get_project_info(self._client, path=path) + return Project(self._client, web_host=self._web_host, data=data) + def runs(self, path: str, filters: Dict[str, object] = None) -> Experiments: """ 获取指定项目下的所有实验信息 @@ -145,7 +160,8 @@ def workspace( """ if username is None: username = self._login_user - return Workspace(client=self._client, workspace=username) + data = get_workspace_info(self._client, path=username) + return Workspace(self._client, data=data) __all__ = ["Api", "OpenApi"] diff --git a/swanlab/api/project/__init__.py b/swanlab/api/project/__init__.py index 12c24c4a7..66985d2a0 100644 --- a/swanlab/api/project/__init__.py +++ b/swanlab/api/project/__init__.py @@ -11,6 +11,7 @@ from swanlab.api.utils import Label, get_properties from swanlab.api.workspace import Workspace from swanlab.core_python.api.type import ProjectType +from swanlab.core_python.api.user import get_workspace_info from swanlab.core_python.client import Client @@ -19,10 +20,10 @@ class Project: Representing a single project with some of its properties. """ - def __init__(self, client: Client, *, data: ProjectType, web_host: str) -> None: + def __init__(self, client: Client, *, web_host: str, data: ProjectType) -> None: self._client = client - self._data = data self._web_host = web_host + self._data = data @property def name(self) -> str: @@ -78,7 +79,8 @@ def workspace(self) -> Workspace: """ Project workspace object. """ - return Workspace(client=self._client, workspace=self._data["group"]["username"]) + data = get_workspace_info(self._client, path=self._data["group"]["username"]) + return Workspace(self._client, data=data) @property def labels(self) -> List[Label]: diff --git a/swanlab/api/projects/__init__.py b/swanlab/api/projects/__init__.py index bba2b5523..eb3ac297b 100644 --- a/swanlab/api/projects/__init__.py +++ b/swanlab/api/projects/__init__.py @@ -24,14 +24,14 @@ def __init__( client: Client, *, web_host: str, - workspace: str, + path: str, sort: Optional[List[str]] = None, search: Optional[str] = None, detail: Optional[bool] = True, ) -> None: self._client = client self._web_host = web_host - self._workspace = workspace + self._path = path self._sort = sort self._search = search self._detail = detail @@ -43,7 +43,7 @@ def __iter__(self) -> Iterator[Project]: cur_page += 1 resp: ProjResponseType = get_workspace_projects( self._client, - workspace=self._workspace, + path=self._path, page=cur_page, size=20, sort=self._sort, @@ -51,7 +51,7 @@ def __iter__(self) -> Iterator[Project]: detail=self._detail, ) for p in resp['list']: - yield Project(self._client, data=p, web_host=self._web_host) + yield Project(self._client, web_host=self._web_host, data=p) if cur_page >= resp['pages']: break diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace/__init__.py index 371178af6..8cc15295a 100644 --- a/swanlab/api/workspace/__init__.py +++ b/swanlab/api/workspace/__init__.py @@ -10,17 +10,11 @@ from swanlab.api.utils import get_properties from swanlab.core_python import Client from swanlab.core_python.api.type import WorkspaceType, RoleType -from swanlab.core_python.api.user import get_workspace_info class Workspace: - def __init__(self, *, data: WorkspaceType = None, client: Client = None, workspace: str = None) -> None: + def __init__(self, client: Client, *, data: WorkspaceType): self._client = client - - if data is None: - if workspace is None or client is None: - raise ValueError('workspace or client cannot both None') - data = get_workspace_info(self._client, workspace=workspace) self._data = data @property diff --git a/swanlab/api/workspaces/__init__.py b/swanlab/api/workspaces/__init__.py index 6f28127b3..dbe7c108f 100644 --- a/swanlab/api/workspaces/__init__.py +++ b/swanlab/api/workspaces/__init__.py @@ -26,8 +26,8 @@ def get_all_workspaces(self, username: str = None): def __iter__(self) -> Iterator[Workspace]: for space in self.get_all_workspaces(): - data = get_workspace_info(self._client, workspace=space) - yield Workspace(data=data) + data = get_workspace_info(self._client, path=space) + yield Workspace(self._client, data=data) __all__ = ['Workspaces'] diff --git a/swanlab/core_python/api/project/__init__.py b/swanlab/core_python/api/project/__init__.py index 8e807b658..09c4a0f47 100644 --- a/swanlab/core_python/api/project/__init__.py +++ b/swanlab/core_python/api/project/__init__.py @@ -7,7 +7,7 @@ from typing import Optional, List, TYPE_CHECKING -from swanlab.core_python.api.type import ProjResponseType +from swanlab.core_python.api.type import ProjResponseType, ProjectType if TYPE_CHECKING: from swanlab.core_python.client import Client @@ -16,7 +16,7 @@ def get_workspace_projects( client: "Client", *, - workspace: str, + path: str, page: int = 1, size: int = 20, sort: Optional[List[str]] = None, @@ -26,7 +26,7 @@ def get_workspace_projects( """ 获取指定页数和条件下的项目信息 :param client: 已登录的客户端实例 - :param workspace: 工作空间名称 + :param path: 工作空间名称 :param page: 页码 :param size: 每页项目数量 :param sort: 排序规则, 可选 @@ -40,8 +40,18 @@ def get_workspace_projects( 'search': search, 'detail': detail, } - res = client.get(f"/project/{workspace}", params=dict(params)) + res = client.get(f"/project/{path}", params=dict(params)) return res[0] -__all__ = ["get_workspace_projects"] +def get_project_info(client: "Client", *, path: str) -> ProjectType: + """ + 获取指定路径的项目信息 + :param client: 已登录的客户端实例 + :param path: 项目路径 'username/project' + """ + res = client.get(f"/project/{path}") + return res[0] + + +__all__ = ["get_workspace_projects", "get_project_info"] diff --git a/swanlab/core_python/api/user/__init__.py b/swanlab/core_python/api/user/__init__.py index 9756620c6..1aac2a8d4 100644 --- a/swanlab/core_python/api/user/__init__.py +++ b/swanlab/core_python/api/user/__init__.py @@ -42,13 +42,13 @@ def get_user_groups(client: "Client", *, username: str) -> List[GroupType]: return res[0] -def get_workspace_info(client: "Client", *, workspace: str) -> WorkspaceType: +def get_workspace_info(client: "Client", *, path: str) -> WorkspaceType: """ 获取指定工作空间的信息 :param client: 已登录的客户端实例 - :param workspace: 工作空间名称 + :param path: 工作空间名称 """ - res = client.get(f"/group/{workspace}") + res = client.get(f"/group/{path}") return res[0] diff --git a/test/unit/api/test_project.py b/test/unit/api/test_project.py index c91ef97ee..3c07fa991 100644 --- a/test/unit/api/test_project.py +++ b/test/unit/api/test_project.py @@ -20,7 +20,7 @@ def side_effect(*args, **kwargs): mock_projects = Projects( MagicMock(spec=Client), web_host=get_host_web(), - workspace='test_user', + path='test_user', ) projects = list(mock_projects) assert len(projects) == total From 3ed8ed85fea6a38d3452170204e5c7ab5b1726e6 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Wed, 4 Feb 2026 03:33:45 +0800 Subject: [PATCH 16/21] Inline experiment history CSV fetch; remove pool Remove the threaded HistoryPool helper and simplify Experiment.history by fetching per-key CSVs inline. Deleted swanlab/api/experiment/thread.py and removed its import. The history method now normalizes keys (allowing a single string), appends x_axis, deduplicates, retrieves CSV URLs via the client, reads them with pandas, strips a common prefix and trailing "_step" from column names, and concatenates DataFrames (inner join when x_axis is present, outer otherwise). When x_axis is given, timestamp columns are dropped, the x_axis column is validated and moved to the front. The method now raises if keys is None and supports sample trimming. Note: concurrency via HistoryPool is removed and the pandas parameter is effectively unused (the method returns a pandas DataFrame). --- swanlab/api/experiment/__init__.py | 100 ++++++++++++++++++----------- swanlab/api/experiment/thread.py | 95 --------------------------- 2 files changed, 61 insertions(+), 134 deletions(-) delete mode 100644 swanlab/api/experiment/thread.py diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index e5f34f92e..9be1de9d3 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -12,7 +12,7 @@ from swanlab.core_python.api.type import RunType from swanlab.core_python.client import Client from swanlab.log import swanlog -from .thread import HistoryPool + class Experiment: @@ -144,24 +144,6 @@ def json(self): """ return get_properties(self) - def __full_history(self) -> Any: - """ - Get all metric keys' data of the experiment with timestamp. - """ - try: - import pandas as pd - except ImportError: - raise TypeError( - "OpenApi requires pandas to implement the run.history(). Please install with 'pip install pandas'." - ) - - df = pd.DataFrame() - if len(self.metric_keys) >= 1: - pool = HistoryPool(self._client, self.id, keys=self.metric_keys) - df = pool.execute() - - return df - def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None, pandas: bool = True) -> Any: """ Get specific metric data of the experiment. @@ -178,7 +160,8 @@ def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None print(exp.history(keys=['loss'], sample=20, x_axis='t/accuracy')) Returns: - t/accuracy loss + t/accuracy loss + step 0 0.310770 0.525776 1 0.642817 0.479186 2 0.646031 0.362428 @@ -193,27 +176,66 @@ def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None raise TypeError( "OpenApi requires pandas to implement the run.history(). Please install with 'pip install pandas'." ) - - if keys is not None and not isinstance(keys, list): - swanlog.warning('keys must be specified as a list') - return pd.DataFrame() - elif keys is not None and len(keys) and not all(isinstance(k, str) for k in keys): - swanlog.warning('keys must be a list of string') - return pd.DataFrame() - - if keys is None and x_axis is None: - # x轴与keys都未指定时,获取所有指标数据 - df = self.__full_history() - else: - # 使用线程池并发获取所有的key的指标数据 - pool = HistoryPool(self._client, self.id, keys=keys, x_axis=x_axis) - df = pool.execute() - - # 截取前sample行 + + if keys is None: + raise ValueError("keys cannot be None") + + if isinstance(keys, str): + keys = [keys] + if x_axis is not None: + keys += [x_axis] + + # 去重 keys + keys = list(set(keys)) + dfs = [] + prefix = "" + for idx, key in enumerate(keys): + resp = self._client.get(f"/experiment/{self.id}/column/csv", params={"key": key}) + + url:str = resp[0].get("url", "") + df = pd.read_csv(url, index_col=0) + + if idx == 0: + # 从第一列名提取 prefix,例如 "t0707-02:17-loss_step" 中提取 "t0707-02:17-" + first_col = df.columns[0] + suffix = f"{key}_" + if suffix in first_col: + prefix = first_col.split(suffix)[0] # 结果为 "t0707-02:17-" + else: + prefix = "" + + if prefix: + df.columns = [ + col[len(prefix):].removesuffix("_step") if col.startswith(prefix) else col.removesuffix("_step") + for col in df.columns + ] + else: + df.columns = [col.removesuffix("_step") for col in df.columns] + + dfs.append(df) + + # 如果有 x_axis,使用 inner join(交集);否则使用 outer join(并集) + join_type = "inner" if x_axis is not None else "outer" + result_df = pd.concat(dfs, axis=1, join=join_type) + + # 如果有 x_axis,进行特殊处理 + if x_axis is not None: + # 去掉所有带 _timestamp 后缀的列 + timestamp_cols = [col for col in result_df.columns if col.endswith("_timestamp")] + result_df = result_df.drop(columns=timestamp_cols) + + # 确保 x_axis 列存在 + if x_axis not in result_df.columns: + raise ValueError(f"x_axis '{x_axis}' not found in the result DataFrame") + + # 将 x_axis 列放到第一列 + cols = [x_axis] + [col for col in result_df.columns if col != x_axis] + result_df = result_df[cols] + if sample is not None: - df = df.head(sample) + result_df = result_df.head(sample) - return df if pandas else df.to_dict(orient='records') + return result_df __all__ = ['Experiment'] diff --git a/swanlab/api/experiment/thread.py b/swanlab/api/experiment/thread.py deleted file mode 100644 index 50e12fc9f..000000000 --- a/swanlab/api/experiment/thread.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -@author: Zhou QiYang -@file: thread.py -@time: 2025/12/30 15:08 -@description: 用于并发请求实验指标数据的封装类 -""" - -from concurrent.futures import ThreadPoolExecutor -from io import BytesIO -from typing import List, Any - -import requests - -from swanlab.core_python.api.experiment import get_experiment_metrics -from swanlab.core_python.client import Client -from swanlab.log import swanlog - - -class HistoryPool: - - def __init__(self, client: Client, expid: str, *, keys: List[str], x_axis: str = None, num_threads: int = 10): - try: - import pandas as pd - except ImportError: - raise TypeError("Api requires pandas to init the HistoryPool. Please install with 'pip install pandas'.") - - self._client = client - self._expid = expid - self._keys = keys - self._x_axis = x_axis - if self._x_axis is not None: - self._keys = [self._x_axis] + [k for k in self._keys if k != self._x_axis] - self._num_threads = num_threads - - # 使用 _results 字典收集每个 key 的 DataFrame,最后统一按顺序合并到 _history - self._executor = ThreadPoolExecutor(max_workers=self._num_threads) - self._futures = [] - self._results = dict() - self._history = pd.DataFrame() - - def _task(self, key: str): - """ - 处理单个key,获取对应csv - """ - import pandas as pd - - try: - csv_df = pd.DataFrame() - resp = get_experiment_metrics(self._client, expid=self._expid, key=key) - # 从返回网址中解析csv内容 - with requests.get(resp['url']) as response: - csv_df = pd.read_csv(BytesIO(response.content)) - return key, csv_df - except Exception as e: - swanlog.warning(f'Error processing key {key} in experiment {self._expid}: {e}') - return key, pd.DataFrame() - - def execute(self) -> Any: - if not self._keys: - return self._history - - # 将所有key提交到线程池 - for key in self._keys: - future = self._executor.submit(self._task, key) - self._futures.append((key, future)) - - # 等待所有任务完成并收集结果 - for key, future in self._futures: - try: - result_key, csv_df = future.result() - self._results[result_key] = csv_df - except Exception as e: - swanlog.warning(f'Error getting result for key {key} in experiment {self._expid}: {e}') - self._executor.shutdown(wait=True) - - # 按照 keys 的顺序统一合并 - for key in self._keys: - if key not in self._results: - continue - key_df = self._results[key] - step_col, value_col = key_df.columns[:2] # step 列, 指标值列 - - # 将 step 设为索引,其后基于索引自动对齐 - if self._history.empty: - self._history = key_df.set_index(step_col) - else: - self._history[value_col] = key_df.set_index(step_col)[value_col] - - # 若指定x轴,重置索引 - if self._x_axis is not None: - self._history = self._history.reset_index().iloc[:, 1:] - self._history = self._history.set_index(self._history.columns[0]) - else: - self._history.rename(columns={'step': '_step'}, inplace=True) - return self._history From c758ca58eb663dd526da6a4238d740e11ce411d9 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Wed, 4 Feb 2026 03:36:22 +0800 Subject: [PATCH 17/21] Rename history() to metrics() in Experiment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove Experiment.config and Experiment.summary properties and rename the Experiment.history(...) method to Experiment.metrics(...). Update docstrings, example usage and the pandas import error message to reference metrics. This is an API change — update callers to use exp.metrics(...) and note that config/summary accessors were removed. --- swanlab/api/experiment/__init__.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index 9be1de9d3..b1c61a216 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -68,20 +68,6 @@ def labels(self) -> List[Label]: """ return [Label(label['name']) for label in self._data['labels']] - @property - def config(self) -> Dict[str, object]: - """ - Experiment configuration. Can be used as filter in the format of 'config.' - """ - return self._data.get('profile', dict()).get('config', dict()) - - @property - def summary(self) -> Dict[str, object]: - """ - Experiment metrics data. Can be used as filter in the format of 'summary.' - """ - return self._data.get('profile', dict()).get('scalar', dict()) - @property def state(self) -> str: """ @@ -144,7 +130,7 @@ def json(self): """ return get_properties(self) - def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None, pandas: bool = True) -> Any: + def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None, pandas: bool = True) -> Any: """ Get specific metric data of the experiment. :param keys: List of metric keys to obtain. If None, all metrics keys will be used. @@ -157,7 +143,7 @@ def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None ```python api = swanlab.OpenApi() exp = api.run(path="username/project/expid") # You can get expid from api.runs() - print(exp.history(keys=['loss'], sample=20, x_axis='t/accuracy')) + print(exp.metrics(keys=['loss'], sample=20, x_axis='t/accuracy')) Returns: t/accuracy loss @@ -174,7 +160,7 @@ def history(self, keys: List[str] = None, x_axis: str = None, sample: int = None import pandas as pd except ImportError: raise TypeError( - "OpenApi requires pandas to implement the run.history(). Please install with 'pip install pandas'." + "OpenApi requires pandas to implement the run.metrics(). Please install with 'pip install pandas'." ) if keys is None: From db7786f3b950bac491694580f94541528bbfc8d6 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Wed, 4 Feb 2026 04:12:54 +0800 Subject: [PATCH 18/21] Refactor DataFrame joining and x_axis handling Introduce x_axis_state to centralize x_axis checks and only append x_axis to keys when applicable. Replace pd.concat(join_type) with iterative outer .join(...).sort_index() to build the result DataFrame, and keep timestamp columns dropped as before. After reordering columns to put x_axis first, filter out rows where x_axis is NaN. Minor cleanup and readability improvements. --- swanlab/api/experiment/__init__.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index b1c61a216..b6948594f 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -165,10 +165,12 @@ def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None if keys is None: raise ValueError("keys cannot be None") + + x_axis_state = x_axis is not None and x_axis != "step" if isinstance(keys, str): keys = [keys] - if x_axis is not None: + if x_axis_state: keys += [x_axis] # 去重 keys @@ -200,12 +202,14 @@ def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None dfs.append(df) - # 如果有 x_axis,使用 inner join(交集);否则使用 outer join(并集) - join_type = "inner" if x_axis is not None else "outer" - result_df = pd.concat(dfs, axis=1, join=join_type) + # 拼接整张表 + result_df = dfs[0] + if len(dfs) > 1: + for df in dfs[1:]: + result_df = result_df.join(df, how='outer').sort_index() # 如果有 x_axis,进行特殊处理 - if x_axis is not None: + if x_axis_state: # 去掉所有带 _timestamp 后缀的列 timestamp_cols = [col for col in result_df.columns if col.endswith("_timestamp")] result_df = result_df.drop(columns=timestamp_cols) @@ -217,6 +221,7 @@ def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None # 将 x_axis 列放到第一列 cols = [x_axis] + [col for col in result_df.columns if col != x_axis] result_df = result_df[cols] + result_df = result_df[result_df[x_axis].notna()] if sample is not None: result_df = result_df.head(sample) From 94224b1e22893c3fdbcf59380e2d6334d7adb1ee Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Wed, 4 Feb 2026 04:22:35 +0800 Subject: [PATCH 19/21] fix test --- swanlab/api/experiment/__init__.py | 18 +++++++---- .../api/{test_history.py => test_metrics.py} | 30 +++++++++---------- 2 files changed, 27 insertions(+), 21 deletions(-) rename test/unit/api/{test_history.py => test_metrics.py} (76%) diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index b6948594f..3d043f9fb 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -162,14 +162,20 @@ def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None raise TypeError( "OpenApi requires pandas to implement the run.metrics(). Please install with 'pip install pandas'." ) - - if keys is None: - raise ValueError("keys cannot be None") - - x_axis_state = x_axis is not None and x_axis != "step" - if isinstance(keys, str): + if keys is None: + swanlog.warning('keys cannot be None') + return pd.DataFrame() + elif not isinstance(keys, list): + swanlog.warning('keys must be specified as a list') + return pd.DataFrame() + elif isinstance(keys, str): keys = [keys] + elif len(keys) and not all(isinstance(k, str) for k in keys): + swanlog.warning('keys must be a list of string') + return pd.DataFrame() + + x_axis_state = x_axis is not None and x_axis != "step" if x_axis_state: keys += [x_axis] diff --git a/test/unit/api/test_history.py b/test/unit/api/test_metrics.py similarity index 76% rename from test/unit/api/test_history.py rename to test/unit/api/test_metrics.py index 0efb3d92a..eb634d9d2 100644 --- a/test/unit/api/test_history.py +++ b/test/unit/api/test_metrics.py @@ -1,8 +1,8 @@ """ @author: Zhou QiYang -@file: test_history.py +@file: test_metrics.py @time: 2026/1/11 16:36 -@description: 测试 Experiment.history() 方法,使用 MagicMock 和 monkeypatch 模拟网络请求 +@description: 测试 Experiment.metrics() 方法,使用 MagicMock 和 monkeypatch 模拟网络请求 """ from unittest.mock import patch, MagicMock @@ -62,47 +62,47 @@ def __exit__(self, *args): self.m.stop() -def test_history_basic(experiment, metrics_data): +def test_metrics_basic(experiment, metrics_data): """测试使用指定 keys 获取历史数据""" with MockSetup(metrics_data): - result = experiment.history(keys=['loss', 'accuracy']) + result = experiment.metrics(keys=['loss', 'accuracy']) assert len(result) == 10 assert 'loss' in result.columns assert 'accuracy' in result.columns -def test_history_with_x_axis(experiment, metrics_data): +def test_metrics_with_x_axis(experiment, metrics_data): """测试使用 x_axis 参数""" with MockSetup(metrics_data): - result = experiment.history(keys=['loss'], x_axis='accuracy') + result = experiment.metrics(keys=['loss'], x_axis='accuracy') # x_axis 应该作为索引 assert result.index.name == 'accuracy' -def test_history_with_sample(experiment, metrics_data): +def test_metrics_with_sample(experiment, metrics_data): """测试使用 sample 参数限制返回行数""" with MockSetup(metrics_data): - result = experiment.history(keys=['loss'], sample=5) + result = experiment.metrics(keys=['loss'], sample=5) # 只返回前 5 行 assert len(result) == 5 -def test_history_dict_mode(experiment, metrics_data): +def test_metrics_dict_mode(experiment, metrics_data): """测试 pandas=False 时返回 dict 格式""" with MockSetup(metrics_data): - result = experiment.history(keys=['loss'], pandas=False) + result = experiment.metrics(keys=['loss'], pandas=False) # 应该返回字典列表 assert all(isinstance(item, dict) for item in result) -def test_full_history(experiment, metrics_data): - """测试 keys 和 x_axis 都为 None 时调用 __full_history""" +def test_full_metrics(experiment, metrics_data): + """测试 keys 和 x_axis 都为 None 时调用 __full_metrics""" with MockSetup(metrics_data): - result = experiment.history() + result = experiment.metrics() assert len(result) == 10 assert 'loss' in result.columns @@ -110,8 +110,8 @@ def test_full_history(experiment, metrics_data): @pytest.mark.parametrize("keys", ('invalid_keys', ['loss', 123, 'accuracy'])) -def test_history_invalid_keys(experiment, metrics_data, keys): +def test_metrics_invalid_keys(experiment, metrics_data, keys): """测试 keys 参数类型错误的情况,返回空 DataFrame""" with MockSetup(metrics_data): - result = experiment.history(keys=keys) + result = experiment.metrics(keys=keys) assert len(result) == 0 From 1d8c58c489c55341750af4ee9eb98b7ffd8455a7 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Wed, 4 Feb 2026 04:35:19 +0800 Subject: [PATCH 20/21] fix test --- test/unit/api/test_metrics.py | 61 +++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/test/unit/api/test_metrics.py b/test/unit/api/test_metrics.py index eb634d9d2..6fcf07fa8 100644 --- a/test/unit/api/test_metrics.py +++ b/test/unit/api/test_metrics.py @@ -8,7 +8,7 @@ from unittest.mock import patch, MagicMock import pytest -import requests_mock +import pandas as pd from swanlab.api.experiment import Experiment from swanlab.core_python.client import Client @@ -41,30 +41,42 @@ def metrics_data(): class MockSetup: """ 模拟网络请求 - 分别模拟获取 csv 网址和文件内容 + 直接 mock client.get 返回 URL,然后 mock pd.read_csv 返回 DataFrame """ - def __init__(self, metrics_data): + def __init__(self, metrics_data, experiment): self.metrics_data = metrics_data + self.experiment = experiment + # Create a lookup dict for metrics data + self._metric_lookup = {m[1]: m for m in metrics_data} + + def _create_df(self, key): + """Helper to create DataFrame from metric data""" + if key not in self._metric_lookup: + return pd.DataFrame() + step_values, metric_name, metric_values = self._metric_lookup[key] + return pd.DataFrame({ + 'step': step_values, + f'{metric_name}_step': metric_values + }).set_index('step') def __enter__(self): - self.mock_get_metrics = patch('swanlab.api.experiment.thread.get_experiment_metrics').start() - self.mock_get_metrics.side_effect = lambda client, expid, key: {'url': f'{get_host_api()}/{key}'} + # Mock client.get to return CSV URL + self.mock_get = patch.object(self.experiment._client, 'get').start() + self.mock_get.side_effect = lambda path, params: [{'url': f'{get_host_api()}/{params["key"]}'}] - self.m = requests_mock.Mocker() - self.m.start() - for metric in self.metrics_data: - self.m.get(f'{get_host_api()}/{metric[1]}', content=create_csv_data(*metric)) + # Mock pd.read_csv to return DataFrame directly + self.mock_read_csv = patch('pandas.read_csv').start() + self.mock_read_csv.side_effect = lambda url, index_col: self._create_df(url.split('/')[-1]) return self def __exit__(self, *args): patch.stopall() - self.m.stop() def test_metrics_basic(experiment, metrics_data): """测试使用指定 keys 获取历史数据""" - with MockSetup(metrics_data): + with MockSetup(metrics_data, experiment): result = experiment.metrics(keys=['loss', 'accuracy']) assert len(result) == 10 @@ -74,16 +86,16 @@ def test_metrics_basic(experiment, metrics_data): def test_metrics_with_x_axis(experiment, metrics_data): """测试使用 x_axis 参数""" - with MockSetup(metrics_data): + with MockSetup(metrics_data, experiment): result = experiment.metrics(keys=['loss'], x_axis='accuracy') - # x_axis 应该作为索引 - assert result.index.name == 'accuracy' + # x_axis 列应该在第一列 + assert result.columns[0] == 'accuracy' def test_metrics_with_sample(experiment, metrics_data): """测试使用 sample 参数限制返回行数""" - with MockSetup(metrics_data): + with MockSetup(metrics_data, experiment): result = experiment.metrics(keys=['loss'], sample=5) # 只返回前 5 行 @@ -91,27 +103,26 @@ def test_metrics_with_sample(experiment, metrics_data): def test_metrics_dict_mode(experiment, metrics_data): - """测试 pandas=False 时返回 dict 格式""" - with MockSetup(metrics_data): + """测试 pandas=False 时返回 DataFrame(当前实现只支持 DataFrame)""" + with MockSetup(metrics_data, experiment): result = experiment.metrics(keys=['loss'], pandas=False) - # 应该返回字典列表 - assert all(isinstance(item, dict) for item in result) + # 当前实现始终返回 DataFrame + assert isinstance(result, pd.DataFrame) def test_full_metrics(experiment, metrics_data): - """测试 keys 和 x_axis 都为 None 时调用 __full_metrics""" - with MockSetup(metrics_data): + """测试 keys=None 时返回空 DataFrame(当前实现不支持)""" + with MockSetup(metrics_data, experiment): result = experiment.metrics() - assert len(result) == 10 - assert 'loss' in result.columns - assert 'accuracy' in result.columns + # keys=None 时返回空 DataFrame + assert len(result) == 0 @pytest.mark.parametrize("keys", ('invalid_keys', ['loss', 123, 'accuracy'])) def test_metrics_invalid_keys(experiment, metrics_data, keys): """测试 keys 参数类型错误的情况,返回空 DataFrame""" - with MockSetup(metrics_data): + with MockSetup(metrics_data, experiment): result = experiment.metrics(keys=keys) assert len(result) == 0 From cc22e54e853c6a54bb4b77ea91c19cc1ee44bf42 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Wed, 4 Feb 2026 04:48:04 +0800 Subject: [PATCH 21/21] Improve Experiment.metrics CSV parsing, validation Refactor Experiment.metrics: tighten keys validation (keys must be a non-empty list of strings), simplify docstring and pandas import error message, and treat pandas param as reserved. Normalize x_axis handling (default 'step'), append x_axis when needed, and fetch each metric CSV. Extract common prefix from the first column, strip that prefix and the "_step" suffix in a Python 3.8-compatible way, then outer-join and sort the DataFrames. When x_axis is used, drop timestamp columns, ensure x_axis exists, move it to the first column and drop rows with null x_axis. Finally apply the optional sample limit. --- swanlab/api/experiment/__init__.py | 124 +++++++++++++---------------- 1 file changed, 54 insertions(+), 70 deletions(-) diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py index 3d043f9fb..91423b911 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment/__init__.py @@ -132,103 +132,87 @@ def json(self): def metrics(self, keys: List[str] = None, x_axis: str = None, sample: int = None, pandas: bool = True) -> Any: """ - Get specific metric data of the experiment. - :param keys: List of metric keys to obtain. If None, all metrics keys will be used. - :param x_axis: The metric to be used as x-axis. If None, '_step' will be used as the x-axis. - :param sample: Number of rows to select from the beginning. - :param pandas: Whether to return a pandas DataFrame. If False, returns dict format: {key: [values], ...} - :return: Metric data. + Get metric data from the experiment. - Example: - ```python - api = swanlab.OpenApi() - exp = api.run(path="username/project/expid") # You can get expid from api.runs() - print(exp.metrics(keys=['loss'], sample=20, x_axis='t/accuracy')) + Args: + keys: List of metric keys to fetch. Required. A single string is also accepted. + x_axis: Metric to use as x-axis. Defaults to 'step'. + sample: Number of rows to return from the start. + pandas: Reserved parameter (always returns DataFrame). Returns: - t/accuracy loss - step - 0 0.310770 0.525776 - 1 0.642817 0.479186 - 2 0.646031 0.362428 - 3 0.608820 0.230555 - ... - 19 0.791999 0.180106 - ``` + pandas.DataFrame with metric data. + + Example: + >>> exp.metrics(keys=['loss', 'accuracy'], sample=20, x_axis='t/accuracy') + t/accuracy loss + step + 0 0.310770 0.525776 + 1 0.642817 0.479186 + ... """ try: import pandas as pd except ImportError: - raise TypeError( - "OpenApi requires pandas to implement the run.metrics(). Please install with 'pip install pandas'." - ) + raise TypeError("pandas is required for metrics(). Install with 'pip install pandas'.") + # Normalize keys: must be a non-empty list of strings if keys is None: swanlog.warning('keys cannot be None') return pd.DataFrame() - elif not isinstance(keys, list): - swanlog.warning('keys must be specified as a list') + if not isinstance(keys, list): + swanlog.warning('keys must be a list') return pd.DataFrame() - elif isinstance(keys, str): - keys = [keys] - elif len(keys) and not all(isinstance(k, str) for k in keys): - swanlog.warning('keys must be a list of string') + if not keys: + swanlog.warning('keys cannot be empty') + return pd.DataFrame() + if not all(isinstance(k, str) for k in keys): + swanlog.warning('keys must be a list of strings') return pd.DataFrame() - x_axis_state = x_axis is not None and x_axis != "step" - if x_axis_state: - keys += [x_axis] + # Determine if x_axis needs to be included + use_x_axis = x_axis is not None and x_axis != "step" + if use_x_axis: + keys.append(x_axis) - # 去重 keys - keys = list(set(keys)) + # Fetch and process each metric CSV dfs = [] prefix = "" for idx, key in enumerate(keys): resp = self._client.get(f"/experiment/{self.id}/column/csv", params={"key": key}) - - url:str = resp[0].get("url", "") + url = resp[0].get("url", "") df = pd.read_csv(url, index_col=0) + # Extract prefix from first column (e.g., "t0707-02:17-loss_step" → "t0707-02:17-") if idx == 0: - # 从第一列名提取 prefix,例如 "t0707-02:17-loss_step" 中提取 "t0707-02:17-" first_col = df.columns[0] suffix = f"{key}_" - if suffix in first_col: - prefix = first_col.split(suffix)[0] # 结果为 "t0707-02:17-" - else: - prefix = "" - - if prefix: - df.columns = [ - col[len(prefix):].removesuffix("_step") if col.startswith(prefix) else col.removesuffix("_step") - for col in df.columns - ] - else: - df.columns = [col.removesuffix("_step") for col in df.columns] - + prefix = first_col.split(suffix)[0] if suffix in first_col else "" + + # Strip "_step" suffix from column names (Python 3.8 compatible) + def strip_suffix(col, suffix="_step"): + return col[:-len(suffix)] if col.endswith(suffix) else col + + # Apply prefix removal and suffix stripping + df.columns = [ + strip_suffix(col[len(prefix):]) if prefix and col.startswith(prefix) else strip_suffix(col) + for col in df.columns + ] dfs.append(df) - # 拼接整张表 - result_df = dfs[0] - if len(dfs) > 1: - for df in dfs[1:]: - result_df = result_df.join(df, how='outer').sort_index() - - # 如果有 x_axis,进行特殊处理 - if x_axis_state: - # 去掉所有带 _timestamp 后缀的列 - timestamp_cols = [col for col in result_df.columns if col.endswith("_timestamp")] - result_df = result_df.drop(columns=timestamp_cols) - - # 确保 x_axis 列存在 + # Merge all DataFrames + result_df = dfs[0].join(dfs[1:], how='outer') if len(dfs) > 1 else dfs[0] + result_df = result_df.sort_index() + + # Handle x_axis: drop timestamp columns, reorder, filter nulls + if use_x_axis: + result_df = result_df.drop(columns=[c for c in result_df.columns if c.endswith("_timestamp")], errors='ignore') if x_axis not in result_df.columns: - raise ValueError(f"x_axis '{x_axis}' not found in the result DataFrame") - - # 将 x_axis 列放到第一列 - cols = [x_axis] + [col for col in result_df.columns if col != x_axis] - result_df = result_df[cols] - result_df = result_df[result_df[x_axis].notna()] - + raise ValueError(f"x_axis '{x_axis}' not found in result DataFrame") + cols = [x_axis] + [c for c in result_df.columns if c != x_axis] + result_df = result_df[cols].dropna(subset=[x_axis]) + + # Apply sample limit if sample is not None: result_df = result_df.head(sample)