فهرست منبع

Merge pull request #51 from squeed/iptables-nft

Add support for iptables in nftables mode.
Casey Callendrello 6 سال پیش
والد
کامیت
ee9f8ee3f1
3فایلهای تغییر یافته به همراه171 افزوده شده و 31 حذف شده
  1. 76 21
      iptables/iptables.go
  2. 90 8
      iptables/iptables_test.go
  3. 5 2
      test

+ 76 - 21
iptables/iptables.go

@@ -29,11 +29,15 @@ import (
 // Adds the output of stderr to exec.ExitError
 type Error struct {
 	exec.ExitError
-	cmd exec.Cmd
-	msg string
+	cmd        exec.Cmd
+	msg        string
+	exitStatus *int //for overriding
 }
 
 func (e *Error) ExitStatus() int {
+	if e.exitStatus != nil {
+		return *e.exitStatus
+	}
 	return e.Sys().(syscall.WaitStatus).ExitStatus()
 }
 
@@ -65,6 +69,7 @@ type IPTables struct {
 	v1             int
 	v2             int
 	v3             int
+	mode           string // the underlying iptables operating mode, e.g. nf_tables
 }
 
 // New creates a new IPTables.
@@ -81,12 +86,10 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 		return nil, err
 	}
 	vstring, err := getIptablesVersionString(path)
-	v1, v2, v3, err := extractIptablesVersion(vstring)
+	v1, v2, v3, mode, err := extractIptablesVersion(vstring)
+
+	checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
 
-	checkPresent, waitPresent, randomFullyPresent, err := getIptablesCommandSupport(v1, v2, v3)
-	if err != nil {
-		return nil, fmt.Errorf("error checking iptables version: %v", err)
-	}
 	ipt := IPTables{
 		path:           path,
 		proto:          proto,
@@ -96,6 +99,7 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 		v1:             v1,
 		v2:             v2,
 		v3:             v3,
+		mode:           mode,
 	}
 	return &ipt, nil
 }
@@ -266,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) {
 	}
 
 	rules := strings.Split(stdout.String(), "\n")
+
+	// strip trailing newline
 	if len(rules) > 0 && rules[len(rules)-1] == "" {
 		rules = rules[:len(rules)-1]
 	}
 
+	// nftables mode doesn't return an error code when listing a non-existent
+	// chain. Patch that up.
+	if len(rules) == 0 && ipt.mode == "nf_tables" {
+		v := 1
+		return nil, &Error{
+			cmd:        exec.Cmd{Args: args},
+			msg:        "iptables: No chain/target/match by that name.",
+			exitStatus: &v,
+		}
+	}
+
+	for i, rule := range rules {
+		rules[i] = filterRuleOutput(rule)
+	}
+
 	return rules, nil
 }
 
@@ -284,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error {
 func (ipt *IPTables) ClearChain(table, chain string) error {
 	err := ipt.NewChain(table, chain)
 
+	// the exit code for "this table already exists" is different for
+	// different iptables modes
+	existsErr := 1
+	if ipt.mode == "nf_tables" {
+		existsErr = 4
+	}
+
 	eerr, eok := err.(*Error)
 	switch {
 	case err == nil:
 		return nil
-	case eok && eerr.ExitStatus() == 1:
+	case eok && eerr.ExitStatus() == existsErr:
 		// chain already exists. Flush (clear) it.
 		return ipt.run("-t", table, "-F", chain)
 	default:
@@ -357,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 	if err := cmd.Run(); err != nil {
 		switch e := err.(type) {
 		case *exec.ExitError:
-			return &Error{*e, cmd, stderr.String()}
+			return &Error{*e, cmd, stderr.String(), nil}
 		default:
 			return err
 		}
@@ -376,36 +404,40 @@ func getIptablesCommand(proto Protocol) string {
 }
 
 // Checks if iptables has the "-C" and "--wait" flag
-func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, error) {
-
-	return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3), nil
+func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
+	return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
 }
 
-// getIptablesVersion returns the first three components of the iptables version.
-// e.g. "iptables v1.3.66" would return (1, 3, 66, nil)
-func extractIptablesVersion(str string) (int, int, int, error) {
-	versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
+// getIptablesVersion returns the first three components of the iptables version
+// and the operating mode (e.g. nf_tables or legacy)
+// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
+func extractIptablesVersion(str string) (int, int, int, string, error) {
+	versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
 	result := versionMatcher.FindStringSubmatch(str)
 	if result == nil {
-		return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str)
+		return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
 	}
 
 	v1, err := strconv.Atoi(result[1])
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 
 	v2, err := strconv.Atoi(result[2])
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 
 	v3, err := strconv.Atoi(result[3])
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 
-	return v1, v2, v3, nil
+	mode := "legacy"
+	if result[4] != "" {
+		mode = result[4]
+	}
+	return v1, v2, v3, mode, nil
 }
 
 // Runs "iptables --version" to get the version string
@@ -473,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string
 	}
 	return strings.Contains(stdout.String(), rs), nil
 }
