diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 7fd8de3..6cc244d 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/json" @@ -223,6 +224,18 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { TLSConfig: tlsCfg, KeepAlive: toBool(node.Values.Get("keepalive")), } + + timeout, _ := strconv.Atoi(node.Values.Get("timeout")) + config.Timeout = time.Duration(timeout) * time.Second + + idle, _ := strconv.Atoi(node.Values.Get("idle")) + config.IdleTimeout = time.Duration(idle) * time.Second + + if key := node.Values.Get("key"); key != "" { + sum := sha256.Sum256([]byte(key)) + config.Key = sum[:] + } + tr = gost.QUICTransporter(config) case "http2": tr = gost.HTTP2Transporter(tlsCfg) @@ -371,6 +384,15 @@ func (r *route) serve() error { } timeout, _ := strconv.Atoi(node.Values.Get("timeout")) config.Timeout = time.Duration(timeout) * time.Second + + idle, _ := strconv.Atoi(node.Values.Get("idle")) + config.IdleTimeout = time.Duration(idle) * time.Second + + if key := node.Values.Get("key"); key != "" { + sum := sha256.Sum256([]byte(key)) + config.Key = sum[:] + } + ln, err = gost.QUICListener(node.Addr, config) case "http2": ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) diff --git a/quic.go b/quic.go index 56a61c8..5353c77 100644 --- a/quic.go +++ b/quic.go @@ -1,8 +1,12 @@ package gost import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" "crypto/tls" "errors" + "io" "net" "sync" "time" @@ -55,10 +59,17 @@ func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Co session, ok := tr.sessions[addr] if !ok { - conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + var cc *net.UDPConn + cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return } + conn = cc + + if tr.config != nil && tr.config.Key != nil { + conn = &quicCipherConn{UDPConn: cc, key: tr.config.Key} + } + session = &quicSession{conn: conn} tr.sessions[addr] = session } @@ -107,7 +118,7 @@ func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) } func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) { - udpConn, ok := conn.(*net.UDPConn) + udpConn, ok := conn.(net.PacketConn) if !ok { return nil, errors.New("quic: wrong connection type") } @@ -118,6 +129,7 @@ func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICC quicConfig := &quic.Config{ HandshakeTimeout: config.Timeout, KeepAlive: config.KeepAlive, + IdleTimeout: config.IdleTimeout, } session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig) if err != nil { @@ -133,9 +145,11 @@ func (tr *quicTransporter) Multiplex() bool { // QUICConfig is the config for QUIC client and server type QUICConfig struct { - TLSConfig *tls.Config - Timeout time.Duration - KeepAlive bool + TLSConfig *tls.Config + Timeout time.Duration + KeepAlive bool + IdleTimeout time.Duration + Key []byte } type quicListener struct { @@ -152,13 +166,31 @@ func QUICListener(addr string, config *QUICConfig) (Listener, error) { quicConfig := &quic.Config{ HandshakeTimeout: config.Timeout, KeepAlive: config.KeepAlive, + IdleTimeout: config.IdleTimeout, } tlsConfig := config.TLSConfig if tlsConfig == nil { tlsConfig = DefaultTLSConfig } - ln, err := quic.ListenAddr(addr, tlsConfig, quicConfig) + + var conn net.PacketConn + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + lconn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + conn = lconn + + if config.Key != nil { + conn = &quicCipherConn{UDPConn: lconn, key: config.Key} + } + + ln, err := quic.Listen(conn, tlsConfig, quicConfig) if err != nil { return nil, err } @@ -241,3 +273,76 @@ func (c *quicConn) LocalAddr() net.Addr { func (c *quicConn) RemoteAddr() net.Addr { return c.raddr } + +type quicCipherConn struct { + *net.UDPConn + key []byte +} + +func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { + n, addr, err = conn.UDPConn.ReadFrom(data) + if err != nil { + return + } + b, err := conn.decrypt(data[:n]) + if err != nil { + return + } + + copy(data, b) + + return len(b), addr, nil +} + +func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { + b, err := conn.encrypt(data) + if err != nil { + return + } + + _, err = conn.UDPConn.WriteTo(b, addr) + if err != nil { + return + } + + return len(b), nil +} + +func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) { + c, err := aes.NewCipher(conn.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + return gcm.Seal(nonce, nonce, data, nil), nil +} + +func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) { + c, err := aes.NewCipher(conn.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return nil, errors.New("ciphertext too short") + } + + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + return gcm.Open(nil, nonce, ciphertext, nil) +} diff --git a/tls.go b/tls.go index f444330..35791d6 100644 --- a/tls.go +++ b/tls.go @@ -56,6 +56,11 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] + if session != nil && session.session != nil && session.session.IsClosed() { + session.Close() + delete(tr.sessions, addr) + ok = false + } if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) diff --git a/ws.go b/ws.go index fa9c043..0b9c61f 100644 --- a/ws.go +++ b/ws.go @@ -158,6 +158,11 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] + if session != nil && session.session != nil && session.session.IsClosed() { + session.Close() + delete(tr.sessions, addr) + ok = false + } if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) @@ -193,6 +198,7 @@ func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( session = s tr.sessions[opts.Addr] = session } + cc, err := session.GetConn() if err != nil { session.Close() @@ -281,6 +287,11 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] + if session != nil && session.session != nil && session.session.IsClosed() { + session.Close() + delete(tr.sessions, addr) + ok = false + } if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout)