Skip to content

Latest commit

 

History

History
254 lines (207 loc) · 7.89 KB

File metadata and controls

254 lines (207 loc) · 7.89 KB

B13 Thermal Diffusion Prediction UI

Overview

A web-based interface for running B13 thermal diffusion model predictions with real-time results display.

File Structure

chromosome/
├── auth_service.py                 # Flask backend with prediction endpoint
├── predict_b13_diffusion.py        # Prediction script (CLI & API)
├── exp/
│   └── best_model.pt              # Trained B13 diffusion model
├── test/                          # Sample input directory
│   ├── images/                    # Input PNG images (3+ frames)
│   └── metadata/                  # Input JSON metadata files
├── predictions/                   # Output directory (created automatically)
│   ├── images/                    # Generated predictions
│   └── metadata/                  # Generation metadata with context
└── frontend/
    ├── index.html                 # Main UI (login + prediction dashboard)
    └── static/
        ├── script.js              # Frontend logic (auth + predictions)
        └── style.css              # Styling with results grid

Features

Authentication Flow

  1. Email OTP Login - Secure authentication via email OTP
  2. Session Management - Redis-based session storage
  3. Chaos Mode - Fun UI chaos toggle (preserved from original)

Prediction Workflow

  1. Input Configuration

    • Input Directory: Path to directory with images/ and metadata/ subdirectories
    • Model Path: Path to .pt trained model file (default: ./exp/best_model.pt)
    • Output Directory: Where predictions will be saved (default: ./predictions)
  2. Prediction Execution

    • Validates input directory and model existence
    • Runs predict_b13_diffusion.py via subprocess
    • Uses 3 consecutive frames as context → generates 1 future frame
    • Displays real-time status updates
  3. Results Display

    • Image gallery grid with thumbnails
    • Metadata overlay showing:
      • Prediction time
      • Context frames used
      • Geographic region
      • Inference steps
    • Download button for each image
    • Smooth scroll to results

API Endpoints

Authentication

  • POST /send-otp - Send OTP to email
  • POST /verify-otp - Verify OTP and create session
  • POST /check-session - Validate session token
  • POST /logout - Invalidate session

Prediction (NEW)

  • POST /predict - Run B13 diffusion prediction

    {
      "session_token": "abc123...",
      "input_dir": "./test",
      "output_dir": "./predictions",
      "model_path": "./exp/best_model.pt"
    }
  • GET /predictions/<path> - Serve generated images

Setup & Usage

1. Install Dependencies

# Python dependencies (if not already installed)
pip install flask flask-cors redis torch torchvision diffusers pillow tqdm

# Start Redis server
redis-server

2. Prepare Input Data

Create input directory structure:

mkdir -p test/images test/metadata

# Add at least 3 consecutive PNG images to test/images/
# Add corresponding JSON metadata files to test/metadata/

Example input structure:

test/
├── images/
│   ├── frame_t0.png    # t-20 minutes
│   ├── frame_t1.png    # t-10 minutes
│   └── frame_t2.png    # t-0 minutes (current)
└── metadata/
    ├── frame_t0.json
    ├── frame_t1.json
    └── frame_t2.json

3. Start Backend Server

python auth_service.py

Server will start on http://localhost:5000

4. Open Web Interface

Open browser to: http://localhost:5000

5. Login & Run Prediction

  1. Enter email → Click "Send OTP"
  2. Check email for 6-digit code
  3. Enter OTP → Click "Verify"
  4. Configure paths in dashboard:
    • Input Directory: ./test
    • Model Path: ./exp/best_model.pt
    • Output Directory: ./predictions
  5. Click "Start Prediction"
  6. Wait for prediction to complete (shown in status)
  7. View results in image gallery below

Input Data Format

Images (PNG)

  • RGB format, any resolution (will be resized to 512x512)
  • Consecutive time series (10-minute intervals)
  • Named chronologically (sorted alphabetically)

