From b70852272aa6c8f0df92f261e0893e561a0804ae Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 20 Aug 2025 02:52:40 +0000 Subject: [PATCH] Add PyTorch implementation of marker tracking models Co-authored-by: saleiferis --- pytorch/README.md | 49 +++++++ pytorch/__init__.py | 1 + pytorch/generate_data.py | 269 +++++++++++++++++++++++++++++++++++++++ pytorch/models.py | 180 ++++++++++++++++++++++++++ pytorch/train.py | 103 +++++++++++++++ pytorch/train_generic.py | 109 ++++++++++++++++ 6 files changed, 711 insertions(+) create mode 100644 pytorch/README.md create mode 100644 pytorch/__init__.py create mode 100644 pytorch/generate_data.py create mode 100644 pytorch/models.py create mode 100644 pytorch/train.py create mode 100644 pytorch/train_generic.py diff --git a/pytorch/README.md b/pytorch/README.md new file mode 100644 index 0000000..5340632 --- /dev/null +++ b/pytorch/README.md @@ -0,0 +1,49 @@ +## PyTorch Training + +This folder contains PyTorch equivalents of the original TensorFlow/Keras training scripts. You can train both the fixed-grid model and the generic model. + +### Setup + +- Python 3.8+ +- Install dependencies: + +``` +pip install torch torchvision opencv-python numpy +``` + +Optional for visualization/debugging: + +``` +pip install matplotlib +``` + +### Train the fixed-grid model (small) + +This matches `train.py` and learns marker displacements on a fixed 10x14 grid from 80x112 inputs. + +``` +python pytorch/train.py -p torch_small -lr 1e-5 --epochs 100 --steps 2000 --batch-size 32 +``` + +Arguments: + +- `-p/--prefix`: model save subfolder under `models/` +- `-lr/--lr`: learning rate +- `--epochs`: number of epochs +- `--steps`: steps per epoch (each step pulls a fresh synthetic batch) +- `--batch-size`: synthetic batch size + +### Train the generic model (encoder-decoder) + +This matches `train_generic.py` and learns a dense flow field at multiple scales from variable-sized inputs and marker grids. + +``` +python pytorch/train_generic.py -p torch_generic -lr 1e-5 --epochs 100 --steps 2000 --batch-size 32 +``` + +Notes: + +- Models are saved to `models//tracking_XXX_LOSS.pt` whenever validation loss improves. +- Scripts are self-contained and generate synthetic training data on the fly (no external datasets required). +- Run from the repository root so relative imports work. + diff --git a/pytorch/__init__.py b/pytorch/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/pytorch/__init__.py @@ -0,0 +1 @@ + diff --git a/pytorch/generate_data.py b/pytorch/generate_data.py new file mode 100644 index 0000000..c1a85c0 --- /dev/null +++ b/pytorch/generate_data.py @@ -0,0 +1,269 @@ +import cv2 +import numpy as np +import random + + +def draw_square(img, x, y, marker_size, xx, yy, theta): + width, height = img.shape[0], img.shape[1] + marker_size_large = marker_size * 2 ** 0.5 + + lx_raw, rx_raw = x - marker_size, x + marker_size + ly_raw, ry_raw = y - marker_size, y + marker_size + + lx, rx = x - marker_size_large, x + marker_size_large + ly, ry = y - marker_size_large, y + marker_size_large + + lx, rx = np.clip(lx, 0, width), np.clip(rx, -1, width - 1) + ly, ry = np.clip(ly, 0, height), np.clip(ry, -1, height - 1) + + lxi, lyi = int(lx), int(ly) + rxi, ryi = int(np.ceil(rx)), int(np.ceil(ry)) + + xx_r, yy_r = xx[lxi : rxi + 1, lyi : ryi + 1], yy[lxi : rxi + 1, lyi : ryi + 1] + xx_r, yy_r = ( + np.cos(theta) * (xx_r - x) - np.sin(theta) * (yy_r - y) + x, + np.sin(theta) * (xx_r - x) + np.cos(theta) * (yy_r - y) + y, + ) + + def intensity(val, left, right): + return 1 - np.clip(np.maximum(left - val, val - right), 0, 1) + + darkness = 0.3 + 0.7 * np.random.random() + scale = 1 - darkness * intensity(xx_r, lx_raw, rx_raw) * intensity(yy_r, ly_raw, ry_raw) + for channel in range(3): + img[lxi : rxi + 1, lyi : ryi + 1, channel] *= scale + + +def generate(xx, yy, img_blur=None, rng=0.0, W=48, H=48, N=6, M=6, degree=None): + if img_blur is None: + img_blur = (np.random.random((W // 3, H // 3, 3)) * 0.9) + 0.1 + img_blur = cv2.resize(img_blur, (H, W)) + + yy_whole, xx_whole = np.meshgrid(np.arange(H), np.arange(W)) + + img = img_blur + np.random.randn(W, H, 3) * 0.05 - 0.025 + + for i in range(N): + for j in range(M): + r = yy[i, j] + c = xx[i, j] + + if degree is None: + theta = np.random.normal(0, 0.5) * 45 / 180 * np.pi + else: + theta = degree + + draw_square(img, r, c, 0.5 + rng * 1, xx_whole, yy_whole, theta) + + img[:, :1] *= np.random.random(img[:, :1].shape) * 0.5 + img = cv2.GaussianBlur(img, (3, 3), 0) + img = np.clip(img, 0.0, 1.0) + return img + + +def shear(center_x, center_y, sigma, shear_x, shear_y, xx, yy): + gaussian = np.exp(-(((xx - center_x) ** 2 + (yy - center_y) ** 2)) / (2.0 * sigma ** 2)) + return xx + shear_x * gaussian, yy + shear_y * gaussian + + +def twist(center_x, center_y, sigma, theta, xx, yy): + gaussian = np.exp(-(((xx - center_x) ** 2 + (yy - center_y) ** 2)) / (2.0 * sigma ** 2)) + dx = xx - center_x + dy = yy - center_y + rotated_x = dx * np.cos(theta) - dy * np.sin(theta) + rotated_y = dx * np.sin(theta) + dy * np.cos(theta) + return xx + (rotated_x - dx) * gaussian, yy + (rotated_y - dy) * gaussian + + +def random_shear(xx, yy, W, H, interval=8): + shear_ratio = 5 + center_x = random.random() * W + center_y = random.random() * H + sigma = random.random() * W / 2 + if np.random.random() < 0.3: + normal = np.array([center_x - W / 2, center_y - H / 2]) + normal = normal / (np.linalg.norm(normal) + 1e-6) + shear_x = random.random() * interval * shear_ratio * normal[0] + shear_y = random.random() * interval * shear_ratio * normal[1] + else: + shear_x = random.random() * interval * shear_ratio - interval * shear_ratio / 2 + shear_y = random.random() * interval * shear_ratio - interval * shear_ratio / 2 + return shear(center_x, center_y, sigma, shear_x, shear_y, xx, yy) + + +def random_twist(xx, yy, W, H): + twist_degree = 100 + center_x = random.random() * W + center_y = random.random() * H + sigma = random.random() * W / 2 + theta = (random.random() * twist_degree - twist_degree / 2.0) / 180.0 * np.pi + return twist(center_x, center_y, sigma, theta, xx, yy) + + +def preprocessing(img, W, H): + ret = img.copy() + x_grid = np.arange(0, W, 1) + y_grid = np.arange(0, H, 1) + xx, yy = np.meshgrid(y_grid, x_grid) + for _ in range(5): + size_x = int(2 + random.random() * 15) + size_y = int(2 + random.random() * 15) + x = int(random.random() * (W - size_x)) + y = int(random.random() * (H - size_y)) + theta = np.random.random() * np.pi + rng = 0.7 + xr = (xx - x) * np.cos(theta) - (yy - y) * np.sin(theta) + yr = (xx - x) * np.sin(theta) + (yy - y) * np.cos(theta) + mask = np.logical_and.reduce([(xr >= -size_x), (xr <= size_x), (yr >= -size_y), (yr <= size_y)]) + ret[mask] *= 1 + (np.random.random(3) * rng * 2 - rng) + return ret + + +def generate_batch_fixed(batch_size=32, setting=(80, 112, 10, 14)): + W, H, N, M = setting + x = np.arange(0, W, 1) + y = np.arange(0, H, 1) + xx0, yy0 = np.meshgrid(y, x) + + interval_x = W / N + interval_y = H / M + x_positions = np.arange(interval_x / 2, W, interval_x)[:N] + y_positions = np.arange(interval_y / 2, H, interval_y)[:M] + xind, yind = np.meshgrid(y_positions, x_positions) + xind = xind.reshape([-1]).astype(np.int64) + yind = yind.reshape([-1]).astype(np.int64) + xind += (np.random.random(xind.shape) * 2 - 1).astype(np.int64) + yind += (np.random.random(yind.shape) * 2 - 1).astype(np.int64) + + X, Y = [], [] + for _ in range(batch_size): + xx = xx0 + (np.random.random(xx0.shape) * 2 - 1) + yy = yy0 + (np.random.random(yy0.shape) * 2 - 1) + rng = np.random.random() + + img0 = generate( + xx[yind, xind].reshape([N, M]), + yy[yind, xind].reshape([N, M]), + img_blur=None, + rng=rng, + W=W, + H=H, + N=N, + M=M, + degree=0, + ) + + xx_distorted, yy_distorted = xx, yy + xx_distorted, yy_distorted = random_shear(xx_distorted, yy_distorted, W, H) + xx_distorted, yy_distorted = random_twist(xx_distorted, yy_distorted, W, H) + xx_distorted += np.random.random(xx_distorted.shape) * 1 - 0.5 + yy_distorted += np.random.random(yy_distorted.shape) * 1 - 0.5 + + img = generate( + xx_distorted[yind, xind].reshape([N, M]), + yy_distorted[yind, xind].reshape([N, M]), + img_blur=None, + rng=rng, + W=W, + H=H, + N=N, + M=M, + ) + img = preprocessing(img, W, H) + + target = np.zeros([N, M, 2], dtype=np.float32) + target[:, :, 0] = ( + xx_distorted[yind, xind].reshape([N, M]) - xx[yind, xind].reshape([N, M]) + ) + target[:, :, 1] = ( + yy_distorted[yind, xind].reshape([N, M]) - yy[yind, xind].reshape([N, M]) + ) + + X.append(np.dstack([img0 - 0.5, img - 0.5])) + Y.append(target) + + X = np.asarray(X, dtype=np.float32) + Y = np.asarray(Y, dtype=np.float32) + return X, Y + + +def generate_batch_generic(batch_size=32, setting=None): + X, Y = [], [] + if setting is None: + N, M = np.random.randint(4, 15), np.random.randint(4, 15) + W = np.random.randint(N * 6, 96) + H = np.random.randint(M * 6, 96) + W = (W // 16 + 1) * 16 + H = (H // 16 + 1) * 16 + else: + W, H, N, M = setting + + x = np.arange(0, W, 1) + y = np.arange(0, H, 1) + xx, yy = np.meshgrid(y, x) + + interval_x = W / (N + 1) + interval_y = H / (M + 1) + x_positions = np.arange(interval_x, W, interval_x)[:N] + y_positions = np.arange(interval_y, H, interval_y)[:M] + xind, yind = np.meshgrid(x_positions, y_positions) + xind = xind.reshape([-1]).astype(np.int64) + yind = yind.reshape([-1]).astype(np.int64) + xind += (np.random.random(xind.shape) * 4 - 2).astype(np.int64) + yind += (np.random.random(yind.shape) * 4 - 2).astype(np.int64) + + for _ in range(batch_size): + rng = np.random.random() + img0 = generate( + xx[xind, yind].reshape([N, M]), + yy[xind, yind].reshape([N, M]), + img_blur=None, + rng=rng, + W=W, + H=H, + N=N, + M=M, + ) + + xx_distorted, yy_distorted = random_shear(xx, yy, W, H) + xx_distorted, yy_distorted = random_twist(xx_distorted, yy_distorted, W, H) + + img = generate( + xx_distorted[xind, yind].reshape([N, M]), + yy_distorted[xind, yind].reshape([N, M]), + img_blur=None, + rng=rng, + W=W, + H=H, + N=N, + M=M, + ) + + target = np.zeros([W, H, 2], dtype=np.float32) + target[:, :, 0] = xx_distorted - xx + target[:, :, 1] = yy_distorted - yy + + features = np.dstack( + [ + img0 - 0.5, + img - 0.5, + np.reshape(xx, [W, H, 1]), + np.reshape(yy, [W, H, 1]), + ] + ) + X.append(features) + Y.append(target) + + X = np.asarray(X, dtype=np.float32) + Y = np.asarray(Y, dtype=np.float32) + + # multi-scale downsampling of targets + Y_list = [ + Y, + Y[:, 1::2, 1::2], + Y[:, 2::4, 2::4], + Y[:, 4::8, 4::8], + Y[:, 8::16, 8::16], + ] + return X, Y_list + diff --git a/pytorch/models.py b/pytorch/models.py new file mode 100644 index 0000000..138aa23 --- /dev/null +++ b/pytorch/models.py @@ -0,0 +1,180 @@ +from typing import List, Tuple +import torch +import torch.nn as nn + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 5): + super().__init__() + padding = kernel_size // 2 + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class SmallModel(nn.Module): + """PyTorch port of build_model_small from train.py. + + Input: B x 6 x H x W + Output: B x 2 x H/8 x W/8 (matches TF sequence: pool 2x2 -> 3 pools total) + """ + + def __init__(self): + super().__init__() + self.enc1 = ConvBlock(6, 16) + self.pool1 = nn.AvgPool2d(2) + + self.enc2 = ConvBlock(16, 32) + self.pool2 = nn.AvgPool2d(2) + + self.enc3 = nn.Sequential( + nn.Conv2d(32, 128, 5, padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 5, padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 5, padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 5, padding=2), + nn.ReLU(inplace=True), + ) + self.pool3 = nn.AvgPool2d(2) + + self.enc4 = nn.Sequential( + nn.Conv2d(128, 256, 5, padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 5, padding=2), + nn.Sigmoid(), + ) + self.out_conv = nn.Conv2d(256, 2, 5, padding=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.enc1(x) + x = self.pool1(x) + x = self.enc2(x) + x = self.pool2(x) + x = self.enc3(x) + x = self.pool3(x) + x = self.enc4(x) + x = self.out_conv(x) + return x + + +class UpBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 5): + super().__init__() + padding = kernel_size // 2 + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding), + nn.ReLU(inplace=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class GenericAE(nn.Module): + """PyTorch port of build_model_ae from train_generic.py. + + Returns a tuple of multi-scale outputs: (full, x2, x4, x8, x16) + Each output has 2 channels (flow x, y). + """ + + def __init__(self): + super().__init__() + padding = 2 + # Encoder + self.enc1 = ConvBlock(8, 32) + self.pool1 = nn.MaxPool2d(2) + + self.enc2 = ConvBlock(32, 64) + self.pool2 = nn.MaxPool2d(2) + + self.enc3 = nn.Sequential( + nn.Conv2d(64, 128, 5, padding=padding), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 5, padding=padding), + nn.ReLU(inplace=True), + ) + self.pool3 = nn.MaxPool2d(2) + + self.enc4 = nn.Sequential( + nn.Conv2d(128, 128, 5, padding=padding), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 5, padding=padding), + nn.ReLU(inplace=True), + ) + self.pool4 = nn.MaxPool2d(2) + + # Decoder with lateral connections and intermediate outputs + self.dec4 = UpBlock(128, 128) + self.dec4_out = nn.Conv2d(128, 2, 5, padding=2) + + self.dec3 = UpBlock(128 + 128, 128) + self.dec3_out = nn.Conv2d(128, 2, 5, padding=2) + + # Input to dec2 is cat([upsampled d3 (128), c3 (128)]) -> 256 + self.dec2 = UpBlock(256, 64) + self.dec2_out = nn.Conv2d(64, 2, 5, padding=2) + + # After upsample dec2 (64) and concat with c2 (64) -> 128 + self.dec1 = UpBlock(128, 32) + self.dec1_out = nn.Conv2d(32, 2, 5, padding=2) + + # After upsample dec1 (32) and concat with c1 (32) -> 64 + self.final = nn.Sequential( + nn.Conv2d(64, 32, 5, padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(32, 32, 5, padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(32, 2, 5, padding=2), + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Encoder + c1 = self.enc1(x) + p1 = self.pool1(c1) + + c2 = self.enc2(p1) + p2 = self.pool2(c2) + + c3 = self.enc3(p2) + p3 = self.pool3(c3) + + c4 = self.enc4(p3) + p4 = self.pool4(c4) + + # Decoder stage 4 + d4 = self.dec4(p4) + d4_out = self.dec4_out(d4) + d4 = nn.functional.interpolate(d4, scale_factor=2, mode="nearest") + d4 = torch.cat([d4, c4], dim=1) + + # Decoder stage 3 + d3 = self.dec3(d4) + d3_out = self.dec3_out(d3) + d3 = nn.functional.interpolate(d3, scale_factor=2, mode="nearest") + d3 = torch.cat([d3, c3], dim=1) + + # Decoder stage 2 + d2 = self.dec2(d3) + d2_out = self.dec2_out(d2) + d2 = nn.functional.interpolate(d2, scale_factor=2, mode="nearest") + d2 = torch.cat([d2, c2], dim=1) + + # Decoder stage 1 + d1 = self.dec1(d2) + d1_out = self.dec1_out(d1) + d1 = nn.functional.interpolate(d1, scale_factor=2, mode="nearest") + d1 = torch.cat([d1, c1], dim=1) + + out_full = self.final(d1) + return out_full, d1_out, d2_out, d3_out, d4_out + diff --git a/pytorch/train.py b/pytorch/train.py new file mode 100644 index 0000000..f9206ac --- /dev/null +++ b/pytorch/train.py @@ -0,0 +1,103 @@ +import os +import argparse +import shutil +import time +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from pytorch.models import SmallModel +from pytorch.generate_data import generate_batch_fixed + + +def numpy_to_tensor(images: np.ndarray, targets: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: + # images: B x H x W x 6 -> B x 6 x H x W + images_t = torch.from_numpy(images.transpose(0, 3, 1, 2)).float() + # targets: B x N x M x 2 + targets_t = torch.from_numpy(targets.transpose(0, 3, 1, 2)).float() + return images_t, targets_t + + +def evaluate(model: nn.Module, device: torch.device, setting=(80, 112, 10, 14)) -> float: + model.eval() + with torch.no_grad(): + X_val, Y_val = generate_batch_fixed(batch_size=1000, setting=setting) + images_t, targets_t = numpy_to_tensor(X_val, Y_val) + images_t = images_t.to(device) + targets_t = targets_t.to(device) + outputs = model(images_t) + loss = nn.functional.mse_loss(outputs, targets_t).item() + return float(loss) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--prefix", default="torch_small") + parser.add_argument("-lr", "--lr", type=float, default=1e-5) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--steps", type=int, default=2000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--save-dir", default="models") + args = parser.parse_args() + + prefix = args.prefix + lr = args.lr + epochs = args.epochs + steps_per_epoch = args.steps + batch_size = args.batch_size + save_root = args.save_dir + + os.makedirs(os.path.join(save_root, prefix), exist_ok=True) + # Archive training script and data generator for reproducibility + shutil.copy(__file__, os.path.join(save_root, prefix, "train.py")) + src_gen = os.path.join(os.path.dirname(__file__), "generate_data.py") + shutil.copy(src_gen, os.path.join(save_root, prefix, "generate_data.py")) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SmallModel().to(device) + optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999)) + criterion = nn.MSELoss() + + best_loss = float("inf") + val_setting = (80, 112, 10, 14) + + for epoch in range(epochs): + model.train() + running_loss = 0.0 + start_time = time.time() + for step in range(steps_per_epoch): + X_batch, Y_batch = generate_batch_fixed(batch_size=batch_size, setting=val_setting) + images_t, targets_t = numpy_to_tensor(X_batch, Y_batch) + images_t = images_t.to(device) + targets_t = targets_t.to(device) + + optimizer.zero_grad() + outputs = model(images_t) + loss = criterion(outputs, targets_t) + loss.backward() + optimizer.step() + + running_loss += float(loss.item()) + + avg_loss = running_loss / steps_per_epoch + val_loss = evaluate(model, device, setting=val_setting) + elapsed = time.time() - start_time + print(f"epoch {epoch}: train_loss={avg_loss:.6f}, val_loss={val_loss:.6f}, time={elapsed:.1f}s") + + if val_loss < best_loss: + best_loss = val_loss + save_path = os.path.join(save_root, prefix, f"tracking_{epoch:03d}_{val_loss:.3f}.pt") + torch.save({ + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, save_path) + + +if __name__ == "__main__": + main() + diff --git a/pytorch/train_generic.py b/pytorch/train_generic.py new file mode 100644 index 0000000..eeefb66 --- /dev/null +++ b/pytorch/train_generic.py @@ -0,0 +1,109 @@ +import os +import argparse +import shutil +import time +from typing import Tuple, List + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from pytorch.models import GenericAE +from pytorch.generate_data import generate_batch_generic + + +def numpy_to_tensor(images: np.ndarray, targets: List[np.ndarray]): + # images: B x H x W x 8 -> B x 8 x H x W + images_t = torch.from_numpy(images.transpose(0, 3, 1, 2)).float() + # targets: list of levels with shapes: + # L0: B x H x W x 2 + # L1: B x H/2 x W/2 x 2 + # L2: B x H/4 x W/4 x 2 + # L3: B x H/8 x W/8 x 2 + # L4: B x H/16 x W/16 x 2 + targets_t = [torch.from_numpy(t.transpose(0, 3, 1, 2)).float() for t in targets] + return images_t, targets_t + + +def evaluate(model: nn.Module, device: torch.device, batch_size: int = 1000) -> float: + model.eval() + with torch.no_grad(): + X_val, Y_val_list = generate_batch_generic(batch_size=batch_size) + images_t, targets_t = numpy_to_tensor(X_val, Y_val_list) + images_t = images_t.to(device) + targets_t = [t.to(device) for t in targets_t] + outputs = model(images_t) + loss = 0.0 + for out, tgt in zip(outputs, targets_t): + loss += nn.functional.mse_loss(out, tgt) + return float(loss.item()) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--prefix", default="torch_generic") + parser.add_argument("-lr", "--lr", type=float, default=1e-5) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--steps", type=int, default=2000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--save-dir", default="models") + args = parser.parse_args() + + prefix = args.prefix + lr = args.lr + epochs = args.epochs + steps_per_epoch = args.steps + batch_size = args.batch_size + save_root = args.save_dir + + os.makedirs(os.path.join(save_root, prefix), exist_ok=True) + shutil.copy(__file__, os.path.join(save_root, prefix, "train_generic.py")) + src_gen = os.path.join(os.path.dirname(__file__), "generate_data.py") + shutil.copy(src_gen, os.path.join(save_root, prefix, "generate_data.py")) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = GenericAE().to(device) + optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999)) + + best_loss = float("inf") + + for epoch in range(epochs): + model.train() + running_loss = 0.0 + start_time = time.time() + for step in range(steps_per_epoch): + X_batch, Y_batch_list = generate_batch_generic(batch_size=batch_size) + images_t, targets_t = numpy_to_tensor(X_batch, Y_batch_list) + images_t = images_t.to(device) + targets_t = [t.to(device) for t in targets_t] + + optimizer.zero_grad() + outputs = model(images_t) + loss = 0.0 + for out, tgt in zip(outputs, targets_t): + loss = loss + nn.functional.mse_loss(out, tgt) + loss.backward() + optimizer.step() + + running_loss += float(loss.item()) + + avg_loss = running_loss / steps_per_epoch + val_loss = evaluate(model, device, batch_size=1000) + elapsed = time.time() - start_time + print(f"epoch {epoch}: train_loss={avg_loss:.6f}, val_loss={val_loss:.6f}, time={elapsed:.1f}s") + + if val_loss < best_loss: + best_loss = val_loss + save_path = os.path.join(save_root, prefix, f"tracking_{epoch:03d}_{val_loss:.3f}.pt") + torch.save({ + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, save_path) + + +if __name__ == "__main__": + main() +