|
@@ -16,6 +16,7 @@ package iptables
|
|
|
|
|
|
import (
|
|
|
"crypto/rand"
|
|
|
+ "fmt"
|
|
|
"math/big"
|
|
|
"reflect"
|
|
|
"testing"
|
|
@@ -30,16 +31,46 @@ func randChain(t *testing.T) string {
|
|
|
return "TEST-" + n.String()
|
|
|
}
|
|
|
|
|
|
-func TestChain(t *testing.T) {
|
|
|
- chain := randChain(t)
|
|
|
-
|
|
|
+// Create an array of IPTables with different hasWait/hasCheck to
|
|
|
+// test different behaviours
|
|
|
+func mustTestableIptables() []*IPTables {
|
|
|
ipt, err := New()
|
|
|
if err != nil {
|
|
|
- t.Fatalf("New failed: %v", err)
|
|
|
+ panic(fmt.Sprintf("New failed: %v", err))
|
|
|
+ }
|
|
|
+ ipts := []*IPTables{ipt}
|
|
|
+ // ensure we check one variant without built-in locking
|
|
|
+ if ipt.hasWait {
|
|
|
+ iptNoWait := &IPTables{
|
|
|
+ path: ipt.path,
|
|
|
+ hasWait: false,
|
|
|
+ }
|
|
|
+ ipts = append(ipts, iptNoWait)
|
|
|
+ }
|
|
|
+ // ensure we check one variant without built-in checking
|
|
|
+ if ipt.hasCheck {
|
|
|
+ iptNoCheck := &IPTables{
|
|
|
+ path: ipt.path,
|
|
|
+ hasCheck: false,
|
|
|
+ }
|
|
|
+ ipts = append(ipts, iptNoCheck)
|
|
|
+ }
|
|
|
+ return ipts
|
|
|
+}
|
|
|
+
|
|
|
+func TestChain(t *testing.T) {
|
|
|
+ for _, ipt := range mustTestableIptables() {
|
|
|
+ runChainTests(t, ipt)
|
|
|
}
|
|
|
+}
|
|
|
+
|
|
|
+func runChainTests(t *testing.T, ipt *IPTables) {
|
|
|
+ t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
|
|
|
+
|
|
|
+ chain := randChain(t)
|
|
|
|
|
|
// chain shouldn't exist, this will create new
|
|
|
- err = ipt.ClearChain("filter", chain)
|
|
|
+ err := ipt.ClearChain("filter", chain)
|
|
|
if err != nil {
|
|
|
t.Fatalf("ClearChain (of missing) failed: %v", err)
|
|
|
}
|
|
@@ -82,15 +113,18 @@ func TestChain(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
func TestRules(t *testing.T) {
|
|
|
- chain := randChain(t)
|
|
|
-
|
|
|
- ipt, err := New()
|
|
|
- if err != nil {
|
|
|
- t.Fatalf("New failed: %v", err)
|
|
|
+ for _, ipt := range mustTestableIptables() {
|
|
|
+ runRulesTests(t, ipt)
|
|
|
}
|
|
|
+}
|
|
|
+
|
|
|
+func runRulesTests(t *testing.T, ipt *IPTables) {
|
|
|
+ t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
|
|
|
+
|
|
|
+ chain := randChain(t)
|
|
|
|
|
|
// chain shouldn't exist, this will create new
|
|
|
- err = ipt.ClearChain("filter", chain)
|
|
|
+ err := ipt.ClearChain("filter", chain)
|
|
|
if err != nil {
|
|
|
t.Fatalf("ClearChain (of missing) failed: %v", err)
|
|
|
}
|