Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMQ - 2724 - Add Auth Callout #2731

Merged
merged 18 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/http/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
errors.Contains(err, apiutil.ErrMissingRoleName),
errors.Contains(err, apiutil.ErrMissingRoleID),
errors.Contains(err, apiutil.ErrMissingPolicyEntityType),
errors.Contains(err, apiutil.ErrMissingRoleMembers):
errors.Contains(err, apiutil.ErrMissingRoleMembers),
errors.Contains(err, apiutil.ErrMissingDescription),
errors.Contains(err, apiutil.ErrMissingEntityID):
err = unwrap(err)
w.WriteHeader(http.StatusBadRequest)

Expand Down
3 changes: 3 additions & 0 deletions api/http/util/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ var (
// ErrMissingID indicates missing entity ID.
ErrMissingID = errors.New("missing entity id")

// ErrMissingEntityID indicates missing entity ID.
ErrMissingEntityID = errors.New("missing entity id")

// ErrMissingClientID indicates missing client ID.
ErrMissingClientID = errors.New("missing cient id")

Expand Down
148 changes: 81 additions & 67 deletions auth/README.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion auth/api/http/keys/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ func newService() (auth.Service, *mocks.KeyRepository) {
pService := new(policymocks.Service)
pEvaluator := new(policymocks.Evaluator)
t := jwt.New([]byte(secret))
callback := new(mocks.CallBack)

return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration, callback), krepo
}

func newServer(svc auth.Service) *httptest.Server {
Expand Down
113 changes: 113 additions & 0 deletions auth/callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0

package auth

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"

"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
)

type callback struct {
httpClient *http.Client
urls []string
method string
}

// CallBack send auth request to an external service.
//
//go:generate mockery --name CallBack --output=./mocks --filename callback.go --quiet --note "Copyright (c) Abstract Machines"
type CallBack interface {
Authorize(ctx context.Context, pr policies.Policy) error
}

// NewCallback creates a new instance of CallBack.
func NewCallback(httpClient *http.Client, method string, urls []string) (CallBack, error) {
if httpClient == nil {
httpClient = http.DefaultClient
}
if method != http.MethodPost && method != http.MethodGet {
return nil, fmt.Errorf("unsupported auth callback method: %s", method)
}

return &callback{
httpClient: httpClient,
urls: urls,
method: method,
}, nil
}

func (c *callback) Authorize(ctx context.Context, pr policies.Policy) error {
if len(c.urls) == 0 {
return nil
}

payload := map[string]string{
"domain": pr.Domain,
"subject": pr.Subject,
"subject_type": pr.SubjectType,
"subject_kind": pr.SubjectKind,
"subject_relation": pr.SubjectRelation,
"object": pr.Object,
"object_type": pr.ObjectType,
"object_kind": pr.ObjectKind,
"relation": pr.Relation,
"permission": pr.Permission,
}

var err error
// We use a single URL at a time and others as fallbacks
// the first positive result returned by a callback in the chain is considered to be final
for i := range c.urls {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an explanation comment here.

if err = c.makeRequest(ctx, c.urls[i], payload); err == nil {
return nil
}
}

return err
}

func (c *callback) makeRequest(ctx context.Context, urlStr string, params map[string]string) error {
var req *http.Request
var err error

switch c.method {
case http.MethodGet:
query := url.Values{}
for key, value := range params {
query.Set(key, value)
}
req, err = http.NewRequestWithContext(ctx, c.method, urlStr+"?"+query.Encode(), nil)
case http.MethodPost:
data, jsonErr := json.Marshal(params)
if jsonErr != nil {
return jsonErr
}

Check warning on line 93 in auth/callback.go

View check run for this annotation

Codecov / codecov/patch

auth/callback.go#L92-L93

Added lines #L92 - L93 were not covered by tests
req, err = http.NewRequestWithContext(ctx, c.method, urlStr, bytes.NewReader(data))
req.Header.Set("Content-Type", "application/json")
}

if err != nil {
return err
}

Check warning on line 100 in auth/callback.go

View check run for this annotation

Codecov / codecov/patch

auth/callback.go#L99-L100

Added lines #L99 - L100 were not covered by tests

resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, resp.StatusCode)
}

