Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
jons-pf committed Mar 5, 2025
1 parent 5676d3f commit 6e5e281
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ VectorXd NonEmptyVectorOr(const std::vector<double>& vec, const double val) {
// Write object to the specified HDF5 file, under key "vmecinternalresults".
absl::Status vmecpp::VmecInternalResults::WriteTo(H5::H5File& file) const {
file.createGroup(H5key);
WRITEMEMBER(return_outputs_even_if_not_converged);
WRITEMEMBER(sign_of_jacobian);
WRITEMEMBER(num_full);
WRITEMEMBER(num_half);
Expand Down Expand Up @@ -115,11 +114,6 @@ absl::Status vmecpp::VmecInternalResults::WriteTo(H5::H5File& file) const {

absl::Status vmecpp::VmecInternalResults::LoadInto(
vmecpp::VmecInternalResults& obj, H5::H5File& from_file) {
if (from_file.exists("return_outputs_even_if_not_converged")) {
READMEMBER(return_outputs_even_if_not_converged);
} else {
obj.return_outputs_even_if_not_converged = false;
}
READMEMBER(sign_of_jacobian);
READMEMBER(num_full);
READMEMBER(num_half);
Expand Down Expand Up @@ -1255,7 +1249,6 @@ vmecpp::OutputQuantities vmecpp::ComputeOutputQuantities(
output_quantities.vmec_internal_results = GatherDataFromThreads(
sign_of_jacobian, s, fc, constants, radial_partitioning, decomposed_x,
models_from_threads, radial_profiles);
output_quantities.vmec_internal_results.return_outputs_even_if_not_converged = indata.return_outputs_even_if_not_converged;

if (vmec_status == VmecStatus::NORMAL_TERMINATION ||
vmec_status == VmecStatus::SUCCESSFUL_TERMINATION) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ namespace vmecpp {
// This is the data from inside VMEC, gathered from all threads,
// that form the basis of computing the output quantities.
struct VmecInternalResults {
// copy of corresponding VmecINDATA variable
// if true, compute full outputs even if VMEC did not converge
// defaults to false
bool return_outputs_even_if_not_converged;

int sign_of_jacobian;

// total number of full-grid points
Expand Down
40 changes: 21 additions & 19 deletions tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,52 +33,54 @@
def test_run(max_threads, input_file, verbose):
"""Test that the Python API works with different combinations of parameters."""

input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / input_file)
out = vmecpp.run(input, max_threads=max_threads, verbose=verbose)
vmec_input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / input_file)
vmec_output = vmecpp.run(vmec_input, max_threads=max_threads, verbose=verbose)

assert out.wout is not None
assert vmec_output.wout is not None


def test_get_outputs_if_non_converged_if_wanted():
"""Test that one can get the VMEC++ outputs even if a run did not converge."""

input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / "solovev.json")
vmec_input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / "solovev.json")

# only allow one iteration - VMEC++ will not converge that fast
input.niter_array[-1] = 1
vmec_input.niter_array[-1] = 1

# instruct VMEC++ to return the outputs, even if it did not converge
input.return_outputs_even_if_not_converged = True
vmec_input.return_outputs_even_if_not_converged = True

out = vmecpp.run(input)
vmec_output = vmecpp.run(vmec_input)

assert out.wout is not None
assert out.wout.niter == 2
assert vmec_output.wout is not None
assert vmec_output.wout.niter == 2


# We trust the C++ tests to cover the hot restart functionality properly,
# here we just want to test that the Python API for it works.
def test_run_with_hot_restart():
input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / "cma.json")
vmec_input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / "cma.json")

# base run
out = vmecpp.run(input, verbose=False)
vmec_output = vmecpp.run(vmec_input, verbose=False)

# now with hot restart
# (only a single multigrid step is supported)
input.ns_array = input.ns_array[-1:]
input.ftol_array = input.ftol_array[-1:]
input.niter_array = input.niter_array[-1:]
hot_restarted_out = vmecpp.run(input, verbose=False, restart_from=out)
vmec_input.ns_array = vmec_input.ns_array[-1:]
vmec_input.ftol_array = vmec_input.ftol_array[-1:]
vmec_input.niter_array = vmec_input.niter_array[-1:]
vmec_output_hot_restarted = vmecpp.run(
vmec_input, verbose=False, restart_from=vmec_output
)

assert hot_restarted_out.wout.niter == 2
assert vmec_output_hot_restarted.wout.niter == 2


@pytest.fixture(scope="module")
def cma_output() -> vmecpp.VmecOutput:
input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / "cma.json")
out = vmecpp.run(input, verbose=False)
return out
vmec_input = vmecpp.VmecInput.from_file(TEST_DATA_DIR / "cma.json")
vmec_output = vmecpp.run(vmec_input, verbose=False)
return vmec_output


def test_vmecwout_save(cma_output):
Expand Down

0 comments on commit 6e5e281

Please sign in to comment.