-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests for keylook up changes being manifested in token buckets
- Loading branch information
1 parent
fd49152
commit a509ca7
Showing
2 changed files
with
43 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import pytest | ||
from edsl import Model | ||
from edsl import QuestionFreeText | ||
from edsl.language_models.key_management.KeyLookup import KeyLookup | ||
|
||
def test_default_rpm_bucket_collection(): | ||
# Create model with default RPM | ||
m = Model('test') | ||
jobs = QuestionFreeText.example().by(m) | ||
bc = jobs.create_bucket_collection() | ||
rr = bc[Model(model_name='test', temperature=0.5)].requests_bucket.refill_rate | ||
|
||
# Check if actual refill rate is within 10 requests/min of target RPM | ||
actual_rpm = rr * 60 | ||
assert abs(actual_rpm - m.rpm) < 10, \ | ||
f"Actual RPM ({actual_rpm}) differs from target RPM ({m.rpm}) by more than 10" | ||
|
||
def test_custom_rpm_bucket_collection(): | ||
# Setup custom RPM via KeyLookup | ||
kl = KeyLookup.example() | ||
target_rpm = 1 | ||
kl['test'].rpm = target_rpm | ||
|
||
# Create jobs with custom KeyLookup | ||
m = Model('test') | ||
jobs = QuestionFreeText.example().by(m) | ||
jobs2 = jobs.using(kl) | ||
|
||
# Create bucket collection and get refill rate | ||
bc = jobs2.create_bucket_collection() | ||
rr = bc[Model(model_name='test', temperature=0.5)].requests_bucket.refill_rate | ||
|
||
# Check if actual refill rate is within 1 request/min of target RPM | ||
actual_rpm = rr * 60 | ||
assert abs(actual_rpm - target_rpm) < 1, \ | ||
f"Actual RPM ({actual_rpm}) differs from target RPM ({target_rpm}) by more than 1" |