-
Notifications
You must be signed in to change notification settings - Fork 127
Open
Description
When creating many tensors rapidly on Metal/wgpu, early tensors get corrupted (all zeros) while later tensors are correct.
Related: tracel-ai/burn#3463
Reproduction
use burn::backend::wgpu::Wgpu;
use burn::tensor::{Tensor, TensorData};
fn main() {
let device = burn::backend::wgpu::WgpuDevice::default();
let mut tensors = Vec::new();
let mut expected = Vec::new();
// Create ~290 tensors (simulating LLM model load)
// 1 large embedding + 24 layers of attention/MLP tensors
let sizes = vec![
(151936, 896), // embedding
].into_iter().chain(
(0..24).flat_map(|_| vec![(896, 896), (896, 128), (896, 4864), (4864, 896)])
);
for (rows, cols) in sizes {
let data: Vec<f32> = (0..rows*cols).map(|i| ((i % 1000) as f32) * 0.0001 - 0.05).collect();
expected.push(data.iter().sum::<f32>());
tensors.push(Tensor::<Wgpu, 2>::from_data(TensorData::new(data, [rows, cols]), &device));
}
// Check results
let mut corrupted = 0;
for (i, (t, exp)) in tensors.iter().zip(expected.iter()).enumerate() {
let sum: f32 = t.clone().into_data().to_vec::<f32>().unwrap().iter().sum();
if sum.abs() < 1e-6 && exp.abs() > 1e-6 {
corrupted += 1;
if corrupted <= 5 { println!("Tensor {} ALL ZEROS (expected sum {:.2})", i, exp); }
}
}
println!("Corrupted: {}/{}", corrupted, tensors.len());
}Result on M1 Pro: Tensors 0-132 are ALL ZEROS, tensors 133+ are correct.
Cause
write_to_buffer in stream.rs uses async queue.write_buffer() without periodic synchronization. When many writes accumulate, wgpu's staging buffer pool gets exhausted and buffers are reused before GPU copy completes.
Metadata
Metadata
Assignees
Labels
No labels