Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Staggered invert test cleanup + expanded staggered gtest support #1421

Merged
merged 61 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
b22beda
Misc things I noticed elsewhere
weinbe2 Aug 17, 2023
b78720e
Removed argc/argv hacks because they are not used under the hood anyway
weinbe2 Aug 17, 2023
b2fc275
Cleaned up staggered dslash (c)test, enabled support for testing half…
weinbe2 Aug 18, 2023
ec0359c
Merge branch 'features/multishift-masses' into feature/stag-cleanup
weinbe2 Aug 18, 2023
cc31020
Various cleanup of gauge fields in staggered test exes
weinbe2 Aug 29, 2023
104c404
Various bits of function cleanup, making host verify names more consi…
weinbe2 Aug 30, 2023
80048eb
Merge branch 'develop' into feature/stag-cleanup
weinbe2 Aug 30, 2023
1f8f89c
Added support for mdagm tests for staggered, asqtad
weinbe2 Aug 31, 2023
94a332f
Small cleanup of treatment of naik terms
weinbe2 Sep 12, 2023
dd67aa0
Massive cleanup for staggered_invert/eigensolve_test, removed all enu…
weinbe2 Sep 12, 2023
9e492a5
Merge remote-tracking branch 'origin/develop' into feature/stag-cleanup
weinbe2 Oct 31, 2023
cb6d965
Merge branch 'develop' into feature/stag-cleanup
weinbe2 Nov 29, 2023
b39297b
Misc cleanup to make hisq_stencil_test match some conventions in stag…
weinbe2 Nov 29, 2023
a1049cf
Significant refactoring to hist_stencil_test, getting closer to simpl…
weinbe2 Nov 30, 2023
14f1407
hisq_stencil_test now runs via gtest, creating a ctest is outstanding
weinbe2 Nov 30, 2023
eff6773
Created a working hisq_stencil_ctest, woohoo!
weinbe2 Nov 30, 2023
3d14c35
Merge branch 'develop' into feature/stag-cleanup
weinbe2 Nov 30, 2023
cb94021
Some cleanup of staggered_invert_test, working towards a ctest
weinbe2 Nov 30, 2023
b3508be
Added a mostly working gtest!
weinbe2 Dec 1, 2023
5aa628f
More fully pipecleaned staggered ctest; split grid testing outstanding
weinbe2 Dec 1, 2023
550a5a9
Enabled split grid
weinbe2 Dec 1, 2023
1d25616
Merge remote-tracking branch 'origin/develop' into feature/stag-cleanup
weinbe2 Dec 1, 2023
b2560b5
Added info on how to run the old tests
weinbe2 Dec 1, 2023
e952379
Added Laplace ctests, tweaked some tolerances, uncovered a BiCGStab i…
weinbe2 Dec 1, 2023
afb41b6
Quality of life BiCGStab readability changes
weinbe2 Dec 1, 2023
15eeb82
Strong BiCGStab cleanup, still need to reconcile a host verification …
weinbe2 Dec 4, 2023
fc65b73
Various misc cleanup
weinbe2 Dec 5, 2023
7c5c2c5
Fixed a verify issue for full parity solves
weinbe2 Dec 6, 2023
b66fc76
Various staggered_invert_test cleanup, made it look more like invert_…
weinbe2 Dec 6, 2023
a5b89eb
Updated verifyStaggeredInversion to look like the regular verifyInver…
weinbe2 Dec 6, 2023
6e33961
Refactored staggered_eigensolve_test to look more like eigensolve_tes…
weinbe2 Dec 6, 2023
829ce62
Abstracted the staggered eigensolver test into a gtest. The tests and…
weinbe2 Dec 7, 2023
4837bc6
Added verify functions for eigenvectors and singular vectors
weinbe2 Dec 7, 2023
30eb5ee
Added a ctest for staggered eigensolves, fixed the verify function
weinbe2 Dec 7, 2023
e161cdc
All sorts of cleanup, moved various is_*_[solve/solution/etc] routine…
weinbe2 Dec 7, 2023
bf4eaad
Wilson-type compile fix
weinbe2 Dec 7, 2023
f9598ae
Merge remote-tracking branch 'origin/develop' into feature/stag-cleanup
weinbe2 Dec 7, 2023
b4300cc
Changed dwf tolerance check to use is_chiral
weinbe2 Dec 7, 2023
ca2be8d
Some BiCGStab cleanup, SVD deflation is being quirky
weinbe2 Dec 12, 2023
6bcbdab
Added an asqtad splitgrid test to probe loading both fat and fat+long…
weinbe2 Dec 12, 2023
67b5ef3
Further split grid cleanup, some tolerance fixes
weinbe2 Dec 12, 2023
19d6f02
Potential logic in (block) trlm related to using a max norm for getti…
weinbe2 Dec 12, 2023
862896e
Restored norm behavior for (block)TRLM LR convergence
weinbe2 Dec 13, 2023
5f901e4
Updated Wilson bits of split grid to use GaugeField objects as approp…
weinbe2 Dec 13, 2023
aa236f3
WAR for blowing out argument sizes for diluting typical staggered MG nc
weinbe2 Dec 13, 2023
c3709f2
Merge branch 'develop' into feature/stag-cleanup
weinbe2 Dec 26, 2023
48d7a21
Comment cleanup in eigensolver
weinbe2 Dec 26, 2023
d2b2372
doxygen
weinbe2 Dec 26, 2023
9b631c3
Cleaned up some unnecessary temporary fields outside of verify functions
weinbe2 Dec 26, 2023
0f366c1
Added a simple staggered host stag_matdag_mat verify function
weinbe2 Dec 26, 2023
9121110
Added a few extra parity checks to staggered dslash host verifies
weinbe2 Dec 26, 2023
a1303bd
Commented out the asqtad spectrum ctests
weinbe2 Jan 3, 2024
3252b6c
Removed twisted mass from the CI pipeline, other 4-d Wilson ops are s…
weinbe2 Jan 3, 2024
c8d301b
Added an explicit link to the Nc = 64, 96 issue in spinor_dilute.in.cu
weinbe2 Jan 3, 2024
6b0ba62
Cleaned up C-style casts, plus unnecessary newlines in errorQuda
weinbe2 Jan 4, 2024
4e533fe
Added a cmake flag QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST to toggl…
weinbe2 Jan 4, 2024
1228db9
is_laplace_enabled -> is_enabled_laplace, other misc cleanup
weinbe2 Jan 4, 2024
53dced9
Small stylistic updates to BiCGstab to match conventions in other mod…
weinbe2 Jan 4, 2024
209d554
Fixed using BiCGstab for generating near-null vectors
weinbe2 Jan 5, 2024
ca6b814
clang-format
weinbe2 Jan 9, 2024
431c4ec
Updated comments in TRLM to reflect code changes
weinbe2 Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ option(QUDA_CLOVER_DYNAMIC "Dynamically invert the clover term" ON)
option(QUDA_CLOVER_RECONSTRUCT "set to ON to enable compressed clover storage (requires QUDA_CLOVER_DYNAMIC)" ON)
option(QUDA_CLOVER_CHOLESKY_PROMOTE "Whether to promote the internal precision when inverting the clover term" ON)

option(QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST "Whether to run eigensolver ctests against the improved staggered operator (requires QUDA_DIRAC_STAGGERED)" OFF)

# Set CTest options
option(QUDA_CTEST_SEP_DSLASH_POLICIES "Test Dslash policies separately in ctest instead of only autotuning them." OFF)
option(QUDA_CTEST_DISABLE_BENCHMARKS "Disable benchmark test" ON)
Expand Down Expand Up @@ -391,7 +393,11 @@ set(CMAKE_EXE_LINKER_FLAGS_SANITIZE
CACHE STRING "Flags used by the linker during sanitizer debug builds.")

if(QUDA_CLOVER_RECONSTRUCT AND NOT QUDA_CLOVER_DYNAMIC)
message(SEND_ERROR "QUDA_CLOVER_RECONSTRUCT requires QUDA_CLOVER_DYNAMIC)")
message(SEND_ERROR "QUDA_CLOVER_RECONSTRUCT requires QUDA_CLOVER_DYNAMIC")
endif()

if (QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST AND NOT QUDA_DIRAC_STAGGERED)
message(SEND_ERROR "QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST requires QUDA_DIRAC_STAGGERED")
endif()

find_package(Threads REQUIRED)
Expand Down
1 change: 0 additions & 1 deletion ci/docker/Dockerfile.build
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ RUN QUDA_TEST_GRID_SIZE="1 1 1 2" cmake -S /quda/src \
-DQUDA_DIRAC_DEFAULT_OFF=ON \
-DQUDA_DIRAC_WILSON=ON \
-DQUDA_DIRAC_CLOVER=ON \
-DQUDA_DIRAC_TWISTED_MASS=ON \
-DQUDA_DIRAC_TWISTED_CLOVER=ON \
-DQUDA_DIRAC_STAGGERED=ON \
-GNinja \
Expand Down
23 changes: 21 additions & 2 deletions include/invert_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -1048,17 +1048,36 @@ namespace quda {

private:
const DiracMdagM matMdagM; // used by the eigensolver
// pointers to fields to avoid multiple creation overhead
ColorSpinorField *yp, *rp, *pp, *vp, *tmpp, *tp;

ColorSpinorField y; // Full precision solution accumulator
ColorSpinorField r; // Full precision residual vector
ColorSpinorField p; // Sloppy precision search direction
ColorSpinorField v; // Sloppy precision A * p
ColorSpinorField t; // Sloppy precision vector used for minres step
ColorSpinorField r0; // Bi-orthogonalization vector
ColorSpinorField r_sloppy; // Slopy precision residual vector
ColorSpinorField x_sloppy; // Sloppy solution accumulator vector
bool init = false;

/**
@brief Initiate the fields needed by the solver
@param[in] x Solution vector
@param[in] b Source vector
*/
void create(ColorSpinorField &x, const ColorSpinorField &b);

public:
BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile);
virtual ~BiCGstab();

