diff --git a/pkg/sentry/socket/netlink/nlmsg/message.go b/pkg/sentry/socket/netlink/nlmsg/message.go index e554e09a90..8e9b9291f0 100644 --- a/pkg/sentry/socket/netlink/nlmsg/message.go +++ b/pkg/sentry/socket/netlink/nlmsg/message.go @@ -374,6 +374,28 @@ func (v *BytesView) String() string { return string(b) } +// Uint8 converts the raw attribute value to uint8. +func (v *BytesView) Uint8() (uint8, bool) { + attr := []byte(*v) + val := primitive.Uint8(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return uint8(val), true +} + +// Uint16 converts the raw attribute value to uint16. +func (v *BytesView) Uint16() (uint16, bool) { + attr := []byte(*v) + val := primitive.Uint16(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return uint16(val), true +} + // Uint32 converts the raw attribute value to uint32. func (v *BytesView) Uint32() (uint32, bool) { attr := []byte(*v) @@ -396,6 +418,28 @@ func (v *BytesView) Uint64() (uint64, bool) { return uint64(val), true } +// Int8 converts the raw attribute value to int8. +func (v *BytesView) Int8() (int8, bool) { + attr := []byte(*v) + val := primitive.Int8(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return int8(val), true +} + +// Int16 converts the raw attribute value to int32. +func (v *BytesView) Int16() (int16, bool) { + attr := []byte(*v) + val := primitive.Int16(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return int16(val), true +} + // Int32 converts the raw attribute value to int32. func (v *BytesView) Int32() (int32, bool) { attr := []byte(*v) @@ -407,6 +451,17 @@ func (v *BytesView) Int32() (int32, bool) { return int32(val), true } +// Int64 converts the raw attribute value to int32. +func (v *BytesView) Int64() (int64, bool) { + attr := []byte(*v) + val := primitive.Int64(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return int64(val), true +} + // NetToHostU16 converts a uint16 in network byte order to // host byte order value. func NetToHostU16(v uint16) uint16 { diff --git a/pkg/sentry/socket/netlink/nlmsg/message_test.go b/pkg/sentry/socket/netlink/nlmsg/message_test.go index 4ba142a206..218d2446d2 100644 --- a/pkg/sentry/socket/netlink/nlmsg/message_test.go +++ b/pkg/sentry/socket/netlink/nlmsg/message_test.go @@ -342,6 +342,54 @@ func TestBytesView(t *testing.T) { ok: false, value: 0, }, + bytesViewTest[uint16]{ + desc: "Convert BytesView to uint16", + input: nlmsg.BytesView([]byte{7, 0}), + ok: true, + value: 7, + }, + bytesViewTest[uint16]{ + desc: "Failed convert BytesView to uint16", + input: nlmsg.BytesView([]byte{7}), + ok: false, + value: 0, + }, + bytesViewTest[int16]{ + desc: "Convert BytesView to int16", + input: nlmsg.BytesView([]byte{8, 0}), + ok: true, + value: 8, + }, + bytesViewTest[int16]{ + desc: "Failed convert BytesView to int16", + input: nlmsg.BytesView([]byte{8}), + ok: false, + value: 0, + }, + bytesViewTest[uint8]{ + desc: "Convert BytesView to uint8", + input: nlmsg.BytesView([]byte{7}), + ok: true, + value: 7, + }, + bytesViewTest[uint8]{ + desc: "Failed convert BytesView to uint8", + input: nlmsg.BytesView([]byte{}), + ok: false, + value: 0, + }, + bytesViewTest[int8]{ + desc: "Convert BytesView to int8", + input: nlmsg.BytesView([]byte{8}), + ok: true, + value: 8, + }, + bytesViewTest[int8]{ + desc: "Failed convert BytesView to int8", + input: nlmsg.BytesView([]byte{}), + ok: false, + value: 0, + }, } for _, test := range tests { switch test.(type) { @@ -369,6 +417,42 @@ func TestBytesView(t *testing.T) { if ok && value != tst.value { t.Errorf("%v: BytesView.Int32() got %v, want %v", tst.desc, value, tst.value) } + case bytesViewTest[uint16]: + tst := test.(bytesViewTest[uint16]) + value, ok := tst.input.Uint16() + if ok != tst.ok { + t.Errorf("%v: BytesView.Uint16() got ok = %v, want %v", tst.desc, ok, tst.ok) + } + if ok && value != tst.value { + t.Errorf("%v: BytesView.Uint16() got %v, want %v", tst.desc, value, tst.value) + } + case bytesViewTest[int16]: + tst := test.(bytesViewTest[int16]) + value, ok := tst.input.Int16() + if ok != tst.ok { + t.Errorf("%v: BytesView.Int16() got ok = %v, want %v", tst.desc, ok, tst.ok) + } + if ok && value != tst.value { + t.Errorf("%v: BytesView.Int16() got %v, want %v", tst.desc, value, tst.value) + } + case bytesViewTest[uint8]: + tst := test.(bytesViewTest[uint8]) + value, ok := tst.input.Uint8() + if ok != tst.ok { + t.Errorf("%v: BytesView.Uint8() got ok = %v, want %v", tst.desc, ok, tst.ok) + } + if ok && value != tst.value { + t.Errorf("%v: BytesView.Uint8() got %v, want %v", tst.desc, value, tst.value) + } + case bytesViewTest[int8]: + tst := test.(bytesViewTest[int8]) + value, ok := tst.input.Int8() + if ok != tst.ok { + t.Errorf("%v: BytesView.Int8() got ok = %v, want %v", tst.desc, ok, tst.ok) + } + if ok && value != tst.value { + t.Errorf("%v: BytesView.Int8() got %v, want %v", tst.desc, value, tst.value) + } default: t.Errorf("BytesView %T not support", t) }