diff --git a/base/Corrupt.h b/base/Corrupt.h index 885daa2..0ecb0b3 100755 --- a/base/Corrupt.h +++ b/base/Corrupt.h @@ -9,17 +9,17 @@ /* Answers a question with a randomized unknown head. */ -ent_id corrupt_head( +int64_t corrupt_head( unsigned long int id, - ent_id tail, - rel_id relation) + int64_t tail, + int64_t relation) { // using precalculated better range const auto range = std::equal_range(lefHead[tail], rigHead[tail], Triple(0, tail, relation), Triple::cmp_r); const auto lower = range.first; const auto upper = range.second; - const ent_id x = rand_max(id, entityTotal - (upper - lower)); + const int64_t x = rand_max(id, entityTotal - (upper - lower)); if (x < lower->t) return x; if (x + (upper - lower) > (upper-1)->t) @@ -43,14 +43,14 @@ ent_id corrupt_head( /* Answers a question with a randomized unknown tail. */ -ent_id corrupt_tail(unsigned long int id, ent_id head, rel_id relation) +int64_t corrupt_tail(unsigned long int id, int64_t head, int64_t relation) { // using precalculated better range const auto range = std::equal_range(lefTail[head], rigTail[head], Triple(head, 0, relation), Triple::cmp_r); const auto lower = range.first; const auto upper = range.second; - const ent_id x = rand_max(id, entityTotal - (upper - lower)); + const int64_t x = rand_max(id, entityTotal - (upper - lower)); if (x < lower->h) return x; if (x + (upper - lower) > (upper-1)->h) @@ -75,14 +75,14 @@ ent_id corrupt_tail(unsigned long int id, ent_id head, rel_id relation) Answers a question with a randomized unknown relation. FIXME SIGSEGV */ -rel_id corrupt_rel(unsigned long int id, ent_id head, ent_id tail) +int64_t corrupt_rel(unsigned long int id, int64_t head, int64_t tail) { // using precalculated better range const auto range = std::equal_range(lefRel[head], rigHead[head], Triple(head, tail, 0), Triple::cmp_t); const auto lower = range.first; const auto upper = range.second; - const ent_id x = rand_max(id, relationTotal - (upper - lower)); + const int64_t x = rand_max(id, relationTotal - (upper - lower)); if (x < lower->h) return x; if (x + (upper - lower) > (upper-1)->h) diff --git a/base/Reader.h b/base/Reader.h index 5ca69fa..d1122c6 100755 --- a/base/Reader.h +++ b/base/Reader.h @@ -33,10 +33,7 @@ std::vector meant; extern "C" -int importTrainFiles( - const char* inPath, - ent_id entities, - rel_id relations) +int importTrainFiles(const char* inPath, int64_t entities, int64_t relations) try { entityTotal = entities; @@ -68,7 +65,7 @@ try { const auto end = trainHead.cend(); auto i = trainHead.cbegin(); - ent_id last = i->t; + int64_t last = i->t; for (++i; i != end; last = i->t, ++i) { if (i->t == last) @@ -84,7 +81,7 @@ try { const auto end = trainTail.cend(); auto i = trainTail.cbegin(); - ent_id last = i->h; + int64_t last = i->h; for (++i; i != end; last = i->h, ++i) { if (i->h == last) @@ -100,7 +97,7 @@ try { const auto end = trainRel.cend(); auto i = trainRel.cbegin(); - ent_id last = i->h; + int64_t last = i->h; for (++i; i != end; last = i->h, ++i) { if (i->h == last) @@ -112,7 +109,7 @@ try rigRel.at(trainRel.back().h) = trainRel.cend(); meanh.assign(relations, 0.); - for (ent_id i = 0; i < entities; ++i) + for (int64_t i = 0; i < entities; ++i) { const auto lower = lefHead[i]; const auto upper = rigHead[i]; @@ -122,11 +119,11 @@ try if (j->r != (j - 1)->r) meanh.at(j->r) += 1.; } - for (rel_id i = 0; i < relations; ++i) + for (int64_t i = 0; i < relations; ++i) meanh[i] = meanh[i] > .5 ? freqr[i] / meanh[i] : 0; meant.assign(relations, 0.); - for (ent_id i = 0; i < entities; ++i) + for (int64_t i = 0; i < entities; ++i) { const auto lower = lefTail[i]; const auto upper = rigTail[i]; @@ -137,7 +134,7 @@ try meant.at(j->r) += 1.; } - for (rel_id i = 0; i < relations; ++i) + for (int64_t i = 0; i < relations; ++i) meant[i] = meant[i] > .5 ? freqr[i] / meant[i] : 0; return 0; } diff --git a/base/Setting.h b/base/Setting.h index 450034b..d700356 100755 --- a/base/Setting.h +++ b/base/Setting.h @@ -3,65 +3,9 @@ #include "Triple.h" #include // std::vector -extern "C" -uint64_t getEntityTotal(); - -extern "C" -uint64_t getRelationTotal(); - -extern "C" -uint64_t getTripleTotal(); - -extern "C" -uint64_t getTrainTotal(); - -extern "C" -uint64_t getTestTotal(); - - uint64_t entityTotal = 0; uint64_t relationTotal = 0; std::vector trainList; std::vector testList; -extern "C" -uint64_t getEntityTotal() -{ - return entityTotal; -} - - -extern "C" -uint64_t getRelationTotal() -{ - return relationTotal; -} - - -extern "C" -uint64_t getTripleTotal() -{ - return trainList.size() + testList.size(); -} - - -extern "C" -uint64_t getTrainTotal() -{ - return trainList.size(); -} - - -extern "C" -uint64_t getTestTotal() -{ - return testList.size(); -} - - -extern "C" -uint64_t getValidTotal() -{ - return 0; -} #endif diff --git a/base/Test.h b/base/Test.h index 4cc153f..7fdddc5 100644 --- a/base/Test.h +++ b/base/Test.h @@ -6,15 +6,15 @@ # include // std::binary_search -extern "C" void query_head(char*, ent_id, rel_id); -extern "C" void query_tail(ent_id, char*, rel_id); -extern "C" void query_rel(ent_id, ent_id, char*); +extern "C" void query_head(char*, int64_t, int64_t); +extern "C" void query_tail(int64_t, char*, int64_t); +extern "C" void query_rel(int64_t, int64_t, char*); extern "C" void query_head( char* out, - ent_id tail, - rel_id relation) + int64_t tail, + int64_t relation) try { // assuming out has size `entityTotal` and is zero-initialized @@ -31,9 +31,9 @@ catch (std::out_of_range& e) extern "C" void query_tail( - ent_id head, + int64_t head, char* out, - rel_id relation) + int64_t relation) try { // assuming out has size `entityTotal` and is zero-initialized @@ -50,8 +50,8 @@ catch (std::out_of_range& e) extern "C" void query_rel( - ent_id head, - ent_id tail, + int64_t head, + int64_t tail, char* out) try { diff --git a/base/Triple.h b/base/Triple.h index ca716da..f4452fe 100755 --- a/base/Triple.h +++ b/base/Triple.h @@ -4,107 +4,38 @@ # include // int64_t -using ent_id = int64_t; -using rel_id = int64_t; - - struct Triple { + int64_t h; + int64_t t; + int64_t r; - - ent_id h; - ent_id t; - rel_id r; - - - bool operator<( - const Triple& other) - { - return std::min(h, t) > std::min(other.h, other.t); - } - - - static bool cmp_hrt( - const Triple& a, - const Triple& b) - { - return (a.h < b.h) or (a.h == b.h and a.r < b.r) - or (a.h == b.h and a.r == b.r and a.t < b.t); - } - - - static bool cmp_trh( - const Triple& a, - const Triple& b) - { - return (a.t < b.t) or (a.t == b.t and a.r < b.r) - or (a.t == b.t and a.r == b.r and a.h < b.h); - } - - - static bool cmp_htr( - const Triple& a, - const Triple& b) - { - return (a.h < b.h) or (a.h == b.h and a.t < b.t) - or (a.h == b.h and a.t == b.t and a.r < b.r); + static bool cmp_hrt(const Triple& a, const Triple& b) { + return (a.h < b.h) or (a.h == b.h and a.r < b.r) or (a.h == b.h and a.r == b.r and a.t < b.t); } - - static bool cmp_hr( - const Triple& a, - const Triple& b) - { - return (a.h < b.h) or (a.h == b.h and a.r < b.r); + static bool cmp_trh(const Triple& a, const Triple& b) { + return (a.t < b.t) or (a.t == b.t and a.r < b.r) or (a.t == b.t and a.r == b.r and a.h < b.h); } - - static bool cmp_tr( - const Triple& a, - const Triple& b) - { - return (a.t < b.t) or (a.t == b.t and a.r < b.r); + static bool cmp_htr(const Triple& a, const Triple& b) { + return (a.h < b.h) or (a.h == b.h and a.t < b.t) or (a.h == b.h and a.t == b.t and a.r < b.r); } - - static bool cmp_h( - const Triple& a, - const Triple& b) - { + static bool cmp_h(const Triple& a, const Triple& b) { return (a.h < b.h); } - - static bool cmp_t( - const Triple& a, - const Triple& b) - { + static bool cmp_t(const Triple& a, const Triple& b) { return (a.t < b.t); } - - static bool cmp_r( - const Triple& a, - const Triple& b) - { + static bool cmp_r(const Triple& a, const Triple& b) { return (a.r < b.r); } + Triple(const int64_t& head, const int64_t& tail, const int64_t& rel): h{head}, t{tail}, r{rel} {} - Triple( - const ent_id& head, - const ent_id& tail, - const rel_id& rel): - h{head}, t{tail}, r{rel} - { - } - - - Triple(void): - Triple(0, 0, 0) - { - } - - + Triple(void): Triple(0, 0, 0) {} }; #endif // TRIPLE_H diff --git a/base/openke.hpp b/base/openke.hpp index 8c94e17..27a01bc 100644 --- a/base/openke.hpp +++ b/base/openke.hpp @@ -15,34 +15,17 @@ void bernSampling(uint64_t* h, uint64_t* t, uint64_t* r, float* y, uint64_t size extern "C" void sampling(uint64_t* h, uint64_t* t, uint64_t* r, float* y, uint64_t size, uint64_t ne, uint64_t nr, uint64_t workers); - -extern "C" -int importTrainFiles(const char* inPath, ent_id entities, rel_id relations); - - -extern "C" -void query_head(char*, ent_id, rel_id); -extern "C" -void query_tail(ent_id, char*, rel_id); extern "C" -void query_rel(ent_id, ent_id, char*); - - -extern "C" -uint64_t getEntityTotal(void); +int importTrainFiles(const char* inPath, int64_t entities, int64_t relations); extern "C" -uint64_t getRelationTotal(void); +void query_head(char*, int64_t, int64_t); extern "C" -uint64_t getTripleTotal(void); +void query_tail(int64_t, char*, int64_t); extern "C" -uint64_t getTrainTotal(void); - -extern "C" -uint64_t getTestTotal(void); - +void query_rel(int64_t, int64_t, char*); # if 0 extern "C" diff --git a/openke/Config.py b/openke/Config.py index 16874ee..4262f3d 100755 --- a/openke/Config.py +++ b/openke/Config.py @@ -49,7 +49,7 @@ def __init__(self, filename: str, library: str = './libopenke.so', temp_dir: str parser.ent_count, parser.rel_count, ) - self.size = self.__library.getTrainTotal() + self.size = parser.train_count self.shape = parser.ent_count, parser.rel_count def __len__(self): diff --git a/openke/parser.py b/openke/parser.py index 9ea2f66..4de864f 100644 --- a/openke/parser.py +++ b/openke/parser.py @@ -42,6 +42,7 @@ def __init__(self, filename: str, temp_dir: str, generate_valid_test: bool = Fal self.test_file = os.path.join(self.output_dir, "test2id.txt") self.ent_count = None self.rel_count = None + self.train_count = None def parse(self): """ @@ -55,6 +56,8 @@ def parse(self): self.ent_count = int(f.readline()) with open(self.relation_file) as f: self.rel_count = int(f.readline()) + with open(self.train_file) as f: + self.train_count = int(f.readline()) else: self.create_benchmark() @@ -260,3 +263,4 @@ def create_benchmark(self): self.ent_count = len(srs_ent) self.rel_count = len(srs_rel) + self.train_count = len(df_train)