diff --git a/pkg/transport/timeout_conn.go b/pkg/transport/timeout_conn.go new file mode 100644 index 000000000..32069af38 --- /dev/null +++ b/pkg/transport/timeout_conn.go @@ -0,0 +1,30 @@ +package transport + +import ( + "net" + "time" +) + +type timeoutConn struct { + net.Conn + wtimeoutd time.Duration + rdtimeoutd time.Duration +} + +func (c timeoutConn) Write(b []byte) (n int, err error) { + if c.wtimeoutd > 0 { + if err := c.SetWriteDeadline(time.Now().Add(c.wtimeoutd)); err != nil { + return 0, err + } + } + return c.Conn.Write(b) +} + +func (c timeoutConn) Read(b []byte) (n int, err error) { + if c.rdtimeoutd > 0 { + if err := c.SetReadDeadline(time.Now().Add(c.rdtimeoutd)); err != nil { + return 0, err + } + } + return c.Conn.Read(b) +} diff --git a/pkg/transport/timeout_dailer.go b/pkg/transport/timeout_dailer.go new file mode 100644 index 000000000..104cf3f7f --- /dev/null +++ b/pkg/transport/timeout_dailer.go @@ -0,0 +1,22 @@ +package transport + +import ( + "net" + "time" +) + +type rwTimeoutDialer struct { + wtimeoutd time.Duration + rdtimeoutd time.Duration + net.Dialer +} + +func (d *rwTimeoutDialer) Dial(network, address string) (net.Conn, error) { + conn, err := d.Dialer.Dial(network, address) + tconn := &timeoutConn{ + rdtimeoutd: d.rdtimeoutd, + wtimeoutd: d.wtimeoutd, + Conn: conn, + } + return tconn, err +} diff --git a/pkg/transport/timeout_dailer_test.go b/pkg/transport/timeout_dailer_test.go new file mode 100644 index 000000000..50357fd0c --- /dev/null +++ b/pkg/transport/timeout_dailer_test.go @@ -0,0 +1,73 @@ +package transport + +import ( + "net" + "testing" + "time" +) + +func TestReadWriteTimeoutDialer(t *testing.T) { + stop := make(chan struct{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("unexpected listen error: %v", err) + } + ts := testBlockingServer{ln, 2, stop} + go ts.Start(t) + + d := rwTimeoutDialer{ + wtimeoutd: 10 * time.Millisecond, + rdtimeoutd: 10 * time.Millisecond, + } + conn, err := d.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("unexpected dial error: %v", err) + } + defer conn.Close() + + // fill the socket buffer + data := make([]byte, 1024*1024) + timer := time.AfterFunc(d.wtimeoutd*5, func() { + t.Fatal("wait timeout") + }) + defer timer.Stop() + + _, err = conn.Write(data) + if operr, ok := err.(*net.OpError); !ok || operr.Op != "write" || !operr.Timeout() { + t.Errorf("err = %v, want write i/o timeout error", err) + } + + timer.Reset(d.rdtimeoutd * 5) + + conn, err = d.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("unexpected dial error: %v", err) + } + defer conn.Close() + + buf := make([]byte, 10) + _, err = conn.Read(buf) + if operr, ok := err.(*net.OpError); !ok || operr.Op != "read" || !operr.Timeout() { + t.Errorf("err = %v, want write i/o timeout error", err) + } + + stop <- struct{}{} +} + +type testBlockingServer struct { + ln net.Listener + n int + stop chan struct{} +} + +func (ts *testBlockingServer) Start(t *testing.T) { + for i := 0; i < ts.n; i++ { + conn, err := ts.ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + } + <-ts.stop +} diff --git a/pkg/transport/timeout_listener.go b/pkg/transport/timeout_listener.go new file mode 100644 index 000000000..a58f3ed36 --- /dev/null +++ b/pkg/transport/timeout_listener.go @@ -0,0 +1,24 @@ +package transport + +import ( + "net" + "time" +) + +type rwTimeoutListener struct { + net.Listener + wtimeoutd time.Duration + rdtimeoutd time.Duration +} + +func (rwln *rwTimeoutListener) Accept() (net.Conn, error) { + c, err := rwln.Listener.Accept() + if err != nil { + return nil, err + } + return timeoutConn{ + Conn: c, + wtimeoutd: rwln.wtimeoutd, + rdtimeoutd: rwln.rdtimeoutd, + }, nil +} diff --git a/pkg/transport/timeout_listener_test.go b/pkg/transport/timeout_listener_test.go new file mode 100644 index 000000000..ea0884bbb --- /dev/null +++ b/pkg/transport/timeout_listener_test.go @@ -0,0 +1,64 @@ +package transport + +import ( + "net" + "testing" + "time" +) + +func TestWriteReadTimeoutListener(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("unexpected listen error: %v", err) + } + wln := rwTimeoutListener{ + Listener: ln, + wtimeoutd: 10 * time.Millisecond, + rdtimeoutd: 10 * time.Millisecond, + } + stop := make(chan struct{}) + + blocker := func() { + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("unexpected dail error: %v", err) + } + defer conn.Close() + // block the receiver until the writer timeout + <-stop + } + go blocker() + + conn, err := wln.Accept() + if err != nil { + t.Fatalf("unexpected accept error: %v", err) + } + defer conn.Close() + + // fill the socket buffer + data := make([]byte, 1024*1024) + timer := time.AfterFunc(wln.wtimeoutd*5, func() { + t.Fatal("wait timeout") + }) + defer timer.Stop() + + _, err = conn.Write(data) + if operr, ok := err.(*net.OpError); !ok || operr.Op != "write" || !operr.Timeout() { + t.Errorf("err = %v, want write i/o timeout error", err) + } + stop <- struct{}{} + + timer.Reset(wln.rdtimeoutd * 5) + go blocker() + + conn, err = wln.Accept() + if err != nil { + t.Fatalf("unexpected accept error: %v", err) + } + buf := make([]byte, 10) + _, err = conn.Read(buf) + if operr, ok := err.(*net.OpError); !ok || operr.Op != "read" || !operr.Timeout() { + t.Errorf("err = %v, want write i/o timeout error", err) + } + stop <- struct{}{} +}