Skip to content

Commit

Permalink
BCDA-2840 Expire admin credentials (#43)
Browse files Browse the repository at this point in the history
* Expire admin credentials
  • Loading branch information
dhgreene authored Apr 15, 2020
1 parent ed3febf commit 048c67d
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 22 deletions.
19 changes: 12 additions & 7 deletions ssas/service/admin/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,35 @@ func requireBasicAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientID, secret, ok := r.BasicAuth()
if !ok {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
formatError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
return
}

system, err := ssas.GetSystemByClientID(clientID)
if err != nil {
formatError(w, http.StatusText(http.StatusUnauthorized), "invalid client id")
formatError(w, http.StatusUnauthorized, "invalid client id")
return
}

savedSecret, err := system.GetSecret()
if err != nil || !ssas.Hash(savedSecret.Hash).IsHashOf(secret) {
formatError(w, http.StatusText(http.StatusUnauthorized), "invalid client secret")
formatError(w, http.StatusUnauthorized, "invalid client secret")
return
}

if savedSecret.IsExpired() {
formatError(w, http.StatusUnauthorized, "credentials expired")
return
}

next.ServeHTTP(w, r)
})
}

func formatError(w http.ResponseWriter, error string, description string) {
ssas.Logger.Printf("%s; %s", description, error)
w.WriteHeader(http.StatusBadRequest)
body := []byte(fmt.Sprintf(`{"error":"%s","error_description":"%s"}`, error, description))
func formatError(w http.ResponseWriter, errorcode int, description string) {
ssas.Logger.Printf("%s; %s", description, http.StatusText(errorcode))
w.WriteHeader(errorcode)
body := []byte(fmt.Sprintf(`{"error":"%s","error_description":"%s"}`, http.StatusText(errorcode), description))
_, err := w.Write(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
Expand Down
99 changes: 99 additions & 0 deletions ssas/service/admin/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package admin

import (
"encoding/base64"
"encoding/json"
"github.com/CMSgov/bcda-ssas-app/ssas"
"github.com/go-chi/chi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
)

type AdminMiddlewareTestSuite struct {
suite.Suite
server *httptest.Server
rr *httptest.ResponseRecorder
basicAuth string
badAuth string
}

func (s *AdminMiddlewareTestSuite) CreateRouter(handler ...func(http.Handler) http.Handler) http.Handler {
router := chi.NewRouter()
router.With(handler...).Get("/", func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("Test router"))
if err != nil {
log.Fatal(err)
}
})

return router
}

func (s *AdminMiddlewareTestSuite) SetupTest() {
s.rr = httptest.NewRecorder()
encCreds, err := ssas.ResetAdminCreds()
assert.NoError(s.T(), err)
s.basicAuth = encCreds

badAuth := "31e029ef-0e97-47f8-873c-0e8b7e7f99bf:This_is_not_the_secret"
s.badAuth = base64.StdEncoding.EncodeToString([]byte(badAuth))
}

func (s *AdminMiddlewareTestSuite) TestRequireBasicAuthSuccess() {
testAuth(s.basicAuth, http.StatusOK, s)
}

func (s *AdminMiddlewareTestSuite) TestRequireBasicAuthFailure() {
r := testAuth(s.badAuth, http.StatusUnauthorized, s)

b, err := ioutil.ReadAll(r.Body)
if err != nil {
assert.FailNow(s.T(), err.Error())
}
var result map[string]interface{}
_ = json.Unmarshal(b, &result)
assert.Equal(s.T(), "invalid client secret", result["error_description"], string(b))
}

func (s *AdminMiddlewareTestSuite) TestRequireBasicAuthExpired() {
ssas.ExpireAdminCreds()
r := testAuth(s.basicAuth, http.StatusUnauthorized, s)

b, err := ioutil.ReadAll(r.Body)
if err != nil {
assert.FailNow(s.T(), err.Error())
}
var result map[string]interface{}
_ = json.Unmarshal(b, &result)
assert.Equal(s.T(), "credentials expired", result["error_description"], string(b))
}

func testAuth(base64Creds string, statusCode int, s *AdminMiddlewareTestSuite) *http.Response {
s.server = httptest.NewServer(s.CreateRouter(requireBasicAuth))
client := s.server.Client()

// Valid credentials should return a 200 response
req, err := http.NewRequest("GET", s.server.URL, nil)
if err != nil {
assert.FailNow(s.T(), err.Error())
}

req.Header.Add("Authorization", "Basic "+base64Creds)

resp, err := client.Do(req)
if err != nil {
assert.FailNow(s.T(), err.Error())
}
assert.Equal(s.T(), statusCode, resp.StatusCode)

return resp
}

func TestAdminMiddlewareTestSuite(t *testing.T) {
suite.Run(t, new(AdminMiddlewareTestSuite))
}
21 changes: 6 additions & 15 deletions ssas/service/admin/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,12 @@ type RouterTestSuite struct {
}

func (s *RouterTestSuite) SetupSuite() {
id := "31e029ef-0e97-47f8-873c-0e8b7e7f99bf"
system, err := ssas.GetSystemByClientID(id)
if err != nil {
s.FailNow(err.Error())
}
encSecret, err := ssas.ResetAdminCreds()
assert.NoError(s.T(), err)

creds, err := system.ResetSecret(id)
if err != nil {
s.FailNow(err.Error())
}
s.basicAuth = encSecret

basicAuth := id + ":" + creds.ClientSecret
s.basicAuth = base64.StdEncoding.EncodeToString([]byte(basicAuth))

badAuth := id + ":This_is_not_the_secret"
badAuth := "31e029ef-0e97-47f8-873c-0e8b7e7f99bf:This_is_not_the_secret"
s.badAuth = base64.StdEncoding.EncodeToString([]byte(badAuth))
}

Expand All @@ -58,7 +49,7 @@ func (s *RouterTestSuite) TestUnauthorized() {
rr := httptest.NewRecorder()
s.router.ServeHTTP(rr, req)
res := rr.Result()
assert.Equal(s.T(), http.StatusBadRequest, res.StatusCode)
assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode)
}

