Skip to content

Commit

Permalink
fix validator issues
Browse files Browse the repository at this point in the history
  • Loading branch information
spikelu2016 committed Feb 6, 2024
1 parent ed7c64d commit f84046e
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions internal/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error {
return err
}

err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit, promptCost)
err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit)
if err != nil {
return err
}

err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd, promptCost)
err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd)
if err != nil {
return err
}
Expand Down Expand Up @@ -96,14 +96,14 @@ func (v *Validator) validateRateLimitOverTime(keyId string, rateLimitOverTime in
return errors.New("failed to get rate limit counter")
}

if c+1 > int64(rateLimitOverTime) {
if c >= int64(rateLimitOverTime) {
return internal_errors.NewRateLimitError(fmt.Sprintf("key exceeded rate limit %d requests per %s", rateLimitOverTime, rateLimitUnit))
}

return nil
}

func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit, promptCost float64) error {
func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit) error {
if costLimitOverTime == 0 {
return nil
}
Expand All @@ -113,8 +113,8 @@ func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime fl
return errors.New("failed to get cached token cost")
}

if convertDollarToMicroDollars(promptCost)+cachedCost > convertDollarToMicroDollars(costLimitOverTime) {
return internal_errors.NewExpirationError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit), internal_errors.CostLimitExpiration)
if cachedCost >= convertDollarToMicroDollars(costLimitOverTime) {
return internal_errors.NewCostLimitError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit))
}

return nil
Expand All @@ -124,7 +124,7 @@ func convertDollarToMicroDollars(dollar float64) int64 {
return int64(dollar * 1000000)
}

func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCost float64) error {
func (v *Validator) validateCostLimit(keyId string, costLimit float64) error {
if costLimit == 0 {
return nil
}
Expand All @@ -134,7 +134,7 @@ func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCos
return errors.New("failed to get total token cost")
}

if convertDollarToMicroDollars(promptCost)+existingTotalCost > convertDollarToMicroDollars(costLimit) {
if existingTotalCost >= convertDollarToMicroDollars(costLimit) {
return internal_errors.NewExpirationError(fmt.Sprintf("total cost limit: %f has been reached", costLimit), internal_errors.CostLimitExpiration)
}

Expand Down

0 comments on commit f84046e

Please sign in to comment.