Skip to content

Commit

Permalink
add backward compatbility
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzhen1997 committed Nov 5, 2024
1 parent 3c0ed77 commit 3291f54
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 25 deletions.
97 changes: 80 additions & 17 deletions selectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chain_selectors
import (
_ "embed"
"fmt"
"strconv"

"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -32,9 +33,9 @@ const (
var chainIDToSelectorMapForFamily = make(map[string]map[string]uint64)
var selectorsMap = loadYML(selectorYml)
var testSelectorsMap = loadYML(testSelectorsYml)
var chainIdToChainSelector = loadAllChainIDToChainSelector()
var chainSelectorToDetails = loadAllChainSelector()

func loadAllChainIDToChainSelector() map[uint64]ChainDetails {
func loadAllChainSelector() map[uint64]ChainDetails {
output := make(map[uint64]ChainDetails, len(selectorsMap)+len(testSelectorsMap))
for k, v := range selectorsMap {
output[k] = v
Expand Down Expand Up @@ -75,7 +76,7 @@ func loadYML(yml []byte) map[uint64]ChainDetails {

func ChainSelectorToChainDetails() map[uint64]ChainDetails {
copyMap := make(map[uint64]ChainDetails, len(selectorsMap))
for k, v := range chainIdToChainSelector {
for k, v := range chainSelectorToDetails {
copyMap[k] = v
}

Expand All @@ -84,26 +85,42 @@ func ChainSelectorToChainDetails() map[uint64]ChainDetails {

func GetSelectorFamily(selector uint64) (string, error) {
// previously selector_families.yml includes both real and test chains, therefore we check both maps
details, exist := chainIdToChainSelector[selector]
details, exist := chainSelectorToDetails[selector]
if exist {
return details.Family, nil
}

return "", fmt.Errorf("chain detail not found for selector %d", selector)
}

func ChainIdFromSelector(chainSelectorId uint64) (string, error) {
chainDetail, ok := chainIdToChainSelector[chainSelectorId]
// ChainIdFromSelector is for backward compatibility support, it used to return uint64 for chainID so we preserve the behavior
// Deprecated: Call GetChainIdFromSelector directly
func ChainIdFromSelector(chainSelectorId uint64) (uint64, error) {
chainId, err := GetChainIdFromSelector(chainSelectorId)
if err != nil {
return 0, err
}

parseInt, err := strconv.ParseInt(chainId, 10, 64)
if err != nil {
return 0, err
}
return uint64(parseInt), fmt.Errorf("chain not found for chain selector %d", chainSelectorId)
}

func GetChainIdFromSelector(chainSelectorId uint64) (string, error) {
chainDetail, ok := chainSelectorToDetails[chainSelectorId]
if ok {
return chainDetail.ChainID, nil
}

return "0", fmt.Errorf("chain not found for chain selector %d", chainSelectorId)
}

// SelectorFromChainId is for backward compatibility support
func SelectorFromChainId(chainId string) (uint64, error) {
return SelectorFromChainIdAndFamily(chainId, FamilyEVM)
// SelectorFromChainId is for backward compatibility support, it used to take uint64 as chainID so we preserve the behavior
// Deprecated: Call SelectorFromChainIdAndFamily directly
func SelectorFromChainId(chainId uint64) (uint64, error) {
return SelectorFromChainIdAndFamily(strconv.FormatUint(chainId, 10), FamilyEVM)
}

func SelectorFromChainIdAndFamily(chainId string, family string) (uint64, error) {
Expand All @@ -125,9 +142,40 @@ func SelectorFromChainIdAndFamily(chainId string, family string) (uint64, error)
return selector, nil
}

// ChainIdFromName is for backward compatibility support
func ChainIdFromName(name string) (string, error) {
return ChainIdFromNameAndFamily(name, FamilyEVM)
// ChainIdFromName is for backward compatibility support, it used to return uint64 as chain ID so we preserve the behavior
// Deprecated: Call ChainIdFromNameAndFamily directly
func ChainIdFromName(name string) (uint64, error) {
chainID, err := ChainIdFromNameAndFamily(name, FamilyEVM)
if err != nil {
return 0, err
}

parseInt, err := strconv.ParseInt(chainID, 10, 64)
if err != nil {
return 0, err
}

return uint64(parseInt), nil
}

// NameFromChainId is for backward compatibility support
// Deprecated: Call SelectorFromChainId directly
func NameFromChainId(chainId uint64) (string, error) {
selector, err := SelectorFromChainIdAndFamily(strconv.FormatUint(chainId, 10), FamilyEVM)
if err != nil {
return "", fmt.Errorf("chain name not found for chain %d", chainId)
}

details, exist := chainSelectorToDetails[selector]
if !exist {
return "", fmt.Errorf("chain selector not found for chain %d", chainId)
}

if details.Name == "" {
return strconv.FormatUint(chainId, 10), nil
}

return details.Name, nil
}

func ChainIdFromNameAndFamily(name string, family string) (string, error) {
Expand All @@ -136,7 +184,7 @@ func ChainIdFromNameAndFamily(name string, family string) (string, error) {
family = FamilyEVM
}

for _, v := range chainIdToChainSelector {
for _, v := range chainSelectorToDetails {
if v.Name == name && family == v.Family {
return v.ChainID, nil
}
Expand All @@ -145,28 +193,43 @@ func ChainIdFromNameAndFamily(name string, family string) (string, error) {
return "0", fmt.Errorf("chain not found for name %s and family %s", name, family)
}

// TestChainIds is for backward compatibility support, it used to return uint64 as chain ID so we preserve the behavior
func TestChainIds() []uint64 {
chainIds := make([]uint64, 0, len(testSelectorsMap))
for k := range testSelectorsMap {
chainIds = append(chainIds, k)
for _, details := range testSelectorsMap {
parseInt, err := strconv.ParseInt(details.ChainID, 10, 64)
if err != nil {
continue
}

chainIds = append(chainIds, uint64(parseInt))
}
return chainIds
}

var chainsBySelector = make(map[uint64]Chain)
var chainsByChainID = make(map[string]Chain)
var chainsByEvmChainID = make(map[string]Chain)

func init() {
for _, ch := range ALL {
chainsBySelector[ch.Selector] = ch
chainsByChainID[ch.ChainID] = ch
if ch.Family == FamilyEVM {
chainsByEvmChainID[ch.ChainID] = ch
}
}
}

func ChainByEvmChainID(evmChainID string) (Chain, bool) {
ch, exists := chainsByEvmChainID[evmChainID]
func chainByChainID(chainID string) (Chain, bool) {
ch, exists := chainsByChainID[chainID]
return ch, exists
}

// Deprecated: Call chainByChainID directly
func ChainByEvmChainID(evmChainID uint64) (Chain, bool) {
chainID := strconv.FormatUint(evmChainID, 10)
ch, exists := chainsByEvmChainID[chainID]
return ch, exists
}

Expand Down
20 changes: 12 additions & 8 deletions selectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func TestNoSameChainSelectorsAreGenerated(t *testing.T) {
chainSelectors := map[uint64]struct{}{}

for selector := range chainIdToChainSelector {
for selector := range chainSelectorToDetails {
_, exist := chainSelectors[selector]
assert.False(t, exist, "Chain Selectors should be unique. Selector %d is duplicated for chain %d", selector)
chainSelectors[selector] = struct{}{}
Expand All @@ -23,7 +23,7 @@ func TestNoSameChainSelectorsAreGenerated(t *testing.T) {
func TestNoSameChainIDAndFamilyAreGenerated(t *testing.T) {
chainIDAndFamily := map[string]struct{}{}

for _, details := range chainIdToChainSelector {
for _, details := range chainSelectorToDetails {
key := fmt.Sprintf("%s:%s", details.ChainID, details.Family)
_, exist := chainIDAndFamily[key]
assert.False(t, exist, "ChainID within single family should be unique. chainID %s is duplicated for family", details.ChainID, details.Family)
Expand Down Expand Up @@ -66,7 +66,7 @@ func TestChainIdToChainSelectorReturningCopiedMap(t *testing.T) {
tmp.ChainID = "2"
selectors[5009297550715157269] = tmp

chainID, err := ChainIdFromSelector(5009297550715157269)
chainID, err := GetChainIdFromSelector(5009297550715157269)
assert.NoError(t, err)
assert.NotEqual(t, chainID, tmp)
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func Test_ChainSelectors(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
chainId, err1 := ChainIdFromSelector(test.chainSelector)
chainId, err1 := GetChainIdFromSelector(test.chainSelector)
chainSelector, err2 := SelectorFromChainIdAndFamily(test.chainId, "")
if test.expectErr {
require.Error(t, err1)
Expand All @@ -135,7 +135,11 @@ func Test_TestChainIds(t *testing.T) {
assert.Equal(t, len(chainIds), len(testSelectorsMap), "Should return correct number of test chain ids")

for _, chainId := range chainIds {
_, exist := testSelectorsMap[chainId]
selector, err := SelectorFromChainId(chainId)
if err != nil {
return
}
_, exist := testSelectorsMap[selector]
assert.True(t, exist)
}
}
Expand Down Expand Up @@ -177,15 +181,15 @@ func Test_ChainNames(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
chainId, err1 := ChainIdFromName(test.chainName)
selector, err2 := SelectorFromChainIdAndFamily(chainId, "")
chainID, err1 := ChainIdFromNameAndFamily(test.chainName, FamilyEVM)
selector, err2 := SelectorFromChainIdAndFamily(chainID, "")
if test.expectErr {
require.Error(t, err1)
require.Error(t, err2)
return
}
require.NoError(t, err1)
assert.Equal(t, test.chainId, chainId)
assert.Equal(t, test.chainId, chainID)

require.NoError(t, err2)
detail, _ := selectorsMap[selector]
Expand Down

0 comments on commit 3291f54

Please sign in to comment.