-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 853bd7c
Showing
12 changed files
with
1,710 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
## Fast Approximate Convex Hull Construction in Networks via Node Embedding | ||
|
||
## Authors | ||
|
||
* Dmitrii Gavrilev | ||
* Ilya Makarov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
import dgl | ||
|
||
|
||
def loss_dist(X, X_nograd, A, idx, norm=1): | ||
dist = torch.cdist(X, X_nograd, p=norm) | ||
k = X_nograd.shape[0] * len(idx) - len(idx) | ||
s = torch.sum((dist - A[idx])**2) / k | ||
return s | ||
|
||
|
||
def loss_log_dist(X, X_nograd, A, idx, norm=1): | ||
dist = torch.cdist(X, X_nograd, p=norm) | ||
k = X_nograd.shape[0] * len(idx) - len(idx) | ||
s = torch.sum((torch.log(dist + 1e-9) - torch.log(A[idx] + 1e-9))**2) / k | ||
return s | ||
|
||
|
||
def l1_loss_dist(X, X_nograd, A, idx, norm=1): | ||
dist = torch.cdist(X, X_nograd, p=norm) | ||
k = X_nograd.shape[0] * len(idx) - len(idx) | ||
s = torch.sum(torch.abs(dist - A[idx])) / k | ||
return s | ||
|
||
|
||
def train(model, num_epochs, g, dist, loader, opt, scheduler, log_loss=False, max_grad_norm=2., verbose=1): | ||
model.eval() | ||
with torch.no_grad(): | ||
mfgs = [dgl.to_block(g) for _ in range(len(model.conv))] | ||
inputs = mfgs[0].srcdata['x'] | ||
emb_nograd = model(mfgs, inputs) | ||
|
||
for epoch in range(num_epochs): | ||
model.train() | ||
epoch_loss = 0 | ||
for input_nodes, output_nodes, mfgs in loader: | ||
inputs = mfgs[0].srcdata['x'] | ||
emb = model(mfgs, inputs) | ||
emb_nograd[output_nodes] = emb.detach() | ||
if log_loss: | ||
loss = loss_log_dist(emb, emb_nograd, dist, output_nodes) | ||
else: | ||
loss = loss_dist(emb, emb_nograd, dist, output_nodes) | ||
opt.zero_grad() | ||
loss.backward() | ||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) | ||
opt.step() | ||
if scheduler is not None: | ||
scheduler.step() | ||
epoch_loss += loss.item() | ||
epoch_loss /= len(loader) | ||
if epoch % verbose == 0: | ||
print(f'Epoch: {epoch}, loss: {epoch_loss}') | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
mfgs = [dgl.to_block(g) for _ in range(len(model.conv))] | ||
inputs = mfgs[0].srcdata['x'] | ||
emb = model(mfgs, inputs) | ||
loss = loss_dist(emb, emb, dist, g.nodes()) | ||
print(f'Final loss: {loss.item()}') | ||
loss = loss_log_dist(emb, emb, dist, g.nodes()) | ||
print(f'Final loss (log): {loss.item()}') | ||
J = l1_loss_dist(emb, emb, dist, g.nodes()) | ||
print(f'Absolute loss (J): {J}') | ||
|
||
return emb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
#define _CRT_SECURE_NO_WARNINGS | ||
#include <vector> | ||
#include <deque> | ||
#include <fstream> | ||
#include <cstdio> | ||
#include <utility> | ||
#include <unordered_set> | ||
#include <random> | ||
#include <cassert> | ||
#include <numeric> | ||
#include <iostream> | ||
#include <chrono> | ||
#include <cstring> | ||
#include <climits> | ||
#include <iterator> | ||
#include <set> | ||
#include "algo.h" | ||
|
||
using namespace std; | ||
#define NO_VALUE -1 | ||
|
||
SubGraph::SubGraph(const vector<vector<int>>& network) : | ||
present(vector<char>(network.size())) | ||
{} | ||
inline bool SubGraph::insert(int vertex) { | ||
bool inserted = !present[vertex]; | ||
if (inserted) { | ||
present[vertex] = 1; | ||
list.push_back(vertex); | ||
} | ||
return inserted; | ||
} | ||
|
||
template <typename T> | ||
bool contains(const vector<T>& vec, T& el) { | ||
for (const T& i : vec) { | ||
if (i == el) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
vector<vector<int>> readPajek(string fn, vector<string>* names) { | ||
ifstream input(fn, ifstream::in); | ||
if (!input.is_open()) { | ||
perror(fn.c_str()); | ||
exit(0); | ||
} | ||
vector<vector<int>> res; | ||
char line[1024]; | ||
input.getline(line, 1024); | ||
int n = 0; | ||
for (; ; n++) { | ||
input.getline(line, 1024); | ||
if (line[0] == '*') { | ||
break; | ||
} | ||
if (names != nullptr) { | ||
names->push_back(line); | ||
} | ||
} | ||
res.resize(n, vector<int>()); | ||
int m = 0; | ||
int a, b; | ||
for (; ; m++) { | ||
input.getline(line, 256); | ||
if (sscanf(line, "%d %d", &a, &b) == EOF) { | ||
break; | ||
} | ||
a--; // pajek uses 1-based indexing | ||
b--; | ||
if (!contains(res[a], b)) { | ||
res[a].push_back(b); | ||
} | ||
if (!contains(res[b], a)) { | ||
res[b].push_back(a); | ||
} | ||
} | ||
return res; | ||
} | ||
|
||
vector<vector<int>> distances(const vector<vector<int>>& network) { | ||
vector<vector<int>> res(network.size()); | ||
//res.reserve(network.size()); | ||
#pragma omp parallel for | ||
for (int vertex = 0; vertex < network.size(); vertex++) { | ||
vector<int> distances(network.size(), NO_VALUE); | ||
distances[vertex] = 0; | ||
deque<int> todo; | ||
todo.push_back(vertex); | ||
while (!todo.empty()) { | ||
int current = todo.front(); | ||
todo.pop_front(); | ||
for (int neighbor : network[current]) { | ||
if (distances[neighbor] == NO_VALUE) { | ||
distances[neighbor] = distances[current] + 1; | ||
todo.push_back(neighbor); | ||
} | ||
} | ||
} | ||
res[vertex] = move(distances); | ||
} | ||
return res; | ||
} | ||
|
||
vector<int> convexHull(const vector<vector<int>>& network, const vector<vector<int>>& distances, SubGraph& subGraph, vector<int> base) { | ||
deque<int> todo; | ||
vector<int> insertions; | ||
for (auto newVertex : base) { | ||
todo.push_back(newVertex); | ||
insertions.push_back(newVertex); | ||
subGraph.insert(newVertex); | ||
} | ||
while (!todo.empty()) { | ||
int current = todo.front(); | ||
todo.pop_front(); | ||
for (int neighbor : network[current]) { | ||
if (!subGraph.present[neighbor]) { | ||
for (int endVertex : subGraph.list) { | ||
if (distances[current][endVertex] >= distances[current][neighbor] + distances[neighbor][endVertex]) { | ||
todo.push_back(neighbor); | ||
insertions.push_back(neighbor); | ||
subGraph.insert(neighbor); | ||
break; | ||
} | ||
else { | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
return insertions; | ||
} | ||
|
||
vector<vector<int>> generateBases(string fn, int k, int repeats) { | ||
auto net = readPajek(fn); | ||
vector<vector<int>> bases(repeats); | ||
vector<int> nodes(net.size()); | ||
long long rnd_init = 14994518116208229; | ||
std::default_random_engine generator(rnd_init); | ||
for (int i = 0; i < net.size(); ++i) { | ||
nodes[i] = i; | ||
} | ||
for (int i = 0; i < repeats; ++i) { | ||
vector<int> s; | ||
std::sample(nodes.begin(), nodes.end(), std::back_inserter(s), k, generator); | ||
bases[i] = s; | ||
} | ||
return bases; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#pragma once | ||
#define _CRT_SECURE_NO_WARNINGS | ||
#include <vector> | ||
#include <deque> | ||
#include <fstream> | ||
#include <cstdio> | ||
#include <utility> | ||
#include <unordered_set> | ||
#include <random> | ||
#include <cassert> | ||
#include <numeric> | ||
#include <iostream> | ||
#include <chrono> | ||
#include <cstring> | ||
|
||
using namespace std; | ||
#define NO_VALUE -1 | ||
|
||
class SubGraph { | ||
public: | ||
vector<char> present; | ||
vector<int> list; | ||
SubGraph(const vector<vector<int>>& network); | ||
bool insert(int vertex); | ||
}; | ||
|
||
template <typename T> | ||
bool contains(const vector<T>& vec, T& el); | ||
|
||
vector<vector<int>> readPajek(string fn, vector<string>* names = nullptr); | ||
|
||
vector<vector<int>> distances(const vector<vector<int>>& network); | ||
|
||
vector<int> convexHull(const vector<vector<int>>& network, const vector<vector<int>>& distances, SubGraph& subGraph, vector<int> base); | ||
|
||
vector<vector<int>> generateBases(string fn, int k, int repeats); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#include "algo.h" | ||
|
||
using namespace std; | ||
|
||
int main(int argc, char* argv[]) { | ||
string fn_in; | ||
string fn_out = "out.txt"; | ||
string fn_dist; | ||
int k; | ||
int repeats = 10; | ||
//parse args | ||
if (argc > 1) { | ||
fn_in = argv[1]; | ||
cout << fn_in << endl; | ||
if (argc > 2) { | ||
sscanf(argv[2], "%d", &k); | ||
cout << "k = " << k << endl; | ||
if (argc > 3) { | ||
sscanf(argv[3], "%d", &repeats); | ||
cout << "repeats = " << repeats << endl; | ||
if (argc > 4) { | ||
fn_out = argv[4]; | ||
cout << fn_out << endl; | ||
} | ||
} | ||
} | ||
else { | ||
cout << "Second argument should be number of elements!" << endl; | ||
} | ||
} | ||
else { | ||
cout << "First argument should be path to pajek file containing a network!" << endl; | ||
exit(0); | ||
} | ||
|
||
auto net = readPajek(fn_in); | ||
cout << "Network loaded." << endl; | ||
cout << "Nodes in the largest connected component: " << net.size() << endl; | ||
|
||
fn_dist = fn_out + ".dist"; | ||
fn_out = fn_out + ".out"; | ||
std::ofstream output(fn_out); | ||
output << repeats << endl; | ||
|
||
vector<vector<int>> bases = generateBases(fn_in, k, repeats); | ||
for (auto b : bases) { | ||
for (auto elem : b) { | ||
output << elem << ' '; | ||
} | ||
output << endl; | ||
} | ||
cout << "Bases have been sampled" << endl; | ||
|
||
std::ofstream dist_out(fn_dist); | ||
auto dists = distances(net); | ||
for (int i = 0; i < dists.size(); ++i) { | ||
for (int j = 0; j < dists[i].size(); ++j) { | ||
dist_out << dists[i][j] << " "; | ||
} | ||
dist_out << endl; | ||
} | ||
cout << "Distances have been calculated" << endl; | ||
|
||
int i = 0; | ||
for (auto b : bases) { | ||
SubGraph s(net); | ||
auto hull = convexHull(net, dists, s, b); | ||
for (auto elem : hull) { | ||
output << elem << ' '; | ||
} | ||
output << endl; | ||
cout << "Hull: " <<i << endl; | ||
++i; | ||
} | ||
|
||
output.close(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
0 1 1 1 1 1 1 1 1 2 1 1 1 1 3 3 2 1 3 1 3 1 3 3 2 2 3 2 2 3 2 1 2 2 | ||
1 0 1 1 2 2 2 1 2 2 2 2 2 1 3 3 3 1 3 1 3 1 3 3 3 3 3 2 2 3 1 2 2 2 | ||
1 1 0 1 2 2 2 1 1 1 2 2 2 1 2 2 3 2 2 2 2 2 2 2 2 3 3 1 1 2 2 2 1 2 | ||
1 1 1 0 2 2 2 1 2 2 2 2 1 1 3 3 3 2 3 2 3 2 3 3 3 3 3 2 2 3 2 2 2 2 | ||
1 2 2 2 0 2 1 2 2 3 1 2 2 2 4 4 2 2 4 2 4 2 4 4 3 3 4 3 3 4 3 2 3 3 | ||
1 2 2 2 2 0 1 2 2 3 1 2 2 2 4 4 1 2 4 2 4 2 4 4 3 3 4 3 3 4 3 2 3 3 | ||
1 2 2 2 1 1 0 2 2 3 2 2 2 2 4 4 1 2 4 2 4 2 4 4 3 3 4 3 3 4 3 2 3 3 | ||
1 1 1 1 2 2 2 0 2 2 2 2 2 2 3 3 3 2 3 2 3 2 3 3 3 3 4 2 2 3 2 2 2 3 | ||
1 2 1 2 2 2 2 2 0 2 2 2 2 2 2 2 3 2 2 2 2 2 2 2 3 3 2 2 2 2 1 2 1 1 | ||
2 2 1 2 3 3 3 2 2 0 3 3 3 2 2 2 4 3 2 2 2 3 2 2 3 3 2 2 2 2 2 2 2 1 | ||
1 2 2 2 1 1 2 2 2 3 0 2 2 2 4 4 2 2 4 2 4 2 4 4 3 3 4 3 3 4 3 2 3 3 | ||
1 2 2 2 2 2 2 2 2 3 2 0 2 2 4 4 3 2 4 2 4 2 4 4 3 3 4 3 3 4 3 2 3 3 | ||
1 2 2 1 2 2 2 2 2 3 2 2 0 2 4 4 3 2 4 2 4 2 4 4 3 3 4 3 3 4 3 2 3 3 | ||
1 1 1 1 2 2 2 2 2 2 2 2 2 0 2 2 3 2 2 2 2 2 2 2 3 3 2 2 2 2 2 2 2 1 | ||
3 3 2 3 4 4 4 3 2 2 4 4 4 2 0 2 5 4 2 2 2 4 2 2 3 3 2 2 2 2 2 2 1 1 | ||
3 3 2 3 4 4 4 3 2 2 4 4 4 2 2 0 5 4 2 2 2 4 2 2 3 3 2 2 2 2 2 2 1 1 | ||
2 3 3 3 2 1 1 3 3 4 2 3 3 3 5 5 0 3 5 3 5 3 5 5 4 4 5 4 4 5 4 3 4 4 | ||
1 1 2 2 2 2 2 2 2 3 2 2 2 2 4 4 3 0 4 2 4 2 4 4 3 3 4 3 3 4 2 2 3 3 | ||
3 3 2 3 4 4 4 3 2 2 4 4 4 2 2 2 5 4 0 2 2 4 2 2 3 3 2 2 2 2 2 2 1 1 | ||
1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 0 2 2 2 2 3 3 2 2 2 2 2 2 2 1 | ||
3 3 2 3 4 4 4 3 2 2 4 4 4 2 2 2 5 4 2 2 0 4 2 2 3 3 2 2 2 2 2 2 1 1 | ||
1 1 2 2 2 2 2 2 2 3 2 2 2 2 4 4 3 2 4 2 4 0 4 4 3 3 4 3 3 4 2 2 3 3 | ||
3 3 2 3 4 4 4 3 2 2 4 4 4 2 2 2 5 4 2 2 2 4 0 2 3 3 2 2 2 2 2 2 1 1 | ||
3 3 2 3 4 4 4 3 2 2 4 4 4 2 2 2 5 4 2 2 2 4 2 0 2 1 2 1 2 1 2 2 1 1 | ||
2 3 2 3 3 3 3 3 3 3 3 3 3 3 3 3 4 3 3 3 3 3 3 2 0 1 3 1 2 3 3 1 2 2 | ||
2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 3 3 3 3 3 3 1 1 0 3 2 2 2 3 1 2 2 | ||
3 3 3 3 4 4 4 4 2 2 4 4 4 2 2 2 5 4 2 2 2 4 2 2 3 3 0 2 2 1 2 2 2 1 | ||
2 2 1 2 3 3 3 2 2 2 3 3 3 2 2 2 4 3 2 2 2 3 2 1 1 2 2 0 2 2 2 2 2 1 | ||
2 2 1 2 3 3 3 2 2 2 3 3 3 2 2 2 4 3 2 2 2 3 2 2 2 2 2 2 0 2 2 1 2 1 | ||
3 3 2 3 4 4 4 3 2 2 4 4 4 2 2 2 5 4 2 2 2 4 2 1 3 2 1 2 2 0 2 2 1 1 | ||
2 1 2 2 3 3 3 2 1 2 3 3 3 2 2 2 4 2 2 2 2 2 2 2 3 3 2 2 2 2 0 2 1 1 | ||
1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 2 2 2 2 2 1 1 2 2 1 2 2 0 1 1 | ||
2 2 1 2 3 3 3 2 1 2 3 3 3 2 1 1 4 3 1 2 1 3 1 1 2 2 2 2 2 1 1 1 0 1 | ||
2 2 2 2 3 3 3 3 1 1 3 3 3 1 1 1 4 3 1 1 1 3 1 1 2 2 1 1 1 1 1 1 1 0 |
Oops, something went wrong.