Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client/v2): prompt ui for any command #18555

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 0 additions & 68 deletions client/prompt_validation.go

This file was deleted.

39 changes: 0 additions & 39 deletions client/prompt_validation_test.go

This file was deleted.

6 changes: 4 additions & 2 deletions client/v2/autocli/flag/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ const (
AddressStringScalarType = "cosmos.AddressString"
ValidatorAddressStringScalarType = "cosmos.ValidatorAddressString"
ConsensusAddressStringScalarType = "cosmos.ConsensusAddressString"
PubkeyScalarType = "cosmos.Pubkey"

CoinScalarType = "cosmos.base.v1beta1.Coin"
PubkeyScalarType = "cosmos.Pubkey"
)

// Builder manages options for building pflag flags for protobuf messages.
Expand Down Expand Up @@ -64,7 +66,7 @@ func (b *Builder) init() {
b.messageFlagTypes = map[protoreflect.FullName]Type{}
b.messageFlagTypes["google.protobuf.Timestamp"] = timestampType{}
b.messageFlagTypes["google.protobuf.Duration"] = durationType{}
b.messageFlagTypes["cosmos.base.v1beta1.Coin"] = coinType{}
b.messageFlagTypes[CoinScalarType] = coinType{}
}

if b.scalarFlagTypes == nil {
Expand Down
2 changes: 1 addition & 1 deletion client/v2/autocli/flag/coin.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ func (c *coinValue) Set(stringValue string) error {
}

func (c *coinValue) Type() string {
return "cosmos.base.v1beta1.Coin"
return CoinScalarType
}
193 changes: 193 additions & 0 deletions client/v2/autocli/prompt/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package prompt

import (
"fmt"
"reflect"
"strconv"
"strings"

"cosmossdk.io/client/v2/autocli/flag"
addresscodec "cosmossdk.io/core/address"
"google.golang.org/protobuf/reflect/protoreflect"

"github.com/cosmos/cosmos-sdk/types/address"
"github.com/manifoldco/promptui"
)

const GovModuleName = "gov"

func Prompt(
addressCodec addresscodec.Codec,
validatorAddressCodec addresscodec.Codec,
consensusAddressCodec addresscodec.Codec,
promptPrefix string,
msg protoreflect.Message,
) (protoreflect.Message, error) {
fields := msg.Descriptor().Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
fieldName := string(field.Name())

// create prompts
prompt := promptui.Prompt{
Label: fmt.Sprintf("Enter %s %s", promptPrefix, fieldName),
Validate: ValidatePromptNotEmpty,
}

// signer field
if strings.EqualFold(fieldName, flag.GetSignerFieldName(msg.Descriptor())) {
// pre-fill with gov address
govAddr := address.Module(GovModuleName)
govAddrStr, err := addressCodec.BytesToString(govAddr)
if err != nil {
return msg, fmt.Errorf("failed to convert gov address to string: %w", err)
}

// note, we don't set prompt.Validate here because we need to get the scalar annotation
prompt.Default = govAddrStr
}

// validate address fields
scalarField, ok := flag.GetScalarType(field)
if ok {
switch scalarField {
case flag.AddressStringScalarType:
prompt.Validate = func(input string) error {
if _, err := addressCodec.StringToBytes(input); err != nil {
return fmt.Errorf("invalid address")
}

return nil
}
case flag.ValidatorAddressStringScalarType:
prompt.Validate = func(input string) error {
if _, err := validatorAddressCodec.StringToBytes(input); err != nil {
return fmt.Errorf("invalid validator address")
}

return nil
}
case flag.ConsensusAddressStringScalarType:
prompt.Validate = func(input string) error {
if _, err := consensusAddressCodec.StringToBytes(input); err != nil {
return fmt.Errorf("invalid consensus address")
}

return nil
}
case flag.CoinScalarType:
prompt.Validate = ValidatePromptCoins
default:
// prompt.Validate = ValidatePromptNotEmpty (we possibly don't want to force all fields to be non-empty)
prompt.Validate = nil
}
}

result, err := prompt.Run()
if err != nil {
return msg, fmt.Errorf("failed to prompt for %s: %w", fieldName, err)
}

switch field.Kind() {
case protoreflect.StringKind:
msg.Set(field, protoreflect.ValueOfString(result))
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
resultUint, err := strconv.ParseUint(result, 10, 0)
if err != nil {
return msg, fmt.Errorf("invalid value for int: %w", err)
}

msg.Set(field, protoreflect.ValueOfUint64(resultUint))
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
resultInt, err := strconv.ParseInt(result, 10, 0)
if err != nil {
return msg, fmt.Errorf("invalid value for int: %w", err)
}
// If a value was successfully parsed the ranges of:
// [minInt, maxInt]
// are within the ranges of:
// [minInt64, maxInt64]
// of which on 64-bit machines, which are most common,
// int==int64
msg.Set(field, protoreflect.ValueOfInt64(resultInt))
case protoreflect.BoolKind:
resultBool, err := strconv.ParseBool(result)
if err != nil {
return msg, fmt.Errorf("invalid value for bool: %w", err)
}

msg.Set(field, protoreflect.ValueOfBool(resultBool))
case protoreflect.MessageKind:
// TODO
default:
// skip any other types
continue // TODO(@julienrbrt) add support for other types
}
}

return msg, nil
}

