add AES encryption support for QUIC
This commit is contained in:
parent
db4591cadd
commit
4cdf5d5b8b
@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -223,6 +224,18 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
|||||||
TLSConfig: tlsCfg,
|
TLSConfig: tlsCfg,
|
||||||
KeepAlive: toBool(node.Values.Get("keepalive")),
|
KeepAlive: toBool(node.Values.Get("keepalive")),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
||||||
|
config.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
|
||||||
|
idle, _ := strconv.Atoi(node.Values.Get("idle"))
|
||||||
|
config.IdleTimeout = time.Duration(idle) * time.Second
|
||||||
|
|
||||||
|
if key := node.Values.Get("key"); key != "" {
|
||||||
|
sum := sha256.Sum256([]byte(key))
|
||||||
|
config.Key = sum[:]
|
||||||
|
}
|
||||||
|
|
||||||
tr = gost.QUICTransporter(config)
|
tr = gost.QUICTransporter(config)
|
||||||
case "http2":
|
case "http2":
|
||||||
tr = gost.HTTP2Transporter(tlsCfg)
|
tr = gost.HTTP2Transporter(tlsCfg)
|
||||||
@ -371,6 +384,15 @@ func (r *route) serve() error {
|
|||||||
}
|
}
|
||||||
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
||||||
config.Timeout = time.Duration(timeout) * time.Second
|
config.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
|
||||||
|
idle, _ := strconv.Atoi(node.Values.Get("idle"))
|
||||||
|
config.IdleTimeout = time.Duration(idle) * time.Second
|
||||||
|
|
||||||
|
if key := node.Values.Get("key"); key != "" {
|
||||||
|
sum := sha256.Sum256([]byte(key))
|
||||||
|
config.Key = sum[:]
|
||||||
|
}
|
||||||
|
|
||||||
ln, err = gost.QUICListener(node.Addr, config)
|
ln, err = gost.QUICListener(node.Addr, config)
|
||||||
case "http2":
|
case "http2":
|
||||||
ln, err = gost.HTTP2Listener(node.Addr, tlsCfg)
|
ln, err = gost.HTTP2Listener(node.Addr, tlsCfg)
|
||||||
|
117
quic.go
117
quic.go
@ -1,8 +1,12 @@
|
|||||||
package gost
|
package gost
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -55,10 +59,17 @@ func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
|||||||
|
|
||||||
session, ok := tr.sessions[addr]
|
session, ok := tr.sessions[addr]
|
||||||
if !ok {
|
if !ok {
|
||||||
conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
var cc *net.UDPConn
|
||||||
|
cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
conn = cc
|
||||||
|
|
||||||
|
if tr.config != nil && tr.config.Key != nil {
|
||||||
|
conn = &quicCipherConn{UDPConn: cc, key: tr.config.Key}
|
||||||
|
}
|
||||||
|
|
||||||
session = &quicSession{conn: conn}
|
session = &quicSession{conn: conn}
|
||||||
tr.sessions[addr] = session
|
tr.sessions[addr] = session
|
||||||
}
|
}
|
||||||
@ -107,7 +118,7 @@ func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) {
|
func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) {
|
||||||
udpConn, ok := conn.(*net.UDPConn)
|
udpConn, ok := conn.(net.PacketConn)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("quic: wrong connection type")
|
return nil, errors.New("quic: wrong connection type")
|
||||||
}
|
}
|
||||||
@ -118,6 +129,7 @@ func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICC
|
|||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
HandshakeTimeout: config.Timeout,
|
HandshakeTimeout: config.Timeout,
|
||||||
KeepAlive: config.KeepAlive,
|
KeepAlive: config.KeepAlive,
|
||||||
|
IdleTimeout: config.IdleTimeout,
|
||||||
}
|
}
|
||||||
session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig)
|
session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -133,9 +145,11 @@ func (tr *quicTransporter) Multiplex() bool {
|
|||||||
|
|
||||||
// QUICConfig is the config for QUIC client and server
|
// QUICConfig is the config for QUIC client and server
|
||||||
type QUICConfig struct {
|
type QUICConfig struct {
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
KeepAlive bool
|
KeepAlive bool
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
Key []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type quicListener struct {
|
type quicListener struct {
|
||||||
@ -152,13 +166,31 @@ func QUICListener(addr string, config *QUICConfig) (Listener, error) {
|
|||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
HandshakeTimeout: config.Timeout,
|
HandshakeTimeout: config.Timeout,
|
||||||
KeepAlive: config.KeepAlive,
|
KeepAlive: config.KeepAlive,
|
||||||
|
IdleTimeout: config.IdleTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := config.TLSConfig
|
tlsConfig := config.TLSConfig
|
||||||
if tlsConfig == nil {
|
if tlsConfig == nil {
|
||||||
tlsConfig = DefaultTLSConfig
|
tlsConfig = DefaultTLSConfig
|
||||||
}
|
}
|
||||||
ln, err := quic.ListenAddr(addr, tlsConfig, quicConfig)
|
|
||||||
|
var conn net.PacketConn
|
||||||
|
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lconn, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn = lconn
|
||||||
|
|
||||||
|
if config.Key != nil {
|
||||||
|
conn = &quicCipherConn{UDPConn: lconn, key: config.Key}
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := quic.Listen(conn, tlsConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -241,3 +273,76 @@ func (c *quicConn) LocalAddr() net.Addr {
|
|||||||
func (c *quicConn) RemoteAddr() net.Addr {
|
func (c *quicConn) RemoteAddr() net.Addr {
|
||||||
return c.raddr
|
return c.raddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type quicCipherConn struct {
|
||||||
|
*net.UDPConn
|
||||||
|
key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
|
||||||
|
n, addr, err = conn.UDPConn.ReadFrom(data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b, err := conn.decrypt(data[:n])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(data, b)
|
||||||
|
|
||||||
|
return len(b), addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
|
||||||
|
b, err := conn.encrypt(data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.UDPConn.WriteTo(b, addr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) {
|
||||||
|
c, err := aes.NewCipher(conn.key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return gcm.Seal(nonce, nonce, data, nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) {
|
||||||
|
c, err := aes.NewCipher(conn.key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
if len(data) < nonceSize {
|
||||||
|
return nil, errors.New("ciphertext too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||||
|
return gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
}
|
||||||
|
5
tls.go
5
tls.go
@ -56,6 +56,11 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
|||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
|
|
||||||
session, ok := tr.sessions[addr]
|
session, ok := tr.sessions[addr]
|
||||||
|
if session != nil && session.session != nil && session.session.IsClosed() {
|
||||||
|
session.Close()
|
||||||
|
delete(tr.sessions, addr)
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
if opts.Chain == nil {
|
if opts.Chain == nil {
|
||||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||||
|
11
ws.go
11
ws.go
@ -158,6 +158,11 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
|
|||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
|
|
||||||
session, ok := tr.sessions[addr]
|
session, ok := tr.sessions[addr]
|
||||||
|
if session != nil && session.session != nil && session.session.IsClosed() {
|
||||||
|
session.Close()
|
||||||
|
delete(tr.sessions, addr)
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
if opts.Chain == nil {
|
if opts.Chain == nil {
|
||||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||||
@ -193,6 +198,7 @@ func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
|
|||||||
session = s
|
session = s
|
||||||
tr.sessions[opts.Addr] = session
|
tr.sessions[opts.Addr] = session
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err := session.GetConn()
|
cc, err := session.GetConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
session.Close()
|
session.Close()
|
||||||
@ -281,6 +287,11 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
|||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
|
|
||||||
session, ok := tr.sessions[addr]
|
session, ok := tr.sessions[addr]
|
||||||
|
if session != nil && session.session != nil && session.session.IsClosed() {
|
||||||
|
session.Close()
|
||||||
|
delete(tr.sessions, addr)
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
if opts.Chain == nil {
|
if opts.Chain == nil {
|
||||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||||
|
Loading…
Reference in New Issue
Block a user