Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serve /alpaca.pac to send non-DIRECT requests to alpaca #16

Merged
merged 6 commits into from
Sep 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions contextid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package main

import (
"context"
"net/http"
)

// AddContextID wraps a http.Handler to add a strictly increasing
// uint to the context of the http.Request with the key "id" (string)
// as it passes through the request to the next handler.
func AddContextID(next http.Handler) http.Handler {
// TODO(#17): Use sync/atomic AddUint64 instead of channel/goroutine
ids := make(chan uint)
go func() {
for id := uint(0); ; id++ {
ids <- id
}
}()
camh- marked this conversation as resolved.
Show resolved Hide resolved
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// TODO(#17): Use package scoped type instead of string for key
ctx := context.WithValue(req.Context(), "id", <-ids)
camh- marked this conversation as resolved.
Show resolved Hide resolved
next.ServeHTTP(w, req.WithContext(ctx))
})
}
35 changes: 35 additions & 0 deletions contextid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package main

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strconv"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func getIDFromRequest(t *testing.T, server *httptest.Server) uint {
res, err := http.Get(server.URL)
require.NoError(t, err)
b, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
id, err := strconv.ParseUint(string(b), 10, 64)
require.NoError(t, err)
return uint(id)
}

func TestContextID(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id, ok := r.Context().Value("id").(uint)
assert.True(t, ok, "Unexpected type for context id value")
_, err := w.Write([]byte(strconv.FormatUint(uint64(id), 10)))
require.NoError(t, err)
})
server := httptest.NewServer(AddContextID(handler))
defer server.Close()
assert.Equal(t, uint(0), getIDFromRequest(t, server))
assert.Equal(t, uint(1), getIDFromRequest(t, server))
}
32 changes: 14 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"crypto/tls"
"flag"
"fmt"
"golang.org/x/crypto/ssh/terminal"
"log"
"net/http"
"net/url"
"os"
"os/user"

"golang.org/x/crypto/ssh/terminal"
)

var getCredentialsFromKeyring func() (authenticator, error)
Expand Down Expand Up @@ -63,22 +63,18 @@ func main() {
}
}

var handler ProxyHandler
if len(pacURL) == 0 {
log.Println("No PAC URL specified or detected; all requests will be made directly")
handler = NewProxyHandler(func(req *http.Request) (*url.URL, error) {
log.Printf(`[%d] %s %s via "DIRECT"`,
req.Context().Value("id"), req.Method, req.URL)
return nil, nil
}, nil)
} else if _, err := url.Parse(pacURL); err != nil {
log.Fatalf("Couldn't find a valid PAC URL: %v", pacURL)
} else {
pf := NewProxyFinder(pacURL)
handler = NewProxyHandler(func(req *http.Request) (*url.URL, error) {
return pf.findProxyForRequest(req)
}, &a)
}
pacWrapper := NewPACWrapper(PACData{Port: *port})
proxyFinder := NewProxyFinder(pacURL, pacWrapper)
proxyHandler := NewProxyHandler(proxyFinder.findProxyForRequest, &a)
mux := http.NewServeMux()
pacWrapper.SetupHandlers(mux)

// build the handler by wrapping middleware upon middleware
var handler http.Handler = mux
handler = RequestLogger(handler)
handler = proxyHandler.WrapHandler(handler)
handler = proxyFinder.WrapHandler(handler)
handler = AddContextID(handler)

