Skip to content

Commit

Permalink
Add golang.org/x/net/context support.
Browse files Browse the repository at this point in the history
Read http://blog.golang.org/context for its benefits.

This PR has not utilized contexts yet; just passing them
to every customization points to help them add/retrieve
request context values.
  • Loading branch information
ymmt2005 committed Mar 8, 2016
1 parent 3a873e9 commit 385bbe4
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 30 deletions.
41 changes: 27 additions & 14 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"io"
"net"
"strings"

"golang.org/x/net/context"
)

const (
Expand Down Expand Up @@ -34,7 +36,7 @@ var (

// AddressRewriter is used to rewrite a destination transparently
type AddressRewriter interface {
Rewrite(request *Request) *AddrSpec
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
}

// AddrSpec is used to return the target AddrSpec
Expand Down Expand Up @@ -105,33 +107,36 @@ func NewRequest(bufConn io.Reader) (*Request, error) {

// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(req *Request, conn conn) error {
ctx := context.Background()

// Resolve the address if we have a FQDN
dest := req.DestAddr
if dest.FQDN != "" {
addr, err := s.config.Resolver.Resolve(dest.FQDN)
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := sendReply(conn, hostUnreachable, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
}
ctx = ctx_
dest.IP = addr
}

// Apply any address rewrites
req.realDestAddr = req.DestAddr
if s.config.Rewriter != nil {
req.realDestAddr = s.config.Rewriter.Rewrite(req)
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
}

// Switch on the command
switch req.Command {
case ConnectCommand:
return s.handleConnect(conn, req)
return s.handleConnect(ctx, conn, req)
case BindCommand:
return s.handleBind(conn, req)
return s.handleBind(ctx, conn, req)
case AssociateCommand:
return s.handleAssociate(conn, req)
return s.handleAssociate(ctx, conn, req)
default:
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
Expand All @@ -141,22 +146,26 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
}

// handleConnect is used to handle a connect command
func (s *Server) handleConnect(conn conn, req *Request) error {
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if !s.config.Rules.Allow(req) {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}

// Attempt to connect
addr := (&net.TCPAddr{IP: req.realDestAddr.IP, Port: req.realDestAddr.Port}).String()
dial := s.config.Dial
if dial == nil {
dial = net.Dial
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
return net.Dial(net_, addr)
}
}
target, err := dial("tcp", addr)
target, err := dial(ctx, "tcp", addr)
if err != nil {
msg := err.Error()
resp := hostUnreachable
Expand Down Expand Up @@ -196,13 +205,15 @@ func (s *Server) handleConnect(conn conn, req *Request) error {
}

// handleBind is used to handle a connect command
func (s *Server) handleBind(conn conn, req *Request) error {
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if !s.config.Rules.Allow(req) {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}

// TODO: Support bind
Expand All @@ -213,13 +224,15 @@ func (s *Server) handleBind(conn conn, req *Request) error {
}

// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(conn conn, req *Request) error {
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if !s.config.Rules.Allow(req) {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}

// TODO: Support associate
Expand Down
10 changes: 6 additions & 4 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@ package socks5

import (
"net"

"golang.org/x/net/context"
)

// NameResolver is used to implement custom name resolution
type NameResolver interface {
Resolve(name string) (net.IP, error)
Resolve(ctx context.Context, name string) (context.Context, net.IP, error)
}

// DNSResolver uses the system DNS to resolve host names
type DNSResolver struct{}

func (d DNSResolver) Resolve(name string) (net.IP, error) {
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := net.ResolveIPAddr("ip", name)
if err != nil {
return nil, err
return ctx, nil, err
}
return addr.IP, err
return ctx, addr.IP, err
}
5 changes: 4 additions & 1 deletion resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package socks5

import (
"testing"

"golang.org/x/net/context"
)

func TestDNSResolver(t *testing.T) {
d := DNSResolver{}
ctx := context.Background()

addr, err := d.Resolve("localhost")
_, addr, err := d.Resolve(ctx, "localhost")
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down
16 changes: 10 additions & 6 deletions ruleset.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package socks5

import (
"golang.org/x/net/context"
)

// RuleSet is used to provide custom rules to allow or prohibit actions
type RuleSet interface {
Allow(req *Request) bool
Allow(ctx context.Context, req *Request) (context.Context, bool)
}

// PermitAll returns a RuleSet which allows all types of connections
Expand All @@ -23,15 +27,15 @@ type PermitCommand struct {
EnableAssociate bool
}

func (p *PermitCommand) Allow(req *Request) bool {
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
switch req.Command {
case ConnectCommand:
return p.EnableConnect
return ctx, p.EnableConnect
case BindCommand:
return p.EnableBind
return ctx, p.EnableBind
case AssociateCommand:
return p.EnableAssociate
return ctx, p.EnableAssociate
}

return false
return ctx, false
}
13 changes: 9 additions & 4 deletions ruleset_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
package socks5

import "testing"
import (
"testing"

"golang.org/x/net/context"
)

func TestPermitCommand(t *testing.T) {
ctx := context.Background()
r := &PermitCommand{true, false, false}

if !r.Allow(&Request{Command: ConnectCommand}) {
if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok {
t.Fatalf("expect connect")
}

if r.Allow(&Request{Command: BindCommand}) {
if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok {
t.Fatalf("do not expect bind")
}

if r.Allow(&Request{Command: AssociateCommand}) {
if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok {
t.Fatalf("do not expect associate")
}
}
4 changes: 3 additions & 1 deletion socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"log"
"net"
"os"

"golang.org/x/net/context"
)

const (
Expand Down Expand Up @@ -45,7 +47,7 @@ type Config struct {
Logger *log.Logger

// Optional function for dialing out
Dial func(network, addr string) (net.Conn, error)
Dial func(ctx context.Context, network, addr string) (net.Conn, error)
}

// Server is reponsible for accepting connections and handling
Expand Down

0 comments on commit 385bbe4

Please sign in to comment.