uquic/session.go

171 lines
4.3 KiB
Go

package quic
import (
"bytes"
"errors"
"fmt"
"net"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol"
)
// StreamCallback gets a stream frame and returns a reply frame
type StreamCallback func(*StreamFrame) []Frame
// A Session is a QUIC session
type Session struct {
VersionNumber protocol.VersionNumber
ConnectionID protocol.ConnectionID
Connection *net.UDPConn
CurrentRemoteAddr *net.UDPAddr
ServerConfig *handshake.ServerConfig
hshk *handshake.Handshake
Entropy EntropyAccumulator
lastSentPacketNumber protocol.PacketNumber
streamCallback StreamCallback
}
// NewSession makes a new session
func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) *Session {
return &Session{
Connection: conn,
VersionNumber: v,
ConnectionID: connectionID,
ServerConfig: sCfg,
hshk: handshake.NewHandshake(connectionID, v, sCfg),
streamCallback: streamCallback,
}
}
// HandlePacket handles a packet
func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, publicHeader *PublicHeader, r *bytes.Reader) error {
// TODO: Only do this after authenticating
if addr != s.CurrentRemoteAddr {
s.CurrentRemoteAddr = addr
}
r, err := s.hshk.Open(publicHeader.PacketNumber, publicHeaderBinary, r)
if err != nil {
return err
}
privateFlag, err := r.ReadByte()
if err != nil {
return err
}
s.Entropy.Add(publicHeader.PacketNumber, privateFlag&0x01 > 0)
s.SendFrames([]Frame{&AckFrame{
LargestObserved: uint64(publicHeader.PacketNumber),
Entropy: s.Entropy.Get(),
}})
frameCounter := 0
// read all frames in the packet
for r.Len() > 0 {
typeByte, err := r.ReadByte()
if err != nil {
fmt.Println("No more frames in this packet.")
break
}
r.UnreadByte()
frameCounter++
fmt.Printf("Reading frame %d\n", frameCounter)
if typeByte&0x80 == 0x80 { // STREAM
fmt.Println("Detected STREAM")
frame, err := ParseStreamFrame(r)
if err != nil {
return err
}
fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID)
if frame.StreamID == 0 {
return errors.New("Session: 0 is not a valid Stream ID")
}
if frame.StreamID == 1 {
reply, err := s.hshk.HandleCryptoMessage(frame.Data)
if err != nil {
return err
}
if reply != nil {
s.SendFrames([]Frame{&StreamFrame{StreamID: 1, Data: reply}})
}
// TODO: Send reply
} else {
replyFrames := s.streamCallback(frame)
if replyFrames != nil {
s.SendFrames(replyFrames)
}
}
continue
} else if typeByte&0xC0 == 0x40 { // ACK
fmt.Println("Detected ACK")
frame, err := ParseAckFrame(r)
if err != nil {
return err
}
fmt.Printf("%#v\n", frame)
continue
} else if typeByte&0xE0 == 0x20 { // CONGESTION_FEEDBACK
return errors.New("Detected CONGESTION_FEEDBACK")
} else if typeByte&0x06 == 0x06 { // STOP_WAITING
fmt.Println("Detected STOP_WAITING")
_, err := ParseStopWaitingFrame(r, publicHeader.PacketNumberLen)
if err != nil {
return err
}
// ToDo: react to receiving this frame
} else if typeByte&0x02 == 0x02 { // CONNECTION_CLOSE
fmt.Println("Detected CONNECTION_CLOSE")
frame, err := ParseConnectionCloseFrame(r)
if err != nil {
return err
}
fmt.Printf("%#v\n", frame)
} else if typeByte == 0 {
// PAD
return nil
} else {
return errors.New("Session: invalid Frame Type Field")
}
}
return nil
}
// SendFrames sends a number of frames to the client
func (s *Session) SendFrames(frames []Frame) error {
var framesData bytes.Buffer
framesData.WriteByte(0) // TODO: entropy
for _, f := range frames {
if err := f.Write(&framesData); err != nil {
return err
}
}
s.lastSentPacketNumber++
var fullReply bytes.Buffer
responsePublicHeader := PublicHeader{ConnectionID: s.ConnectionID, PacketNumber: s.lastSentPacketNumber}
fmt.Printf("Sending packet # %d\n", responsePublicHeader.PacketNumber)
if err := responsePublicHeader.WritePublicHeader(&fullReply); err != nil {
return err
}
s.hshk.Seal(s.lastSentPacketNumber, &fullReply, fullReply.Bytes(), framesData.Bytes())
fmt.Printf("Sending %d bytes to %v\n", len(fullReply.Bytes()), s.CurrentRemoteAddr)
_, err := s.Connection.WriteToUDP(fullReply.Bytes(), s.CurrentRemoteAddr)
return err
}