diff --git a/pkg/sentry/socket/netlink/netfilter/protocol.go b/pkg/sentry/socket/netlink/netfilter/protocol.go index 5deadd180d..adacd29b9e 100644 --- a/pkg/sentry/socket/netlink/netfilter/protocol.go +++ b/pkg/sentry/socket/netlink/netfilter/protocol.go @@ -442,7 +442,7 @@ func (p *Protocol) addChain(attrs map[uint16]nlmsg.BytesView, tab *nftables.Tabl return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Chain binding attribute is not supported for chains with a hook")) } - bcInfo, err = p.chainParseHook(nil, family, nlmsg.AttrsView(hookDataBytes)) + bcInfo, err = p.chainParseHook(nil, family, nlmsg.AttrsView(hookDataBytes), attrs) if err != nil { return err } @@ -494,7 +494,7 @@ func (p *Protocol) addChain(attrs map[uint16]nlmsg.BytesView, tab *nftables.Tabl // chainParseHook parses the hook attributes and returns a complete // BaseChainInfo. -func (p *Protocol) chainParseHook(chain *nftables.Chain, family stack.AddressFamily, hdata nlmsg.AttrsView) (*nftables.BaseChainInfo, *syserr.AnnotatedError) { +func (p *Protocol) chainParseHook(chain *nftables.Chain, family stack.AddressFamily, hdata nlmsg.AttrsView, attrs map[uint16]nlmsg.BytesView) (*nftables.BaseChainInfo, *syserr.AnnotatedError) { hookAttrs, ok := nftables.NfParse(hdata) if !ok { return nil, syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Failed to parse hook attributes")) @@ -530,7 +530,7 @@ func (p *Protocol) chainParseHook(chain *nftables.Chain, family stack.AddressFam // All families default to filter type. hookInfo.ChainType = nftables.BaseChainTypeFilter - if chainTypeBytes, ok := hookAttrs[linux.NFTA_CHAIN_TYPE]; ok { + if chainTypeBytes, ok := attrs[linux.NFTA_CHAIN_TYPE]; ok { // TODO - b/434243967: Support base chain types other than filter. switch chainType := chainTypeBytes.String(); chainType { case "filter": diff --git a/test/syscalls/linux/socket_netlink_netfilter.cc b/test/syscalls/linux/socket_netlink_netfilter.cc index 9998001e16..f981f4e218 100644 --- a/test/syscalls/linux/socket_netlink_netfilter.cc +++ b/test/syscalls/linux/socket_netlink_netfilter.cc @@ -932,7 +932,6 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidPolicy) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_table_request_buffer = @@ -955,6 +954,7 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidPolicy) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 5) @@ -1122,7 +1122,6 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidChainType) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1139,6 +1138,7 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidChainType) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1167,7 +1167,6 @@ TEST(NetlinkNetfilterTest, ErrNewNATBaseChainWithInvalidPriority) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1184,6 +1183,7 @@ TEST(NetlinkNetfilterTest, ErrNewNATBaseChainWithInvalidPriority) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1212,7 +1212,6 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewNetDevBaseChain) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1229,6 +1228,7 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewNetDevBaseChain) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1257,7 +1257,6 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewInetBaseChainAtIngress) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1274,6 +1273,7 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewInetBaseChainAtIngress) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1302,7 +1302,6 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewBaseChainWithChainCounters) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1319,6 +1318,7 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewBaseChainWithChainCounters) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .RawAttr(NFTA_CHAIN_COUNTERS, nullptr, 0) .Build()) @@ -1540,7 +1540,6 @@ TEST(NetlinkNetfilterTest, AddBaseChainWithDropPolicy) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1557,6 +1556,7 @@ TEST(NetlinkNetfilterTest, AddBaseChainWithDropPolicy) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1787,7 +1787,6 @@ TEST(NetlinkNetfilterTest, GetBaseChain) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1804,6 +1803,7 @@ TEST(NetlinkNetfilterTest, GetBaseChain) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .RawAttr(NFTA_CHAIN_USERDATA, test_user_data, expected_udata_size) @@ -1857,7 +1857,6 @@ TEST(NetlinkNetfilterTest, ErrDeleteChainWithNoTableNameSpecified) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1874,6 +1873,7 @@ TEST(NetlinkNetfilterTest, ErrDeleteChainWithNoTableNameSpecified) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1984,7 +1984,6 @@ TEST(NetlinkNetfilterTest, DeleteBaseChain) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -2001,6 +2000,7 @@ TEST(NetlinkNetfilterTest, DeleteBaseChain) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -2041,7 +2041,6 @@ TEST(NetlinkNetfilterTest, DeleteBaseChainByHandle) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -2058,6 +2057,7 @@ TEST(NetlinkNetfilterTest, DeleteBaseChainByHandle) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3)