Skip to content

Commit

Permalink
Improved performance of linucb.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrucker committed Feb 7, 2024
1 parent 2824a3c commit 037b695
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion coba/learners/linucb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _pmf(self,context,actions):
features = np.array([self._X_encoder.encode(x=context,a=action) for action in actions]).T

point_estimate = self._theta @ features
point_bounds = np.diagonal(features.T @ self._A_inv @ features)
point_bounds = np.einsum('ij,ij->j', self._A_inv @ features, features) #== np.diagonal(features.T @ self._A_inv @ features)

action_values = point_estimate + self._alpha*np.sqrt(point_bounds)
max_indexes = np.where(action_values == np.amax(action_values))[0]
Expand Down

0 comments on commit 037b695

Please sign in to comment.