add AES encryption support for QUIC
This commit is contained in:
parent
db4591cadd
commit
4cdf5d5b8b
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
@ -223,6 +224,18 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
||||
TLSConfig: tlsCfg,
|
||||
KeepAlive: toBool(node.Values.Get("keepalive")),
|
||||
}
|
||||
|
||||
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
||||
config.Timeout = time.Duration(timeout) * time.Second
|
||||
|
||||
idle, _ := strconv.Atoi(node.Values.Get("idle"))
|
||||
config.IdleTimeout = time.Duration(idle) * time.Second
|
||||
|
||||
if key := node.Values.Get("key"); key != "" {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
config.Key = sum[:]
|
||||
}
|
||||
|
||||
tr = gost.QUICTransporter(config)
|
||||
case "http2":
|
||||
tr = gost.HTTP2Transporter(tlsCfg)
|
||||
@ -371,6 +384,15 @@ func (r *route) serve() error {
|
||||
}
|
||||
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
||||
config.Timeout = time.Duration(timeout) * time.Second
|
||||
|
||||
idle, _ := strconv.Atoi(node.Values.Get("idle"))
|
||||
config.IdleTimeout = time.Duration(idle) * time.Second
|
||||
|
||||
if key := node.Values.Get("key"); key != "" {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
config.Key = sum[:]
|
||||
}
|
||||
|
||||
ln, err = gost.QUICListener(node.Addr, config)
|
||||
case "http2":
|
||||
ln, err = gost.HTTP2Listener(node.Addr, tlsCfg)
|
||||
|
117
quic.go
117
quic.go
@ -1,8 +1,12 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@ -55,10 +59,17 @@ func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
||||
|
||||
session, ok := tr.sessions[addr]
|
||||
if !ok {
|
||||
conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
var cc *net.UDPConn
|
||||
cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn = cc
|
||||
|
||||
if tr.config != nil && tr.config.Key != nil {
|
||||
conn = &quicCipherConn{UDPConn: cc, key: tr.config.Key}
|
||||
}
|
||||
|
||||
session = &quicSession{conn: conn}
|
||||
tr.sessions[addr] = session
|
||||
}
|
||||
@ -107,7 +118,7 @@ func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption)
|
||||
}
|
||||
|
||||
func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) {
|
||||
udpConn, ok := conn.(*net.UDPConn)
|
||||
udpConn, ok := conn.(net.PacketConn)
|
||||
if !ok {
|
||||
return nil, errors.New("quic: wrong connection type")
|
||||
}
|
||||
@ -118,6 +129,7 @@ func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICC
|
||||
quicConfig := &quic.Config{
|
||||
HandshakeTimeout: config.Timeout,
|
||||
KeepAlive: config.KeepAlive,
|
||||
IdleTimeout: config.IdleTimeout,
|
||||
}
|
||||
session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig)
|
||||
if err != nil {
|
||||
@ -133,9 +145,11 @@ func (tr *quicTransporter) Multiplex() bool {
|
||||
|
||||
// QUICConfig is the config for QUIC client and server
|
||||
type QUICConfig struct {
|
||||
TLSConfig *tls.Config
|
||||
Timeout time.Duration
|
||||
KeepAlive bool
|
||||
TLSConfig *tls.Config
|
||||
Timeout time.Duration
|
||||
KeepAlive bool
|
||||
IdleTimeout time.Duration
|
||||
Key []byte
|
||||
}
|
||||
|
||||
type quicListener struct {
|
||||
@ -152,13 +166,31 @@ func QUICListener(addr string, config *QUICConfig) (Listener, error) {
|
||||
quicConfig := &quic.Config{
|
||||
HandshakeTimeout: config.Timeout,
|
||||
KeepAlive: config.KeepAlive,
|
||||
IdleTimeout: config.IdleTimeout,
|
||||
}
|
||||
|
||||
tlsConfig := config.TLSConfig
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = DefaultTLSConfig
|
||||
}
|
||||
ln, err := quic.ListenAddr(addr, tlsConfig, quicConfig)
|
||||
|
||||
var conn net.PacketConn
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lconn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = lconn
|
||||
|
||||
if config.Key != nil {
|
||||
conn = &quicCipherConn{UDPConn: lconn, key: config.Key}
|
||||
}
|
||||
|
||||
ln, err := quic.Listen(conn, tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -241,3 +273,76 @@ func (c *quicConn) LocalAddr() net.Addr {
|
||||
func (c *quicConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
}
|
||||
|
||||
type quicCipherConn struct {
|
||||
*net.UDPConn
|
||||
key []byte
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = conn.UDPConn.ReadFrom(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b, err := conn.decrypt(data[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
copy(data, b)
|
||||
|
||||
return len(b), addr, nil
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
|
||||
b, err := conn.encrypt(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.UDPConn.WriteTo(b, addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) {
|
||||
c, err := aes.NewCipher(conn.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) {
|
||||
c, err := aes.NewCipher(conn.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
return gcm.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
||||
|
5
tls.go
5
tls.go
@ -56,6 +56,11 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
||||
defer tr.sessionMutex.Unlock()
|
||||
|
||||
session, ok := tr.sessions[addr]
|
||||
if session != nil && session.session != nil && session.session.IsClosed() {
|
||||
session.Close()
|
||||
delete(tr.sessions, addr)
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
if opts.Chain == nil {
|
||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
|
11
ws.go
11
ws.go
@ -158,6 +158,11 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
|
||||
defer tr.sessionMutex.Unlock()
|
||||
|
||||
session, ok := tr.sessions[addr]
|
||||
if session != nil && session.session != nil && session.session.IsClosed() {
|
||||
session.Close()
|
||||
delete(tr.sessions, addr)
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
if opts.Chain == nil {
|
||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
@ -193,6 +198,7 @@ func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
|
||||
session = s
|
||||
tr.sessions[opts.Addr] = session
|
||||
}
|
||||
|
||||
cc, err := session.GetConn()
|
||||
if err != nil {
|
||||
session.Close()
|
||||
@ -281,6 +287,11 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
||||
defer tr.sessionMutex.Unlock()
|
||||
|
||||
session, ok := tr.sessions[addr]
|
||||
if session != nil && session.session != nil && session.session.IsClosed() {
|
||||
session.Close()
|
||||
delete(tr.sessions, addr)
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
if opts.Chain == nil {
|
||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
|
Loading…
Reference in New Issue
Block a user