Skip to content

Commit

Permalink
Improve code clarity and robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas-bouvier committed Nov 6, 2019
1 parent eaad5b5 commit 03f2c26
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 174 deletions.
14 changes: 11 additions & 3 deletions associationController.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,33 @@ func GetAllAssociationsController(w http.ResponseWriter, r *http.Request) {
func AddAssociationController(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
var association Association
_ = decoder.Decode(&association)
decoder.Decode(&association)

isValidMail := VerifyEmail(association.Email)
if !isValidMail {
w.WriteHeader(http.StatusConflict)
_ = json.NewEncoder(w).Encode(bson.M{"error": "email already used"})
json.NewEncoder(w).Encode(bson.M{"error": "email already used"})
return
}

res := AddAssociation(association)
password := GeneratePassword()

userID, err := GetUserFromRequest(r)
if err != nil {
w.WriteHeader(http.StatusNotAcceptable)
json.NewEncoder(w).Encode(bson.M{"error": "could not get user ID"})
return
}

var user AssociationUser
user.Association = res.ID
user.Username = res.Email
user.Master = false
user.Owner = GetUserFromRequest(r)
user.Owner = userID
user.Password = GetMD5Hash(password)
AddAssociationUser(user)

_ = SendAssociationEmailSubscription(user.Username, password)
_ = json.NewEncoder(w).Encode(res)
}
Expand Down
223 changes: 102 additions & 121 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,187 +54,137 @@ func InitJWT() error {
}

// CreateNewTokens creates auth and refresh tokens.
func CreateNewTokens(ID bson.ObjectId, role string) (string, string, error) {
// Generate the auth token
authTokenString, err := createAuthTokenString(ID, role)
if err != nil {
return "", "", err
}

// Generate the refresh token
refreshTokenString, err := createRefreshTokenString(ID, role)
if err != nil {
return "", "", err
}

return authTokenString, refreshTokenString, nil
func CreateNewTokens(ID bson.ObjectId, role string) (*jwt.Token, *jwt.Token) {
return createAuthToken(ID, role), createRefreshToken(ID, role)
}

// CheckAndRefreshTokens renews the auth token, if needed.
func CheckAndRefreshTokens(authTokenString string, refreshTokenString string, role string) (string, string, error) {
var newAuthTokenString string
var newRefreshTokenString string

// Check that it matches with the auth token claims
authToken, err := jwt.ParseWithClaims(authTokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
return verifyKey, nil
})

// CheckAndRefreshStringTokens renews the auth token, if needed.
func CheckAndRefreshStringTokens(authStringToken string, refreshStringToken string, role string) (*jwt.Token, *jwt.Token, error) {
authToken, refreshToken, err := parseStringTokens(authStringToken, refreshStringToken)
if err != nil {
return "", "", err
return nil, nil, err
}

// The auth token is still valid
if _, ok := authToken.Claims.(*TokenClaims); ok && authToken.Valid {
if authToken.Valid {
// Check the role
if authToken.Claims.(*TokenClaims).Role != role {
return "", "", errors.New("Unauthorized")
return nil, nil, errors.New("Unauthorized")
}

// Update the expiration time of refresh token
newRefreshTokenString, err = updateRefreshTokenExpiration(refreshTokenString)
newRefreshToken := updateRefreshTokenExpiration(refreshToken)

return authTokenString, newRefreshTokenString, nil
return authToken, newRefreshToken, nil
}

if ve, ok := err.(*jwt.ValidationError); ok {
// The auth token has expired
if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
newAuthTokenString, err = updateAuthTokenString(authTokenString, refreshTokenString)
newAuthToken, err := updateAuthToken(authToken, refreshToken)
if err != nil {
return "", "", err
return nil, nil, err
}

// Update the expiration time of refresh token string
newRefreshTokenString, err = updateRefreshTokenExpiration(refreshTokenString)
if err != nil {
return "", "", err
}
newRefreshToken := updateRefreshTokenExpiration(refreshToken)

return newAuthTokenString, newRefreshTokenString, nil
return newAuthToken, newRefreshToken, nil
}
}

return "", "", err
return nil, nil, err
}

// RevokeRefreshToken deletes the given token from the database, if valid.
func RevokeRefreshToken(refreshTokenString string) error {
refreshToken, err := jwt.ParseWithClaims(refreshTokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
return verifyKey, nil
})

// RevokeRefreshStringToken deletes the given token from the database, if valid.
func RevokeRefreshStringToken(refreshStringToken string) error {
refreshToken, err := parseRefreshStringToken(refreshStringToken)
if err != nil {
return errors.New("Could not parse refresh token with claims")
}

refreshTokenClaims, ok := refreshToken.Claims.(*TokenClaims)
if !ok {
return errors.New("Could not read refresh token claims")
return err
}

deleteRefreshToken(refreshTokenClaims.StandardClaims.Id)
deleteRefreshToken(refreshToken.Claims.(*TokenClaims).StandardClaims.Id)

return nil
}

func GetUserFromRequest(r *http.Request) bson.ObjectId {
authCookie, _ := r.Cookie("AuthToken")
// GetUserFromRequest returns the User or AssociationUser ID from the auth cookie.
func GetUserFromRequest(r *http.Request) (bson.ObjectId, error) {
authCookie, err1 := r.Cookie("AuthToken")
if err1 != nil {
return bson.ObjectId(""), err1
}

// Check that it matches with the auth token claims
authToken, _ := jwt.ParseWithClaims(authCookie.Value, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
return verifyKey, nil
})
authToken, err2 := parseAuthStringToken(authCookie.Value)
if err2 != nil {
return bson.ObjectId(""), err2
}

return authToken.Claims.(*TokenClaims).ID
return authToken.Claims.(*TokenClaims).ID, nil
}

// createAuthTokenString creates an auth token
func createAuthTokenString(id bson.ObjectId, role string) (string, error) {
authTokenExpiration := time.Now().Add(authTokenValidTime).Unix()

authClaims := TokenClaims{
ID: id,
Role: role,
StandardClaims: jwt.StandardClaims{
ExpiresAt: authTokenExpiration,
},
func parseStringTokens(authStringToken string, refreshStringToken string) (*jwt.Token, *jwt.Token, error) {
authToken, err1 := parseAuthStringToken(authStringToken)
if err1 != nil {
return nil, nil, err1
}

// Create a signer for rsa 256
authJwt := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), authClaims)
refreshToken, err2 := parseRefreshStringToken(refreshStringToken)
if err2 != nil {
return nil, nil, err2
}

// Generate the auth token string
return authJwt.SignedString(signKey)
return authToken, refreshToken, nil
}

func updateAuthTokenString(authTokenString string, refreshTokenString string) (string, error) {
refreshToken, err := jwt.ParseWithClaims(refreshTokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
func parseAuthStringToken(authStringToken string) (*jwt.Token, error) {
authToken, err := jwt.ParseWithClaims(authStringToken, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
return verifyKey, nil
})

refreshTokenClaims, ok := refreshToken.Claims.(*TokenClaims)
if !ok {
return "", err
if err != nil {
return nil, err
}

// Check that the refresh token has not been revoked
if checkRefreshToken(refreshTokenClaims.StandardClaims.Id) {
// Has the refresh token expired?
if refreshToken.Valid {
// We can issue a new auth token
authToken, err := jwt.ParseWithClaims(authTokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
return verifyKey, nil
})

authTokenClaims, ok := authToken.Claims.(*TokenClaims)
if !ok {
return "", err
}

return createAuthTokenString(authTokenClaims.ID, authTokenClaims.Role)
}

// The refresh token has expired: revoke the token
deleteRefreshToken(refreshTokenClaims.StandardClaims.Id)

return "", errors.New("Unauthorized")
_, ok := authToken.Claims.(*TokenClaims)
if !ok {
return nil, errors.New("Auth token parse error")
}

// The refresh token has been revoked!
return "", errors.New("Unauthorized")
return authToken, nil
}

func updateRefreshTokenExpiration(refreshTokenString string) (string, error) {
refreshToken, err := jwt.ParseWithClaims(refreshTokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
func parseRefreshStringToken(refreshStringToken string) (*jwt.Token, error) {
refreshToken, err := jwt.ParseWithClaims(refreshStringToken, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
return verifyKey, nil
})

refreshTokenClaims, ok := refreshToken.Claims.(*TokenClaims)
if err != nil {
return nil, err
}
_, ok := refreshToken.Claims.(*TokenClaims)
if !ok {
return "", err
return nil, errors.New("Refresh token parse error")
}

refreshTokenExpiration := time.Now().Add(refreshTokenValidTime).Unix()
return refreshToken, nil
}

refreshClaims := TokenClaims{
ID: refreshTokenClaims.ID,
Role: refreshTokenClaims.Role,
// createAuthToken creates an auth token
func createAuthToken(id bson.ObjectId, role string) *jwt.Token {
authTokenExpiration := time.Now().Add(authTokenValidTime).Unix()

authClaims := TokenClaims{
ID: id,
Role: role,
StandardClaims: jwt.StandardClaims{
Id: refreshTokenClaims.StandardClaims.Id,
ExpiresAt: refreshTokenExpiration,
ExpiresAt: authTokenExpiration,
},
}

// Create a signer for rsa 256
refreshJwt := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), refreshClaims)

// Generate the refresh token string
return refreshJwt.SignedString(signKey)
return jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), authClaims)
}

// createRefreshTokenString create a refresh token
func createRefreshTokenString(id bson.ObjectId, role string) (string, error) {
// createRefreshToken create a refresh token
func createRefreshToken(id bson.ObjectId, role string) *jwt.Token {
refreshTokenExpiration := time.Now().Add(refreshTokenValidTime).Unix()

// Store a token in the database
Expand All @@ -250,10 +200,41 @@ func createRefreshTokenString(id bson.ObjectId, role string) (string, error) {
}

// Create a signer for rsa 256
refreshJwt := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), refreshClaims)
return jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), refreshClaims)
}

