-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
30 lines (22 loc) · 1006 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import dataset
from bandits import *
from evaluator import evaluate
import os
from matplotlib import pyplot as plt
from pathlib import Path
if __name__ == '__main__':
files = ("./dataset/R6/ydata-fp-td-clicks-v1_0.20090501")
dataset.get_yahoo_events(files)
dim = dataset.n_context_features * dataset.n_arms
seed = 42
kwargs = {'N': 20, 'decoupled_arms': None, 'std': 0.0625, 'is_optimistic': False}
algos = [NeuralTSDiag(dim, n_hidden=1, hidden_size=100, style='rand_ucb', **kwargs),
NeuralTSDiag(dim, lamdba=1, nu=1e-1, n_hidden=1, hidden_size=100, style='ucb'),
NeuralTSDiag(dim, lamdba=1, nu=1e-2, n_hidden=1, hidden_size=100, style='ts'),
LinUCB(0.3, context="user")]
for algo in algos:
torch.manual_seed(seed)
np.random.seed(seed)
_, deploy = evaluate(algo, size=100)
filepath = Path(Path(__file__).parent, 'results', algo.name + '.csv')
np.savetxt(filepath, deploy, delimiter=',')