Browse Source

Add support for iptables in nftables mode.

Iptables also has the ability to work in nftables mode, where it is
supposed to act like iptables but use the nftables subsystem.
Unfortunately, it isn't exactly the same.

The biggest difference is that counter output is iptables-save style,
rather than with "-c N N".

Also, improve some tests.
Casey Callendrello 6 years ago
parent
commit
5c15b20bd5
3 changed files with 171 additions and 31 deletions
  1. 76 21
      iptables/iptables.go
  2. 90 8
      iptables/iptables_test.go
  3. 5 2
      test

+ 76 - 21
iptables/iptables.go

@@ -29,11 +29,15 @@ import (
 // Adds the output of stderr to exec.ExitError
 type Error struct {
 	exec.ExitError
-	cmd exec.Cmd
-	msg string
+	cmd        exec.Cmd
+	msg        string
+	exitStatus *int //for overriding
 }
 
 func (e *Error) ExitStatus() int {
+	if e.exitStatus != nil {
+		return *e.exitStatus
+	}
 	return e.Sys().(syscall.WaitStatus).ExitStatus()
 }
 
@@ -65,6 +69,7 @@ type IPTables struct {
 	v1             int
 	v2             int
 	v3             int
+	mode           string // the underlying iptables operating mode, e.g. nf_tables
 }
 
 // New creates a new IPTables.
@@ -81,12 +86,10 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 		return nil, err
 	}
 	vstring, err := getIptablesVersionString(path)
-	v1, v2, v3, err := extractIptablesVersion(vstring)
+	v1, v2, v3, mode, err := extractIptablesVersion(vstring)
+
+	checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
 
-	checkPresent, waitPresent, randomFullyPresent, err := getIptablesCommandSupport(v1, v2, v3)
-	if err != nil {
-		return nil, fmt.Errorf("error checking iptables version: %v", err)
-	}
 	ipt := IPTables{
 		path:           path,
 		proto:          proto,
@@ -96,6 +99,7 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 		v1:             v1,
 		v2:             v2,
 		v3:             v3,
+		mode:           mode,
 	}
 	return &ipt, nil
 }
@@ -266,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) {
 	}
 
 	rules := strings.Split(stdout.String(), "\n")
+
+	// strip trailing newline
 	if len(rules) > 0 && rules[len(rules)-1] == "" {
 		rules = rules[:len(rules)-1]
 	}
 
+	// nftables mode doesn't return an error code when listing a non-existent
+	// chain. Patch that up.
+	if len(rules) == 0 && ipt.mode == "nf_tables" {
+		v := 1
+		return nil, &Error{
+			cmd:        exec.Cmd{Args: args},
+			msg:        "iptables: No chain/target/match by that name.",
+			exitStatus: &v,
+		}
+	}
+
+	for i, rule := range rules {
+		rules[i] = filterRuleOutput(rule)
+	}
+
 	return rules, nil
 }
 
@@ -284,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error {
 func (ipt *IPTables) ClearChain(table, chain string) error {
 	err := ipt.NewChain(table, chain)
 
+	// the exit code for "this table already exists" is different for
+	// different iptables modes
+	existsErr := 1
+	if ipt.mode == "nf_tables" {
+		existsErr = 4
+	}
+
 	eerr, eok := err.(*Error)
 	switch {
 	case err == nil:
 		return nil
-	case eok && eerr.ExitStatus() == 1:
+	case eok && eerr.ExitStatus() == existsErr:
 		// chain already exists. Flush (clear) it.
 		return ipt.run("-t", table, "-F", chain)
 	default:
@@ -357,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 	if err := cmd.Run(); err != nil {
 		switch e := err.(type) {
 		case *exec.ExitError:
-			return &Error{*e, cmd, stderr.String()}
+			return &Error{*e, cmd, stderr.String(), nil}
 		default:
 			return err
 		}
@@ -376,36 +404,40 @@ func getIptablesCommand(proto Protocol) string {
 }
 
 // Checks if iptables has the "-C" and "--wait" flag
-func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, error) {
-
-	return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3), nil
+func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
+	return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
 }
 
-// getIptablesVersion returns the first three components of the iptables version.
-// e.g. "iptables v1.3.66" would return (1, 3, 66, nil)
-func extractIptablesVersion(str string) (int, int, int, error) {
-	versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
+// getIptablesVersion returns the first three components of the iptables version
+// and the operating mode (e.g. nf_tables or legacy)
+// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
+func extractIptablesVersion(str string) (int, int, int, string, error) {
+	versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
 	result := versionMatcher.FindStringSubmatch(str)
 	if result == nil {
-		return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str)
+		return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
 	}
 
 	v1, err := strconv.Atoi(result[1])
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 
 	v2, err := strconv.Atoi(result[2])
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 
 	v3, err := strconv.Atoi(result[3])
 	if err != nil {
-		return 0, 0, 0, err
+		return 0, 0, 0, "", err
 	}
 
-	return v1, v2, v3, nil
+	mode := "legacy"
+	if result[4] != "" {
+		mode = result[4]
+	}
+	return v1, v2, v3, mode, nil
 }
 
 // Runs "iptables --version" to get the version string
@@ -473,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string
 	}
 	return strings.Contains(stdout.String(), rs), nil
 }
