simplify Read and Write mock calls in http3 tests

This commit is contained in:
Marten Seemann 2020-12-21 14:41:22 +07:00
parent 9693a46d31
commit 4c6496bc0e
2 changed files with 14 additions and 42 deletions

View file

@ -217,9 +217,7 @@ var _ = Describe("Client", func() {
// don't EXPECT any calls to HandshakeComplete()
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
str.EXPECT().Close()
str.EXPECT().CancelWrite(gomock.Any())
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
@ -283,9 +281,7 @@ var _ = Describe("Client", func() {
var err error
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
Expect(err).ToNot(HaveOccurred())
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return strBuf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
})
It("sends a request", func() {
@ -378,9 +374,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
done := make(chan struct{})
canceled := make(chan struct{})
@ -414,12 +408,8 @@ var _ = Describe("Client", func() {
str.EXPECT().Close().MaxTimes(1)
done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return rspBuf.Read(b)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) })
_, err := client.RoundTrip(req)
@ -437,9 +427,7 @@ var _ = Describe("Client", func() {
It("adds the gzip header to requests", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
@ -456,9 +444,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
@ -481,9 +467,7 @@ var _ = Describe("Client", func() {
gz.Close()
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
@ -504,9 +488,7 @@ var _ = Describe("Client", func() {
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)

View file

@ -87,9 +87,7 @@ var _ = Describe("Server", func() {
encodeRequest := func(req *http.Request) []byte {
buf := &bytes.Buffer{}
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
rw := newRequestWriter(utils.DefaultLogger)
@ -151,9 +149,7 @@ var _ = Describe("Server", func() {
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return responseBuf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(sess, str, qpackDecoder, nil)
@ -170,9 +166,7 @@ var _ = Describe("Server", func() {
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return responseBuf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(sess, str, qpackDecoder, nil)
@ -210,9 +204,7 @@ var _ = Describe("Server", func() {
setRequest(append(requestData, buf.Bytes()...))
done := make(chan struct{})
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return responseBuf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(quic.ErrorCode(errorNoError))
str.EXPECT().Close().Do(func() { close(done) })
@ -235,9 +227,7 @@ var _ = Describe("Server", func() {
responseBuf := &bytes.Buffer{}
setRequest(append(requestData, buf.Bytes()...))
done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return responseBuf.Write(p)
}).AnyTimes()
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelWrite(quic.ErrorCode(errorFrameError)).Do(func(quic.ErrorCode) { close(done) })
s.handleConn(sess)