Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 105 additions & 80 deletions qdp/qdp-kernels/src/iqp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,26 @@ __global__ void iqp_encode_kernel_naive(
// FWT O(n * 2^n) Implementation
// ============================================================================

// Step 1: Compute f[x] = exp(i*theta(x)) for all x
// One thread per state, reuses existing compute_phase()
// Step 1: Compute f[x] = exp(i*theta(x)) for all x.
// Uses a grid-stride loop so large state vectors can reuse a fixed launch size.
__global__ void iqp_phase_kernel(
const double* __restrict__ data,
cuDoubleComplex* __restrict__ state,
size_t state_len,
unsigned int num_qubits,
int enable_zz
) {
size_t x = blockIdx.x * blockDim.x + threadIdx.x;
if (x >= state_len) return;
const size_t stride = gridDim.x * blockDim.x;

double phase = compute_phase(data, x, num_qubits, enable_zz);
for (size_t x = blockIdx.x * blockDim.x + threadIdx.x;
x < state_len;
x += stride) {
double phase = compute_phase(data, x, num_qubits, enable_zz);

double cos_phase, sin_phase;
sincos(phase, &sin_phase, &cos_phase);
state[x] = make_cuDoubleComplex(cos_phase, sin_phase);
double cos_phase, sin_phase;
sincos(phase, &sin_phase, &cos_phase);
state[x] = make_cuDoubleComplex(cos_phase, sin_phase);
}
}

// Step 2a: FWT butterfly stage for global memory (n > threshold)
Expand All @@ -142,39 +145,53 @@ __global__ void iqp_phase_kernel(
__global__ void fwt_butterfly_stage_kernel(
cuDoubleComplex* __restrict__ state,
size_t state_len,
unsigned int stage // 0 to n-1
unsigned int stage, // 0 to n-1
double norm_factor
) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t grid_stride = gridDim.x * blockDim.x;

// Each thread processes one butterfly pair
// For stage s, butterflies are separated by 2^s
size_t stride = 1ULL << stage;
size_t block_size = stride << 1; // 2^(s+1)
size_t num_pairs = state_len >> 1; // state_len / 2 total pairs

if (idx >= num_pairs) return;
const size_t stride = 1ULL << stage;
const size_t block_size = stride << 1; // 2^(s+1)
const size_t num_pairs = state_len >> 1; // state_len / 2 total pairs
const bool apply_normalization = (norm_factor != 1.0);

// Compute which butterfly pair this thread handles
size_t block_idx = idx / stride;
size_t pair_offset = idx % stride;
size_t i = block_idx * block_size + pair_offset;
size_t j = i + stride;
for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < num_pairs;
idx += grid_stride) {
// Compute which butterfly pair this thread handles
const size_t block_idx = idx / stride;
const size_t pair_offset = idx % stride;
const size_t i = block_idx * block_size + pair_offset;
const size_t j = i + stride;

// Load values
cuDoubleComplex a = state[i];
cuDoubleComplex b = state[j];
// Load values
cuDoubleComplex a = state[i];
cuDoubleComplex b = state[j];
cuDoubleComplex sum = cuCadd(a, b);
cuDoubleComplex diff = cuCsub(a, b);

if (apply_normalization) {
sum = make_cuDoubleComplex(cuCreal(sum) * norm_factor, cuCimag(sum) * norm_factor);
diff = make_cuDoubleComplex(cuCreal(diff) * norm_factor, cuCimag(diff) * norm_factor);
}

// Butterfly: (a, b) -> (a + b, a - b)
state[i] = cuCadd(a, b);
state[j] = cuCsub(a, b);
// Butterfly: (a, b) -> (a + b, a - b)
state[i] = sum;
state[j] = diff;
}
}

