Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/data/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ def _init_tables(self):
name TEXT UNIQUE NOT NULL,
host TEXT NOT NULL,
username TEXT NOT NULL,
password TEXT NOT NULL,
password TEXT,
port INTEGER DEFAULT 22,
is_active BOOLEAN DEFAULT 0,
pkey_path TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
Expand Down Expand Up @@ -389,15 +390,16 @@ def add_gpu_server(self, server_data: Dict) -> bool:
try:
self.cursor.execute('''
INSERT OR REPLACE INTO gpu_servers
(name, host, username, password, port, is_active)
VALUES (?, ?, ?, ?, ?, ?)
(name, host, username, password, port, is_active,pkey_path)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
server_data["name"],
server_data["host"],
server_data["username"],
server_data.get("password"),
server_data.get("port", 22), # 默认端口为22
server_data.get("is_active", False)
server_data.get("is_active", False),
server_data.get("pkey_path")
))
self.conn.commit()
return True
Expand Down
8 changes: 8 additions & 0 deletions src/gui/i18n/translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@
'select_result_first': 'Please select a result first',
'file_not_exist': 'Result file does not exist',
'load_error': 'Failed to load result file',
'auth_mode': 'Auth mode',
'pkey_auth': 'Private Key',
'sel_pkey_file': 'Select Private Key File',
'pwd_auth': 'Password',
},
'zh_CN': {
'settings': '设置',
Expand Down Expand Up @@ -474,6 +478,10 @@
'select_result_first': '请先选择一个结果',
'file_not_exist': '结果文件不存在',
'load_error': '加载结果文件失败',
'auth_mode': '认证方式',
'pkey_auth': '私钥认证',
'sel_pkey_file': '选择私钥文件',
'pwd_auth': '密码认证',
},
'fr': {
'settings': 'Paramètres',
Expand Down
38 changes: 34 additions & 4 deletions src/gui/settings/gpu_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox,
QLabel, QLineEdit, QSpinBox, QPushButton,
QListWidget, QMessageBox, QFormLayout,
QDialog, QCheckBox, QDoubleSpinBox, QListWidgetItem
QDialog, QCheckBox, QDoubleSpinBox, QListWidgetItem,
QPushButton,QTabWidget,QFileDialog
)
from PyQt6.QtGui import QAction
from PyQt6.QtCore import Qt, pyqtSignal
from src.utils.logger import setup_logger
from src.utils.config import config
from src.monitor.gpu_monitor import GPUMonitorManager
from src.data.db_manager import db_manager
from src.monitor.gpu_monitor import gpu_monitor
from src.gui.i18n.language_manager import LanguageManager
import os

logger = setup_logger("gpu_settings")

Expand Down Expand Up @@ -50,11 +53,23 @@ def init_ui(self):
# 用户名
self.username_input = QLineEdit()
layout.addRow(self.tr('username') + ":", self.username_input)


# 认证方式
tab_layout = QTabWidget()

# 密码
self.password_input = QLineEdit()
self.password_input.setEchoMode(QLineEdit.EchoMode.Password)
layout.addRow(self.tr('password') + ":", self.password_input)
tab_layout.addTab(self.password_input, self.tr('pwd_auth'))

# 私钥文件路径
self.pkey_path_btn = QPushButton()
self.pkey_path_btn.setText(self.tr('sel_pkey_file')) # 设置按钮文本
self.pkey_path_btn.clicked.connect(self.get_pkey_path)

tab_layout.addTab(self.pkey_path_btn, self.tr('pkey_auth'))
layout.addRow(self.tr('auth_mode') + ":", tab_layout)

# 按钮
button_box = QHBoxLayout()
Expand All @@ -67,6 +82,15 @@ def init_ui(self):
layout.addRow("", button_box)

self.setLayout(layout)

def get_pkey_path(self):
"""获取私钥文件路径"""
file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.FileMode.ExistingFile)
file_dialog.setNameFilter("All Files (*.*)")
if file_dialog.exec():
self.pkey_path_btn.setText(file_dialog.selectedFiles()[0])


def load_server_data(self):
"""加载服务器数据"""
Expand All @@ -75,15 +99,21 @@ def load_server_data(self):
self.port_input.setValue(self.server_data.get("port", 22))
self.username_input.setText(self.server_data.get("username", ""))
self.password_input.setText(self.server_data.get("password", ""))
self.pkey_path_btn.setText(self.server_data.get("pkey_path", ""))

def get_server_data(self) -> dict:
"""获取服务器数据"""
pkey_path = self.pkey_path_btn.text().strip()
if pkey_path and not os.path.exists(pkey_path):
logger.warning(f"私钥文件不存在: {pkey_path}")
pkey_path = ""
return {
"name": self.name_input.text().strip(),
"host": self.host_input.text().strip(),
"port": self.port_input.value(),
"username": self.username_input.text().strip(),
"password": self.password_input.text().strip()
"password": self.password_input.text().strip(),
"pkey_path": pkey_path,
}

def tr(self, key):
Expand All @@ -95,7 +125,7 @@ def test_connection(self):
try:
data = self.get_server_data()
monitor = GPUMonitorManager()
monitor.setup_monitor(data["host"], data["username"], data["password"], data["port"])
monitor.setup_monitor(data["host"], data["username"], data["password"], data["port"], data["pkey_path"])
stats = monitor.get_stats()

if stats:
Expand Down
40 changes: 28 additions & 12 deletions src/monitor/gpu_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@ def get_gpu_memory_util(self, index: int) -> float:

class GPUMonitor:
"""远程GPU监控类"""
def __init__(self, host: str, username: str, password: str, port: int = 22):
def __init__(self, host: str, username: str, password: str, port: int = 22, pkey: str = ""):
self.host = host
self.username = username
self.password = password
self.port = port
self.pkey = pkey
self.client = None
self.max_retries = 3 # 最大重试次数
self.retry_interval = 2 # 重试间隔(秒)
Expand All @@ -130,13 +131,26 @@ def _connect(self) -> bool:

self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.client.connect(
self.host,
port=self.port,
username=self.username,
password=self.password,
timeout=5
)
logger.debug(self.pkey)
if self.pkey:
private_key = paramiko.RSAKey.from_private_key_file(self.pkey)
self.client.connect(
self.host,
port=self.port,
username=self.username,
pkey=private_key,
timeout=5
)
elif self.password:
self.client.connect(
self.host,
port=self.port,
username=self.username,
password=self.password,
timeout=5
)
else:
raise ValueError("Neither password nor private key provided")
logger.info(f"成功连接到远程服务器: {self.host}:{self.port}")
return True
except Exception as e:
Expand Down Expand Up @@ -428,8 +442,9 @@ def init_monitor(self):
self.setup_monitor(
server["host"],
server["username"],
server["password"],
server.get("port", 22) # 使用默认端口22
server.get("password", ""),
server.get("port", 22), # 使用默认端口22
server.get("pkey_path", "") # 使用默认私钥路径
)
logger.info(f"GPU监控器初始化成功: {server['host']}:{server.get('port', 22)}")
else:
Expand All @@ -439,17 +454,18 @@ def init_monitor(self):
logger.error(f"初始化GPU监控器失败: {e}")
self.monitor = None

def setup_monitor(self, host: str, username: str, password: str, port: int = 22):
def setup_monitor(self, host: str, username: str, password: str, port: int = 22, pkey: str = ""):
"""设置GPU监控器

Args:
host: 主机地址
username: 用户名
password: 密码
port: SSH端口
pkey: 私钥路径(可选)
"""
try:
self.monitor = GPUMonitor(host, username, password, port)
self.monitor = GPUMonitor(host, username, password, port, pkey)
logger.info(f"GPU监控器设置成功: {host}:{port}")
except Exception as e:
logger.error(f"设置GPU监控器失败: {e}")
Expand Down