diff --git a/app/client.example.yaml b/app/client.example.yaml index df864af..df90b4b 100644 --- a/app/client.example.yaml +++ b/app/client.example.yaml @@ -33,3 +33,12 @@ http: # username: user # password: pass # realm: my_private_realm + +forwarding: + - listen: 127.0.0.1:6666 + remote: 127.0.0.1:5201 + protocol: tcp + - listen: 127.0.0.1:5353 + remote: 1.1.1.1:53 + protocol: udp + udpTimeout: 30s diff --git a/app/cmd/client.go b/app/cmd/client.go index 3d1967f..6a61ea3 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -5,6 +5,7 @@ import ( "errors" "net" "os" + "strings" "sync" "time" @@ -12,6 +13,7 @@ import ( "github.com/spf13/viper" "go.uber.org/zap" + "github.com/apernet/hysteria/app/internal/forwarding" "github.com/apernet/hysteria/app/internal/http" "github.com/apernet/hysteria/app/internal/socks5" "github.com/apernet/hysteria/core/client" @@ -48,9 +50,10 @@ type clientConfig struct { Up string `mapstructure:"up"` Down string `mapstructure:"down"` } `mapstructure:"bandwidth"` - FastOpen bool `mapstructure:"fastOpen"` - SOCKS5 *socks5Config `mapstructure:"socks5"` - HTTP *httpConfig `mapstructure:"http"` + FastOpen bool `mapstructure:"fastOpen"` + SOCKS5 *socks5Config `mapstructure:"socks5"` + HTTP *httpConfig `mapstructure:"http"` + Forwarding []forwardingEntry `mapstructure:"forwarding"` } type socks5Config struct { @@ -67,6 +70,13 @@ type httpConfig struct { Realm string `mapstructure:"realm"` } +type forwardingEntry struct { + Listen string `mapstructure:"listen"` + Remote string `mapstructure:"remote"` + Protocol string `mapstructure:"protocol"` + UDPTimeout time.Duration `mapstructure:"udpTimeout"` +} + // Config validates the fields and returns a ready-to-use Hysteria client config func (c *clientConfig) Config() (*client.Config, error) { hyConfig := &client.Config{} @@ -174,6 +184,16 @@ func runClient(cmd *cobra.Command, args []string) { } }() } + if len(config.Forwarding) > 0 { + hasMode = true + wg.Add(1) + go func() { + defer wg.Done() + if err := clientForwarding(config.Forwarding, c); err != nil { + logger.Fatal("failed to run forwarding", zap.Error(err)) + } + }() + } if !hasMode { logger.Fatal("no mode specified") @@ -234,6 +254,53 @@ func clientHTTP(config httpConfig, c client.Client) error { return h.Serve(l) } +func clientForwarding(entries []forwardingEntry, c client.Client) error { + errChan := make(chan error, len(entries)) + for _, e := range entries { + if e.Listen == "" { + return configError{Field: "listen", Err: errors.New("listen address is empty")} + } + if e.Remote == "" { + return configError{Field: "remote", Err: errors.New("remote address is empty")} + } + switch strings.ToLower(e.Protocol) { + case "tcp": + l, err := net.Listen("tcp", e.Listen) + if err != nil { + return configError{Field: "listen", Err: err} + } + logger.Info("TCP forwarding listening", zap.String("addr", e.Listen), zap.String("remote", e.Remote)) + go func(remote string) { + t := &forwarding.TCPTunnel{ + HyClient: c, + Remote: remote, + EventLogger: &tcpLogger{}, + } + errChan <- t.Serve(l) + }(e.Remote) + case "udp": + l, err := net.ListenPacket("udp", e.Listen) + if err != nil { + return configError{Field: "listen", Err: err} + } + logger.Info("UDP forwarding listening", zap.String("addr", e.Listen), zap.String("remote", e.Remote)) + go func(remote string, timeout time.Duration) { + u := &forwarding.UDPTunnel{ + HyClient: c, + Remote: remote, + Timeout: timeout, + EventLogger: &udpLogger{}, + } + errChan <- u.Serve(l) + }(e.Remote, e.UDPTimeout) + default: + return configError{Field: "protocol", Err: errors.New("unsupported protocol")} + } + } + // Return if any one of the forwarding fails + return <-errChan +} + // parseServerAddrString parses server address string. // Server address can be in either "host:port" or "host" format (in which case we assume port 443). func parseServerAddrString(addrStr string) (host, hostPort string) { @@ -295,3 +362,31 @@ func (l *httpLogger) HTTPError(addr net.Addr, reqURL string, err error) { logger.Error("HTTP error", zap.String("addr", addr.String()), zap.String("reqURL", reqURL), zap.Error(err)) } } + +type tcpLogger struct{} + +func (l *tcpLogger) Connect(addr net.Addr) { + logger.Debug("TCP forwarding connect", zap.String("addr", addr.String())) +} + +func (l *tcpLogger) Error(addr net.Addr, err error) { + if err == nil { + logger.Debug("TCP forwarding closed", zap.String("addr", addr.String())) + } else { + logger.Error("TCP forwarding error", zap.String("addr", addr.String()), zap.Error(err)) + } +} + +type udpLogger struct{} + +func (l *udpLogger) Connect(addr net.Addr) { + logger.Debug("UDP forwarding connect", zap.String("addr", addr.String())) +} + +func (l *udpLogger) Error(addr net.Addr, err error) { + if err == nil { + logger.Debug("UDP forwarding closed", zap.String("addr", addr.String())) + } else { + logger.Error("UDP forwarding error", zap.String("addr", addr.String()), zap.Error(err)) + } +} diff --git a/app/cmd/client_test.go b/app/cmd/client_test.go index 6c37b7b..795dfa7 100644 --- a/app/cmd/client_test.go +++ b/app/cmd/client_test.go @@ -68,6 +68,19 @@ func TestClientConfig(t *testing.T) { Password: "bruh", Realm: "martian", }, + Forwarding: []forwardingEntry{ + { + Listen: "127.0.0.1:8088", + Remote: "internal.example.com:80", + Protocol: "tcp", + }, + { + Listen: "127.0.0.1:5353", + Remote: "internal.example.com:53", + Protocol: "udp", + UDPTimeout: 50 * time.Second, + }, + }, }) { t.Fatal("parsed client config is not equal to expected") } diff --git a/app/cmd/client_test.yaml b/app/cmd/client_test.yaml index e459951..ea5aa9e 100644 --- a/app/cmd/client_test.yaml +++ b/app/cmd/client_test.yaml @@ -33,3 +33,12 @@ http: username: qqq password: bruh realm: martian + +forwarding: + - listen: 127.0.0.1:8088 + remote: internal.example.com:80 + protocol: tcp + - listen: 127.0.0.1:5353 + remote: internal.example.com:53 + protocol: udp + udpTimeout: 50s diff --git a/app/internal/forwarding/tcp.go b/app/internal/forwarding/tcp.go new file mode 100644 index 0000000..da21bdb --- /dev/null +++ b/app/internal/forwarding/tcp.go @@ -0,0 +1,62 @@ +package forwarding + +import ( + "io" + "net" + + "github.com/apernet/hysteria/core/client" +) + +type TCPTunnel struct { + HyClient client.Client + Remote string + EventLogger TCPEventLogger +} + +type TCPEventLogger interface { + Connect(addr net.Addr) + Error(addr net.Addr, err error) +} + +func (t *TCPTunnel) Serve(listener net.Listener) error { + for { + conn, err := listener.Accept() + if err != nil { + return err + } + go t.handle(conn) + } +} + +func (t *TCPTunnel) handle(conn net.Conn) { + defer conn.Close() + + if t.EventLogger != nil { + t.EventLogger.Connect(conn.RemoteAddr()) + } + var closeErr error + defer func() { + if t.EventLogger != nil { + t.EventLogger.Error(conn.RemoteAddr(), closeErr) + } + }() + + rc, err := t.HyClient.DialTCP(t.Remote) + if err != nil { + closeErr = err + return + } + defer rc.Close() + + // Start forwarding + copyErrChan := make(chan error, 2) + go func() { + _, copyErr := io.Copy(rc, conn) + copyErrChan <- copyErr + }() + go func() { + _, copyErr := io.Copy(conn, rc) + copyErrChan <- copyErr + }() + closeErr = <-copyErrChan +} diff --git a/app/internal/forwarding/tcp_test.go b/app/internal/forwarding/tcp_test.go new file mode 100644 index 0000000..42710e3 --- /dev/null +++ b/app/internal/forwarding/tcp_test.go @@ -0,0 +1,49 @@ +package forwarding + +import ( + "bytes" + "crypto/rand" + "net" + "testing" + + "github.com/apernet/hysteria/app/internal/utils_test" +) + +func TestTCPTunnel(t *testing.T) { + // Start the tunnel + tunnel := &TCPTunnel{ + HyClient: &utils_test.MockEchoHyClient{}, + Remote: "whatever", + } + l, err := net.Listen("tcp", "127.0.0.1:34567") + if err != nil { + t.Fatal(err) + } + defer l.Close() + go tunnel.Serve(l) + + for i := 0; i < 10; i++ { + conn, err := net.Dial("tcp", "127.0.0.1:34567") + if err != nil { + t.Fatal(err) + } + + data := make([]byte, 1024) + _, _ = rand.Read(data) + _, err = conn.Write(data) + if err != nil { + t.Fatal(err) + } + recv := make([]byte, 1024) + _, err = conn.Read(recv) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, recv) { + t.Fatalf("connection %d: data mismatch", i) + } + + _ = conn.Close() + } +} diff --git a/app/internal/forwarding/udp.go b/app/internal/forwarding/udp.go new file mode 100644 index 0000000..2bf46f4 --- /dev/null +++ b/app/internal/forwarding/udp.go @@ -0,0 +1,146 @@ +package forwarding + +import ( + "net" + "sync" + "sync/atomic" + "time" + + "github.com/apernet/hysteria/core/client" +) + +const ( + udpBufferSize = 4096 + + defaultTimeout = 5 * time.Minute +) + +type UDPTunnel struct { + HyClient client.Client + Remote string + Timeout time.Duration + EventLogger UDPEventLogger +} + +type UDPEventLogger interface { + Connect(addr net.Addr) + Error(addr net.Addr, err error) +} + +type sessionEntry struct { + HyConn client.HyUDPConn + Deadline atomic.Value +} + +type sessionManager struct { + SessionMap map[string]*sessionEntry + Timeout time.Duration + TimeoutFunc func(addr net.Addr) + Mutex sync.RWMutex +} + +func (sm *sessionManager) New(addr net.Addr, hyConn client.HyUDPConn) { + entry := &sessionEntry{ + HyConn: hyConn, + } + entry.Deadline.Store(time.Now().Add(sm.Timeout)) + + // Timeout cleanup routine + go func() { + for { + ttl := entry.Deadline.Load().(time.Time).Sub(time.Now()) + if ttl <= 0 { + // Inactive for too long, close the session + sm.Mutex.Lock() + delete(sm.SessionMap, addr.String()) + sm.Mutex.Unlock() + _ = hyConn.Close() + if sm.TimeoutFunc != nil { + sm.TimeoutFunc(addr) + } + return + } else { + time.Sleep(ttl) + } + } + }() + + sm.Mutex.Lock() + defer sm.Mutex.Unlock() + sm.SessionMap[addr.String()] = entry +} + +func (sm *sessionManager) Get(addr net.Addr) client.HyUDPConn { + sm.Mutex.RLock() + defer sm.Mutex.RUnlock() + if entry, ok := sm.SessionMap[addr.String()]; ok { + return entry.HyConn + } else { + return nil + } +} + +func (sm *sessionManager) Renew(addr net.Addr) { + sm.Mutex.RLock() // RLock is enough as we are not modifying the map itself, only a value in the entry + defer sm.Mutex.RUnlock() + if entry, ok := sm.SessionMap[addr.String()]; ok { + entry.Deadline.Store(time.Now().Add(sm.Timeout)) + } +} + +func (t *UDPTunnel) Serve(listener net.PacketConn) error { + sm := &sessionManager{ + SessionMap: make(map[string]*sessionEntry), + Timeout: t.Timeout, + TimeoutFunc: func(addr net.Addr) { t.EventLogger.Error(addr, nil) }, + } + if sm.Timeout <= 0 { + sm.Timeout = defaultTimeout + } + buf := make([]byte, udpBufferSize) + for { + n, addr, err := listener.ReadFrom(buf) + if err != nil { + return err + } + t.handle(listener, sm, addr, buf[:n]) + } +} + +func (t *UDPTunnel) handle(l net.PacketConn, sm *sessionManager, addr net.Addr, data []byte) { + hyConn := sm.Get(addr) + if hyConn != nil { + // Existing session + _ = hyConn.Send(data, t.Remote) + sm.Renew(addr) + } else { + // New session + if t.EventLogger != nil { + t.EventLogger.Connect(addr) + } + hyConn, err := t.HyClient.ListenUDP() + if err != nil { + if t.EventLogger != nil { + t.EventLogger.Error(addr, err) + } + return + } + sm.New(addr, hyConn) + _ = hyConn.Send(data, t.Remote) + + // Local <- Remote routine + go func() { + for { + data, _, err := hyConn.Receive() + if err != nil { + return + } + _, err = l.WriteTo(data, addr) + if err != nil { + return + } + sm.Renew(addr) + } + }() + } +} diff --git a/app/internal/forwarding/udp_test.go b/app/internal/forwarding/udp_test.go new file mode 100644 index 0000000..006dab8 --- /dev/null +++ b/app/internal/forwarding/udp_test.go @@ -0,0 +1,49 @@ +package forwarding + +import ( + "bytes" + "crypto/rand" + "net" + "testing" + + "github.com/apernet/hysteria/app/internal/utils_test" +) + +func TestUDPTunnel(t *testing.T) { + // Start the tunnel + tunnel := &UDPTunnel{ + HyClient: &utils_test.MockEchoHyClient{}, + Remote: "whatever", + } + l, err := net.ListenPacket("udp", "127.0.0.1:34567") + if err != nil { + t.Fatal(err) + } + defer l.Close() + go tunnel.Serve(l) + + for i := 0; i < 10; i++ { + conn, err := net.Dial("udp", "127.0.0.1:34567") + if err != nil { + t.Fatal(err) + } + + data := make([]byte, 1024) + _, _ = rand.Read(data) + _, err = conn.Write(data) + if err != nil { + t.Fatal(err) + } + recv := make([]byte, 1024) + _, err = conn.Read(recv) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, recv) { + t.Fatalf("connection %d: data mismatch", i) + } + + _ = conn.Close() + } +} diff --git a/app/internal/http/server_test.go b/app/internal/http/server_test.go index 3c942e8..878362f 100644 --- a/app/internal/http/server_test.go +++ b/app/internal/http/server_test.go @@ -15,25 +15,25 @@ const ( testKeyFile = "test.key" ) -type mockEchoHyClient struct{} +type mockHyClient struct{} -func (c *mockEchoHyClient) DialTCP(addr string) (net.Conn, error) { +func (c *mockHyClient) DialTCP(addr string) (net.Conn, error) { return net.Dial("tcp", addr) } -func (c *mockEchoHyClient) ListenUDP() (client.HyUDPConn, error) { +func (c *mockHyClient) ListenUDP() (client.HyUDPConn, error) { // Not implemented return nil, errors.New("not implemented") } -func (c *mockEchoHyClient) Close() error { +func (c *mockHyClient) Close() error { return nil } func TestServer(t *testing.T) { // Start the server s := &Server{ - HyClient: &mockEchoHyClient{}, + HyClient: &mockHyClient{}, } l, err := net.Listen("tcp", "127.0.0.1:18080") if err != nil { diff --git a/app/internal/socks5/server_test.go b/app/internal/socks5/server_test.go index 5fcb103..de85204 100644 --- a/app/internal/socks5/server_test.go +++ b/app/internal/socks5/server_test.go @@ -1,116 +1,17 @@ package socks5 import ( - "io" "net" "os/exec" "testing" - "time" - "github.com/apernet/hysteria/core/client" + "github.com/apernet/hysteria/app/internal/utils_test" ) -type mockEchoHyClient struct{} - -func (c *mockEchoHyClient) DialTCP(addr string) (net.Conn, error) { - return &mockEchoTCPConn{ - BufChan: make(chan []byte, 10), - }, nil -} - -func (c *mockEchoHyClient) ListenUDP() (client.HyUDPConn, error) { - return &mockEchoUDPConn{ - BufChan: make(chan mockEchoUDPPacket, 10), - }, nil -} - -func (c *mockEchoHyClient) Close() error { - return nil -} - -type mockEchoTCPConn struct { - BufChan chan []byte -} - -func (c *mockEchoTCPConn) Read(b []byte) (n int, err error) { - buf := <-c.BufChan - if buf == nil { - // EOF - return 0, io.EOF - } - return copy(b, buf), nil -} - -func (c *mockEchoTCPConn) Write(b []byte) (n int, err error) { - c.BufChan <- b - return len(b), nil -} - -func (c *mockEchoTCPConn) Close() error { - close(c.BufChan) - return nil -} - -func (c *mockEchoTCPConn) LocalAddr() net.Addr { - // Not implemented - return nil -} - -func (c *mockEchoTCPConn) RemoteAddr() net.Addr { - // Not implemented - return nil -} - -func (c *mockEchoTCPConn) SetDeadline(t time.Time) error { - // Not implemented - return nil -} - -func (c *mockEchoTCPConn) SetReadDeadline(t time.Time) error { - // Not implemented - return nil -} - -func (c *mockEchoTCPConn) SetWriteDeadline(t time.Time) error { - // Not implemented - return nil -} - -type mockEchoUDPPacket struct { - Data []byte - Addr string -} - -type mockEchoUDPConn struct { - BufChan chan mockEchoUDPPacket -} - -func (c *mockEchoUDPConn) Receive() ([]byte, string, error) { - p := <-c.BufChan - if p.Data == nil { - // EOF - return nil, "", io.EOF - } - return p.Data, p.Addr, nil -} - -func (c *mockEchoUDPConn) Send(bytes []byte, s string) error { - c.BufChan <- mockEchoUDPPacket{ - Data: bytes, - Addr: s, - } - return nil -} - -func (c *mockEchoUDPConn) Close() error { - close(c.BufChan) - return nil -} - func TestServer(t *testing.T) { // Start the server s := &Server{ - HyClient: &mockEchoHyClient{}, + HyClient: &utils_test.MockEchoHyClient{}, } l, err := net.Listen("tcp", "127.0.0.1:11080") if err != nil { diff --git a/app/internal/utils_test/mock.go b/app/internal/utils_test/mock.go new file mode 100644 index 0000000..4e04d85 --- /dev/null +++ b/app/internal/utils_test/mock.go @@ -0,0 +1,106 @@ +package utils_test + +import ( + "io" + "net" + "time" + + "github.com/apernet/hysteria/core/client" +) + +type MockEchoHyClient struct{} + +func (c *MockEchoHyClient) DialTCP(addr string) (net.Conn, error) { + return &mockEchoTCPConn{ + BufChan: make(chan []byte, 10), + }, nil +} + +func (c *MockEchoHyClient) ListenUDP() (client.HyUDPConn, error) { + return &mockEchoUDPConn{ + BufChan: make(chan mockEchoUDPPacket, 10), + }, nil +} + +func (c *MockEchoHyClient) Close() error { + return nil +} + +type mockEchoTCPConn struct { + BufChan chan []byte +} + +func (c *mockEchoTCPConn) Read(b []byte) (n int, err error) { + buf := <-c.BufChan + if buf == nil { + // EOF + return 0, io.EOF + } + return copy(b, buf), nil +} + +func (c *mockEchoTCPConn) Write(b []byte) (n int, err error) { + c.BufChan <- b + return len(b), nil +} + +func (c *mockEchoTCPConn) Close() error { + close(c.BufChan) + return nil +} + +func (c *mockEchoTCPConn) LocalAddr() net.Addr { + // Not implemented + return nil +} + +func (c *mockEchoTCPConn) RemoteAddr() net.Addr { + // Not implemented + return nil +} + +func (c *mockEchoTCPConn) SetDeadline(t time.Time) error { + // Not implemented + return nil +} + +func (c *mockEchoTCPConn) SetReadDeadline(t time.Time) error { + // Not implemented + return nil +} + +func (c *mockEchoTCPConn) SetWriteDeadline(t time.Time) error { + // Not implemented + return nil +} + +type mockEchoUDPPacket struct { + Data []byte + Addr string +} + +type mockEchoUDPConn struct { + BufChan chan mockEchoUDPPacket +} + +func (c *mockEchoUDPConn) Receive() ([]byte, string, error) { + p := <-c.BufChan + if p.Data == nil { + // EOF + return nil, "", io.EOF + } + return p.Data, p.Addr, nil +} + +func (c *mockEchoUDPConn) Send(bytes []byte, s string) error { + c.BufChan <- mockEchoUDPPacket{ + Data: bytes, + Addr: s, + } + return nil +} + +func (c *mockEchoUDPConn) Close() error { + close(c.BufChan) + return nil +} diff --git a/app/server.example.yaml b/app/server.example.yaml index e3d8fc1..788ea93 100644 --- a/app/server.example.yaml +++ b/app/server.example.yaml @@ -20,14 +20,14 @@ acme: # maxStreamReceiveWindow: 8388608 # initConnReceiveWindow: 20971520 # maxConnReceiveWindow: 20971520 -# maxIdleTimeout: 130s +# maxIdleTimeout: 30s # maxIncomingStreams: 1024 # disablePathMTUDiscovery: false # bandwidth: # up: 100 mbps # down: 100 mbps -# + # disableUDP: false auth: @@ -38,4 +38,4 @@ masquerade: type: proxy proxy: url: https://some.site.net - rewriteHost: true \ No newline at end of file + rewriteHost: true diff --git a/core/server/server.go b/core/server/server.go index 48c5fcc..2b0b1ca 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -323,47 +323,45 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) { msgBuf := make([]byte, protocol.MaxUDPSize) for { udpN, rAddr, err := conn.ReadFrom(udpBuf) - if udpN > 0 { - if h.config.TrafficLogger != nil { - ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN)) - if !ok { - // TrafficLogger requested to disconnect the client - _ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "") - return - } - } - // Try no frag first - msg := protocol.UDPMessage{ - SessionID: sessionID, - PacketID: 0, - FragID: 0, - FragCount: 1, - Addr: rAddr, - Data: udpBuf[:udpN], - } - msgN := msg.Serialize(msgBuf) - if msgN < 0 { - // Message even larger than MaxUDPSize, drop it - continue - } - sendErr := h.conn.SendMessage(msgBuf[:msgN]) - var errTooLarge quic.ErrMessageTooLarge - if errors.As(sendErr, &errTooLarge) { - // Message too large, try fragmentation - msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 - fMsgs := frag.FragUDPMessage(msg, int(errTooLarge)) - for _, fMsg := range fMsgs { - msgN = fMsg.Serialize(msgBuf) - _ = h.conn.SendMessage(msgBuf[:msgN]) - } + if err != nil { + connCloseFunc() + _ = stream.Close() + return + } + if h.config.TrafficLogger != nil { + ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN)) + if !ok { + // TrafficLogger requested to disconnect the client + _ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "") + return } } - if err != nil { - break + // Try no frag first + msg := protocol.UDPMessage{ + SessionID: sessionID, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: rAddr, + Data: udpBuf[:udpN], + } + msgN := msg.Serialize(msgBuf) + if msgN < 0 { + // Message even larger than MaxUDPSize, drop it + continue + } + sendErr := h.conn.SendMessage(msgBuf[:msgN]) + var errTooLarge quic.ErrMessageTooLarge + if errors.As(sendErr, &errTooLarge) { + // Message too large, try fragmentation + msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 + fMsgs := frag.FragUDPMessage(msg, int(errTooLarge)) + for _, fMsg := range fMsgs { + msgN = fMsg.Serialize(msgBuf) + _ = h.conn.SendMessage(msgBuf[:msgN]) + } } } - connCloseFunc() - _ = stream.Close() }() // Hold (drain) the stream until the client closes it. diff --git a/hyperbole.py b/hyperbole.py index c1b94dd..a7bef2c 100644 --- a/hyperbole.py +++ b/hyperbole.py @@ -3,6 +3,7 @@ import argparse import os +import sys import subprocess import datetime import shutil @@ -162,9 +163,11 @@ def cmd_run(args): try: subprocess.check_call(cmd) - except Exception: - print('Failed to run app') - return + except KeyboardInterrupt: + pass + except subprocess.CalledProcessError as e: + # Pass through the exit code + sys.exit(e.returncode) def cmd_format(): @@ -176,7 +179,6 @@ def cmd_format(): subprocess.check_call(['gofumpt', '-w', '-l', '-extra', '.']) except Exception: print('Failed to format code') - return def cmd_clean():