diff --git a/iptables/iptables.go b/iptables/iptables.go index cefbb87..8d6f689 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -73,6 +73,7 @@ type IPTables struct { v2 int v3 int mode string // the underlying iptables operating mode, e.g. nf_tables + timeout int // time to wait for the iptables lock, default waits forever } // Stat represents a structured statistic entry. @@ -89,19 +90,42 @@ type Stat struct { Options string `json:"options"` } -// New creates a new IPTables. -// For backwards compatibility, this always uses IPv4, i.e. "iptables". -func New() (*IPTables, error) { - return NewWithProtocol(ProtocolIPv4) +type option func(*IPTables) + +func IPFamily(proto Protocol) option { + return func(ipt *IPTables) { + ipt.proto = proto + } } -// New creates a new IPTables for the given proto. -// The proto will determine which command is used, either "iptables" or "ip6tables". -func NewWithProtocol(proto Protocol) (*IPTables, error) { - path, err := exec.LookPath(getIptablesCommand(proto)) +func Timeout(timeout int) option { + return func(ipt *IPTables) { + ipt.timeout = timeout + } +} + +// New creates a new IPTables configured with the options passed as parameter. +// For backwards compatibility, by default always uses IPv4 and timeout 0. +// i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing +// the IPFamily and Timeout options as follow: +// ip6t := New(IPFamily(ProtocolIPv6), Timeout(5)) +func New(opts ...option) (*IPTables, error) { + + ipt := &IPTables{ + proto: ProtocolIPv4, + timeout: 0, + } + + for _, opt := range opts { + opt(ipt) + } + + path, err := exec.LookPath(getIptablesCommand(ipt.proto)) if err != nil { return nil, err } + ipt.path = path + vstring, err := getIptablesVersionString(path) if err != nil { return nil, fmt.Errorf("could not get iptables version: %v", err) @@ -110,21 +134,23 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) { if err != nil { return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err) } + ipt.v1 = v1 + ipt.v2 = v2 + ipt.v3 = v3 + ipt.mode = mode checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3) + ipt.hasCheck = checkPresent + ipt.hasWait = waitPresent + ipt.hasRandomFully = randomFullyPresent - ipt := IPTables{ - path: path, - proto: proto, - hasCheck: checkPresent, - hasWait: waitPresent, - hasRandomFully: randomFullyPresent, - v1: v1, - v2: v2, - v3: v3, - mode: mode, - } - return &ipt, nil + return ipt, nil +} + +// New creates a new IPTables for the given proto. +// The proto will determine which command is used, either "iptables" or "ip6tables". +func NewWithProtocol(proto Protocol) (*IPTables, error) { + return New(IPFamily(proto), Timeout(0)) } // Proto returns the protocol used by this IPTables. @@ -461,6 +487,9 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { args = append([]string{ipt.path}, args...) if ipt.hasWait { args = append(args, "--wait") + if ipt.timeout != 0 { + args = append(args, strconv.Itoa(ipt.timeout)) + } } else { fmu, err := newXtablesFileLock() if err != nil { diff --git a/iptables/iptables_test.go b/iptables/iptables_test.go index 5f0fb4a..624e2da 100644 --- a/iptables/iptables_test.go +++ b/iptables/iptables_test.go @@ -50,6 +50,25 @@ func TestProto(t *testing.T) { } } +func TestTimeout(t *testing.T) { + ipt, err := New() + if err != nil { + t.Fatalf("New failed: %v", err) + } + if ipt.timeout != 0 { + t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout) + } + + ipt2, err := New(Timeout(5)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + if ipt2.timeout != 5 { + t.Fatalf("Expected timeout 5, got %v", ipt.timeout) + } + +} + func randChain(t *testing.T) string { n, err := rand.Int(rand.Reader, big.NewInt(1000000)) if err != nil {