소스 검색

Merge pull request #75 from aojea/wait

make iptables timeout configurable
Casey Callendrello 4 년 전
부모
커밋
abea47c2ca
2개의 변경된 파일68개의 추가작업 그리고 20개의 파일을 삭제
  1. 49 20
      iptables/iptables.go
  2. 19 0
      iptables/iptables_test.go

+ 49 - 20
iptables/iptables.go

@@ -73,6 +73,7 @@ type IPTables struct {
 	v2             int
 	v2             int
 	v3             int
 	v3             int
 	mode           string // the underlying iptables operating mode, e.g. nf_tables
 	mode           string // the underlying iptables operating mode, e.g. nf_tables
+	timeout        int    // time to wait for the iptables lock, default waits forever
 }
 }
 
 
 // Stat represents a structured statistic entry.
 // Stat represents a structured statistic entry.
@@ -89,19 +90,42 @@ type Stat struct {
 	Options     string     `json:"options"`
 	Options     string     `json:"options"`
 }
 }
 
 
-// New creates a new IPTables.
-// For backwards compatibility, this always uses IPv4, i.e. "iptables".
-func New() (*IPTables, error) {
-	return NewWithProtocol(ProtocolIPv4)
+type option func(*IPTables)
+
+func IPFamily(proto Protocol) option {
+	return func(ipt *IPTables) {
+		ipt.proto = proto
+	}
 }
 }
 
 
-// New creates a new IPTables for the given proto.
-// The proto will determine which command is used, either "iptables" or "ip6tables".
-func NewWithProtocol(proto Protocol) (*IPTables, error) {
-	path, err := exec.LookPath(getIptablesCommand(proto))
+func Timeout(timeout int) option {
+	return func(ipt *IPTables) {
+		ipt.timeout = timeout
+	}
+}
+
+// New creates a new IPTables configured with the options passed as parameter.
+// For backwards compatibility, by default always uses IPv4 and timeout 0.
+// i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
+// the IPFamily and Timeout options as follow:
+//	ip6t := New(IPFamily(ProtocolIPv6), Timeout(5))
+func New(opts ...option) (*IPTables, error) {
+
+	ipt := &IPTables{
+		proto:   ProtocolIPv4,
+		timeout: 0,
+	}
+
+	for _, opt := range opts {
+		opt(ipt)
+	}
+
+	path, err := exec.LookPath(getIptablesCommand(ipt.proto))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	ipt.path = path
+
 	vstring, err := getIptablesVersionString(path)
 	vstring, err := getIptablesVersionString(path)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("could not get iptables version: %v", err)
 		return nil, fmt.Errorf("could not get iptables version: %v", err)
@@ -110,21 +134,23 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err)
 		return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err)
 	}
 	}
+	ipt.v1 = v1
+	ipt.v2 = v2
+	ipt.v3 = v3
+	ipt.mode = mode
 
 
 	checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
 	checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
+	ipt.hasCheck = checkPresent
+	ipt.hasWait = waitPresent
+	ipt.hasRandomFully = randomFullyPresent
 
 
-	ipt := IPTables{
-		path:           path,
-		proto:          proto,
-		hasCheck:       checkPresent,
-		hasWait:        waitPresent,
-		hasRandomFully: randomFullyPresent,
-		v1:             v1,
-		v2:             v2,
-		v3:             v3,
-		mode:           mode,
-	}
-	return &ipt, nil
+	return ipt, nil
+}
+
+// New creates a new IPTables for the given proto.
+// The proto will determine which command is used, either "iptables" or "ip6tables".
+func NewWithProtocol(proto Protocol) (*IPTables, error) {
+	return New(IPFamily(proto), Timeout(0))
 }
 }
 
 
 // Proto returns the protocol used by this IPTables.
 // Proto returns the protocol used by this IPTables.
@@ -461,6 +487,9 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 	args = append([]string{ipt.path}, args...)
 	args = append([]string{ipt.path}, args...)
 	if ipt.hasWait {
 	if ipt.hasWait {
 		args = append(args, "--wait")
 		args = append(args, "--wait")
+		if ipt.timeout != 0 {
+			args = append(args, strconv.Itoa(ipt.timeout))
+		}
 	} else {
 	} else {
 		fmu, err := newXtablesFileLock()
 		fmu, err := newXtablesFileLock()
 		if err != nil {
 		if err != nil {

+ 19 - 0
iptables/iptables_test.go

@@ -50,6 +50,25 @@ func TestProto(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestTimeout(t *testing.T) {
+	ipt, err := New()
+	if err != nil {
+		t.Fatalf("New failed: %v", err)
+	}
+	if ipt.timeout != 0 {
+		t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
+	}
+
+	ipt2, err := New(Timeout(5))
+	if err != nil {
+		t.Fatalf("New failed: %v", err)
+	}
+	if ipt2.timeout != 5 {
+		t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
+	}
+
+}
+
 func randChain(t *testing.T) string {
 func randChain(t *testing.T) string {
 	n, err := rand.Int(rand.Reader, big.NewInt(1000000))
 	n, err := rand.Int(rand.Reader, big.NewInt(1000000))
 	if err != nil {
 	if err != nil {