-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
executable file
·124 lines (106 loc) · 4.85 KB
/
Copy pathrun.py
File metadata and controls
executable file
·124 lines (106 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import toml
import typer
import logging
import pickle
import warnings
from tqdm import tqdm
from dotenv import load_dotenv
from datetime import datetime
from typing import Optional
from puppy import MarketEnvironment, LLMAgent
# 경고 메시지를 무시합니다.
warnings.filterwarnings("ignore")
load_dotenv(override=True)
def run_simulation(
config_path: str = typer.Option(
"config/tsla_gpt_config.toml",
"-c", "--config",
help="Path to the TOML configuration file."
),
market_data_path: str = typer.Option(
"data/03_model_input/add_filing_tsla.pkl",
"-m", "--market-data",
help="Path to the market data pickle file."
),
start_date_str: str = typer.Option(
"2022-10-06",
"-s", "--start-date",
help="Simulation start date in YYYY-MM-DD format."
),
end_date_str: str = typer.Option(
"2023-04-10",
"-e", "--end-date",
help="Simulation end date in YYYY-MM-DD format."
),
output_path: str = typer.Option(
"data/05_model_output/default_run",
"-o", "--output-path",
help="Directory to save all checkpoints and final results. Resumes from here if checkpoints exist."
),
update_market_data: bool = typer.Option(
False,
"--update-market",
help="에이전트는 유지하고, 새로운 시장 데이터로 환경을 업데이트하여 재개합니다."
),
):
"""
에이전트 시뮬레이션을 시작하거나 체크포인트에서 재개합니다.
"""
os.makedirs(output_path, exist_ok=True)
log_file_path = os.path.join(output_path, "run.log")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler = logging.FileHandler(log_file_path, mode="a")
file_handler.setFormatter(logging_formatter)
if not logger.handlers:
logger.addHandler(file_handler)
with open(config_path, "r") as f:
config = toml.load(f)
agent_checkpoint_path = os.path.join(output_path, "agent")
env_checkpoint_path = os.path.join(output_path, "env")
if update_market_data:
logger.info("🔄 데이터 업데이트 모드로 실행합니다: 에이전트는 복구하고 환경은 새로 구성합니다.")
if not os.path.exists(agent_checkpoint_path):
logger.error(f"❌ 오류: 에이전트 체크포인트를 찾을 수 없습니다. 먼저 일반 모드로 시뮬레이션을 실행하세요.", fg=typer.colors.RED)
raise typer.Exit()
the_agent = LLMAgent.load_checkpoint(path=agent_checkpoint_path)
with open(market_data_path, "rb") as f:
env_data_pkl = pickle.load(f)
environment = MarketEnvironment(
symbol=config["general"]["trading_symbol"],
env_data_pkl=env_data_pkl,
start_date=datetime.strptime(start_date_str, "%Y-%m-%d").date(),
end_date=datetime.strptime(end_date_str, "%Y-%m-%d").date(),
)
environment.current_step = the_agent.counter - 1
elif os.path.exists(agent_checkpoint_path) and os.path.exists(env_checkpoint_path):
logger.info(f"'{output_path}'에서 체크포인트를 발견했습니다. 시뮬레이션을 재개합니다.")
environment = MarketEnvironment.load_checkpoint(path=env_checkpoint_path)
the_agent = LLMAgent.load_checkpoint(path=agent_checkpoint_path)
else:
logger.info("체크포인트를 찾을 수 없습니다. 새로운 시뮬레이션을 시작합니다.")
with open(market_data_path, "rb") as f:
env_data_pkl = pickle.load(f)
environment = MarketEnvironment(
symbol=config["general"]["trading_symbol"],
env_data_pkl=env_data_pkl,
start_date=datetime.strptime(start_date_str, "%Y-%m-%d").date(),
end_date=datetime.strptime(end_date_str, "%Y-%m-%d").date(),
)
the_agent = LLMAgent.from_config(config)
with tqdm(total=environment.simulation_length, initial=the_agent.counter - 1, desc="Simulating Agent") as pbar:
while True:
market_info = environment.step()
the_agent.step(market_info=market_info)
pbar.update(1)
the_agent.save_checkpoint(path=agent_checkpoint_path, force=True)
environment.save_checkpoint(path=env_checkpoint_path, force=True)
if market_info[-1]:
logger.info("시뮬레이션 기간의 마지막에 도달했습니다.")
break
logger.info("시뮬레이션이 성공적으로 완료되었습니다.")
logger.info(f"최종 결과 및 체크포인트가 '{output_path}'에 저장되었습니다.")
if __name__ == "__main__":
typer.run(run_simulation)