-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlfm-bpr.py
53 lines (44 loc) · 1.96 KB
/
lfm-bpr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import sys
import argparse
from scipy import sparse
from lightfm import LightFM
parser = argparse.ArgumentParser(description='Argument Parser')
parser.add_argument('--train', help='train file')
parser.add_argument('--save', help='representation file')
parser.add_argument('--dim', type=int, default=128, help='representation file')
parser.add_argument('--iter', type=int, default=100, help='representation file')
parser.add_argument('--worker', type=int, default=1, help='# of workers')
args = parser.parse_args()
model = LightFM(learning_rate=0.025, no_components=args.dim, loss='bpr', item_alpha=0.00001, user_alpha=0.00001)
users = []
items = []
clicks = []
user_map, user_index = {}, {}
item_map, item_index = {}, {}
sys.stderr.write('load train data ...\n')
with open(args.train) as f:
for line in f:
user, item, click = line.rstrip('\n').split()
if user not in user_index:
user_index[user] = len(user_index)
user_map[user_index[user]] = user
if item not in item_index:
item_index[item] = len(item_index)
item_map[item_index[item]] = item
users.append(user_index[user])
items.append(item_index[item])
clicks.append(float(click))
train = sparse.coo_matrix((clicks, (users, items)))
sys.stderr.write('start training ...\n')
model.fit(train, epochs=args.iter, num_threads=args.worker, verbose=True)
sys.stderr.write('save representations ...\n')
user_reps = model.get_user_representations()
item_reps = model.get_item_representations()
rep_res = []
rep_res.append('%d %d' % (len(user_reps[1])+len(item_reps[1]), args.dim))
for e, user_rep in enumerate(user_reps[1]):
rep_res.append('%s %f 1.0 %s' % (user_map[e], user_reps[0][e], ' '.join(list(map(str, user_rep)))))
for e, item_rep in enumerate(item_reps[1]):
rep_res.append('%s 1.0 %f %s' % (item_map[e], item_reps[0][e], ' '.join(list(map(str, item_rep)))))
with open(args.save, 'w') as f:
f.write('%s\n' % ('\n'.join(rep_res)))