Browse Source

*: refactor iptables invocation

Centralise all invocations of iptables, so that the lock is used every
time it is called.
Jonathan Boulle 9 years ago
parent
commit
591ab2760d
1 changed files with 26 additions and 34 deletions
  1. 26 34
      iptables/iptables.go

+ 26 - 34
iptables/iptables.go

@@ -17,6 +17,7 @@ package iptables
 import (
 	"bytes"
 	"fmt"
+	"io"
 	"log"
 	"os/exec"
 	"regexp"
@@ -75,8 +76,8 @@ func New() (*IPTables, error) {
 // Exists checks if given rulespec in specified table/chain exists
 func (ipt *IPTables) Exists(table, chain string, rulespec ...string) (bool, error) {
 	if !ipt.hasCheck {
-		cmd := append([]string{"-A", chain}, rulespec...)
-		return existsForOldIpTables(table, strings.Join(cmd, " "))
+		return ipt.existsForOldIptables(table, chain, rulespec)
+
 	}
 	cmd := append([]string{"-t", table, "-C", chain}, rulespec...)
 	err := ipt.run(cmd...)
@@ -124,28 +125,10 @@ func (ipt *IPTables) Delete(table, chain string, rulespec ...string) error {
 
 // List rules in specified table/chain
 func (ipt *IPTables) List(table, chain string) ([]string, error) {
-	var stdout, stderr bytes.Buffer
-	args := []string{ipt.path, "-t", table, "-S", chain}
-
-	if ipt.hasWait {
-		args = append(args, "--wait")
-	} else {
-		ul, err := ipt.fmu.tryLock()
-		if err != nil {
-			return nil, err
-		}
-		defer ul.Unlock()
-	}
-
-	cmd := exec.Cmd{
-		Path:   ipt.path,
-		Args:   args,
-		Stdout: &stdout,
-		Stderr: &stderr,
-	}
-
-	if err := cmd.Run(); err != nil {
-		return nil, &Error{*(err.(*exec.ExitError)), stderr.String()}
+	args := []string{"-t", table, "-S", chain}
+	var stdout bytes.Buffer
+	if err := ipt.runWithOutput(args, &stdout); err != nil {
+		return nil, err
 	}
 
 	rules := strings.Split(stdout.String(), "\n")
@@ -182,10 +165,18 @@ func (ipt *IPTables) DeleteChain(table, chain string) error {
 	return ipt.run("-t", table, "-X", chain)
 }
 
+// run runs an iptables command with the given arguments, ignoring
+// any stdout output
 func (ipt *IPTables) run(args ...string) error {
-	var stderr bytes.Buffer
+	return ipt.runWithOutput(args, nil)
+}
+
+// runWithOutput runs an iptables command with the given arguments,
+// writing any stdout output to the given writer
+func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
+	args = append([]string{ipt.path}, args...)
 	if ipt.hasWait {
-		args = append([]string{"--wait"}, args...)
+		args = append(args, "--wait")
 	} else {
 		ul, err := ipt.fmu.tryLock()
 		if err != nil {
@@ -194,9 +185,11 @@ func (ipt *IPTables) run(args ...string) error {
 		defer ul.Unlock()
 	}
 
+	var stderr bytes.Buffer
 	cmd := exec.Cmd{
 		Path:   ipt.path,
-		Args:   append([]string{ipt.path}, args...),
+		Args:   args,
+		Stdout: stdout,
 		Stderr: &stderr,
 	}
 
@@ -290,14 +283,13 @@ func iptablesHasWaitCommand(v1 int, v2 int, v3 int) bool {
 }
 
 // Checks if a rule specification exists for a table
-func existsForOldIpTables(table string, ruleSpec string) (bool, error) {
-	cmd := exec.Command("iptables", "-t", table, "-S")
-	var out bytes.Buffer
-	cmd.Stdout = &out
-	err := cmd.Run()
+func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string) (bool, error) {
+	rs := strings.Join(append([]string{"-A", chain}, rulespec...), " ")
+	args := []string{"-t", table, "-S"}
+	var stdout bytes.Buffer
+	err := ipt.runWithOutput(args, &stdout)
 	if err != nil {
 		return false, err
 	}
-	rules := out.String()
-	return strings.Contains(rules, ruleSpec), nil
+	return strings.Contains(stdout.String(), rs), nil
 }