Skip to content

Commit

Permalink
iptables-helper: Use iptables-restore to maintain rules.
Browse files Browse the repository at this point in the history
Replace usage in cloud-agent

Signed-off-by: Zhen Tang <[email protected]>
  • Loading branch information
lostcharlie committed Sep 20, 2023
1 parent 389308b commit 27a370c
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 37 deletions.
41 changes: 9 additions & 32 deletions pkg/cloud-agent/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,44 +68,21 @@ func (h IptablesHandler) maintainRules(remotePodCIDRs []string) {
}

h.helper.Mutex.Lock()
if err := h.syncForwardRules(); err != nil {
logger.Error(err, "failed to sync iptables forward chain")
} else {
logger.V(5).Info("iptables forward chain is synced")
}

if err := h.syncPostRoutingRules(); err != nil {
logger.Error(err, "failed to sync iptables post-routing chain")
h.helper.CreateFabEdgeForwardChain()
h.helper.NewMaintainForwardRulesForIPSet([]string{h.ipsetName})
h.helper.NewPreparePostRoutingChain()
h.helper.NewAddPostRoutingRuleForKubernetes()
h.helper.NewAddPostRoutingRulesForIPSet(h.ipsetName)

if err := h.helper.ReplaceRules(); err != nil {
logger.Error(err, "failed to sync iptables rules")
} else {
logger.V(5).Info("iptables post-routing chain is synced")
logger.V(5).Info("iptables rules is synced")
}
h.helper.Mutex.Unlock()
}

func (h IptablesHandler) syncForwardRules() (err error) {
if err = h.helper.ClearOrCreateFabEdgeForwardChain(); err != nil {
return err
}

if err = h.helper.MaintainForwardRulesForIPSet([]string{h.ipsetName}); err != nil {
return err
}

return nil
}

func (h IptablesHandler) syncPostRoutingRules() (err error) {
if err = h.helper.PreparePostRoutingChain(); err != nil {
return err
}

if err = h.helper.AddPostRoutingRuleForKubernetes(); err != nil {
return err
}

return h.helper.AddPostRoutingRulesForIPSet(h.ipsetName)
}