// Step 2b: FWT using shared memory (n <= threshold)
// All stages in single kernel launch
__global__ void fwt_shared_memory_kernel(
// Step 1 + 2 + 3 fused: phase computation, full shared-memory FWT, normalization
// Used when the entire state fits in shared memory.
__global__ void iqp_phase_fwt_shared_normalize_kernel(
const double* __restrict__ data,
cuDoubleComplex* __restrict__ state,
size_t state_len,
unsigned int num_qubits
unsigned int num_qubits,
int enable_zz,
double norm_factor
) {
extern __shared__ cuDoubleComplex shared_state[];

Expand All @@ -185,9 +202,13 @@ __global__ void fwt_shared_memory_kernel(
// Block 0 handles the full transform
if (bid > 0) return;

// Load state into shared memory
// Materialize the phase vector directly in shared memory to avoid an
// intermediate global-memory write before the FWT.
for (size_t i = tid; i < state_len; i += blockDim.x) {
shared_state[i] = state[i];
double phase = compute_phase(data, i, num_qubits, enable_zz);
double cos_phase, sin_phase;
sincos(phase, &sin_phase, &cos_phase);
shared_state[i] = make_cuDoubleComplex(cos_phase, sin_phase);
}
__syncthreads();

Expand All @@ -213,9 +234,13 @@ __global__ void fwt_shared_memory_kernel(
__syncthreads();
}

// Write back to global memory
// Write back once, after applying the global 1/2^n normalization.
for (size_t i = tid; i < state_len; i += blockDim.x) {
state[i] = shared_state[i];
cuDoubleComplex val = shared_state[i];
state[i] = make_cuDoubleComplex(
cuCreal(val) * norm_factor,
cuCimag(val) * norm_factor
);
}
}

Expand All @@ -225,14 +250,17 @@ __global__ void normalize_state_kernel(
size_t state_len,
double norm_factor
) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= state_len) return;
const size_t stride = gridDim.x * blockDim.x;

cuDoubleComplex val = state[idx];
state[idx] = make_cuDoubleComplex(
cuCreal(val) * norm_factor,
cuCimag(val) * norm_factor
);
for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < state_len;
idx += stride) {
cuDoubleComplex val = state[idx];
state[idx] = make_cuDoubleComplex(
cuCreal(val) * norm_factor,
cuCimag(val) * norm_factor
);
}
}

