@@ -245,21 +245,9 @@ impl QuantumEncoder for AmplitudeEncoder {
245245 buffer
246246 } ;
247247
248- // Validate norms on host to catch zero or NaN samples early
249- {
250- crate :: profile_scope!( "GPU::NormValidation" ) ;
251- let host_inv_norms = device
252- . dtoh_sync_copy ( & inv_norms_gpu)
253- . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
254-
255- if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
256- return Err ( MahoutError :: InvalidInput (
257- "One or more samples have zero or invalid norm" . to_string ( ) ,
258- ) ) ;
259- }
260- }
261-
262- // Launch batch kernel
248+ // Launch batch encode kernel — takes GPU norm buffer directly, no D2H needed yet.
249+ // We defer the norm validation D2H copy until AFTER the encode kernel + sync so that
250+ // the norm kernel → encode kernel sequence runs without an intermediate GPU-CPU roundtrip.
263251 {
264252 crate :: profile_scope!( "GPU::BatchKernelLaunch" ) ;
265253 let state_ptr = batch_state_vector. ptr_f64 ( ) . ok_or_else ( || {
@@ -288,14 +276,30 @@ impl QuantumEncoder for AmplitudeEncoder {
288276 }
289277 }
290278
291- // Synchronize
279+ // Synchronize — all GPU work (norm + encode) complete after this point.
292280 {
293281 crate :: profile_scope!( "GPU::Synchronize" ) ;
294282 device
295283 . synchronize ( )
296284 . map_err ( |e| MahoutError :: Cuda ( format ! ( "Sync failed: {:?}" , e) ) ) ?;
297285 }
298286
287+ // Validate norms on host AFTER sync: D2H copy no longer blocks the encode kernel.
288+ // This preserves error detection for zero/NaN samples without adding a mid-pipeline
289+ // GPU-CPU roundtrip between the norm and encode kernels.
290+ {
291+ crate :: profile_scope!( "GPU::NormValidation" ) ;
292+ let host_inv_norms = device
293+ . dtoh_sync_copy ( & inv_norms_gpu)
294+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
295+
296+ if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
297+ return Err ( MahoutError :: InvalidInput (
298+ "One or more samples have zero or invalid norm" . to_string ( ) ,
299+ ) ) ;
300+ }
301+ }
302+
299303 Ok ( batch_state_vector)
300304 }
301305
@@ -412,17 +416,8 @@ impl QuantumEncoder for AmplitudeEncoder {
412416 }
413417 buffer
414418 } ;
415- {
416- crate :: profile_scope!( "GPU::NormValidation" ) ;
417- let host_inv_norms = device
418- . dtoh_sync_copy ( & inv_norms_gpu)
419- . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
420- if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
421- return Err ( MahoutError :: InvalidInput (
422- "One or more samples have zero or invalid norm" . to_string ( ) ,
423- ) ) ;
424- }
425- }
419+ // Launch encode kernel before D2H norm validation: GPU norm buffer is passed directly,
420+ // so the encode kernel can run immediately after the norm kernel without a CPU roundtrip.
426421 {
427422 crate :: profile_scope!( "GPU::BatchKernelLaunch" ) ;
428423 use cudarc:: driver:: DevicePtr ;
@@ -450,10 +445,22 @@ impl QuantumEncoder for AmplitudeEncoder {
450445 ) ) ) ;
451446 }
452447 }
448+ // Synchronize first; then validate norms on host (D2H after all GPU work is done).
453449 {
454450 crate :: profile_scope!( "GPU::Synchronize" ) ;
455451 sync_cuda_stream ( stream, "CUDA stream synchronize failed" ) ?;
456452 }
453+ {
454+ crate :: profile_scope!( "GPU::NormValidation" ) ;
455+ let host_inv_norms = device
456+ . dtoh_sync_copy ( & inv_norms_gpu)
457+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
458+ if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
459+ return Err ( MahoutError :: InvalidInput (
460+ "One or more samples have zero or invalid norm" . to_string ( ) ,
461+ ) ) ;
462+ }
463+ }
457464 Ok ( batch_state_vector)
458465 }
459466
0 commit comments