Add write lock to aead writer

This commit is contained in:
世界 2022-08-19 07:20:40 +08:00
parent 7cbd7d6346
commit 7461bb09a8
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -4,6 +4,7 @@ import (
"crypto/cipher" "crypto/cipher"
"encoding/binary" "encoding/binary"
"io" "io"
"sync"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
) )
@ -276,6 +277,7 @@ type Writer struct {
maxPacketSize int maxPacketSize int
buffer []byte buffer []byte
nonce []byte nonce []byte
access sync.Mutex
} }
func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer { func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer {
@ -337,12 +339,14 @@ func (w *Writer) Write(p []byte) (n int, err error) {
data = p data = p
pLen = 0 pLen = 0
} }
w.access.Lock()
binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data))) binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data)))
w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
increaseNonce(w.nonce) increaseNonce(w.nonce)
offset := Overhead + PacketLengthBufferSize offset := Overhead + PacketLengthBufferSize
packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil) packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil)
increaseNonce(w.nonce) increaseNonce(w.nonce)
w.access.Unlock()
_, err = w.upstream.Write(w.buffer[:offset+len(packet)]) _, err = w.upstream.Write(w.buffer[:offset+len(packet)])
if err != nil { if err != nil {
return return
@ -372,12 +376,14 @@ func (w *Writer) WriteVectorised(buffers []*buf.Buffer) error {
return err return err
} }
} }
w.access.Lock()
binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen)) binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen))
w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil) w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil)
increaseNonce(w.nonce) increaseNonce(w.nonce)
offset := index + Overhead + PacketLengthBufferSize offset := index + Overhead + PacketLengthBufferSize
w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil) w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil)
increaseNonce(w.nonce) increaseNonce(w.nonce)
w.access.Unlock()
index = offset + pLen + Overhead index = offset + pLen + Overhead
} }
} }