diff --git a/auth.go b/auth/auth.go similarity index 90% rename from auth.go rename to auth/auth.go index b88721d..7378251 100644 --- a/auth.go +++ b/auth/auth.go @@ -1,4 +1,4 @@ -package main +package auth import ( "bytes" @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "net/url" + "os" "strconv" "strings" "sync" @@ -15,6 +16,8 @@ import ( "github.com/tg123/go-htpasswd" "golang.org/x/crypto/bcrypt" + + clog "github.com/SenseUnit/dumbproxy/log" ) const AUTH_REQUIRED_MSG = "Proxy authentication required.\n" @@ -27,7 +30,7 @@ type Auth interface { Stop() } -func NewAuth(paramstr string, logger *CondLogger) (Auth, error) { +func NewAuth(paramstr string, logger *clog.CondLogger) (Auth, error) { url, err := url.Parse(paramstr) if err != nil { return nil, err @@ -47,7 +50,7 @@ func NewAuth(paramstr string, logger *CondLogger) (Auth, error) { } } -func NewStaticAuth(param_url *url.URL, logger *CondLogger) (*BasicAuth, error) { +func NewStaticAuth(param_url *url.URL, logger *clog.CondLogger) (*BasicAuth, error) { values, err := url.ParseQuery(param_url.RawQuery) if err != nil { return nil, err @@ -100,14 +103,14 @@ type BasicAuth struct { pwFilename string pwFile *htpasswd.File pwMux sync.RWMutex - logger *CondLogger + logger *clog.CondLogger hiddenDomain string stopOnce sync.Once stopChan chan struct{} lastReloaded time.Time } -func NewBasicFileAuth(param_url *url.URL, logger *CondLogger) (*BasicAuth, error) { +func NewBasicFileAuth(param_url *url.URL, logger *clog.CondLogger) (*BasicAuth, error) { values, err := url.ParseQuery(param_url.RawQuery) if err != nil { return nil, err @@ -268,3 +271,18 @@ func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, b } func (_ CertAuth) Stop() {} + +func fileModTime(filename string) (time.Time, error) { + f, err := os.Open(filename) + if err != nil { + return time.Time{}, fmt.Errorf("fileModTime(): can't open file %q: %w", filename, err) + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return time.Time{}, fmt.Errorf("fileModTime(): can't stat file %q: %w", filename, err) + } + + return fi.ModTime(), nil +} diff --git a/dialer/dialer.go b/dialer/dialer.go new file mode 100644 index 0000000..44a2f9d --- /dev/null +++ b/dialer/dialer.go @@ -0,0 +1,85 @@ +package dialer + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "sync" + + xproxy "golang.org/x/net/proxy" +) + +type Dialer = xproxy.Dialer +type ContextDialer = xproxy.ContextDialer + +var registerDialerTypesOnce sync.Once + +func ProxyDialerFromURL(proxyURL string, forward Dialer) (Dialer, error) { + registerDialerTypesOnce.Do(func() { + xproxy.RegisterDialerType("http", HTTPProxyDialerFromURL) + xproxy.RegisterDialerType("https", HTTPProxyDialerFromURL) + }) + parsedURL, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("unable to parse proxy URL: %w", err) + } + d, err := xproxy.FromURL(parsedURL, forward) + if err != nil { + return nil, fmt.Errorf("unable to construct proxy dialer from URL %q: %w", proxyURL, err) + } + return d, nil +} + +type wrappedDialer struct { + d Dialer +} + +func (wd wrappedDialer) Dial(net, address string) (net.Conn, error) { + return wd.d.Dial(net, address) +} + +func (wd wrappedDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var ( + conn net.Conn + done = make(chan struct{}, 1) + err error + ) + go func() { + conn, err = wd.d.Dial(network, address) + close(done) + if conn != nil && ctx.Err() != nil { + conn.Close() + } + }() + select { + case <-ctx.Done(): + err = ctx.Err() + case <-done: + } + return conn, err +} + +func MaybeWrapWithContextDialer(d Dialer) ContextDialer { + if xd, ok := d.(ContextDialer); ok { + return xd + } + return wrappedDialer{d} +} + +func parseIPList(list string) ([]net.IP, error) { + res := make([]net.IP, 0) + for _, elem := range strings.Split(list, ",") { + elem = strings.TrimSpace(elem) + if len(elem) == 0 { + continue + } + if parsed := net.ParseIP(elem); parsed == nil { + return nil, fmt.Errorf("unable to parse IP address %q", elem) + } else { + res = append(res, parsed) + } + } + return res, nil +} diff --git a/hintdialer.go b/dialer/hintdialer.go similarity index 99% rename from hintdialer.go rename to dialer/hintdialer.go index e76efce..d8a0408 100644 --- a/hintdialer.go +++ b/dialer/hintdialer.go @@ -1,4 +1,4 @@ -package main +package dialer import ( "context" diff --git a/upstream.go b/dialer/upstream.go similarity index 96% rename from upstream.go rename to dialer/upstream.go index d223ab5..9a325e2 100644 --- a/upstream.go +++ b/dialer/upstream.go @@ -1,4 +1,4 @@ -package main +package dialer import ( "bufio" @@ -29,7 +29,7 @@ func NewHTTPProxyDialer(address string, tls bool, userinfo *url.Userinfo, next D return &HTTPProxyDialer{ address: address, tls: tls, - next: maybeWrapWithContextDialer(next), + next: MaybeWrapWithContextDialer(next), userinfo: userinfo, } } @@ -106,7 +106,7 @@ func (d *HTTPProxyDialer) DialContext(ctx context.Context, network, address stri if d.userinfo != nil { fmt.Fprintf(&reqBuf, "Proxy-Authorization: %s\r\n", basicAuthHeader(d.userinfo)) } - fmt.Fprintf(&reqBuf, "User-Agent: dumbproxy/%s\r\n\r\n", version) + fmt.Fprintf(&reqBuf, "User-Agent: dumbproxy\r\n\r\n") _, err = io.Copy(conn, &reqBuf) if err != nil { conn.Close() diff --git a/handler.go b/handler/handler.go similarity index 88% rename from handler.go rename to handler/handler.go index cbca628..7d3d0cb 100644 --- a/handler.go +++ b/handler/handler.go @@ -1,4 +1,4 @@ -package main +package handler import ( "context" @@ -8,6 +8,10 @@ import ( "strings" "sync" "time" + + "github.com/SenseUnit/dumbproxy/auth" + "github.com/SenseUnit/dumbproxy/dialer" + clog "github.com/SenseUnit/dumbproxy/log" ) const HintsHeaderName = "X-Src-IP-Hints" @@ -18,8 +22,8 @@ type HandlerDialer interface { type ProxyHandler struct { timeout time.Duration - auth Auth - logger *CondLogger + auth auth.Auth + logger *clog.CondLogger dialer HandlerDialer httptransport http.RoundTripper outbound map[string]string @@ -27,8 +31,8 @@ type ProxyHandler struct { userIPHints bool } -func NewProxyHandler(timeout time.Duration, auth Auth, dialer HandlerDialer, - userIPHints bool, logger *CondLogger) *ProxyHandler { +func NewProxyHandler(timeout time.Duration, auth auth.Auth, dialer HandlerDialer, + userIPHints bool, logger *clog.CondLogger) *ProxyHandler { httptransport := &http.Transport{ DialContext: dialer.DialContext, DisableKeepAlives: true, @@ -122,14 +126,14 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { if originator, isLoopback := s.isLoopback(req); isLoopback { s.logger.Critical("Loopback tunnel detected: %s is an outbound "+ "address for another request from %s", req.RemoteAddr, originator) - http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + http.Error(wr, auth.BAD_REQ_MSG, http.StatusBadRequest) return } isConnect := strings.ToUpper(req.Method) == "CONNECT" if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || req.Host == "" && req.ProtoMajor == 2 { - http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + http.Error(wr, auth.BAD_REQ_MSG, http.StatusBadRequest) return } @@ -149,7 +153,7 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { ipHints = &hintValues[0] } } - newCtx := context.WithValue(req.Context(), BoundDialerContextKey{}, BoundDialerContextValue{ + newCtx := context.WithValue(req.Context(), dialer.BoundDialerContextKey{}, dialer.BoundDialerContextValue{ Hints: ipHints, LocalAddr: trimAddrPort(localAddr), }) diff --git a/handler/proxy.go b/handler/proxy.go new file mode 100644 index 0000000..df022dd --- /dev/null +++ b/handler/proxy.go @@ -0,0 +1,140 @@ +package handler + +import ( + "bufio" + "context" + "errors" + "io" + "net" + "net/http" + "sync" + "time" +) + +const COPY_BUF = 128 * 1024 + +func proxy(ctx context.Context, left, right net.Conn) { + wg := sync.WaitGroup{} + cpy := func(dst, src net.Conn) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + wg.Add(2) + go cpy(left, right) + go cpy(right, left) + groupdone := make(chan struct{}, 1) + go func() { + wg.Wait() + groupdone <- struct{}{} + }() + select { + case <-ctx.Done(): + left.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return +} + +func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { + wg := sync.WaitGroup{} + ltr := func(dst net.Conn, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + rtl := func(dst io.Writer, src io.Reader) { + defer wg.Done() + copyBody(dst, src) + } + wg.Add(2) + go ltr(right, leftreader) + go rtl(leftwriter, right) + groupdone := make(chan struct{}, 1) + go func() { + wg.Wait() + groupdone <- struct{}{} + }() + select { + case <-ctx.Done(): + leftreader.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Connection", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func delHopHeaders(header http.Header) { + for _, h := range hopHeaders { + header.Del(h) + } +} + +func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { + hj, ok := hijackable.(http.Hijacker) + if !ok { + return nil, nil, errors.New("Connection doesn't support hijacking") + } + conn, rw, err := hj.Hijack() + if err != nil { + return nil, nil, err + } + var emptytime time.Time + err = conn.SetDeadline(emptytime) + if err != nil { + conn.Close() + return nil, nil, err + } + return conn, rw, nil +} + +func flush(flusher interface{}) bool { + f, ok := flusher.(http.Flusher) + if !ok { + return false + } + f.Flush() + return true +} + +func copyBody(wr io.Writer, body io.Reader) { + buf := make([]byte, COPY_BUF) + for { + bread, read_err := body.Read(buf) + var write_err error + if bread > 0 { + _, write_err = wr.Write(buf[:bread]) + flush(wr) + } + if read_err != nil || write_err != nil { + break + } + } +} diff --git a/condlog.go b/log/condlog.go similarity index 98% rename from condlog.go rename to log/condlog.go index 96a18f3..b40572a 100644 --- a/condlog.go +++ b/log/condlog.go @@ -1,4 +1,4 @@ -package main +package log import ( "fmt" diff --git a/logwriter.go b/log/logwriter.go similarity index 98% rename from logwriter.go rename to log/logwriter.go index 657c2f3..b36f727 100644 --- a/logwriter.go +++ b/log/logwriter.go @@ -1,4 +1,4 @@ -package main +package log import ( "errors" diff --git a/main.go b/main.go index 76a29da..9b9d616 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,11 @@ package main import ( "crypto/tls" + "crypto/x509" + "errors" "flag" "fmt" + "io/ioutil" "log" "net" "net/http" @@ -16,6 +19,12 @@ import ( "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/ssh/terminal" + + "github.com/SenseUnit/dumbproxy/auth" + "github.com/SenseUnit/dumbproxy/dialer" + "github.com/SenseUnit/dumbproxy/handler" + clog "github.com/SenseUnit/dumbproxy/log" ) var ( @@ -197,39 +206,44 @@ func run() int { return 0 } - logWriter := NewLogWriter(os.Stderr) + logWriter := clog.NewLogWriter(os.Stderr) defer logWriter.Close() - mainLogger := NewCondLogger(log.New(logWriter, "MAIN : ", + mainLogger := clog.NewCondLogger(log.New(logWriter, "MAIN : ", log.LstdFlags|log.Lshortfile), args.verbosity) - proxyLogger := NewCondLogger(log.New(logWriter, "PROXY : ", + proxyLogger := clog.NewCondLogger(log.New(logWriter, "PROXY : ", log.LstdFlags|log.Lshortfile), args.verbosity) - authLogger := NewCondLogger(log.New(logWriter, "AUTH : ", + authLogger := clog.NewCondLogger(log.New(logWriter, "AUTH : ", log.LstdFlags|log.Lshortfile), args.verbosity) - auth, err := NewAuth(args.auth, authLogger) + auth, err := auth.NewAuth(args.auth, authLogger) if err != nil { mainLogger.Critical("Failed to instantiate auth provider: %v", err) return 3 } defer auth.Stop() - var dialer Dialer = NewBoundDialer(new(net.Dialer), args.sourceIPHints) + var d dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints) for _, proxyURL := range args.proxy { - newDialer, err := proxyDialerFromURL(proxyURL, dialer) + newDialer, err := dialer.ProxyDialerFromURL(proxyURL, d) if err != nil { mainLogger.Critical("Failed to create dialer for proxy %q: %v", proxyURL, err) return 3 } - dialer = newDialer + d = newDialer } server := http.Server{ - Addr: args.bind_address, - Handler: NewProxyHandler(args.timeout, auth, maybeWrapWithContextDialer(dialer), args.userIPHints, proxyLogger), + Addr: args.bind_address, + Handler: handler.NewProxyHandler( + args.timeout, + auth, + dialer.MaybeWrapWithContextDialer(d), + args.userIPHints, + proxyLogger), ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile), ReadTimeout: 0, ReadHeaderTimeout: 0, @@ -309,6 +323,158 @@ func run() int { return 0 } +func makeServerTLSConfig(certfile, keyfile, cafile, ciphers string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) { + cfg := tls.Config{ + MinVersion: minVer, + MaxVersion: maxVer, + } + cert, err := tls.LoadX509KeyPair(certfile, keyfile) + if err != nil { + return nil, err + } + cfg.Certificates = []tls.Certificate{cert} + if cafile != "" { + roots := x509.NewCertPool() + certs, err := ioutil.ReadFile(cafile) + if err != nil { + return nil, err + } + if ok := roots.AppendCertsFromPEM(certs); !ok { + return nil, errors.New("Failed to load CA certificates") + } + cfg.ClientCAs = roots + cfg.ClientAuth = tls.VerifyClientCertIfGiven + } + cfg.CipherSuites = makeCipherList(ciphers) + if h2 { + cfg.NextProtos = []string{"h2", "http/1.1"} + } else { + cfg.NextProtos = []string{"http/1.1"} + } + return &cfg, nil +} + +func updateServerTLSConfig(cfg *tls.Config, cafile, ciphers string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) { + if cafile != "" { + roots := x509.NewCertPool() + certs, err := ioutil.ReadFile(cafile) + if err != nil { + return nil, err + } + if ok := roots.AppendCertsFromPEM(certs); !ok { + return nil, errors.New("Failed to load CA certificates") + } + cfg.ClientCAs = roots + cfg.ClientAuth = tls.VerifyClientCertIfGiven + } + cfg.CipherSuites = makeCipherList(ciphers) + if h2 { + cfg.NextProtos = []string{"h2", "http/1.1", "acme-tls/1"} + } else { + cfg.NextProtos = []string{"http/1.1", "acme-tls/1"} + } + cfg.MinVersion = minVer + cfg.MaxVersion = maxVer + return cfg, nil +} + +func makeCipherList(ciphers string) []uint16 { + if ciphers == "" { + return nil + } + + cipherIDs := make(map[string]uint16) + for _, cipher := range tls.CipherSuites() { + cipherIDs[cipher.Name] = cipher.ID + } + + cipherNameList := strings.Split(ciphers, ":") + cipherIDList := make([]uint16, 0, len(cipherNameList)) + + for _, name := range cipherNameList { + id, ok := cipherIDs[name] + if !ok { + log.Printf("WARNING: Unknown cipher \"%s\"", name) + } + cipherIDList = append(cipherIDList, id) + } + + return cipherIDList +} + +func list_ciphers() { + for _, cipher := range tls.CipherSuites() { + fmt.Println(cipher.Name) + } +} + +func passwd(filename string, cost int, args ...string) error { + var ( + username, password, password2 string + err error + ) + + if len(args) > 0 { + username = args[0] + } else { + username, err = prompt("Enter username: ", false) + if err != nil { + return fmt.Errorf("can't get username: %w", err) + } + } + + if len(args) > 1 { + password = args[1] + } else { + password, err = prompt("Enter password: ", true) + if err != nil { + return fmt.Errorf("can't get password: %w", err) + } + password2, err = prompt("Repeat password: ", true) + if err != nil { + return fmt.Errorf("can't get password (repeat): %w", err) + } + if password != password2 { + return fmt.Errorf("passwords do not match") + } + } + + hash, err := bcrypt.GenerateFromPassword([]byte(password), cost) + if err != nil { + return fmt.Errorf("can't generate password hash: %w", err) + } + + f, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("can't open file: %w", err) + } + defer f.Close() + + _, err = f.WriteString(fmt.Sprintf("%s:%s\n", username, hash)) + if err != nil { + return fmt.Errorf("can't write to file: %w", err) + } + + return nil +} + +func prompt(prompt string, secure bool) (string, error) { + var input string + fmt.Print(prompt) + + if secure { + b, err := terminal.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + return "", err + } + input = string(b) + fmt.Println() + } else { + fmt.Scanln(&input) + } + return input, nil +} + func main() { os.Exit(run()) } diff --git a/utils.go b/utils.go deleted file mode 100644 index 1e11180..0000000 --- a/utils.go +++ /dev/null @@ -1,392 +0,0 @@ -package main - -import ( - "bufio" - "context" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "net/http" - "net/url" - "os" - "strings" - "sync" - "time" - - "golang.org/x/crypto/bcrypt" - "golang.org/x/crypto/ssh/terminal" - xproxy "golang.org/x/net/proxy" -) - -const COPY_BUF = 128 * 1024 - -func proxy(ctx context.Context, left, right net.Conn) { - wg := sync.WaitGroup{} - cpy := func(dst, src net.Conn) { - defer wg.Done() - io.Copy(dst, src) - dst.Close() - } - wg.Add(2) - go cpy(left, right) - go cpy(right, left) - groupdone := make(chan struct{}, 1) - go func() { - wg.Wait() - groupdone <- struct{}{} - }() - select { - case <-ctx.Done(): - left.Close() - right.Close() - case <-groupdone: - return - } - <-groupdone - return -} - -func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { - wg := sync.WaitGroup{} - ltr := func(dst net.Conn, src io.Reader) { - defer wg.Done() - io.Copy(dst, src) - dst.Close() - } - rtl := func(dst io.Writer, src io.Reader) { - defer wg.Done() - copyBody(dst, src) - } - wg.Add(2) - go ltr(right, leftreader) - go rtl(leftwriter, right) - groupdone := make(chan struct{}, 1) - go func() { - wg.Wait() - groupdone <- struct{}{} - }() - select { - case <-ctx.Done(): - leftreader.Close() - right.Close() - case <-groupdone: - return - } - <-groupdone - return -} - -// Hop-by-hop headers. These are removed when sent to the backend. -// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html -var hopHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Connection", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailers", - "Transfer-Encoding", - "Upgrade", -} - -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } - } -} - -func delHopHeaders(header http.Header) { - for _, h := range hopHeaders { - header.Del(h) - } -} - -func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { - hj, ok := hijackable.(http.Hijacker) - if !ok { - return nil, nil, errors.New("Connection doesn't support hijacking") - } - conn, rw, err := hj.Hijack() - if err != nil { - return nil, nil, err - } - var emptytime time.Time - err = conn.SetDeadline(emptytime) - if err != nil { - conn.Close() - return nil, nil, err - } - return conn, rw, nil -} - -func flush(flusher interface{}) bool { - f, ok := flusher.(http.Flusher) - if !ok { - return false - } - f.Flush() - return true -} - -func copyBody(wr io.Writer, body io.Reader) { - buf := make([]byte, COPY_BUF) - for { - bread, read_err := body.Read(buf) - var write_err error - if bread > 0 { - _, write_err = wr.Write(buf[:bread]) - flush(wr) - } - if read_err != nil || write_err != nil { - break - } - } -} - -func makeServerTLSConfig(certfile, keyfile, cafile, ciphers string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) { - cfg := tls.Config{ - MinVersion: minVer, - MaxVersion: maxVer, - } - cert, err := tls.LoadX509KeyPair(certfile, keyfile) - if err != nil { - return nil, err - } - cfg.Certificates = []tls.Certificate{cert} - if cafile != "" { - roots := x509.NewCertPool() - certs, err := ioutil.ReadFile(cafile) - if err != nil { - return nil, err - } - if ok := roots.AppendCertsFromPEM(certs); !ok { - return nil, errors.New("Failed to load CA certificates") - } - cfg.ClientCAs = roots - cfg.ClientAuth = tls.VerifyClientCertIfGiven - } - cfg.CipherSuites = makeCipherList(ciphers) - if h2 { - cfg.NextProtos = []string{"h2", "http/1.1"} - } else { - cfg.NextProtos = []string{"http/1.1"} - } - return &cfg, nil -} - -func updateServerTLSConfig(cfg *tls.Config, cafile, ciphers string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) { - if cafile != "" { - roots := x509.NewCertPool() - certs, err := ioutil.ReadFile(cafile) - if err != nil { - return nil, err - } - if ok := roots.AppendCertsFromPEM(certs); !ok { - return nil, errors.New("Failed to load CA certificates") - } - cfg.ClientCAs = roots - cfg.ClientAuth = tls.VerifyClientCertIfGiven - } - cfg.CipherSuites = makeCipherList(ciphers) - if h2 { - cfg.NextProtos = []string{"h2", "http/1.1", "acme-tls/1"} - } else { - cfg.NextProtos = []string{"http/1.1", "acme-tls/1"} - } - cfg.MinVersion = minVer - cfg.MaxVersion = maxVer - return cfg, nil -} - -func makeCipherList(ciphers string) []uint16 { - if ciphers == "" { - return nil - } - - cipherIDs := make(map[string]uint16) - for _, cipher := range tls.CipherSuites() { - cipherIDs[cipher.Name] = cipher.ID - } - - cipherNameList := strings.Split(ciphers, ":") - cipherIDList := make([]uint16, 0, len(cipherNameList)) - - for _, name := range cipherNameList { - id, ok := cipherIDs[name] - if !ok { - log.Printf("WARNING: Unknown cipher \"%s\"", name) - } - cipherIDList = append(cipherIDList, id) - } - - return cipherIDList -} - -func list_ciphers() { - for _, cipher := range tls.CipherSuites() { - fmt.Println(cipher.Name) - } -} - -func passwd(filename string, cost int, args ...string) error { - var ( - username, password, password2 string - err error - ) - - if len(args) > 0 { - username = args[0] - } else { - username, err = prompt("Enter username: ", false) - if err != nil { - return fmt.Errorf("can't get username: %w", err) - } - } - - if len(args) > 1 { - password = args[1] - } else { - password, err = prompt("Enter password: ", true) - if err != nil { - return fmt.Errorf("can't get password: %w", err) - } - password2, err = prompt("Repeat password: ", true) - if err != nil { - return fmt.Errorf("can't get password (repeat): %w", err) - } - if password != password2 { - return fmt.Errorf("passwords do not match") - } - } - - hash, err := bcrypt.GenerateFromPassword([]byte(password), cost) - if err != nil { - return fmt.Errorf("can't generate password hash: %w", err) - } - - f, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { - return fmt.Errorf("can't open file: %w", err) - } - defer f.Close() - - _, err = f.WriteString(fmt.Sprintf("%s:%s\n", username, hash)) - if err != nil { - return fmt.Errorf("can't write to file: %w", err) - } - - return nil -} - -func fileModTime(filename string) (time.Time, error) { - f, err := os.Open(filename) - if err != nil { - return time.Time{}, fmt.Errorf("fileModTime(): can't open file %q: %w", filename, err) - } - defer f.Close() - - fi, err := f.Stat() - if err != nil { - return time.Time{}, fmt.Errorf("fileModTime(): can't stat file %q: %w", filename, err) - } - - return fi.ModTime(), nil -} - -func prompt(prompt string, secure bool) (string, error) { - var input string - fmt.Print(prompt) - - if secure { - b, err := terminal.ReadPassword(int(os.Stdin.Fd())) - if err != nil { - return "", err - } - input = string(b) - fmt.Println() - } else { - fmt.Scanln(&input) - } - return input, nil -} - -type Dialer xproxy.Dialer -type ContextDialer xproxy.ContextDialer - -var registerDialerTypesOnce sync.Once - -func proxyDialerFromURL(proxyURL string, forward Dialer) (Dialer, error) { - registerDialerTypesOnce.Do(func() { - xproxy.RegisterDialerType("http", HTTPProxyDialerFromURL) - xproxy.RegisterDialerType("https", HTTPProxyDialerFromURL) - }) - parsedURL, err := url.Parse(proxyURL) - if err != nil { - return nil, fmt.Errorf("unable to parse proxy URL: %w", err) - } - d, err := xproxy.FromURL(parsedURL, forward) - if err != nil { - return nil, fmt.Errorf("unable to construct proxy dialer from URL %q: %w", proxyURL, err) - } - return d, nil -} - -type wrappedDialer struct { - d Dialer -} - -func (wd wrappedDialer) Dial(net, address string) (net.Conn, error) { - return wd.d.Dial(net, address) -} - -func (wd wrappedDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - var ( - conn net.Conn - done = make(chan struct{}, 1) - err error - ) - go func() { - conn, err = wd.d.Dial(network, address) - close(done) - if conn != nil && ctx.Err() != nil { - conn.Close() - } - }() - select { - case <-ctx.Done(): - err = ctx.Err() - case <-done: - } - return conn, err -} - -func maybeWrapWithContextDialer(d Dialer) ContextDialer { - if xd, ok := d.(ContextDialer); ok { - return xd - } - return wrappedDialer{d} -} - -func parseIPList(list string) ([]net.IP, error) { - res := make([]net.IP, 0) - for _, elem := range strings.Split(list, ",") { - elem = strings.TrimSpace(elem) - if len(elem) == 0 { - continue - } - if parsed := net.ParseIP(elem); parsed == nil { - return nil, fmt.Errorf("unable to parse IP address %q", elem) - } else { - res = append(res, parsed) - } - } - return res, nil -}