Преглед изворни кода

Add support for iptables in nftables mode.

Iptables also has the ability to work in nftables mode, where it is
supposed to act like iptables but use the nftables subsystem.
Unfortunately, it isn't exactly the same.

The biggest difference is that counter output is iptables-save style,
rather than with "-c N N".

Also, improve some tests.
Casey Callendrello пре 6 година
родитељ
комит
5c15b20bd5
3 измењених фајлова са 171 додато и 31 уклоњено
  1. 76 21
      iptables/iptables.go
  2. 90 8
      iptables/iptables_test.go
  3. 5 2
      test

+ 76 - 21
iptables/iptables.go

@@ -29,11 +29,15 @@ import (
 // Adds the output of stderr to exec.ExitError
 // Adds the output of stderr to exec.ExitError
 type Error struct {
 type Error struct {
 	exec.ExitError
 	exec.ExitError
-	cmd exec.Cmd
-	msg string
+	cmd        exec.Cmd
+	msg        string
+	exitStatus *int //for overriding
 }
 }
 
 
 func (e *Error) ExitStatus() int {
 func (e *Error) ExitStatus() int {
+	if e.exitStatus != nil {
+		return *e.exitStatus
+	}
 	return e.Sys().(syscall.WaitStatus).ExitStatus()
 	return e.Sys().(syscall.WaitStatus).ExitStatus()
 }
 }
 
 
@@ -65,6 +69,7 @@ type IPTables struct {
 	v1             int
 	v1             int
 	v2             int
 	v2             int
 	v3             int
 	v3             int
+	mode           string // the underlying iptables operating mode, e.g. nf_tables
 }
 }
 
 
 // New creates a new IPTables.
 // New creates a new IPTables.
@@ -81,12 +86,10 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 	vstring, err := getIptablesVersionString(path)
 	vstring, err := getIptablesVersionString(path)
-	v1, v2, v3, err := extractIptablesVersion(vstring)
+	v1, v2, v3, mode, err := extractIptablesVersion(vstring)
+
+	checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
 
 
-	checkPresent, waitPresent, randomFullyPresent, err := getIptablesCommandSupport(v1, v2, v3)
-	if err != nil {
-		return nil, fmt.Errorf("error checking iptables version: %v", err)
-	}
 	ipt := IPTables{
 	ipt := IPTables{
 		path:           path,
 		path:           path,
 		proto:          proto,
 		proto:          proto,
@@ -96,6 +99,7 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 		v1:             v1,
 		v1:             v1,
 		v2:             v2,
 		v2:             v2,
 		v3:             v3,
 		v3:             v3,
+		mode:           mode,
 	}
 	}
 	return &ipt, nil
 	return &ipt, nil
 }
 }
@@ -266,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) {
 	}
 	}
 
 
 	rules := strings.Split(stdout.String(), "\n")
 	rules := strings.Split(stdout.String(), "\n")
+
+	// strip trailing newline
 	if len(rules) > 0 && rules[len(rules)-1] == "" {
 	if len(rules) > 0 && rules[len(rules)-1] == "" {
 		rules = rules[:len(rules)-1]
 		rules = rules[:len(rules)-1]
 	}
 	}
 
 
+	// nftables mode doesn't return an error code when listing a non-existent
+	// chain. Patch that up.
+	if len(rules) == 0 && ipt.mode == "nf_tables" {
+		v := 1
+		return nil, &Error{
+			cmd:        exec.Cmd{Args: args},
+			msg:        "iptables: No chain/target/match by that name.",
+			exitStatus: &v,
+		}
+	}
+
+	for i, rule := range rules {
+		rules[i] = filterRuleOutput(rule)
+	}
+
 	return rules, nil
 	return rules, nil
 }
 }
 
 
