Skip to content

Commit 99d6c94

Browse files
kerumetogvisor-bot
authored andcommitted
Hookup nftables filtering into IPv4 and IPv6 paths in netstack.
This change adds functionality to traverse the nftables ruleset in the IPv4 and IPv6 network endpoints within netstack. Updates #11778 PiperOrigin-RevId: 799463692
1 parent de2484a commit 99d6c94

File tree

4 files changed

+158
-25
lines changed

4 files changed

+158
-25
lines changed

pkg/sentry/socket/netlink/netfilter/protocol.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ func (p *Protocol) deleteChain(nft *nftables.NFTables, attrs map[uint16]nlmsg.By
678678
const NFT_RULE_MAXEXPRS = 128
679679

680680
// newRule creates a new rule in the given chain.
681-
func (p *Protocol) newRule(nft *nftables.NFTables, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, msgFlags uint16, ms *nlmsg.MessageSet) *syserr.AnnotatedError {
681+
func (p *Protocol) newRule(nft *nftables.NFTables, st *stack.Stack, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, msgFlags uint16, ms *nlmsg.MessageSet) *syserr.AnnotatedError {
682682
tabNameBytes, ok := attrs[linux.NFTA_RULE_TABLE]
683683
if !ok {
684684
return syserr.NewAnnotatedError(syserr.ErrInvalidArgument, "Nftables: NFTA_CHAIN_TABLE attribute is malformed or not found")
@@ -818,6 +818,10 @@ func (p *Protocol) newRule(nft *nftables.NFTables, attrs map[uint16]nlmsg.BytesV
818818
return err
819819
}
820820

821+
// Once we have a at least one rule registered on a base chain, nftables can
822+
// be called to potentially filter the packet.
823+
st.SetNFTablesConfigured(chain.IsBaseChain())
824+
821825
// TODO - b/434244017: Support validating the entire table before returning.
822826
return nil
823827
}
@@ -1141,7 +1145,7 @@ func (p *Protocol) processBatchMessage(ctx context.Context, s *netlink.Socket, b
11411145
case linux.NFT_MSG_DELCHAIN, linux.NFT_MSG_DESTROYCHAIN:
11421146
subErr = p.deleteChain(nftCopy, attrs, family, hdr.Flags, hdr.NetFilterMsgType(), ms)
11431147
case linux.NFT_MSG_NEWRULE:
1144-
subErr = p.newRule(nftCopy, attrs, family, hdr.Flags, ms)
1148+
subErr = p.newRule(nftCopy, st, attrs, family, hdr.Flags, ms)
11451149
case linux.NFT_MSG_DELRULE, linux.NFT_MSG_DESTROYRULE, linux.NFT_MSG_NEWSET,
11461150
linux.NFT_MSG_DELSET, linux.NFT_MSG_DESTROYSET, linux.NFT_MSG_NEWSETELEM,
11471151
linux.NFT_MSG_DELSETELEM, linux.NFT_MSG_DESTROYSETELEM,

pkg/tcpip/network/ipv4/ipv4.go

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -529,16 +529,27 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
529529
func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
530530
netHeader := header.IPv4(pkt.NetworkHeader().Slice())
531531
dstAddr := netHeader.DestinationAddress()
532+
stk := e.protocol.stack
532533

533-
// iptables filtering. All packets that reach here are locally
534-
// generated.
535-
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
536-
if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
534+
// iptables filtering. All packets that reach here are locally generated.
535+
outNicName := stk.FindNICNameFromID(e.nic.ID())
536+
if ok := stk.IPTables().CheckOutput(pkt, r, outNicName); !ok {
537537
// iptables is telling us to drop the packet.
538538
e.stats.ip.IPTablesOutputDropped.Increment()
539539
return nil
540540
}
541541

542+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
543+
ipCheck := nft.CheckOutput(pkt, stack.IP)
544+
// nftables allows us to use the inet family to apply rules to both IPv4
545+
// and IPv6 packets.
546+
inetCheck := nft.CheckOutput(pkt, stack.Inet)
547+
if !ipCheck || !inetCheck {
548+
// nftables is telling us to drop the packet.
549+
return nil
550+
}
551+
}
552+
542553
// If the packet is manipulated as per DNAT Output rules, handle packet
543554
// based on destination address and do not send the packet to link
544555
// layer.
@@ -569,15 +580,25 @@ func (e *endpoint) writePacketPostRouting(r *stack.Route, pkt *stack.PacketBuffe
569580
return nil
570581
}
571582

583+
stk := e.protocol.stack
572584
// Postrouting NAT can only change the source address, and does not alter the
573585
// route or outgoing interface of the packet.
574-
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
575-
if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
586+
outNicName := stk.FindNICNameFromID(e.nic.ID())
587+
if ok := stk.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
576588
// iptables is telling us to drop the packet.
577589
e.stats.ip.IPTablesPostroutingDropped.Increment()
578590
return nil
579591
}
580592

593+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
594+
ipCheck := nft.CheckOutput(pkt, stack.IP)
595+
inetCheck := nft.CheckOutput(pkt, stack.Inet)
596+
if !ipCheck || !inetCheck {
597+
// nftables is telling us to drop the packet.
598+
return nil
599+
}
600+
}
601+
581602
stats := e.stats.ip
582603

583604
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(len(pkt.NetworkHeader().Slice())))
@@ -688,6 +709,15 @@ func (e *endpoint) forwardPacketWithRoute(route *stack.Route, pkt *stack.PacketB
688709
return nil
689710
}
690711

