Просмотр исходного кода

Implement ChainExists

Use 'iptables -S' to determine existence of a given chain.
Phil Sutter 4 лет назад
Родитель
Сommit
b15012c6af
2 измененных файлов с 31 добавлено и 0 удалено
  1. 15 0
      iptables/iptables.go
  2. 16 0
      iptables/iptables_test.go

+ 15 - 0
iptables/iptables.go

@@ -228,6 +228,21 @@ func (ipt *IPTables) ListChains(table string) ([]string, error) {
 	return chains, nil
 }
 
+// '-S' is fine with non existing rule index as long as the chain exists
+// therefore pass index 1 to reduce overhead for large chains
+func (ipt *IPTables) ChainExists(table, chain string) (bool, error) {
+	err := ipt.run("-t", table, "-S", chain, "1")
+	eerr, eok := err.(*Error)
+	switch {
+	case err == nil:
+		return true, nil
+	case eok && eerr.ExitStatus() == 1:
+		return false, nil
+	default:
+		return false, err
+	}
+}
+
 // Stats lists rules including the byte and packet counts
 func (ipt *IPTables) Stats(table, chain string) ([][]string, error) {
 	args := []string{"-t", table, "-L", chain, "-n", "-v", "-x"}

+ 16 - 0
iptables/iptables_test.go

@@ -131,6 +131,14 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 		t.Fatalf("ListChains doesn't contain the new chain %v", chain)
 	}
 
+	// ChainExists should find it, too
+	exists, err := ipt.ChainExists("filter", chain)
+	if err != nil {
+		t.Fatalf("ChainExists for existing chain failed: %v", err)
+	} else if !exists {
+		t.Fatalf("ChainExists doesn't find existing chain")
+	}
+
 	// chain now exists
 	err = ipt.ClearChain("filter", chain)
 	if err != nil {
@@ -179,6 +187,14 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 	if !reflect.DeepEqual(originaListChain, listChain) {
 		t.Fatalf("ListChains mismatch: \ngot  %#v \nneed %#v", originaListChain, listChain)
 	}
+
+	// ChainExists must not find it anymore
+	exists, err = ipt.ChainExists("filter", chain)
+	if err != nil {
+		t.Fatalf("ChainExists for non-existing chain failed: %v", err)
+	} else if exists {
+		t.Fatalf("ChainExists finds non-existing chain")
+	}
 }
 
 func TestRules(t *testing.T) {