// ============================================================================
Expand Down Expand Up @@ -272,15 +300,16 @@ __global__ void iqp_encode_batch_kernel_naive(
// FWT O(n * 2^n) Batch Implementation
// ============================================================================

// Step 1: Compute f[x] = exp(i*theta(x)) for all x, for all samples in batch
// Step 1: Compute the normalized phase vector for all samples in batch.
__global__ void iqp_phase_batch_kernel(
const double* __restrict__ data_batch,
cuDoubleComplex* __restrict__ state_batch,
size_t num_samples,
size_t state_len,
unsigned int num_qubits,
unsigned int data_len,
int enable_zz
int enable_zz,
double norm_factor
) {
const size_t total_elements = num_samples * state_len;
const size_t stride = gridDim.x * blockDim.x;
Expand All @@ -297,7 +326,10 @@ __global__ void iqp_phase_batch_kernel(

double cos_phase, sin_phase;
sincos(phase, &sin_phase, &cos_phase);
state_batch[global_idx] = make_cuDoubleComplex(cos_phase, sin_phase);
state_batch[global_idx] = make_cuDoubleComplex(
cos_phase * norm_factor,
sin_phase * norm_factor
);
}
}

Expand Down Expand Up @@ -385,7 +417,7 @@ extern "C" {
/// For num_qubits >= FWT_MIN_QUBITS, uses Fast Walsh-Hadamard Transform:
/// 1. Phase computation: f[x] = exp(i*theta(x)) - O(2^n)
/// 2. FWT transform: WHT of phase array - O(n * 2^n)
/// 3. Normalization: divide by 2^n - O(2^n)
/// 3. Normalization: fused into the shared-memory path or the final FWT stage
/// Total: O(n * 2^n) vs naive O(4^n)
int launch_iqp_encode(
const double* data_d,
Expand All @@ -401,6 +433,7 @@ int launch_iqp_encode(

cuDoubleComplex* state_complex_d = static_cast<cuDoubleComplex*>(state_d);
const int blockSize = DEFAULT_BLOCK_SIZE;
const double norm_factor = 1.0 / (double)state_len;

// Use naive kernel for small n (FWT overhead not worth it)
if (num_qubits < FWT_MIN_QUBITS) {
Expand All @@ -416,48 +449,46 @@ int launch_iqp_encode(
}

// FWT-based implementation for larger n
const int gridSize = (state_len + blockSize - 1) / blockSize;
const size_t blocks_needed = (state_len + blockSize - 1) / blockSize;
const int gridSize = (int)((blocks_needed < MAX_GRID_BLOCKS) ? blocks_needed : MAX_GRID_BLOCKS);

// Step 1: Compute phase array f[x] = exp(i*theta(x))
iqp_phase_kernel<<<gridSize, blockSize, 0, stream>>>(
data_d,
state_complex_d,
state_len,
num_qubits,
enable_zz
);

// Step 2: Apply FWT
if (num_qubits <= FWT_SHARED_MEM_THRESHOLD) {
// Shared memory FWT - all stages in one kernel
// Shared-memory fast path: phase generation, full FWT, and normalization
// happen in a single launch and touch global memory only once.
size_t shared_mem_size = state_len * sizeof(cuDoubleComplex);
fwt_shared_memory_kernel<<<1, blockSize, shared_mem_size, stream>>>(
iqp_phase_fwt_shared_normalize_kernel<<<1, blockSize, shared_mem_size, stream>>>(
data_d,
state_complex_d,
state_len,
num_qubits
num_qubits,
enable_zz,
norm_factor
);
} else {
// Step 1: Compute phase array f[x] = exp(i*theta(x))
iqp_phase_kernel<<<gridSize, blockSize, 0, stream>>>(
data_d,
state_complex_d,
state_len,
num_qubits,
enable_zz
);

// Global memory FWT - one kernel launch per stage
const size_t num_pairs = state_len >> 1;
const int fwt_grid_size = (num_pairs + blockSize - 1) / blockSize;
const size_t fwt_blocks_needed = (num_pairs + blockSize - 1) / blockSize;
const int fwt_grid_size = (int)((fwt_blocks_needed < MAX_GRID_BLOCKS) ? fwt_blocks_needed : MAX_GRID_BLOCKS);

for (unsigned int stage = 0; stage < num_qubits; ++stage) {
fwt_butterfly_stage_kernel<<<fwt_grid_size, blockSize, 0, stream>>>(
state_complex_d,
state_len,
stage
stage,
(stage + 1 == num_qubits) ? norm_factor : 1.0
);
}
}

// Step 3: Normalize by 1/2^n
double norm_factor = 1.0 / (double)state_len;
normalize_state_kernel<<<gridSize, blockSize, 0, stream>>>(
state_complex_d,
state_len,
norm_factor
);

return (int)cudaGetLastError();
}

Expand All @@ -478,9 +509,9 @@ int launch_iqp_encode(
///
/// # Algorithm
/// For num_qubits >= FWT_MIN_QUBITS, uses Fast Walsh-Hadamard Transform:
/// 1. Phase computation for all samples - O(batch * 2^n)
/// 1. Normalized phase computation for all samples - O(batch * 2^n)
/// 2. FWT transform for all samples - O(batch * n * 2^n)
/// 3. Normalization - O(batch * 2^n)
/// 3. No standalone normalization kernel; the scale factor is fused into phase
/// Total: O(batch * n * 2^n) vs naive O(batch * 4^n)
int launch_iqp_encode_batch(
const double* data_batch_d,
Expand All @@ -501,6 +532,7 @@ int launch_iqp_encode_batch(
const size_t total_elements = num_samples * state_len;
const size_t blocks_needed = (total_elements + blockSize - 1) / blockSize;
const size_t gridSize = (blocks_needed < MAX_GRID_BLOCKS) ? blocks_needed : MAX_GRID_BLOCKS;
const double norm_factor = 1.0 / (double)state_len;

// Use naive kernel for small n (FWT overhead not worth it)
if (num_qubits < FWT_MIN_QUBITS) {
Expand All @@ -526,7 +558,8 @@ int launch_iqp_encode_batch(
state_len,
num_qubits,
data_len,
enable_zz
enable_zz,
norm_factor
);

// Step 2: Apply FWT to all samples (global memory version for batch)
Expand All @@ -546,14 +579,6 @@ int launch_iqp_encode_batch(
);
}

// Step 3: Normalize by 1/2^n
double norm_factor = 1.0 / (double)state_len;
normalize_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
state_complex_d,
total_elements,
norm_factor
);

return (int)cudaGetLastError();
}

Expand Down
Loading
Loading