//go:build with_gvisor

package tun

import (
	"context"
	"errors"
	"net"
	"os"
	"syscall"
	"time"

	"github.com/sagernet/gvisor/pkg/tcpip"
	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
	"github.com/sagernet/gvisor/pkg/tcpip/header"
	"github.com/sagernet/gvisor/pkg/tcpip/stack"
	"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
	"github.com/sagernet/gvisor/pkg/waiter"
)

type gLazyConn struct {
	tcpConn       *gonet.TCPConn
	parentCtx     context.Context
	stack         *stack.Stack
	request       *tcp.ForwarderRequest
	localAddr     net.Addr
	remoteAddr    net.Addr
	handshakeDone bool
	handshakeErr  error
}

func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
	if c.handshakeDone {
		return nil
	}
	defer func() {
		c.handshakeDone = true
	}()
	var (
		wq       waiter.Queue
		endpoint tcpip.Endpoint
	)
	handshakeCtx, cancel := context.WithCancel(ctx)
	go func() {
		select {
		case <-c.parentCtx.Done():
			wq.Notify(wq.Events())
		case <-handshakeCtx.Done():
		}
	}()
	endpoint, err := c.request.CreateEndpoint(&wq)
	cancel()
	if err != nil {
		gErr := gonet.TranslateNetstackError(err)
		c.handshakeErr = gErr
		c.request.Complete(true)
		return gErr
	}
	c.request.Complete(false)
	endpoint.SocketOptions().SetKeepAlive(true)
	keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
	endpoint.SetSockOpt(&keepAliveIdle)
	keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
	endpoint.SetSockOpt(&keepAliveInterval)
	tcpConn := gonet.NewTCPConn(&wq, endpoint)
	c.tcpConn = tcpConn
	return nil
}

func (c *gLazyConn) HandshakeFailure(err error) error {
	if c.handshakeDone {
		return nil
	}
	c.request.Complete(gWriteUnreachable(c.stack, c.request.Packet(), err) == os.ErrInvalid)
	c.handshakeDone = true
	c.handshakeErr = err
	return nil
}

func (c *gLazyConn) HandshakeSuccess() error {
	return c.HandshakeContext(context.Background())
}

func (c *gLazyConn) Read(b []byte) (n int, err error) {
	if !c.handshakeDone {
		err = c.HandshakeContext(context.Background())
		if err != nil {
			return
		}
	} else if c.handshakeErr != nil {
		return 0, c.handshakeErr
	}
	return c.tcpConn.Read(b)
}

func (c *gLazyConn) Write(b []byte) (n int, err error) {
	if !c.handshakeDone {
		err = c.HandshakeContext(context.Background())
		if err != nil {
			return
		}
	} else if c.handshakeErr != nil {
		return 0, c.handshakeErr
	}
	return c.tcpConn.Write(b)
}

func (c *gLazyConn) LocalAddr() net.Addr {
	return c.localAddr
}

func (c *gLazyConn) RemoteAddr() net.Addr {
	return c.remoteAddr
}

func (c *gLazyConn) SetDeadline(t time.Time) error {
	if !c.handshakeDone {
		err := c.HandshakeContext(context.Background())
		if err != nil {
			return err
		}
	} else if c.handshakeErr != nil {
		return c.handshakeErr
	}
	return c.tcpConn.SetDeadline(t)
}

func (c *gLazyConn) SetReadDeadline(t time.Time) error {
	if !c.handshakeDone {
		err := c.HandshakeContext(context.Background())
		if err != nil {
			return err
		}
	} else if c.handshakeErr != nil {
		return c.handshakeErr
	}
	return c.tcpConn.SetReadDeadline(t)
}

func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
	if !c.handshakeDone {
		err := c.HandshakeContext(context.Background())
		if err != nil {
			return err
		}
	} else if c.handshakeErr != nil {
		return c.handshakeErr
	}
	return c.tcpConn.SetWriteDeadline(t)
}

func (c *gLazyConn) Close() error {
	if !c.handshakeDone {
		c.request.Complete(true)
		c.handshakeErr = net.ErrClosed
		return nil
	} else if c.handshakeErr != nil {
		return nil
	}
	return c.tcpConn.Close()
}

func (c *gLazyConn) CloseRead() error {
	if !c.handshakeDone {
		c.request.Complete(true)
		c.handshakeErr = net.ErrClosed
		return nil
	} else if c.handshakeErr != nil {
		return nil
	}
	return c.tcpConn.CloseRead()
}

func (c *gLazyConn) CloseWrite() error {
	if !c.handshakeDone {
		c.request.Complete(true)
		c.handshakeErr = net.ErrClosed
		return nil
	} else if c.handshakeErr != nil {
		return nil
	}
	return c.tcpConn.CloseRead()
}

func (c *gLazyConn) ReaderReplaceable() bool {
	return c.handshakeDone && c.handshakeErr == nil
}

func (c *gLazyConn) WriterReplaceable() bool {
	return c.handshakeDone && c.handshakeErr == nil
}

func (c *gLazyConn) Upstream() any {
	return c.tcpConn
}

func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error {
	if errors.Is(err, ErrDrop) {
		return nil
	} else if errors.Is(err, syscall.ENETUNREACH) {
		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
		} else {
			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
		}
	} else if errors.Is(err, syscall.EHOSTUNREACH) {
		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
		} else {
			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPAddrUnreachable)
		}
	} else if errors.Is(err, syscall.ECONNREFUSED) {
		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
		} else {
			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
		}
	}
	return os.ErrInvalid
}

func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error {
	return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true))
}

func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error {
	return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true))
}