diff --git a/list.go b/list.go index 0e1186b..2070ba4 100644 --- a/list.go +++ b/list.go @@ -30,7 +30,7 @@ interface. func (list *List) UnmarshalJSON(rawData []byte) error { // Create a sub-type here so when we call Unmarshal below, we don't recursively // call this function over and over - type MarshalList List + type UnmarshalList List // if our "List" is a single object, modify the JSON to make it into a list // by wrapping with "[ ]" @@ -38,7 +38,7 @@ func (list *List) UnmarshalJSON(rawData []byte) error { rawData = []byte(fmt.Sprintf("[%s]", rawData)) } - newList := MarshalList{} + newList := UnmarshalList{} err := json.Unmarshal(rawData, &newList) if err != nil { @@ -50,3 +50,21 @@ func (list *List) UnmarshalJSON(rawData []byte) error { return nil } + +/* +MarshalJSON returns a top level object for the "data" attribute if a single object. In +all other cases returns a JSON encoded list for "data". +*/ +func (list List) MarshalJSON() ([]byte, error) { + // avoid stack overflow by using this subtype for marshaling + type MarshalList List + marshalList := MarshalList(list) + count := len(marshalList) + + switch { + case count == 1: + return json.Marshal(marshalList[0]) + default: + return json.Marshal(marshalList) + } +}