Browse Source

Merge pull request #77 from n0-1/master

Add some convenience functions
Casey Callendrello 4 years ago
parent
commit
b014539f5d
2 changed files with 85 additions and 0 deletions
  1. 35 0
      iptables/iptables.go
  2. 50 0
      iptables/iptables_test.go

+ 35 - 0
iptables/iptables.go

@@ -183,6 +183,14 @@ func (ipt *IPTables) Delete(table, chain string, rulespec ...string) error {
 	return ipt.run(cmd...)
 }
 
+func (ipt *IPTables) DeleteIfExists(table, chain string, rulespec ...string) error {
+	exists, err := ipt.Exists(table, chain, rulespec...)
+	if err == nil && exists {
+		err = ipt.Delete(table, chain, rulespec...)
+	}
+	return err
+}
+
 // List rules in specified table/chain
 func (ipt *IPTables) List(table, chain string) ([]string, error) {
 	args := []string{"-t", table, "-S", chain}
@@ -220,6 +228,21 @@ func (ipt *IPTables) ListChains(table string) ([]string, error) {
 	return chains, nil
 }
 
+// '-S' is fine with non existing rule index as long as the chain exists
+// therefore pass index 1 to reduce overhead for large chains
+func (ipt *IPTables) ChainExists(table, chain string) (bool, error) {
+	err := ipt.run("-t", table, "-S", chain, "1")
+	eerr, eok := err.(*Error)
+	switch {
+	case err == nil:
+		return true, nil
+	case eok && eerr.ExitStatus() == 1:
+		return false, nil
+	default:
+		return false, err
+	}
+}
+
 // Stats lists rules including the byte and packet counts
 func (ipt *IPTables) Stats(table, chain string) ([][]string, error) {
 	args := []string{"-t", table, "-L", chain, "-n", "-v", "-x"}
@@ -399,6 +422,18 @@ func (ipt *IPTables) DeleteChain(table, chain string) error {
 	return ipt.run("-t", table, "-X", chain)
 }
 
+func (ipt *IPTables) ClearAndDeleteChain(table, chain string) error {
+	exists, err := ipt.ChainExists(table, chain)
+	if err != nil || !exists {
+		return err
+	}
+	err = ipt.run("-t", table, "-F", chain)
+	if err == nil {
+		err = ipt.run("-t", table, "-X", chain)
+	}
+	return err
+}
+
 // ChangePolicy changes policy on chain to target
 func (ipt *IPTables) ChangePolicy(table, chain, target string) error {
 	return ipt.run("-t", table, "-P", chain, target)

+ 50 - 0
iptables/iptables_test.go

@@ -131,6 +131,14 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 		t.Fatalf("ListChains doesn't contain the new chain %v", chain)
 	}
 
+	// ChainExists should find it, too
+	exists, err := ipt.ChainExists("filter", chain)
+	if err != nil {
+		t.Fatalf("ChainExists for existing chain failed: %v", err)
+	} else if !exists {
+		t.Fatalf("ChainExists doesn't find existing chain")
+	}
+
 	// chain now exists
 	err = ipt.ClearChain("filter", chain)
 	if err != nil {
@@ -179,6 +187,39 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 	if !reflect.DeepEqual(originaListChain, listChain) {
 		t.Fatalf("ListChains mismatch: \ngot  %#v \nneed %#v", originaListChain, listChain)
 	}
+
+	// ChainExists must not find it anymore
+	exists, err = ipt.ChainExists("filter", chain)
+	if err != nil {
+		t.Fatalf("ChainExists for non-existing chain failed: %v", err)
+	} else if exists {
+		t.Fatalf("ChainExists finds non-existing chain")
+	}
+
+	// test ClearAndDelete
+	err = ipt.NewChain("filter", chain)
+	if err != nil {
+		t.Fatalf("NewChain failed: %v", err)
+	}
+	err = ipt.Append("filter", chain, "-j", "ACCEPT")
+	if err != nil {
+		t.Fatalf("Append failed: %v", err)
+	}
+	err = ipt.ClearAndDeleteChain("filter", chain)
+	if err != nil {
+		t.Fatalf("ClearAndDelete failed: %v", err)
+	}
+	exists, err = ipt.ChainExists("filter", chain)
+	if err != nil {
+		t.Fatalf("ChainExists failed: %v", err)
+	}
+	if exists {
+		t.Fatalf("ClearAndDelete didn't delete the chain")
+	}
+	err = ipt.ClearAndDeleteChain("filter", chain)
+	if err != nil {
+		t.Fatalf("ClearAndDelete failed for non-existing chain: %v", err)
+	}
 }
 
 func TestRules(t *testing.T) {
@@ -343,6 +384,15 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
 		}
 	}
 
+	err = ipt.DeleteIfExists("filter", chain, "-s", address1, "-d", subnet2, "-j", "ACCEPT")
+	if err != nil {
+		t.Fatalf("DeleteIfExists failed for existing rule: %v", err)
+	}
+	err = ipt.DeleteIfExists("filter", chain, "-s", address1, "-d", subnet2, "-j", "ACCEPT")
+	if err != nil {
+		t.Fatalf("DeleteIfExists failed for non-existing rule: %v", err)
+	}
+
 	// Clear the chain that was created.
 	err = ipt.ClearChain("filter", chain)
 	if err != nil {