return nil
}
143 changes: 143 additions & 0 deletions auth/callback_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0

package auth_test

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/stretchr/testify/assert"
)

func TestCallback_Authorize(t *testing.T) {
policy := policies.Policy{
Domain: "test-domain",
Subject: "test-subject",
SubjectType: "user",
SubjectKind: "individual",
SubjectRelation: "owner",
Object: "test-object",
ObjectType: "message",
ObjectKind: "event",
Relation: "publish",
Permission: "allow",
}

cases := []struct {
desc string
method string
respStatus int
expectError bool
}{
{
desc: "successful GET authorization",
method: http.MethodGet,
respStatus: http.StatusOK,
expectError: false,
},
{
desc: "successful POST authorization",
method: http.MethodPost,
respStatus: http.StatusOK,
expectError: false,
},
{
desc: "failed authorization",
method: http.MethodPost,
respStatus: http.StatusForbidden,
expectError: true,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, tc.method, r.Method)

if tc.method == http.MethodGet {
query := r.URL.Query()
assert.Equal(t, policy.Domain, query.Get("domain"))
assert.Equal(t, policy.Subject, query.Get("subject"))
}

w.WriteHeader(tc.respStatus)
}))
defer ts.Close()

cb, err := auth.NewCallback(http.DefaultClient, tc.method, []string{ts.URL})
assert.NoError(t, err)
err = cb.Authorize(context.Background(), policy)

if tc.expectError {
assert.Error(t, err)
assert.True(t, errors.Contains(err, svcerr.ErrAuthorization), "expected authorization error")
} else {
assert.NoError(t, err)
}
})
}
}

func TestCallback_MultipleURLs(t *testing.T) {
ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts1.Close()

ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts2.Close()

cb, err := auth.NewCallback(http.DefaultClient, http.MethodPost, []string{ts1.URL, ts2.URL})
assert.NoError(t, err)
err = cb.Authorize(context.Background(), policies.Policy{})
assert.NoError(t, err)
}

func TestCallback_InvalidURL(t *testing.T) {
cb, err := auth.NewCallback(http.DefaultClient, http.MethodPost, []string{"http://invalid-url"})
assert.NoError(t, err)
err = cb.Authorize(context.Background(), policies.Policy{})
assert.Error(t, err)
}

func TestCallback_InvalidMethod(t *testing.T) {
_, err := auth.NewCallback(http.DefaultClient, "invalid-method", []string{"http://example.com"})
assert.Error(t, err)
}

func TestCallback_CancelledContext(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

ctx, cancel := context.WithCancel(context.Background())
cancel()

cb, err := auth.NewCallback(http.DefaultClient, http.MethodPost, []string{ts.URL})
assert.NoError(t, err)
err = cb.Authorize(ctx, policies.Policy{})
assert.Error(t, err)
}

func TestNewCallback_NilClient(t *testing.T) {
cb, err := auth.NewCallback(nil, http.MethodPost, []string{"test"})
assert.NoError(t, err)
assert.NotNil(t, cb)
}

func TestCallback_NoURL(t *testing.T) {
cb, err := auth.NewCallback(http.DefaultClient, http.MethodPost, []string{})
assert.NoError(t, err)
err = cb.Authorize(context.Background(), policies.Policy{})
assert.NoError(t, err)
}
49 changes: 49 additions & 0 deletions auth/mocks/callback.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions auth/pat.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"strings"
"time"

apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/errors"
)

Expand Down Expand Up @@ -275,15 +276,16 @@
return errInvalidScope
}
if s.EntityID == "" {
return errors.New("missing entityID")
return apiutil.ErrMissingEntityID

Check warning on line 279 in auth/pat.go

View check run for this annotation

Codecov / codecov/patch

auth/pat.go#L279

Added line #L279 was not covered by tests
}

switch s.EntityType {
case ChannelsType, GroupsType, ClientsType:
if s.OptionalDomainID == "" {
return errors.New("missing domainID")
return apiutil.ErrMissingDomainID

Check warning on line 285 in auth/pat.go

View check run for this annotation

Codecov / codecov/patch

auth/pat.go#L285

Added line #L285 was not covered by tests
}
}

return nil
}

Expand Down
Loading
Loading