Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
mitigate ddos on create call
Browse files Browse the repository at this point in the history
  • Loading branch information
timothy committed Apr 16, 2017
1 parent 9e47dae commit dde254b
Show file tree
Hide file tree
Showing 12 changed files with 351 additions and 189 deletions.
6 changes: 3 additions & 3 deletions channels/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ func setUpChannel(t *testing.T, capacity int64) (*Sender, *Receiver) {
t.Fatal(err)
}

r, err := NewReceiver(DefaultReceiverConfig, receiverWIF.PrivKey)
r, err := NewReceiver(DefaultReceiverConfig, addr2, receiverWIF.PrivKey)
if err != nil {
t.Fatal(err)
}
createResp, err := r.Create(addr2, createReq)
createResp, err := r.Create(createReq)
if err != nil {
t.Fatal(err)
}
Expand All @@ -66,7 +66,7 @@ func setUpChannel(t *testing.T, capacity int64) (*Sender, *Receiver) {
if err != nil {
t.Fatal(err)
}
openResp, err := r.Open(capacity, openReq)
openResp, err := r.Open(capacity, createResp.FundingAddress, openReq)
if err != nil {
t.Fatal(err)
}
Expand Down
78 changes: 58 additions & 20 deletions channels/receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Receiver struct {
State SharedState
}

func NewReceiver(c ReceiverConfig, privKey *btcec.PrivateKey) (*Receiver, error) {
func NewReceiver(c ReceiverConfig, receiverOutput string, privKey *btcec.PrivateKey) (*Receiver, error) {
state := SharedState{
Net: c.Net,
Status: StatusCreated,
Expand All @@ -50,6 +50,11 @@ func NewReceiver(c ReceiverConfig, privKey *btcec.PrivateKey) (*Receiver, error)
}
state.ReceiverPubKey = pubKey.PubKey().SerializeCompressed()

if err := checkSupportedAddress(net, receiverOutput); err != nil {
return nil, errors.New("invalid receiverOutput")
}
state.ReceiverOutput = receiverOutput

return &Receiver{
config: c,
privKey: privKey,
Expand Down Expand Up @@ -92,13 +97,10 @@ func LoadReceiver(c ReceiverConfig, state SharedState, privKey *btcec.PrivateKey
}, nil
}

func (r *Receiver) Create(receiverOutput string, req *models.CreateRequest) (*models.CreateResponse, error) {
func (r *Receiver) Create(req *models.CreateRequest) (*models.CreateResponse, error) {
if r.State.Status != StatusCreated {
return nil, ErrNotStatusCreated
}
if err := checkSupportedAddress(r.net, receiverOutput); err != nil {
return nil, errors.New("invalid receiverOutput")
}

if req.Version != Version {
return nil, errors.New("unsupported version")
Expand All @@ -117,7 +119,6 @@ func (r *Receiver) Create(receiverOutput string, req *models.CreateRequest) (*mo
s.Version = Version
s.Timeout = r.config.Timeout
s.Fee = r.config.FeeRate * typicalCloseTxSize
s.ReceiverOutput = receiverOutput
s.SenderOutput = req.SenderOutput
s.SenderPubKey = req.SenderPubKey

Expand All @@ -126,21 +127,22 @@ func (r *Receiver) Create(receiverOutput string, req *models.CreateRequest) (*mo
return nil, err
}

r.State = s
//r.State = s

return &models.CreateResponse{
Version: r.State.Version,
Net: r.State.Net,
Timeout: r.State.Timeout,
Fee: r.State.Fee,
ReceiverPubKey: r.State.ReceiverPubKey,
ReceiverOutput: r.State.ReceiverOutput,
Version: s.Version,
Net: s.Net,
Timeout: s.Timeout,
Fee: s.Fee,
ReceiverPubKey: s.ReceiverPubKey,
ReceiverOutput: s.ReceiverOutput,
FundingAddress: fundingAddr,
}, nil
}

// TODO: add nconf param and validate according to config
func (r *Receiver) Open(amount int64, req *models.OpenRequest) (*models.OpenResponse, error) {
// TODO: pkscript instead of address
func (r *Receiver) Open(amount int64, address string, req *models.OpenRequest) (*models.OpenResponse, error) {
if r.State.Status != StatusCreated {
return nil, ErrNotStatusCreated
}
Expand All @@ -154,18 +156,54 @@ func (r *Receiver) Open(amount int64, req *models.OpenRequest) (*models.OpenResp
if len(req.SenderSig) == 0 {
return nil, errors.New("missing senderSig")
}
if req.Net != r.config.Net {
return nil, errors.New("wrong net")
}
if !bytes.Equal(req.ReceiverPubKey, r.State.ReceiverPubKey) {
return nil, errors.New("wrong receiverPubKey")
}
if req.ReceiverOutput != r.State.ReceiverOutput {
return nil, errors.New("wrong receiverOutput")
}

s := r.State
s.Status = StatusOpen
s.FundingTxID = req.TxID
s.FundingVout = req.Vout
s.Capacity = amount
s.SenderSig = req.SenderSig
s := SharedState{
Version: req.Version,
Net: req.Net,
Timeout: req.Timeout,
Fee: req.Fee,
Status: StatusOpen,
SenderPubKey: req.SenderPubKey,
ReceiverPubKey: req.ReceiverPubKey,
SenderOutput: req.SenderOutput,
ReceiverOutput: req.ReceiverOutput,
FundingTxID: req.TxID,
FundingVout: req.Vout,
Capacity: amount,
SenderSig: req.SenderSig,
}

_, fundingAddr, err := s.GetFundingScript()
if err != nil {
return nil, err
}
if fundingAddr != address {
return nil, errors.New("mismatched funding address")
}

if err := validateSenderSig(s, r.privKey); err != nil {
return nil, err
}

minFee := r.config.FeeRate * typicalCloseTxSize

acceptable := s.Version == Version &&
s.Timeout >= r.config.Timeout &&
s.Fee >= minFee

if !acceptable {
r.State.Status = StatusClosing
}

r.State = s

return &models.OpenResponse{}, nil
Expand Down
21 changes: 17 additions & 4 deletions channels/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ func (s *Sender) GotCreateResponse(resp *models.CreateResponse) error {
return errors.New("senderOutput is missing")
}

if !models.ValidateChannelID(resp.ID) {
return errors.New("invalid channel ID")
}
if resp.Version != Version {
return errors.New("unsupported version")
}
Expand Down Expand Up @@ -212,6 +209,17 @@ func (s *Sender) GetOpenRequest(txid string, vout uint32, amount int64) (*models
}

return &models.OpenRequest{
Version: s.State.Version,
Net: s.State.Net,
Timeout: s.State.Timeout,
Fee: s.State.Fee,

SenderPubKey: s.State.SenderPubKey,
SenderOutput: s.State.SenderOutput,

ReceiverPubKey: s.State.ReceiverPubKey,
ReceiverOutput: s.State.ReceiverOutput,

TxID: txid,
Vout: vout,
SenderSig: sig,
Expand Down Expand Up @@ -251,6 +259,8 @@ func (s *Sender) GetSendRequest(amount int64, payment []byte) (*models.SendReque
}

return &models.SendRequest{
TxID: s.State.FundingTxID,
Vout: s.State.FundingVout,
Payment: payment,
SenderSig: sig,
}, nil
Expand All @@ -275,7 +285,10 @@ func (s *Sender) GetCloseRequest() (*models.CloseRequest, error) {
return nil, ErrNotStatusOpen
}
s.State.Status = StatusClosing
return &models.CloseRequest{}, nil
return &models.CloseRequest{
TxID: s.State.FundingTxID,
Vout: s.State.FundingVout,
}, nil
}

func (s *Sender) GotCloseResponse(resp *models.CloseResponse) error {
Expand Down
31 changes: 17 additions & 14 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,8 @@ func NewClient(c *http.Client, endpoint string) (*Client, error) {
}, nil
}

func (c *Client) do(method, id string, req, resp interface{}) error {
url := c.endpoint
if id != "" {
if !models.ValidateChannelID(id) {
return errors.New("invalid channel ID")
}
url += "/" + id
}
func (c *Client) do(method, path string, req, resp interface{}) error {
url := c.endpoint + path

buf, err := json.Marshal(req)
if err != nil {
Expand Down Expand Up @@ -80,47 +74,56 @@ func (c *Client) do(method, id string, req, resp interface{}) error {

func (c *Client) Create(req models.CreateRequest) (*models.CreateResponse, error) {
var resp models.CreateResponse
if err := c.do(http.MethodPost, "", req, &resp); err != nil {
if err := c.do(http.MethodPost, "/create", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func getChannelID(txid string, vout uint32) string {
return fmt.Sprintf("%s-%d", txid, vout)
}

func (c *Client) Open(req models.OpenRequest) (*models.OpenResponse, error) {
path := "/open/" + getChannelID(req.TxID, req.Vout)
var resp models.OpenResponse
if err := c.do(http.MethodPatch, req.ID, req, &resp); err != nil {
if err := c.do(http.MethodPut, path, req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Validate(req models.ValidateRequest) (*models.ValidateResponse, error) {
path := "/validate/" + getChannelID(req.TxID, req.Vout)
var resp models.ValidateResponse
if err := c.do(http.MethodPut, req.ID, req, &resp); err != nil {
if err := c.do(http.MethodPut, path, req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Send(req models.SendRequest) (*models.SendResponse, error) {
path := "/send/" + getChannelID(req.TxID, req.Vout)
var resp models.SendResponse
if err := c.do(http.MethodPost, req.ID, req, &resp); err != nil {
if err := c.do(http.MethodPost, path, req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Close(req models.CloseRequest) (*models.CloseResponse, error) {
path := "/close/" + getChannelID(req.TxID, req.Vout)
var resp models.CloseResponse
if err := c.do(http.MethodDelete, req.ID, req, &resp); err != nil {
if err := c.do(http.MethodDelete, path, req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Status(req models.StatusRequest) (*models.StatusResponse, error) {
path := "/status/" + getChannelID(req.TxID, req.Vout)
var resp models.StatusResponse
if err := c.do(http.MethodGet, req.ID, req, &resp); err != nil {
if err := c.do(http.MethodGet, path, req, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand Down
32 changes: 15 additions & 17 deletions cmd/mbclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,20 @@ func create(args []string) error {
return err
}

fmt.Printf("Funding address: %s\n", addr)

// Sanity check to make sure client and server both agree on the state.
if addr != resp.FundingAddress {
return errors.New("state discrepancy")
}

if hasRemoteID(domain, resp.ID) {
return errors.New("reused channel id")
}
fmt.Printf("Funding address: %s\n", addr)

id := strconv.Itoa(n)
globalState.Channels[id] = Channel{
Domain: domain,
Host: host,
State: s.State,
KeyPath: n,
RemoteID: resp.ID,
Domain: domain,
Host: host,
State: s.State,
KeyPath: n,
ReceiverData: resp.ReceiverData,
}

fmt.Printf("%s\n", id)
Expand Down Expand Up @@ -186,7 +182,7 @@ func fund(args []string) error {
if err != nil {
return err
}
req.ID = ch.RemoteID
req.ReceiverData = ch.ReceiverData

c, err := getClient(id)
if err != nil {
Expand Down Expand Up @@ -254,7 +250,8 @@ func send(args []string) error {
return err
}
req := models.ValidateRequest{
ID: ch.RemoteID,
TxID: ch.State.FundingTxID,
Vout: ch.State.FundingVout,
Payment: payment,
}
resp, err := c.Validate(req)
Expand Down Expand Up @@ -296,7 +293,7 @@ func flush(id string) error {
if err != nil {
return err
}
sendReq.ID = ch.RemoteID
//sendReq.ID = ch.RemoteID

// Either the payment has been sent or it hasn't. Find out which one.

Expand All @@ -305,7 +302,8 @@ func flush(id string) error {
return err
}
req := models.StatusRequest{
ID: ch.RemoteID,
TxID: ch.State.FundingTxID,
Vout: ch.State.FundingVout,
}
resp, err := c.Status(req)
if err != nil {
Expand Down Expand Up @@ -348,7 +346,7 @@ func flushAction(args []string) error {
func closeAction(args []string) error {
id := args[0]

ch, sender, err := getChannel(id)
_, sender, err := getChannel(id)
if err != nil {
return err
}
Expand All @@ -357,7 +355,6 @@ func closeAction(args []string) error {
if err != nil {
return err
}
req.ID = ch.RemoteID

c, err := getClient(id)
if err != nil {
Expand Down Expand Up @@ -394,7 +391,8 @@ func status(args []string) error {
return err
}
req := models.StatusRequest{
ID: ch.RemoteID,
TxID: ch.State.FundingTxID,
Vout: ch.State.FundingVout,
}
resp, err := c.Status(req)
if err != nil {
Expand Down
Loading

0 comments on commit dde254b

Please sign in to comment.