Metadata (JSON)

Each JSON file should contain:

{
  "observation_time_utc": "2025-10-12T00:00:00Z",
  "min_lat": 20.0,
  "max_lat": 50.0,
  "min_lon": 120.0,
  "max_lon": 150.0,
  "segment_index": 1,
  "satellite": "Himawari-8",
  "enhanced": false,
  "composite_bands": ["B13"]
}

Output Structure

predictions/
├── images/
│   └── future_from_frame_t0_to_frame_t2.png    # Generated future frame
└── metadata/
    └── future_from_frame_t0_to_frame_t2.json   # With generation info

Output metadata includes:

  • All input metadata fields
  • _generation object with:
    • Model name and version
    • Generation timestamp
    • Inference parameters (steps, guidance)
    • Context frame references
    • Temporal context times

Model Architecture

ConditionalLatentDiffusion (must match training script exactly)

  • VAE: stabilityai/sd-vae-ft-mse (frozen)
  • Conditioning: 14-dim metadata → 768-dim embedding
    • Temporal: 6-dim (hour, day, month - cyclical)
    • Spatial: 4-dim (lat/lon bounds)
    • Metadata: 4-dim (segment, satellite, enhanced, composite)
  • U-Net: 4 channel latent space with cross-attention
  • Scheduler: DDIM (50 steps default)

Troubleshooting

"Model not found"

  • Check model path: ./exp/best_model.pt
  • Ensure model was trained with exp/train_b13_single_gpu.py

"Input directory not found"

  • Verify path is relative to project root
  • Check directory has images/ and metadata/ subdirectories

"Architecture mismatch"

  • Model must be trained with same architecture as predict_b13_diffusion.py
  • Conditioning encoder: Linear(14, 256) → GELU → Linear(256, 512) → GELU → Linear(512, 768)
  • No temporal context encoder in current version

"Prediction timed out"

  • Default timeout: 10 minutes
  • Reduce --num-inference-steps for faster generation (lower quality)
  • Check GPU availability: torch.cuda.is_available()

Images not displaying

  • Check browser console for CORS errors
  • Verify Flask server is running
  • Check /predictions/<path> endpoint serves images correctly

Performance Notes

  • Generation Time: ~30-60 seconds per frame (GPU) / 5-10 minutes (CPU)
  • Memory Usage: ~4GB GPU VRAM for inference
  • Inference Steps: 50 (default) - reduce for speed, increase for quality
  • Batch Processing: Processes multiple 3-frame sequences sequentially

Future Enhancements

  1. Real-time Progress: WebSocket for live inference progress
  2. Temporal Context Encoder: Use actual pixel data from previous frames (requires retraining)
  3. Batch Upload: Upload images directly via web interface
  4. Comparison View: Side-by-side comparison of input/output
  5. Animation: Create GIF/video from sequence of predictions
  6. Export Metadata: Download metadata as CSV for analysis

Development Notes

Frontend Changes

  • Added inputDir, modelPath, outputDir input fields
  • Implemented handleDirectorySubmit() to call /predict endpoint
  • Created displayResults() to render image gallery
  • Added responsive grid layout for results
  • Container expands to 1400px when dashboard active

Backend Changes

  • Added POST /predict endpoint in auth_service.py
  • Subprocess execution of predict_b13_diffusion.py
  • Image serving via /predictions/<path> route
  • Session validation for prediction requests
  • 10-minute timeout with error handling

Prediction Script

  • Updated load_input_data() to group 3-frame sequences
  • Modified predict() to use metadata-only conditioning
  • Architecture matches training script exactly (no context encoder)
  • Output includes temporal context metadata

Credits

  • Satellite Data: Himawari-8 B13 thermal infrared
  • Model: Conditional Latent Diffusion with VAE + U-Net
  • Framework: PyTorch, Hugging Face Diffusers
  • Frontend: Vanilla JavaScript (no frameworks)