package quic import ( "bytes" "errors" "fmt" "net" "strings" "sync" "time" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) var ( connMuxerOnce sync.Once connMuxer multiplexer ) type multiplexer interface { AddConn(net.PacketConn, int) (packetHandlerManager, error) AddHandler(net.PacketConn, protocol.ConnectionID, packetHandler) error } type connManager struct { connIDLen int manager packetHandlerManager } // The connMultiplexer listens on multiple net.PacketConns and dispatches // incoming packets to the session handler. type connMultiplexer struct { mutex sync.Mutex conns map[net.PacketConn]connManager newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests logger utils.Logger } var _ multiplexer = &connMultiplexer{} func getMultiplexer() multiplexer { connMuxerOnce.Do(func() { connMuxer = &connMultiplexer{ conns: make(map[net.PacketConn]connManager), logger: utils.DefaultLogger.WithPrefix("muxer"), newPacketHandlerManager: newPacketHandlerMap, } }) return connMuxer } func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) { m.mutex.Lock() defer m.mutex.Unlock() p, ok := m.conns[c] if !ok { manager := m.newPacketHandlerManager() p = connManager{connIDLen: connIDLen, manager: manager} m.conns[c] = p // If we didn't know this packet conn before, listen for incoming packets // and dispatch them to the right sessions. go m.listen(c, &p) } if p.connIDLen != connIDLen { return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) } return p.manager, nil } func (m *connMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error { m.mutex.Lock() defer m.mutex.Unlock() p, ok := m.conns[c] if !ok { return errors.New("unknown packet conn %s") } p.manager.Add(connID, handler) return nil } func (m *connMultiplexer) listen(c net.PacketConn, p *connManager) { for { data := *getPacketBuffer() data = data[:protocol.MaxReceivePacketSize] // The packet size should not exceed protocol.MaxReceivePacketSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable n, addr, err := c.ReadFrom(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { p.manager.Close() } return } data = data[:n] if err := m.handlePacket(addr, data, p); err != nil { m.logger.Debugf("error handling packet from %s: %s", addr, err) } } } func (m *connMultiplexer) handlePacket(addr net.Addr, data []byte, p *connManager) error { rcvTime := time.Now() r := bytes.NewReader(data) iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen) // drop the packet if we can't parse the header if err != nil { return fmt.Errorf("error parsing invariant header: %s", err) } handler, ok := p.manager.Get(iHdr.DestConnectionID) if !ok { return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) } if handler == nil { // Late packet for closed session return nil } hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, handler.GetVersion()) if err != nil { return fmt.Errorf("error parsing header: %s", err) } hdr.Raw = data[:len(data)-r.Len()] packetData := data[len(data)-r.Len():] if hdr.IsLongHeader { if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) } packetData = packetData[:int(hdr.PayloadLen)] // TODO(#1312): implement parsing of compound packets } handler.handlePacket(&receivedPacket{ remoteAddr: addr, header: hdr, data: packetData, rcvTime: rcvTime, }) return nil }