Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ravil/multi constrains #5

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 143 additions & 63 deletions PartitionMetis.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,31 @@ class PartitionMetis

const idx_t m_numCells;

#ifdef USE_MPI
int m_rank = 0;
int m_nparts = 0;
#endif

// Better the use the same convention as METIS, even though we have the prefix m_
idx_t m_numflag = 0; // Since we are a C++ library, numflag for Metis will be always 0
idx_t m_ncommonnodes = 3; // TODO adapt for hex, but until then we can use constexpr for these

std::vector<idx_t> m_vtxdist;
std::vector<idx_t> m_xadj;
std::vector<idx_t> m_adjncy;

public:
PartitionMetis(const cell_t* cells, unsigned int numCells) :
#ifdef USE_MPI
m_comm(MPI_COMM_WORLD),
#endif // USE_MPI
m_cells(cells), m_numCells(numCells)
{ }
{
#ifdef USE_MPI
MPI_Comm_rank(m_comm, &m_rank);
MPI_Comm_size(m_comm, &m_nparts);
#endif
}

#ifdef USE_MPI
void setComm(MPI_Comm comm)
Expand All @@ -55,6 +73,82 @@ class PartitionMetis
}
#endif // USE_MPI

#ifdef USE_MPI
void generateGraphFromMesh() {
assert((m_vtxdist.empty() && m_xadj.empty() && m_adjncy.empty()) || (!m_vtxdist.empty() && !m_xadj.empty() && !m_adjncy.empty()));

if (m_vtxdist.empty() && m_xadj.empty() && m_adjncy.empty()) {
std::vector<idx_t> elemdist;
elemdist.resize(m_nparts + 1);
std::fill(elemdist.begin(), elemdist.end(), 0);

MPI_Allgather(const_cast<idx_t*>(&m_numCells), 1, IDX_T, elemdist.data(), 1, IDX_T, m_comm);

idx_t sum = 0;
for (int i = 0; i < m_nparts; i++) {
idx_t e = elemdist[i];
elemdist[i] = sum;
sum += e;
}
elemdist[m_nparts] = sum;

std::vector<idx_t> eptr;
eptr.resize(m_numCells + 1);
std::fill(eptr.begin(), eptr.end(), 0);

std::vector<idx_t> eind;
eind.resize(m_numCells * internal::Topology<Topo>::cellvertices());
std::fill(eind.begin(), eind.end(), 0);

for (idx_t i = 0; i < m_numCells; i++) {
eptr[i] = i * internal::Topology<Topo>::cellvertices();

for (unsigned int j = 0; j < internal::Topology<Topo>::cellvertices(); j++) {
eind[i * internal::Topology<Topo>::cellvertices() + j] = m_cells[i][j];
}
}

eptr[m_numCells] = m_numCells * internal::Topology<Topo>::cellvertices();

idx_t* metis_xadj;
idx_t* metis_adjncy;

ParMETIS_V3_Mesh2Dual(elemdist.data(), eptr.data(), eind.data(), &m_numflag, &m_ncommonnodes, &metis_xadj,
&metis_adjncy, &m_comm);

m_vtxdist = std::move(elemdist);

// the size of xadj is the
// - vtxdist[proc] + vtxdist[proc+1]
// because proc has the index proc to proc +1 elements

assert(m_vtxdist.size() == static_cast<size_t>(m_nparts + 1));
// the first element is always 0 and on top of that we have n nodes
size_t numElements = m_vtxdist[m_rank + 1] - m_vtxdist[m_rank] + 1;
m_xadj.reserve(numElements);
std::copy(metis_xadj, metis_xadj + numElements, std::back_inserter(m_xadj));

// last element of xadj will be the size of adjncy
size_t adjncySize = m_xadj[numElements - 1];
m_adjncy.reserve(adjncySize);
std::copy(metis_adjncy, metis_adjncy + adjncySize, std::back_inserter(m_adjncy));

METIS_Free(metis_xadj);
METIS_Free(metis_adjncy);
}
}
#endif

#ifdef USE_MPI
std::tuple<const std::vector<idx_t>&, const std::vector<idx_t>&, const std::vector<idx_t>&> getGraph() {
if (m_xadj.empty() && m_adjncy.empty()) {
generateGraphFromMesh();
}

return {m_vtxdist, m_xadj, m_adjncy};
}
#endif

