Skip to content

fix: add TPU v5e compatibility for trace parsing#1

Open
midiareshadi wants to merge 1 commit into
scalesim-project:mainfrom
midiareshadi:fix/tpu-v5e-compatibility
Open

fix: add TPU v5e compatibility for trace parsing#1
midiareshadi wants to merge 1 commit into
scalesim-project:mainfrom
midiareshadi:fix/tpu-v5e-compatibility

Conversation

@midiareshadi

Copy link
Copy Markdown

Problem

This repo was not runnable on TPU v5e (and likely any TPU newer than v4)
due to two issues in the trace parsing pipeline, causing all
Actual_Duration_us values to be 0.0 and MAPE to be inf%.

Root Causes and Fixes

1. trace_parser.py — JSON format difference

TPU v4 wraps trace events in a dict: {"traceEvents": [...]}
TPU v5e emits a raw JSON list: [...]

The parser called trace_data.get('traceEvents', []) which silently
returns [] on v5e, so no events were ever parsed.

Fix: check if trace_data is already a list before calling .get()

2. flexible_validation.py — Hardcoded TPU device pid

The event filter hardcoded pid == 8 to identify TPU device events.
On TPU v5e the device pid is 3, causing all real hardware events to
be filtered out and Actual_Duration_us to always be 0.0.

Fix: dynamically detect the TPU device pid by scanning trace metadata
events for a process_name containing 'TPU', making it robust across
all TPU generations rather than hardcoding any specific pid value.

Testing

Tested on Google Colab with:

  • TPU v5e
  • JAX 0.7.2
  • Python 3.12.13

Results after fixes:

  • RMSE dropped from 8.44μs to 4.26μs
  • MAPE changed from inf% to 88.65%

The remaining MAPE is expected — the linear models were calibrated on
TPU v4 hardware. TPU v5e runs matmuls ~1.9x faster on average.
Recalibrating the linear models for v5e would be a worthwhile
follow-up contribution.

- trace_parser.py: handle both list and dict JSON formats since TPU v5e
  emits a raw list while TPU v4 wraps events in a traceEvents dict key
- flexible_validation.py: dynamically detect TPU device pid from trace
  metadata instead of hardcoding pid=8, fixing compatibility with TPU v5e
  and future TPU generations

Tested on Google Colab with TPU v5e, JAX 0.7.2, Python 3.12.13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant