-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_datasets.py
69 lines (48 loc) · 2.11 KB
/
merge_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import gzip
import numpy as np
import pandas as pd
import scipy.sparse as sps
from RecSysFramework.ExperimentalConfig import EXPERIMENTAL_CONFIG
def gzip_file(filename):
with open(filename, 'rb') as f_in, gzip.open(filename + '.gz', 'wb') as f_out:
f_out.writelines(f_in)
def merge_datasets(datasets_to_merge):
print("Merging", "-".join(datasets_to_merge))
for data in ["train.tsv.gz", "train_5core.tsv.gz", "valid_qrel.tsv.gz", "valid_run.tsv.gz"]:
dfs = []
colnames = ["userid", "itemid", "score"]
header = 0
if data == "valid_run.tsv.gz":
header = None
colnames = colnames[:2]
for folder in datasets_to_merge:
basepath = EXPERIMENTAL_CONFIG['dataset_folder'] + folder + os.sep
filename = basepath + data
if not os.path.isfile(filename):
if os.path.isfile(filename[:-3]):
gzip_file(filename[:-3])
else:
raise Exception("File {} is missing!".format(filename[:-3]))
dfs.append(pd.read_csv(filename, sep="\t", index_col=False, header=header, names=colnames))
df = pd.concat(dfs)
del dfs
if data == "valid_run.tsv.gz":
df.drop_duplicates(subset=["userid"], keep="last")
else:
df.drop_duplicates(subset=["userid", "itemid"], keep="last")
new_dir = EXPERIMENTAL_CONFIG['dataset_folder'] + "-".join(sorted(datasets_to_merge)) + "-iu" + os.sep
os.makedirs(new_dir, exist_ok=True)
df.to_csv(new_dir + data, sep="\t", header=data != "valid_run.tsv.gz", index=False)
if __name__ == "__main__":
source_markets = [
["s1"], ["s2"], ["s3"], ["s1", "s2", "s3"],
["s1", "s2"], ["s2", "s3"], ["s1", "s3"],
]
target_markets = [
["t1"], ["t2"], ["t1", "t2"]
]
merge_datasets(["t1", "t2"])
for source_market in source_markets:
for target_market in target_markets:
merge_datasets(list(sorted(source_market + target_market)))