#ifdef USE_MPI
enum Status {
Ok,
Expand All @@ -73,99 +167,85 @@ class PartitionMetis
const int* vertexWeights = nullptr,
const double* imbalances = nullptr,
int nWeightsPerVertex = 1,
const double* nodeWeights = nullptr) {
int rank, procs;
MPI_Comm_rank(m_comm, &rank);
MPI_Comm_size(m_comm, &procs);

idx_t* elemdist = new idx_t[procs+1];
MPI_Allgather(const_cast<idx_t*>(&m_numCells), 1, IDX_T, elemdist, 1, IDX_T, m_comm);
idx_t sum = 0;
for (int i = 0; i < procs; i++) {
idx_t e = elemdist[i];
elemdist[i] = sum;
sum += e;
}
elemdist[procs] = sum;
const double* nodeWeights = nullptr,
const int* edgeWeights = nullptr,
size_t edgeCount = 0)
{
generateGraphFromMesh();

idx_t* eptr = new idx_t[m_numCells+1];
idx_t* eind = new idx_t[m_numCells * internal::Topology<Topo>::cellvertices()];
unsigned long m = 0;
for (idx_t i = 0; i < m_numCells; i++) {
eptr[i] = i * internal::Topology<Topo>::cellvertices();
idx_t wgtflag = 0;

for (unsigned int j = 0; j < internal::Topology<Topo>::cellvertices(); j++) {
m = std::max(m, m_cells[i][j]);
eind[i*internal::Topology<Topo>::cellvertices() + j] = m_cells[i][j];
}
// set the flag
if (nodeWeights == nullptr && edgeWeights == nullptr) {
wgtflag = 0;
} else if (nodeWeights != nullptr && edgeWeights != nullptr) {
wgtflag = 3;
} else if (nodeWeights == nullptr && edgeWeights != nullptr) {
wgtflag = 1;
} else {
wgtflag = 2;
}
eptr[m_numCells] = m_numCells * internal::Topology<Topo>::cellvertices();

idx_t wgtflag = 0;

idx_t ncon = nWeightsPerVertex;
idx_t* elmwgt = nullptr;

if (vertexWeights != nullptr) {
wgtflag = 2;
elmwgt = new idx_t[m_numCells * ncon];
for (idx_t cell = 0; cell < m_numCells; ++cell) {
for (idx_t j = 0; j < ncon; ++j) {
elmwgt[ncon * cell + j] = static_cast<idx_t>(vertexWeights[ncon * cell + j]);
elmwgt[ncon * cell + j] = static_cast<idx_t>(vertexWeights[ncon * cell + j]);
}
}
}

idx_t numflag = 0;
idx_t ncommonnodes = 3; // TODO adapt for hex
idx_t nparts = procs;
idx_t* edgewgt = nullptr;
if (edgeWeights != nullptr) {
assert(edgeCount != 0);
edgewgt = new idx_t[edgeCount];
for (size_t i = 0; i < edgeCount; ++i) {
edgewgt[i] = static_cast<idx_t>(edgeWeights[i]);
}
}

real_t* tpwgts = new real_t[nparts * ncon];
real_t* tpwgts = new real_t[m_nparts* ncon];
if (nodeWeights != nullptr) {
for (idx_t i = 0; i < nparts; i++) {
for (idx_t i = 0; i < m_nparts; i++) {
for (idx_t j = 0; j < ncon; ++j) {
tpwgts[i*ncon + j] = nodeWeights[i];
tpwgts[i * ncon + j] = nodeWeights[i];
}
}
} else {
for (idx_t i = 0; i < nparts * ncon; i++) {
tpwgts[i] = static_cast<real_t>(1.) / nparts;
for (idx_t i = 0; i < m_nparts * ncon; i++) {
tpwgts[i] = static_cast<real_t>(1.) / m_nparts;
}
}

real_t* ubvec = new real_t[ncon];
for (idx_t i = 0; i < ncon; ++i) {
ubvec[i] = imbalances[i];
}

idx_t edgecut;
idx_t options[3] = {1, 1, METIS_RANDOM_SEED};

idx_t* part = new idx_t[m_numCells];

auto metisResult = ParMETIS_V3_PartMeshKway(elemdist,
eptr,
eind,
elmwgt,
&wgtflag,
&numflag,
&ncon,
&ncommonnodes,
&nparts,
tpwgts,
ubvec,
options,
&edgecut,
part,
&m_comm);

delete [] elemdist;
delete [] eptr;
delete [] eind;
delete [] tpwgts;
delete [] ubvec;

for (idx_t i = 0; i < m_numCells; i++)
partition[i] = part[i];

delete [] part;
assert(m_xadj.size() == static_cast<size_t>(m_vtxdist[m_rank + 1] - m_vtxdist[m_rank] + 1));
assert(m_adjncy.size() == static_cast<size_t>(m_xadj.back()));

auto metisResult = ParMETIS_V3_PartKway(m_vtxdist.data(), m_xadj.data(), m_adjncy.data(), elmwgt, edgewgt, &wgtflag,
&m_numflag, &ncon, &m_nparts, tpwgts, ubvec, options, &edgecut, part, &m_comm);

delete[] tpwgts;
delete[] ubvec;
delete[] elmwgt;
delete[] edgewgt;

for (idx_t i = 0; i < m_numCells; i++){
partition[i] = static_cast<int>(part[i]);
}

delete[] part;

return (metisResult == METIS_OK) ? Status::Ok : Status::Error;
}
Expand Down