From ea8754f42b381527e9ac1f93ead8829a5f13cc12 Mon Sep 17 00:00:00 2001 From: "Michael.Ma" Date: Mon, 24 Mar 2025 19:25:40 +0800 Subject: [PATCH 1/2] --- src/data/db_manager.py | 10 ++++---- src/gui/i18n/translations.py | 8 +++++++ src/gui/settings/gpu_settings.py | 33 ++++++++++++++++++++++---- src/monitor/gpu_monitor.py | 40 ++++++++++++++++++++++---------- 4 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/data/db_manager.py b/src/data/db_manager.py index f6c7a38..5eb9af7 100644 --- a/src/data/db_manager.py +++ b/src/data/db_manager.py @@ -180,9 +180,10 @@ def _init_tables(self): name TEXT UNIQUE NOT NULL, host TEXT NOT NULL, username TEXT NOT NULL, - password TEXT NOT NULL, + password TEXT, port INTEGER DEFAULT 22, is_active BOOLEAN DEFAULT 0, + pkey_path TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') @@ -389,15 +390,16 @@ def add_gpu_server(self, server_data: Dict) -> bool: try: self.cursor.execute(''' INSERT OR REPLACE INTO gpu_servers - (name, host, username, password, port, is_active) - VALUES (?, ?, ?, ?, ?, ?) + (name, host, username, password, port, is_active,pkey_path) + VALUES (?, ?, ?, ?, ?, ?, ?) ''', ( server_data["name"], server_data["host"], server_data["username"], server_data.get("password"), server_data.get("port", 22), # 默认端口为22 - server_data.get("is_active", False) + server_data.get("is_active", False), + server_data.get("pkey_path") )) self.conn.commit() return True diff --git a/src/gui/i18n/translations.py b/src/gui/i18n/translations.py index c84bbba..42746c1 100644 --- a/src/gui/i18n/translations.py +++ b/src/gui/i18n/translations.py @@ -236,6 +236,10 @@ 'select_result_first': 'Please select a result first', 'file_not_exist': 'Result file does not exist', 'load_error': 'Failed to load result file', + 'auth_mode': 'Auth mode', + 'pkey_auth': 'Private Key', + 'sel_pkey_file': 'Select Private Key File', + 'pwd_auth': 'Password', }, 'zh_CN': { 'settings': '设置', @@ -474,6 +478,10 @@ 'select_result_first': '请先选择一个结果', 'file_not_exist': '结果文件不存在', 'load_error': '加载结果文件失败', + 'auth_mode': '认证方式', + 'pkey_auth': '私钥认证', + 'sel_pkey_file': '选择私钥文件', + 'pwd_auth': '密码认证', }, 'fr': { 'settings': 'Paramètres', diff --git a/src/gui/settings/gpu_settings.py b/src/gui/settings/gpu_settings.py index 12fc2f9..0044fbc 100644 --- a/src/gui/settings/gpu_settings.py +++ b/src/gui/settings/gpu_settings.py @@ -5,8 +5,10 @@ QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QLabel, QLineEdit, QSpinBox, QPushButton, QListWidget, QMessageBox, QFormLayout, - QDialog, QCheckBox, QDoubleSpinBox, QListWidgetItem + QDialog, QCheckBox, QDoubleSpinBox, QListWidgetItem, + QPushButton,QTabWidget,QFileDialog ) +from PyQt6.QtGui import QAction from PyQt6.QtCore import Qt, pyqtSignal from src.utils.logger import setup_logger from src.utils.config import config @@ -50,11 +52,23 @@ def init_ui(self): # 用户名 self.username_input = QLineEdit() layout.addRow(self.tr('username') + ":", self.username_input) + + + # 认证方式 + tab_layout = QTabWidget() # 密码 self.password_input = QLineEdit() self.password_input.setEchoMode(QLineEdit.EchoMode.Password) - layout.addRow(self.tr('password') + ":", self.password_input) + tab_layout.addTab(self.password_input, self.tr('pwd_auth')) + + # 私钥文件路径 + self.pkey_path_btn = QPushButton() + self.pkey_path_btn.setText(self.tr('sel_pkey_file')) # 设置按钮文本 + self.pkey_path_btn.clicked.connect(self.get_pkey_path) + + tab_layout.addTab(self.pkey_path_btn, self.tr('pkey_auth')) + layout.addRow(self.tr('auth_mode') + ":", tab_layout) # 按钮 button_box = QHBoxLayout() @@ -67,6 +81,15 @@ def init_ui(self): layout.addRow("", button_box) self.setLayout(layout) + + def get_pkey_path(self): + """获取私钥文件路径""" + file_dialog = QFileDialog() + file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile) + file_dialog.setNameFilter("All Files (*.*)") + if file_dialog.exec(): + self.pkey_path_btn.setText(file_dialog.selectedFiles()[0]) + def load_server_data(self): """加载服务器数据""" @@ -75,6 +98,7 @@ def load_server_data(self): self.port_input.setValue(self.server_data.get("port", 22)) self.username_input.setText(self.server_data.get("username", "")) self.password_input.setText(self.server_data.get("password", "")) + self.pkey_path_btn.setText(self.server_data.get("pkey_path", "")) def get_server_data(self) -> dict: """获取服务器数据""" @@ -83,7 +107,8 @@ def get_server_data(self) -> dict: "host": self.host_input.text().strip(), "port": self.port_input.value(), "username": self.username_input.text().strip(), - "password": self.password_input.text().strip() + "password": self.password_input.text().strip(), + "pkey_path": self.pkey_path_btn.text().strip(), } def tr(self, key): @@ -95,7 +120,7 @@ def test_connection(self): try: data = self.get_server_data() monitor = GPUMonitorManager() - monitor.setup_monitor(data["host"], data["username"], data["password"], data["port"]) + monitor.setup_monitor(data["host"], data["username"], data["password"], data["port"], data["pkey_path"]) stats = monitor.get_stats() if stats: diff --git a/src/monitor/gpu_monitor.py b/src/monitor/gpu_monitor.py index 2c465d5..7858502 100644 --- a/src/monitor/gpu_monitor.py +++ b/src/monitor/gpu_monitor.py @@ -105,11 +105,12 @@ def get_gpu_memory_util(self, index: int) -> float: class GPUMonitor: """远程GPU监控类""" - def __init__(self, host: str, username: str, password: str, port: int = 22): + def __init__(self, host: str, username: str, password: str, port: int = 22, pkey: str = ""): self.host = host self.username = username self.password = password self.port = port + self.pkey = pkey self.client = None self.max_retries = 3 # 最大重试次数 self.retry_interval = 2 # 重试间隔(秒) @@ -130,13 +131,26 @@ def _connect(self) -> bool: self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - self.client.connect( - self.host, - port=self.port, - username=self.username, - password=self.password, - timeout=5 - ) + logger.debug(self.pkey) + if self.pkey: + private_key = paramiko.RSAKey.from_private_key_file(self.pkey) + self.client.connect( + self.host, + port=self.port, + username=self.username, + pkey=private_key, + timeout=5 + ) + elif self.password: + self.client.connect( + self.host, + port=self.port, + username=self.username, + password=self.password, + timeout=5 + ) + else: + raise ValueError("Neither password nor private key provided") logger.info(f"成功连接到远程服务器: {self.host}:{self.port}") return True except Exception as e: @@ -428,8 +442,9 @@ def init_monitor(self): self.setup_monitor( server["host"], server["username"], - server["password"], - server.get("port", 22) # 使用默认端口22 + server.get("password", ""), + server.get("port", 22), # 使用默认端口22 + server.get("pkey_path", "") # 使用默认私钥路径 ) logger.info(f"GPU监控器初始化成功: {server['host']}:{server.get('port', 22)}") else: @@ -439,7 +454,7 @@ def init_monitor(self): logger.error(f"初始化GPU监控器失败: {e}") self.monitor = None - def setup_monitor(self, host: str, username: str, password: str, port: int = 22): + def setup_monitor(self, host: str, username: str, password: str, port: int = 22, pkey: str = ""): """设置GPU监控器 Args: @@ -447,9 +462,10 @@ def setup_monitor(self, host: str, username: str, password: str, port: int = 22) username: 用户名 password: 密码 port: SSH端口 + pkey: 私钥路径(可选) """ try: - self.monitor = GPUMonitor(host, username, password, port) + self.monitor = GPUMonitor(host, username, password, port, pkey) logger.info(f"GPU监控器设置成功: {host}:{port}") except Exception as e: logger.error(f"设置GPU监控器失败: {e}") From 89c6ab11a7e675ed636d9e70600e7e8c4a7a9d01 Mon Sep 17 00:00:00 2001 From: "Michael.Ma" Date: Mon, 24 Mar 2025 19:35:28 +0800 Subject: [PATCH 2/2] --- src/gui/settings/gpu_settings.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gui/settings/gpu_settings.py b/src/gui/settings/gpu_settings.py index 0044fbc..f34cc22 100644 --- a/src/gui/settings/gpu_settings.py +++ b/src/gui/settings/gpu_settings.py @@ -16,6 +16,7 @@ from src.data.db_manager import db_manager from src.monitor.gpu_monitor import gpu_monitor from src.gui.i18n.language_manager import LanguageManager +import os logger = setup_logger("gpu_settings") @@ -102,13 +103,17 @@ def load_server_data(self): def get_server_data(self) -> dict: """获取服务器数据""" + pkey_path = self.pkey_path_btn.text().strip() + if pkey_path and not os.path.exists(pkey_path): + logger.warning(f"私钥文件不存在: {pkey_path}") + pkey_path = "" return { "name": self.name_input.text().strip(), "host": self.host_input.text().strip(), "port": self.port_input.value(), "username": self.username_input.text().strip(), "password": self.password_input.text().strip(), - "pkey_path": self.pkey_path_btn.text().strip(), + "pkey_path": pkey_path, } def tr(self, key):