Skip to content

Commit

Permalink
feat: add more default type for binding (#1056)
Browse files Browse the repository at this point in the history
#### What type of PR is this?
<!--
Add one of the following kinds:

build: Changes that affect the build system or external dependencies
(example scopes: gulp, broccoli, npm)
ci: Changes to our CI configuration files and scripts (example scopes:
Travis, Circle, BrowserStack, SauceLabs)
docs: Documentation only changes
feat: A new feature
optimize: A new optimization
fix: A bug fix
perf: A code change that improves performance
refactor: A code change that neither fixes a bug nor adds a feature
style: Changes that do not affect the meaning of the code (white space,
formatting, missing semi-colons, etc)
test: Adding missing tests or correcting existing tests
chore: Changes to the build process or auxiliary tools and libraries
such as documentation generation
-->
feat
#### Check the PR title.
<!--
The description of the title will be attached in Release Notes, 
so please describe it from user-oriented, what this PR does / why we
need it.
Please check your PR title with the below requirements:
-->
- [ ] This PR title match the format: \<type\>(optional scope):
\<description\>
- [ ] The description of this PR title is user-oriented and clear enough
for others to understand.
- [ ] Attach the PR updating the user documentation if the current PR
requires user awareness at the usage level. [User docs
repo](https://github.com/cloudwego/cloudwego.github.io)


#### (Optional) Translate the PR title into Chinese.
为参数绑定提供更多的默认类型支持

#### (Optional) More detailed description for this PR(en: English/zh:
Chinese).
<!--
Provide more detailed info for review(e.g., it's recommended to provide
perf data if this is a perf type PR).
-->
en: 为参数绑定提供更多的默认类型支持
zh(optional): add more default type for binding

- 增加了 map、struct、slice 等类型的默认值处理
- 降低 default value 的优先级,避免在做完 json.unmarshal() 后,default value 将其值覆盖

#### (Optional) Which issue(s) this PR fixes:
<!--
Automatically closes linked issue when PR is merged.
Eg: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->

#### (Optional) The PR that updates user documentation:
<!--
If the current PR requires user awareness at the usage level, please
submit a PR to update user docs. [User docs
repo](https://github.com/cloudwego/cloudwego.github.io)
-->
  • Loading branch information
FGYFFFF authored Apr 12, 2024
2 parents b7cbc9d + 4742b2e commit 3b3296c
Show file tree
Hide file tree
Showing 13 changed files with 247 additions and 108 deletions.
6 changes: 5 additions & 1 deletion _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ HeaderReferer = "HeaderReferer"
expectedReferer = "expectedReferer"
Referer = "Referer"
O_WRONLY = "O_WRONLY"
WRONLY = "WRONLY"
WRONLY = "WRONLY"
ome = "ome"
ifModifiedSice = "ifModifiedSice"
hd = "hd"
pn = "pn"
51 changes: 47 additions & 4 deletions pkg/app/server/binding/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,11 @@ func TestBind_ZeroValueBind(t *testing.T) {

func TestBind_DefaultValueBind(t *testing.T) {
var s struct {
A int `default:"15"`
B float64 `query:"b" default:"17"`
C []int `default:"15"`
D []string `default:"qwe"`
A int `default:"15"`
B float64 `query:"b" default:"17"`
C []int `default:"[15]"`
D []string `default:"['qwe','asd']"`
F [2]string `default:"['qwe','asd','zxc']"`
}
req := newMockRequest().
SetRequestURI("http://foobar.com")
Expand All @@ -432,7 +433,23 @@ func TestBind_DefaultValueBind(t *testing.T) {
assert.DeepEqual(t, 15, s.A)
assert.DeepEqual(t, float64(17), s.B)
assert.DeepEqual(t, 15, s.C[0])
assert.DeepEqual(t, 2, len(s.D))
assert.DeepEqual(t, "qwe", s.D[0])
assert.DeepEqual(t, "asd", s.D[1])
assert.DeepEqual(t, 2, len(s.F))
assert.DeepEqual(t, "qwe", s.F[0])
assert.DeepEqual(t, "asd", s.F[1])

var s2 struct {
F [2]string `default:"['qwe']"`
}
err = DefaultBinder().Bind(req.Req, &s2, nil)
if err != nil {
t.Fatal(err)
}
assert.DeepEqual(t, 2, len(s2.F))
assert.DeepEqual(t, "qwe", s2.F[0])
assert.DeepEqual(t, "", s2.F[1])

var d struct {
D [2]string `default:"qwe"`
Expand Down Expand Up @@ -1549,6 +1566,32 @@ func TestBind_Issue1015(t *testing.T) {
assert.DeepEqual(t, "asd", result.A)
}

func TestBind_JSONWithDefault(t *testing.T) {
type Req struct {
J1 string `json:"j1" default:"j1default"`
}

req := newMockRequest().
SetJSONContentType().
SetBody([]byte(`{"j1":"j1"}`))
var result Req
err := DefaultBinder().Bind(req.Req, &result, nil)
if err != nil {
t.Error(err)
}
assert.DeepEqual(t, "j1", result.J1)

result = Req{}
req = newMockRequest().
SetJSONContentType().
SetBody([]byte(`{"j2":"j2"}`))
err = DefaultBinder().Bind(req.Req, &result, nil)
if err != nil {
t.Error(err)
}
assert.DeepEqual(t, "j1default", result.J1)
}

func TestBind_WithoutPreBindForTag(t *testing.T) {
type BaseQuery struct {
Action string `query:"Action" binding:"required"`
Expand Down
7 changes: 5 additions & 2 deletions pkg/app/server/binding/internal/decoder/base_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa
var defaultValue string
for _, tagInfo := range d.tagInfos {
if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag {
defaultValue = tagInfo.Default
if tagInfo.Key == jsonTag {
defaultValue = tagInfo.Default
found := checkRequireJSON(req, tagInfo)
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request body does not have this parameter '%s'", d.fieldName, tagInfo.JSONName)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) {
defaultValue = ""
}
}
continue
}
Expand All @@ -94,7 +97,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa
return err
}
if len(text) == 0 && len(defaultValue) != 0 {
text = defaultValue
text = toDefaultValue(d.fieldType, defaultValue)
}
if !exist && len(text) == 0 {
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.
var defaultValue string
for _, tagInfo := range d.tagInfos {
if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag {
defaultValue = tagInfo.Default
if tagInfo.Key == jsonTag {
defaultValue = tagInfo.Default
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) {
defaultValue = ""
}
}
continue
}
text, exist = tagInfo.Getter(req, params, tagInfo.Value)
Expand All @@ -73,7 +78,7 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.
return nil
}
if len(text) == 0 && len(defaultValue) != 0 {
text = defaultValue
text = toDefaultValue(d.fieldType, defaultValue)
}

v, err := d.decodeFunc(req, params, text)
Expand Down
2 changes: 1 addition & 1 deletion pkg/app/server/binding/internal/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func getFieldDecoder(pInfo parentInfos, field reflect.StructField, index int, by
// JSONName is like 'a.b.c' for 'required validate'
fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, pInfo.JSONName, config)
if len(fieldTagInfos) == 0 && !config.DisableDefaultTag {
fieldTagInfos = getDefaultFieldTags(field)
fieldTagInfos, newParentJSONName = getDefaultFieldTags(field, pInfo.JSONName)
}
if len(byTag) != 0 {
fieldTagInfos = getFieldTagInfoByTag(field, byTag)
Expand Down
9 changes: 9 additions & 0 deletions pkg/app/server/binding/internal/decoder/gjson_required.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,12 @@ func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool {
}
return true
}

func keyExist(req *protocol.Request, tagInfo TagInfo) bool {
ct := bytesconv.B2s(req.Header.ContentType())
if utils.FilterContentType(ct) != consts.MIMEApplicationJSON {
return false
}
result := gjson.GetBytes(req.Body(), tagInfo.JSONName)
return result.Exists()
}
7 changes: 5 additions & 2 deletions pkg/app/server/binding/internal/decoder/map_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,17 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par
var defaultValue string
for _, tagInfo := range d.tagInfos {
if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag {
defaultValue = tagInfo.Default
if tagInfo.Key == jsonTag {
defaultValue = tagInfo.Default
found := checkRequireJSON(req, tagInfo)
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) {
defaultValue = ""
}
}
continue
}
Expand All @@ -86,7 +89,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par
return err
}
if len(text) == 0 && len(defaultValue) != 0 {
text = defaultValue
text = toDefaultValue(d.fieldType, defaultValue)
}
if !exist && len(text) == 0 {
return nil
Expand Down
47 changes: 15 additions & 32 deletions pkg/app/server/binding/internal/decoder/slice_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P
var texts []string
var defaultValue string
var bindRawBody bool
var isDefault bool
for _, tagInfo := range d.tagInfos {
if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag {
defaultValue = tagInfo.Default
if tagInfo.Key == jsonTag {
defaultValue = tagInfo.Default
found := checkRequireJSON(req, tagInfo)
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { //
defaultValue = ""
}
}
continue
}
Expand All @@ -91,7 +95,9 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P
return err
}
if len(texts) == 0 && len(defaultValue) != 0 {
defaultValue = toDefaultValue(d.fieldType, defaultValue)
texts = append(texts, defaultValue)
isDefault = true
}
if len(texts) == 0 {
return nil
Expand All @@ -113,7 +119,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P
}

if d.isArray {
if len(texts) != field.Len() {
if len(texts) != field.Len() && !isDefault {
return fmt.Errorf("%q is not valid value for %s", texts, field.Type().String())
}
} else {
Expand All @@ -135,6 +141,13 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P
elemKind = t.Kind()
ptrDepth++
}
if isDefault {
err = hJson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface())
if err != nil {
return fmt.Errorf("using '%s' to unmarshal field '%s: %s' failed, %v", texts[0], d.fieldName, d.fieldType.String(), err)
}
return nil
}

for idx, text := range texts {
var vv reflect.Value
Expand Down Expand Up @@ -218,33 +231,3 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn
isArray: isArray,
}}, nil
}

func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (v reflect.Value, err error) {
v = reflect.New(elemType).Elem()
if customizedFunc, exist := config.TypeUnmarshalFuncs[elemType]; exist {
val, err := customizedFunc(req, params, text)
if err != nil {
return reflect.Value{}, err
}
return val, nil
}
switch elemType.Kind() {
case reflect.Struct:
err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface())
case reflect.Map:
err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface())
case reflect.Array, reflect.Slice:
// do nothing
default:
decoder, err := SelectTextDecoder(elemType)
if err != nil {
return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String())
}
err = decoder.UnmarshalString(text, v, config.LooseZeroMode)
if err != nil {
return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err)
}
}

return v, err
}
9 changes: 9 additions & 0 deletions pkg/app/server/binding/internal/decoder/sonic_required.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ func stringSliceForInterface(s string) (ret []interface{}) {
}
return
}