// Generate the refresh token string
return refreshJwt.SignedString(signKey)
func updateAuthToken(authToken *jwt.Token, refreshToken *jwt.Token) (*jwt.Token, error) {
// Check that the refresh token has not been revoked
if checkRefreshToken(refreshToken.Claims.(*TokenClaims).StandardClaims.Id) {
// The refresh token has not expired: issue a new auth token
if refreshToken.Valid {
return createAuthToken(authToken.Claims.(*TokenClaims).ID, authToken.Claims.(*TokenClaims).Role), nil
}

// The refresh token has expired: revoke the token
deleteRefreshToken(refreshToken.Claims.(*TokenClaims).StandardClaims.Id)

return nil, errors.New("Unauthorized")
}

// The refresh token has been revoked!
return nil, errors.New("Unauthorized")
}

func updateRefreshTokenExpiration(refreshToken *jwt.Token) *jwt.Token {
refreshTokenExpiration := time.Now().Add(refreshTokenValidTime).Unix()

refreshClaims := TokenClaims{
ID: refreshToken.Claims.(*TokenClaims).ID,
Role: refreshToken.Claims.(*TokenClaims).Role,
StandardClaims: jwt.StandardClaims{
Id: refreshToken.Claims.(*TokenClaims).StandardClaims.Id,
ExpiresAt: refreshTokenExpiration,
},
}

// Create a signer for rsa 256
return jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), refreshClaims)
}

func checkRefreshToken(jti string) bool {
Expand Down
14 changes: 10 additions & 4 deletions eventController.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ func GetEventController(w http.ResponseWriter, r *http.Request) {
// GetFutureEventsController will answer a JSON
// containing all future events from "NOW"
func GetFutureEventsController(w http.ResponseWriter, r *http.Request) {
userID := GetUserFromRequest(r)
user := GetUser(userID)
os := GetNotificationUserForUser(userID).Os
id, err := GetUserFromRequest(r)
if err != nil {
w.WriteHeader(http.StatusNotAcceptable)
json.NewEncoder(w).Encode(bson.M{"error": "could not get user ID"})
return
}

user := GetUser(id)
os := GetNotificationUserForUser(id).Os
events := GetFutureEvents()
res := Events{}
if user.ID != "" {
Expand All @@ -37,7 +43,7 @@ func GetFutureEventsController(w http.ResponseWriter, r *http.Request) {
} else {
res = events
}
_ = json.NewEncoder(w).Encode(res)
json.NewEncoder(w).Encode(res)
}

func GetEventsForAssociationController(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading

0 comments on commit 03f2c26

Please sign in to comment.