712+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
713+
ipCheck := nft.CheckForward(pkt, stack.IP)
714+
inetCheck := nft.CheckForward(pkt, stack.Inet)
715+
if !ipCheck || !inetCheck {
716+
// nftables is telling us to drop the packet.
717+
return nil
718+
}
719+
}
720+
691721
// We need to do a deep copy of the IP packet because
692722
// WriteHeaderIncludedPacket may modify the packet buffer, but we do
693723
// not own it.
@@ -788,6 +818,15 @@ func (e *endpoint) forwardUnicastPacket(pkt *stack.PacketBuffer) ip.ForwardingEr
788818
return nil
789819
}
790820

821+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
822+
ipCheck := nft.CheckForward(pkt, stack.IP)
823+
inetCheck := nft.CheckForward(pkt, stack.Inet)
824+
if !ipCheck || !inetCheck {
825+
// nftables is telling us to drop the packet.
826+
return nil
827+
}
828+
}
829+
791830
// The packet originally arrived on e so provide its NIC as the input NIC.
792831
ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
793832
return nil
@@ -855,7 +894,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
855894
}
856895
}
857896

858-
if e.protocol.stack.HandleLocal() {
897+
stk := e.protocol.stack
898+
if stk.HandleLocal() {
859899
addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().Slice()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint, true /* readOnly */)
860900
if addressEndpoint != nil {
861901
// The source address is one of our own, so we never should have gotten
@@ -867,12 +907,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
867907
}
868908

869909
// Loopback traffic skips the prerouting chain.
870-
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
871-
if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
910+
inNicName := stk.FindNICNameFromID(e.nic.ID())
911+
if ok := stk.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
872912
// iptables is telling us to drop the packet.
873913
stats.IPTablesPreroutingDropped.Increment()
874914
return
875915
}
916+
917+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
918+
ipCheck := nft.CheckPrerouting(pkt, stack.IP)
919+
inetCheck := nft.CheckPrerouting(pkt, stack.Inet)
920+
if !ipCheck || !inetCheck {
921+
// nftables is telling us to drop the packet.
922+
return
923+
}
924+
}
876925
}
877926
// CheckPrerouting can modify the backing storage of the packet, so refresh
878927
// the header.
@@ -1206,14 +1255,24 @@ func (e *endpoint) handleForwardingError(err ip.ForwardingError) {
12061255

12071256
func (e *endpoint) deliverPacketLocally(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
12081257
stats := e.stats
1258+
stk := e.protocol.stack
12091259
// iptables filtering. All packets that reach here are intended for
12101260
// this machine and will not be forwarded.
1211-
if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
1261+
if ok := stk.IPTables().CheckInput(pkt, inNICName); !ok {
12121262
// iptables is telling us to drop the packet.
12131263
stats.ip.IPTablesInputDropped.Increment()
12141264
return
12151265
}
12161266

1267+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
1268+
ipCheck := nft.CheckInput(pkt, stack.IP)
1269+
inetCheck := nft.CheckInput(pkt, stack.Inet)
1270+
if !ipCheck || !inetCheck {
1271+
// nftables is telling us to drop the packet.
1272+
return
1273+
}
1274+
}
1275+
12171276
if h.More() || h.FragmentOffset() != 0 {
12181277
if pkt.Data().Size()+len(pkt.TransportHeader().Slice()) == 0 {
12191278
// Drop the packet as it's marked as a fragment but has

pkg/tcpip/network/ipv6/ipv6.go

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -817,15 +817,24 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
817817
return err
818818
}
819819

820-
// iptables filtering. All packets that reach here are locally
821-
// generated.
822-
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
823-
if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
820+
stk := e.protocol.stack
821+
// iptables filtering. All packets that reach here are locally generated.
822+
outNicName := stk.FindNICNameFromID(e.nic.ID())
823+
if ok := stk.IPTables().CheckOutput(pkt, r, outNicName); !ok {
824824
// iptables is telling us to drop the packet.
825825
e.stats.ip.IPTablesOutputDropped.Increment()
826826
return nil
827827
}
828828

829+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
830+
ip6Check := nft.CheckOutput(pkt, stack.IP6)
831+
inetCheck := nft.CheckOutput(pkt, stack.Inet)
832+
if !ip6Check || !inetCheck {
833+
// nftables is telling us to drop the packet.
834+
return nil
835+
}
836+
}
837+
829838
// If the packet is manipulated as per DNAT Output rules, handle packet
830839
// based on destination address and do not send the packet to link
831840
// layer.
@@ -856,15 +865,25 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
856865
return nil
857866
}
858867

