diff --git a/flexible_validation.py b/flexible_validation.py index 02707d1..ad2ca3e 100644 --- a/flexible_validation.py +++ b/flexible_validation.py @@ -230,8 +230,19 @@ def parse_profile_trace(self): def filter_profile_trace_events(self, trace_events, write_to_file: bool = False): filtered_events = [] kernel_function_name = self.config.kernel_type.get_kernel().__name__ + # Dynamically detect the TPU device pid from process_name metadata events + # This ensures compatibility across TPU generations (v4, v5e, etc.) + tpu_pid = None for event in trace_events: - if "pid" not in event.keys() or event['pid'] != 8: + if (event.get('ph') == 'M' and + event.get('name') == 'process_name' and + isinstance(event.get('args', {}).get('name', ''), str) and + 'TPU' in event['args']['name']): + tpu_pid = event['pid'] + break + + for event in trace_events: + if "pid" not in event.keys() or event['pid'] != tpu_pid: continue if "name" in event.keys() and (kernel_function_name in event['name'] or "jit_kernel" in event['name']) and "args" in event.keys(): diff --git a/trace_parser.py b/trace_parser.py index 4290b1f..6c6fb35 100644 --- a/trace_parser.py +++ b/trace_parser.py @@ -59,7 +59,7 @@ def parse_trace_csv(self): return None # Extract trace events - trace_events = trace_data.get('traceEvents', []) + trace_events = trace_data if isinstance(trace_data, list) else trace_data.get("traceEvents", []) if not trace_events: print("No trace events found in the data") return None