diff --git a/bench_test.go b/bench_test.go index 6f3ef1e..2351cd0 100644 --- a/bench_test.go +++ b/bench_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net" + "sync" "testing" ) @@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) { } }() - b.ResetTimer() + donec := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(b.N) + b.ResetTimer() for i := 0; i < b.N; i++ { c := &mockConn{ r: bytes.NewReader(benchHTTPPayload), } - m.serve(c) + m.serve(c, donec, &wg) } } diff --git a/cmux.go b/cmux.go index cbc259b..e124799 100644 --- a/cmux.go +++ b/cmux.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "sync" ) // Matcher matches a connection based on its content. @@ -48,6 +49,7 @@ func New(l net.Listener) CMux { root: l, bufLen: 1024, errh: func(_ error) bool { return true }, + donec: make(chan struct{}), } } @@ -74,6 +76,7 @@ type cMux struct { root net.Listener bufLen int errh ErrorHandler + donec chan struct{} sls []matchersListener } @@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener { ml := muxListener{ Listener: m.root, connc: make(chan net.Conn, m.bufLen), - donec: make(chan struct{}), } m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) return ml } func (m *cMux) Serve() error { + var wg sync.WaitGroup + defer func() { + close(m.donec) + wg.Wait() + for _, sl := range m.sls { - close(sl.l.donec) + close(sl.l.connc) } }() @@ -103,11 +110,14 @@ func (m *cMux) Serve() error { continue } - go m.serve(c) + wg.Add(1) + go m.serve(c, m.donec, &wg) } } -func (m *cMux) serve(c net.Conn) { +func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() + muc := newMuxConn(c) for _, sl := range m.sls { for _, s := range sl.ss { @@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) { if matched { select { case sl.l.connc <- muc: - case <-sl.l.donec: - _ = c.Close() + default: + select { + case <-donec: + _ = c.Close() + default: + } } return } @@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool { type muxListener struct { net.Listener connc chan net.Conn - donec chan struct{} } func (l muxListener) Accept() (net.Conn, error) { - select { - case c := <-l.connc: - return c, nil - case <-l.donec: + c, ok := <-l.connc + if !ok { return nil, ErrListenerClosed } + return c, nil } // MuxConn wraps a net.Conn and provides transparent sniffing of connection data. diff --git a/cmux_test.go b/cmux_test.go index 3f0588d..433763c 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -1,6 +1,7 @@ package cmux import ( + "errors" "fmt" "io/ioutil" "net" @@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) { } } +type chanListener struct { + net.Listener + connCh chan net.Conn +} + +func newChanListener() *chanListener { + return &chanListener{connCh: make(chan net.Conn, 1)} +} + +func (l *chanListener) Accept() (net.Conn, error) { + if c, ok := <-l.connCh; ok { + return c, nil + } + return nil, errors.New("use of closed network connection") +} + func testListener(t *testing.T) (net.Listener, func()) { l, err := net.Listen("tcp", ":0") if err != nil { @@ -235,21 +252,41 @@ func TestErrorHandler(t *testing.T) { } } -type closerConn struct { - net.Conn -} - -func (c closerConn) Close() error { return nil } - -func TestClosed(t *testing.T) { +func TestClose(t *testing.T) { defer leakCheck(t)() - mux := &cMux{} - lis := mux.Match(Any()).(muxListener) - close(lis.donec) - mux.serve(closerConn{}) - _, err := lis.Accept() - if _, ok := err.(errListenerClosed); !ok { - t.Errorf("expected errListenerClosed got %v", err) + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + l := newChanListener() + + c1, c2 := net.Pipe() + + muxl := New(l) + anyl := muxl.Match(Any()) + + go safeServe(errCh, muxl) + + l.connCh <- c1 + + // First connection goes through. + if _, err := anyl.Accept(); err != nil { + t.Fatal(err) + } + + // Second connection is sent + l.connCh <- c2 + + // Listener is closed. + close(l.connCh) + + // Second connection goes through. + if _, err := anyl.Accept(); err != nil { + t.Fatal(err) } }