Skip to content

Commit

Permalink
Merge pull request #294 from remotehq/feature/slack-provider-progress…
Browse files Browse the repository at this point in the history
…ive-profile-fetching

Feature - Slack Provider - Progressive Profile Fetching
  • Loading branch information
bentranter authored Oct 9, 2019
2 parents 124a0b3 + 64d022d commit 42e707e
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 27 deletions.
91 changes: 66 additions & 25 deletions providers/slack/slack.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ import (
"net/url"

"fmt"

"github.com/markbates/goth"
"golang.org/x/oauth2"
)

// Scopes
const (
ScopeUserRead string = "users:read"
)

// URLs and endpoints
const (
authURL string = "https://slack.com/oauth/authorize"
tokenURL string = "https://slack.com/api/oauth.access"
Expand Down Expand Up @@ -56,6 +63,7 @@ func (p *Provider) SetName(name string) {
p.providerName = name
}

// Client returns the http.Client used in the provider.
func (p *Provider) Client() *http.Client {
return goth.HTTPClientWithFallBack(p.HTTPClient)
}
Expand Down Expand Up @@ -88,11 +96,9 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
// Get the userID, slack needs userID in order to get user profile info
response, err := p.Client().Get(endpointUser + "?token=" + url.QueryEscape(sess.AccessToken))
if err != nil {
if response != nil {
response.Body.Close()
}
return user, err
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
Expand All @@ -103,36 +109,54 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
return user, err
}

u := struct {
UserID string `json:"user_id"`
}{}

err = json.NewDecoder(bytes.NewReader(bits)).Decode(&u)

// Get user profile info
response, err = p.Client().Get(endpointProfile + "?token=" + url.QueryEscape(sess.AccessToken) + "&user=" + u.UserID)
err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)
if err != nil {
if response != nil {
response.Body.Close()
}
return user, err
}
defer response.Body.Close()

bits, err = ioutil.ReadAll(response.Body)
if err != nil {
return user, err
}
err = simpleUserFromReader(bytes.NewReader(bits), &user)

err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)
if err != nil {
return user, err
if p.hasScope(ScopeUserRead) {
// Get user profile info
response, err = p.Client().Get(endpointProfile + "?token=" + url.QueryEscape(sess.AccessToken) + "&user=" + user.UserID)
if err != nil {
return user, err
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
}

bits, err = ioutil.ReadAll(response.Body)
if err != nil {
return user, err
}

err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)
if err != nil {
return user, err
}

err = userFromReader(bytes.NewReader(bits), &user)
}

err = userFromReader(bytes.NewReader(bits), &user)
return user, err
}

func (p *Provider) hasScope(scope string) bool {
hasScope := false

for i := range p.config.Scopes {
if p.config.Scopes[i] == scope {
hasScope = true
break
}
}

return hasScope
}

func newConfig(provider *Provider, scopes []string) *oauth2.Config {
c := &oauth2.Config{
ClientID: provider.ClientKey,
Expand All @@ -150,11 +174,28 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config {
c.Scopes = append(c.Scopes, scope)
}
} else {
c.Scopes = append(c.Scopes, "users:read")
c.Scopes = append(c.Scopes, ScopeUserRead)
}
return c
}

func simpleUserFromReader(r io.Reader, user *goth.User) error {
u := struct {
UserID string `json:"user_id"`
Name string `json:"user"`
}{}

err := json.NewDecoder(r).Decode(&u)
if err != nil {
return err
}

user.UserID = u.UserID
user.NickName = u.Name

return nil
}

