Skip to content

Commit

Permalink
fix for adjoints with preequilibration and sensitivities (#334)
Browse files Browse the repository at this point in the history
* Fix preequilibration for adjoints

* fix cvodes errors during prequilibration simulation by creating preequilibrating solver (new CVode object)

* fix scrambled sx0 in matlab and  transposed sx0 in hdf5

* add preequilibration tests (Fixes  #333)

* better naming for the steady state simulation routine and more cleanup
  • Loading branch information
paulstapor authored Jun 25, 2018
1 parent 7a9bd16 commit 2e2e445
Show file tree
Hide file tree
Showing 15 changed files with 388 additions and 59 deletions.
3 changes: 1 addition & 2 deletions include/amici/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ typedef enum AMICI_sensi_order_TAG {
typedef enum AMICI_sensi_meth_TAG {
AMICI_SENSI_NONE,
AMICI_SENSI_FSA,
AMICI_SENSI_ASA,
AMICI_SENSI_SS
AMICI_SENSI_ASA
} AMICI_sensi_meth;

/** linear solvers for CVODES/IDAS */
Expand Down
5 changes: 2 additions & 3 deletions include/amici/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ class Solver {
* AMIFree frees allocation solver memory
*/
virtual void AMIFree() = 0;

/**
* AMIAdjInit initializes the adjoint problem
*
Expand All @@ -807,7 +807,7 @@ class Solver {
*
*/
virtual void AMIAdjInit(long int steps, int interp) = 0;

/**
* AMICreateB specifies solver method and initializes solver memory for the
* backward problem
Expand Down Expand Up @@ -1122,7 +1122,6 @@ class Solver {

/** flag indicating whether sensitivities are supposed to be computed */
AMICI_sensi_order sensi = AMICI_SENSI_ORDER_NONE;

};

bool operator ==(const Solver &a, const Solver &b);
Expand Down
24 changes: 12 additions & 12 deletions include/amici/solver_cvodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class CVodeSolver : public Solver {
* @return The clone
*/
virtual Solver* clone() const override;



void *AMICreate(int lmm, int iter) override;

void AMISStolerances(double rtol, double atol) override;
Expand Down Expand Up @@ -95,6 +94,17 @@ class CVodeSolver : public Solver {

int AMISolveF(realtype tout, AmiVector *yret, AmiVector *ypret, realtype *tret,
int itask, int *ncheckPtr) override;


static int fxdot(realtype t, N_Vector x, N_Vector xdot, void *user_data);

static int fJSparse(realtype t, N_Vector x, N_Vector xdot, SlsMat J,
void *user_data, N_Vector tmp1, N_Vector tmp2,
N_Vector tmp3);

static int fJ(long int N, realtype t, N_Vector x, N_Vector xdot,
DlsMat J, void *user_data, N_Vector tmp1,
N_Vector tmp2, N_Vector tmp3);

void AMISolveB(realtype tBout, int itaskB) override;

Expand Down Expand Up @@ -196,18 +206,10 @@ class CVodeSolver : public Solver {

void setJacTimesVecFnB(int which) override;

static int fJ(long int N, realtype t, N_Vector x, N_Vector xdot,
DlsMat J, void *user_data, N_Vector tmp1,
N_Vector tmp2, N_Vector tmp3);

static int fJB(long int NeqBdot, realtype t, N_Vector x, N_Vector xB,
N_Vector xBdot, DlsMat JB, void *user_data, N_Vector tmp1B,
N_Vector tmp2B, N_Vector tmp3B);

static int fJSparse(realtype t, N_Vector x, N_Vector xdot, SlsMat J,
void *user_data, N_Vector tmp1, N_Vector tmp2,
N_Vector tmp3);

static int fJSparseB(realtype t, N_Vector x, N_Vector xB, N_Vector xBdot,
SlsMat JB, void *user_data, N_Vector tmp1B,
N_Vector tmp2B, N_Vector tmp3B);
Expand All @@ -232,8 +234,6 @@ class CVodeSolver : public Solver {
static int froot(realtype t, N_Vector x, realtype *root,
void *user_data);

static int fxdot(realtype t, N_Vector x, N_Vector xdot, void *user_data);

static int fxBdot(realtype t, N_Vector x, N_Vector xB,
N_Vector xBdot, void *user_data);

Expand Down
9 changes: 6 additions & 3 deletions include/amici/steadystateproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include <nvector/nvector_serial.h>

#include <functional>
#include <memory>

namespace amici {

Expand Down Expand Up @@ -35,10 +37,11 @@ class SteadystateProblem {
Model *model, int newton_status,
double run_time, int it);

void getNewtonSimulation(ReturnData *rdata, Solver *solver,
Model *model, int it);

void getSteadystateSimulation(ReturnData *rdata, Solver *solver,
Model *model, int it);

std::unique_ptr<void, std::function<void(void *)>> createSteadystateSimSolver(Solver *solver, Model *model, realtype tstart);

/** default constructor
* @param t pointer to time variable
* @param x pointer to state variables
Expand Down
8 changes: 4 additions & 4 deletions matlab/@amimodel/generateMatlabWrapper.m
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,19 @@ function generateMatlabWrapper(nx, ny, np, nk, nz, o2flag, amimodelo2, wrapperFi
fprintf(fid,'init = struct();\n');
fprintf(fid,'if(~isempty(options_ami.x0))\n');
fprintf(fid,' if(size(options_ami.x0,2)~=1)\n');
fprintf(fid,' error(''x0 field must be a row vector!'');\n');
fprintf(fid,' error(''x0 field must be a column vector!'');\n');
fprintf(fid,' end\n');
fprintf(fid,' if(size(options_ami.x0,1)~=nxfull)\n');
fprintf(fid,' error(''Number of columns in x0 field does not agree with number of states!'');\n');
fprintf(fid,' error(''Number of rows in x0 field does not agree with number of states!'');\n');
fprintf(fid,' end\n');
fprintf(fid,' init.x0 = options_ami.x0;\n');
fprintf(fid,'end\n');
fprintf(fid,'if(~isempty(options_ami.sx0))\n');
fprintf(fid,' if(size(options_ami.sx0,2)~=nplist)\n');
fprintf(fid,' error(''Number of rows in sx0 field does not agree with number of model parameters!'');\n');
fprintf(fid,' error(''Number of columns in sx0 field does not agree with number of model parameters!'');\n');
fprintf(fid,' end\n');
fprintf(fid,' if(size(options_ami.sx0,1)~=nxfull)\n');
fprintf(fid,' error(''Number of columns in sx0 field does not agree with number of states!'');\n');
fprintf(fid,' error(''Number of rows in sx0 field does not agree with number of states!'');\n');
fprintf(fid,' end\n');
fprintf(fid,' init.sx0 = bsxfun(@times,options_ami.sx0,1./permute(chainRuleFactor(:),[2,1]));\n');
fprintf(fid,'end\n');
Expand Down
2 changes: 1 addition & 1 deletion python/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def getFieldAsNumPyArray(rdata, field):
'x': [rdata.nt, rdata.nx],
'x0': [rdata.nx],
'sx': [rdata.nt, rdata.nplist, rdata.nx],
'sx0': [rdata.nx, rdata.nplist],
'sx0': [rdata.nplist, rdata.nx],

# observables
'y': [rdata.nt, rdata.ny],
Expand Down
12 changes: 10 additions & 2 deletions src/hdf5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ std::unique_ptr<ExpData> readSimulationExpData(std::string const& hdf5Filename,
checkEventDimensionsCompatible(m, n, model);
}

if(locationExists(file, hdf5Root + "/condition")) {
edata->fixedParameters = getDoubleDataset1D(file, hdf5Root + "/condition");
}

if(locationExists(file, hdf5Root + "/conditionPreequilibration")) {
edata->fixedParametersPreequilibration = getDoubleDataset1D(file, hdf5Root + "/conditionPreequilibration");
}

return edata;
}

Expand Down Expand Up @@ -171,7 +179,7 @@ void writeReturnData(const ReturnData &rdata, H5::H5File const& file, const std:
rdata.nJ - 1, rdata.nplist);

if (rdata.sx0.size())
createAndWriteDouble2DDataset(file, hdf5Location + "/sx0", rdata.sx0.data(), rdata.nx, rdata.nplist);
createAndWriteDouble2DDataset(file, hdf5Location + "/sx0", rdata.sx0.data(), rdata.nplist, rdata.nx);

if (rdata.sx.size())
createAndWriteDouble3DDataset(file, hdf5Location + "/sx", rdata.sx.data(), rdata.nt, rdata.nplist, rdata.nx);
Expand Down Expand Up @@ -499,7 +507,7 @@ void readModelDataFromHDF5(const H5::H5File &file, Model &model, const std::stri
hsize_t length1 = 0;
auto sx0 = getDoubleDataset2D(file, datasetPath + "/sx0", length0, length1);
if(sx0.size()) {
if (length0 != (unsigned) model.nx && length1 != (unsigned) model.nplist())
if (length0 != (unsigned) model.nplist() && length1 != (unsigned) model.nx)
throw(AmiException("Dimension mismatch when reading sx0. Expected %dx%d, got %llu, %llu.",
model.nx, model.nplist(), length0, length1));
model.setInitialStateSensitivities(sx0);
Expand Down
5 changes: 3 additions & 2 deletions src/returndata_matlab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ mxArray *initMatlabReturnFields(ReturnData const *rdata) {
mxArray *matlabSolutionStruct =
mxCreateStructMatrix(1, 1, numFields, field_names_sol);

std::vector<int> perm0 = {1, 0};
std::vector<int> perm1 = {0, 1};
std::vector<int> perm2 = {0, 2, 1};
std::vector<int> perm3 = {0, 2, 3, 1};
Expand All @@ -69,15 +70,15 @@ mxArray *initMatlabReturnFields(ReturnData const *rdata) {
}
if (rdata->nx > 0) {
writeMatlabField2(matlabSolutionStruct, "x", rdata->x, rdata->nt, rdata->nx, perm1);
writeMatlabField2(matlabSolutionStruct, "x0", rdata->x0, 1, rdata->nx, perm1);
writeMatlabField2(matlabSolutionStruct, "x0", rdata->x0, rdata->nx, 1, perm1);
}
if (rdata->ny > 0) {
writeMatlabField2(matlabSolutionStruct, "y", rdata->y, rdata->nt, rdata->ny, perm1);
writeMatlabField2(matlabSolutionStruct, "sigmay", rdata->sigmay, rdata->nt, rdata->ny, perm1);
}
if (rdata->sensi >= AMICI_SENSI_ORDER_FIRST) {
writeMatlabField1(matlabSolutionStruct, "sllh", rdata->sllh, rdata->nplist);
writeMatlabField2(matlabSolutionStruct, "sx0", rdata->sx0, rdata->nx, rdata->nplist, perm1);
writeMatlabField2(matlabSolutionStruct, "sx0", rdata->sx0, rdata->nplist, rdata->nx, perm0);

if (rdata->sensi_meth == AMICI_SENSI_FSA) {
writeMatlabField3(matlabSolutionStruct, "sx", rdata->sx, rdata->nt, rdata->nplist, rdata->nx, perm2);
Expand Down
6 changes: 2 additions & 4 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,10 @@ void Solver::setupAMI(ForwardProblem *fwd, Model *model) {
}
}

if (sensi_meth == AMICI_SENSI_ASA) {
if (model->nx > 0) {
if (sensi_meth == AMICI_SENSI_ASA)
if (model->nx > 0)
/* Allocate space for the adjoint computation */
AMIAdjInit(maxsteps, interpType);
}
}
}

AMISetId(model);
Expand Down
Loading

0 comments on commit 2e2e445

Please sign in to comment.