Prechádzať zdrojové kódy

Add support for ListChains method. (#26)

ListChains lists all the existing chains on an existing table
Van De Walle Bernard 8 rokov pred
rodič
commit
5463fbac3b
2 zmenil súbory, kde vykonal 63 pridanie a 1 odobranie
  1. 29 0
      iptables/iptables.go
  2. 34 1
      iptables/iptables_test.go

+ 29 - 0
iptables/iptables.go

@@ -139,6 +139,35 @@ func (ipt *IPTables) Delete(table, chain string, rulespec ...string) error {
 // List rules in specified table/chain
 func (ipt *IPTables) List(table, chain string) ([]string, error) {
 	args := []string{"-t", table, "-S", chain}
+	return ipt.executeList(args)
+}
+
+// ListChains returns a slice containing the name of each chain in the specified table.
+func (ipt *IPTables) ListChains(table string) ([]string, error) {
+	args := []string{"-t", table, "-S"}
+
+	result, err := ipt.executeList(args)
+	if err != nil {
+		return nil, err
+	}
+
+	// Iterate over rules to find all default (-P) and user-specified (-N) chains.
+	// Chains definition always come before rules.
+	// Format is the following:
+	// -P OUTPUT ACCEPT
+	// -N Custom
+	var chains []string
+	for _, val := range result {
+		if strings.HasPrefix(val, "-P") || strings.HasPrefix(val, "-N") {
+			chains = append(chains, strings.Fields(val)[1])
+		} else {
+			break
+		}
+	}
+	return chains, nil
+}
+
+func (ipt *IPTables) executeList(args []string) ([]string, error) {
 	var stdout bytes.Buffer
 	if err := ipt.runWithOutput(args, &stdout); err != nil {
 		return nil, err

+ 34 - 1
iptables/iptables_test.go

@@ -57,6 +57,15 @@ func randChain(t *testing.T) string {
 	return "TEST-" + n.String()
 }
 
+func contains(list []string, value string) bool {
+	for _, val := range list {
+		if val == value {
+			return true
+		}
+	}
+	return false
+}
+
 // Create an array of IPTables with different hasWait/hasCheck to
 // test different behaviours
 func mustTestableIptables() []*IPTables {
@@ -99,12 +108,27 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 
 	chain := randChain(t)
 
+	// Saving the list of chains before executing tests
+	originaListChain, err := ipt.ListChains("filter")
+	if err != nil {
+		t.Fatalf("ListChains of Initial failed: %v", err)
+	}
+
 	// chain shouldn't exist, this will create new
-	err := ipt.ClearChain("filter", chain)
+	err = ipt.ClearChain("filter", chain)
 	if err != nil {
 		t.Fatalf("ClearChain (of missing) failed: %v", err)
 	}
 
+	// chain should be in listChain
+	listChain, err := ipt.ListChains("filter")
+	if err != nil {
+		t.Fatalf("ListChains failed: %v", err)
+	}
+	if !contains(listChain, chain) {
+		t.Fatalf("ListChains doesn't contain the new chain %v", chain)
+	}
+
 	// chain now exists
 	err = ipt.ClearChain("filter", chain)
 	if err != nil {
@@ -140,6 +164,15 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 	if err != nil {
 		t.Fatalf("DeleteChain of empty chain failed: %v", err)
 	}
+
+	// check that chain is fully gone and that state similar to initial one
+	listChain, err = ipt.ListChains("filter")
+	if err != nil {
+		t.Fatalf("ListChains failed: %v", err)
+	}
+	if !reflect.DeepEqual(originaListChain, listChain) {
+		t.Fatalf("ListChains mismatch: \ngot  %#v \nneed %#v", originaListChain, listChain)
+	}
 }
 
 func TestRules(t *testing.T) {