Browse Source

Merge pull request #23 from robbertkl/master

Add support for IPv6 with ip6tables
Luca Bruno 8 years ago
parent
commit
18855ec8ac
2 changed files with 96 additions and 19 deletions
  1. 40 6
      iptables/iptables.go
  2. 56 13
      iptables/iptables_test.go

+ 40 - 6
iptables/iptables.go

@@ -39,29 +39,52 @@ func (e *Error) Error() string {
 	return fmt.Sprintf("exit status %v: %v", e.ExitStatus(), e.msg)
 }
 
+// Protocol to differentiate between IPv4 and IPv6
+type Protocol byte
+
+const (
+	ProtocolIPv4 Protocol = iota
+	ProtocolIPv6
+)
+
 type IPTables struct {
 	path     string
+	proto    Protocol
 	hasCheck bool
 	hasWait  bool
 }
 
+// New creates a new IPTables.
+// For backwards compatibility, this always uses IPv4, i.e. "iptables".
 func New() (*IPTables, error) {
-	path, err := exec.LookPath("iptables")
+	return NewWithProtocol(ProtocolIPv4)
+}
+
+// New creates a new IPTables for the given proto.
+// The proto will determine which command is used, either "iptables" or "ip6tables".
+func NewWithProtocol(proto Protocol) (*IPTables, error) {
+	path, err := exec.LookPath(getIptablesCommand(proto))
 	if err != nil {
 		return nil, err
 	}
-	checkPresent, waitPresent, err := getIptablesCommandSupport()
+	checkPresent, waitPresent, err := getIptablesCommandSupport(path)
 	if err != nil {
 		return nil, fmt.Errorf("error checking iptables version: %v", err)
 	}
 	ipt := IPTables{
 		path:     path,
+		proto:    proto,
 		hasCheck: checkPresent,
 		hasWait:  waitPresent,
 	}
 	return &ipt, nil
 }
 