func keyExist(req *protocol.Request, tagInfo TagInfo) bool {
ct := bytesconv.B2s(req.Header.ContentType())
if utils.FilterContentType(ct) != consts.MIMEApplicationJSON {
return false
}
node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName)...)
return node.Exists()
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.
var defaultValue string
for _, tagInfo := range d.tagInfos {
if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag {
defaultValue = tagInfo.Default
if tagInfo.Key == jsonTag {
defaultValue = tagInfo.Default
found := checkRequireJSON(req, tagInfo)
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) {
defaultValue = ""
}
}
continue
}
Expand All @@ -63,7 +66,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.
return err
}
if len(text) == 0 && len(defaultValue) != 0 {
text = defaultValue
text = toDefaultValue(d.fieldType, defaultValue)
}
if !exist && len(text) == 0 {
return nil
Expand Down
8 changes: 5 additions & 3 deletions pkg/app/server/binding/internal/decoder/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func lookupFieldTags(field reflect.StructField, parentJSONName string, config *D
tagValue = field.Name
}
skip := false
jsonName := ""
jsonName := parentJSONName + "." + field.Name
if tag == jsonTag {
jsonName = parentJSONName + "." + tagValue
}
Expand Down Expand Up @@ -120,16 +120,18 @@ func lookupFieldTags(field reflect.StructField, parentJSONName string, config *D
return tagInfos, newParentJSONName, needValidate
}

func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) {
func getDefaultFieldTags(field reflect.StructField, parentJSONName string) (tagInfos []TagInfo, newParentJSONName string) {
defaultVal := ""
if val, ok := field.Tag.Lookup(defaultTag); ok {
defaultVal = val
}

tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, fileNameTag}
for _, tag := range tags {
tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal})
jsonName := strings.TrimPrefix(parentJSONName+"."+field.Name, ".")
tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal, JSONName: jsonName})
}
newParentJSONName = strings.TrimPrefix(parentJSONName+"."+field.Name, ".")

return
}
Expand Down
Loading

0 comments on commit 3b3296c

Please sign in to comment.