Browse Source

tests: ensure we test without hasWait/hasCheck

Jonathan Boulle 9 years ago
parent
commit
e1427d0cec
1 changed files with 45 additions and 11 deletions
  1. 45 11
      iptables/iptables_test.go

+ 45 - 11
iptables/iptables_test.go

@@ -16,6 +16,7 @@ package iptables
 
 import (
 	"crypto/rand"
+	"fmt"
 	"math/big"
 	"reflect"
 	"testing"
@@ -30,16 +31,46 @@ func randChain(t *testing.T) string {
 	return "TEST-" + n.String()
 }
 
-func TestChain(t *testing.T) {
-	chain := randChain(t)
-
+// Create an array of IPTables with different hasWait/hasCheck to
+// test different behaviours
+func mustTestableIptables() []*IPTables {
 	ipt, err := New()
 	if err != nil {
-		t.Fatalf("New failed: %v", err)
+		panic(fmt.Sprintf("New failed: %v", err))
+	}
+	ipts := []*IPTables{ipt}
+	// ensure we check one variant without built-in locking
+	if ipt.hasWait {
+		iptNoWait := &IPTables{
+			path:    ipt.path,
+			hasWait: false,
+		}
+		ipts = append(ipts, iptNoWait)
+	}
+	// ensure we check one variant without built-in checking
+	if ipt.hasCheck {
+		iptNoCheck := &IPTables{
+			path:     ipt.path,
+			hasCheck: false,
+		}
+		ipts = append(ipts, iptNoCheck)
+	}
+	return ipts
+}
+
+func TestChain(t *testing.T) {
+	for _, ipt := range mustTestableIptables() {
+		runChainTests(t, ipt)
 	}
+}
+
+func runChainTests(t *testing.T, ipt *IPTables) {
+	t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
+
+	chain := randChain(t)
 
 	// chain shouldn't exist, this will create new
-	err = ipt.ClearChain("filter", chain)
+	err := ipt.ClearChain("filter", chain)
 	if err != nil {
 		t.Fatalf("ClearChain (of missing) failed: %v", err)
 	}
@@ -82,15 +113,18 @@ func TestChain(t *testing.T) {
 }
 
 func TestRules(t *testing.T) {
-	chain := randChain(t)
-
-	ipt, err := New()
-	if err != nil {
-		t.Fatalf("New failed: %v", err)
+	for _, ipt := range mustTestableIptables() {
+		runRulesTests(t, ipt)
 	}
+}
+
+func runRulesTests(t *testing.T, ipt *IPTables) {
+	t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
+
+	chain := randChain(t)
 
 	// chain shouldn't exist, this will create new
-	err = ipt.ClearChain("filter", chain)
+	err := ipt.ClearChain("filter", chain)
 	if err != nil {
 		t.Fatalf("ClearChain (of missing) failed: %v", err)
 	}