diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index af102a6..126d670 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -188,6 +188,10 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons stats.Incr("bricksllm.authenticator.authenticate_http_request.found_key_from_memdb", nil, 1) } + if key == nil { + key = a.kms.GetKey(raw) + } + if key == nil { key, err = a.ks.GetKeyByHash(hash) if err != nil { diff --git a/internal/key/key.go b/internal/key/key.go index cc3b988..8976b9b 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -17,6 +17,7 @@ type UpdateKey struct { Tags []string `json:"tags"` Revoked *bool `json:"revoked"` RevokedReason string `json:"revokedReason"` + Key string `json:"key"` SettingId string `json:"settingId"` SettingIds []string `json:"settingIds"` CostLimitInUsd *float64 `json:"costLimitInUsd"` @@ -29,6 +30,7 @@ type UpdateKey struct { ShouldLogResponse *bool `json:"shouldLogResponse"` RotationEnabled *bool `json:"rotationEnabled"` PolicyId *string `json:"policyId"` + IsKeyNotHashed *bool `json:"isKeyNotHashed"` } func (uk *UpdateKey) Validate() error { @@ -169,6 +171,7 @@ type RequestKey struct { ShouldLogResponse bool `json:"shouldLogResponse"` RotationEnabled bool `json:"rotationEnabled"` PolicyId string `json:"policyId"` + IsKeyNotHashed bool `json:"isKeyNotHashed"` } func (rk *RequestKey) Validate() error { @@ -317,6 +320,7 @@ type ResponseKey struct { ShouldLogResponse bool `json:"shouldLogResponse"` RotationEnabled bool `json:"rotationEnabled"` PolicyId string `json:"policyId"` + IsKeyNotHashed bool `json:"isKeyNotHashed"` } func (rk *ResponseKey) GetSettingIds() []string { diff --git a/internal/manager/key.go b/internal/manager/key.go index 7632a5d..61dd8e4 100644 --- a/internal/manager/key.go +++ b/internal/manager/key.go @@ -61,13 +61,16 @@ func (m *Manager) GetKeys(tags, keyIds []string, provider string) ([]*key.Respon func (m *Manager) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { rk.CreatedAt = time.Now().Unix() rk.UpdatedAt = time.Now().Unix() - rk.Key = encrypter.Encrypt(rk.Key) rk.KeyId = util.NewUuid() if err := rk.Validate(); err != nil { return nil, err } + if !rk.IsKeyNotHashed { + rk.Key = encrypter.Encrypt(rk.Key) + } + if len(rk.SettingId) != 0 { if _, err := m.s.GetProviderSetting(rk.SettingId); err != nil { return nil, err @@ -102,6 +105,19 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err return nil, err } + existing, err := m.s.GetKey(id) + if err != nil { + return nil, err + } + + if uk.IsKeyNotHashed != nil && !*uk.IsKeyNotHashed { + uk.Key = encrypter.Encrypt(existing.Key) + } + + if uk.IsKeyNotHashed == nil || (uk.IsKeyNotHashed != nil && *uk.IsKeyNotHashed) { + uk.Key = "" + } + if len(uk.SettingId) != 0 { if _, err := m.s.GetProviderSetting(uk.SettingId); err != nil { return nil, err diff --git a/internal/storage/postgresql/key.go b/internal/storage/postgresql/key.go index 4125273..369252e 100644 --- a/internal/storage/postgresql/key.go +++ b/internal/storage/postgresql/key.go @@ -55,7 +55,7 @@ func (s *Store) AlterKeysTable() error { END IF; END $$; - ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT ''; + ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE; ` ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) @@ -156,6 +156,7 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response &k.ShouldLogResponse, &k.RotationEnabled, &k.PolicyId, + &k.IsKeyNotHashed, ); err != nil { return nil, err } @@ -208,6 +209,7 @@ func (s *Store) GetKeyByHash(hash string) (*key.ResponseKey, error) { &k.ShouldLogResponse, &k.RotationEnabled, &k.PolicyId, + &k.IsKeyNotHashed, ) if err != nil { @@ -270,6 +272,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) { &k.ShouldLogResponse, &k.RotationEnabled, &k.PolicyId, + &k.IsKeyNotHashed, ); err != nil { return nil, err } @@ -333,6 +336,7 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) { &k.ShouldLogResponse, &k.RotationEnabled, &k.PolicyId, + &k.IsKeyNotHashed, ); err != nil { return nil, err } @@ -391,6 +395,7 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) { &k.ShouldLogResponse, &k.RotationEnabled, &k.PolicyId, + &k.IsKeyNotHashed, ); err != nil { return nil, err } @@ -531,6 +536,12 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error counter++ } + if len(uk.Key) != 0 { + values = append(values, uk.Key) + fields = append(fields, fmt.Sprintf("key = $%d", counter)) + counter++ + } + query := fmt.Sprintf("UPDATE keys SET %s WHERE key_id = $1 RETURNING *;", strings.Join(fields, ",")) ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) @@ -561,6 +572,7 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error &k.ShouldLogResponse, &k.RotationEnabled, &k.PolicyId, + &k.IsKeyNotHashed, ); err != nil { if err == sql.ErrNoRows { return nil, internal_errors.NewNotFoundError(fmt.Sprintf("key not found for id: %s", id)) @@ -585,8 +597,8 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { query := ` - INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) + INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id, is_key_not_hashed) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22) RETURNING *; ` @@ -617,6 +629,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { rk.ShouldLogResponse, rk.RotationEnabled, rk.PolicyId, + rk.IsKeyNotHashed, } ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) @@ -648,6 +661,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { &k.ShouldLogResponse, &k.RotationEnabled, &k.PolicyId, + &k.IsKeyNotHashed, ); err != nil { return nil, err }