diff --git a/internal/meters/amberflo.go b/internal/meters/amberflo.go index c2209e56..e5617f65 100644 --- a/internal/meters/amberflo.go +++ b/internal/meters/amberflo.go @@ -91,17 +91,34 @@ func (m *AmberFlo) getValue(user MeterUser, meterName string, startTime time.Tim if cfg.Name == "" { return 0, false } - - startTimeInSeconds := (time.Now().In(time.UTC).UnixNano() / int64(time.Second)) - (24 * 60 * 60) timeRange := &metering.TimeRange{ - StartTimeInSeconds: startTimeInSeconds, + StartTimeInSeconds: startTime.In(time.UTC).Unix(), + EndTimeInSeconds: endTime.In(time.UTC).Unix(), + } + if timeRange.EndTimeInSeconds > time.Now().In(time.UTC).Unix() { + timeRange.EndTimeInSeconds = 0 } + filter := make(map[string][]string) filter["customerId"] = []string{customerId} + for _, dim := range checkDims { + filter[dim.Key] = []string{dim.Value} + } + + timeGroupingInterval := metering.Hour + switch timeSpan := endTime.Unix() - startTime.Unix(); { + case timeSpan > 24*60*60: + timeGroupingInterval = metering.Month + case timeSpan > 60*60: + timeGroupingInterval = metering.Day + default: + timeGroupingInterval = metering.Hour + } + usageResult, err := m.usageClient.GetUsage(&metering.UsagePayload{ MeterApiName: cfg.Name, Aggregation: metering.Sum, - TimeGroupingInterval: metering.Day, + TimeGroupingInterval: timeGroupingInterval, GroupBy: []string{"customerId"}, TimeRange: timeRange, Filter: filter, @@ -114,9 +131,9 @@ func (m *AmberFlo) getValue(user MeterUser, meterName string, startTime time.Tim log.Error().Err(err).Str("user", user.ID()).Msg("could not get value; no client value meter") return 0, false } - cm := usageResult.ClientMeters[0].Values - cmv := cm[len(cm)-1].Value - return cmv, true + + total := usageResult.ClientMeters[0].GroupValue + return total, true } func (m *AmberFlo) sendMeter(user MeterUser, meterName string, value float64, extraDimensions Dimensions) error { diff --git a/internal/meters/cache.go b/internal/meters/cache.go deleted file mode 100644 index d4abcaf5..00000000 --- a/internal/meters/cache.go +++ /dev/null @@ -1,62 +0,0 @@ -package meters - -// func init() { -// var _ MeterProvider = &CacheMeterProvider{} -// } - -// type meterValueCache struct { -// Value float64 -// } - -// type CacheMeterProvider struct { -// cache *ecache.Cache[meterValueCache] -// provider MeterProvider -// } - -// func NewCacheMeterProvider(redisClient *redis.Client, provider MeterProvider) *CacheMeterProvider { -// return &CacheMeterProvider{ -// provider: provider, -// cache: ecache.NewCache[meterValueCache](redisClient, "cachemeter"), -// } -// } - -// func (c *CacheMeterProvider) NewMeter(u MeterUser) ApiMeter { -// return &CacheMeter{ -// userId: u.ID(), -// meter: c.provider.NewMeter(u), -// cm: c, -// } -// } - -// func (c *CacheMeterProvider) Close() error { -// return c.provider.Close() -// } - -// func (c *CacheMeterProvider) Flush() error { -// return c.provider.Flush() -// } - -// type CacheMeter struct { -// userId string -// meter ApiMeter -// cm *CacheMeterProvider -// } - -// func (c *CacheMeter) Meter(meterName string, value float64, extraDimensions Dimensions) error { -// return c.meter.Meter(meterName, value, extraDimensions) -// } - -// func (c *CacheMeter) GetValue(meterName string, d time.Duration, dims Dimensions) (float64, bool) { -// ctx := context.Background() -// key := fmt.Sprintf("%s:%s", c.userId, meterName) -// a, ok := c.cm.cache.Get(ctx, key) -// if !ok { -// a.Value, ok = c.meter.GetValue(meterName, d, dims) -// c.cm.cache.SetTTL(ctx, key, a, 10*time.Minute, 10*time.Minute) -// } -// return a.Value, ok -// } - -// func (c *CacheMeter) AddDimension(meterName string, key string, value string) { -// c.meter.AddDimension(meterName, key, value) -// } diff --git a/internal/meters/cache_test.go b/internal/meters/cache_test.go deleted file mode 100644 index 5acb1d8f..00000000 --- a/internal/meters/cache_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package meters - -// func TestCacheMeter(t *testing.T) { -// redisClient := testutil.MustOpenTestRedisClient() -// mp := NewDefaultMeterProvider() -// cmp := NewCacheMeterProvider(redisClient, mp) -// testConfig := testMeterConfig{ -// testMeter1: "test1", -// testMeter2: "test2", -// user1: &testUser{name: "test1"}, -// user2: &testUser{name: "test2"}, -// user3: &testUser{name: "test3"}, -// } -// testMeter(t, cmp, testConfig) -// } diff --git a/internal/meters/limit.go b/internal/meters/limit.go index 18464c40..ddbd59b6 100644 --- a/internal/meters/limit.go +++ b/internal/meters/limit.go @@ -2,22 +2,18 @@ package meters import ( "errors" + "fmt" "time" + + "github.com/rs/zerolog/log" ) func init() { var _ MeterProvider = &LimitMeterProvider{} } -type userMeterLimit struct { - User string - MeterName string - Dims Dimensions - Period string - Limit float64 -} - type LimitMeterProvider struct { + Enabled bool UserLimits map[string][]userMeterLimit MeterProvider } @@ -46,7 +42,7 @@ type LimitMeter struct { ApiMeter } -func (c *LimitMeter) GetLimit(meterName string, checkDims Dimensions) (time.Time, time.Time, float64, bool) { +func (c *LimitMeter) GetLimit(meterName string, checkDims Dimensions) (userMeterLimit, bool) { var lim userMeterLimit found := false for _, checkLim := range c.provider.UserLimits[c.userId] { @@ -57,13 +53,42 @@ func (c *LimitMeter) GetLimit(meterName string, checkDims Dimensions) (time.Time } } if !found { - // fmt.Println("no limit found") - return time.Now(), time.Now(), 0, false + return lim, false } + return lim, true +} + +func (c *LimitMeter) Meter(meterName string, value float64, extraDimensions Dimensions) error { + lim, foundLimit := c.GetLimit(meterName, extraDimensions) + d1, d2 := lim.Span() + if c.provider.Enabled && foundLimit { + currentValue, _ := c.GetValue(meterName, d1, d2, extraDimensions) + if foundLimit && currentValue+value > lim.Limit { + log.Info().Str("meter", meterName).Str("user", c.userId).Float64("current", currentValue).Float64("add", value).Str("dims", fmt.Sprintf("%v", extraDimensions)).Msg("rate limited") + return errors.New("rate limited") + } else { + log.Info().Str("meter", meterName).Str("user", c.userId).Float64("current", currentValue).Float64("add", value).Str("dims", fmt.Sprintf("%v", extraDimensions)).Msg("rate check") + } + } + return c.ApiMeter.Meter(meterName, value, extraDimensions) +} + +type userMeterLimit struct { + User string + MeterName string + Dims Dimensions + Period string + Limit float64 +} + +func (lim *userMeterLimit) Span() (time.Time, time.Time) { now := time.Now().In(time.UTC) d1 := now d2 := now - if lim.Period == "day" { + if lim.Period == "hour" { + d1 = time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, time.UTC) + d2 = d1.Add(3600 * time.Second) + } else if lim.Period == "day" { d1 = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) d2 = d1.AddDate(0, 0, 1) } else if lim.Period == "month" { @@ -76,22 +101,7 @@ func (c *LimitMeter) GetLimit(meterName string, checkDims Dimensions) (time.Time d1 = time.Unix(0, 0) d2 = time.Unix(1<<63-1, 0) } else { - return now, now, 0, false + return now, now } - // fmt.Println("limit found:", d1, d2, lim.Limit, lim.Period) - return d1, d2, lim.Limit, true -} - -func (c *LimitMeter) Meter(meterName string, value float64, extraDimensions Dimensions) error { - d1, d2, lim, foundLimit := c.GetLimit(meterName, extraDimensions) - a, valueFound := c.GetValue(meterName, d1, d2, extraDimensions) - _ = valueFound - // if !valueFound { - // fmt.Println("value not found") - // } - // fmt.Println("a:", a, "value:", value, "lim:", lim, "extraDims:", extraDimensions) - if foundLimit && a+value > lim { - return errors.New("rate limited") - } - return c.ApiMeter.Meter(meterName, value, extraDimensions) + return d1, d2 } diff --git a/internal/meters/limit_test.go b/internal/meters/limit_test.go index 488508e7..ba914d0d 100644 --- a/internal/meters/limit_test.go +++ b/internal/meters/limit_test.go @@ -1,9 +1,9 @@ package meters import ( + "fmt" "math" "testing" - "time" "github.com/stretchr/testify/assert" ) @@ -13,6 +13,7 @@ func TestLimitMeter(t *testing.T) { user := testUser{name: "testuser"} mp := NewDefaultMeterProvider() cmp := NewLimitMeterProvider(mp) + cmp.Enabled = true testLimitMeter(t, cmp, meterName, user) } @@ -23,61 +24,79 @@ func TestLimitMeter_Amberflo(t *testing.T) { return } cmp := NewLimitMeterProvider(mp) + cmp.Enabled = true testLimitMeter(t, cmp, testConfig.testMeter1, testUser{name: testConfig.user1.ID()}) } func testLimitMeter(t *testing.T, cmp *LimitMeterProvider, meterName string, user testUser) { m := cmp.NewMeter(user) - testDims1 := Dimensions{{Key: "ok", Value: "test"}} - testDims2 := Dimensions{{Key: "ok", Value: "bar"}} - - lim1 := 10.0 - lim2 := 11.0 - incr := 3.0 - - cmp.UserLimits[user.name] = append(cmp.UserLimits[user.name], - userMeterLimit{ + testKey := 1 // time.Now().In(time.UTC).Unix() + lims := []userMeterLimit{ + // foo tests + { + MeterName: meterName, + Period: "hour", + Limit: 5.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("foo:%d", testKey)}}, + }, + { + MeterName: meterName, + Period: "day", + Limit: 8.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("foo:%d", testKey)}}, + }, + { MeterName: meterName, Period: "month", - Limit: lim1, - Dims: testDims1, + Limit: 11.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("foo:%d", testKey)}}, + }, + // bar tests + { + MeterName: meterName, + Period: "hour", + Limit: 14.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("bar:%d", testKey)}}, }, - userMeterLimit{ + { MeterName: meterName, Period: "day", - Limit: lim2, - Dims: testDims2, + Limit: 17.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("bar:%d", testKey)}}, + }, + { + MeterName: meterName, + Period: "month", + Limit: 20.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("bar:%d", testKey)}}, }, - ) - - // 1 - successCount1 := 0.0 - for i := 0; i < 10; i++ { - err := m.Meter(meterName, incr, testDims1) - if err == nil { - successCount1 += 1 - } - cmp.Flush() - } - assert.Equal(t, successCount1, math.Floor(lim1/incr)) - - // 2 - successCount2 := 0.0 - for i := 0; i < 10; i++ { - err := m.Meter(meterName, incr, testDims2) - if err == nil { - successCount2 += 1 - } - cmp.Flush() } - assert.Equal(t, successCount2, math.Floor(lim2/incr)) - // total 1 - v1, _ := m.GetValue(meterName, time.Unix(0, 0), time.Now(), testDims1) - assert.Equal(t, successCount1*incr, v1) + incr := 3.0 + for _, lim := range lims { + t.Run(fmt.Sprintf("%v", lim), func(t *testing.T) { + startTime, endTime := lim.Span() + base, _ := m.GetValue(meterName, startTime, endTime, lim.Dims) + lim.Limit += base + cmp.UserLimits[user.name] = []userMeterLimit{lim} - // total 2 - v2, _ := m.GetValue(meterName, time.Unix(0, 0), time.Now(), testDims2) - assert.Equal(t, successCount2*incr, v2) + successCount := 0.0 + for i := 0; i < 10; i++ { + err := m.Meter(meterName, incr, lim.Dims) + if err == nil { + successCount += 1 + } + cmp.Flush() + } + expectCount := math.Floor((lim.Limit - base) / incr) + // fmt.Println("successCount:", successCount, "expectCount:", expectCount) + assert.Equal(t, expectCount, successCount) + total, _ := m.GetValue(meterName, startTime, endTime, lim.Dims) + total = total - base + expectTotal := successCount * incr + // fmt.Println("total:", total, "expectTotal:", expectTotal) + assert.Equal(t, expectTotal, total) + }) + } } diff --git a/internal/meters/meters_test.go b/internal/meters/meters_test.go index 30d7fb09..f16846ef 100644 --- a/internal/meters/meters_test.go +++ b/internal/meters/meters_test.go @@ -2,7 +2,6 @@ package meters import ( "testing" - "time" "github.com/stretchr/testify/assert" ) @@ -28,8 +27,7 @@ type testMeterConfig struct { } func testMeter(t *testing.T, mp MeterProvider, cfg testMeterConfig) { - d1 := time.Unix(0, 0) - d2 := time.Now().Add(10 * time.Second) + d1, d2 := (&userMeterLimit{Period: "hour"}).Span() t.Run("Meter", func(t *testing.T) { m := mp.NewMeter(cfg.user1) v, _ := m.GetValue(cfg.testMeter1, d1, d2, nil) @@ -86,35 +84,38 @@ func testMeter(t *testing.T, mp MeterProvider, cfg testMeterConfig) { }) t.Run("GetValue match dims", func(t *testing.T) { - addDims := []Dimension{{Key: "test", Value: "ok"}} - addDims2 := []Dimension{{Key: "test", Value: "not ok"}} - checkDims := []Dimension{{Key: "test", Value: "ok"}} - checkDims2 := []Dimension{{Key: "test", Value: "not ok"}} + addDims1 := []Dimension{{Key: "test", Value: "ok1"}} + addDims2 := []Dimension{{Key: "test", Value: "not ok1"}} + checkDims1 := addDims1 + checkDims2 := addDims2 + m1 := mp.NewMeter(cfg.user1) m2 := mp.NewMeter(cfg.user2) m3 := mp.NewMeter(cfg.user3) - v1, _ := m1.GetValue(cfg.testMeter1, d1, d2, checkDims) - v2, _ := m2.GetValue(cfg.testMeter1, d1, d2, checkDims) - v3, _ := m3.GetValue(cfg.testMeter1, d1, d2, checkDims) + + // Initial values + v1, _ := m1.GetValue(cfg.testMeter1, d1, d2, checkDims1) + v2, _ := m2.GetValue(cfg.testMeter1, d1, d2, checkDims2) + v3, _ := m3.GetValue(cfg.testMeter1, d1, d2, checkDims1) // m2 uses different dimension - m1.Meter(cfg.testMeter1, 1, addDims) + m1.Meter(cfg.testMeter1, 1, addDims1) m2.Meter(cfg.testMeter1, 2.0, addDims2) mp.Flush() - a, ok := m1.GetValue(cfg.testMeter1, d1, d2, checkDims) + a, ok := m1.GetValue(cfg.testMeter1, d1, d2, checkDims1) assert.Equal(t, 1.0, a-v1) assert.Equal(t, true, ok) - a, ok = m2.GetValue(cfg.testMeter1, d1, d2, checkDims) - assert.Equal(t, 0.0, a-v2) + a, ok = m2.GetValue(cfg.testMeter1, d1, d2, checkDims1) + assert.Equal(t, 0.0, a) assert.Equal(t, true, ok) a, ok = m2.GetValue(cfg.testMeter1, d1, d2, checkDims2) assert.Equal(t, 2.0, a-v2) assert.Equal(t, true, ok) - a, _ = m3.GetValue(cfg.testMeter1, d1, d2, checkDims) + a, _ = m3.GetValue(cfg.testMeter1, d1, d2, checkDims1) assert.Equal(t, 0.0, a-v3) })