diff --git a/frames/connection_close_frame.go b/frames/connection_close_frame.go index f7739573..1c988589 100644 --- a/frames/connection_close_frame.go +++ b/frames/connection_close_frame.go @@ -67,10 +67,7 @@ func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNu reasonPhraseLen := uint16(len(f.ReasonPhrase)) utils.WriteUint16(b, reasonPhraseLen) - - for i := 0; i < int(reasonPhraseLen); i++ { - b.WriteByte(uint8(f.ReasonPhrase[i])) - } + b.WriteString(f.ReasonPhrase) return nil } diff --git a/frames/goaway_frame.go b/frames/goaway_frame.go new file mode 100644 index 00000000..b890ba20 --- /dev/null +++ b/frames/goaway_frame.go @@ -0,0 +1,73 @@ +package frames + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" +) + +// A GoawayFrame is a GOAWAY frame +type GoawayFrame struct { + ErrorCode qerr.ErrorCode + LastGoodStream protocol.StreamID + ReasonPhrase string +} + +// ParseGoawayFrame parses a GOAWAY frame +func ParseGoawayFrame(r *bytes.Reader) (*GoawayFrame, error) { + frame := &GoawayFrame{} + + _, err := r.ReadByte() + if err != nil { + return nil, err + } + + errorCode, err := utils.ReadUint32(r) + if err != nil { + return nil, err + } + frame.ErrorCode = qerr.ErrorCode(errorCode) + + lastGoodStream, err := utils.ReadUint32(r) + if err != nil { + return nil, err + } + frame.LastGoodStream = protocol.StreamID(lastGoodStream) + + reasonPhraseLen, err := utils.ReadUint16(r) + if err != nil { + return nil, err + } + + if reasonPhraseLen > uint16(protocol.MaxPacketSize) { + return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long") + } + + reasonPhrase := make([]byte, reasonPhraseLen) + if _, err := io.ReadFull(r, reasonPhrase); err != nil { + return nil, err + } + frame.ReasonPhrase = string(reasonPhrase) + + return frame, nil +} + +func (f *GoawayFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + typeByte := uint8(0x03) + b.WriteByte(typeByte) + + utils.WriteUint32(b, uint32(f.ErrorCode)) + utils.WriteUint32(b, uint32(f.LastGoodStream)) + utils.WriteUint16(b, uint16(len(f.ReasonPhrase))) + b.WriteString(f.ReasonPhrase) + + return nil +} + +// MinLength of a written frame +func (f *GoawayFrame) MinLength() (protocol.ByteCount, error) { + return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)), nil +} diff --git a/frames/goaway_frame_test.go b/frames/goaway_frame_test.go new file mode 100644 index 00000000..e25b96ec --- /dev/null +++ b/frames/goaway_frame_test.go @@ -0,0 +1,70 @@ +package frames + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("GoawayFrame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{ + 0x03, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x03, 0x00, + 'f', 'o', 'o', + }) + frame, err := ParseGoawayFrame(b) + Expect(frame).To(Equal(&GoawayFrame{ + ErrorCode: 1, + LastGoodStream: 2, + ReasonPhrase: "foo", + })) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(Equal(0)) + }) + + It("rejects long reason phrases", func() { + b := bytes.NewReader([]byte{ + 0x03, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0xff, 0xff, + }) + _, err := ParseGoawayFrame(b) + Expect(err).To(MatchError(qerr.Error(qerr.InvalidGoawayData, "reason phrase too long"))) + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + b := &bytes.Buffer{} + frame := GoawayFrame{ + ErrorCode: 1, + LastGoodStream: 2, + ReasonPhrase: "foo", + } + frame.Write(b, 0) + Expect(b.Bytes()).To(Equal([]byte{ + 0x03, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x03, 0x00, + 'f', 'o', 'o', + })) + }) + + It("has the correct min length", func() { + frame := GoawayFrame{ + ReasonPhrase: "foo", + } + Expect(frame.MinLength()).To(Equal(protocol.ByteCount(14))) + }) + }) +}) diff --git a/packet_unpacker.go b/packet_unpacker.go index 1e1682dc..2bedbf10 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -73,7 +73,10 @@ ReadLoop: err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) } case 0x03: - err = errors.New("unimplemented: GOAWAY") + frame, err = frames.ParseGoawayFrame(r) + if err != nil { + err = qerr.Error(qerr.InvalidGoawayData, err.Error()) + } case 0x04: frame, err = frames.ParseWindowUpdateFrame(r) if err != nil { diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index c9f1c255..6124499c 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -110,10 +110,23 @@ var _ = Describe("Packet unpacker", func() { Expect(packet.frames).To(Equal([]frames.Frame{f})) }) - It("errors on GOAWAY frames", func() { - setReader([]byte{0x03}) - _, err := unpacker.Unpack(hdrBin, hdr, r) - Expect(err).To(MatchError("unimplemented: GOAWAY")) + It("accepts GOAWAY frames", func() { + setReader([]byte{ + 0x03, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x03, 0x00, + 'f', 'o', 'o', + }) + packet, err := unpacker.Unpack(hdrBin, hdr, r) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]frames.Frame{ + &frames.GoawayFrame{ + ErrorCode: 1, + LastGoodStream: 2, + ReasonPhrase: "foo", + }, + })) }) It("accepts WINDOW_UPDATE frames", func() { diff --git a/session.go b/session.go index b2a46a23..813ca2d9 100644 --- a/session.go +++ b/session.go @@ -250,6 +250,9 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da case *frames.ConnectionCloseFrame: utils.Debugf("\t<- %#v", frame) s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) + case *frames.GoawayFrame: + utils.Debugf("\t<- %#v", frame) + err = errors.New("unimplemented: handling GOAWAY frames") case *frames.StopWaitingFrame: utils.Debugf("\t<- %#v", frame) err = s.receivedPacketHandler.ReceivedStopWaiting(frame)