Skip to content

Commit

Permalink
Fixed a bug that occured when using unsorted plot values.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrucker committed May 1, 2024
1 parent ce7d887 commit 55c3ac3
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
9 changes: 8 additions & 1 deletion coba/results/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1933,8 +1933,15 @@ def _group_p(self, l:Union[str, Sequence[str]], p:Union[str, Sequence[str]]):
indexes = list(self._grouped_ys(p,l,'environment_id','learner_id','evaluator_id',y=None,card='S'))
n_levels = len(set(map(itemgetter(1),indexes)))

try:
indexes = sorted(indexes)
except:
sorted_=False
else:
sorted_=True

to_keep, to_remove, n_larger, n_smaller = [], [], 0, 0
for _, group in groupby(indexes,key=itemgetter(0)):
for _, group in grouper(indexes,key=itemgetter(0), sorted_=sorted_):
group = list(group)
if len(group) > n_levels:
n_larger += 1
Expand Down
47 changes: 47 additions & 0 deletions coba/tests/test_results_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,53 @@ def test_where_fin_multi_p(self):
self.assertEqual("We removed 1 data_id because more than one existed for each learner_id.", CobaContext.logger.sink.items[0])
self.assertEqual("There was no data_id which was finished for every learner_id.", CobaContext.logger.sink.items[1])

def test_where_fin_multi_unordered_p(self):

CobaContext.logger = IndentLogger()
CobaContext.logger.sink = ListSink()

envs = [['environment_id','data_id','seed' ],[1,2,1],[2,3,2],[3,2,2],[4,3,2]]
lrns = [['learner_id' ],[1],[2]]
vals = [['evaluator_id' ],[1]]
ints = [['environment_id','learner_id','evaluator_id','index'],[1,1,1,0],[2,1,1,0],[3,2,1,0],[4,2,1,0]]

original_result = Result(envs, lrns, vals, ints)
filtered_result = original_result.where_fin(l='learner_id',p='data_id')

self.assertEqual(4, len(original_result.environments))
self.assertEqual(2, len(original_result.learners))
self.assertEqual(1, len(original_result.evaluators))
self.assertEqual(4, len(original_result.interactions))

self.assertEqual(4, len(filtered_result.environments))
self.assertEqual(2, len(filtered_result.learners))
self.assertEqual(1, len(filtered_result.evaluators))
self.assertEqual(4, len(filtered_result.interactions))

def test_where_fin_multi_unsortable_p(self):

CobaContext.logger = IndentLogger()
CobaContext.logger.sink = ListSink()

envs = [['environment_id','data_id','seed' ],[1,'a',1],[2,3,2],[3,'a',2],[4,3,2]]
lrns = [['learner_id' ],[1],[2]]
vals = [['evaluator_id' ],[1]]
ints = [['environment_id','learner_id','evaluator_id','index'],[1,1,1,0],[2,1,1,0],[3,2,1,0],[4,2,1,0]]

original_result = Result(envs, lrns, vals, ints)
filtered_result = original_result.where_fin(l='learner_id',p='data_id')

self.assertEqual(4, len(original_result.environments))
self.assertEqual(2, len(original_result.learners))
self.assertEqual(1, len(original_result.evaluators))
self.assertEqual(4, len(original_result.interactions))

self.assertEqual(4, len(filtered_result.environments))
self.assertEqual(2, len(filtered_result.learners))
self.assertEqual(1, len(filtered_result.evaluators))
self.assertEqual(4, len(filtered_result.interactions))


def test_where_best_family(self):

CobaContext.logger = IndentLogger()
Expand Down
4 changes: 2 additions & 2 deletions coba/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __missing__(self, key):
return value

_T = TypeVar("_T")
def peek_first(items: Iterable[_T], n:int=1, reduce:bool=True) -> Tuple[Union[_T,Sequence[_T]], Iterable[_T]]:
def peek_first(items:Iterable[_T], n:int=1, reduce:bool=True) -> Tuple[Union[_T,Sequence[_T]], Iterable[_T]]:
items = iter(items)
first = list(islice(items,n))

Expand Down Expand Up @@ -184,7 +184,7 @@ def minobj(o):

return obj

def grouper(items: Sequence[Any], key:Callable[[Any],Hashable]=None, val:Callable[[Any],Any]=None, sorted_: bool= False) -> Iterable[Tuple[Hashable,Iterable[Any]]]:
def grouper(items:Sequence[Any], key:Callable[[Any],Hashable]=None, val:Callable[[Any],Any]=None, sorted_:bool=False) -> Iterable[Tuple[Hashable,Iterable[Any]]]:

if sorted_:
for k, group in groupby(items,key=key):
Expand Down

0 comments on commit 55c3ac3

Please sign in to comment.