+// Proto returns the protocol used by this IPTables.
+func (ipt *IPTables) Proto() Protocol {
+	return ipt.proto
+}
+
 // Exists checks if given rulespec in specified table/chain exists
 func (ipt *IPTables) Exists(table, chain string, rulespec ...string) (bool, error) {
 	if !ipt.hasCheck {
@@ -129,6 +152,8 @@ func (ipt *IPTables) List(table, chain string) ([]string, error) {
 	return rules, nil
 }
 
+// NewChain creates a new chain in the specified table.
+// If the chain already exists, it will result in an error.
 func (ipt *IPTables) NewChain(table, chain string) error {
 	return ipt.run("-t", table, "-N", chain)
 }
@@ -200,9 +225,18 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 	return nil
 }
 
+// getIptablesCommand returns the correct command for the given protocol, either "iptables" or "ip6tables".
+func getIptablesCommand(proto Protocol) string {
+	if proto == ProtocolIPv6 {
+		return "ip6tables"
+	} else {
+		return "iptables"
+	}
+}
+
 // Checks if iptables has the "-C" and "--wait" flag
-func getIptablesCommandSupport() (bool, bool, error) {
-	vstring, err := getIptablesVersionString()
+func getIptablesCommandSupport(path string) (bool, bool, error) {
+	vstring, err := getIptablesVersionString(path)
 	if err != nil {
 		return false, false, err
 	}
@@ -243,8 +277,8 @@ func extractIptablesVersion(str string) (int, int, int, error) {
 }
 
 // Runs "iptables --version" to get the version string
-func getIptablesVersionString() (string, error) {
-	cmd := exec.Command("iptables", "--version")
+func getIptablesVersionString(path string) (string, error) {
+	cmd := exec.Command(path, "--version")
 	var out bytes.Buffer
 	cmd.Stdout = &out
 	err := cmd.Run()

+ 56 - 13
iptables/iptables_test.go

@@ -22,6 +22,32 @@ import (
 	"testing"
 )
 
+func TestProto(t *testing.T) {
+	ipt, err := New()
+	if err != nil {
+		t.Fatalf("New failed: %v", err)
+	}
+	if ipt.Proto() != ProtocolIPv4 {
+		t.Fatalf("Expected default protocol IPv4, got %v", ipt.Proto())
+	}
+
+	ip4t, err := NewWithProtocol(ProtocolIPv4)
+	if err != nil {
+		t.Fatalf("NewWithProtocol(ProtocolIPv4) failed: %v", err)
+	}
+	if ip4t.Proto() != ProtocolIPv4 {
+		t.Fatalf("Expected protocol IPv4, got %v", ip4t.Proto())
+	}
+
+	ip6t, err := NewWithProtocol(ProtocolIPv6)
+	if err != nil {
+		t.Fatalf("NewWithProtocol(ProtocolIPv6) failed: %v", err)
+	}
+	if ip6t.Proto() != ProtocolIPv6 {
+		t.Fatalf("Expected protocol IPv6, got %v", ip6t.Proto())
+	}
+}
+
 func randChain(t *testing.T) string {
 	n, err := rand.Int(rand.Reader, big.NewInt(1000000))
 	if err != nil {
@@ -38,7 +64,11 @@ func mustTestableIptables() []*IPTables {
 	if err != nil {
 		panic(fmt.Sprintf("New failed: %v", err))
 	}
-	ipts := []*IPTables{ipt}
+	ip6t, err := NewWithProtocol(ProtocolIPv6)
+	if err != nil {
+		panic(fmt.Sprintf("NewWithProtocol(ProtocolIPv6) failed: %v", err))
+	}
+	ipts := []*IPTables{ipt, ip6t}
 	// ensure we check one variant without built-in locking
 	if ipt.hasWait {
 		iptNoWait := &IPTables{
@@ -65,7 +95,7 @@ func TestChain(t *testing.T) {
 }
 
 func runChainTests(t *testing.T, ipt *IPTables) {
-	t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
+	t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto()), ipt.hasWait, ipt.hasCheck)
 
 	chain := randChain(t)
 
@@ -82,7 +112,7 @@ func runChainTests(t *testing.T, ipt *IPTables) {
 	}
 
 	// put a simple rule in
-	err = ipt.Append("filter", chain, "-s", "0.0.0.0/0", "-j", "ACCEPT")
+	err = ipt.Append("filter", chain, "-s", "0/0", "-j", "ACCEPT")
 	if err != nil {
 		t.Fatalf("Append failed: %v", err)
 	}
@@ -119,7 +149,20 @@ func TestRules(t *testing.T) {
 }
 
 func runRulesTests(t *testing.T, ipt *IPTables) {
-	t.Logf("testing iptables (hasWait=%t, hasCheck=%t)", ipt.hasWait, ipt.hasCheck)
+	t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto()), ipt.hasWait, ipt.hasCheck)
+
+	var address1, address2, subnet1, subnet2 string
+	if ipt.Proto() == ProtocolIPv6 {
+		address1 = "2001:db8::1/128"
+		address2 = "2001:db8::2/128"
+		subnet1 = "2001:db8:a::/48"
+		subnet2 = "2001:db8:b::/48"
+	} else {
+		address1 = "203.0.113.1/32"
+		address2 = "203.0.113.2/32"
+		subnet1 = "192.0.2.0/24"
+		subnet2 = "198.51.100.0/24"
+	}
 
 	chain := randChain(t)
 
@@ -129,32 +172,32 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
 		t.Fatalf("ClearChain (of missing) failed: %v", err)
 	}
 
-	err = ipt.Append("filter", chain, "-s", "10.1.0.0/16", "-d", "8.8.8.8/32", "-j", "ACCEPT")
+	err = ipt.Append("filter", chain, "-s", subnet1, "-d", address1, "-j", "ACCEPT")
 	if err != nil {
 		t.Fatalf("Append failed: %v", err)
 	}
 
-	err = ipt.AppendUnique("filter", chain, "-s", "10.1.0.0/16", "-d", "8.8.8.8/32", "-j", "ACCEPT")
+	err = ipt.AppendUnique("filter", chain, "-s", subnet1, "-d", address1, "-j", "ACCEPT")
 	if err != nil {
 		t.Fatalf("AppendUnique failed: %v", err)
 	}
 
-	err = ipt.Append("filter", chain, "-s", "10.2.0.0/16", "-d", "8.8.8.8/32", "-j", "ACCEPT")
+	err = ipt.Append("filter", chain, "-s", subnet2, "-d", address1, "-j", "ACCEPT")
 	if err != nil {
 		t.Fatalf("Append failed: %v", err)
 	}
 
-	err = ipt.Insert("filter", chain, 2, "-s", "10.2.0.0/16", "-d", "9.9.9.9/32", "-j", "ACCEPT")
+	err = ipt.Insert("filter", chain, 2, "-s", subnet2, "-d", address2, "-j", "ACCEPT")
 	if err != nil {
 		t.Fatalf("Insert failed: %v", err)
 	}
 
-	err = ipt.Insert("filter", chain, 1, "-s", "10.1.0.0/16", "-d", "9.9.9.9/32", "-j", "ACCEPT")
+	err = ipt.Insert("filter", chain, 1, "-s", subnet1, "-d", address2, "-j", "ACCEPT")
 	if err != nil {
 		t.Fatalf("Insert failed: %v", err)
 	}
 
-	err = ipt.Delete("filter", chain, "-s", "10.1.0.0/16", "-d", "9.9.9.9/32", "-j", "ACCEPT")
+	err = ipt.Delete("filter", chain, "-s", subnet1, "-d", address2, "-j", "ACCEPT")
 	if err != nil {
 		t.Fatalf("Delete failed: %v", err)
 	}
@@ -166,9 +209,9 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
 
 	expected := []string{
 		"-N " + chain,
-		"-A " + chain + " -s 10.1.0.0/16 -d 8.8.8.8/32 -j ACCEPT",
-		"-A " + chain + " -s 10.2.0.0/16 -d 9.9.9.9/32 -j ACCEPT",
-		"-A " + chain + " -s 10.2.0.0/16 -d 8.8.8.8/32 -j ACCEPT",
+		"-A " + chain + " -s " + subnet1 + " -d " + address1 + " -j ACCEPT",
+		"-A " + chain + " -s " + subnet2 + " -d " + address2 + " -j ACCEPT",
+		"-A " + chain + " -s " + subnet2 + " -d " + address1 + " -j ACCEPT",
 	}
 
 	if !reflect.DeepEqual(rules, expected) {