diff --git a/selectors.go b/selectors.go index aca7a97..b360b48 100644 --- a/selectors.go +++ b/selectors.go @@ -3,6 +3,7 @@ package chain_selectors import ( _ "embed" "fmt" + "strconv" "gopkg.in/yaml.v3" ) @@ -32,9 +33,9 @@ const ( var chainIDToSelectorMapForFamily = make(map[string]map[string]uint64) var selectorsMap = loadYML(selectorYml) var testSelectorsMap = loadYML(testSelectorsYml) -var chainIdToChainSelector = loadAllChainIDToChainSelector() +var chainSelectorToDetails = loadAllChainSelector() -func loadAllChainIDToChainSelector() map[uint64]ChainDetails { +func loadAllChainSelector() map[uint64]ChainDetails { output := make(map[uint64]ChainDetails, len(selectorsMap)+len(testSelectorsMap)) for k, v := range selectorsMap { output[k] = v @@ -75,7 +76,7 @@ func loadYML(yml []byte) map[uint64]ChainDetails { func ChainSelectorToChainDetails() map[uint64]ChainDetails { copyMap := make(map[uint64]ChainDetails, len(selectorsMap)) - for k, v := range chainIdToChainSelector { + for k, v := range chainSelectorToDetails { copyMap[k] = v } @@ -84,7 +85,7 @@ func ChainSelectorToChainDetails() map[uint64]ChainDetails { func GetSelectorFamily(selector uint64) (string, error) { // previously selector_families.yml includes both real and test chains, therefore we check both maps - details, exist := chainIdToChainSelector[selector] + details, exist := chainSelectorToDetails[selector] if exist { return details.Family, nil } @@ -92,8 +93,23 @@ func GetSelectorFamily(selector uint64) (string, error) { return "", fmt.Errorf("chain detail not found for selector %d", selector) } -func ChainIdFromSelector(chainSelectorId uint64) (string, error) { - chainDetail, ok := chainIdToChainSelector[chainSelectorId] +// ChainIdFromSelector is for backward compatibility support, it used to return uint64 for chainID so we preserve the behavior +// Deprecated: Call GetChainIdFromSelector directly +func ChainIdFromSelector(chainSelectorId uint64) (uint64, error) { + chainId, err := GetChainIdFromSelector(chainSelectorId) + if err != nil { + return 0, err + } + + parseInt, err := strconv.ParseInt(chainId, 10, 64) + if err != nil { + return 0, err + } + return uint64(parseInt), fmt.Errorf("chain not found for chain selector %d", chainSelectorId) +} + +func GetChainIdFromSelector(chainSelectorId uint64) (string, error) { + chainDetail, ok := chainSelectorToDetails[chainSelectorId] if ok { return chainDetail.ChainID, nil } @@ -101,9 +117,10 @@ func ChainIdFromSelector(chainSelectorId uint64) (string, error) { return "0", fmt.Errorf("chain not found for chain selector %d", chainSelectorId) } -// SelectorFromChainId is for backward compatibility support -func SelectorFromChainId(chainId string) (uint64, error) { - return SelectorFromChainIdAndFamily(chainId, FamilyEVM) +// SelectorFromChainId is for backward compatibility support, it used to take uint64 as chainID so we preserve the behavior +// Deprecated: Call SelectorFromChainIdAndFamily directly +func SelectorFromChainId(chainId uint64) (uint64, error) { + return SelectorFromChainIdAndFamily(strconv.FormatUint(chainId, 10), FamilyEVM) } func SelectorFromChainIdAndFamily(chainId string, family string) (uint64, error) { @@ -125,9 +142,40 @@ func SelectorFromChainIdAndFamily(chainId string, family string) (uint64, error) return selector, nil } -// ChainIdFromName is for backward compatibility support -func ChainIdFromName(name string) (string, error) { - return ChainIdFromNameAndFamily(name, FamilyEVM) +// ChainIdFromName is for backward compatibility support, it used to return uint64 as chain ID so we preserve the behavior +// Deprecated: Call ChainIdFromNameAndFamily directly +func ChainIdFromName(name string) (uint64, error) { + chainID, err := ChainIdFromNameAndFamily(name, FamilyEVM) + if err != nil { + return 0, err + } + + parseInt, err := strconv.ParseInt(chainID, 10, 64) + if err != nil { + return 0, err + } + + return uint64(parseInt), nil +} + +// NameFromChainId is for backward compatibility support +// Deprecated: Call SelectorFromChainId directly +func NameFromChainId(chainId uint64) (string, error) { + selector, err := SelectorFromChainIdAndFamily(strconv.FormatUint(chainId, 10), FamilyEVM) + if err != nil { + return "", fmt.Errorf("chain name not found for chain %d", chainId) + } + + details, exist := chainSelectorToDetails[selector] + if !exist { + return "", fmt.Errorf("chain selector not found for chain %d", chainId) + } + + if details.Name == "" { + return strconv.FormatUint(chainId, 10), nil + } + + return details.Name, nil } func ChainIdFromNameAndFamily(name string, family string) (string, error) { @@ -136,7 +184,7 @@ func ChainIdFromNameAndFamily(name string, family string) (string, error) { family = FamilyEVM } - for _, v := range chainIdToChainSelector { + for _, v := range chainSelectorToDetails { if v.Name == name && family == v.Family { return v.ChainID, nil } @@ -145,28 +193,43 @@ func ChainIdFromNameAndFamily(name string, family string) (string, error) { return "0", fmt.Errorf("chain not found for name %s and family %s", name, family) } +// TestChainIds is for backward compatibility support, it used to return uint64 as chain ID so we preserve the behavior func TestChainIds() []uint64 { chainIds := make([]uint64, 0, len(testSelectorsMap)) - for k := range testSelectorsMap { - chainIds = append(chainIds, k) + for _, details := range testSelectorsMap { + parseInt, err := strconv.ParseInt(details.ChainID, 10, 64) + if err != nil { + continue + } + + chainIds = append(chainIds, uint64(parseInt)) } return chainIds } var chainsBySelector = make(map[uint64]Chain) +var chainsByChainID = make(map[string]Chain) var chainsByEvmChainID = make(map[string]Chain) func init() { for _, ch := range ALL { chainsBySelector[ch.Selector] = ch + chainsByChainID[ch.ChainID] = ch if ch.Family == FamilyEVM { chainsByEvmChainID[ch.ChainID] = ch } } } -func ChainByEvmChainID(evmChainID string) (Chain, bool) { - ch, exists := chainsByEvmChainID[evmChainID] +func chainByChainID(chainID string) (Chain, bool) { + ch, exists := chainsByChainID[chainID] + return ch, exists +} + +// Deprecated: Call chainByChainID directly +func ChainByEvmChainID(evmChainID uint64) (Chain, bool) { + chainID := strconv.FormatUint(evmChainID, 10) + ch, exists := chainsByEvmChainID[chainID] return ch, exists } diff --git a/selectors_test.go b/selectors_test.go index 805ea99..9ce792f 100644 --- a/selectors_test.go +++ b/selectors_test.go @@ -13,7 +13,7 @@ import ( func TestNoSameChainSelectorsAreGenerated(t *testing.T) { chainSelectors := map[uint64]struct{}{} - for selector := range chainIdToChainSelector { + for selector := range chainSelectorToDetails { _, exist := chainSelectors[selector] assert.False(t, exist, "Chain Selectors should be unique. Selector %d is duplicated for chain %d", selector) chainSelectors[selector] = struct{}{} @@ -23,7 +23,7 @@ func TestNoSameChainSelectorsAreGenerated(t *testing.T) { func TestNoSameChainIDAndFamilyAreGenerated(t *testing.T) { chainIDAndFamily := map[string]struct{}{} - for _, details := range chainIdToChainSelector { + for _, details := range chainSelectorToDetails { key := fmt.Sprintf("%s:%s", details.ChainID, details.Family) _, exist := chainIDAndFamily[key] assert.False(t, exist, "ChainID within single family should be unique. chainID %s is duplicated for family", details.ChainID, details.Family) @@ -66,7 +66,7 @@ func TestChainIdToChainSelectorReturningCopiedMap(t *testing.T) { tmp.ChainID = "2" selectors[5009297550715157269] = tmp - chainID, err := ChainIdFromSelector(5009297550715157269) + chainID, err := GetChainIdFromSelector(5009297550715157269) assert.NoError(t, err) assert.NotEqual(t, chainID, tmp) } @@ -114,7 +114,7 @@ func Test_ChainSelectors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - chainId, err1 := ChainIdFromSelector(test.chainSelector) + chainId, err1 := GetChainIdFromSelector(test.chainSelector) chainSelector, err2 := SelectorFromChainIdAndFamily(test.chainId, "") if test.expectErr { require.Error(t, err1) @@ -135,7 +135,11 @@ func Test_TestChainIds(t *testing.T) { assert.Equal(t, len(chainIds), len(testSelectorsMap), "Should return correct number of test chain ids") for _, chainId := range chainIds { - _, exist := testSelectorsMap[chainId] + selector, err := SelectorFromChainId(chainId) + if err != nil { + return + } + _, exist := testSelectorsMap[selector] assert.True(t, exist) } } @@ -177,15 +181,15 @@ func Test_ChainNames(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - chainId, err1 := ChainIdFromName(test.chainName) - selector, err2 := SelectorFromChainIdAndFamily(chainId, "") + chainID, err1 := ChainIdFromNameAndFamily(test.chainName, FamilyEVM) + selector, err2 := SelectorFromChainIdAndFamily(chainID, "") if test.expectErr { require.Error(t, err1) require.Error(t, err2) return } require.NoError(t, err1) - assert.Equal(t, test.chainId, chainId) + assert.Equal(t, test.chainId, chainID) require.NoError(t, err2) detail, _ := selectorsMap[selector]