func (h IptablesHandler) syncRemotePodCIDRSet(remotePodCIDRs []string) error {
set := &ipsetutil.IPSet{
Name: h.ipsetName,
Expand Down
214 changes: 209 additions & 5 deletions pkg/util/iptables/iptables_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
package iptables

import (
"bytes"
"fmt"
"github.com/coreos/go-iptables/iptables"
"io"
"os/exec"
"strings"
"sync"
)

Expand All @@ -42,8 +46,11 @@ const (
)

type IPTablesHelper struct {
ipt *iptables.IPTables
Mutex sync.RWMutex
ipt *iptables.IPTables
protocol iptables.Protocol
restoreCommand string
ruleSets []IPTablesRuleSet
Mutex sync.RWMutex
}

func NewIPTablesHelper() (*IPTablesHelper, error) {
Expand All @@ -54,16 +61,171 @@ func NewIP6TablesHelper() (*IPTablesHelper, error) {
return doCreateIPTablesHelper(iptables.ProtocolIPv6)
}

func doCreateIPTablesHelper(protocol iptables.Protocol) (*IPTablesHelper, error) {
t, err := iptables.NewWithProtocol(protocol)
func doCreateIPTablesHelper(proto iptables.Protocol) (*IPTablesHelper, error) {
t, err := iptables.NewWithProtocol(proto)
if err != nil {
return nil, err
}
var command string
switch proto {
case iptables.ProtocolIPv4:
command = IPTablesRestoreCommand
case iptables.ProtocolIPv6:
command = IP6TablesRestoreCommand
}
return &IPTablesHelper{
ipt: t,
ipt: t,
protocol: proto,
restoreCommand: command,
ruleSets: NewRuleSets(),
}, err
}

func (h *IPTablesHelper) runRestoreCommand(args []string, stdin io.Reader) (string, string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer

args = append(args, "--wait")

cmd := exec.Command(h.restoreCommand, args...)
cmd.Stdout = &stdout
cmd.Stderr = &stderr
cmd.Stdin = stdin

if err := cmd.Run(); err != nil {
return stdout.String(), stderr.String(), err
}

return stdout.String(), stderr.String(), nil
}

func (h *IPTablesHelper) ReplaceRules() error {
rules := h.GenerateInputFromRuleSet()

stdout, stderr, err := h.runRestoreCommand([]string{}, bytes.NewBuffer([]byte(rules)))
if err != nil {
print(err)
print("out:", stdout)
print("err:", stderr)
return err
}
return nil
}

func (h *IPTablesHelper) isInternalChain(table string, chain string) bool {
if table == "filter" {
if chain == "INPUT" || chain == "OUTPUT" || chain == "FORWARD" {
return true
}
}
if table == "nat" {
if chain == "PREROUTING" || chain == "POSTROUTING" || chain == "OUTPUT" {
return true
}
}
if table == "mangle" {
if chain == "PREROUTING" || chain == "OUTPUT" || chain == "FORWARD" || chain == "INPUT" || chain == "POSTROUTING" {
return true
}
}
if table == "raw" {
if chain == "PREROUTING" || chain == "OUTPUT" {
return true
}
}
return false
}

func (h *IPTablesHelper) GenerateInputFromRuleSet() string {
ret := ""
for _, ruleSet := range h.ruleSets {
ret += "*" + ruleSet.table + "\n"
for _, chain := range ruleSet.chains {
var policy string
// For custom chains, we do not set default policy
if h.isInternalChain(ruleSet.table, chain) {
policy = "ACCEPT"
} else {
policy = "-"
}
ret += strings.Join([]string{":", chain, " " + policy + " [0:0]\n"}, "")
}

for _, ruleEntry := range ruleSet.rules {
line := strings.Join(append([]string{"-A", ruleEntry.chain}, ruleEntry.rule...), " ")
ret += line
ret += "\n"
}

ret += "COMMIT\n"
}
return ret
}

func (h *IPTablesHelper) findTable(table string) int {
for i, ruleSet := range h.ruleSets {
if ruleSet.table == table {
return i
}
}
return -1
}

func (h *IPTablesHelper) findChain(tableIndex int, chain string) int {
for i, elem := range h.ruleSets[tableIndex].chains {
if chain == elem {
return i
}
}
return -1
}

func (h *IPTablesHelper) CreateChain(table string, chain string) {
tableIndex := h.findTable(table)
if tableIndex == -1 {
h.ruleSets = append(h.ruleSets, IPTablesRuleSet{table: table, chains: []string{}, rules: []IPTablesRule{}})
tableIndex = len(h.ruleSets) - 1
}
chainIndex := h.findChain(tableIndex, chain)
if chainIndex == -1 {
h.ruleSets[tableIndex].chains = append(h.ruleSets[tableIndex].chains, chain)
}
}

func (h *IPTablesHelper) AppendUniqueRule(table string, chain string, rule ...string) {
// Prepare chain and table if not exist
tableIndex := h.findTable(table)
if tableIndex == -1 {
h.CreateChain(table, chain)
tableIndex = h.findTable(table)
}
chainIndex := h.findChain(tableIndex, chain)
if chainIndex == -1 {
h.CreateChain(table, chain)
chainIndex = h.findChain(tableIndex, chain)
}

for _, elem := range h.ruleSets[tableIndex].rules {
if elem.chain == chain && h.rulesEqual(elem.rule, rule) {
// Already Exist
return
}
}
h.ruleSets[tableIndex].rules = append(h.ruleSets[tableIndex].rules, IPTablesRule{chain: chain, rule: rule})
}

func (h *IPTablesHelper) rulesEqual(one, other []string) bool {
if len(one) != len(other) {
return false
}
for i, elem := range one {
if elem != other[i] {
return false
}
}
return true
}

func (h *IPTablesHelper) ClearOrCreateFabEdgePostRoutingChain() (err error) {
return h.ipt.ClearChain(TableNat, ChainFabEdgePostRouting)
}
Expand All @@ -72,10 +234,15 @@ func (h *IPTablesHelper) ClearOrCreateFabEdgeInputChain() (err error) {
return h.ipt.ClearChain(TableFilter, ChainFabEdgeInput)
}

// To remove
func (h *IPTablesHelper) ClearOrCreateFabEdgeForwardChain() (err error) {
return h.ipt.ClearChain(TableFilter, ChainFabEdgeForward)
}

func (h *IPTablesHelper) CreateFabEdgeForwardChain() {
h.CreateChain(TableFilter, ChainFabEdgeForward)
}

func (h *IPTablesHelper) ClearOrCreateFabEdgeNatOutgoingChain() (err error) {
return h.ipt.ClearChain(TableNat, ChainFabEdgeNatOutgoing)
}
Expand All @@ -101,6 +268,12 @@ func (h *IPTablesHelper) CheckOrCreateFabEdgeNatOutgoingChain() (err error) {
return h.checkOrCreateChain(TableNat, ChainFabEdgeNatOutgoing)
}

func (h *IPTablesHelper) NewPreparePostRoutingChain() {
h.CreateChain(TableNat, ChainFabEdgePostRouting)
h.AppendUniqueRule(TableNat, ChainPostRouting, "-j", ChainFabEdgePostRouting)
}

// To remove
func (h *IPTablesHelper) PreparePostRoutingChain() (err error) {
if err = h.ClearOrCreateFabEdgePostRoutingChain(); err != nil {
return err
Expand All @@ -118,6 +291,7 @@ func (h *IPTablesHelper) PreparePostRoutingChain() (err error) {
return nil
}

// To remove
func (h *IPTablesHelper) PrepareForwardChain() (err error) {
exists, err := h.ipt.Exists(TableFilter, ChainForward, "-j", ChainFabEdgeForward)
if err != nil {
Expand All @@ -132,6 +306,19 @@ func (h *IPTablesHelper) PrepareForwardChain() (err error) {
return nil
}

func (h *IPTablesHelper) NewMaintainForwardRulesForIPSet(ipsetNames []string) {
// Prepare
h.AppendUniqueRule(TableFilter, ChainForward, "-j", ChainFabEdgeForward)
// Add connection track rule
h.AppendUniqueRule(TableFilter, ChainFabEdgeForward, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT")
// Accept forward packets for ipset
for _, ipsetName := range ipsetNames {
h.AppendUniqueRule(TableFilter, ChainFabEdgeForward, "-m", "set", "--match-set", ipsetName, "src", "-j", "ACCEPT")
h.AppendUniqueRule(TableFilter, ChainFabEdgeForward, "-m", "set", "--match-set", ipsetName, "dst", "-j", "ACCEPT")
}
}

// To remove
func (h *IPTablesHelper) acceptForward(ipsetName string) (err error) {
if err = h.ipt.AppendUnique(TableFilter, ChainFabEdgeForward, "-m", "set", "--match-set", ipsetName, "src", "-j", "ACCEPT"); err != nil {
return err
Expand All @@ -144,6 +331,7 @@ func (h *IPTablesHelper) acceptForward(ipsetName string) (err error) {
return nil
}

// To remove
func (h *IPTablesHelper) addConnectionTrackRule() (err error) {
if err = h.ipt.AppendUnique(TableFilter, ChainFabEdgeForward, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"); err != nil {
return err
Expand All @@ -152,6 +340,7 @@ func (h *IPTablesHelper) addConnectionTrackRule() (err error) {
return nil
}

// To remove
func (h *IPTablesHelper) MaintainForwardRulesForIPSet(ipsetNames []string) (err error) {
if err = h.PrepareForwardChain(); err != nil {
return err
Expand Down Expand Up @@ -203,6 +392,15 @@ func (h *IPTablesHelper) MaintainNatOutgoingRulesForSubnets(subnets []string, ip
return nil, ""
}

func (h *IPTablesHelper) NewAddPostRoutingRuleForKubernetes() {
// If packets have 0x4000/0x4000 mark, then traffic should be handled by KUBE-POSTROUTING chain,
// otherwise traffic to nodePort service, sometimes load balancer service, won't be masqueraded,
// and this would cause response packets are dropped
h.CreateChain(TableNat, "KUBE-POSTROUTING")
h.AppendUniqueRule(TableNat, ChainFabEdgePostRouting, "-m", "mark", "--mark", "0x4000/0x4000", "-j", "KUBE-POSTROUTING")
}

// To remove
func (h *IPTablesHelper) AddPostRoutingRuleForKubernetes() (err error) {
// If packets have 0x4000/0x4000 mark, then traffic should be handled by KUBE-POSTROUTING chain,
// otherwise traffic to nodePort service, sometimes load balancer service, won't be masqueraded,
Expand All @@ -213,6 +411,12 @@ func (h *IPTablesHelper) AddPostRoutingRuleForKubernetes() (err error) {
return nil
}

func (h *IPTablesHelper) NewAddPostRoutingRulesForIPSet(ipsetName string) {
h.AppendUniqueRule(TableNat, ChainFabEdgePostRouting, "-m", "set", "--match-set", ipsetName, "dst", "-j", "ACCEPT")
h.AppendUniqueRule(TableNat, ChainFabEdgePostRouting, "-m", "set", "--match-set", ipsetName, "src", "-j", "ACCEPT")
}

// To remove
func (h *IPTablesHelper) AddPostRoutingRulesForIPSet(ipsetName string) (err error) {
if err = h.ipt.AppendUnique(TableNat, ChainFabEdgePostRouting, "-m", "set", "--match-set", ipsetName, "dst", "-j", "ACCEPT"); err != nil {
return err
Expand Down
30 changes: 30 additions & 0 deletions pkg/util/iptables/iptables_helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package iptables

import (
"testing"
)

func TestGenerateCloudAgentRules(t *testing.T) {
ipt, err := NewIPTablesHelper()
if err != nil {
t.Error(err)
}

// Sync forward
ipsetName := "FABEDGE-REMOTE-POD-CIDR"
ipt.CreateFabEdgeForwardChain()
ipt.NewMaintainForwardRulesForIPSet([]string{ipsetName})

// Sync PostRouting
ipt.NewPreparePostRoutingChain()
ipt.NewAddPostRoutingRuleForKubernetes()
ipt.NewAddPostRoutingRulesForIPSet(ipsetName)

str := ipt.GenerateInputFromRuleSet()
println(str)

//err = ipt.ReplaceRules()
//if err != nil {
// t.Error(err)
//}
}
Loading

0 comments on commit 27a370c

Please sign in to comment.