+
+// counterRegex is the regex used to detect nftables counter format
+var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)
+
+// filterRuleOutput works around some inconsistencies in output.
+// For example, when iptables is in legacy vs. nftables mode, it produces
+// different results.
+func filterRuleOutput(rule string) string {
+	out := rule
+
+	// work around an output difference in nftables mode where counters
+	// are output in iptables-save format, rather than iptables -S format
+	// The string begins with "[0:0]"
+	//
+	// Fixes #49
+	if groups := counterRegex.FindStringSubmatch(out); groups != nil {
+		// drop the brackets
+		out = out[len(groups[0]):]
+		out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
+	}
+
+	return out
+}

+ 90 - 8
iptables/iptables_test.go

@@ -97,8 +97,10 @@ func mustTestableIptables() []*IPTables {
 }
 
 func TestChain(t *testing.T) {
-	for _, ipt := range mustTestableIptables() {
-		runChainTests(t, ipt)
+	for i, ipt := range mustTestableIptables() {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			runChainTests(t, ipt)
+		})
 	}
 }
 
@@ -179,8 +181,10 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 }
 
 func TestRules(t *testing.T) {
-	for _, ipt := range mustTestableIptables() {
-		runRulesTests(t, ipt)
+	for i, ipt := range mustTestableIptables() {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			runRulesTests(t, ipt)
+		})
 	}
 }
 
@@ -265,12 +269,17 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
 		t.Fatalf("ListWithCounters failed: %v", err)
 	}
 
+	suffix := " -c 0 0 -j ACCEPT"
+	if ipt.mode == "nf_tables" {
+		suffix = " -j ACCEPT -c 0 0"
+	}
+
 	expected = []string{
 		"-N " + chain,
-		"-A " + chain + " -s " + subnet1 + " -d " + address1 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + subnet2 + " -d " + address2 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + subnet2 + " -d " + address1 + " -c 0 0 -j ACCEPT",
-		"-A " + chain + " -s " + address1 + " -d " + subnet2 + " -c 0 0 -j ACCEPT",
+		"-A " + chain + " -s " + subnet1 + " -d " + address1 + suffix,
+		"-A " + chain + " -s " + subnet2 + " -d " + address2 + suffix,
+		"-A " + chain + " -s " + subnet2 + " -d " + address1 + suffix,
+		"-A " + chain + " -s " + address1 + " -d " + subnet2 + suffix,
 	}
 
 	if !reflect.DeepEqual(rules, expected) {
@@ -408,3 +417,76 @@ func TestIsNotExist(t *testing.T) {
 		t.Fatal("IsNotExist returned false, expected true")
 	}
 }
+
+func TestFilterRuleOutput(t *testing.T) {
+	testCases := []struct {
+		name string
+		in   string
+		out  string
+	}{
+		{
+			"legacy output",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+		},
+		{
+			"nft output",
+			"[99:42] -A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
+			"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT -c 99 42",
+		},
+	}
+
+	for _, tt := range testCases {
+		t.Run(tt.name, func(t *testing.T) {
+			actual := filterRuleOutput(tt.in)
+			if actual != tt.out {
+				t.Fatalf("expect %s actual %s", tt.out, actual)
+			}
+		})
+	}
+}
+
+func TestExtractIptablesVersion(t *testing.T) {
+	testCases := []struct {
+		in         string
+		v1, v2, v3 int
+		mode       string
+		err        bool
+	}{
+		{
+			"iptables v1.8.0 (nf_tables)",
+			1, 8, 0,
+			"nf_tables",
+			false,
+		},
+		{
+			"iptables v1.8.0 (legacy)",
+			1, 8, 0,
+			"legacy",
+			false,
+		},
+		{
+			"iptables v1.6.2",
+			1, 6, 2,
+			"legacy",
+			false,
+		},
+	}
+
+	for i, tt := range testCases {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			v1, v2, v3, mode, err := extractIptablesVersion(tt.in)
+			if err == nil && tt.err {
+				t.Fatal("expected err, got none")
+			} else if err != nil && !tt.err {
+				t.Fatalf("unexpected err %s", err)
+			}
+
+			if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
+				t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
+					tt.v1, tt.v2, tt.v3, tt.mode,
+					v1, v2, v3, mode)
+			}
+		})
+	}
+}

+ 5 - 2
test

@@ -45,11 +45,14 @@ split=(${TEST// / })
 TEST=${split[@]/#/${REPO_PATH}/}
 
 echo "Running tests..."
-go test -i ${TEST}
+bin=$(mktemp)
+
+go test -c -o ${bin} ${COVER} -i ${TEST}
 if [[ -z "$SUDO_PERMITTED" ]]; then
     echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable."
     exit 1
 fi
 
-sudo -E bash -c "PATH=\$GOROOT/bin:\$PATH go test ${COVER} $@ ${TEST}"
+sudo -E bash -c "${bin} $@ ${TEST}"
 echo "Success"
+rm "${bin}"