From 37a68ac56eb571bf07245eea11b81b11bb701ca3 Mon Sep 17 00:00:00 2001 From: Donald Adu-Poku Date: Tue, 20 Oct 2020 17:57:31 +0000 Subject: [PATCH] multi: update wire error types. This updates the wire error types to leverage go 1.13 errors.Is/As functionality as well as confirm to the error infrastructure best practices. --- blockchain/fullblocks_test.go | 2 +- blockchain/go.mod | 11 ++ blockchain/go.sum | 16 --- server.go | 6 +- wire/common_test.go | 16 ++- wire/error.go | 192 ++++++++++------------------------ wire/error_test.go | 42 +++----- wire/fakemessage_test.go | 3 +- wire/message.go | 4 +- wire/message_test.go | 60 ++++------- wire/msgtx_test.go | 16 +-- 11 files changed, 119 insertions(+), 249 deletions(-) diff --git a/blockchain/fullblocks_test.go b/blockchain/fullblocks_test.go index 105e01d013..8d36c10d7a 100644 --- a/blockchain/fullblocks_test.go +++ b/blockchain/fullblocks_test.go @@ -218,7 +218,7 @@ func TestFullBlocks(t *testing.T) { // Ensure there is an error due to deserializing the block. var msgBlock wire.MsgBlock err := msgBlock.BtcDecode(bytes.NewReader(item.RawBlock), 0) - var werr *wire.MessageError + var werr wire.MessageError if !errors.As(err, &werr) { t.Fatalf("block %q (hash %s, height %d) should have "+ "failed to decode", item.Name, blockHash, diff --git a/blockchain/go.mod b/blockchain/go.mod index d7a937355f..97d6a30a92 100644 --- a/blockchain/go.mod +++ b/blockchain/go.mod @@ -16,3 +16,14 @@ require ( github.com/decred/dcrd/wire v1.4.0 github.com/decred/slog v1.1.0 ) + +replace ( + github.com/decred/dcrd/blockchain/stake/v3 => ./stake + github.com/decred/dcrd/blockchain/standalone/v2 => ./standalone + github.com/decred/dcrd/chaincfg/v3 => ../chaincfg + github.com/decred/dcrd/dcrec/secp256k1/v3 => ../dcrec/secp256k1 + github.com/decred/dcrd/dcrutil/v3 => ../dcrutil + github.com/decred/dcrd/gcs/v2 => ../gcs + github.com/decred/dcrd/txscript/v3 => ../txscript + github.com/decred/dcrd/wire => ../wire +) diff --git a/blockchain/go.sum b/blockchain/go.sum index d1f3d79200..f59a5994c1 100644 --- a/blockchain/go.sum +++ b/blockchain/go.sum @@ -6,14 +6,8 @@ github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4= github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= github.com/decred/base58 v1.0.3 h1:KGZuh8d1WEMIrK0leQRM47W85KqCAdl2N+uagbctdDI= github.com/decred/base58 v1.0.3/go.mod h1:pXP9cXCfM2sFLb2viz2FNIdeMWmZDBKG3ZBYbiSM78E= -github.com/decred/dcrd/blockchain/stake/v3 v3.0.0 h1:vr0o0ICjuEzg1End6YtBfwgDuPkg+FYIwGVEz18kFg0= -github.com/decred/dcrd/blockchain/stake/v3 v3.0.0/go.mod h1:5GIUwsrHQCJauacgCegIR6t92SaeVi28Qls/BLN9vOw= -github.com/decred/dcrd/blockchain/standalone/v2 v2.0.0 h1:9gUuH0u/IZNPWBK9K3CxgAWPG7nTqVSsZefpGY4Okns= -github.com/decred/dcrd/blockchain/standalone/v2 v2.0.0/go.mod h1:t2qaZ3hNnxHZ5kzVJDgW5sp47/8T5hYJt7SR+/JtRhI= github.com/decred/dcrd/chaincfg/chainhash v1.0.2 h1:rt5Vlq/jM3ZawwiacWjPa+smINyLRN07EO0cNBV6DGU= github.com/decred/dcrd/chaincfg/chainhash v1.0.2/go.mod h1:BpbrGgrPTr3YJYRN3Bm+D9NuaFd+zGyNeIKgrhCXK60= -github.com/decred/dcrd/chaincfg/v3 v3.0.0 h1:+TFbu7ZmvBwM+SZz5mrj6cun9ts/6DAL5sqnsaFBHGQ= -github.com/decred/dcrd/chaincfg/v3 v3.0.0/go.mod h1:EspyubQ7D2w6tjP7rBGDIE7OTbuMgBjR2F2kZFnh31A= github.com/decred/dcrd/crypto/blake256 v1.0.0 h1:/8DMNYp9SGi5f0w7uCm6d6M4OU2rGFK09Y2A4Xv7EE0= github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= github.com/decred/dcrd/crypto/ripemd160 v1.0.1 h1:TjRL4LfftzTjXzaufov96iDAkbY2R3aTvH2YMYa1IOc= @@ -24,16 +18,6 @@ github.com/decred/dcrd/dcrec v1.0.0 h1:W+z6Es+Rai3MXYVoPAxYr5U1DGis0Co33scJ6uH2J github.com/decred/dcrd/dcrec v1.0.0/go.mod h1:HIaqbEJQ+PDzQcORxnqen5/V1FR3B4VpIfmePklt8Q8= github.com/decred/dcrd/dcrec/edwards/v2 v2.0.1 h1:V6eqU1crZzuoFT4KG2LhaU5xDSdkHuvLQsj25wd7Wb4= github.com/decred/dcrd/dcrec/edwards/v2 v2.0.1/go.mod h1:d0H8xGMWbiIQP7gN3v2rByWUcuZPm9YsgmnfoxgbINc= -github.com/decred/dcrd/dcrec/secp256k1/v3 v3.0.0 h1:sgNeV1VRMDzs6rzyPpxyM0jp317hnwiq58Filgag2xw= -github.com/decred/dcrd/dcrec/secp256k1/v3 v3.0.0/go.mod h1:J70FGZSbzsjecRTiTzER+3f1KZLNaXkuv+yeFTKoxM8= -github.com/decred/dcrd/dcrutil/v3 v3.0.0 h1:n6uQaTQynIhCY89XsoDk2WQqcUcnbD+zUM9rnZcIOZo= -github.com/decred/dcrd/dcrutil/v3 v3.0.0/go.mod h1:iVsjcqVzLmYFGCZLet2H7Nq+7imV9tYcuY+0lC2mNsY= -github.com/decred/dcrd/gcs/v2 v2.1.0 h1:foECqwfE3UJztU4CYtqUYqvR254x1Z9clXVfNdOjBQ8= -github.com/decred/dcrd/gcs/v2 v2.1.0/go.mod h1:MbnJOINFcp42NMRAQ+CjX/xGz+53AwNgMzKZhwBibdM= -github.com/decred/dcrd/txscript/v3 v3.0.0 h1:74NmirXAIskbGP0g9OWtrmN7OxDbWJ9G73a5uoxTkcM= -github.com/decred/dcrd/txscript/v3 v3.0.0/go.mod h1:pdvnlD4KGdDoc09cvWRJ8EoRQUaiUz41uDevOWuEfII= -github.com/decred/dcrd/wire v1.4.0 h1:KmSo6eTQIvhXS0fLBQ/l7hG7QLcSJQKSwSyzSqJYDk0= -github.com/decred/dcrd/wire v1.4.0/go.mod h1:WxC/0K+cCAnBh+SKsRjIX9YPgvrjhmE+6pZlel1G7Ro= github.com/decred/slog v1.1.0 h1:uz5ZFfmaexj1rEDgZvzQ7wjGkoSPjw2LCh8K+K1VrW4= github.com/decred/slog v1.1.0/go.mod h1:kVXlGnt6DHy2fV5OjSeuvCJ0OmlmTF6LFpEPMu/fOY0= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= diff --git a/server.go b/server.go index 7532ba91a9..b65988e65f 100644 --- a/server.go +++ b/server.go @@ -1506,9 +1506,9 @@ func (sp *serverPeer) OnAddr(p *peer.Peer, msg *wire.MsgAddr) { // the bytes received by the server. func (sp *serverPeer) OnRead(p *peer.Peer, bytesRead int, msg wire.Message, err error) { // Ban peers sending messages that do not conform to the wire protocol. - var errCode wire.ErrorCode - if errors.As(err, &errCode) { - peerLog.Errorf("Unable to read wire message from %s: %v", sp, err) + var werr wire.MessageError + if errors.As(err, &werr) { + peerLog.Errorf("Unable to read wire message from %s: %v", sp, werr) sp.server.BanPeer(sp) } diff --git a/wire/common_test.go b/wire/common_test.go index 4901c67492..241640eb37 100644 --- a/wire/common_test.go +++ b/wire/common_test.go @@ -403,7 +403,7 @@ func TestVarIntNonCanonical(t *testing.T) { // Decode from wire format. rbuf := bytes.NewReader(test.in) val, err := ReadVarInt(rbuf, test.pver) - var merr *MessageError + var merr MessageError if !errors.As(err, &merr) { t.Errorf("ReadVarInt #%d (%s) unexpected error %v", i, test.name, err) @@ -565,11 +565,11 @@ func TestVarStringOverflowErrors(t *testing.T) { }{ { []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, - pver, &MessageError{}, + pver, MessageError{}, }, { []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - pver, &MessageError{}, + pver, MessageError{}, }, } @@ -835,13 +835,11 @@ func TestVarBytesOverflowErrors(t *testing.T) { }{ { []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, - pver, - &MessageError{}, + pver, ErrVarBytesTooLong, }, { []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, - pver, - &MessageError{}, + pver, ErrVarBytesTooLong, }, } @@ -851,9 +849,9 @@ func TestVarBytesOverflowErrors(t *testing.T) { rbuf := bytes.NewReader(test.buf) _, err := ReadVarBytes(rbuf, test.pver, MaxMessagePayload, "test payload") - if reflect.TypeOf(err) != reflect.TypeOf(test.err) { + if !errors.Is(err, test.err) { t.Errorf("ReadVarBytes #%d wrong error got: %v, "+ - "want: %v", i, err, reflect.TypeOf(test.err)) + "want: %v", i, err, test.err) continue } } diff --git a/wire/error.go b/wire/error.go index 9497587c02..60298e7515 100644 --- a/wire/error.go +++ b/wire/error.go @@ -5,246 +5,160 @@ package wire -import ( - "fmt" -) - -// ErrorCode describes a kind of message error. -type ErrorCode int +// ErrorKind identifies a kind of error. It has full support for errors.Is and +// errors.As, so the caller can directly check against an error kind when +// determining the reason for an error. +type ErrorKind string // These constants are used to identify a specific Error. const ( // ErrNonCanonicalVarInt is returned when a variable length integer is // not canonically encoded. - ErrNonCanonicalVarInt ErrorCode = iota + ErrNonCanonicalVarInt = ErrorKind("ErrNonCanonicalVarInt") // ErrVarStringTooLong is returned when a variable string exceeds the // maximum message size allowed. - ErrVarStringTooLong + ErrVarStringTooLong = ErrorKind("ErrVarStringTooLong") // ErrVarBytesTooLong is returned when a variable-length byte slice // exceeds the maximum message size allowed. - ErrVarBytesTooLong + ErrVarBytesTooLong = ErrorKind("ErrVarBytesTooLong") // ErrCmdTooLong is returned when a command exceeds the maximum command // size allowed. - ErrCmdTooLong + ErrCmdTooLong = ErrorKind("ErrCmdTooLong") // ErrPayloadTooLarge is returned when a payload exceeds the maximum // payload size allowed. - ErrPayloadTooLarge + ErrPayloadTooLarge = ErrorKind("ErrPayloadTooLarge") // ErrWrongNetwork is returned when a message intended for a different // network is received. - ErrWrongNetwork + ErrWrongNetwork = ErrorKind("ErrWrongNetwork") // ErrMalformedCmd is returned when a malformed command is received. - ErrMalformedCmd + ErrMalformedCmd = ErrorKind("ErrMalformedCmd") // ErrUnknownCmd is returned when an unknown command is received. - ErrUnknownCmd + ErrUnknownCmd = ErrorKind("ErrUnknownCmd") // ErrPayloadChecksum is returned when a message with an invalid checksum // is received. - ErrPayloadChecksum + ErrPayloadChecksum = ErrorKind("ErrPayloadChecksum") // ErrTooManyAddrs is returned when an address list exceeds the maximum // allowed. - ErrTooManyAddrs + ErrTooManyAddrs = ErrorKind("ErrTooManyAddrs") // ErrTooManyTxs is returned when a the number of transactions exceed the // maximum allowed. - ErrTooManyTxs + ErrTooManyTxs = ErrorKind("ErrTooManyTxs") // ErrMsgInvalidForPVer is returned when a message is invalid for // the expected protocol version. - ErrMsgInvalidForPVer + ErrMsgInvalidForPVer = ErrorKind("ErrMsgInvalidForPVer") // ErrFilterTooLarge is returned when a committed filter exceeds // the maximum size allowed. - ErrFilterTooLarge + ErrFilterTooLarge = ErrorKind("ErrFilterTooLarge") // ErrTooManyProofs is returned when the numeber of proof hashes // exceeds the maximum allowed. - ErrTooManyProofs + ErrTooManyProofs = ErrorKind("ErrTooManyProofs") // ErrTooManyFilterTypes is returned when the number of filter types // exceeds the maximum allowed. - ErrTooManyFilterTypes + ErrTooManyFilterTypes = ErrorKind("ErrTooManyFilterTypes") // ErrTooManyLocators is returned when the number of block locators exceed // the maximum allowed. - ErrTooManyLocators + ErrTooManyLocators = ErrorKind("ErrTooManyLocators") // ErrTooManyVectors is returned when the number of inventory vectors // exceed the maximum allowed. - ErrTooManyVectors + ErrTooManyVectors = ErrorKind("ErrTooManyVectors") // ErrTooManyHeaders is returned when the number of block headers exceed // the maximum allowed. - ErrTooManyHeaders + ErrTooManyHeaders = ErrorKind("ErrTooManyHeaders") // ErrHeaderContainsTxs is returned when a header's transactions // count is greater than zero. - ErrHeaderContainsTxs + ErrHeaderContainsTxs = ErrorKind("ErrHeaderContainsTxs") // ErrTooManyVotes is returned when the number of vote hashes exceed the // maximum allowed. - ErrTooManyVotes + ErrTooManyVotes = ErrorKind("ErrTooManyVotes") // ErrTooManyBlocks is returned when the number of block hashes exceed the // maximum allowed. - ErrTooManyBlocks + ErrTooManyBlocks = ErrorKind("ErrTooManyBlocks") // ErrMismatchedWitnessCount returned when a transaction has unequal witness // and prefix txin quantities. - ErrMismatchedWitnessCount + ErrMismatchedWitnessCount = ErrorKind("ErrMismatchedWitnessCount") // ErrUnknownTxType is returned when a transaction type is unknown. - ErrUnknownTxType + ErrUnknownTxType = ErrorKind("ErrUnknownTxType") // ErrReadInPrefixFromWitnessOnlyTx is returned when attempting to read a // transaction input prefix from a witness only transaction. - ErrReadInPrefixFromWitnessOnlyTx + ErrReadInPrefixFromWitnessOnlyTx = ErrorKind("ErrReadInPrefixFromWitnessOnlyTx") // ErrInvalidMsg is returned for an invalid message structure. - ErrInvalidMsg + ErrInvalidMsg = ErrorKind("ErrInvalidMsg") // ErrUserAgentTooLong is returned when the provided user agent exceeds // the maximum allowed. - ErrUserAgentTooLong + ErrUserAgentTooLong = ErrorKind("ErrUserAgentTooLong") // ErrTooManyFilterHeaders is returned when the number of committed filter // headers exceed the maximum allowed. - ErrTooManyFilterHeaders + ErrTooManyFilterHeaders = ErrorKind("ErrTooManyFilterHeaders") // ErrMalformedStrictString is returned when a string that has strict // formatting requirements does not conform to the requirements. - ErrMalformedStrictString + ErrMalformedStrictString = ErrorKind("ErrMalformedStrictString") // ErrTooManyInitialStateTypes is returned when the number of initial // state types is larger than the maximum allowed by the protocol. - ErrTooManyInitStateTypes + ErrTooManyInitStateTypes = ErrorKind("ErrTooManyInitStateTypes") // ErrInitialStateTypeTooLong is returned when an individual initial // state type is longer than allowed by the protocol. - ErrInitStateTypeTooLong + ErrInitStateTypeTooLong = ErrorKind("ErrInitStateTypeTooLong") // ErrTooManyTSpends is returned when the number of tspend hashes // exceeds the maximum allowed. - ErrTooManyTSpends + ErrTooManyTSpends = ErrorKind("ErrTooManyTSpends") ) -// Map of ErrorCode values back to their constant names for pretty printing. -var errorCodeStrings = map[ErrorCode]string{ - ErrNonCanonicalVarInt: "ErrNonCanonicalVarInt", - ErrVarStringTooLong: "ErrVarStringTooLong", - ErrVarBytesTooLong: "ErrVarBytesTooLong", - ErrCmdTooLong: "ErrCmdTooLong", - ErrPayloadTooLarge: "ErrPayloadTooLarge", - ErrWrongNetwork: "ErrWrongNetwork", - ErrMalformedCmd: "ErrMalformedCmd", - ErrUnknownCmd: "ErrUnknownCmd", - ErrPayloadChecksum: "ErrPayloadChecksum", - ErrTooManyAddrs: "ErrTooManyAddrs", - ErrTooManyTxs: "ErrTooManyTxs", - ErrMsgInvalidForPVer: "ErrMsgInvalidForPVer", - ErrFilterTooLarge: "ErrFilterTooLarge", - ErrTooManyProofs: "ErrTooManyProofs", - ErrTooManyFilterTypes: "ErrTooManyFilterTypes", - ErrTooManyLocators: "ErrTooManyLocators", - ErrTooManyVectors: "ErrTooManyVectors", - ErrTooManyHeaders: "ErrTooManyHeaders", - ErrHeaderContainsTxs: "ErrHeaderContainsTxs", - ErrTooManyVotes: "ErrTooManyVotes", - ErrTooManyBlocks: "ErrTooManyBlocks", - ErrMismatchedWitnessCount: "ErrMismatchedWitnessCount", - ErrUnknownTxType: "ErrUnknownTxType", - ErrReadInPrefixFromWitnessOnlyTx: "ErrReadInPrefixFromWitnessOnlyTx", - ErrInvalidMsg: "ErrInvalidMsg", - ErrUserAgentTooLong: "ErrUserAgentTooLong", - ErrTooManyFilterHeaders: "ErrTooManyFilterHeaders", - ErrMalformedStrictString: "ErrMalformedStrictString", - ErrTooManyInitStateTypes: "ErrTooManyInitStateTypes", - ErrInitStateTypeTooLong: "ErrInitStateTypeTooLong", - ErrTooManyTSpends: "ErrTooManyTSpends", -} - -// String returns the ErrorCode as a human-readable name. -func (e ErrorCode) String() string { - if s := errorCodeStrings[e]; s != "" { - return s - } - return fmt.Sprintf("Unknown ErrorCode (%d)", int(e)) -} - -// Error implements the error interface. -func (e ErrorCode) Error() string { - return e.String() -} - -// Is implements the interface to work with the standard library's errors.Is. -// -// It returns true in the following cases: -// - The target is a *MessageError and the error codes match -// - The target is an ErrorCode and it the error codes match -func (e ErrorCode) Is(target error) bool { - switch target := target.(type) { - case *MessageError: - return e == target.ErrorCode - - case ErrorCode: - return e == target - } - - return false +// Error satisfies the error interface and prints human-readable errors. +func (e ErrorKind) Error() string { + return string(e) } -// MessageError describes an issue with a message. -// An example of some potential issues are messages from the wrong decred -// network, invalid commands, mismatched checksums, and exceeding max payloads. -// -// This provides a mechanism for the caller to type assert the error to -// differentiate between general io errors such as io.EOF and issues that -// resulted from malformed messages. +// MessageError identifies an error related to wire messages. It has +// full support for errors.Is and errors.As, so the caller can +// ascertain the specific reason for the error by checking the +// underlying error. type MessageError struct { - Func string // Function name - ErrorCode ErrorCode // Describes the kind of error - Description string // Human readable description of the issue + Func string + Err error + Description string } // Error satisfies the error interface and prints human-readable errors. -func (m MessageError) Error() string { - if m.Func != "" { - return fmt.Sprintf("%v: %v", m.Func, m.Description) - } - return m.Description -} - -// messageError creates an Error given a set of arguments. -func messageError(Func string, c ErrorCode, desc string) *MessageError { - return &MessageError{Func: Func, ErrorCode: c, Description: desc} +func (e MessageError) Error() string { + return e.Description } -// Is implements the interface to work with the standard library's errors.Is. -// -// It returns true in the following cases: -// - The target is a *MessageError and the error codes match -// - The target is an ErrorCode and it the error codes match -func (m *MessageError) Is(target error) bool { - switch target := target.(type) { - case *MessageError: - return m.ErrorCode == target.ErrorCode - - case ErrorCode: - return target == m.ErrorCode - } - - return false +// Unwrap returns the underlying wrapped error. +func (e MessageError) Unwrap() error { + return e.Err } -// Unwrap returns the underlying wrapped error if it is not ErrOther. -// Unwrap returns the ErrorCode. Else, it returns nil. -func (m *MessageError) Unwrap() error { - return m.ErrorCode +// messageError creates a MessageError given a set of arguments. +func messageError(fn string, kind ErrorKind, desc string) MessageError { + return MessageError{Func: fn, Err: kind, Description: desc} } diff --git a/wire/error_test.go b/wire/error_test.go index 155738518f..4b2861de64 100644 --- a/wire/error_test.go +++ b/wire/error_test.go @@ -10,13 +10,12 @@ import ( "testing" ) -// TestMessageErrorCodeStringer tests the stringized output for -// the ErrorCode type. -func TestMessageErrorCodeStringer(t *testing.T) { +// TestErrorKindStringer tests the stringized output for the ErrorKind type. +func TestErrorKindStringer(t *testing.T) { t.Parallel() tests := []struct { - in ErrorCode + in ErrorKind want string }{ {ErrNonCanonicalVarInt, "ErrNonCanonicalVarInt"}, @@ -50,15 +49,12 @@ func TestMessageErrorCodeStringer(t *testing.T) { {ErrTooManyInitStateTypes, "ErrTooManyInitStateTypes"}, {ErrInitStateTypeTooLong, "ErrInitStateTypeTooLong"}, {ErrTooManyTSpends, "ErrTooManyTSpends"}, - - {0xffff, "Unknown ErrorCode (65535)"}, } - t.Logf("Running %d tests", len(tests)) for i, test := range tests { - result := test.in.String() + result := test.in.Error() if result != test.want { - t.Errorf("String #%d\n got: %s want: %s", i, result, + t.Errorf("%d: got: %s want: %s", i, result, test.want) continue } @@ -78,12 +74,8 @@ func TestMessageError(t *testing.T) { }, { MessageError{Description: "human-readable error"}, "human-readable error", - }, { - MessageError{Func: "foo", Description: "something bad happened"}, - "foo: something bad happened", }} - t.Logf("Running %d tests", len(tests)) for i, test := range tests { result := test.in.Error() if result != test.want { @@ -93,7 +85,7 @@ func TestMessageError(t *testing.T) { } } -// TestErrorCodeIsAs ensures both ErrorCode and MessageError can be identified +// TestErrorKindIsAs ensures both ErrorKind and MessageError can be identified // as being a specific error code via errors.Is and unwrapped via errors.As. func TestErrorCodeIsAs(t *testing.T) { tests := []struct { @@ -101,7 +93,7 @@ func TestErrorCodeIsAs(t *testing.T) { err error target error wantMatch bool - wantAs ErrorCode + wantAs ErrorKind }{{ name: "ErrTooManyAddrs == ErrTooManyAddrs", err: ErrTooManyAddrs, @@ -114,12 +106,6 @@ func TestErrorCodeIsAs(t *testing.T) { target: ErrTooManyAddrs, wantMatch: true, wantAs: ErrTooManyAddrs, - }, { - name: "ErrTooManyAddrs == MessageError.ErrTooManyAddrs", - err: ErrTooManyAddrs, - target: messageError("", ErrTooManyAddrs, ""), - wantMatch: true, - wantAs: ErrTooManyAddrs, }, { name: "MessageError.ErrTooManyAddrs == MessageError.ErrTooManyAddrs", err: messageError("", ErrTooManyAddrs, ""), @@ -161,16 +147,16 @@ func TestErrorCodeIsAs(t *testing.T) { continue } - // Ensure the underlying error code can be unwrapped and is the expected - // code. - var code ErrorCode - if !errors.As(test.err, &code) { + // Ensure the underlying error kind can be unwrapped and is the + // expected error. + var kind ErrorKind + if !errors.As(test.err, &kind) { t.Errorf("%s: unable to unwrap to error code", test.name) continue } - if code != test.wantAs { - t.Errorf("%s: unexpected unwrapped error code -- got %v, want %v", - test.name, code, test.wantAs) + if kind != test.wantAs { + t.Errorf("%s: unexpected unwrapped error kind -- got %v, want %v", + test.name, kind, test.wantAs) continue } } diff --git a/wire/fakemessage_test.go b/wire/fakemessage_test.go index 67f701d56d..ea1812d436 100644 --- a/wire/fakemessage_test.go +++ b/wire/fakemessage_test.go @@ -26,8 +26,9 @@ func (msg *fakeMessage) BtcDecode(r io.Reader, pver uint32) error { // Message interface. func (msg *fakeMessage) BtcEncode(w io.Writer, pver uint32) error { if msg.forceEncodeErr { - err := &MessageError{ + err := MessageError{ Func: "fakeMessage.BtcEncode", + Err: ErrInvalidMsg, Description: "intentional error", } return err diff --git a/wire/message.go b/wire/message.go index 4cd5f91e44..84f1d3c97d 100644 --- a/wire/message.go +++ b/wire/message.go @@ -263,10 +263,10 @@ func WriteMessageN(w io.Writer, msg Message, pver uint32, dcrnet CurrencyNet) (i // Enforce maximum message payload based on the message type. mpl := msg.MaxPayloadLength(pver) if uint32(lenp) > mpl { - str := fmt.Sprintf("message payload is too large - encoded "+ + msg := fmt.Sprintf("message payload is too large - encoded "+ "%d bytes, but maximum message payload size for "+ "messages of type [%s] is %d.", lenp, cmd, mpl) - return totalBytes, messageError(op, ErrPayloadTooLarge, str) + return totalBytes, messageError(op, ErrPayloadTooLarge, msg) } // Create header for the message. diff --git a/wire/message_test.go b/wire/message_test.go index ee4b32cd69..95e795c370 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -250,7 +250,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(testNetBytes), - &MessageError{}, + ErrWrongNetwork, 24, }, @@ -260,7 +260,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(exceedMaxPayloadBytes), - &MessageError{}, + ErrPayloadTooLarge, 24, }, @@ -270,7 +270,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(badCommandBytes), - &MessageError{}, + ErrMalformedCmd, 24, }, @@ -280,7 +280,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(unsupportedCommandBytes), - &MessageError{}, + ErrUnknownCmd, 24, }, @@ -290,7 +290,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(exceedTypePayloadBytes), - &MessageError{}, + ErrPayloadTooLarge, 24, }, @@ -310,7 +310,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(badChecksumBytes), - &MessageError{}, + ErrPayloadChecksum, 26, }, @@ -320,7 +320,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(badMessageBytes), - &MessageError{}, + ErrPayloadChecksum, 25, }, @@ -330,7 +330,7 @@ func TestReadMessageWireErrors(t *testing.T) { pver, dcrnet, len(discardBytes), - &MessageError{}, + ErrUnknownCmd, 24, }, } @@ -340,9 +340,9 @@ func TestReadMessageWireErrors(t *testing.T) { // Decode from wire format. r := newFixedReader(test.max, test.buf) nr, _, _, err := ReadMessageN(r, test.pver, test.dcrnet) - if reflect.TypeOf(err) != reflect.TypeOf(test.readErr) { - t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+ - "want: %T", i, err, err, test.readErr) + if !errors.Is(err, test.readErr) { + t.Errorf("ReadMessage #%d wrong error got: %v, "+ + "want: %v", i, err.Error(), test.readErr) continue } @@ -351,18 +351,6 @@ func TestReadMessageWireErrors(t *testing.T) { t.Errorf("ReadMessage #%d unexpected num bytes read - "+ "got %d, want %d", i, nr, test.bytes) } - - // For errors which are not of type MessageError, check them for - // equality. - var merr *MessageError - if !errors.As(err, &merr) { - if !errors.Is(err, test.readErr) { - t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+ - "want: %v <%T>", i, err, err, - test.readErr, test.readErr) - continue - } - } } } @@ -371,7 +359,6 @@ func TestReadMessageWireErrors(t *testing.T) { func TestWriteMessageWireErrors(t *testing.T) { pver := ProtocolVersion dcrnet := MainNet - wireErr := &MessageError{} // Fake message with a command that is too long. badCommandMsg := &fakeMessage{command: "somethingtoolong"} @@ -401,13 +388,13 @@ func TestWriteMessageWireErrors(t *testing.T) { bytes int // Expected num bytes written }{ // Command too long. - {badCommandMsg, pver, dcrnet, 0, wireErr, 0}, + {badCommandMsg, pver, dcrnet, 0, ErrCmdTooLong, 0}, // Force error in payload encode. - {encodeErrMsg, pver, dcrnet, 0, wireErr, 0}, + {encodeErrMsg, pver, dcrnet, 0, ErrInvalidMsg, 0}, // Force error due to exceeding max overall message payload size. - {exceedOverallPayloadErrMsg, pver, dcrnet, 0, wireErr, 0}, + {exceedOverallPayloadErrMsg, pver, dcrnet, 0, ErrPayloadTooLarge, 0}, // Force error due to exceeding max payload for message type. - {exceedPayloadErrMsg, pver, dcrnet, 0, wireErr, 0}, + {exceedPayloadErrMsg, pver, dcrnet, 0, ErrPayloadTooLarge, 0}, // Force error in header write. {bogusMsg, pver, dcrnet, 0, io.ErrShortWrite, 0}, // Force error in payload write. @@ -419,9 +406,9 @@ func TestWriteMessageWireErrors(t *testing.T) { // Encode wire format. w := newFixedWriter(test.max) nw, err := WriteMessageN(w, test.msg, test.pver, test.dcrnet) - if reflect.TypeOf(err) != reflect.TypeOf(test.err) { - t.Errorf("WriteMessage #%d wrong error got: %v <%T>, "+ - "want: %T", i, err, err, test.err) + if !errors.Is(err, test.err) { + t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+ + "want: %v <%T>", i, err, err, test.err, test.err) continue } @@ -430,16 +417,5 @@ func TestWriteMessageWireErrors(t *testing.T) { t.Errorf("WriteMessage #%d unexpected num bytes "+ "written - got %d, want %d", i, nw, test.bytes) } - - // For errors which are not of type MessageError, check them for - // equality. - var merr *MessageError - if !errors.As(err, &merr) { - if !errors.Is(err, test.err) { - t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+ - "want: %v <%T>", i, err, err, test.err, test.err) - continue - } - } } } diff --git a/wire/msgtx_test.go b/wire/msgtx_test.go index 919e1b978e..23b3d94599 100644 --- a/wire/msgtx_test.go +++ b/wire/msgtx_test.go @@ -719,7 +719,7 @@ func TestTxOverflowErrors(t *testing.T) { 0x01, 0x00, 0x00, 0x00, // Version 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // Varint for number of input transactions - }, pver, txVer, &MessageError{}, + }, pver, txVer, ErrTooManyTxs, }, // Transaction that claims to have ~uint64(0) outputs. [1] @@ -729,7 +729,7 @@ func TestTxOverflowErrors(t *testing.T) { 0x00, // Varint for number of input transactions 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // Varint for number of output transactions - }, pver, txVer, &MessageError{}, + }, pver, txVer, ErrTooManyTxs, }, // Transaction that has an input with a signature script that [2] @@ -777,7 +777,7 @@ func TestTxOverflowErrors(t *testing.T) { 0x00, 0x00, 0x00, 0x00, // Expiry 0x01, // Varint for number of input signature 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // Varint for sig script length (overflows) - }, pver, txVer, &MessageError{}, + }, pver, txVer, ErrTooManyTxs, }, // Transaction that has an output with a public key script [3] @@ -798,7 +798,7 @@ func TestTxOverflowErrors(t *testing.T) { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Transaction amount 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // Varint for length of public key script - }, pver, txVer, &MessageError{}, + }, pver, txVer, ErrNonCanonicalVarInt, }, } @@ -808,17 +808,17 @@ func TestTxOverflowErrors(t *testing.T) { var msg MsgTx r := bytes.NewReader(test.buf) err := msg.BtcDecode(r, test.pver) - if reflect.TypeOf(err) != reflect.TypeOf(test.err) { + if !errors.Is(err, test.err) { t.Errorf("BtcDecode #%d wrong error got: %v, want: %v", - i, err, reflect.TypeOf(test.err)) + i, err, test.err) } // Decode from wire format. r = bytes.NewReader(test.buf) err = msg.Deserialize(r) - if reflect.TypeOf(err) != reflect.TypeOf(test.err) { + if !errors.Is(err, test.err) { t.Errorf("Deserialize #%d wrong error got: %v, want: %v", - i, err, reflect.TypeOf(test.err)) + i, err, test.err) continue } }