Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
add auth token
Browse files Browse the repository at this point in the history
  • Loading branch information
timothy committed Apr 16, 2017
1 parent dde254b commit 891076a
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 38 deletions.
25 changes: 14 additions & 11 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewClient(c *http.Client, endpoint string) (*Client, error) {
}, nil
}

func (c *Client) do(method, path string, req, resp interface{}) error {
func (c *Client) do(method, path string, authToken string, req, resp interface{}) error {
url := c.endpoint + path

buf, err := json.Marshal(req)
Expand All @@ -48,6 +48,9 @@ func (c *Client) do(method, path string, req, resp interface{}) error {
if err != nil {
return err
}
if authToken != "" {
hreq.Header.Add("Authorization", "Bearer "+authToken)
}

hresp, err := c.c.Do(hreq)
if err != nil {
Expand All @@ -74,7 +77,7 @@ func (c *Client) do(method, path string, req, resp interface{}) error {

func (c *Client) Create(req models.CreateRequest) (*models.CreateResponse, error) {
var resp models.CreateResponse
if err := c.do(http.MethodPost, "/create", req, &resp); err != nil {
if err := c.do(http.MethodPost, "/create", "", req, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand All @@ -87,43 +90,43 @@ func getChannelID(txid string, vout uint32) string {
func (c *Client) Open(req models.OpenRequest) (*models.OpenResponse, error) {
path := "/open/" + getChannelID(req.TxID, req.Vout)
var resp models.OpenResponse
if err := c.do(http.MethodPut, path, req, &resp); err != nil {
if err := c.do(http.MethodPut, path, "", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Validate(req models.ValidateRequest) (*models.ValidateResponse, error) {
func (c *Client) Validate(req models.ValidateRequest, authToken string) (*models.ValidateResponse, error) {
path := "/validate/" + getChannelID(req.TxID, req.Vout)
var resp models.ValidateResponse
if err := c.do(http.MethodPut, path, req, &resp); err != nil {
if err := c.do(http.MethodPut, path, authToken, req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Send(req models.SendRequest) (*models.SendResponse, error) {
func (c *Client) Send(req models.SendRequest, authToken string) (*models.SendResponse, error) {
path := "/send/" + getChannelID(req.TxID, req.Vout)
var resp models.SendResponse
if err := c.do(http.MethodPost, path, req, &resp); err != nil {
if err := c.do(http.MethodPost, path, authToken, req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Close(req models.CloseRequest) (*models.CloseResponse, error) {
func (c *Client) Close(req models.CloseRequest, authToken string) (*models.CloseResponse, error) {
path := "/close/" + getChannelID(req.TxID, req.Vout)
var resp models.CloseResponse
if err := c.do(http.MethodDelete, path, req, &resp); err != nil {
if err := c.do(http.MethodDelete, path, authToken, req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

func (c *Client) Status(req models.StatusRequest) (*models.StatusResponse, error) {
func (c *Client) Status(req models.StatusRequest, authToken string) (*models.StatusResponse, error) {
path := "/status/" + getChannelID(req.TxID, req.Vout)
var resp models.StatusResponse
if err := c.do(http.MethodGet, path, req, &resp); err != nil {
if err := c.do(http.MethodGet, path, authToken, req, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand Down
18 changes: 11 additions & 7 deletions cmd/mbclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ func create(args []string) error {
}

fmt.Printf("Funding address: %s\n", addr)
fmt.Printf("Fee: %d\n", s.State.Fee)
fmt.Printf("Timeout: %d\n", s.State.Timeout)

id := strconv.Itoa(n)
globalState.Channels[id] = Channel{
Expand Down Expand Up @@ -197,6 +199,9 @@ func fund(args []string) error {
return err
}

if err := storeAuthToken(id, resp.AuthToken); err != nil {
return err
}
return storeChannel(id, sender.State)
}

Expand Down Expand Up @@ -254,7 +259,7 @@ func send(args []string) error {
Vout: ch.State.FundingVout,
Payment: payment,
}
resp, err := c.Validate(req)
resp, err := c.Validate(req, ch.AuthToken)
if err != nil {
return err
}
Expand Down Expand Up @@ -293,7 +298,6 @@ func flush(id string) error {
if err != nil {
return err
}
//sendReq.ID = ch.RemoteID

// Either the payment has been sent or it hasn't. Find out which one.

Expand All @@ -305,7 +309,7 @@ func flush(id string) error {
TxID: ch.State.FundingTxID,
Vout: ch.State.FundingVout,
}
resp, err := c.Status(req)
resp, err := c.Status(req, ch.AuthToken)
if err != nil {
return err
}
Expand All @@ -315,7 +319,7 @@ func flush(id string) error {
if serverBal == sender.State.Balance {
// Pending payment doesn't reflect yet. We have to retry.

if _, err := c.Send(*sendReq); err != nil {
if _, err := c.Send(*sendReq, ch.AuthToken); err != nil {
return err
}

Expand Down Expand Up @@ -346,7 +350,7 @@ func flushAction(args []string) error {
func closeAction(args []string) error {
id := args[0]

_, sender, err := getChannel(id)
ch, sender, err := getChannel(id)
if err != nil {
return err
}
Expand All @@ -360,7 +364,7 @@ func closeAction(args []string) error {
if err != nil {
return err
}
resp, err := c.Close(*req)
resp, err := c.Close(*req, ch.AuthToken)
if err != nil {
return err
}
Expand Down Expand Up @@ -394,7 +398,7 @@ func status(args []string) error {
TxID: ch.State.FundingTxID,
Vout: ch.State.FundingVout,
}
resp, err := c.Status(req)
resp, err := c.Status(req, ch.AuthToken)
if err != nil {
return err
}
Expand Down
11 changes: 11 additions & 0 deletions cmd/mbclient/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Channel struct {
Host string
KeyPath int
ReceiverData []byte
AuthToken string

PendingPayment []byte

Expand Down Expand Up @@ -166,6 +167,16 @@ func storePendingPayment(id string, state channels.SharedState, p []byte) error
return nil
}

func storeAuthToken(id string, authToken string) error {
c, ok := globalState.Channels[id]
if !ok {
return errors.New("channel does not exist")
}
c.AuthToken = authToken
globalState.Channels[id] = c
return nil
}

func findForDomain(domain string) []string {
var ids []string
for id, c := range globalState.Channels {
Expand Down
24 changes: 20 additions & 4 deletions cmd/mbserver/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ func checkID(w http.ResponseWriter, atxid string, avout uint32, btxid string, bv
return true
}

func checkAuthToken(s *ServerState, r *http.Request, txid string, vout uint32) bool {
h := r.Header.Get("Authorization")
const prefix = "Bearer "
if !strings.HasPrefix(h, prefix) {
return false
}
h = h[len(prefix):]
return s.Receiver.ValidateToken(txid, vout, h)
}

func respond(w http.ResponseWriter, r *http.Request, resp interface{}, err error) {
if err != nil {
if *debugServerRPC {
Expand Down Expand Up @@ -180,18 +190,24 @@ func rpcHandler(s *ServerState, w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
return
}
log.Printf("%s\t%s\t%s", path, path[:i], path[i+1:])
call := path[:i]
txid, vout, ok := splitTxIDVout(path[i+1:])
if !ok {
log.Printf("!ok")
http.Error(w, "Invalid channel ID", http.StatusNotFound)
return
}

switch call {
case "open":
if call == "open" {
rpcOpenHandler(s, w, r, txid, vout)
return
}

if !checkAuthToken(s, r, txid, vout) {
http.Error(w, "invalid auth token", http.StatusUnauthorized)
return
}

switch call {
case "validate":
rpcValidateHandler(s, w, r, txid, vout)
case "send":
Expand Down
3 changes: 2 additions & 1 deletion cmd/mbserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ var externalURL = flag.String("external_url", "https://example.com:3211", "Exter
var domain = flag.String("domain", "example.com", "Domain to accept payments for")
var tlsCert = flag.String("tls_cert", "tls/cert.pem", "TLS certificate")
var tlsKey = flag.String("tls_key", "tls/key.pem", "TLS key")
var authToken = flag.String("auth_token", "38a9cba31aed7e655b8d6d7014efc9bbc8ed9a961b708e90dc05e3b70994c5df", "Secret used to issue auth tokens")

func getnet() *chaincfg.Params {
if *testnet {
Expand Down Expand Up @@ -99,7 +100,7 @@ func main() {
defer bc.Shutdown()

dir := receiver.NewDirectory(*domain)
s := receiver.NewReceiver(net, ek, bc, storage, dir, *destination)
s := receiver.NewReceiver(net, ek, bc, storage, dir, *destination, *authToken)

go s.WatchBlockchainForever()

Expand Down
6 changes: 3 additions & 3 deletions models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ type CreateResponse struct {
}

type OpenRequest struct {
TxID string `json:"txid"`
Vout uint32 `json:"vout"`

ReceiverData []byte `json:"receiverData"`

Version int `json:"version"`
Expand All @@ -36,9 +39,6 @@ type OpenRequest struct {
ReceiverPubKey []byte `json:"receiverPubKey"`
ReceiverOutput string `json:"receiverOutput"`

TxID string `json:"txid"`
Vout uint32 `json:"vout"`

SenderSig []byte `json:"senderSig"`
}

Expand Down
40 changes: 28 additions & 12 deletions receiver/receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package receiver

import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -32,6 +34,7 @@ type Receiver struct {
db storage.Storage
dir *Directory
receiverOutput string
authKey []byte
config channels.ReceiverConfig
}

Expand All @@ -40,7 +43,8 @@ func NewReceiver(net *chaincfg.Params,
bc *btcrpcclient.Client,
db storage.Storage,
dir *Directory,
destination string) *Receiver {
destination string,
authKey string) *Receiver {

config := channels.DefaultReceiverConfig
config.Net = net.Name
Expand All @@ -52,6 +56,7 @@ func NewReceiver(net *chaincfg.Params,
db: db,
dir: dir,
receiverOutput: destination,
authKey: []byte(authKey),
config: config,
}
}
Expand All @@ -77,6 +82,26 @@ func (r *Receiver) ListPayments(txid string, vout uint32) ([][]byte, error) {
return r.db.ListPayments(id)
}

func (r *Receiver) issue(txid string, vout uint32) []byte {
id := getChannelID(txid, vout)
mac := hmac.New(sha256.New, r.authKey)
mac.Write([]byte(id))
return mac.Sum(nil)
}

func (r *Receiver) issueToken(txid string, vout uint32) string {
return base64.StdEncoding.EncodeToString(r.issue(txid, vout))
}

func (r *Receiver) ValidateToken(txid string, vout uint32, token string) bool {
actual, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return false
}
expected := r.issue(txid, vout)
return hmac.Equal(actual, expected)
}

func (r *Receiver) getKey(n int) (*btcec.PrivateKey, error) {
ek, err := r.ek.Child(uint32(n))
if err != nil {
Expand Down Expand Up @@ -200,17 +225,6 @@ func (r *Receiver) getPolicy() policy {
}

func (r *Receiver) Open(req models.OpenRequest) (*models.OpenResponse, error) {
//c, err := r.get(req.ID)
//if err != nil {
// return nil, err
//}
//prevState := c.State

//_, addr, err := c.State.GetFundingScript()
//if err != nil {
// return nil, err
//}

// TODO: sign receiverData with expiry
keyPath, err := strconv.Atoi(string(req.ReceiverData))
if err != nil {
Expand Down Expand Up @@ -264,6 +278,8 @@ func (r *Receiver) Open(req models.OpenRequest) (*models.OpenResponse, error) {
return nil, err
}

resp.AuthToken = r.issueToken(req.TxID, req.Vout)

return resp, nil
}

Expand Down

0 comments on commit 891076a

Please sign in to comment.