diff --git a/client_conn.go b/client_conn.go index 2ccfc0b..c5c7d14 100644 --- a/client_conn.go +++ b/client_conn.go @@ -20,6 +20,10 @@ type clientConn struct { responseRead bool } +func (c *clientConn) NeedHandshake() bool { + return !c.requestWritten +} + func (c *clientConn) readResponse() error { response, err := ReadStreamResponse(c.Conn) if err != nil { @@ -94,6 +98,10 @@ type clientPacketConn struct { responseRead bool } +func (c *clientPacketConn) NeedHandshake() bool { + return !c.requestWritten +} + func (c *clientPacketConn) readResponse() error { response, err := ReadStreamResponse(c.ExtendedConn) if err != nil { @@ -276,6 +284,10 @@ type clientPacketAddrConn struct { responseRead bool } +func (c *clientPacketAddrConn) NeedHandshake() bool { + return !c.requestWritten +} + func (c *clientPacketAddrConn) readResponse() error { response, err := ReadStreamResponse(c.ExtendedConn) if err != nil { diff --git a/protocol_conn.go b/protocol_conn.go index aaf3ffe..75ddfa9 100644 --- a/protocol_conn.go +++ b/protocol_conn.go @@ -32,6 +32,10 @@ func newProtocolConn(conn net.Conn, request Request) net.Conn { } } +func (c *protocolConn) NeedHandshake() bool { + return !c.requestWritten +} + func (c *protocolConn) Write(p []byte) (n int, err error) { if c.requestWritten { return c.Conn.Write(p) diff --git a/server_conn.go b/server_conn.go index 70a9689..bac7add 100644 --- a/server_conn.go +++ b/server_conn.go @@ -20,6 +20,10 @@ type serverConn struct { responseWritten bool } +func (c *serverConn) NeedHandshake() bool { + return !c.responseWritten +} + func (c *serverConn) HandshakeFailure(err error) error { errMessage := err.Error() buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) @@ -85,6 +89,10 @@ type serverPacketConn struct { responseWritten bool } +func (c *serverPacketConn) NeedHandshake() bool { + return !c.responseWritten +} + func (c *serverPacketConn) HandshakeFailure(err error) error { errMessage := err.Error() buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) @@ -186,6 +194,10 @@ type serverPacketAddrConn struct { responseWritten bool } +func (c *serverPacketAddrConn) NeedHandshake() bool { + return !c.responseWritten +} + func (c *serverPacketAddrConn) HandshakeFailure(err error) error { errMessage := err.Error() buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))