868+
stk := e.protocol.stack
859869
// Postrouting NAT can only change the source address, and does not alter the
860870
// route or outgoing interface of the packet.
861-
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
862-
if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
871+
outNicName := stk.FindNICNameFromID(e.nic.ID())
872+
if ok := stk.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
863873
// iptables is telling us to drop the packet.
864874
e.stats.ip.IPTablesPostroutingDropped.Increment()
865875
return nil
866876
}
867877

878+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
879+
ip6Check := nft.CheckPostrouting(pkt, stack.IP6)
880+
inetCheck := nft.CheckPostrouting(pkt, stack.Inet)
881+
if !ip6Check || !inetCheck {
882+
// nftables is telling us to drop the packet.
883+
return nil
884+
}
885+
}
886+
868887
stats := e.stats.ip
869888
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(len(pkt.NetworkHeader().Slice())))
870889
if err != nil {
@@ -1005,6 +1024,15 @@ func (e *endpoint) forwardUnicastPacket(pkt *stack.PacketBuffer) ip.ForwardingEr
10051024
return nil
10061025
}
10071026

1027+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
1028+
ip6Check := nft.CheckForward(pkt, stack.IP6)
1029+
inetCheck := nft.CheckForward(pkt, stack.Inet)
1030+
if !ip6Check || !inetCheck {
1031+
// nftables is telling us to drop the packet.
1032+
return nil
1033+
}
1034+
}
1035+
10081036
// The packet originally arrived on e so provide its NIC as the input NIC.
10091037
ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
10101038
return nil
@@ -1046,6 +1074,15 @@ func (e *endpoint) forwardPacketWithRoute(route *stack.Route, pkt *stack.PacketB
10461074
return nil
10471075
}
10481076

1077+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
1078+
ip6Check := nft.CheckForward(pkt, stack.IP6)
1079+
inetCheck := nft.CheckForward(pkt, stack.Inet)
1080+
if !ip6Check || !inetCheck {
1081+
// nftables is telling us to drop the packet.
1082+
return nil
1083+
}
1084+
}
1085+
10491086
hopLimit := h.HopLimit()
10501087

10511088
// We need to do a deep copy of the IP packet because
@@ -1121,7 +1158,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
11211158
}
11221159
}
11231160

1124-
if e.protocol.stack.HandleLocal() {
1161+
stk := e.protocol.stack
1162+
if stk.HandleLocal() {
11251163
addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().Slice()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint, true /* readOnly */)
11261164
if addressEndpoint != nil {
11271165
// The source address is one of our own, so we never should have gotten
@@ -1133,12 +1171,23 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
11331171
}
11341172

11351173
// Loopback traffic skips the prerouting chain.
1136-
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
1137-
if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
1174+
inNicName := stk.FindNICNameFromID(e.nic.ID())
1175+
if ok := stk.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
11381176
// iptables is telling us to drop the packet.
11391177
stats.IPTablesPreroutingDropped.Increment()
11401178
return
11411179
}
1180+
1181+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
1182+
ipv6Check := nft.CheckPrerouting(pkt, stack.IP6)
1183+
// nftables allows us to use the inet family to apply rules to both IPv4
1184+
// and IPv6 packets.
1185+
inetCheck := nft.CheckPrerouting(pkt, stack.Inet)
1186+
if !ipv6Check || !inetCheck {
1187+
// nftables is telling us to drop the packet.
1188+
return
1189+
}
1190+
}
11421191
}
11431192

11441193
// CheckPrerouting can modify the backing storage of the packet, so refresh
@@ -1384,15 +1433,22 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer,
13841433

13851434
func (e *endpoint) deliverPacketLocally(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
13861435
stats := e.stats.ip
1387-
1388-
// iptables filtering. All packets that reach here are intended for
1389-
// this machine and need not be forwarded.
1390-
if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
1436+
stk := e.protocol.stack
1437+
if ok := stk.IPTables().CheckInput(pkt, inNICName); !ok {
13911438
// iptables is telling us to drop the packet.
13921439
stats.IPTablesInputDropped.Increment()
13931440
return
13941441
}
13951442

1443+
if nft := stk.NFTables(); nft != nil && stk.IsNFTablesConfigured() {
1444+
ip6Check := nft.CheckInput(pkt, stack.IP6)
1445+
inetCheck := nft.CheckInput(pkt, stack.Inet)
1446+
if !ip6Check || !inetCheck {
1447+
// nftables is telling us to drop the packet.
1448+
return
1449+
}
1450+
}
1451+
13961452
// Any returned error is only useful for terminating execution early, but
13971453
// we have nothing left to do, so we can drop it.
13981454
_ = e.processExtensionHeaders(h, pkt, false /* forwarding */)

pkg/tcpip/stack/stack.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ type Stack struct {
123123
// nftables is the nftables interface for packet filtering and manipulation rules.
124124
nftables NFTablesInterface `state:"nosave"`
125125

126+
// nftablesConfigured indicates whether NFTables is configured with at
127+
// least one rule on a chain at a network hook.
128+
nftablesConfigured atomicbitops.Bool
129+
126130
// restoredEndpoints is a list of endpoints that need to be restored if the
127131
// stack is being restored.
128132
restoredEndpoints []RestoredEndpoint
@@ -2254,6 +2258,16 @@ func (s *Stack) SetNFTables(nft NFTablesInterface) {
22542258
s.nftables = nft
22552259
}
22562260

2261+
// IsNFTablesConfigured returns true if the stack has nftables configured.
2262+
func (s *Stack) IsNFTablesConfigured() bool {
2263+
return s.nftablesConfigured.Load()
2264+
}
2265+
2266+
// SetNFTablesConfigured sets whether the stack has nftables configured.
2267+
func (s *Stack) SetNFTablesConfigured(configured bool) {
2268+
s.nftablesConfigured.Store(configured)
2269+
}
2270+
22572271
// ICMPLimit returns the maximum number of ICMP messages that can be sent
22582272
// in one second.
22592273
func (s *Stack) ICMPLimit() rate.Limit {

0 commit comments

Comments
 (0)