Browse Source

Merge pull request #96 from alegrey91/main

feat: add ListById function and test
Casey Callendrello 2 years ago
parent
commit
d2b8608923
2 changed files with 77 additions and 0 deletions
  1. 10 0
      iptables/iptables.go
  2. 67 0
      iptables/iptables_test.go

+ 10 - 0
iptables/iptables.go

@@ -234,6 +234,16 @@ func (ipt *IPTables) DeleteIfExists(table, chain string, rulespec ...string) err
 	return err
 }
 
+// List rules in specified table/chain
+func (ipt *IPTables) ListById(table, chain string, id int) (string, error) {
+	args := []string{"-t", table, "-S", chain, strconv.Itoa(id)}
+	rule, err := ipt.executeList(args)
+	if err != nil {
+		return "", err
+	}
+	return rule[0], nil
+}
+
 // List rules in specified table/chain
 func (ipt *IPTables) List(table, chain string) ([]string, error) {
 	args := []string{"-t", table, "-S", chain}

+ 67 - 0
iptables/iptables_test.go

@@ -21,6 +21,7 @@ import (
 	"net"
 	"os"
 	"reflect"
+	"strings"
 	"testing"
 )
 
@@ -683,3 +684,69 @@ func TestExtractIptablesVersion(t *testing.T) {
 		})
 	}
 }
+
+func TestListById(t *testing.T) {
+	testCases := []struct {
+		in       string
+		id       int
+		out      string
+		expected bool
+	}{
+		{
+			"-i lo -p tcp -m tcp --dport 3000 -j DNAT --to-destination 127.0.0.1:3000",
+			1,
+			"-A PREROUTING -i lo -p tcp -m tcp --dport 3000 -j DNAT --to-destination 127.0.0.1:3000",
+			true,
+		},
+		{
+			"-i lo -p tcp -m tcp --dport 3000 -j DNAT --to-destination 127.0.0.1:3001",
+			2,
+			"-A PREROUTING -i lo -p tcp -m tcp --dport 3000 -j DNAT --to-destination 127.0.0.1:3001",
+			true,
+		},
+		{
+			"-i lo -p tcp -m tcp --dport 3000 -j DNAT --to-destination 127.0.0.1:3002",
+			3,
+			"-A PREROUTING -i lo -p tcp -m tcp --dport 3000 -j DNAT --to-destination 127.0.0.1:3003",
+			false,
+		},
+	}
+
+	ipt, err := New()
+	if err != nil {
+		t.Fatalf("failed to init: %v", err)
+	}
+	// ensure to test in a clear environment
+	err = ipt.ClearChain("nat", "PREROUTING")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer func() {
+		err = ipt.ClearChain("nat", "PREROUTING")
+		if err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	for _, tt := range testCases {
+		t.Run(fmt.Sprintf("Checking rule with id %d", tt.id), func(t *testing.T) {
+			err = ipt.Append("nat", "PREROUTING", strings.Split(tt.in, " ")...)
+			if err != nil {
+				t.Fatal(err)
+			}
+			rule, err := ipt.ListById("nat", "PREROUTING", tt.id)
+			if err != nil {
+				t.Fatal(err)
+			}
+			fmt.Println(rule)
+			test_result := false
+			if rule == tt.out {
+				test_result = true
+			}
+			if test_result != tt.expected {
+				t.Fatal("Test failed")
+			}
+		})
+	}
+}