func (s *RouterTestSuite) TestNonBasicAuth() {
Expand All @@ -76,7 +67,7 @@ func (s *RouterTestSuite) TestBadSecret() {
rr := httptest.NewRecorder()
s.router.ServeHTTP(rr, req)
res := rr.Result()
assert.Equal(s.T(), http.StatusBadRequest, res.StatusCode)
assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode)
}

func (s *RouterTestSuite) TestRevokeToken() {
Expand Down
15 changes: 15 additions & 0 deletions ssas/systems.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,21 @@ func (system *System) ResetSecret(trackingID string) (Credentials, error) {
return creds, nil
}

// RevokeActiveCreds revokes all credentials for the specified GroupID
func RevokeActiveCreds(groupID string) error {
systems, err := GetSystemsByGroupIDString(groupID)
if err != nil {
return err
}
for _, system := range systems {
err = system.RevokeSecret("ssas.RevokeActiveCreds for GroupID " + groupID)
if err != nil {
return err
}
}
return nil
}

// CleanDatabase deletes the given group and associated systems, encryption keys, and secrets.
func CleanDatabase(group Group) error {
var (
Expand Down
30 changes: 30 additions & 0 deletions ssas/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,36 @@ import (
"net"
)

func ResetAdminCreds() (encSecret string, err error) {
err = RevokeActiveCreds("admin")
if err != nil {
return
}

id := "31e029ef-0e97-47f8-873c-0e8b7e7f99bf"
system, err := GetSystemByClientID(id)
if err != nil {
return
}

creds, err := system.ResetSecret(id)
if err != nil {
return
}

basicAuth := id + ":" + creds.ClientSecret
encSecret = base64.StdEncoding.EncodeToString([]byte(basicAuth))

return
}

func ExpireAdminCreds() {
db := GetGORMDbConnection()
defer Close(db)

db.Exec("UPDATE secrets SET created_at = '2000-01-01', updated_at = '2000-01-01' WHERE system_id IN (SELECT id FROM systems WHERE client_id = '31e029ef-0e97-47f8-873c-0e8b7e7f99bf')")
}

func GeneratePublicKey(bits int) (string, error) {
keyPair, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
Expand Down

0 comments on commit 048c67d

Please sign in to comment.