func userFromReader(r io.Reader, user *goth.User) error {
u := struct {
User struct {
Expand All @@ -165,7 +206,7 @@ func userFromReader(r io.Reader, user *goth.User) error {
Name string `json:"real_name"`
AvatarURL string `json:"image_32"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
LastName string `json:"last_name"`
} `json:"profile"`
} `json:"user"`
}{}
Expand Down
188 changes: 186 additions & 2 deletions providers/slack/slack_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
package slack_test

import (
"context"
"crypto/tls"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/markbates/goth"
"github.com/markbates/goth/providers/slack"
"github.com/stretchr/testify/assert"
"os"
"testing"
)

var (
testAuthTestResponseData = map[string]interface{}{
"user": "testuser",
"user_id": "user1234",
}

testUserInfoResponseData = map[string]interface{}{
"user": map[string]interface{}{
"id": testAuthTestResponseData["user_id"],
"name": testAuthTestResponseData["user"],
"profile": map[string]interface{}{
"real_name": "Test User",
"first_name": "Test",
"last_name": "User",
"image_32": "http://example.org/avatar.png",
"email": "[email protected]",
},
},
}
)

func Test_New(t *testing.T) {
Expand Down Expand Up @@ -34,6 +62,142 @@ func Test_BeginAuth(t *testing.T) {
a.Contains(s.AuthURL, "slack.com/oauth/authorize")
}

func Test_FetchUser(t *testing.T) {
t.Parallel()

for _, testData := range []struct {
name string
provider *slack.Provider
session goth.Session
handler http.Handler
expectedUser goth.User
expectErr bool
}{
{
name: "FetchesFullProfile",
provider: provider(),
session: &slack.Session{AccessToken: "TOKEN"},
handler: http.HandlerFunc(
func(res http.ResponseWriter, req *http.Request) {
switch req.URL.Path {
case "/api/auth.test":
res.WriteHeader(http.StatusOK)
json.NewEncoder(res).Encode(testAuthTestResponseData)
case "/api/users.info":
res.WriteHeader(http.StatusOK)
json.NewEncoder(res).Encode(testUserInfoResponseData)
default:
res.WriteHeader(http.StatusNotFound)
}
},
),
expectedUser: goth.User{
UserID: "user1234",
NickName: "testuser",
Name: "Test User",
FirstName: "Test",
LastName: "User",
AvatarURL: "http://example.org/avatar.png",
Email: "[email protected]",
AccessToken: "TOKEN",
},
expectErr: false,
},
{
name: "FetchesBasicProfileWhenLackingUserReadScope",
provider: slack.New(os.Getenv("SLACK_KEY"), os.Getenv("SLACK_SECRET"), "/foo", "commands"),
session: &slack.Session{AccessToken: "TOKEN"},
handler: http.HandlerFunc(
func(res http.ResponseWriter, req *http.Request) {
switch req.URL.Path {
case "/api/auth.test":
res.WriteHeader(http.StatusOK)
json.NewEncoder(res).Encode(testAuthTestResponseData)
default:
res.WriteHeader(http.StatusNotFound)
}
},
),
expectedUser: goth.User{
UserID: "user1234",
NickName: "testuser",
AccessToken: "TOKEN",
},
expectErr: false,
},
{
name: "FailsWithNoAccessToken",
provider: provider(),
session: &slack.Session{AccessToken: ""},
handler: nil,
expectErr: true,
},
{
name: "FailsWithBadAuthTestResponse",
provider: provider(),
session: &slack.Session{AccessToken: "TOKEN"},
handler: http.HandlerFunc(
func(res http.ResponseWriter, req *http.Request) {
switch req.URL.Path {
case "/api/auth.test":
res.WriteHeader(http.StatusForbidden)
}
},
),
expectedUser: goth.User{
AccessToken: "TOKEN",
},
expectErr: true,
},
{
name: "FailsWithBadUserInfoResponse",
provider: provider(),
session: &slack.Session{AccessToken: "TOKEN"},
handler: http.HandlerFunc(
func(res http.ResponseWriter, req *http.Request) {
switch req.URL.Path {
case "/api/auth.test":
res.WriteHeader(http.StatusOK)
json.NewEncoder(res).Encode(testAuthTestResponseData)
case "/api/users.info":
res.WriteHeader(http.StatusForbidden)
}
},
),
expectedUser: goth.User{
UserID: "user1234",
NickName: "testuser",
AccessToken: "TOKEN",
},
expectErr: true,
},
} {
t.Run(testData.name, func(t *testing.T) {
a := assert.New(t)

withMockServer(testData.provider, testData.handler, func(p *slack.Provider) {
user, err := p.FetchUser(testData.session)
a.NotZero(user)

if testData.expectErr {
a.Error(err)
} else {
a.NoError(err)
}

a.Equal(testData.expectedUser.UserID, user.UserID)
a.Equal(testData.expectedUser.NickName, user.NickName)
a.Equal(testData.expectedUser.Name, user.Name)
a.Equal(testData.expectedUser.FirstName, user.FirstName)
a.Equal(testData.expectedUser.LastName, user.LastName)
a.Equal(testData.expectedUser.AvatarURL, user.AvatarURL)
a.Equal(testData.expectedUser.Email, user.Email)
a.Equal(testData.expectedUser.AccessToken, user.AccessToken)
})
})
}
}

func Test_SessionFromJSON(t *testing.T) {
t.Parallel()
a := assert.New(t)
Expand All @@ -50,3 +214,23 @@ func Test_SessionFromJSON(t *testing.T) {
func provider() *slack.Provider {
return slack.New(os.Getenv("SLACK_KEY"), os.Getenv("SLACK_SECRET"), "/foo")
}

func withMockServer(p *slack.Provider, handler http.Handler, fn func(p *slack.Provider)) {
server := httptest.NewTLSServer(handler)
defer server.Close()

httpClient := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
return net.Dial(network, server.Listener.Addr().String())
},
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}

p.HTTPClient = httpClient

fn(p)
}

0 comments on commit 42e707e

Please sign in to comment.