Skip to content

Commit

Permalink
add key integration
Browse files Browse the repository at this point in the history
  • Loading branch information
spikelu2016 committed Mar 25, 2024
1 parent c792ec2 commit 87ed692
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
4 changes: 4 additions & 0 deletions internal/authenticator/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons
stats.Incr("bricksllm.authenticator.authenticate_http_request.found_key_from_memdb", nil, 1)
}

if key == nil {
key = a.kms.GetKey(raw)
}

if key == nil {
key, err = a.ks.GetKeyByHash(hash)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/key/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type UpdateKey struct {
Tags []string `json:"tags"`
Revoked *bool `json:"revoked"`
RevokedReason string `json:"revokedReason"`
Key string `json:"key"`
SettingId string `json:"settingId"`
SettingIds []string `json:"settingIds"`
CostLimitInUsd *float64 `json:"costLimitInUsd"`
Expand All @@ -29,6 +30,7 @@ type UpdateKey struct {
ShouldLogResponse *bool `json:"shouldLogResponse"`
RotationEnabled *bool `json:"rotationEnabled"`
PolicyId *string `json:"policyId"`
IsKeyNotHashed *bool `json:"isKeyNotHashed"`
}

func (uk *UpdateKey) Validate() error {
Expand Down Expand Up @@ -169,6 +171,7 @@ type RequestKey struct {
ShouldLogResponse bool `json:"shouldLogResponse"`
RotationEnabled bool `json:"rotationEnabled"`
PolicyId string `json:"policyId"`
IsKeyNotHashed bool `json:"isKeyNotHashed"`
}

func (rk *RequestKey) Validate() error {
Expand Down Expand Up @@ -317,6 +320,7 @@ type ResponseKey struct {
ShouldLogResponse bool `json:"shouldLogResponse"`
RotationEnabled bool `json:"rotationEnabled"`
PolicyId string `json:"policyId"`
IsKeyNotHashed bool `json:"isKeyNotHashed"`
}

func (rk *ResponseKey) GetSettingIds() []string {
Expand Down
18 changes: 17 additions & 1 deletion internal/manager/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,16 @@ func (m *Manager) GetKeys(tags, keyIds []string, provider string) ([]*key.Respon
func (m *Manager) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
rk.CreatedAt = time.Now().Unix()
rk.UpdatedAt = time.Now().Unix()
rk.Key = encrypter.Encrypt(rk.Key)
rk.KeyId = util.NewUuid()

if err := rk.Validate(); err != nil {
return nil, err
}

if !rk.IsKeyNotHashed {
rk.Key = encrypter.Encrypt(rk.Key)
}

if len(rk.SettingId) != 0 {
if _, err := m.s.GetProviderSetting(rk.SettingId); err != nil {
return nil, err
Expand Down Expand Up @@ -102,6 +105,19 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err
return nil, err
}

existing, err := m.s.GetKey(id)
if err != nil {
return nil, err
}

if uk.IsKeyNotHashed != nil && !*uk.IsKeyNotHashed {
uk.Key = encrypter.Encrypt(existing.Key)
}

if uk.IsKeyNotHashed == nil || (uk.IsKeyNotHashed != nil && *uk.IsKeyNotHashed) {
uk.Key = ""
}

if len(uk.SettingId) != 0 {
if _, err := m.s.GetProviderSetting(uk.SettingId); err != nil {
return nil, err
Expand Down
20 changes: 17 additions & 3 deletions internal/storage/postgresql/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *Store) AlterKeysTable() error {
END IF;
END
$$;
ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '';
ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE;
`

ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
Expand Down Expand Up @@ -156,6 +156,7 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
&k.IsKeyNotHashed,
); err != nil {
return nil, err
}
Expand Down Expand Up @@ -208,6 +209,7 @@ func (s *Store) GetKeyByHash(hash string) (*key.ResponseKey, error) {
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
&k.IsKeyNotHashed,
)

if err != nil {
Expand Down Expand Up @@ -270,6 +272,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) {
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
&k.IsKeyNotHashed,
); err != nil {
return nil, err
}
Expand Down Expand Up @@ -333,6 +336,7 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) {
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
&k.IsKeyNotHashed,
); err != nil {
return nil, err
}
Expand Down Expand Up @@ -391,6 +395,7 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) {
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
&k.IsKeyNotHashed,
); err != nil {
return nil, err
}
Expand Down Expand Up @@ -531,6 +536,12 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
counter++
}

if len(uk.Key) != 0 {
values = append(values, uk.Key)
fields = append(fields, fmt.Sprintf("key = $%d", counter))
counter++
}

query := fmt.Sprintf("UPDATE keys SET %s WHERE key_id = $1 RETURNING *;", strings.Join(fields, ","))

ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
Expand Down Expand Up @@ -561,6 +572,7 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
&k.IsKeyNotHashed,
); err != nil {
if err == sql.ErrNoRows {
return nil, internal_errors.NewNotFoundError(fmt.Sprintf("key not found for id: %s", id))
Expand All @@ -585,8 +597,8 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error

func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
query := `
INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id, is_key_not_hashed)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
RETURNING *;
`

Expand Down Expand Up @@ -617,6 +629,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
rk.ShouldLogResponse,
rk.RotationEnabled,
rk.PolicyId,
rk.IsKeyNotHashed,
}

ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
Expand Down Expand Up @@ -648,6 +661,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
&k.IsKeyNotHashed,
); err != nil {
return nil, err
}
Expand Down

0 comments on commit 87ed692

Please sign in to comment.