Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added flag to allow redirect between subdomains #124

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ _Note for Caddy users_: Not all parameters are available in Caddy. See the table
| -osiam | value | | X | OSIAM login backend opts: endpoint=..,client_id=..,client_secret=.. |
| -port | string | "6789" | - | Port to listen on |
| -redirect | boolean | true | X | Allow dynamic overwriting of the the success by query parameter |
| -redirect-allow-subdomain | bool | false | X | If true redirect is allowed when the target is on a different subdomain |
| -redirect-query-parameter | string | "backTo" | X | URL parameter for the redirect target |
| -redirect-check-referer | boolean | true | X | Check the referer header to ensure it matches the host header on dynamic redirects |
| -redirect-host-file | string | "" | X | A file containing a list of domains that redirects are allowed to, one domain per line |
Expand Down
2 changes: 2 additions & 0 deletions caddy/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func TestSetup(t *testing.T) {
redirect_query_parameter comingFrom
redirect_check_referer true
redirect_host_file domainWhitelist.txt
redirect_allow_subdomain true
cookie_name cookiename
cookie_http_only false
cookie_domain example.com
Expand All @@ -60,6 +61,7 @@ func TestSetup(t *testing.T) {
Equal(t, cfg.RedirectQueryParameter, "comingFrom")
Equal(t, cfg.RedirectCheckReferer, true)
Equal(t, cfg.RedirectHostFile, "domainWhitelist.txt")
Equal(t, cfg.RedirectAllowSubdomain, true)
Equal(t, cfg.CookieName, "cookiename")
Equal(t, cfg.CookieHTTPOnly, false)
Equal(t, cfg.CookieDomain, "example.com")
Expand Down
3 changes: 3 additions & 0 deletions login/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func DefaultConfig() *Config {
RedirectQueryParameter: "backTo",
RedirectCheckReferer: true,
RedirectHostFile: "",
RedirectAllowSubdomain: false,
LogoutURL: "",
LoginPath: "/login",
CookieName: "jwt_token",
Expand Down Expand Up @@ -73,6 +74,7 @@ type Config struct {
RedirectQueryParameter string
RedirectCheckReferer bool
RedirectHostFile string
RedirectAllowSubdomain bool
LogoutURL string
Template string
LoginPath string
Expand Down Expand Up @@ -152,6 +154,7 @@ func (c *Config) ConfigureFlagSet(f *flag.FlagSet) {
f.StringVar(&c.RedirectQueryParameter, "redirect-query-parameter", c.RedirectQueryParameter, "URL parameter for the redirect target")
f.BoolVar(&c.RedirectCheckReferer, "redirect-check-referer", c.RedirectCheckReferer, "When redirecting check that the referer is the same domain")
f.StringVar(&c.RedirectHostFile, "redirect-host-file", c.RedirectHostFile, "A file containing a list of domains that redirects are allowed to, one domain per line")
f.BoolVar(&c.RedirectAllowSubdomain, "redirect-allow-subdomain", c.RedirectAllowSubdomain, "If true a redirect is allowed if the target is a different subdomain than loginsrv")

f.StringVar(&c.LogoutURL, "logout-url", c.LogoutURL, "The url or path to redirect after logout")
f.StringVar(&c.Template, "template", c.Template, "An alternative template for the login form")
Expand Down
32 changes: 32 additions & 0 deletions login/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,52 @@ func (h *Handler) allowRedirect(r *http.Request) bool {
logging.Application(r.Header).Warnf("couldn't parse redirect url %s", err)
return false
}

if referer.Host != r.Host {
logging.Application(r.Header).Warnf("redirect from referer domain: '%s', not matching current domain '%s'", referer.Host, r.Host)
return false
}
return true
}

func removeSubdomain(host string) string {
parts := strings.Split(host, ".")
if len(parts) == 1 {
return host
}
return strings.Join(parts[1:], ".")
}

// haveSubdomain checks that there's at least one subdomain
func haveSubdomain(host string) bool {
trimmed := strings.Trim(host, ".")
parts := strings.Split(trimmed, ".")
return len(parts) > 2
}

func (h *Handler) isSubdomainAllowed(target string, host string) bool {
if !h.config.RedirectAllowSubdomain {
return false
}
if target == "" || host == "" {
return false
}
if !haveSubdomain(target) || !haveSubdomain(host) {
return false
}
return removeSubdomain(target) == removeSubdomain(host)
}

func (h *Handler) redirectURL(r *http.Request, w http.ResponseWriter) string {
targetURL, foundTarget := h.getRedirectTarget(r)
if foundTarget && h.config.Redirect {
sameHost := targetURL.Host == "" || r.Host == targetURL.Host
if sameHost && targetURL.Path != "" {
return targetURL.Path
}
if h.isSubdomainAllowed(targetURL.Host, r.Host) {
return targetURL.String()
}
if !sameHost && h.isRedirectDomainWhitelisted(r, targetURL.Host) {
return targetURL.String()
}
Expand Down
71 changes: 71 additions & 0 deletions login/redirect_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package login

import (
"fmt"
"net/http/httptest"
"os"
"testing"
Expand Down Expand Up @@ -122,5 +123,75 @@ func TestRedirect_Whitelisting(t *testing.T) {
h.ServeHTTP(recorder, req("POST", "/login?backTo=https://evildomain.com/website", "username=bob&password=secret", TypeForm, AcceptHTML, BadReferer))
Equal(t, 303, recorder.Code)
Equal(t, "/", recorder.Header().Get("Location"))
}

func TestRemoveSubDomain(t *testing.T) {
tests := []struct {
input string
output string
}{
{input: "sub.home.com", output: "home.com"},
{input: "tld", output: "tld"},
{input: "home.com", output: "com"},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s should be %s", tt.input, tt.output), func(t *testing.T) {
Equal(t, tt.output, removeSubdomain(tt.input))
})
}
}

func TestHaveSubdomain(t *testing.T) {
tests := []struct {
input string
expect bool
}{
{input: "sub.home.com", expect: true},
{input: "tld", expect: false},
{input: "home.com", expect: false},
{input: "home.com.", expect: false},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s should be %v", tt.input, tt.expect), func(t *testing.T) {
Equal(t, tt.expect, haveSubdomain(tt.input))
})
}
}

func TestRedirect_Subdomain(t *testing.T) {

cfg := DefaultConfig()
cfg.RedirectAllowSubdomain = true
h := &Handler{
backends: []Backend{
NewSimpleBackend(map[string]string{"bob": "secret"}),
},
oauth: oauth2.NewManager(),
config: cfg,
}
recorder := httptest.NewRecorder()
h.ServeHTTP(recorder, req("POST", "http://auth.home.com/login?backTo=https://sub.home.com/website", "username=bob&password=secret", TypeForm, AcceptHTML, BadReferer))
Equal(t, 303, recorder.Code)
Equal(t, "https://sub.home.com/website", recorder.Header().Get("Location"))

// need at least one subdomain
recorder = httptest.NewRecorder()
h.ServeHTTP(recorder, req("POST", "http://home.com/login?backTo=https://google.com/website", "username=bob&password=secret", TypeForm, AcceptHTML, BadReferer))
Equal(t, 303, recorder.Code)
Equal(t, "/", recorder.Header().Get("Location"))

// make sure extra . is ignored
recorder = httptest.NewRecorder()
h.ServeHTTP(recorder, req("POST", "http://home.com./login?backTo=https://google.com./website", "username=bob&password=secret", TypeForm, AcceptHTML, BadReferer))
Equal(t, 303, recorder.Code)
Equal(t, "/", recorder.Header().Get("Location"))

// not allowed if current host is unknown
recorder = httptest.NewRecorder()
h.ServeHTTP(recorder, req("POST", "/login?backTo=https://sub.home.com/website", "username=bob&password=secret", TypeForm, AcceptHTML, BadReferer))
Equal(t, 303, recorder.Code)
Equal(t, "/", recorder.Header().Get("Location"))

}