+
+// counterRegex is the regex used to detect nftables counter format
+var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)
+
+// filterRuleOutput works around some inconsistencies in output.
+// For example, when iptables is in legacy vs. nftables mode, it produces
+// different results.
+func filterRuleOutput(rule string) string {
+	out := rule
+
+	// work around an output difference in nftables mode where counters
+	// are output in iptables-save format, rather than iptables -S format
+	// The string begins with "[0:0]"
+	//
+	// Fixes #49
+	if groups := counterRegex.FindStringSubmatch(out); groups != nil {
+		// drop the brackets
+		out = out[len(groups[0]):]
+		out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
+	}
+
+	return out
+}

+ 90 - 8
iptables/iptables_test.go

@@ -97,8 +97,10 @@ func mustTestableIptables() []*IPTables {
 }
 
 func TestChain(t *testing.T) {
-	for _, ipt := range mustTestableIptables() {
-		runChainTests(t, ipt)
+	for i, ipt := range mustTestableIptables() {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			runChainTests(t, ipt)
+		})
 	}
 }
 
@@ -179,8 +181,10 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 }
 
 func TestRules(t *testing.T) {
-	for _, ipt := range mustTestableIptables() {
-		runRulesTests(t, ipt)
+	for i, ipt := range mustTestableIptables() {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			runRulesTests(t, ipt)
+		})
 	}
 }
 
@@ -265,12 +269,17 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
 		t.Fatalf("ListWithCounters failed: %v", err)
 	}
 
+	suffix := " -c 0 0 -j ACCEPT"
+	if ipt.mode == "nf_tables" {
+		suffix = " -j ACCEPT -c 0 0"
+	}
+
 	expected = []string{
 		"-N " + chain,
-		"-A " + chain + " -s " + subnet1 + " -d " + address1 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + subnet2 + " -d " + address2 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + subnet2 + " -d " + address1 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + address1 + " -d " + subnet2 + " -c 0 0 -j ACCEPT",
+		"-A " + chain + " -s " + subnet1 + " -d " + address1 + suffix,
+		"-A " + chain + " -s " + subnet2 + " -d " + address2 + suffix,
+		"-A " + chain + " -s " + subnet2 + " -d " + address1 + suffix,
+		"-A " + chain + " -s " + address1 + " -d " + subnet2 + suffix,
 	}
 
 	if !reflect.DeepEqual(rules, expected) {
@@ -408,3 +417,76 @@ func TestIsNotExist(t *testing.T) {
 		t.Fatal("IsNotExist returned false, expected true")
 	}
 }
+
+func TestFilterRuleOutput(t *testing.T) {
+	testCases := []struct {
+		name string
+		in   string
+		out  string
+	}{
+		{
+			"legacy output",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+		},
+		{
+			"nft output",
+			"[99:42] -A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT -c 99 42",
+		},
+	}
+
+	for _, tt := range testCases {
+		t.Run(tt.name, func(t *testing.T) {
+			actual := filterRuleOutput(tt.in)
+			if actual != tt.out {
+				t.Fatalf("expect %s actual %s", tt.out, actual)
+			}
+		})
+	}
+}
+
+func TestExtractIptablesVersion(t *testing.T) {
+	testCases := []struct {
+		in         string
+		v1, v2, v3 int
+		mode       string
+		err        bool
+	}{
+		{
+			"iptables v1.8.0 (nf_tables)",
+			1, 8, 0,
+			"nf_tables",
+			false,
+		},
+		{
+			"iptables v1.8.0 (legacy)",
+			1, 8, 0,
+			"legacy",
+			false,
+		},
+		{
+			"iptables v1.6.2",
+			1, 6, 2,
+			"legacy",
+			false,
+		},
+	}
+
+	for i, tt := range testCases {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			v1, v2, v3, mode, err := extractIptablesVersion(tt.in)
+			if err == nil && tt.err {
+				t.Fatal("expected err, got none")
+			} else if err != nil && !tt.err {
+				t.Fatalf("unexpected err %s", err)
+			}
+
+			if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
+				t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
+					tt.v1, tt.v2, tt.v3, tt.mode,
+					v1, v2, v3, mode)
+			}
+		})
+	}
+}

+ 5 - 2
test

@@ -45,11 +45,14 @@ split=(${TEST// / })
 TEST=${split[@]/#/${REPO_PATH}/}
 
 echo "Running tests..."
-go test -i ${TEST}
+bin=$(mktemp)
+
+go test -c -o ${bin} ${COVER} -i ${TEST}
 if [[ -z "$SUDO_PERMITTED" ]]; then
     echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable."
     exit 1
 fi
 
-sudo -E bash -c "PATH=\$GOROOT/bin:\$PATH go test ${COVER} $@ ${TEST}"
+sudo -E bash -c "${bin} $@ ${TEST}"
 echo "Success"
+rm "${bin}"