Skip to content

Commit

Permalink
Tests for keylook up changes being manifested in token buckets
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjosephhorton committed Feb 9, 2025
1 parent fd49152 commit a509ca7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
8 changes: 7 additions & 1 deletion edsl/jobs/Jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,13 @@ def create_bucket_collection(self) -> BucketCollection:
>>> bc
BucketCollection(...)
"""
return BucketCollection.from_models(self.models)
bc = BucketCollection.from_models(self.models)

if self.run_config.environment.key_lookup is not None:
bc.update_from_key_lookup(
self.run_config.environment.key_lookup
)
return bc

def html(self):
"""Return the HTML representations for each scenario"""
Expand Down
36 changes: 36 additions & 0 deletions tests/jobs/test_KeyLookup_Modify_BucketCollection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
from edsl import Model
from edsl import QuestionFreeText
from edsl.language_models.key_management.KeyLookup import KeyLookup

def test_default_rpm_bucket_collection():
# Create model with default RPM
m = Model('test')
jobs = QuestionFreeText.example().by(m)
bc = jobs.create_bucket_collection()
rr = bc[Model(model_name='test', temperature=0.5)].requests_bucket.refill_rate

# Check if actual refill rate is within 10 requests/min of target RPM
actual_rpm = rr * 60
assert abs(actual_rpm - m.rpm) < 10, \
f"Actual RPM ({actual_rpm}) differs from target RPM ({m.rpm}) by more than 10"

def test_custom_rpm_bucket_collection():
# Setup custom RPM via KeyLookup
kl = KeyLookup.example()
target_rpm = 1
kl['test'].rpm = target_rpm

# Create jobs with custom KeyLookup
m = Model('test')
jobs = QuestionFreeText.example().by(m)
jobs2 = jobs.using(kl)

# Create bucket collection and get refill rate
bc = jobs2.create_bucket_collection()
rr = bc[Model(model_name='test', temperature=0.5)].requests_bucket.refill_rate

# Check if actual refill rate is within 1 request/min of target RPM
actual_rpm = rr * 60
assert abs(actual_rpm - target_rpm) < 1, \
f"Actual RPM ({actual_rpm}) differs from target RPM ({target_rpm}) by more than 1"

0 comments on commit a509ca7

Please sign in to comment.