void operator()(ColorSpinorField &out, ColorSpinorField &in) override;

/**
@return Return the residual vector from the prior solve
*/
ColorSpinorField &get_residual() override;

virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */

virtual QudaInverterType getInverterType() const final { return QUDA_BICGSTAB_INVERTER; }
Expand Down
19 changes: 14 additions & 5 deletions lib/eig_block_trlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,25 @@ namespace quda
eigensolveFromBlockArrowMat();
profile.TPSTART(QUDA_PROFILE_COMPUTE);

// mat_norm is updated.
// mat_norm is updated and used for LR
for (int i = num_locked; i < n_kr; i++)
if (fabs(alpha[i]) > mat_norm) mat_norm = fabs(alpha[i]);

// Lambda that returns mat_norm for LR and returns the relevant alpha
// (the corresponding Ritz value) for SR
auto check_norm = [&](double sr_norm) -> double {
if (eig_param->spectrum == QUDA_SPECTRUM_LR_EIG)
return mat_norm;
else
return sr_norm;
};

// Locking check
iter_locked = 0;
for (int i = 1; i < (n_kr - num_locked); i++) {
if (residua[i + num_locked] < epsilon * mat_norm) {
if (residua[i + num_locked] < epsilon * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Locking %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
epsilon * mat_norm);
epsilon * check_norm(alpha[i + num_locked]));
iter_locked = i;
} else {
// Unlikely to find new locked pairs
Expand All @@ -125,9 +134,9 @@ namespace quda
// Convergence check
iter_converged = iter_locked;
for (int i = iter_locked + 1; i < n_kr - num_locked; i++) {
if (residua[i + num_locked] < tol * mat_norm) {
if (residua[i + num_locked] < tol * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Converged %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
tol * mat_norm);
tol * check_norm(alpha[i + num_locked]));
iter_converged = i;
} else {
// Unlikely to find new converged pairs
Expand Down
23 changes: 16 additions & 7 deletions lib/eig_trlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,25 @@ namespace quda
eigensolveFromArrowMat();
profile.TPSTART(QUDA_PROFILE_COMPUTE);

// mat_norm is updated.
// mat_norm is updated and used for LR
for (int i = num_locked; i < n_kr; i++)
if (fabs(alpha[i]) > mat_norm) mat_norm = fabs(alpha[i]);

// Lambda that returns mat_norm for LR and returns the relevant alpha
// (the corresponding Ritz value) for SR
auto check_norm = [&](double sr_norm) -> double {
if (eig_param->spectrum == QUDA_SPECTRUM_LR_EIG)
return mat_norm;
else
return sr_norm;
};

// Locking check
iter_locked = 0;
for (int i = 1; i < (n_kr - num_locked); i++) {
if (residua[i + num_locked] < epsilon * mat_norm) {
if (residua[i + num_locked] < epsilon * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Locking %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
epsilon * mat_norm);
epsilon * check_norm(alpha[i + num_locked]));
iter_locked = i;
} else {
// Unlikely to find new locked pairs
Expand All @@ -106,9 +115,9 @@ namespace quda
// Convergence check
iter_converged = iter_locked;
for (int i = iter_locked + 1; i < n_kr - num_locked; i++) {
if (residua[i + num_locked] < tol * mat_norm) {
if (residua[i + num_locked] < tol * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Converged %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
tol * mat_norm);
tol * check_norm(alpha[i + num_locked]));
iter_converged = i;
} else {
// Unlikely to find new converged pairs
Expand Down Expand Up @@ -165,8 +174,8 @@ namespace quda
logQuda(QUDA_SUMMARIZE, "TRLM computed the requested %d vectors in %d restart steps and %d OP*x operations.\n",
n_conv, restart_iter, iter);

// Dump all Ritz values and residua if using Chebyshev
for (int i = 0; i < n_conv && eig_param->use_poly_acc; i++) {
// Dump all Ritz values and residua
for (int i = 0; i < n_conv; i++) {
logQuda(QUDA_SUMMARIZE, "RitzValue[%04d]: (%+.16e, %+.16e) residual %.16e\n", i, alpha[i], 0.0, residua[i]);
}

Expand Down
123 changes: 62 additions & 61 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3014,7 +3014,7 @@ void loadFatLongGaugeQuda(QudaInvertParam *inv_param, QudaGaugeParam *gauge_para
template <class Interface, class... Args>
void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // color spinor field pointers, and inv_param
void *h_gauge, void *milc_fatlinks, void *milc_longlinks,
QudaGaugeParam *gauge_param, // gauge field pointers
QudaGaugeParam *gauge_param_, // gauge field pointers
void *h_clover, void *h_clovinv, // clover field pointers
Interface op, Args... args)
{
Expand All @@ -3034,14 +3034,17 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
errorQuda("split_key = [%d,%d,%d,%d] is not valid", split_key[0], split_key[1], split_key[2], split_key[3]);
}

// Create a local copy of gauge_param that we can modify without perturbing
// the original one
if (!gauge_param_) errorQuda("Input gauge_param is null");
QudaGaugeParam gauge_param = *gauge_param_;

if (num_sub_partition == 1) { // In this case we don't split the grid.

for (int n = 0; n < param->num_src; n++) { op(_hp_x[n], _hp_b[n], param, args...); }

} else {

if (gauge_param == nullptr) { errorQuda("gauge_param == nullptr"); }

// Doing the sub-partition arithmatics
if (param->num_src_per_sub_partition * num_sub_partition != param->num_src) {
errorQuda("We need to have split_grid[0](=%d) * split_grid[1](=%d) * split_grid[2](=%d) * split_grid[3](=%d) * "
Expand All @@ -3058,44 +3061,50 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col

checkInvertParam(param, _hp_x[0], _hp_b[0]);

bool is_staggered;
bool is_staggered = false;
bool is_asqtad = false;
if (h_gauge) {
is_staggered = false;
} else if (milc_fatlinks) {
is_staggered = true;
if (param->dslash_type == QUDA_ASQTAD_DSLASH) {
if (!milc_longlinks) errorQuda("milc_longlinks is null for an asqtad dslash");
is_asqtad = true;
}
} else {
errorQuda("Both h_gauge and milc_fatlinks are null.");
is_staggered = true; // to suppress compiler warning/error.
}

// Gauge fields/params
GaugeFieldParam *gf_param = nullptr;
GaugeField *in = nullptr;
GaugeFieldParam gf_param;
GaugeField in;
// Staggered gauge fields/params
GaugeFieldParam *milc_fatlink_param = nullptr;
GaugeFieldParam *milc_longlink_param = nullptr;
GaugeField *milc_fatlink_field = nullptr;
GaugeField *milc_longlink_field = nullptr;
GaugeFieldParam milc_fatlink_param;
GaugeFieldParam milc_longlink_param;
quda::GaugeField milc_fatlink_field;
quda::GaugeField milc_longlink_field;

// set up the gauge field params.
if (!is_staggered) { // not staggered
gf_param = new GaugeFieldParam(*gauge_param, h_gauge);
if (gf_param->order <= 4) gf_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO;
in = GaugeField::Create(*gf_param);
gf_param = GaugeFieldParam(gauge_param, h_gauge);
in = GaugeField(gf_param);
} else { // staggered
milc_fatlink_param = new GaugeFieldParam(*gauge_param, milc_fatlinks);
if (milc_fatlink_param->order <= 4) milc_fatlink_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO;
milc_fatlink_field = GaugeField::Create(*milc_fatlink_param);
milc_longlink_param = new GaugeFieldParam(*gauge_param, milc_longlinks);
if (milc_longlink_param->order <= 4) milc_longlink_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO;
milc_longlink_field = GaugeField::Create(*milc_longlink_param);
milc_fatlink_param = GaugeFieldParam(gauge_param, milc_fatlinks);
milc_fatlink_param.order = QUDA_MILC_GAUGE_ORDER;
milc_fatlink_field = GaugeField(milc_fatlink_param);

if (is_asqtad) {
milc_longlink_param = GaugeFieldParam(gauge_param, milc_longlinks);
milc_longlink_param.order = QUDA_MILC_GAUGE_ORDER;
milc_longlink_field = GaugeField(milc_longlink_param);
}
}

// Create the temp host side helper fields, which are just wrappers of the input pointers.
bool pc_solution
= (param->solution_type == QUDA_MATPC_SOLUTION) || (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION);

lat_dim_t X = {gauge_param->X[0], gauge_param->X[1], gauge_param->X[2], gauge_param->X[3]};
lat_dim_t X = {gauge_param.X[0], gauge_param.X[1], gauge_param.X[2], gauge_param.X[3]};
ColorSpinorParam cpuParam(_hp_b[0], *param, X, pc_solution, param->input_location);
std::vector<ColorSpinorField *> _h_b(param->num_src);
for (int i = 0; i < param->num_src; i++) {
Expand All @@ -3119,16 +3128,14 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
errorQuda("Split not possible: %2d %% %2d != 0", comm_dim(d), split_key[d]);
}
if (!is_staggered) {
gf_param->x[d] *= split_key[d];
gf_param->pad *= split_key[d];
gf_param.x[d] *= split_key[d];
gf_param.pad *= split_key[d];
} else {
milc_fatlink_param->x[d] *= split_key[d];
milc_fatlink_param->pad *= split_key[d];
milc_longlink_param->x[d] *= split_key[d];
milc_longlink_param->pad *= split_key[d];
milc_fatlink_param.x[d] *= split_key[d];
if (is_asqtad) milc_longlink_param.x[d] *= split_key[d];
}
gauge_param->X[d] *= split_key[d];
gauge_param->ga_pad *= split_key[d];
gauge_param.X[d] *= split_key[d];
if (!is_staggered) gauge_param.ga_pad *= split_key[d];
}

// Deal with clover field. For Multi source computatons, clover field construction is done
Expand Down Expand Up @@ -3171,26 +3178,30 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
}
}

quda::GaugeField *collected_gauge = nullptr;
quda::GaugeField *collected_milc_fatlink_field = nullptr;
quda::GaugeField *collected_milc_longlink_field = nullptr;
quda::GaugeField collected_gauge;
quda::GaugeField collected_milc_fatlink_field;
quda::GaugeField collected_milc_longlink_field;
weinbe2 marked this conversation as resolved.
Show resolved Hide resolved

if (!is_staggered) {
gf_param->create = QUDA_NULL_FIELD_CREATE;
collected_gauge = new quda::GaugeField(*gf_param);
gf_param.create = QUDA_NULL_FIELD_CREATE;
collected_gauge = quda::GaugeField(gf_param);
std::vector<quda::GaugeField *> v_g(1);
v_g[0] = in;
quda::split_field(*collected_gauge, v_g, split_key);
v_g[0] = &in;
quda::split_field(collected_gauge, v_g, split_key);
} else {
milc_fatlink_param->create = QUDA_NULL_FIELD_CREATE;
milc_longlink_param->create = QUDA_NULL_FIELD_CREATE;
collected_milc_fatlink_field = new quda::GaugeField(*milc_fatlink_param);
collected_milc_longlink_field = new quda::GaugeField(*milc_longlink_param);
std::vector<quda::GaugeField *> v_g(1);
v_g[0] = milc_fatlink_field;
quda::split_field(*collected_milc_fatlink_field, v_g, split_key);
v_g[0] = milc_longlink_field;
quda::split_field(*collected_milc_longlink_field, v_g, split_key);

milc_fatlink_param.create = QUDA_NULL_FIELD_CREATE;
collected_milc_fatlink_field = GaugeField(milc_fatlink_param);
v_g[0] = &milc_fatlink_field;
quda::split_field(collected_milc_fatlink_field, v_g, split_key);

if (is_asqtad) {
milc_longlink_param.create = QUDA_NULL_FIELD_CREATE;
collected_milc_longlink_field = GaugeField(milc_longlink_param);
v_g[0] = &milc_longlink_field;
quda::split_field(collected_milc_longlink_field, v_g, split_key);
}
}

profileInvertMultiSrc.TPSTART(QUDA_PROFILE_PREAMBLE);
Expand Down Expand Up @@ -3223,10 +3234,10 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
// the split topology.
logQuda(QUDA_DEBUG_VERBOSE, "Split grid loading gauge field...\n");
if (!is_staggered) {
loadGaugeQuda(collected_gauge->raw_pointer(), gauge_param);
loadGaugeQuda(collected_gauge.raw_pointer(), &gauge_param);
} else {
loadFatLongGaugeQuda(param, gauge_param, collected_milc_fatlink_field->raw_pointer(),
collected_milc_longlink_field->raw_pointer());
loadFatLongGaugeQuda(param, &gauge_param, collected_milc_fatlink_field.raw_pointer(),
(is_asqtad) ? collected_milc_longlink_field.raw_pointer() : nullptr);
}
logQuda(QUDA_DEBUG_VERBOSE, "Split grid loaded gauge field...\n");

Expand All @@ -3251,8 +3262,8 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
comm_barrier();

for (int d = 0; d < CommKey::n_dim; d++) {
gauge_param->X[d] /= split_key[d];
gauge_param->ga_pad /= split_key[d];
gauge_param.X[d] /= split_key[d];
if (!is_staggered) gauge_param.ga_pad /= split_key[d];
}

for (int n = 0; n < param->num_src_per_sub_partition; n++) {
Expand All @@ -3268,27 +3279,17 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
for (auto p : _h_x) { delete p; }
for (auto p : _h_b) { delete p; }

if (!is_staggered) {
delete in;
delete collected_gauge;
} else {
delete milc_fatlink_field;
delete milc_longlink_field;
delete collected_milc_fatlink_field;
delete collected_milc_longlink_field;
}

if (input_clover) { delete input_clover; }
if (collected_clover) { delete collected_clover; }

profileInvertMultiSrc.TPSTOP(QUDA_PROFILE_EPILOGUE);

// Restore the gauge field
if (!is_staggered) {
loadGaugeQuda(h_gauge, gauge_param);
loadGaugeQuda(h_gauge, gauge_param_);
} else {
freeGaugeQuda();
loadFatLongGaugeQuda(param, gauge_param, milc_fatlinks, milc_longlinks);
loadFatLongGaugeQuda(param, gauge_param_, milc_fatlinks, milc_longlinks);
}

if (param->dslash_type == QUDA_CLOVER_WILSON_DSLASH || param->dslash_type == QUDA_TWISTED_CLOVER_DSLASH) {
Expand Down
Loading