Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion generic_ks/ks_multicg_offset_quda.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ int ks_multicg_offset_field_gpu(
double* final_residual = (double*)malloc(num_offsets*sizeof(double));
double* final_relative_residual = (double*)malloc(num_offsets*sizeof(double));

double max_residual = 0.0, min_residual = __DBL_MAX__;
for(i=0; i<num_offsets; ++i){

residual[i] = qic[i].resid;
Expand All @@ -145,6 +146,8 @@ int ks_multicg_offset_field_gpu(
// scale the shifted residual relative to the residue
residual[i] = fabs(residue[0] / residue[i]) * residual[0];
if (residual[i] < 1e-14) residual[i] = 1e-14;
if (residual[i] > max_residual) max_residual = residual[i];
if (residual[i] < min_residual) min_residual = residual[i];
} else {
residual[i] = qic[i].resid; // for a mixed-precision solver use residual for higher shifts
}
Expand All @@ -160,7 +163,15 @@ int ks_multicg_offset_field_gpu(
}

inv_args.max_iter = qic[0].max*qic[0].nrestart;
#if defined(MAX_MIXED) || defined(HALF_MIXED) // never do half precision with multi-shift solver
#if defined(MAX_MIXED)
if (residual[0] > 3e-5 && min_residual / max_residual > 4e-3) {
// 3e-5 ~ 2**-15 (Machine epsilon of QUDA's half precision)
// 4e-3 ~ 2**-8 (Refinement will not use too many iterations)
inv_args.mixed_precision = 2;
} else {
inv_args.mixed_precision = 1;
}
#elif defined(HALF_MIXED)
inv_args.mixed_precision = 1;
#else
inv_args.mixed_precision = 0;
Expand Down