ソースを参照

new iptables constructor

Use functional options to create a new iptables object, so we can
support new options to configure it.

It allows to configure the IP family and the timeout used in the
`-w` flag. Until now, if iptables version supports wait,
we used `-w` without any timeout, so it keeps waiting forever trying
to acquire the lock at a 1 second interval rate.
This can cause issues on busy environments by software that depends
on this library, because they can be waiting forever.

Signed-off-by: Antonio Ojea <aojea@redhat.com>
Antonio Ojea 4 年 前
コミット
619e48d024
2 ファイル変更68 行追加20 行削除
  1. 49 20
      iptables/iptables.go
  2. 19 0
      iptables/iptables_test.go

+ 49 - 20
iptables/iptables.go

@@ -73,6 +73,7 @@ type IPTables struct {
 	v2             int
 	v3             int
 	mode           string // the underlying iptables operating mode, e.g. nf_tables
+	timeout        int    // time to wait for the iptables lock, default waits forever
 }
 
 // Stat represents a structured statistic entry.
@@ -89,19 +90,42 @@ type Stat struct {
 	Options     string     `json:"options"`
 }
 
-// New creates a new IPTables.
-// For backwards compatibility, this always uses IPv4, i.e. "iptables".
-func New() (*IPTables, error) {
-	return NewWithProtocol(ProtocolIPv4)
+type option func(*IPTables)
+
+func IPFamily(proto Protocol) option {
+	return func(ipt *IPTables) {
+		ipt.proto = proto
+	}
 }
 
-// New creates a new IPTables for the given proto.
-// The proto will determine which command is used, either "iptables" or "ip6tables".
-func NewWithProtocol(proto Protocol) (*IPTables, error) {
-	path, err := exec.LookPath(getIptablesCommand(proto))
+func Timeout(timeout int) option {
+	return func(ipt *IPTables) {
+		ipt.timeout = timeout
+	}
+}
+
+// New creates a new IPTables configured with the options passed as parameter.
+// For backwards compatibility, by default always uses IPv4 and timeout 0.
+// i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
+// the IPFamily and Timeout options as follow:
+//	ip6t := New(IPFamily(ProtocolIPv6), Timeout(5))
+func New(opts ...option) (*IPTables, error) {
+
+	ipt := &IPTables{
+		proto:   ProtocolIPv4,
+		timeout: 0,
+	}
+
+	for _, opt := range opts {
+		opt(ipt)
+	}
+
+	path, err := exec.LookPath(getIptablesCommand(ipt.proto))
 	if err != nil {
 		return nil, err
 	}
+	ipt.path = path
+
 	vstring, err := getIptablesVersionString(path)
 	if err != nil {
 		return nil, fmt.Errorf("could not get iptables version: %v", err)
@@ -110,21 +134,23 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 	if err != nil {
 		return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err)
 	}
+	ipt.v1 = v1
+	ipt.v2 = v2
+	ipt.v3 = v3
+	ipt.mode = mode
 
 	checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
+	ipt.hasCheck = checkPresent
+	ipt.hasWait = waitPresent
+	ipt.hasRandomFully = randomFullyPresent
 
-	ipt := IPTables{
-		path:           path,
-		proto:          proto,
-		hasCheck:       checkPresent,
-		hasWait:        waitPresent,
-		hasRandomFully: randomFullyPresent,
-		v1:             v1,
-		v2:             v2,
-		v3:             v3,
-		mode:           mode,
-	}
-	return &ipt, nil
+	return ipt, nil
+}
+
+// New creates a new IPTables for the given proto.
+// The proto will determine which command is used, either "iptables" or "ip6tables".
+func NewWithProtocol(proto Protocol) (*IPTables, error) {
+	return New(IPFamily(proto), Timeout(0))
 }
 
 // Proto returns the protocol used by this IPTables.
@@ -426,6 +452,9 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 	args = append([]string{ipt.path}, args...)
 	if ipt.hasWait {
 		args = append(args, "--wait")
+		if ipt.timeout != 0 {
+			args = append(args, strconv.Itoa(ipt.timeout))
+		}
 	} else {
 		fmu, err := newXtablesFileLock()
 		if err != nil {

+ 19 - 0
iptables/iptables_test.go

@@ -50,6 +50,25 @@ func TestProto(t *testing.T) {
 	}
 }
 
+func TestTimeout(t *testing.T) {
+	ipt, err := New()
+	if err != nil {
+		t.Fatalf("New failed: %v", err)
+	}
+	if ipt.timeout != 0 {
+		t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
+	}
+
+	ipt2, err := New(Timeout(5))
+	if err != nil {
+		t.Fatalf("New failed: %v", err)
+	}
+	if ipt2.timeout != 5 {
+		t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
+	}
+
+}
+
 func randChain(t *testing.T) string {
 	n, err := rand.Int(rand.Reader, big.NewInt(1000000))
 	if err != nil {