s := &http.Server{
// Set the addr to localhost so that we only listen locally.
Expand Down
73 changes: 73 additions & 0 deletions pacwrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package main

import (
"bytes"
"log"
"net/http"
"text/template"
)

// PACData contains program configuration to be made available to the pacWrapTmpl.
type PACData struct {
Port int
}

type pacData struct {
PACData
UpstreamPAC string
}

type PACWrapper struct {
data pacData
tmpl *template.Template
alpacaPAC string
}

// PACWrapper template for serving a PAC file to point at alpaca or DIRECT. If we have a valid
// PAC file, we wrap that PAC file with a wrapper function that only returns "DIRECT" or
// "localhost:port". If we do not have a PAC file, the PAC function we serve only returns "DIRECT",
// which should prevent all requests reaching us.
var pacWrapTmpl = `// Wrapped for and by alpaca
function FindProxyForURL(url, host) {
{{ if .UpstreamPAC }}
return FindProxyForURL(url, host) === "DIRECT" ? "DIRECT" : "PROXY localhost:{{.Port}}";
{{.UpstreamPAC}}
{{ else }}
return "DIRECT";
{{ end }}
}
`

func NewPACWrapper(data PACData) *PACWrapper {
t := template.Must(template.New("alpaca").Parse(pacWrapTmpl))
return &PACWrapper{pacData{data, ""}, t, ""}
}

func (pw *PACWrapper) Wrap(pacjs []byte) {
pac := string(pacjs)
if pac == pw.data.UpstreamPAC && pw.alpacaPAC != "" {
return
}
pw.data.UpstreamPAC = pac
b := &bytes.Buffer{}
if err := pw.tmpl.Execute(b, pw.data); err != nil {
log.Printf("error executing PAC wrap template: %v\n", err)
return
}
pw.alpacaPAC = b.String()
}

func (pw *PACWrapper) SetupHandlers(mux *http.ServeMux) {
mux.HandleFunc("/alpaca.pac", pw.handlePAC)
}

func (pw *PACWrapper) handlePAC(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/x-ns-proxy-autoconfig")
if _, err := w.Write([]byte(pw.alpacaPAC)); err != nil {
log.Printf("Error writing PAC to response: %v\n", err)
}
}
47 changes: 47 additions & 0 deletions pacwrapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestWrapPAC(t *testing.T) {
camh- marked this conversation as resolved.
Show resolved Hide resolved
pw := NewPACWrapper(PACData{Port: 1234})
pac := `function FindProxyForURL(url, host) { return "DIRECT" }`
pw.Wrap([]byte(pac))
assert.Contains(t, pw.alpacaPAC, pac)
assert.Contains(t, pw.alpacaPAC, `"DIRECT" : "PROXY localhost:1234"`)
}

func TestWrapEmptyPAC(t *testing.T) {
pw := NewPACWrapper(PACData{Port: 1234})
pw.Wrap(nil)
assert.Contains(t, pw.alpacaPAC, `return "DIRECT"`)
}

func TestPACServe(t *testing.T) {
pw := NewPACWrapper(PACData{Port: 1234})
pac := `function FindProxyForURL(url, host) { return "DIRECT" }`
pw.Wrap([]byte(pac))
mux := http.NewServeMux()
pw.SetupHandlers(mux)
server := httptest.NewServer(mux)
defer server.Close()

resp, err := http.Get(server.URL + "/alpaca.pac")
require.NoError(t, err)

assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, "application/x-ns-proxy-autoconfig", resp.Header.Get("Content-Type"))
b, err := ioutil.ReadAll(resp.Body)
body := string(b)
require.NoError(t, err)
assert.Contains(t, body, pac)
assert.Contains(t, body, `"DIRECT" : "PROXY localhost:1234"`)
resp.Body.Close()
}
28 changes: 15 additions & 13 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
Expand All @@ -17,29 +16,32 @@ import (
type ProxyHandler struct {
transport *http.Transport
auth *authenticator
ids chan uint
}

type proxyFunc func(*http.Request) (*url.URL, error)

func NewProxyHandler(proxy proxyFunc, auth *authenticator) ProxyHandler {
return newProxyHandler(&http.Transport{Proxy: proxy}, auth)
return ProxyHandler{&http.Transport{Proxy: proxy}, auth}
}

func newProxyHandler(tr *http.Transport, auth *authenticator) ProxyHandler {
ids := make(chan uint)
go func() {
for id := uint(0); ; id++ {
ids <- id
func (ph ProxyHandler) WrapHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Pass CONNECT requests and absolute-form URIs to the ProxyHandler.
// If the request URL has a scheme, it is an absolute-form URI
// (RFC 7230 Section 5.3.2).
if req.Method == http.MethodConnect || req.URL.Scheme != "" {
ph.ServeHTTP(w, req)
return
}
}()
return ProxyHandler{tr, auth, ids}
// The request URI is an origin-form or asterisk-form target which we
// handle as an origin server (RFC 7230 5.3). authority-form URIs
// are only for CONNECT, which has already been dispatched to the
// ProxyHandler.
next.ServeHTTP(w, req)
})
}

func (ph ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
ctx = context.WithValue(ctx, "id", <-ph.ids)
req = req.WithContext(ctx)
deleteRequestHeaders(req)
if req.Method == http.MethodConnect {
ph.handleConnect(w, req)
Expand Down
27 changes: 23 additions & 4 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type testServer struct {
Expand Down Expand Up @@ -74,7 +75,9 @@ func TestGetViaProxy(t *testing.T) {
requests := make(chan string, 2)
server := httptest.NewServer(testServer{requests})
defer server.Close()
proxy := httptest.NewServer(testProxy{requests, "proxy", newDirectProxy()})
// Proxy request should not go to the mux. The empty mux will always return 404.
mux := http.NewServeMux()
proxy := httptest.NewServer(testProxy{requests, "proxy", newDirectProxy().WrapHandler(mux)})
camh- marked this conversation as resolved.
Show resolved Hide resolved
defer proxy.Close()
tr := &http.Transport{Proxy: proxyServer(t, proxy)}
testGetRequest(t, tr, server.URL)
Expand All @@ -87,7 +90,9 @@ func TestGetOverTlsViaProxy(t *testing.T) {
requests := make(chan string, 2)
server := httptest.NewTLSServer(testServer{requests})
defer server.Close()
proxy := httptest.NewServer(testProxy{requests, "proxy", newDirectProxy()})
// Proxy request should not go to the mux. The empty mux will always return 404.
mux := http.NewServeMux()
proxy := httptest.NewServer(testProxy{requests, "proxy", newDirectProxy().WrapHandler(mux)})
defer proxy.Close()
tr := &http.Transport{Proxy: proxyServer(t, proxy), TLSClientConfig: tlsConfig(server)}
testGetRequest(t, tr, server.URL)
Expand All @@ -96,6 +101,20 @@ func TestGetOverTlsViaProxy(t *testing.T) {
assert.Equal(t, "GET to server", <-requests)
}

func TestGetOriginURLsNotProxied(t *testing.T) {
requests := make(chan string, 2)
mux := http.NewServeMux()
mux.HandleFunc("/origin", func(w http.ResponseWriter, req *http.Request) {
_, err := w.Write([]byte("Hello, client\n"))
require.NoError(t, err)
})
proxy := httptest.NewServer(testProxy{requests, "proxy", newDirectProxy().WrapHandler(mux)})
defer proxy.Close()
testGetRequest(t, &http.Transport{}, proxy.URL+"/origin")
require.Len(t, requests, 1)
assert.Equal(t, "GET to proxy", <-requests)
}

func TestGetViaTwoProxies(t *testing.T) {
requests := make(chan string, 3)
server := httptest.NewServer(testServer{requests})
Expand Down
Loading