diff --git a/internal/report/report.go b/internal/report/report.go index 2d0d6c6..4a75a14 100644 --- a/internal/report/report.go +++ b/internal/report/report.go @@ -491,14 +491,13 @@ type ApplyAction struct { } func (a *ApplyAction) Unmarshal(b []byte) error { - var v []byte if len(b) < 1 { return errors.Errorf("ApplyAction Unmarshal: less than 1 bytes") - } else if len(b) < 2 { - // slice len might be 1 or 2; enlarge slice to 2 bytes at least - v = make([]byte, len(b)+1) - copy(v, b) } + + // slice len might be 1 or 2; enlarge slice to 2 bytes at least + v := make([]byte, max(2, len(b))) + copy(v, b) a.Flags = binary.LittleEndian.Uint16(v) return nil } diff --git a/internal/report/report_test.go b/internal/report/report_test.go new file mode 100644 index 0000000..8a08eff --- /dev/null +++ b/internal/report/report_test.go @@ -0,0 +1,55 @@ +package report_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/free5gc/go-upf/internal/report" +) + +func TestApplyAction0(t *testing.T) { + var act report.ApplyAction + e := act.Unmarshal([]byte{}) + assert.Error(t, e) +} + +func TestApplyAction1(t *testing.T) { + var act report.ApplyAction + e := act.Unmarshal([]byte{0x02}) + assert.NoError(t, e) + assert.Equal(t, uint16(0x0002), act.Flags) + assert.False(t, act.DROP()) + assert.True(t, act.FORW()) + assert.False(t, act.BUFF()) + assert.False(t, act.NOCP()) + assert.False(t, act.DUPL()) + assert.False(t, act.IPMA()) + assert.False(t, act.IPMD()) + assert.False(t, act.DFRT()) + assert.False(t, act.EDRT()) + assert.False(t, act.BDPN()) + assert.False(t, act.DDPN()) + assert.False(t, act.FSSM()) + assert.False(t, act.MBSU()) +} + +func TestApplyAction2(t *testing.T) { + var act report.ApplyAction + e := act.Unmarshal([]byte{0x0C, 0x00}) + assert.NoError(t, e) + assert.Equal(t, uint16(0x000C), act.Flags) + assert.False(t, act.DROP()) + assert.False(t, act.FORW()) + assert.True(t, act.BUFF()) + assert.True(t, act.NOCP()) + assert.False(t, act.DUPL()) + assert.False(t, act.IPMA()) + assert.False(t, act.IPMD()) + assert.False(t, act.DFRT()) + assert.False(t, act.EDRT()) + assert.False(t, act.BDPN()) + assert.False(t, act.DDPN()) + assert.False(t, act.FSSM()) + assert.False(t, act.MBSU()) +}