diff --git a/conn.go b/conn.go index 785abbaf..2f4e3a23 100644 --- a/conn.go +++ b/conn.go @@ -12,7 +12,7 @@ import ( type connection interface { ReadPacket() (*receivedPacket, error) - WritePacket(b []byte, addr net.Addr, info *packetInfo) (int, error) + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr io.Closer } @@ -60,6 +60,6 @@ func (c *basicConn) ReadPacket() (*receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, info *packetInfo) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } diff --git a/conn_generic.go b/conn_generic.go index f5fe9dcf..451a70ae 100644 --- a/conn_generic.go +++ b/conn_generic.go @@ -11,3 +11,5 @@ func newConn(c net.PacketConn) (connection, error) { func inspectReadBuffer(interface{}) (int, error) { return 0, nil } + +func (i *packetInfo) OOB() []byte { return nil } diff --git a/conn_oob.go b/conn_oob.go index 36ec2c54..7e9878b7 100644 --- a/conn_oob.go +++ b/conn_oob.go @@ -171,8 +171,8 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { }, nil } -func (c *oobConn) WritePacket(b []byte, addr net.Addr, info *packetInfo) (n int, err error) { - n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, info.OOB(), addr.(*net.UDPAddr)) +func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { + n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } @@ -180,11 +180,6 @@ func (info *packetInfo) OOB() []byte { if info == nil { return nil } - info.once.Do(info.computeOOB) - return info.oob -} - -func (info *packetInfo) computeOOB() { if ip4 := info.addr.To4(); ip4 != nil { // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ @@ -208,7 +203,7 @@ func (info *packetInfo) computeOOB() { off += 4 } copy(oob[off:], ip4) - info.oob = oob + return oob } else if len(info.addr) == 16 { // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ @@ -224,8 +219,9 @@ func (info *packetInfo) computeOOB() { off := cmsgLen(0) off += copy(oob[off:], info.addr) binary.LittleEndian.PutUint32(oob[off:], info.ifIndex) - info.oob = oob + return oob } + return nil } func cmsgLen(datalen int) int { diff --git a/conn_windows.go b/conn_windows.go index 1d80eefe..833f2374 100644 --- a/conn_windows.go +++ b/conn_windows.go @@ -35,3 +35,5 @@ func inspectReadBuffer(c net.PacketConn) (int, error) { } return size, serr } + +func (i *packetInfo) OOB() []byte { return nil } diff --git a/packet_handler_map.go b/packet_handler_map.go index 8db6caaa..63d33703 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -469,7 +469,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info); err != nil { + if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { h.logger.Debugf("Error sending Stateless Reset: %s", err) } } diff --git a/send_conn.go b/send_conn.go index 677f73d2..b276af11 100644 --- a/send_conn.go +++ b/send_conn.go @@ -17,16 +17,22 @@ type sconn struct { remoteAddr net.Addr info *packetInfo + oob []byte } var _ sendConn = &sconn{} func newSendConn(c connection, remote net.Addr, info *packetInfo) sendConn { - return &sconn{connection: c, remoteAddr: remote, info: info} + return &sconn{ + connection: c, + remoteAddr: remote, + info: info, + oob: info.OOB(), + } } func (c *sconn) Write(p []byte) error { - _, err := c.WritePacket(p, c.remoteAddr, c.info) + _, err := c.WritePacket(p, c.remoteAddr, c.oob) return err } diff --git a/server.go b/server.go index 9d5e7dd9..f95a081f 100644 --- a/server.go +++ b/server.go @@ -556,7 +556,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack if s.config.Tracer != nil { s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) } - _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info) + _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB()) return err } @@ -633,7 +633,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.config.Tracer != nil { s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) } - _, err := s.conn.WritePacket(raw, remoteAddr, info) + _, err := s.conn.WritePacket(raw, remoteAddr, info.OOB()) return err } @@ -656,7 +656,7 @@ func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.H nil, ) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/session.go b/session.go index b6fef87f..7b0a1d47 100644 --- a/session.go +++ b/session.go @@ -62,8 +62,6 @@ type cryptoStreamHandler interface { type packetInfo struct { addr net.IP ifIndex uint32 - once sync.Once - oob []byte } type receivedPacket struct {