@@ -284,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error {
 func (ipt *IPTables) ClearChain(table, chain string) error {
 func (ipt *IPTables) ClearChain(table, chain string) error {
 	err := ipt.NewChain(table, chain)
 	err := ipt.NewChain(table, chain)
 
 
+	// the exit code for "this table already exists" is different for
+	// different iptables modes
+	existsErr := 1
+	if ipt.mode == "nf_tables" {
+		existsErr = 4
+	}
+
 	eerr, eok := err.(*Error)
 	eerr, eok := err.(*Error)
 	switch {
 	switch {
 	case err == nil:
 	case err == nil:
 		return nil
 		return nil
-	case eok && eerr.ExitStatus() == 1:
+	case eok && eerr.ExitStatus() == existsErr:
 		// chain already exists. Flush (clear) it.
 		// chain already exists. Flush (clear) it.
 		return ipt.run("-t", table, "-F", chain)
 		return ipt.run("-t", table, "-F", chain)
 	default:
 	default:
@@ -357,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 	if err := cmd.Run(); err != nil {
 	if err := cmd.Run(); err != nil {
 		switch e := err.(type) {
 		switch e := err.(type) {
 		case *exec.ExitError:
 		case *exec.ExitError:
-			return &Error{*e, cmd, stderr.String()}
+			return &Error{*e, cmd, stderr.String(), nil}
 		default:
 		default:
 			return err
 			return err
 		}
 		}
@@ -376,36 +404,40 @@ func getIptablesCommand(proto Protocol) string {
 }
 }
 
 
 // Checks if iptables has the "-C" and "--wait" flag
 // Checks if iptables has the "-C" and "--wait" flag
-func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, error) {
-
-	return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3), nil
+func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
+	return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
 }
 }
 
 
-// getIptablesVersion returns the first three components of the iptables version.
-// e.g. "iptables v1.3.66" would return (1, 3, 66, nil)
-func extractIptablesVersion(str string) (int, int, int, error) {
-	versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
+// getIptablesVersion returns the first three components of the iptables version
+// and the operating mode (e.g. nf_tables or legacy)
+// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
+func extractIptablesVersion(str string) (int, int, int, string, error) {
+	versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
 	result := versionMatcher.FindStringSubmatch(str)
 	result := versionMatcher.FindStringSubmatch(str)
 	if result == nil {
 	if result == nil {
-		return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str)
+		return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
 	}
 	}
 
 
 	v1, err := strconv.Atoi(result[1])
 	v1, err := strconv.Atoi(result[1])
 	if err != nil {
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 	}
 
 
 	v2, err := strconv.Atoi(result[2])
 	v2, err := strconv.Atoi(result[2])
 	if err != nil {
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 	}
 
 
 	v3, err := strconv.Atoi(result[3])
 	v3, err := strconv.Atoi(result[3])
 	if err != nil {
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 	}
 
 
-	return v1, v2, v3, nil
+	mode := "legacy"
+	if result[4] != "" {
+		mode = result[4]
+	}
+	return v1, v2, v3, mode, nil
 }
 }
 
 
 // Runs "iptables --version" to get the version string
 // Runs "iptables --version" to get the version string
@@ -473,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string
 	}
 	}
 	return strings.Contains(stdout.String(), rs), nil
 	return strings.Contains(stdout.String(), rs), nil
 }
 }
+
+// counterRegex is the regex used to detect nftables counter format
+var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)
+
+// filterRuleOutput works around some inconsistencies in output.
+// For example, when iptables is in legacy vs. nftables mode, it produces
+// different results.
+func filterRuleOutput(rule string) string {
+	out := rule
+
+	// work around an output difference in nftables mode where counters
+	// are output in iptables-save format, rather than iptables -S format
+	// The string begins with "[0:0]"
+	//
+	// Fixes #49
+	if groups := counterRegex.FindStringSubmatch(out); groups != nil {
+		// drop the brackets
+		out = out[len(groups[0]):]
+		out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
+	}
+
+	return out
+}

+ 90 - 8
iptables/iptables_test.go

@@ -97,8 +97,10 @@ func mustTestableIptables() []*IPTables {
 }
 }
 
 
 func TestChain(t *testing.T) {
 func TestChain(t *testing.T) {
-	for _, ipt := range mustTestableIptables() {
-		runChainTests(t, ipt)
+	for i, ipt := range mustTestableIptables() {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			runChainTests(t, ipt)
+		})
 	}
 	}
 }
 }
 
 
@@ -179,8 +181,10 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 }
 }
 
 
 func TestRules(t *testing.T) {
 func TestRules(t *testing.T) {
-	for _, ipt := range mustTestableIptables() {
-		runRulesTests(t, ipt)
+	for i, ipt := range mustTestableIptables() {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			runRulesTests(t, ipt)
+		})
 	}
 	}
 }
 }
 
 
