Browse Source

Merge pull request #19 from jonboulle/master

*: fix locking operation and add tests
Jonathan Boulle 9 years ago
parent
commit
ed1329de7a
2 changed files with 54 additions and 22 deletions
  1. 9 11
      iptables/iptables.go
  2. 45 11
      iptables/iptables_test.go

+ 9 - 11
iptables/iptables.go

@@ -44,8 +44,6 @@ type IPTables struct {
 	path     string
 	hasCheck bool
 	hasWait  bool
-
-	fmu *fileLock
 }
 
 func New() (*IPTables, error) {
@@ -64,12 +62,6 @@ func New() (*IPTables, error) {
 		hasCheck: checkPresent,
 		hasWait:  waitPresent,
 	}
-	if !waitPresent {
-		ipt.fmu, err = newXtablesFileLock()
-		if err != nil {
-			return nil, err
-		}
-	}
 	return &ipt, nil
 }
 
@@ -81,10 +73,11 @@ func (ipt *IPTables) Exists(table, chain string, rulespec ...string) (bool, erro
 	}
 	cmd := append([]string{"-t", table, "-C", chain}, rulespec...)
 	err := ipt.run(cmd...)
+	eerr, eok := err.(*Error)
 	switch {
 	case err == nil:
 		return true, nil
-	case err.(*Error).ExitStatus() == 1:
+	case eok && eerr.ExitStatus() == 1:
 		return false, nil
 	default:
 		return false, err
@@ -148,10 +141,11 @@ func (ipt *IPTables) NewChain(table, chain string) error {
 func (ipt *IPTables) ClearChain(table, chain string) error {
 	err := ipt.NewChain(table, chain)
 
+	eerr, eok := err.(*Error)
 	switch {
 	case err == nil:
 		return nil
-	case err.(*Error).ExitStatus() == 1:
+	case eok && eerr.ExitStatus() == 1:
 		// chain already exists. Flush (clear) it.
 		return ipt.run("-t", table, "-F", chain)
 	default:
@@ -183,7 +177,11 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 	if ipt.hasWait {
 		args = append(args, "--wait")
 	} else {
-		ul, err := ipt.fmu.tryLock()
+		fmu, err := newXtablesFileLock()
+		if err != nil {
+			return err
+		}
+		ul, err := fmu.tryLock()
 		if err != nil {
 			return err
 		}

+ 45 - 11
iptables/iptables_test.go

@@ -16,6 +16,7 @@ package iptables
 
 import (
 	"crypto/rand"
+	"fmt"
 	"math/big"
 	"reflect"
 	"testing"
@@ -30,16 +31,46 @@ func randChain(t *testing.T) string {
 	return "TEST-" + n.String()
 }
 
-func TestChain(t *testing.T) {
-	chain := randChain(t)
-
+// Create an array of IPTables with different hasWait/hasCheck to
+// test different behaviours
+func mustTestableIptables() []*IPTables {
 	ipt, err := New()
 	if err != nil {
-		t.Fatalf("New failed: %v", err)
+		panic(fmt.Sprintf("New failed: %v", err))
+	}
+	ipts := []*IPTables{ipt}
+	// ensure we check one variant without built-in locking
+	if ipt.hasWait {
+		iptNoWait := &IPTables{
+			path:    ipt.path,
+			hasWait: false,
+		}
+		ipts = append(ipts, iptNoWait)
+	}
+	// ensure we check one variant without built-in checking
+	if ipt.hasCheck {
+		iptNoCheck := &IPTables{
+			path:     ipt.path,
+			hasCheck: false,
+		}
+		ipts = append(ipts, iptNoCheck)
+	}
+	return ipts
+}
+
+func TestChain(t *testing.T) {
+	for _, ipt := range mustTestableIptables() {
+		runChainTests(t, ipt)
 	}
+}
+
+func runChainTests(t *testing.T, ipt *IPTables) {
+	t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
+
+	chain := randChain(t)
 
 	// chain shouldn't exist, this will create new
-	err = ipt.ClearChain("filter", chain)
+	err := ipt.ClearChain("filter", chain)
 	if err != nil {
 		t.Fatalf("ClearChain (of missing) failed: %v", err)
 	}
@@ -82,15 +113,18 @@ func TestChain(t *testing.T) {
 }
 
 func TestRules(t *testing.T) {
-	chain := randChain(t)
-
-	ipt, err := New()
-	if err != nil {
-		t.Fatalf("New failed: %v", err)
+	for _, ipt := range mustTestableIptables() {
+		runRulesTests(t, ipt)
 	}
+}
+
+func runRulesTests(t *testing.T, ipt *IPTables) {
+	t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
+
+	chain := randChain(t)
 
 	// chain shouldn't exist, this will create new
-	err = ipt.ClearChain("filter", chain)
+	err := ipt.ClearChain("filter", chain)
 	if err != nil {
 		t.Fatalf("ClearChain (of missing) failed: %v", err)
 	}