-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathppo_utils.py
executable file
·124 lines (102 loc) · 4.16 KB
/
ppo_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import math
import numpy as np
from numpy.random import randint
from typing import Any, Callable, Dict, List
rng = np.random.RandomState(0)
def expected(score_A, score_B):
"""
Calculates the expected score of A in a match against B
score_A: score of player A
score_B: score of player B
"""
return 1 / (1 + 10 ** ((score_B - score_A) / 400))
def compute_elo(old, expected, score, k=16):
"""
Computes the new elo score for a player
old: old elo score
expected: expected score of the player
score: actual score of the player
k: k-factor
"""
return old + k * (score - expected)
def elo_schedule(prior : Any,
players : List[Any],
match_function : Callable,
player_scores : List[float] = None,
samples : int = 8,
tournament_size : int = 1,
mbs : int = 1,
list_return_dict : List[Dict[Any, float]] = []) -> List[Any]:
"""
prior: prior distribution
players: list of players
match_function: function that takes two players and returns a win (1) or loss (0)
player_scores: list of scores for each player
samples: number of matches to play per player
tournament_size: how many assignments per sample (lower bound)
mbs: number of matches to play at once
return_dict: a list of dicts for the scores at each tournament
returns: a tuple of the players and their scores
"""
# if this is the first time, set the initial scores to 1000 and initialize the order
if player_scores is None:
player_scores = [1000] * len(players)
# base case
if samples == 0:
return list_return_dict
wins = [0] * len(players)
# Compute the match ups first, then microbatch over them using compute elo.
pairs = []
for _ in range(tournament_size):
# since we can draw many samples, we can match randomly
idxs = np.arange(len(players))
rng.shuffle(idxs)
# in uneven case, remove either first or last participant
if len(players) & 1:
idxs = np.delete(idxs, -rng.randint(2))
pairs.append(idxs.reshape(-1, 2))
pairs = np.vstack(pairs)
players1 = [players[pair[0]] for pair in pairs]
players2 = [players[pair[1]] for pair in pairs]
# play the matches, using mbs
results = []
for i in range(math.ceil(len(pairs)/mbs)):
batch_ixs = slice(i*mbs, (i+1)*mbs)
results.extend(match_function(prior, players1[batch_ixs], players2[batch_ixs]))
# record the results
for result, (p1, p2) in zip(results, pairs):
wins[p1] += 1 - result
wins[p2] += result
# update elo
next_player_scores = np.zeros(len(players))
out_dict = {}
for pix in range(len(players)):
opponents = [set(pair).difference({pix}).pop() for pair in pairs if pix in pair]
expected_score = sum(expected(player_scores[pix], player_scores[opp]) for opp in opponents)
next_player_scores[pix] = compute_elo(player_scores[pix], expected_score, wins[pix])
out_dict[players[pix]] = next_player_scores[pix]
list_return_dict.append(out_dict)
# recurse
return elo_schedule(prior, players, match_function, next_player_scores, samples - 1,
tournament_size=tournament_size, mbs=mbs, list_return_dict=list_return_dict)
# The critic model below is a language model that we'll prompt for a single set of logits.
class ELOCriticModel:
def __init__(self, model, tokenizer):
"""
model: A hugging face transformer model
tokenizer: a tokenizer that takes a string and returns a list of ints
"""
self.model = model
self.tokenizer = tokenizer
def match_function(self, priors, player1, player2):
raise NotImplementedError
if __name__ == '__main__':
# test elo_schedule
players = np.arange(5)
match_function = lambda prior, xs1, xs2: (np.array(xs1) > np.array(xs2)).astype(int)
ranking = elo_schedule(None, players, match_function, mbs=3, tournament_size=10, samples=100)
xs = sorted(zip(*ranking), key=lambda x: x[1], reverse=True)
print('ELO for number comparisons:')
for sample, rating in xs:
print(f'[{rating:.0f}]', sample)
assert all(players == np.argsort(ranking[1]))