@@ -265,12 +269,17 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
 		t.Fatalf("ListWithCounters failed: %v", err)
 		t.Fatalf("ListWithCounters failed: %v", err)
 	}
 	}
 
 
+	suffix := " -c 0 0 -j ACCEPT"
+	if ipt.mode == "nf_tables" {
+		suffix = " -j ACCEPT -c 0 0"
+	}
+
 	expected = []string{
 	expected = []string{
 		"-N " + chain,
 		"-N " + chain,
-		"-A " + chain + " -s " + subnet1 + " -d " + address1 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + subnet2 + " -d " + address2 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + subnet2 + " -d " + address1 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + address1 + " -d " + subnet2 + " -c 0 0 -j ACCEPT",
+		"-A " + chain + " -s " + subnet1 + " -d " + address1 + suffix,
+		"-A " + chain + " -s " + subnet2 + " -d " + address2 + suffix,
+		"-A " + chain + " -s " + subnet2 + " -d " + address1 + suffix,
+		"-A " + chain + " -s " + address1 + " -d " + subnet2 + suffix,
 	}
 	}
 
 
 	if !reflect.DeepEqual(rules, expected) {
 	if !reflect.DeepEqual(rules, expected) {
@@ -408,3 +417,76 @@ func TestIsNotExist(t *testing.T) {
 		t.Fatal("IsNotExist returned false, expected true")
 		t.Fatal("IsNotExist returned false, expected true")
 	}
 	}
 }
 }
+
+func TestFilterRuleOutput(t *testing.T) {
+	testCases := []struct {
+		name string
+		in   string
+		out  string
+	}{
+		{
+			"legacy output",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+		},
+		{
+			"nft output",
+			"[99:42] -A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT -c 99 42",
+		},
+	}
+
+	for _, tt := range testCases {
+		t.Run(tt.name, func(t *testing.T) {
+			actual := filterRuleOutput(tt.in)
+			if actual != tt.out {
+				t.Fatalf("expect %s actual %s", tt.out, actual)
+			}
+		})
+	}
+}
+
+func TestExtractIptablesVersion(t *testing.T) {
+	testCases := []struct {
+		in         string
+		v1, v2, v3 int
+		mode       string
+		err        bool
+	}{
+		{
+			"iptables v1.8.0 (nf_tables)",
+			1, 8, 0,
+			"nf_tables",
+			false,
+		},
+		{
+			"iptables v1.8.0 (legacy)",
+			1, 8, 0,
+			"legacy",
+			false,
+		},
+		{
+			"iptables v1.6.2",
+			1, 6, 2,
+			"legacy",
+			false,
+		},
+	}
+
+	for i, tt := range testCases {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			v1, v2, v3, mode, err := extractIptablesVersion(tt.in)
+			if err == nil && tt.err {
+				t.Fatal("expected err, got none")
+			} else if err != nil && !tt.err {
+				t.Fatalf("unexpected err %s", err)
+			}
+
+			if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
+				t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
+					tt.v1, tt.v2, tt.v3, tt.mode,
+					v1, v2, v3, mode)
+			}
+		})
+	}
+}

+ 5 - 2
test

@@ -45,11 +45,14 @@ split=(${TEST// / })
 TEST=${split[@]/#/${REPO_PATH}/}
 TEST=${split[@]/#/${REPO_PATH}/}
 
 
 echo "Running tests..."
 echo "Running tests..."
-go test -i ${TEST}
+bin=$(mktemp)
+
+go test -c -o ${bin} ${COVER} -i ${TEST}
 if [[ -z "$SUDO_PERMITTED" ]]; then
 if [[ -z "$SUDO_PERMITTED" ]]; then
     echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable."
     echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable."
     exit 1
     exit 1
 fi
 fi
 
 
-sudo -E bash -c "PATH=\$GOROOT/bin:\$PATH go test ${COVER} $@ ${TEST}"
+sudo -E bash -c "${bin} $@ ${TEST}"
 echo "Success"
 echo "Success"
+rm "${bin}"