From ad5a3e2fa068a33cc5d30e333ab635e6b173ce8e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 20 Jul 2018 08:26:36 -0400 Subject: [PATCH] also use the multiplexer for the server --- client.go | 13 ++ internal/protocol/perspective.go | 5 + internal/protocol/perspective_test.go | 5 + mock_packet_handler_manager_test.go | 37 ++-- mock_packet_handler_test.go | 91 ++++++++ mock_unknown_packet_handler_test.go | 56 +++++ mockgen.go | 2 + multiplexer.go | 4 +- packet_handler_map.go | 94 +++++--- packet_handler_map_test.go | 130 ++++++----- server.go | 234 ++++++-------------- server_session.go | 63 ++++++ server_session_test.go | 101 +++++++++ server_test.go | 302 ++++++-------------------- server_tls.go | 6 +- 15 files changed, 631 insertions(+), 512 deletions(-) create mode 100644 mock_packet_handler_test.go create mode 100644 mock_unknown_packet_handler_test.go create mode 100644 server_session.go create mode 100644 server_session_test.go diff --git a/client.go b/client.go index 8aabc0b7..74a9bc28 100644 --- a/client.go +++ b/client.go @@ -544,9 +544,22 @@ func (c *client) Close() error { return c.session.Close() } +func (c *client) destroy(e error) { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.session == nil { + return + } + c.session.destroy(e) +} + func (c *client) GetVersion() protocol.VersionNumber { c.mutex.Lock() v := c.version c.mutex.Unlock() return v } + +func (c *client) GetPerspective() protocol.Perspective { + return protocol.PerspectiveClient +} diff --git a/internal/protocol/perspective.go b/internal/protocol/perspective.go index 948e371a..43358fec 100644 --- a/internal/protocol/perspective.go +++ b/internal/protocol/perspective.go @@ -9,6 +9,11 @@ const ( PerspectiveClient Perspective = 2 ) +// Opposite returns the perspective of the peer +func (p Perspective) Opposite() Perspective { + return 3 - p +} + func (p Perspective) String() string { switch p { case PerspectiveServer: diff --git a/internal/protocol/perspective_test.go b/internal/protocol/perspective_test.go index 55e47706..0ae23d7c 100644 --- a/internal/protocol/perspective_test.go +++ b/internal/protocol/perspective_test.go @@ -11,4 +11,9 @@ var _ = Describe("Perspective", func() { Expect(PerspectiveServer.String()).To(Equal("Server")) Expect(Perspective(0).String()).To(Equal("invalid perspective")) }) + + It("returns the opposite", func() { + Expect(PerspectiveClient.Opposite()).To(Equal(PerspectiveServer)) + Expect(PerspectiveServer.Opposite()).To(Equal(PerspectiveClient)) + }) }) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 5cd722cb..fef4eb84 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -44,29 +44,14 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } -// Close mocks base method -func (m *MockPacketHandlerManager) Close() error { - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 +// CloseServer mocks base method +func (m *MockPacketHandlerManager) CloseServer() { + m.ctrl.Call(m, "CloseServer") } -// Close indicates an expected call of Close -func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close)) -} - -// Get mocks base method -func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(packetHandler) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// Get indicates an expected call of Get -func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) +// CloseServer indicates an expected call of CloseServer +func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } // Remove mocks base method @@ -78,3 +63,13 @@ func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) } + +// SetServer mocks base method +func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) { + m.ctrl.Call(m, "SetServer", arg0) +} + +// SetServer indicates an expected call of SetServer +func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0) +} diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go new file mode 100644 index 00000000..dfa884a9 --- /dev/null +++ b/mock_packet_handler_test.go @@ -0,0 +1,91 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: PacketHandler) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockPacketHandler is a mock of PacketHandler interface +type MockPacketHandler struct { + ctrl *gomock.Controller + recorder *MockPacketHandlerMockRecorder +} + +// MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler +type MockPacketHandlerMockRecorder struct { + mock *MockPacketHandler +} + +// NewMockPacketHandler creates a new mock instance +func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler { + mock := &MockPacketHandler{ctrl: ctrl} + mock.recorder = &MockPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockPacketHandler) Close() error { + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockPacketHandlerMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close)) +} + +// GetPerspective mocks base method +func (m *MockPacketHandler) GetPerspective() protocol.Perspective { + ret := m.ctrl.Call(m, "GetPerspective") + ret0, _ := ret[0].(protocol.Perspective) + return ret0 +} + +// GetPerspective indicates an expected call of GetPerspective +func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective)) +} + +// GetVersion mocks base method +func (m *MockPacketHandler) GetVersion() protocol.VersionNumber { + ret := m.ctrl.Call(m, "GetVersion") + ret0, _ := ret[0].(protocol.VersionNumber) + return ret0 +} + +// GetVersion indicates an expected call of GetVersion +func (mr *MockPacketHandlerMockRecorder) GetVersion() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockPacketHandler)(nil).GetVersion)) +} + +// destroy mocks base method +func (m *MockPacketHandler) destroy(arg0 error) { + m.ctrl.Call(m, "destroy", arg0) +} + +// destroy indicates an expected call of destroy +func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) +} + +// handlePacket mocks base method +func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { + m.ctrl.Call(m, "handlePacket", arg0) +} + +// handlePacket indicates an expected call of handlePacket +func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) +} diff --git a/mock_unknown_packet_handler_test.go b/mock_unknown_packet_handler_test.go new file mode 100644 index 00000000..65f2978a --- /dev/null +++ b/mock_unknown_packet_handler_test.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: UnknownPacketHandler) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockUnknownPacketHandler is a mock of UnknownPacketHandler interface +type MockUnknownPacketHandler struct { + ctrl *gomock.Controller + recorder *MockUnknownPacketHandlerMockRecorder +} + +// MockUnknownPacketHandlerMockRecorder is the mock recorder for MockUnknownPacketHandler +type MockUnknownPacketHandlerMockRecorder struct { + mock *MockUnknownPacketHandler +} + +// NewMockUnknownPacketHandler creates a new mock instance +func NewMockUnknownPacketHandler(ctrl *gomock.Controller) *MockUnknownPacketHandler { + mock := &MockUnknownPacketHandler{ctrl: ctrl} + mock.recorder = &MockUnknownPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorder { + return m.recorder +} + +// closeWithError mocks base method +func (m *MockUnknownPacketHandler) closeWithError(arg0 error) error { + ret := m.ctrl.Call(m, "closeWithError", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// closeWithError indicates an expected call of closeWithError +func (mr *MockUnknownPacketHandlerMockRecorder) closeWithError(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).closeWithError), arg0) +} + +// handlePacket mocks base method +func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) { + m.ctrl.Call(m, "handlePacket", arg0) +} + +// handlePacket indicates an expected call of handlePacket +func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0) +} diff --git a/mockgen.go b/mockgen.go index 13da145a..9c0e08b8 100644 --- a/mockgen.go +++ b/mockgen.go @@ -13,6 +13,8 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner" //go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession QuicSession" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler UnknownPacketHandler" //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager PacketHandlerManager" //go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer Multiplexer" //go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'" diff --git a/multiplexer.go b/multiplexer.go index 6c3c1689..c4482ac2 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -28,7 +28,7 @@ type connMultiplexer struct { mutex sync.Mutex conns map[net.PacketConn]connManager - newPacketHandlerManager func(net.PacketConn, int, utils.Logger, bool) packetHandlerManager // so it can be replaced in the tests + newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests logger utils.Logger } @@ -52,7 +52,7 @@ func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandle p, ok := m.conns[c] if !ok { - manager := m.newPacketHandlerManager(c, connIDLen, m.logger, true) + manager := m.newPacketHandlerManager(c, connIDLen, m.logger) p = connManager{connIDLen: connIDLen, manager: manager} m.conns[c] = p } diff --git a/packet_handler_map.go b/packet_handler_map.go index 95d67a06..fb9346e0 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "net" - "strings" "sync" "time" @@ -24,6 +23,7 @@ type packetHandlerMap struct { connIDLen int handlers map[string] /* string(ConnectionID)*/ packetHandler + server unknownPacketHandler closed bool deleteClosedSessionsAfter time.Duration @@ -33,8 +33,7 @@ type packetHandlerMap struct { var _ packetHandlerManager = &packetHandlerMap{} -// TODO(#561): remove the listen flag -func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger, listen bool) packetHandlerManager { +func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager { m := &packetHandlerMap{ conn: conn, connIDLen: connIDLen, @@ -42,19 +41,10 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, logger: logger, } - if listen { - go m.listen() - } + go m.listen() return m } -func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) { - h.mutex.RLock() - sess, ok := h.handlers[string(id)] - h.mutex.RUnlock() - return sess, ok -} - func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() h.handlers[string(id)] = handler @@ -62,18 +52,47 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.removeByConnectionIDAsString(string(id)) +} + +func (h *packetHandlerMap) removeByConnectionIDAsString(id string) { h.mutex.Lock() - h.handlers[string(id)] = nil + h.handlers[id] = nil h.mutex.Unlock() time.AfterFunc(h.deleteClosedSessionsAfter, func() { h.mutex.Lock() - delete(h.handlers, string(id)) + delete(h.handlers, id) h.mutex.Unlock() }) } -func (h *packetHandlerMap) Close() error { +func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { + h.mutex.Lock() + h.server = s + h.mutex.Unlock() +} + +func (h *packetHandlerMap) CloseServer() { + h.mutex.Lock() + h.server = nil + var wg sync.WaitGroup + for id, handler := range h.handlers { + if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer { + wg.Add(1) + go func(id string, handler packetHandler) { + // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + _ = handler.Close() + h.removeByConnectionIDAsString(id) + wg.Done() + }(id, handler) + } + } + h.mutex.Unlock() + wg.Wait() +} + +func (h *packetHandlerMap) close(e error) error { h.mutex.Lock() if h.closed { h.mutex.Unlock() @@ -86,12 +105,15 @@ func (h *packetHandlerMap) Close() error { if handler != nil { wg.Add(1) go func(handler packetHandler) { - // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - _ = handler.Close() + handler.destroy(e) wg.Done() }(handler) } } + + if h.server != nil { + h.server.closeWithError(e) + } h.mutex.Unlock() wg.Wait() return nil @@ -105,9 +127,7 @@ func (h *packetHandlerMap) listen() { // If it does, we only read a truncated packet, which will then end up undecryptable n, addr, err := h.conn.ReadFrom(data) if err != nil { - if !strings.HasSuffix(err.Error(), "use of closed network connection") { - h.Close() - } + h.close(err) return } data = data[:n] @@ -127,15 +147,33 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { if err != nil { return fmt.Errorf("error parsing invariant header: %s", err) } - handler, ok := h.Get(iHdr.DestConnectionID) - if !ok { - return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) - } - if handler == nil { + + h.mutex.RLock() + handler, ok := h.handlers[string(iHdr.DestConnectionID)] + server := h.server + h.mutex.RUnlock() + + var sentBy protocol.Perspective + var version protocol.VersionNumber + var handlePacket func(*receivedPacket) + if ok && handler == nil { // Late packet for closed session return nil } - hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion()) + if !ok { + if server == nil { // no server set + return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) + } + handlePacket = server.handlePacket + sentBy = protocol.PerspectiveClient + version = iHdr.Version + } else { + sentBy = handler.GetPerspective().Opposite() + version = handler.GetVersion() + handlePacket = handler.handlePacket + } + + hdr, err := iHdr.Parse(r, sentBy, version) if err != nil { return fmt.Errorf("error parsing header: %s", err) } @@ -150,7 +188,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { // TODO(#1312): implement parsing of compound packets } - handler.handlePacket(&receivedPacket{ + handlePacket(&receivedPacket{ remoteAddr: addr, header: hdr, data: packetData, diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 726d1c27..04eaaab8 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "time" "github.com/golang/mock/gomock" @@ -18,66 +19,38 @@ var _ = Describe("Packet Handler Map", func() { conn *mockPacketConn ) + getPacket := func(connID protocol.ConnectionID) []byte { + buf := &bytes.Buffer{} + err := (&wire.Header{ + DestConnectionID: connID, + PacketNumberLen: protocol.PacketNumberLen1, + }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + return buf.Bytes() + } + BeforeEach(func() { conn = newMockPacketConn() - handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger, true).(*packetHandlerMap) - }) - - It("adds and gets", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - sess := &mockSession{} - handler.Add(connID, sess) - session, ok := handler.Get(connID) - Expect(ok).To(BeTrue()) - Expect(session).To(Equal(sess)) - }) - - It("deletes", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - handler.Add(connID, &mockSession{}) - handler.Remove(connID) - session, ok := handler.Get(connID) - Expect(ok).To(BeTrue()) - Expect(session).To(BeNil()) - }) - - It("deletes nil session entries after a wait time", func() { - handler.deleteClosedSessionsAfter = 25 * time.Millisecond - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - handler.Add(connID, &mockSession{}) - handler.Remove(connID) - Eventually(func() bool { - _, ok := handler.Get(connID) - return ok - }).Should(BeFalse()) + handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap) }) It("closes", func() { - sess1 := NewMockQuicSession(mockCtrl) - sess1.EXPECT().Close() - sess2 := NewMockQuicSession(mockCtrl) - sess2.EXPECT().Close() + testErr := errors.New("test error ") + sess1 := NewMockPacketHandler(mockCtrl) + sess1.EXPECT().destroy(testErr) + sess2 := NewMockPacketHandler(mockCtrl) + sess2.EXPECT().destroy(testErr) handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1) handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) - handler.Close() + handler.close(testErr) }) Context("handling packets", func() { - getPacket := func(connID protocol.ConnectionID) []byte { - buf := &bytes.Buffer{} - err := (&wire.Header{ - DestConnectionID: connID, - PacketNumberLen: protocol.PacketNumberLen1, - }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - return buf.Bytes() - } - It("handles packets for different packet handlers on the same packet conn", func() { connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - packetHandler1 := NewMockQuicSession(mockCtrl) - packetHandler2 := NewMockQuicSession(mockCtrl) + packetHandler1 := NewMockPacketHandler(mockCtrl) + packetHandler2 := NewMockPacketHandler(mockCtrl) handledPacket1 := make(chan struct{}) handledPacket2 := make(chan struct{}) packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { @@ -85,11 +58,13 @@ var _ = Describe("Packet Handler Map", func() { close(handledPacket1) }) packetHandler1.EXPECT().GetVersion() + packetHandler1.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { Expect(p.header.DestConnectionID).To(Equal(connID2)) close(handledPacket2) }) packetHandler2.EXPECT().GetVersion() + packetHandler2.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) handler.Add(connID1, packetHandler1) handler.Add(connID2, packetHandler2) @@ -99,8 +74,8 @@ var _ = Describe("Packet Handler Map", func() { Eventually(handledPacket2).Should(BeClosed()) // makes the listen go routine return - packetHandler1.EXPECT().Close().AnyTimes() - packetHandler2.EXPECT().Close().AnyTimes() + packetHandler1.EXPECT().destroy(gomock.Any()).AnyTimes() + packetHandler2.EXPECT().destroy(gomock.Any()).AnyTimes() close(conn.dataToRead) }) @@ -110,10 +85,20 @@ var _ = Describe("Packet Handler Map", func() { Expect(err.Error()).To(ContainSubstring("error parsing invariant header:")) }) + It("deletes nil session entries after a wait time", func() { + handler.deleteClosedSessionsAfter = 10 * time.Millisecond + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + handler.Add(connID, NewMockPacketHandler(mockCtrl)) + handler.Remove(connID) + Eventually(func() error { + return handler.handlePacket(nil, getPacket(connID)) + }).Should(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + }) + It("ignores packets arriving late for closed sessions", func() { handler.deleteClosedSessionsAfter = time.Hour connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - handler.Add(connID, NewMockQuicSession(mockCtrl)) + handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) err := handler.handlePacket(nil, getPacket(connID)) Expect(err).ToNot(HaveOccurred()) @@ -127,8 +112,9 @@ var _ = Describe("Packet Handler Map", func() { It("errors on packets that are smaller than the Payload Length in the packet header", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockQuicSession(mockCtrl) + packetHandler := NewMockPacketHandler(mockCtrl) packetHandler.EXPECT().GetVersion().Return(versionIETFFrames) + packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) handler.Add(connID, packetHandler) hdr := &wire.Header{ IsLongHeader: true, @@ -148,8 +134,9 @@ var _ = Describe("Packet Handler Map", func() { It("cuts packets at the Payload Length", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockQuicSession(mockCtrl) + packetHandler := NewMockPacketHandler(mockCtrl) packetHandler.EXPECT().GetVersion().Return(versionIETFFrames) + packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) handler.Add(connID, packetHandler) packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { Expect(p.data).To(HaveLen(456)) @@ -172,8 +159,9 @@ var _ = Describe("Packet Handler Map", func() { It("closes the packet handlers when reading from the conn fails", func() { done := make(chan struct{}) - packetHandler := NewMockQuicSession(mockCtrl) - packetHandler.EXPECT().Close().Do(func() { + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { + Expect(e).To(HaveOccurred()) close(done) }) handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) @@ -181,4 +169,38 @@ var _ = Describe("Packet Handler Map", func() { Eventually(done).Should(BeClosed()) }) }) + + Context("running a server", func() { + It("adds a server", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p := getPacket(connID) + server := NewMockUnknownPacketHandler(mockCtrl) + server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.header.DestConnectionID).To(Equal(connID)) + }) + handler.SetServer(server) + Expect(handler.handlePacket(nil, p)).To(Succeed()) + }) + + It("closes all server sessions", func() { + clientSess := NewMockPacketHandler(mockCtrl) + clientSess.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) + serverSess := NewMockPacketHandler(mockCtrl) + serverSess.EXPECT().GetPerspective().Return(protocol.PerspectiveServer) + serverSess.EXPECT().Close() + + handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) + handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess) + handler.CloseServer() + }) + + It("stops handling packets with unknown connection IDs after the server is closed", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p := getPacket(connID) + server := NewMockUnknownPacketHandler(mockCtrl) + handler.SetServer(server) + handler.CloseServer() + Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788")) + }) + }) }) diff --git a/server.go b/server.go index b438a124..f25796a7 100644 --- a/server.go +++ b/server.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "crypto/tls" "errors" "fmt" @@ -14,21 +13,27 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/qerr" ) // packetHandler handles packets type packetHandler interface { handlePacket(*receivedPacket) - GetVersion() protocol.VersionNumber io.Closer + destroy(error) + GetVersion() protocol.VersionNumber + GetPerspective() protocol.Perspective +} + +type unknownPacketHandler interface { + handlePacket(*receivedPacket) + closeWithError(error) error } type packetHandlerManager interface { Add(protocol.ConnectionID, packetHandler) - Get(protocol.ConnectionID) (packetHandler, bool) + SetServer(unknownPacketHandler) Remove(protocol.ConnectionID) - io.Closer + CloseServer() } type quicSession interface { @@ -84,6 +89,7 @@ type server struct { } var _ Listener = &server{} +var _ unknownPacketHandler = &server{} // ListenAddr creates a QUIC server listening on a given address. // The tls.Config must not be nil, the quic.Config may be nil. @@ -125,7 +131,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, } } - logger := utils.DefaultLogger.WithPrefix("server") + sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength) + if err != nil { + return nil, err + } s := &server{ conn: conn, tlsConf: tlsConf, @@ -133,11 +142,11 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, certChain: certChain, scfg: scfg, newSession: newSession, - sessionHandler: newPacketHandlerMap(conn, config.ConnectionIDLength, logger, false), + sessionHandler: sessionHandler, sessionQueue: make(chan Session, 5), errorChan: make(chan struct{}), supportsTLS: supportsTLS, - logger: logger, + logger: utils.DefaultLogger.WithPrefix("server"), } s.setup() if supportsTLS { @@ -145,7 +154,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, return nil, err } } - go s.serve() + sessionHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } @@ -176,7 +185,8 @@ func (s *server) setupTLS() error { case tlsSession := <-sessionChan: // The connection ID is a randomly chosen 8 byte value. // It is safe to assume that it doesn't collide with other randomly chosen values. - s.sessionHandler.Add(tlsSession.connID, tlsSession.sess) + serverSession := newServerSession(tlsSession.sess, s.config, s.logger) + s.sessionHandler.Add(tlsSession.connID, serverSession) } } }() @@ -263,27 +273,6 @@ func populateServerConfig(config *Config) *Config { } } -// serve listens on an existing PacketConn -func (s *server) serve() { - for { - data := *getPacketBuffer() - data = data[:protocol.MaxReceivePacketSize] - // The packet size should not exceed protocol.MaxReceivePacketSize bytes - // If it does, we only read a truncated packet, which will then end up undecryptable - n, remoteAddr, err := s.conn.ReadFrom(data) - if err != nil { - s.serverError = err - close(s.errorChan) - _ = s.Close() - return - } - data = data[:n] - if err := s.handlePacket(remoteAddr, data); err != nil { - s.logger.Errorf("error handling packet: %s", err.Error()) - } - } -} - // Accept returns newly openend sessions func (s *server) Accept() (Session, error) { var sess Session @@ -297,10 +286,13 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { - s.sessionHandler.Close() - err := s.conn.Close() - <-s.errorChan // wait for serve() to return - return err + s.sessionHandler.CloseServer() + // TODO: close the conn if this server was started with ListenAddr() (but not with Listen(net.PacketConn)) + if s.serverError == nil { + s.serverError = errors.New("server closed") + } + close(s.errorChan) + return nil } // Addr returns the server's network address @@ -308,157 +300,65 @@ func (s *server) Addr() net.Addr { return s.conn.LocalAddr() } -func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error { - rcvTime := time.Now() - - r := bytes.NewReader(packet) - iHdr, err := wire.ParseInvariantHeader(r, s.config.ConnectionIDLength) - if err != nil { - return qerr.Error(qerr.InvalidPacketHeader, err.Error()) - } - session, sessionKnown := s.sessionHandler.Get(iHdr.DestConnectionID) - if sessionKnown && session == nil { - // Late packet for closed session - return nil - } - version := protocol.VersionUnknown - if sessionKnown { - version = session.GetVersion() - } - hdr, err := iHdr.Parse(r, protocol.PerspectiveClient, version) - if err != nil { - return qerr.Error(qerr.InvalidPacketHeader, err.Error()) - } - hdr.Raw = packet[:len(packet)-r.Len()] - packetData := packet[len(packet)-r.Len():] - - if hdr.IsPublicHeader { - return s.handleGQUICPacket(session, hdr, packetData, remoteAddr, rcvTime) - } - return s.handleIETFQUICPacket(session, hdr, packetData, remoteAddr, rcvTime) +func (s *server) closeWithError(e error) error { + s.serverError = e + return s.Close() } -func (s *server) handleIETFQUICPacket( - session packetHandler, - hdr *wire.Header, - packetData []byte, - remoteAddr net.Addr, - rcvTime time.Time, -) error { - if hdr.IsLongHeader { - if !s.supportsTLS { - return errors.New("Received an IETF QUIC Long Header") - } - if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { - return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) - } - packetData = packetData[:int(hdr.PayloadLen)] - // TODO(#1312): implement parsing of compound packets - - switch hdr.Type { - case protocol.PacketTypeInitial: - go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData) - return nil - case protocol.PacketTypeHandshake: - // nothing to do here. Packet will be passed to the session. - default: - // Note that this also drops 0-RTT packets. - return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) - } +func (s *server) handlePacket(p *receivedPacket) { + if err := s.handlePacketImpl(p); err != nil { + s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) } - - if session == nil { - s.logger.Debugf("Received %s packet for unknown connection %s.", hdr.Type, hdr.DestConnectionID) - return nil - } - - session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) - return nil } -func (s *server) handleGQUICPacket( - session packetHandler, - hdr *wire.Header, - packetData []byte, - remoteAddr net.Addr, - rcvTime time.Time, -) error { - // ignore all Public Reset packets - if hdr.ResetFlag { - s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID) +func (s *server) handlePacketImpl(p *receivedPacket) error { + hdr := p.header + version := hdr.Version + + if hdr.Type == protocol.PacketTypeInitial { + go s.serverTLS.HandleInitial(p.remoteAddr, hdr, p.data) return nil } - sessionKnown := session != nil - - // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset - // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. - if !sessionKnown && !hdr.VersionFlag { - _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr) + if !hdr.VersionFlag { + _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr) return err } - // a session is only created once the client sent a supported version - // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated - // it is safe to drop it - if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - return nil + // This is (potentially) a Client Hello. + // Make sure it has the minimum required size before spending any more ressources on it. + if len(p.data) < protocol.MinClientHelloSize { + return errors.New("dropping small packet for unknown connection") } // send a Version Negotiation Packet if the client is speaking a different protocol version // since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet - if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - // drop packets that are too small to be valid first packets - if len(packetData) < protocol.MinClientHelloSize { - return errors.New("dropping small packet with unknown version") - } - s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) - _, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), remoteAddr) + if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, version) { + s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", version) + _, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), p.remoteAddr) return err } - if !sessionKnown { - // This is (potentially) a Client Hello. - // Make sure it has the minimum required size before spending any more ressources on it. - if len(packetData) < protocol.MinClientHelloSize { - return errors.New("dropping small packet for unknown connection") - } - - version := hdr.Version - if !protocol.IsSupportedVersion(s.config.Versions, version) { - return errors.New("Server BUG: negotiated version not supported") - } - - s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr) - sess, err := s.newSession( - &conn{pconn: s.conn, currentAddr: remoteAddr}, - s.sessionRunner, - version, - hdr.DestConnectionID, - s.scfg, - s.tlsConf, - s.config, - s.logger, - ) - if err != nil { - return err - } - s.sessionHandler.Add(hdr.DestConnectionID, sess) - - go sess.run() - session = sess + if !protocol.IsSupportedVersion(s.config.Versions, version) { + return errors.New("Server BUG: negotiated version not supported") } - session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) + s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, p.remoteAddr) + sess, err := s.newSession( + &conn{pconn: s.conn, currentAddr: p.remoteAddr}, + s.sessionRunner, + version, + hdr.DestConnectionID, + s.scfg, + s.tlsConf, + s.config, + s.logger, + ) + if err != nil { + return err + } + s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger)) + go sess.run() + sess.handlePacket(p) return nil } diff --git a/server_session.go b/server_session.go new file mode 100644 index 00000000..6c7dd81c --- /dev/null +++ b/server_session.go @@ -0,0 +1,63 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type serverSession struct { + quicSession + + config *Config + + logger utils.Logger +} + +var _ packetHandler = &serverSession{} + +func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler { + return &serverSession{ + quicSession: sess, + config: config, + logger: logger, + } +} + +func (s *serverSession) handlePacket(p *receivedPacket) { + if err := s.handlePacketImpl(p); err != nil { + s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) + } +} + +func (s *serverSession) handlePacketImpl(p *receivedPacket) error { + hdr := p.header + // ignore all Public Reset packets + if hdr.ResetFlag { + return fmt.Errorf("Received unexpected Public Reset for connection %s", hdr.DestConnectionID) + } + + // Probably an old packet that was sent by the client before the version was negotiated. + // It is safe to drop it. + if (hdr.VersionFlag || hdr.IsLongHeader) && hdr.Version != s.quicSession.GetVersion() { + return nil + } + + if hdr.IsLongHeader { + switch hdr.Type { + case protocol.PacketTypeHandshake: + // nothing to do here. Packet will be passed to the session. + default: + // Note that this also drops 0-RTT packets. + return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) + } + } + + s.quicSession.handlePacket(p) + return nil +} + +func (s *serverSession) GetPerspective() protocol.Perspective { + return protocol.PerspectiveServer +} diff --git a/server_session_test.go b/server_session_test.go new file mode 100644 index 00000000..426ca9ad --- /dev/null +++ b/server_session_test.go @@ -0,0 +1,101 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Server Session", func() { + var ( + qsess *MockQuicSession + sess *serverSession + ) + + BeforeEach(func() { + qsess = NewMockQuicSession(mockCtrl) + sess = newServerSession(qsess, &Config{}, utils.DefaultLogger).(*serverSession) + }) + + It("handles packets", func() { + p := &receivedPacket{ + header: &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}}, + } + qsess.EXPECT().handlePacket(p) + sess.handlePacket(p) + }) + + It("ignores Public Resets", func() { + p := &receivedPacket{ + header: &wire.Header{ + ResetFlag: true, + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + }, + } + err := sess.handlePacketImpl(p) + Expect(err).To(MatchError("Received unexpected Public Reset for connection 0xdeadbeef")) + }) + + It("ignores delayed packets with mismatching versions, for gQUIC", func() { + qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) + // don't EXPECT any calls to handlePacket() + p := &receivedPacket{ + header: &wire.Header{ + VersionFlag: true, + Version: protocol.VersionNumber(123), + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + }, + } + err := sess.handlePacketImpl(p) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores delayed packets with mismatching versions, for IETF QUIC", func() { + qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) + // don't EXPECT any calls to handlePacket() + p := &receivedPacket{ + header: &wire.Header{ + IsLongHeader: true, + Version: protocol.VersionNumber(123), + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + }, + } + err := sess.handlePacketImpl(p) + Expect(err).ToNot(HaveOccurred()) + }) + + It("ignores packets with the wrong Long Header type", func() { + qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) + p := &receivedPacket{ + header: &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketType0RTT, + Version: protocol.VersionNumber(100), + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + }, + } + err := sess.handlePacketImpl(p) + Expect(err).To(MatchError("Received unsupported packet type: 0-RTT Protected")) + }) + + It("passes on Handshake packets", func() { + p := &receivedPacket{ + header: &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + Version: protocol.VersionNumber(100), + DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + }, + } + qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) + qsess.EXPECT().handlePacket(p) + Expect(sess.handlePacketImpl(p)).To(Succeed()) + }) + + It("has the right perspective", func() { + Expect(sess.GetPerspective()).To(Equal(protocol.PerspectiveServer)) + }) +}) diff --git a/server_test.go b/server_test.go index 248039df..ac55f598 100644 --- a/server_test.go +++ b/server_test.go @@ -14,7 +14,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -27,6 +26,8 @@ type mockSession struct { runner sessionRunner } +func (s *mockSession) GetPerspective() protocol.Perspective { panic("not implemented") } + var _ = Describe("Server", func() { var ( conn *mockPacketConn @@ -89,7 +90,7 @@ var _ = Describe("Server", func() { Context("with mock session", func() { var ( serv *server - firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID) + firstPacket *receivedPacket connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} sessions = make([]*MockQuicSession, 0) sessionHandler *MockPacketHandlerManager @@ -126,9 +127,16 @@ var _ = Describe("Server", func() { serv.setup() b := &bytes.Buffer{} utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0])) - firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} - firstPacket = append(append(firstPacket, b.Bytes()...), 0x01) - firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding + firstPacket = &receivedPacket{ + header: &wire.Header{ + VersionFlag: true, + Version: serv.config.Versions[0], + DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, + PacketNumber: 1, + }, + data: bytes.Repeat([]byte{0}, protocol.MinClientHelloSize), + rcvTime: time.Now(), + } }) AfterEach(func() { @@ -150,12 +158,10 @@ var _ = Describe("Server", func() { s.EXPECT().run().Do(func() { close(run) }) sessions = append(sessions, s) - sessionHandler.EXPECT().Get(connID) - sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) { - Expect(sess.(*mockSession).connID).To(Equal(connID)) + sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(cid protocol.ConnectionID, _ packetHandler) { + Expect(cid).To(Equal(connID)) }) - err := serv.handlePacket(nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) + Expect(serv.handlePacketImpl(firstPacket)).To(Succeed()) Eventually(run).Should(BeClosed()) }) @@ -165,7 +171,8 @@ var _ = Describe("Server", func() { err := serv.setupTLS() Expect(err).ToNot(HaveOccurred()) added := make(chan struct{}) - sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) { + sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, ph packetHandler) { + Expect(ph.GetPerspective()).To(Equal(protocol.PerspectiveServer)) close(added) }) serv.serverTLS.sessionChan <- tlsSession{ @@ -184,17 +191,15 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := serv.Accept() + _, err := serv.Accept() Expect(err).ToNot(HaveOccurred()) - Expect(sess.(*mockSession).connID).To(Equal(connID)) close(done) }() - sessionHandler.EXPECT().Get(connID) sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) { Consistently(done).ShouldNot(BeClosed()) - sess.(*mockSession).runner.onHandshakeComplete(sess.(Session)) + sess.(*serverSession).quicSession.(*mockSession).runner.onHandshakeComplete(sess.(Session)) }) - err := serv.handlePacket(nil, firstPacket) + err := serv.handlePacketImpl(firstPacket) Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) Eventually(run).Should(BeClosed()) @@ -212,45 +217,20 @@ var _ = Describe("Server", func() { serv.Accept() close(done) }() - sessionHandler.EXPECT().Get(connID) - sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) { + sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(protocol.ConnectionID, packetHandler) { run <- errors.New("handshake error") }) - err := serv.handlePacket(nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) + Expect(serv.handlePacketImpl(firstPacket)).To(Succeed()) Consistently(done).ShouldNot(BeClosed()) + // make the go routine return - sessionHandler.EXPECT().Close() close(serv.errorChan) - serv.Close() Eventually(done).Should(BeClosed()) }) - It("assigns packets to existing sessions", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) - sess.EXPECT().GetVersion() - - sessionHandler.EXPECT().Get(connID).Return(sess, true) - err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}) - Expect(err).ToNot(HaveOccurred()) - }) - - It("closes the sessionHandler and the connection when Close is called", func() { - go func() { - defer GinkgoRecover() - serv.serve() - }() - // close the server - sessionHandler.EXPECT().Close().AnyTimes() + It("closes the sessionHandler when Close is called", func() { + sessionHandler.EXPECT().CloseServer() Expect(serv.Close()).To(Succeed()) - Expect(conn.closed).To(BeTrue()) - }) - - It("ignores packets for closed sessions", func() { - sessionHandler.EXPECT().Get(connID).Return(nil, true) - err := serv.handlePacket(nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) }) It("works if no quic.Config is given", func(done Done) { @@ -264,163 +244,56 @@ var _ = Describe("Server", func() { ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config) Expect(err).ToNot(HaveOccurred()) - var returned bool - go func() { - defer GinkgoRecover() - _, err := ln.Accept() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("use of closed network connection")) - returned = true - }() - ln.Close() - Eventually(func() bool { return returned }).Should(BeTrue()) - }) - - It("errors when encountering a connection error", func() { - testErr := errors.New("connection error") - conn.readErr = testErr - sessionHandler.EXPECT().Close() done := make(chan struct{}) go func() { defer GinkgoRecover() - serv.serve() + ln.Accept() close(done) }() - _, err := serv.Accept() - Expect(err).To(MatchError(testErr)) + ln.Close() Eventually(done).Should(BeClosed()) }) - It("ignores delayed packets with mismatching versions", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().GetVersion() - // don't EXPECT any handlePacket() calls to this session - sessionHandler.EXPECT().Get(connID).Return(sess, true) - - b := &bytes.Buffer{} - // add an unsupported version - data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} - utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1)) - data = append(append(data, b.Bytes()...), 0x01) - err := serv.handlePacket(nil, data) - Expect(err).ToNot(HaveOccurred()) - // if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn - Expect(conn.dataWritten.Bytes()).To(BeEmpty()) + It("returns Accept when it is closed", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept() + Expect(err).To(MatchError("server closed")) + close(done) + }() + sessionHandler.EXPECT().CloseServer() + Expect(serv.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) }) - It("errors on invalid public header", func() { - err := serv.handlePacket(nil, nil) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) - }) - - It("errors on packets that are smaller than the Payload Length in the packet header", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().GetVersion().Return(protocol.VersionTLS) - sessionHandler.EXPECT().Get(connID).Return(sess, true) - - serv.supportsTLS = true - b := &bytes.Buffer{} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - PayloadLen: 1000, - SrcConnectionID: connID, - DestConnectionID: connID, - PacketNumberLen: protocol.PacketNumberLen1, - Version: versionIETFFrames, - } - Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) - err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...)) - Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)")) - }) - - It("cuts packets at the payload length", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) { - Expect(packet.data).To(HaveLen(123)) - }) - sess.EXPECT().GetVersion().Return(protocol.VersionTLS) - sessionHandler.EXPECT().Get(connID).Return(sess, true) - - serv.supportsTLS = true - b := &bytes.Buffer{} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - PayloadLen: 123, - SrcConnectionID: connID, - DestConnectionID: connID, - PacketNumberLen: protocol.PacketNumberLen1, - Version: versionIETFFrames, - } - Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) - err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...)) - Expect(err).ToNot(HaveOccurred()) - }) - - It("drops packets with invalid packet types", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().GetVersion().Return(protocol.VersionTLS) - sessionHandler.EXPECT().Get(connID).Return(sess, true) - - serv.supportsTLS = true - b := &bytes.Buffer{} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - PayloadLen: 123, - SrcConnectionID: connID, - DestConnectionID: connID, - PacketNumberLen: protocol.PacketNumberLen1, - Version: versionIETFFrames, - } - Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) - err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...)) - Expect(err).To(MatchError("Received unsupported packet type: Retry")) - }) - - It("ignores Public Resets", func() { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().GetVersion().Return(protocol.VersionTLS) - sessionHandler.EXPECT().Get(connID).Return(sess, true) - - err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337)) - Expect(err).ToNot(HaveOccurred()) + It("returns Accept with the right error when closeWithError is called", func() { + testErr := errors.New("connection error") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept() + Expect(err).To(MatchError(testErr)) + close(done) + }() + sessionHandler.EXPECT().CloseServer() + serv.closeWithError(testErr) + Eventually(done).Should(BeClosed()) }) It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() { config.Versions = []protocol.VersionNumber{99} - b := &bytes.Buffer{} - hdr := wire.Header{ - VersionFlag: true, - DestConnectionID: connID, - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, + p := &receivedPacket{ + header: &wire.Header{ + VersionFlag: true, + DestConnectionID: connID, + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, + }, + data: make([]byte, protocol.MinClientHelloSize), } - Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed()) - b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO - serv.conn = conn - sessionHandler.EXPECT().Get(connID) - err := serv.handlePacket(nil, b.Bytes()) + Expect(serv.handlePacketImpl(p)).To(Succeed()) Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't respond with a version negotiation packet if the first packet is too small", func() { - b := &bytes.Buffer{} - hdr := wire.Header{ - VersionFlag: true, - DestConnectionID: connID, - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, - } - Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed()) - b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small - serv.conn = conn - sessionHandler.EXPECT().Get(connID) - err := serv.handlePacket(udpAddr, b.Bytes()) - Expect(err).To(MatchError("dropping small packet with unknown version")) - Expect(conn.dataWritten.Len()).Should(BeZero()) }) }) @@ -523,8 +396,11 @@ var _ = Describe("Server", func() { }) It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() { - connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} config.Versions = append(config.Versions, protocol.VersionTLS) + ln, err := Listen(conn, testdata.GetTLSConfig(), config) + Expect(err).ToNot(HaveOccurred()) + + connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} b := &bytes.Buffer{} hdr := wire.Header{ Type: protocol.PacketTypeInitial, @@ -536,13 +412,10 @@ var _ = Describe("Server", func() { Version: 0x1234, PayloadLen: protocol.MinInitialPacketSize, } - err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) - Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)).To(Succeed()) b.Write(bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) // add a fake CHLO conn.dataToRead <- b.Bytes() conn.dataReadFrom = udpAddr - ln, err := Listen(conn, testdata.GetTLSConfig(), config) - Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { @@ -568,51 +441,6 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() { - connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - b := &bytes.Buffer{} - hdr := wire.Header{ - Type: protocol.PacketTypeInitial, - IsLongHeader: true, - DestConnectionID: connID, - SrcConnectionID: connID, - PacketNumber: 0x55, - PacketNumberLen: protocol.PacketNumberLen1, - Version: protocol.VersionTLS, - } - err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) - Expect(err).ToNot(HaveOccurred()) - b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO - conn.dataToRead <- b.Bytes() - conn.dataReadFrom = udpAddr - ln, err := Listen(conn, testdata.GetTLSConfig(), config) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero()) - }) - - It("ignores non-Initial Long Header packets for unknown connections", func() { - connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - b := &bytes.Buffer{} - hdr := wire.Header{ - Type: protocol.PacketTypeHandshake, - IsLongHeader: true, - DestConnectionID: connID, - SrcConnectionID: connID, - PacketNumber: 0x55, - PacketNumberLen: protocol.PacketNumberLen1, - Version: protocol.VersionTLS, - } - err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) - Expect(err).ToNot(HaveOccurred()) - conn.dataToRead <- b.Bytes() - conn.dataReadFrom = udpAddr - ln, err := Listen(conn, testdata.GetTLSConfig(), config) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero()) - }) - It("sends a PublicReset for new connections that don't have the VersionFlag set", func() { conn.dataReadFrom = udpAddr conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01} diff --git a/server_tls.go b/server_tls.go index 5c94dee8..150a889f 100644 --- a/server_tls.go +++ b/server_tls.go @@ -17,7 +17,7 @@ import ( type tlsSession struct { connID protocol.ConnectionID - sess packetHandler + sess quicSession } type serverTLS struct { @@ -126,7 +126,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea return err } -func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) { +func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (quicSession, protocol.ConnectionID, error) { if hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { return nil, nil, errors.New("dropping Initial packet with too short connection ID") } @@ -164,7 +164,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat return sess, connID, nil } -func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) { +func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (quicSession, protocol.ConnectionID, error) { version := hdr.Version bc := handshake.NewCryptoStreamConn(remoteAddr) bc.AddDataForReading(frame.Data)