-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_model.py
More file actions
117 lines (98 loc) · 3.58 KB
/
inference_model.py
File metadata and controls
117 lines (98 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Standard library imports
import time
import argparse
from datetime import datetime
from dataclasses import dataclass, fields
# Third-party imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
# Local application/library specific imports
from src.transformer_utils import TransformerWithTimeEmbeddings
from src.utils import r2_score, load_dataclass, Config
from src.dataloader import (
train_test_split,
OrderBookDataset,
load_stats_and_normalize,
apply_fft_high_pass_filter,
generate_orderbook_features,
apply_pca_keep_variance,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Inference script to calculate R2 score from a CSV dataset containing Orderbook ."
)
parser.add_argument(
"csv_path", type=str, help="Path to the CSV file containing the dataset."
)
args = parser.parse_args()
print(f"Loading dataset ... ")
# Load the dataset
data_path = args.csv_path
data = pd.read_csv(data_path)
# Ensure the target variable is named 'y'
calculate_R2 = True
if "y" not in data.columns:
data["y"] = 0
calculate_R2 = False
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set the training configuration
config = load_dataclass("final_model/config.pkl")
print(f"Preparing dataset ...")
# Reduce the number of features
if config.reduce_features:
data = generate_orderbook_features(data)
if config.reduce_features_svd:
data = apply_pca_keep_variance(
data, variance_threshold=config.svd_variance_threshold
)
# Remove very low frequency components to make model more generalizable to unseen data
data = apply_fft_high_pass_filter(data, config.fft_cutoff)
# Load the feature statistics and normalize the data
data = load_stats_and_normalize(data, "final_model/feature_stats.csv")
# Create dataloader
dataset = OrderBookDataset(data, config.window_size)
test_dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)
input_size = data.shape[1] - 1 # Number of features, excluding target variable
transformer_kwargs = {
"n_out": 1,
"emb": config.hidden_size,
"heads": config.transformer_heads,
"depth": config.transformer_depth,
"dropout": config.dropout_rate,
}
if config.use_transformer:
model = TransformerWithTimeEmbeddings(
input_size=input_size, nband=1, **transformer_kwargs
).to(device)
# Loading the model
model.load_state_dict(
torch.load("final_model/model.pt", map_location=torch.device(device))
)
model.eval()
print("Calculating R2 ... ")
y_true_test, y_pred_test = [], []
for inputs, targets in tqdm(test_dataloader):
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
y_true_test.append(targets)
y_pred_test.append(outputs.squeeze().detach())
y_true_test = torch.cat(y_true_test)
y_pred_test = torch.cat(y_pred_test)
df = pd.DataFrame(y_pred_test.cpu().numpy())
# Save CSV file
df.to_csv("prediction.csv", index=False)
if calculate_R2:
test_r2 = r2_score(y_true_test, y_pred_test)
print(f"Test R2 score: {test_r2.item()}")
else:
print(
f"No 'y' column found in the dataset. Prediction saved to 'prediction.csv'."
)