Skip to content

Commit

Permalink
Merge pull request #286 from akramer/master
Browse files Browse the repository at this point in the history
Fix provider deduction from existing sessions so that an argument is not needed on the authentication landing page.
  • Loading branch information
bentranter authored Aug 12, 2019
2 parents 3b80120 + e211ab6 commit 5d9e6bb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
24 changes: 11 additions & 13 deletions gothic/gothic.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,19 +245,6 @@ var GetProviderName = getProviderName

func getProviderName(req *http.Request) (string, error) {

// get all the used providers
providers := goth.GetProviders()

// loop over the used providers, if we already have a valid session for any provider (ie. user is already logged-in with a provider), then return that provider name
for _, provider := range providers {
p := provider.Name()
session, _ := Store.Get(req, p+SessionName)
value := session.Values[p]
if _, ok := value.(string); ok {
return p, nil
}
}

// try to get it from the url param "provider"
if p := req.URL.Query().Get("provider"); p != "" {
return p, nil
Expand All @@ -278,6 +265,17 @@ func getProviderName(req *http.Request) (string, error) {
return p, nil
}

// As a fallback, loop over the used providers, if we already have a valid session for any provider (ie. user has already begun authentication with a provider), then return that provider name
providers := goth.GetProviders()
session, _ := Store.Get(req, SessionName)
for _, provider := range providers {
p := provider.Name()
value := session.Values[p]
if _, ok := value.(string); ok {
return p, nil
}
}

// if not found then return an empty string with the corresponding error
return "", errors.New("you must select a provider")
}
Expand Down
46 changes: 36 additions & 10 deletions gothic/gothic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ import (
"github.com/stretchr/testify/assert"
)

type mapKey struct {
r *http.Request
n string
}

type ProviderStore struct {
Store map[*http.Request]*sessions.Session
Store map[mapKey]*sessions.Session
}

func NewProviderStore() *ProviderStore {
return &ProviderStore{map[*http.Request]*sessions.Session{}}
return &ProviderStore{map[mapKey]*sessions.Session{}}
}

func (p ProviderStore) Get(r *http.Request, name string) (*sessions.Session, error) {
s := p.Store[r]
s := p.Store[mapKey{r, name}]
if s == nil {
s, err := p.New(r, name)
return s, err
Expand All @@ -42,12 +47,12 @@ func (p ProviderStore) New(r *http.Request, name string) (*sessions.Session, err
Path: "/",
MaxAge: 86400 * 30,
}
p.Store[r] = s
p.Store[mapKey{r, name}] = s
return s, nil
}

func (p ProviderStore) Save(r *http.Request, w http.ResponseWriter, s *sessions.Session) error {
p.Store[r] = s
p.Store[mapKey{r, s.Name()}] = s
return nil
}

Expand All @@ -68,7 +73,7 @@ func Test_BeginAuthHandler(t *testing.T) {

BeginAuthHandler(res, req)

sess, err := Store.Get(req, "faux"+SessionName)
sess, err := Store.Get(req, SessionName)
if err != nil {
t.Fatalf("error getting faux Gothic session: %v", err)
}
Expand Down Expand Up @@ -128,7 +133,28 @@ func Test_CompleteUserAuth(t *testing.T) {
a.NoError(err)

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, "faux"+SessionName)
session, _ := Store.Get(req, SessionName)
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)

user, err := CompleteUserAuth(res, req)
a.NoError(err)

a.Equal(user.Name, "Homer Simpson")
a.Equal(user.Email, "[email protected]")
}

func Test_CompleteUserAuthWithSessionDeducedProvider(t *testing.T) {
a := assert.New(t)

res := httptest.NewRecorder()
// Inteintionally omit a provider argument, force looking in session.
req, err := http.NewRequest("GET", "/auth/callback", nil)
a.NoError(err)

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, SessionName)
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)
Expand All @@ -148,7 +174,7 @@ func Test_Logout(t *testing.T) {
a.NoError(err)

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, "faux"+SessionName)
session, _ := Store.Get(req, SessionName)
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)
Expand All @@ -160,7 +186,7 @@ func Test_Logout(t *testing.T) {
a.Equal(user.Email, "[email protected]")
err = Logout(res, req)
a.NoError(err)
session, _ = Store.Get(req, "faux"+SessionName)
session, _ = Store.Get(req, SessionName)
a.Equal(session.Values, make(map[interface{}]interface{}))
a.Equal(session.Options.MaxAge, -1)
}
Expand Down Expand Up @@ -188,7 +214,7 @@ func Test_StateValidation(t *testing.T) {
a.NoError(err)

BeginAuthHandler(res, req)
session, _ := Store.Get(req, "faux"+SessionName)
session, _ := Store.Get(req, SessionName)

// Assert that matching states will return a nil error
req, err = http.NewRequest("GET", "/auth/callback?provider=faux&state=state_REAL", nil)
Expand Down

0 comments on commit 5d9e6bb

Please sign in to comment.