Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: F test for basic GWR #90

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
55 changes: 55 additions & 0 deletions include/gwmodelpp/GWRBasic.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,25 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe
typedef arma::mat (GWRBasic::*FitCoreCalculator)(const arma::mat&, const arma::vec&, const SpatialWeight&); //!< \~english Fit function declaration. \~chinese 拟合函数声明。
typedef arma::mat (GWRBasic::*FitCoreSHatCalculator)(const arma::mat&, const arma::vec&, const SpatialWeight&, arma::vec&); //!< \~english Fit function declaration. \~chinese 拟合函数声明。
typedef arma::mat (GWRBasic::*FitCoreCVCalculator)(const arma::mat&, const arma::vec&, const SpatialWeight&); //!< \~english Fit function declaration. \~chinese 拟合函数声明。
typedef void (GWRBasic::*FTestCalculator)();
typedef double (GWRBasic::*TrQtQCalculator)();
typedef double (GWRBasic::*TrQtQCoreCalculator)();
typedef arma::vec (GWRBasic::*DiagBCalculator)(arma::uword);
typedef arma::vec (GWRBasic::*DiagBCoreCalculator)(arma::uword, const arma::vec&);

typedef double (GWRBasic::*BandwidthSelectionCriterionCalculator)(BandwidthWeight*); //!< \~english Declaration of criterion calculator for bandwidth selection. \~chinese 带宽优选指标计算函数声明。
typedef double (GWRBasic::*IndepVarsSelectCriterionCalculator)(const std::vector<std::size_t>&); //!< \~english Declaration of criterion calculator for variable selection. \~chinese 变量优选指标计算函数声明。

struct FTestResult
{
double s = 0.0;
double df1 = 0.0;
double df2 = 0.0;
double p = 0.0;
};

using FTestResultCombine = std::tuple<FTestResult, FTestResult, std::vector<FTestResult>, FTestResult>;

private:

/**
Expand Down Expand Up @@ -388,6 +403,15 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe

void setStoreC(bool flag) { mStoreC = flag; }

bool isDoFTest() { return mIsDoFTest; };

void setIsDoFtest(bool value) { mIsDoFTest = value; }

FTestResultCombine fTestResults()
{
return std::make_tuple(mF1Test, mF2Test, mF3Test, mF4Test);
}

public: // Implement Algorithm
bool isValid() override;

Expand All @@ -396,6 +420,11 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe

arma::mat fit() override;

void fTest()
{
(this->*mFTestFunction)();
}

public: // Implement IVariableSelectable
Status getCriterion(const std::vector<size_t>& variables, double& criterion) override
{
Expand Down Expand Up @@ -508,6 +537,11 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe
*/
arma::mat fitBase();

void fTestBase();

double calcTrQtQBase();

arma::vec calcDiagBBase(arma::uword i);

private:

Expand All @@ -517,6 +551,10 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe

arma::mat fitCoreCVSerial(const arma::mat& x, const arma::vec& y, const SpatialWeight& sw);

double calcTrQtQCoreSerial();

arma::vec calcDiagBCoreSerial(arma::uword i, const arma::vec& c);

#ifdef ENABLE_OPENMP

/**
Expand Down Expand Up @@ -593,6 +631,10 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe
*/
arma::mat fitCoreSHatOmp(const arma::mat& x, const arma::vec& y, const SpatialWeight& sw, arma::vec& shat);

double calcTrQtQCoreOmp();

arma::vec calcDiagBCoreOmp(arma::uword i, const arma::vec& c);

#endif

#ifdef ENABLE_CUDA
Expand Down Expand Up @@ -678,6 +720,8 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe
double bandwidthSizeCriterionCVMpi(BandwidthWeight* bandwidthWeight);
double bandwidthSizeCriterionAICMpi(BandwidthWeight* bandwidthWeight);
arma::mat fitMpi();
double calcTrQtQMpi();
arma::vec calcDiagBMpi(arma::uword i);
#endif // ENABLE_MPI

public: // Implement IParallelizable
Expand Down Expand Up @@ -772,6 +816,17 @@ class GWRBasic : public GWRBase, public IBandwidthSelectable, public IVarialbeSe
arma::cube mC;//!< \~english All \f$S\f$ matrices. \~chinese 所有 \f$C\f$ 矩阵。
bool mStoreS = false; //!< \~english Whether to save S \~chinese 是否保存 S 矩阵
bool mStoreC = false; //!< \~english Whether to save C \~chinese 是否保存 C 矩阵

bool mIsDoFTest = false;
FTestResult mF1Test;
FTestResult mF2Test;
std::vector<FTestResult> mF3Test;
FTestResult mF4Test;
FTestCalculator mFTestFunction = &GWRBasic::fTestBase;
TrQtQCalculator mCalcTrQtQFunction = &GWRBasic::calcTrQtQBase;
TrQtQCoreCalculator mCalcTrQtQCoreFunction = &GWRBasic::calcTrQtQCoreSerial;
DiagBCalculator mCalcDiagBFunction = &GWRBasic::calcDiagBBase;
DiagBCoreCalculator mCalcDiagBCoreFunction = &GWRBasic::calcDiagBCoreSerial;
};

}
Expand Down
6 changes: 5 additions & 1 deletion include/gwmodelpp/utils/armampi.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@
#define GWM_MPI_UWORD MPI_UNSIGNED_LONG_LONG
#endif // ARMA_32BIT_WORD

void mat_mul_mpi(arma::mat& a, arma::mat& b, arma::mat& c, const int ip, const int np, const size_t range);
void mat_mul_mpi(arma::mat& a, arma::mat& b, arma::mat& c, const int ip, const int np);

void mat_quad_mpi(arma::mat& a, arma::mat& aTa, const int ip, const int np);

double mat_trace_mpi(arma::mat& a, const int ip, const int np);
Loading
Loading