add AES encryption support for QUIC

This commit is contained in:
rui.zheng 2017-11-21 13:45:26 +08:00
parent db4591cadd
commit 4cdf5d5b8b
4 changed files with 149 additions and 6 deletions

View File

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
@ -223,6 +224,18 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
TLSConfig: tlsCfg, TLSConfig: tlsCfg,
KeepAlive: toBool(node.Values.Get("keepalive")), KeepAlive: toBool(node.Values.Get("keepalive")),
} }
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
config.Timeout = time.Duration(timeout) * time.Second
idle, _ := strconv.Atoi(node.Values.Get("idle"))
config.IdleTimeout = time.Duration(idle) * time.Second
if key := node.Values.Get("key"); key != "" {
sum := sha256.Sum256([]byte(key))
config.Key = sum[:]
}
tr = gost.QUICTransporter(config) tr = gost.QUICTransporter(config)
case "http2": case "http2":
tr = gost.HTTP2Transporter(tlsCfg) tr = gost.HTTP2Transporter(tlsCfg)
@ -371,6 +384,15 @@ func (r *route) serve() error {
} }
timeout, _ := strconv.Atoi(node.Values.Get("timeout")) timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
config.Timeout = time.Duration(timeout) * time.Second config.Timeout = time.Duration(timeout) * time.Second
idle, _ := strconv.Atoi(node.Values.Get("idle"))
config.IdleTimeout = time.Duration(idle) * time.Second
if key := node.Values.Get("key"); key != "" {
sum := sha256.Sum256([]byte(key))
config.Key = sum[:]
}
ln, err = gost.QUICListener(node.Addr, config) ln, err = gost.QUICListener(node.Addr, config)
case "http2": case "http2":
ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) ln, err = gost.HTTP2Listener(node.Addr, tlsCfg)

117
quic.go
View File

@ -1,8 +1,12 @@
package gost package gost
import ( import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@ -55,10 +59,17 @@ func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Co
session, ok := tr.sessions[addr] session, ok := tr.sessions[addr]
if !ok { if !ok {
conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) var cc *net.UDPConn
cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil { if err != nil {
return return
} }
conn = cc
if tr.config != nil && tr.config.Key != nil {
conn = &quicCipherConn{UDPConn: cc, key: tr.config.Key}
}
session = &quicSession{conn: conn} session = &quicSession{conn: conn}
tr.sessions[addr] = session tr.sessions[addr] = session
} }
@ -107,7 +118,7 @@ func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption)
} }
func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) { func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) {
udpConn, ok := conn.(*net.UDPConn) udpConn, ok := conn.(net.PacketConn)
if !ok { if !ok {
return nil, errors.New("quic: wrong connection type") return nil, errors.New("quic: wrong connection type")
} }
@ -118,6 +129,7 @@ func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICC
quicConfig := &quic.Config{ quicConfig := &quic.Config{
HandshakeTimeout: config.Timeout, HandshakeTimeout: config.Timeout,
KeepAlive: config.KeepAlive, KeepAlive: config.KeepAlive,
IdleTimeout: config.IdleTimeout,
} }
session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig) session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig)
if err != nil { if err != nil {
@ -133,9 +145,11 @@ func (tr *quicTransporter) Multiplex() bool {
// QUICConfig is the config for QUIC client and server // QUICConfig is the config for QUIC client and server
type QUICConfig struct { type QUICConfig struct {
TLSConfig *tls.Config TLSConfig *tls.Config
Timeout time.Duration Timeout time.Duration
KeepAlive bool KeepAlive bool
IdleTimeout time.Duration
Key []byte
} }
type quicListener struct { type quicListener struct {
@ -152,13 +166,31 @@ func QUICListener(addr string, config *QUICConfig) (Listener, error) {
quicConfig := &quic.Config{ quicConfig := &quic.Config{
HandshakeTimeout: config.Timeout, HandshakeTimeout: config.Timeout,
KeepAlive: config.KeepAlive, KeepAlive: config.KeepAlive,
IdleTimeout: config.IdleTimeout,
} }
tlsConfig := config.TLSConfig tlsConfig := config.TLSConfig
if tlsConfig == nil { if tlsConfig == nil {
tlsConfig = DefaultTLSConfig tlsConfig = DefaultTLSConfig
} }
ln, err := quic.ListenAddr(addr, tlsConfig, quicConfig)
var conn net.PacketConn
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
lconn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
conn = lconn
if config.Key != nil {
conn = &quicCipherConn{UDPConn: lconn, key: config.Key}
}
ln, err := quic.Listen(conn, tlsConfig, quicConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -241,3 +273,76 @@ func (c *quicConn) LocalAddr() net.Addr {
func (c *quicConn) RemoteAddr() net.Addr { func (c *quicConn) RemoteAddr() net.Addr {
return c.raddr return c.raddr
} }
type quicCipherConn struct {
*net.UDPConn
key []byte
}
func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
n, addr, err = conn.UDPConn.ReadFrom(data)
if err != nil {
return
}
b, err := conn.decrypt(data[:n])
if err != nil {
return
}
copy(data, b)
return len(b), addr, nil
}
func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
b, err := conn.encrypt(data)
if err != nil {
return
}
_, err = conn.UDPConn.WriteTo(b, addr)
if err != nil {
return
}
return len(b), nil
}
func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, data, nil), nil
}
func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
}

5
tls.go
View File

@ -56,6 +56,11 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
defer tr.sessionMutex.Unlock() defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] session, ok := tr.sessions[addr]
if session != nil && session.session != nil && session.session.IsClosed() {
session.Close()
delete(tr.sessions, addr)
ok = false
}
if !ok { if !ok {
if opts.Chain == nil { if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout) conn, err = net.DialTimeout("tcp", addr, opts.Timeout)

11
ws.go
View File

@ -158,6 +158,11 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
defer tr.sessionMutex.Unlock() defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] session, ok := tr.sessions[addr]
if session != nil && session.session != nil && session.session.IsClosed() {
session.Close()
delete(tr.sessions, addr)
ok = false
}
if !ok { if !ok {
if opts.Chain == nil { if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout) conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
@ -193,6 +198,7 @@ func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
session = s session = s
tr.sessions[opts.Addr] = session tr.sessions[opts.Addr] = session
} }
cc, err := session.GetConn() cc, err := session.GetConn()
if err != nil { if err != nil {
session.Close() session.Close()
@ -281,6 +287,11 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
defer tr.sessionMutex.Unlock() defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] session, ok := tr.sessions[addr]
if session != nil && session.session != nil && session.session.IsClosed() {
session.Close()
delete(tr.sessions, addr)
ok = false
}
if !ok { if !ok {
if opts.Chain == nil { if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout) conn, err = net.DialTimeout("tcp", addr, opts.Timeout)