mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
improve public quic and h2 APIs, embedding http.Server in h2quic.Server
ref #124
This commit is contained in:
parent
b0bc84c5aa
commit
bf3d89c795
5 changed files with 71 additions and 81 deletions
|
@ -68,13 +68,6 @@ func main() {
|
|||
})
|
||||
http.Handle("/", http.FileServer(http.Dir(*www)))
|
||||
|
||||
server, err := h2quic.NewServer(tlsConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// server.CloseAfterFirstRequest = true
|
||||
|
||||
if len(bs) == 0 {
|
||||
bs = binds{"localhost:6121"}
|
||||
}
|
||||
|
@ -84,7 +77,14 @@ func main() {
|
|||
for _, b := range bs {
|
||||
bCap := b
|
||||
go func() {
|
||||
err := server.ListenAndServe(bCap, nil)
|
||||
server := h2quic.Server{
|
||||
// CloseAfterFirstRequest: true,
|
||||
Server: &http.Server{
|
||||
Addr: bCap,
|
||||
TLSConfig: tlsConfig,
|
||||
},
|
||||
}
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -19,36 +18,27 @@ type streamCreator interface {
|
|||
Close(error) error
|
||||
}
|
||||
|
||||
// Server is a HTTP2 server listening for QUIC connections
|
||||
// Server is a HTTP2 server listening for QUIC connections.
|
||||
// The nil value is invalid, as a valid TLS config is required.
|
||||
type Server struct {
|
||||
server *quic.Server
|
||||
handler http.Handler
|
||||
*http.Server
|
||||
|
||||
// Private flag for demo, do not use
|
||||
CloseAfterFirstRequest bool
|
||||
}
|
||||
|
||||
// NewServer creates a new server instance
|
||||
func NewServer(tlsConfig *tls.Config) (*Server, error) {
|
||||
s := &Server{}
|
||||
|
||||
var err error
|
||||
s.server, err = quic.NewServer(tlsConfig, s.handleStreamCb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the network address and calls the handler.
|
||||
func (s *Server) ListenAndServe(addr string, handler http.Handler) error {
|
||||
if handler != nil {
|
||||
s.handler = handler
|
||||
} else {
|
||||
s.handler = http.DefaultServeMux
|
||||
func (s *Server) ListenAndServe() error {
|
||||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
return s.server.ListenAndServe(addr)
|
||||
|
||||
server, err := quic.NewServer(s.Addr, s.TLSConfig, s.handleStreamCb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) handleStreamCb(session *quic.Session, stream utils.Stream) {
|
||||
|
@ -109,7 +99,11 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
|
|||
responseWriter := newResponseWriter(headerStream, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
||||
go func() {
|
||||
s.handler.ServeHTTP(responseWriter, req)
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
handler.ServeHTTP(responseWriter, req)
|
||||
if responseWriter.dataStream != nil {
|
||||
responseWriter.dataStream.Close()
|
||||
}
|
||||
|
|
|
@ -33,29 +33,15 @@ var _ = Describe("H2 server", func() {
|
|||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
s, err = NewServer(testdata.GetTLSConfig())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s).NotTo(BeNil())
|
||||
s = &Server{
|
||||
Server: &http.Server{
|
||||
TLSConfig: testdata.GetTLSConfig(),
|
||||
},
|
||||
}
|
||||
dataStream = &mockStream{}
|
||||
session = &mockSession{dataStream: dataStream}
|
||||
})
|
||||
|
||||
It("uses default handler", func() {
|
||||
// We try binding to a low port number, s.t. it always fails
|
||||
err := s.ListenAndServe("127.0.0.1:80", nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(s.handler).To(Equal(http.DefaultServeMux))
|
||||
})
|
||||
|
||||
It("sets handler properly", func() {
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
// We try binding to a low port number, s.t. it always fails
|
||||
err := s.ListenAndServe("127.0.0.1:80", h)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(s.handler).NotTo(Equal(http.DefaultServeMux))
|
||||
})
|
||||
|
||||
Context("handling requests", func() {
|
||||
var (
|
||||
h2framer *http2.Framer
|
||||
|
@ -71,7 +57,7 @@ var _ = Describe("H2 server", func() {
|
|||
|
||||
It("handles a sample GET request", func() {
|
||||
var handlerCalled bool
|
||||
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
})
|
||||
|
@ -88,7 +74,7 @@ var _ = Describe("H2 server", func() {
|
|||
|
||||
It("does not close the dataStream when end of stream is not set", func() {
|
||||
var handlerCalled bool
|
||||
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
})
|
||||
|
@ -106,7 +92,7 @@ var _ = Describe("H2 server", func() {
|
|||
|
||||
It("handles the header stream", func() {
|
||||
var handlerCalled bool
|
||||
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
})
|
||||
|
@ -122,7 +108,7 @@ var _ = Describe("H2 server", func() {
|
|||
|
||||
It("ignores other streams", func() {
|
||||
var handlerCalled bool
|
||||
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
})
|
||||
|
@ -138,7 +124,7 @@ var _ = Describe("H2 server", func() {
|
|||
|
||||
It("supports closing after first request", func() {
|
||||
s.CloseAfterFirstRequest = true
|
||||
s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
headerStream := &mockStream{id: 3}
|
||||
headerStream.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
||||
|
@ -149,4 +135,20 @@ var _ = Describe("H2 server", func() {
|
|||
s.handleStream(session, headerStream)
|
||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("uses the default handler as fallback", func() {
|
||||
var handlerCalled bool
|
||||
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
}))
|
||||
headerStream := &mockStream{id: 3}
|
||||
headerStream.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
s.handleStream(session, headerStream)
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
|
36
server.go
36
server.go
|
@ -21,8 +21,8 @@ type packetHandler interface {
|
|||
|
||||
// A Server of QUIC
|
||||
type Server struct {
|
||||
conns []*net.UDPConn
|
||||
connsMutex sync.Mutex
|
||||
addr *net.UDPAddr
|
||||
conn *net.UDPConn
|
||||
|
||||
signer crypto.Signer
|
||||
scfg *handshake.ServerConfig
|
||||
|
@ -36,7 +36,7 @@ type Server struct {
|
|||
}
|
||||
|
||||
// NewServer makes a new server
|
||||
func NewServer(tlsConfig *tls.Config, cb StreamCallback) (*Server, error) {
|
||||
func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, error) {
|
||||
signer, err := crypto.NewRSASigner(tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -51,7 +51,13 @@ func NewServer(tlsConfig *tls.Config, cb StreamCallback) (*Server, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Server{
|
||||
addr: udpAddr,
|
||||
signer: signer,
|
||||
scfg: scfg,
|
||||
streamCallback: cb,
|
||||
|
@ -61,19 +67,12 @@ func NewServer(tlsConfig *tls.Config, cb StreamCallback) (*Server, error) {
|
|||
}
|
||||
|
||||
// ListenAndServe listens and serves a connection
|
||||
func (s *Server) ListenAndServe(address string) error {
|
||||
addr, err := net.ResolveUDPAddr("udp", address)
|
||||
func (s *Server) ListenAndServe() error {
|
||||
conn, err := net.ListenUDP("udp", s.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.connsMutex.Lock()
|
||||
s.conns = append(s.conns, conn)
|
||||
s.connsMutex.Unlock()
|
||||
s.conn = conn
|
||||
|
||||
for {
|
||||
data := make([]byte, protocol.MaxPacketSize)
|
||||
|
@ -90,16 +89,11 @@ func (s *Server) ListenAndServe(address string) error {
|
|||
|
||||
// Close the server
|
||||
func (s *Server) Close() error {
|
||||
s.connsMutex.Lock()
|
||||
defer s.connsMutex.Unlock()
|
||||
for _, c := range s.conns {
|
||||
err := c.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet []byte) error {
|
||||
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
|
||||
|
|
|
@ -84,11 +84,11 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("setups and responds with version negotiation", func(done Done) {
|
||||
server, err := NewServer(testdata.GetTLSConfig(), nil)
|
||||
server, err := NewServer("127.0.0.1:13370", testdata.GetTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.ListenAndServe("127.0.0.1:13370")
|
||||
err := server.ListenAndServe()
|
||||
Expect(err).To(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
@ -123,11 +123,11 @@ var _ = Describe("Server", func() {
|
|||
}, 1)
|
||||
|
||||
It("setups and responds with error on invalid frame", func(done Done) {
|
||||
server, err := NewServer(testdata.GetTLSConfig(), nil)
|
||||
server, err := NewServer("127.0.0.1:13370", testdata.GetTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := server.ListenAndServe("127.0.0.1:13370")
|
||||
err := server.ListenAndServe()
|
||||
Expect(err).To(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue