Browse Source

Support overriding binaries' path

Allow overriding binaries path, via a new function, named Path(string).
This works similarity to other options such as Timeout() and configures
IPTables's scruct.

This is particularly useful in sudo-confined environments or any wrapping scenario.

while there, add a set of tests that force iptables into legacy or nft versions
and then compare it with the mode parsed through executing the utility
Costas Drogos 1 year ago
parent
commit
f61413f163
2 changed files with 71 additions and 3 deletions
  1. 23 3
      iptables/iptables.go
  2. 48 0
      iptables/iptables_test.go

+ 23 - 3
iptables/iptables.go

@@ -106,8 +106,20 @@ func Timeout(timeout int) option {
 	}
 }
 
-// New creates a new IPTables configured with the options passed as parameter.
-// For backwards compatibility, by default always uses IPv4 and timeout 0.
+func Path(path string) option {
+	return func(ipt *IPTables) {
+		ipt.path = path
+	}
+}
+
+// New creates a new IPTables configured with the options passed as parameters.
+// Supported parameters are:
+//
+//	IPFamily(Protocol)
+//	Timeout(int)
+//	Path(string)
+//
+// For backwards compatibility, by default New uses IPv4 and timeout 0.
 // i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
 // the IPFamily and Timeout options as follow:
 //
@@ -117,13 +129,21 @@ func New(opts ...option) (*IPTables, error) {
 	ipt := &IPTables{
 		proto:   ProtocolIPv4,
 		timeout: 0,
+		path:    "",
 	}
 
 	for _, opt := range opts {
 		opt(ipt)
 	}
 
-	path, err := exec.LookPath(getIptablesCommand(ipt.proto))
+	// if path wasn't preset through New(Path()), autodiscover it
+	cmd := ""
+	if ipt.path == "" {
+		cmd = getIptablesCommand(ipt.proto)
+	} else {
+		cmd = ipt.path
+	}
+	path, err := exec.LookPath(cmd)
 	if err != nil {
 		return nil, err
 	}

+ 48 - 0
iptables/iptables_test.go

@@ -70,6 +70,54 @@ func TestTimeout(t *testing.T) {
 
 }
 
+// force usage of -legacy or -nft commands and check that they're detected correctly
+func TestLegacyDetection(t *testing.T) {
+	testCases := []struct {
+		in   string
+		mode string
+		err  bool
+	}{
+		{
+			"iptables-legacy",
+			"legacy",
+			false,
+		},
+		{
+			"ip6tables-legacy",
+			"legacy",
+			false,
+		},
+		{
+			"iptables-nft",
+			"nf_tables",
+			false,
+		},
+		{
+			"ip6tables-nft",
+			"nf_tables",
+			false,
+		},
+	}
+
+	for i, tt := range testCases {
+		t.Run(fmt.Sprint(i), func(t *testing.T) {
+			ipt, err := New(Path(tt.in))
+			if err == nil && tt.err {
+				t.Fatal("expected err, got none")
+			} else if err != nil && !tt.err {
+				t.Fatalf("unexpected err %s", err)
+			}
+
+			if !strings.Contains(ipt.path, tt.in) {
+				t.Fatalf("Expected path %s in %s", tt.in, ipt.path)
+			}
+			if ipt.mode != tt.mode {
+				t.Fatalf("Expected %s iptables, but got %s", tt.mode, ipt.mode)
+			}
+		})
+	}
+}
+
 func randChain(t *testing.T) string {
 	n, err := rand.Int(rand.Reader, big.NewInt(1000000))
 	if err != nil {