diff --git a/parser.go b/parser.go index 405d72f..43b747f 100644 --- a/parser.go +++ b/parser.go @@ -8,17 +8,19 @@ import ( "log" "reflect" "strconv" + "strings" ) var ( - config = map[string]any{} - defaults = map[string]any{} + config = map[string]any{} + flattenConfig = map[string]any{} + defaults = map[string]any{} ) // ParseConfig reads the config from an io.Reader. // // The config may be in JSON or in JSONC ( json with comment) -// Allowed format for comments are +// Allowed format for comments are // * single line ( //... ) // * multiligne ( /*...*/ ) func ParseConfig(jsoncConfig io.Reader) error { @@ -33,13 +35,12 @@ func ParseConfig(jsoncConfig io.Reader) error { d := json.NewDecoder(bytes.NewReader(cleanJson)) d.UseNumber() - var src map[string]any - err = d.Decode(&src) + err = d.Decode(&config) if err != nil { return fmt.Errorf("fail to parse JSON: %v", err) } - flatten("", src, config) + flatten("", config, flattenConfig) return nil } @@ -85,7 +86,6 @@ func removeComment(src []byte) []byte { output = append(output, b) } } - return output } @@ -124,41 +124,26 @@ func SetDefault(key string, value any) { default: defaults[key] = value } - } // GetString returns the value associated with the key as a string. func GetString(key string) string { - - val, ok := getValue(key) - if !ok { - return "" - } - return cast[string](val) + return getValue[string](key) } // GetBool returns the value associated with the key as a bool. func GetBool(key string) bool { - - val, ok := getValue(key) - if !ok { - return false - } - return cast[bool](val) + return getValue[bool](key) } // GetInt returns the value associated with the key as an int. func GetInt(key string) int { - val, ok := getValue(key) - if !ok { - return 0 - } + val := getValue[json.Number](key) - n := cast[json.Number](val) - i, err := strconv.ParseInt(string(n), 10, 64) + i, err := strconv.ParseInt(string(val), 10, 64) if err != nil { - panic("Can't parse number '%s' as an int") + panic(fmt.Sprintf("Can't parse number '%s' as an int", val)) } return int(i) } @@ -166,15 +151,11 @@ func GetInt(key string) int { // GetInt32 returns the value associated with the key as an int32 func GetInt32(key string) int32 { - val, ok := getValue(key) - if !ok { - return 0 - } + val := getValue[json.Number](key) - n := cast[json.Number](val) - i, err := strconv.ParseInt(string(n), 10, 32) + i, err := strconv.ParseInt(string(val), 10, 32) if err != nil { - panic("Can't parse number '%s' as an int32") + panic(fmt.Sprintf("Can't parse number '%s' as an int32", val)) } return int32(i) } @@ -182,15 +163,11 @@ func GetInt32(key string) int32 { // GetInt64 returns the value associated with the key as an int64. func GetInt64(key string) int64 { - val, ok := getValue(key) - if !ok { - return 0 - } + val := getValue[json.Number](key) - n := cast[json.Number](val) - i, err := strconv.ParseInt(string(n), 10, 64) + i, err := strconv.ParseInt(string(val), 10, 64) if err != nil { - panic("Can't parse number '%s' as an int64") + panic(fmt.Sprintf("Can't parse number '%s' as an int64", val)) } return i } @@ -198,15 +175,12 @@ func GetInt64(key string) int64 { // GetUint returns the value associated with the key as an uint. func GetUint(key string) uint { - val, ok := getValue(key) - if !ok { - return 0 - } + val := getValue[json.Number](key) n := cast[json.Number](val) i, err := strconv.ParseUint(string(n), 10, 64) if err != nil { - panic("Can't parse number '%s' as an uint") + panic(fmt.Sprintf("Can't parse number '%s' as an uint", val)) } return uint(i) } @@ -214,15 +188,11 @@ func GetUint(key string) uint { // GetUint32 returns the value associated with the key as an uint32. func GetUint32(key string) uint32 { - val, ok := getValue(key) - if !ok { - return 0 - } + val := getValue[json.Number](key) - n := cast[json.Number](val) - i, err := strconv.ParseUint(string(n), 10, 32) + i, err := strconv.ParseUint(string(val), 10, 32) if err != nil { - panic("Can't parse number '%s' as an uint32") + panic(fmt.Sprintf("Can't parse number '%s' as an uint32", val)) } return uint32(i) } @@ -230,15 +200,11 @@ func GetUint32(key string) uint32 { // GetUint64 returns the value associated with the key as an uint64. func GetUint64(key string) uint64 { - val, ok := getValue(key) - if !ok { - return 0 - } + val := getValue[json.Number](key) - n := cast[json.Number](val) - i, err := strconv.ParseUint(string(n), 10, 64) + i, err := strconv.ParseUint(string(val), 10, 64) if err != nil { - panic("Can't parse number '%s' as an uint64") + panic(fmt.Sprintf("Can't parse number '%s' as an uint64", val)) } return uint64(i) } @@ -246,42 +212,69 @@ func GetUint64(key string) uint64 { // GetFloat64 returns the value associated with the key as a float64. func GetFloat64(key string) float64 { - val, ok := getValue(key) - if !ok { - return 0.0 - } + val := getValue[json.Number](key) - n := cast[json.Number](val) - f, err := strconv.ParseFloat(string(n), 64) + f, err := strconv.ParseFloat(string(val), 64) if err != nil { - panic("Can't parse number '%s' as a float64") + panic(fmt.Sprintf("Can't parse number '%s' as an float64", val)) } return f } // GetAny returns any value associated with the key. func GetAny(key string) any { - v, _ := getValue(key) - return v + return getValue[any](key) } -func getValue(key string) (val any, ok bool) { - v, ok := config[key] - if !ok { - v, ok = getDefault(key) +// GetMap returns the map associated with the key. +// Numbers will be of type json.Number +// returns nil if the key does not exist +func GetMap(key string) map[string]any { + + nProp := strings.Split(key, ".") + nested := config + + for i, prop := range nProp { + + if i == len(nProp)-1 { + + if v, ok := nested[prop]; ok { + return cast[map[string]any](v) + } + continue + } + nested, _ = nested[prop].(map[string]any) + } + return getDefault[map[string]any](key) +} + +func getValue[T any](key string) (val T) { + + v, ok := flattenConfig[key] + if ok { + return cast[T](v) } - return v, ok + return getDefault[T](key) } -func getDefault(key string) (val any, ok bool) { +func getDefault[T any](key string) (val T) { + v, ok := defaults[key] - if !ok { - log.Printf("no value found for key '%s', using nil value instead", key) + if ok { + return cast[T](v) } - return v, ok + + log.Printf("no value found for key '%s', using nil value instead", key) + + var zeroVal T + if reflect.TypeOf(zeroVal) == reflect.TypeOf(json.Number("")) { + return cast[T](json.Number("0")) + } + return zeroVal } func cast[T any](v any) T { + s, ok := v.(T) if !ok { panic(fmt.Sprintf("'%v' is not a %s", v, reflect.TypeOf(s))) diff --git a/parser_test.go b/parser_test.go index 0050ddd..083af3b 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,6 +1,7 @@ package boa_test import ( + "encoding/json" "strings" "testing" @@ -367,6 +368,55 @@ func TestRemoveComment(t *testing.T) { } } +func TestGetMap(t *testing.T) { + + config := `{ + "root": { + "object": { + "key1": "val", + "key2": true, + "key3": 1 + } + } + }` + + loadConfig(t, config) + + result := boa.GetMap("root.object") + if want, got := "val", result["key1"]; want != got { + t.Errorf("expected '%v' but got '%v'", want, got) + } + if want, got := true, result["key2"]; want != got { + t.Errorf("expected '%v' but got '%v'", want, got) + } + if want, got := json.Number("1"), result["key3"]; want != got { + t.Errorf("expected '%v' but got '%v'", want, got) + } + + nonExistingMap := boa.GetMap("non_existing") + if nonExistingMap != nil { + t.Errorf("non existing key should return nil, but got '%v'", nonExistingMap) + } + + defer func(t *testing.T) { + got := recover() + if got == nil { + t.Error("val not of type map should trigger a panic") + } + want := "'val' is not a map[string]interface {}" + if want != got { + t.Errorf("expected '%s' but got '%s'", want, got) + } + }(t) + _ = boa.GetMap("root.object.key1") + + boa.SetDefault("non_existing", map[string]any{"k": "v"}) + nonExistingMap = boa.GetMap("non_existing") + if len(nonExistingMap) != 1 { + t.Errorf("after default set, non_existing should be a map of length 1, but got '%v'", nonExistingMap) + } +} + func loadConfig(t *testing.T, config string) { err := boa.ParseConfig(strings.NewReader(config)) if err != nil {