func PromptStruct[T any](promptPrefix string, data T) (T, error) {
v := reflect.ValueOf(&data).Elem()
if v.Kind() == reflect.Interface {
v = reflect.ValueOf(data)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
}

for i := 0; i < v.NumField(); i++ {
// if the field is a struct skip or not slice of string or int then skip
switch v.Field(i).Kind() {
case reflect.Struct:
// TODO(@julienrbrt) in the future we can add a recursive call to Prompt
continue
case reflect.Slice:
if v.Field(i).Type().Elem().Kind() != reflect.String && v.Field(i).Type().Elem().Kind() != reflect.Int {
continue
}
}

// create prompts
prompt := promptui.Prompt{
Label: fmt.Sprintf("Enter %s %s", promptPrefix, strings.Title(v.Type().Field(i).Name)), // nolint:staticcheck // strings.Title has a better API
Validate: ValidatePromptNotEmpty,
}

fieldName := strings.ToLower(v.Type().Field(i).Name)

result, err := prompt.Run()
if err != nil {
return data, fmt.Errorf("failed to prompt for %s: %w", fieldName, err)
}

switch v.Field(i).Kind() {
case reflect.String:
v.Field(i).SetString(result)
case reflect.Int:
resultInt, err := strconv.ParseInt(result, 10, 0)
if err != nil {
return data, fmt.Errorf("invalid value for int: %w", err)
}
v.Field(i).SetInt(resultInt)
case reflect.Slice:
switch v.Field(i).Type().Elem().Kind() {
case reflect.String:
v.Field(i).Set(reflect.ValueOf([]string{result}))
case reflect.Int:
resultInt, err := strconv.ParseInt(result, 10, 0)
if err != nil {
return data, fmt.Errorf("invalid value for int: %w", err)
}

v.Field(i).Set(reflect.ValueOf([]int{int(resultInt)}))
}
default:
// skip any other types
continue
}
}

return data, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// has a data race and this code exposes it, but fixing it would require
// holding up the associated change to this.

package cli_test
package prompt_test

import (
"fmt"
Expand All @@ -14,16 +14,12 @@ import (
"testing"

"github.com/chzyer/readline"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"cosmossdk.io/x/gov/client/cli"
"cosmossdk.io/client/v2/autocli/prompt"
"cosmossdk.io/client/v2/internal/testpb"
)

type st struct {
I int
}

// Tests that we successfully report overflows in parsing ints
// See https://github.com/cosmos/cosmos-sdk/issues/13346
func TestPromptIntegerOverflow(t *testing.T) {
Expand All @@ -48,10 +44,10 @@ func TestPromptIntegerOverflow(t *testing.T) {
fin, fw := readline.NewFillableStdin(os.Stdin)
readline.Stdin = fin
_, err := fw.Write([]byte(overflowStr + "\n"))
assert.NoError(t, err)
require.NoError(t, err)

v, err := cli.Prompt(st{}, "")
assert.Equal(t, st{}, v, "expected a value of zero")
v, err := prompt.Prompt(mockAddressCodec{}, mockAddressCodec{}, mockAddressCodec{}, "", (&testpb.MsgRequest{}).ProtoReflect())
require.Equal(t, (&testpb.MsgRequest{}).ProtoReflect(), v, "expected a value of zero")
require.NotNil(t, err, "expected a report of an overflow")
require.Contains(t, err.Error(), "range")
})
Expand Down Expand Up @@ -80,10 +76,21 @@ func TestPromptParseInteger(t *testing.T) {
fin, fw := readline.NewFillableStdin(os.Stdin)
readline.Stdin = fin
_, err := fw.Write([]byte(tc.in + "\n"))
assert.NoError(t, err)
v, err := cli.Prompt(st{}, "")
assert.Nil(t, err, "expected a nil error")
assert.Equal(t, tc.want, v.I, "expected %d = %d", tc.want, v.I)
require.NoError(t, err)
v, err := prompt.Prompt(mockAddressCodec{}, mockAddressCodec{}, mockAddressCodec{}, "", (&testpb.MsgRequest{}).ProtoReflect())
require.Nil(t, err, "expected a nil error")
require.NotNil(t, v)
// require.Equal(t, tc.want, v.I, "expected %d = %d", tc.want, v.I)
})
}
}

type mockAddressCodec struct{}

func (mockAddressCodec) BytesToString([]byte) (string, error) {
return "cosmos1y74p8wyy4enfhfn342njve6cjmj5c8dtl6emdk", nil
}

func (mockAddressCodec) StringToBytes(string) ([]byte, error) {
return nil, nil
}
Loading
Loading