add QUIC support

This commit is contained in:
rui.zheng 2017-07-29 16:18:04 +08:00
parent 503568b093
commit 151a2f902b
233 changed files with 27733 additions and 8465 deletions

View File

@ -127,6 +127,7 @@ type HandshakeOptions struct {
TLSConfig *tls.Config TLSConfig *tls.Config
WSOptions *WSOptions WSOptions *WSOptions
KCPConfig *KCPConfig KCPConfig *KCPConfig
QUICConfig *QUICConfig
} }
// HandshakeOption allows a common way to set handshake options. // HandshakeOption allows a common way to set handshake options.

View File

@ -128,6 +128,7 @@ func main() {
}, },
*/ */
/*
// http+ssh // http+ssh
gost.Node{ gost.Node{
Addr: "127.0.0.1:12222", Addr: "127.0.0.1:12222",
@ -136,11 +137,21 @@ func main() {
gost.SSHTunnelTransporter(), gost.SSHTunnelTransporter(),
), ),
}, },
*/
// http+quic
gost.Node{
Addr: "localhost:6121",
Client: gost.NewClient(
gost.HTTPConnector(url.UserPassword("admin", "123456")),
gost.QUICTransporter(nil),
),
},
) )
total := 0 total := 0
for total < requests { for total < requests {
if total + concurrency > requests { if total+concurrency > requests {
concurrency = requests - total concurrency = requests - total
} }
startChan := make(chan struct{}) startChan := make(chan struct{})

View File

@ -40,7 +40,7 @@ func main() {
// go sshForwardServer() // go sshForwardServer()
go sshTunnelServer() go sshTunnelServer()
// go http2Server() // go http2Server()
go quicServer()
select {} select {}
} }
@ -242,6 +242,18 @@ func http2Server() {
log.Fatal(s.Serve(ln)) log.Fatal(s.Serve(ln))
} }
func quicServer() {
s := &gost.Server{}
s.Handle(gost.HTTPHandler(
gost.UsersHandlerOption(url.UserPassword("admin", "123456")),
))
ln, err := gost.QUICListener("localhost:6121", &gost.QUICConfig{TLSConfig: tlsConfig()})
if err != nil {
log.Fatal(err)
}
log.Fatal(s.Serve(ln))
}
var ( var (
rawCert = []byte(`-----BEGIN CERTIFICATE----- rawCert = []byte(`-----BEGIN CERTIFICATE-----
MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw

109
gost/examples/quic/quicc.go Normal file
View File

@ -0,0 +1,109 @@
package main
import (
"crypto/tls"
"flag"
"log"
"github.com/ginuerzh/gost/gost"
)
var (
laddr, faddr string
quiet bool
)
func init() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
flag.StringVar(&laddr, "L", ":18080", "listen address")
flag.StringVar(&faddr, "F", ":6121", "forward address")
flag.BoolVar(&quiet, "q", false, "quiet mode")
flag.BoolVar(&gost.Debug, "d", false, "debug mode")
flag.Parse()
if quiet {
gost.SetLogger(&gost.NopLogger{})
}
}
func main() {
chain := gost.NewChain(
gost.Node{
Protocol: "socks5",
Transport: "quic",
Addr: faddr,
Client: gost.NewClient(
gost.SOCKS5Connector(nil),
gost.QUICTransporter(nil),
),
},
)
s := &gost.Server{}
s.Handle(gost.SOCKS5Handler(
gost.ChainHandlerOption(chain),
gost.TLSConfigHandlerOption(tlsConfig()),
))
ln, err := gost.TCPListener(laddr)
if err != nil {
log.Fatal(err)
}
log.Fatal(s.Serve(ln))
}
var (
rawCert = []byte(`-----BEGIN CERTIFICATE-----
MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5
MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N
0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw
hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4
8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv
482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR
LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF
oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC
CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN
l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS
cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w
emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj
b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57
lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg==
-----END CERTIFICATE-----`)
rawKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N
ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo
N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv
GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh
Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg
IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n
IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H
r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe
yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru
kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk
TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU
k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o
/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK
HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg
HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY
CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d
JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr
pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt
/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD
xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL
vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX
1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt
7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4
fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw
cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI=
-----END RSA PRIVATE KEY-----`)
)
func tlsConfig() *tls.Config {
cert, err := tls.X509KeyPair(rawCert, rawKey)
if err != nil {
panic(err)
}
return &tls.Config{Certificates: []tls.Certificate{cert}}
}

102
gost/examples/quic/quics.go Normal file
View File

@ -0,0 +1,102 @@
package main
import (
"crypto/tls"
"flag"
"log"
"github.com/ginuerzh/gost/gost"
)
var (
laddr string
quiet bool
)
func init() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
flag.StringVar(&laddr, "L", ":6121", "listen address")
flag.BoolVar(&quiet, "q", false, "quiet mode")
flag.BoolVar(&gost.Debug, "d", false, "debug mode")
flag.Parse()
if quiet {
gost.SetLogger(&gost.NopLogger{})
}
}
func main() {
quicServer()
}
func quicServer() {
s := &gost.Server{}
s.Handle(
gost.SOCKS5Handler(gost.TLSConfigHandlerOption(tlsConfig())),
)
ln, err := gost.QUICListener(laddr, &gost.QUICConfig{TLSConfig: tlsConfig()})
if err != nil {
log.Fatal(err)
}
log.Println("server listen on", laddr)
log.Fatal(s.Serve(ln))
}
var (
rawCert = []byte(`-----BEGIN CERTIFICATE-----
MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5
MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N
0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw
hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4
8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv
482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR
LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF
oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC
CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN
l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS
cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w
emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj
b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57
lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg==
-----END CERTIFICATE-----`)
rawKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N
ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo
N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv
GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh
Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg
IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n
IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H
r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe
yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru
kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk
TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU
k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o
/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK
HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg
HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY
CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d
JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr
pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt
/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD
xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL
vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX
1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt
7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4
fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw
cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI=
-----END RSA PRIVATE KEY-----`)
)
func tlsConfig() *tls.Config {
cert, err := tls.X509KeyPair(rawCert, rawKey)
if err != nil {
panic(err)
}
return &tls.Config{Certificates: []tls.Certificate{cert}}
}

View File

@ -351,18 +351,6 @@ func (l *kcpListener) listenLoop() {
} }
} }
func (l *kcpListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *kcpListener) mux(conn net.Conn) { func (l *kcpListener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig() smuxConfig := smux.DefaultConfig()
smuxConfig.MaxReceiveBuffer = l.config.SockBuf smuxConfig.MaxReceiveBuffer = l.config.SockBuf
@ -393,11 +381,22 @@ func (l *kcpListener) mux(conn net.Conn) {
select { select {
case l.connChan <- newKCPConn(conn, stream): case l.connChan <- newKCPConn(conn, stream):
default: default:
log.Logf("[kcp] %s - %s: connection queue is full", conn.RemoteAddr(), l.Addr()) log.Logf("[kcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
} }
} }
} }
func (l *kcpListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *kcpListener) Addr() net.Addr { func (l *kcpListener) Addr() net.Addr {
return l.ln.Addr() return l.ln.Addr()
} }

234
gost/quic.go Normal file
View File

@ -0,0 +1,234 @@
package gost
import (
"crypto/tls"
"errors"
"net"
"sync"
"time"
"github.com/go-log/log"
quic "github.com/lucas-clemente/quic-go"
)
type quicSession struct {
conn net.Conn
session quic.Session
}
func (session *quicSession) GetConn() (*quicConn, error) {
stream, err := session.session.OpenStream()
if err != nil {
return nil, err
}
return &quicConn{
Stream: stream,
laddr: session.session.LocalAddr(),
raddr: session.session.RemoteAddr(),
}, nil
}
func (session *quicSession) Close() error {
return session.session.Close(nil)
}
type quicTransporter struct {
config *QUICConfig
sessionMutex sync.Mutex
sessions map[string]*quicSession
}
// QUICTransporter creates a Transporter that is used by QUIC proxy client.
func QUICTransporter(config *QUICConfig) Transporter {
if config == nil {
config = &QUICConfig{}
}
return &quicTransporter{
config: config,
sessions: make(map[string]*quicSession),
}
}
func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) {
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr]
if !ok {
conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return
}
session = &quicSession{conn: conn}
tr.sessions[addr] = session
}
return session.conn, nil
}
func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
opts := &HandshakeOptions{}
for _, option := range options {
option(opts)
}
config := tr.config
if opts.QUICConfig != nil {
config = opts.QUICConfig
}
if config.TLSConfig == nil {
config.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
return nil, errors.New("quic: unrecognized connection")
}
if !ok || session.session == nil {
s, err := tr.initSession(opts.Addr, conn, config)
if err != nil {
conn.Close()
delete(tr.sessions, opts.Addr)
return nil, err
}
session = s
tr.sessions[opts.Addr] = session
}
cc, err := session.GetConn()
if err != nil {
session.Close()
delete(tr.sessions, opts.Addr)
return nil, err
}
return cc, nil
}
func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) {
udpConn, ok := conn.(*net.UDPConn)
if !ok {
return nil, errors.New("quic: wrong connection type")
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
quicConfig := &quic.Config{
HandshakeTimeout: config.Timeout,
KeepAlive: config.KeepAlive,
}
session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig)
if err != nil {
log.Log("quic dial", err)
return nil, err
}
return &quicSession{conn: conn, session: session}, nil
}
func (tr *quicTransporter) Multiplex() bool {
return true
}
type QUICConfig struct {
TLSConfig *tls.Config
Timeout time.Duration
KeepAlive bool
}
type quicListener struct {
ln quic.Listener
connChan chan net.Conn
errChan chan error
}
// QUICListener creates a Listener for QUIC proxy server.
func QUICListener(addr string, config *QUICConfig) (Listener, error) {
if config == nil {
config = &QUICConfig{}
}
quicConfig := &quic.Config{
HandshakeTimeout: config.Timeout,
KeepAlive: config.KeepAlive,
}
ln, err := quic.ListenAddr(addr, config.TLSConfig, quicConfig)
if err != nil {
return nil, err
}
l := &quicListener{
ln: ln,
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),
}
go l.listenLoop()
return l, nil
}
func (l *quicListener) listenLoop() {
for {
session, err := l.ln.Accept()
if err != nil {
log.Log("[quic] accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.sessionLoop(session)
}
}
func (l *quicListener) sessionLoop(session quic.Session) {
log.Logf("[quic] %s <-> %s", session.RemoteAddr(), session.LocalAddr())
defer log.Logf("[quic] %s >-< %s", session.RemoteAddr(), session.LocalAddr())
for {
stream, err := session.AcceptStream()
if err != nil {
log.Log("[quic] accept stream:", err)
return
}
select {
case l.connChan <- &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()}:
default:
log.Logf("[quic] %s - %s: connection queue is full", session.RemoteAddr(), session.LocalAddr())
}
}
}
func (l *quicListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *quicListener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *quicListener) Close() error {
return l.ln.Close()
}
type quicConn struct {
quic.Stream
laddr net.Addr
raddr net.Addr
}
func (c *quicConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *quicConn) RemoteAddr() net.Addr {
return c.raddr
}

View File

@ -173,6 +173,7 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt
config := ssh.ClientConfig{ config := ssh.ClientConfig{
Timeout: opts.Timeout, Timeout: opts.Timeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
if opts.User != nil { if opts.User != nil {
config.User = opts.User.Username() config.User = opts.User.Username()

View File

@ -1,15 +0,0 @@
ISC License
Copyright (c) 2012-2013 Dave Collins <dave@davec.name>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

View File

@ -1,152 +0,0 @@
// Copyright (c) 2015 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when the code is not running on Google App Engine, compiled by GopherJS, and
// "-tags safe" is not added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// +build !js,!appengine,!safe,!disableunsafe
package spew
import (
"reflect"
"unsafe"
)
const (
// UnsafeDisabled is a build-time constant which specifies whether or
// not access to the unsafe package is available.
UnsafeDisabled = false
// ptrSize is the size of a pointer on the current arch.
ptrSize = unsafe.Sizeof((*byte)(nil))
)
var (
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the
// internal reflect.Value fields. These values are valid before golang
// commit ecccf07e7f9d which changed the format. The are also valid
// after commit 82f48826c6c7 which changed the format again to mirror
// the original format. Code in the init function updates these offsets
// as necessary.
offsetPtr = uintptr(ptrSize)
offsetScalar = uintptr(0)
offsetFlag = uintptr(ptrSize * 2)
// flagKindWidth and flagKindShift indicate various bits that the
// reflect package uses internally to track kind information.
//
// flagRO indicates whether or not the value field of a reflect.Value is
// read-only.
//
// flagIndir indicates whether the value field of a reflect.Value is
// the actual data or a pointer to the data.
//
// These values are valid before golang commit 90a7c3c86944 which
// changed their positions. Code in the init function updates these
// flags as necessary.
flagKindWidth = uintptr(5)
flagKindShift = uintptr(flagKindWidth - 1)
flagRO = uintptr(1 << 0)
flagIndir = uintptr(1 << 1)
)
func init() {
// Older versions of reflect.Value stored small integers directly in the
// ptr field (which is named val in the older versions). Versions
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named
// scalar for this purpose which unfortunately came before the flag
// field, so the offset of the flag field is different for those
// versions.
//
// This code constructs a new reflect.Value from a known small integer
// and checks if the size of the reflect.Value struct indicates it has
// the scalar field. When it does, the offsets are updated accordingly.
vv := reflect.ValueOf(0xf00)
if unsafe.Sizeof(vv) == (ptrSize * 4) {
offsetScalar = ptrSize * 2
offsetFlag = ptrSize * 3
}
// Commit 90a7c3c86944 changed the flag positions such that the low
// order bits are the kind. This code extracts the kind from the flags
// field and ensures it's the correct type. When it's not, the flag
// order has been changed to the newer format, so the flags are updated
// accordingly.
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag)
upfv := *(*uintptr)(upf)
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift)
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) {
flagKindShift = 0
flagRO = 1 << 5
flagIndir = 1 << 6
// Commit adf9b30e5594 modified the flags to separate the
// flagRO flag into two bits which specifies whether or not the
// field is embedded. This causes flagIndir to move over a bit
// and means that flagRO is the combination of either of the
// original flagRO bit and the new bit.
//
// This code detects the change by extracting what used to be
// the indirect bit to ensure it's set. When it's not, the flag
// order has been changed to the newer format, so the flags are
// updated accordingly.
if upfv&flagIndir == 0 {
flagRO = 3 << 5
flagIndir = 1 << 7
}
}
}
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
// the typical safety restrictions preventing access to unaddressable and
// unexported data. It works by digging the raw pointer to the underlying
// value out of the protected value and generating a new unprotected (unsafe)
// reflect.Value to it.
//
// This allows us to check for implementations of the Stringer and error
// interfaces to be used for pretty printing ordinarily unaddressable and
// inaccessible values such as unexported struct fields.
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
indirects := 1
vt := v.Type()
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
if rvf&flagIndir != 0 {
vt = reflect.PtrTo(v.Type())
indirects++
} else if offsetScalar != 0 {
// The value is in the scalar field when it's not one of the
// reference types.
switch vt.Kind() {
case reflect.Uintptr:
case reflect.Chan:
case reflect.Func:
case reflect.Map:
case reflect.Ptr:
case reflect.UnsafePointer:
default:
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
offsetScalar)
}
}
pv := reflect.NewAt(vt, upv)
rv = pv
for i := 0; i < indirects; i++ {
rv = rv.Elem()
}
return rv
}

View File

@ -1,38 +0,0 @@
// Copyright (c) 2015 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when the code is running on Google App Engine, compiled by GopherJS, or
// "-tags safe" is added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// +build js appengine safe disableunsafe
package spew
import "reflect"
const (
// UnsafeDisabled is a build-time constant which specifies whether or
// not access to the unsafe package is available.
UnsafeDisabled = true
)
// unsafeReflectValue typically converts the passed reflect.Value into a one
// that bypasses the typical safety restrictions preventing access to
// unaddressable and unexported data. However, doing this relies on access to
// the unsafe package. This is a stub version which simply returns the passed
// reflect.Value when the unsafe package is not available.
func unsafeReflectValue(v reflect.Value) reflect.Value {
return v
}

View File

@ -1,341 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"io"
"reflect"
"sort"
"strconv"
)
// Some constants in the form of bytes to avoid string overhead. This mirrors
// the technique used in the fmt package.
var (
panicBytes = []byte("(PANIC=")
plusBytes = []byte("+")
iBytes = []byte("i")
trueBytes = []byte("true")
falseBytes = []byte("false")
interfaceBytes = []byte("(interface {})")
commaNewlineBytes = []byte(",\n")
newlineBytes = []byte("\n")
openBraceBytes = []byte("{")
openBraceNewlineBytes = []byte("{\n")
closeBraceBytes = []byte("}")
asteriskBytes = []byte("*")
colonBytes = []byte(":")
colonSpaceBytes = []byte(": ")
openParenBytes = []byte("(")
closeParenBytes = []byte(")")
spaceBytes = []byte(" ")
pointerChainBytes = []byte("->")
nilAngleBytes = []byte("<nil>")
maxNewlineBytes = []byte("<max depth reached>\n")
maxShortBytes = []byte("<max>")
circularBytes = []byte("<already shown>")
circularShortBytes = []byte("<shown>")
invalidAngleBytes = []byte("<invalid>")
openBracketBytes = []byte("[")
closeBracketBytes = []byte("]")
percentBytes = []byte("%")
precisionBytes = []byte(".")
openAngleBytes = []byte("<")
closeAngleBytes = []byte(">")
openMapBytes = []byte("map[")
closeMapBytes = []byte("]")
lenEqualsBytes = []byte("len=")
capEqualsBytes = []byte("cap=")
)
// hexDigits is used to map a decimal value to a hex digit.
var hexDigits = "0123456789abcdef"
// catchPanic handles any panics that might occur during the handleMethods
// calls.
func catchPanic(w io.Writer, v reflect.Value) {
if err := recover(); err != nil {
w.Write(panicBytes)
fmt.Fprintf(w, "%v", err)
w.Write(closeParenBytes)
}
}
// handleMethods attempts to call the Error and String methods on the underlying
// type the passed reflect.Value represents and outputes the result to Writer w.
//
// It handles panics in any called methods by catching and displaying the error
// as the formatted value.
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
// We need an interface to check if the type implements the error or
// Stringer interface. However, the reflect package won't give us an
// interface on certain things like unexported struct fields in order
// to enforce visibility rules. We use unsafe, when it's available,
// to bypass these restrictions since this package does not mutate the
// values.
if !v.CanInterface() {
if UnsafeDisabled {
return false
}
v = unsafeReflectValue(v)
}
// Choose whether or not to do error and Stringer interface lookups against
// the base type or a pointer to the base type depending on settings.
// Technically calling one of these methods with a pointer receiver can
// mutate the value, however, types which choose to satisify an error or
// Stringer interface with a pointer receiver should not be mutating their
// state inside these interface methods.
if !cs.DisablePointerMethods && !UnsafeDisabled && !v.CanAddr() {
v = unsafeReflectValue(v)
}
if v.CanAddr() {
v = v.Addr()
}
// Is it an error or Stringer?
switch iface := v.Interface().(type) {
case error:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.Error()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.Error()))
return true
case fmt.Stringer:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.String()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.String()))
return true
}
return false
}
// printBool outputs a boolean value as true or false to Writer w.
func printBool(w io.Writer, val bool) {
if val {
w.Write(trueBytes)
} else {
w.Write(falseBytes)
}
}
// printInt outputs a signed integer value to Writer w.
func printInt(w io.Writer, val int64, base int) {
w.Write([]byte(strconv.FormatInt(val, base)))
}
// printUint outputs an unsigned integer value to Writer w.
func printUint(w io.Writer, val uint64, base int) {
w.Write([]byte(strconv.FormatUint(val, base)))
}
// printFloat outputs a floating point value using the specified precision,
// which is expected to be 32 or 64bit, to Writer w.
func printFloat(w io.Writer, val float64, precision int) {
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
}
// printComplex outputs a complex value using the specified float precision
// for the real and imaginary parts to Writer w.
func printComplex(w io.Writer, c complex128, floatPrecision int) {
r := real(c)
w.Write(openParenBytes)
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
i := imag(c)
if i >= 0 {
w.Write(plusBytes)
}
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
w.Write(iBytes)
w.Write(closeParenBytes)
}
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x'
// prefix to Writer w.
func printHexPtr(w io.Writer, p uintptr) {
// Null pointer.
num := uint64(p)
if num == 0 {
w.Write(nilAngleBytes)
return
}
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
buf := make([]byte, 18)
// It's simpler to construct the hex string right to left.
base := uint64(16)
i := len(buf) - 1
for num >= base {
buf[i] = hexDigits[num%base]
num /= base
i--
}
buf[i] = hexDigits[num]
// Add '0x' prefix.
i--
buf[i] = 'x'
i--
buf[i] = '0'
// Strip unused leading bytes.
buf = buf[i:]
w.Write(buf)
}
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
// elements to be sorted.
type valuesSorter struct {
values []reflect.Value
strings []string // either nil or same len and values
cs *ConfigState
}
// newValuesSorter initializes a valuesSorter instance, which holds a set of
// surrogate keys on which the data should be sorted. It uses flags in
// ConfigState to decide if and how to populate those surrogate keys.
func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface {
vs := &valuesSorter{values: values, cs: cs}
if canSortSimply(vs.values[0].Kind()) {
return vs
}
if !cs.DisableMethods {
vs.strings = make([]string, len(values))
for i := range vs.values {
b := bytes.Buffer{}
if !handleMethods(cs, &b, vs.values[i]) {
vs.strings = nil
break
}
vs.strings[i] = b.String()
}
}
if vs.strings == nil && cs.SpewKeys {
vs.strings = make([]string, len(values))
for i := range vs.values {
vs.strings[i] = Sprintf("%#v", vs.values[i].Interface())
}
}
return vs
}
// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted
// directly, or whether it should be considered for sorting by surrogate keys
// (if the ConfigState allows it).
func canSortSimply(kind reflect.Kind) bool {
// This switch parallels valueSortLess, except for the default case.
switch kind {
case reflect.Bool:
return true
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return true
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return true
case reflect.Float32, reflect.Float64:
return true
case reflect.String:
return true
case reflect.Uintptr:
return true
case reflect.Array:
return true
}
return false
}
// Len returns the number of values in the slice. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Len() int {
return len(s.values)
}
// Swap swaps the values at the passed indices. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Swap(i, j int) {
s.values[i], s.values[j] = s.values[j], s.values[i]
if s.strings != nil {
s.strings[i], s.strings[j] = s.strings[j], s.strings[i]
}
}
// valueSortLess returns whether the first value should sort before the second
// value. It is used by valueSorter.Less as part of the sort.Interface
// implementation.
func valueSortLess(a, b reflect.Value) bool {
switch a.Kind() {
case reflect.Bool:
return !a.Bool() && b.Bool()
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return a.Int() < b.Int()
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return a.Uint() < b.Uint()
case reflect.Float32, reflect.Float64:
return a.Float() < b.Float()
case reflect.String:
return a.String() < b.String()
case reflect.Uintptr:
return a.Uint() < b.Uint()
case reflect.Array:
// Compare the contents of both arrays.
l := a.Len()
for i := 0; i < l; i++ {
av := a.Index(i)
bv := b.Index(i)
if av.Interface() == bv.Interface() {
continue
}
return valueSortLess(av, bv)
}
}
return a.String() < b.String()
}
// Less returns whether the value at index i should sort before the
// value at index j. It is part of the sort.Interface implementation.
func (s *valuesSorter) Less(i, j int) bool {
if s.strings == nil {
return valueSortLess(s.values[i], s.values[j])
}
return s.strings[i] < s.strings[j]
}
// sortValues is a sort function that handles both native types and any type that
// can be converted to error or Stringer. Other inputs are sorted according to
// their Value.String() value to ensure display stability.
func sortValues(values []reflect.Value, cs *ConfigState) {
if len(values) == 0 {
return
}
sort.Sort(newValuesSorter(values, cs))
}

View File

@ -1,306 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"io"
"os"
)
// ConfigState houses the configuration options used by spew to format and
// display values. There is a global instance, Config, that is used to control
// all top-level Formatter and Dump functionality. Each ConfigState instance
// provides methods equivalent to the top-level functions.
//
// The zero value for ConfigState provides no indentation. You would typically
// want to set it to a space or a tab.
//
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
// with default settings. See the documentation of NewDefaultConfig for default
// values.
type ConfigState struct {
// Indent specifies the string to use for each indentation level. The
// global config instance that all top-level functions use set this to a
// single space by default. If you would like more indentation, you might
// set this to a tab with "\t" or perhaps two spaces with " ".
Indent string
// MaxDepth controls the maximum number of levels to descend into nested
// data structures. The default, 0, means there is no limit.
//
// NOTE: Circular data structures are properly detected, so it is not
// necessary to set this value unless you specifically want to limit deeply
// nested data structures.
MaxDepth int
// DisableMethods specifies whether or not error and Stringer interfaces are
// invoked for types that implement them.
DisableMethods bool
// DisablePointerMethods specifies whether or not to check for and invoke
// error and Stringer interfaces on types which only accept a pointer
// receiver when the current type is not a pointer.
//
// NOTE: This might be an unsafe action since calling one of these methods
// with a pointer receiver could technically mutate the value, however,
// in practice, types which choose to satisify an error or Stringer
// interface with a pointer receiver should not be mutating their state
// inside these interface methods. As a result, this option relies on
// access to the unsafe package, so it will not have any effect when
// running in environments without access to the unsafe package such as
// Google App Engine or with the "safe" build tag specified.
DisablePointerMethods bool
// DisablePointerAddresses specifies whether to disable the printing of
// pointer addresses. This is useful when diffing data structures in tests.
DisablePointerAddresses bool
// DisableCapacities specifies whether to disable the printing of capacities
// for arrays, slices, maps and channels. This is useful when diffing
// data structures in tests.
DisableCapacities bool
// ContinueOnMethod specifies whether or not recursion should continue once
// a custom error or Stringer interface is invoked. The default, false,
// means it will print the results of invoking the custom error or Stringer
// interface and return immediately instead of continuing to recurse into
// the internals of the data type.
//
// NOTE: This flag does not have any effect if method invocation is disabled
// via the DisableMethods or DisablePointerMethods options.
ContinueOnMethod bool
// SortKeys specifies map keys should be sorted before being printed. Use
// this to have a more deterministic, diffable output. Note that only
// native types (bool, int, uint, floats, uintptr and string) and types
// that support the error or Stringer interfaces (if methods are
// enabled) are supported, with other types sorted according to the
// reflect.Value.String() output which guarantees display stability.
SortKeys bool
// SpewKeys specifies that, as a last resort attempt, map keys should
// be spewed to strings and sorted by those strings. This is only
// considered if SortKeys is true.
SpewKeys bool
}
// Config is the active configuration of the top-level functions.
// The configuration can be changed by modifying the contents of spew.Config.
var Config = ConfigState{Indent: " "}
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the formatted string as a value that satisfies error. See NewFormatter
// for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
return fmt.Errorf(format, c.convertArgs(a)...)
}
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprint(w, c.convertArgs(a)...)
}
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
return fmt.Fprintf(w, format, c.convertArgs(a)...)
}
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
// passed with a Formatter interface returned by c.NewFormatter. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprintln(w, c.convertArgs(a)...)
}
// Print is a wrapper for fmt.Print that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
return fmt.Print(c.convertArgs(a)...)
}
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
return fmt.Printf(format, c.convertArgs(a)...)
}
// Println is a wrapper for fmt.Println that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
return fmt.Println(c.convertArgs(a)...)
}
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprint(a ...interface{}) string {
return fmt.Sprint(c.convertArgs(a)...)
}
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
return fmt.Sprintf(format, c.convertArgs(a)...)
}
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
// were passed with a Formatter interface returned by c.NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprintln(a ...interface{}) string {
return fmt.Sprintln(c.convertArgs(a)...)
}
/*
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
interface. As a result, it integrates cleanly with standard fmt package
printing functions. The formatter is useful for inline printing of smaller data
types similar to the standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Typically this function shouldn't be called directly. It is much easier to make
use of the custom formatter by calling one of the convenience functions such as
c.Printf, c.Println, or c.Printf.
*/
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
return newFormatter(c, v)
}
// Fdump formats and displays the passed arguments to io.Writer w. It formats
// exactly the same as Dump.
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
fdump(c, w, a...)
}
/*
Dump displays the passed parameters to standard out with newlines, customizable
indentation, and additional debug information such as complete types and all
pointer addresses used to indirect to the final value. It provides the
following features over the built-in printing facilities provided by the fmt
package:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output
The configuration options are controlled by modifying the public members
of c. See ConfigState for options documentation.
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
get the formatted result as a string.
*/
func (c *ConfigState) Dump(a ...interface{}) {
fdump(c, os.Stdout, a...)
}
// Sdump returns a string with the passed arguments formatted exactly the same
// as Dump.
func (c *ConfigState) Sdump(a ...interface{}) string {
var buf bytes.Buffer
fdump(c, &buf, a...)
return buf.String()
}
// convertArgs accepts a slice of arguments and returns a slice of the same
// length with each argument converted to a spew Formatter interface using
// the ConfigState associated with s.
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
formatters = make([]interface{}, len(args))
for index, arg := range args {
formatters[index] = newFormatter(c, arg)
}
return formatters
}
// NewDefaultConfig returns a ConfigState with the following default settings.
//
// Indent: " "
// MaxDepth: 0
// DisableMethods: false
// DisablePointerMethods: false
// ContinueOnMethod: false
// SortKeys: false
func NewDefaultConfig() *ConfigState {
return &ConfigState{Indent: " "}
}

View File

@ -1,202 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
/*
Package spew implements a deep pretty printer for Go data structures to aid in
debugging.
A quick overview of the additional features spew provides over the built-in
printing facilities for Go data types are as follows:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output (only when using
Dump style)
There are two different approaches spew allows for dumping Go data structures:
* Dump style which prints with newlines, customizable indentation,
and additional debug information such as types and all pointer addresses
used to indirect to the final value
* A custom Formatter interface that integrates cleanly with the standard fmt
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
similar to the default %v while providing the additional functionality
outlined above and passing unsupported format verbs such as %x and %q
along to fmt
Quick Start
This section demonstrates how to quickly get started with spew. See the
sections below for further details on formatting and configuration options.
To dump a variable with full newlines, indentation, type, and pointer
information use Dump, Fdump, or Sdump:
spew.Dump(myVar1, myVar2, ...)
spew.Fdump(someWriter, myVar1, myVar2, ...)
str := spew.Sdump(myVar1, myVar2, ...)
Alternatively, if you would prefer to use format strings with a compacted inline
printing style, use the convenience wrappers Printf, Fprintf, etc with
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
%#+v (adds types and pointer addresses):
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
Configuration Options
Configuration of spew is handled by fields in the ConfigState type. For
convenience, all of the top-level functions use a global state available
via the spew.Config global.
It is also possible to create a ConfigState instance that provides methods
equivalent to the top-level functions. This allows concurrent configuration
options. See the ConfigState documentation for more details.
The following configuration options are available:
* Indent
String to use for each indentation level for Dump functions.
It is a single space by default. A popular alternative is "\t".
* MaxDepth
Maximum number of levels to descend into nested data structures.
There is no limit by default.
* DisableMethods
Disables invocation of error and Stringer interface methods.
Method invocation is enabled by default.
* DisablePointerMethods
Disables invocation of error and Stringer interface methods on types
which only accept pointer receivers from non-pointer variables.
Pointer method invocation is enabled by default.
* ContinueOnMethod
Enables recursion into types after invoking error and Stringer interface
methods. Recursion after method invocation is disabled by default.
* SortKeys
Specifies map keys should be sorted before being printed. Use
this to have a more deterministic, diffable output. Note that
only native types (bool, int, uint, floats, uintptr and string)
and types which implement error or Stringer interfaces are
supported with other types sorted according to the
reflect.Value.String() output which guarantees display
stability. Natural map order is used by default.
* SpewKeys
Specifies that, as a last resort attempt, map keys should be
spewed to strings and sorted by those strings. This is only
considered if SortKeys is true.
Dump Usage
Simply call spew.Dump with a list of variables you want to dump:
spew.Dump(myVar1, myVar2, ...)
You may also call spew.Fdump if you would prefer to output to an arbitrary
io.Writer. For example, to dump to standard error:
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
A third option is to call spew.Sdump to get the formatted output as a string:
str := spew.Sdump(myVar1, myVar2, ...)
Sample Dump Output
See the Dump example for details on the setup of the types and variables being
shown here.
(main.Foo) {
unexportedField: (*main.Bar)(0xf84002e210)({
flag: (main.Flag) flagTwo,
data: (uintptr) <nil>
}),
ExportedField: (map[interface {}]interface {}) (len=1) {
(string) (len=3) "one": (bool) true
}
}
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
command as shown.
([]uint8) (len=32 cap=32) {
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
00000020 31 32 |12|
}
Custom Formatter
Spew provides a custom formatter that implements the fmt.Formatter interface
so that it integrates cleanly with standard fmt package printing functions. The
formatter is useful for inline printing of smaller data types similar to the
standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Custom Formatter Usage
The simplest way to make use of the spew custom formatter is to call one of the
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
functions have syntax you are most likely already familiar with:
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
spew.Println(myVar, myVar2)
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
See the Index for the full list convenience functions.
Sample Formatter Output
Double pointer to a uint8:
%v: <**>5
%+v: <**>(0xf8400420d0->0xf8400420c8)5
%#v: (**uint8)5
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
Pointer to circular struct with a uint8 field and a pointer to itself:
%v: <*>{1 <*><shown>}
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
See the Printf example for details on the setup of variables being shown
here.
Errors
Since it is possible for custom Stringer/error interfaces to panic, spew
detects them and handles them internally by printing the panic information
inline with the output. Since spew is intended to provide deep pretty printing
capabilities on structures, it intentionally does not return any errors.
*/
package spew

View File

@ -1,509 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"os"
"reflect"
"regexp"
"strconv"
"strings"
)
var (
// uint8Type is a reflect.Type representing a uint8. It is used to
// convert cgo types to uint8 slices for hexdumping.
uint8Type = reflect.TypeOf(uint8(0))
// cCharRE is a regular expression that matches a cgo char.
// It is used to detect character arrays to hexdump them.
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$")
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
// char. It is used to detect unsigned character arrays to hexdump
// them.
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$")
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
// It is used to detect uint8_t arrays to hexdump them.
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$")
)
// dumpState contains information about the state of a dump operation.
type dumpState struct {
w io.Writer
depth int
pointers map[uintptr]int
ignoreNextType bool
ignoreNextIndent bool
cs *ConfigState
}
// indent performs indentation according to the depth level and cs.Indent
// option.
func (d *dumpState) indent() {
if d.ignoreNextIndent {
d.ignoreNextIndent = false
return
}
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
}
// unpackValue returns values inside of non-nil interfaces when possible.
// This is useful for data types like structs, arrays, slices, and maps which
// can contain varying types packed inside an interface.
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface && !v.IsNil() {
v = v.Elem()
}
return v
}
// dumpPtr handles formatting of pointers by indirecting them as necessary.
func (d *dumpState) dumpPtr(v reflect.Value) {
// Remove pointers at or below the current depth from map used to detect
// circular refs.
for k, depth := range d.pointers {
if depth >= d.depth {
delete(d.pointers, k)
}
}
// Keep list of all dereferenced pointers to show later.
pointerChain := make([]uintptr, 0)
// Figure out how many levels of indirection there are by dereferencing
// pointers and unpacking interfaces down the chain while detecting circular
// references.
nilFound := false
cycleFound := false
indirects := 0
ve := v
for ve.Kind() == reflect.Ptr {
if ve.IsNil() {
nilFound = true
break
}
indirects++
addr := ve.Pointer()
pointerChain = append(pointerChain, addr)
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
cycleFound = true
indirects--
break
}
d.pointers[addr] = d.depth
ve = ve.Elem()
if ve.Kind() == reflect.Interface {
if ve.IsNil() {
nilFound = true
break
}
ve = ve.Elem()
}
}
// Display type information.
d.w.Write(openParenBytes)
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
d.w.Write([]byte(ve.Type().String()))
d.w.Write(closeParenBytes)
// Display pointer information.
if !d.cs.DisablePointerAddresses && len(pointerChain) > 0 {
d.w.Write(openParenBytes)
for i, addr := range pointerChain {
if i > 0 {
d.w.Write(pointerChainBytes)
}
printHexPtr(d.w, addr)
}
d.w.Write(closeParenBytes)
}
// Display dereferenced value.
d.w.Write(openParenBytes)
switch {
case nilFound == true:
d.w.Write(nilAngleBytes)
case cycleFound == true:
d.w.Write(circularBytes)
default:
d.ignoreNextType = true
d.dump(ve)
}
d.w.Write(closeParenBytes)
}
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
// reflection) arrays and slices are dumped in hexdump -C fashion.
func (d *dumpState) dumpSlice(v reflect.Value) {
// Determine whether this type should be hex dumped or not. Also,
// for types which should be hexdumped, try to use the underlying data
// first, then fall back to trying to convert them to a uint8 slice.
var buf []uint8
doConvert := false
doHexDump := false
numEntries := v.Len()
if numEntries > 0 {
vt := v.Index(0).Type()
vts := vt.String()
switch {
// C types that need to be converted.
case cCharRE.MatchString(vts):
fallthrough
case cUnsignedCharRE.MatchString(vts):
fallthrough
case cUint8tCharRE.MatchString(vts):
doConvert = true
// Try to use existing uint8 slices and fall back to converting
// and copying if that fails.
case vt.Kind() == reflect.Uint8:
// We need an addressable interface to convert the type
// to a byte slice. However, the reflect package won't
// give us an interface on certain things like
// unexported struct fields in order to enforce
// visibility rules. We use unsafe, when available, to
// bypass these restrictions since this package does not
// mutate the values.
vs := v
if !vs.CanInterface() || !vs.CanAddr() {
vs = unsafeReflectValue(vs)
}
if !UnsafeDisabled {
vs = vs.Slice(0, numEntries)
// Use the existing uint8 slice if it can be
// type asserted.
iface := vs.Interface()
if slice, ok := iface.([]uint8); ok {
buf = slice
doHexDump = true
break
}
}
// The underlying data needs to be converted if it can't
// be type asserted to a uint8 slice.
doConvert = true
}
// Copy and convert the underlying type if needed.
if doConvert && vt.ConvertibleTo(uint8Type) {
// Convert and copy each element into a uint8 byte
// slice.
buf = make([]uint8, numEntries)
for i := 0; i < numEntries; i++ {
vv := v.Index(i)
buf[i] = uint8(vv.Convert(uint8Type).Uint())
}
doHexDump = true
}
}
// Hexdump the entire slice as needed.
if doHexDump {
indent := strings.Repeat(d.cs.Indent, d.depth)
str := indent + hex.Dump(buf)
str = strings.Replace(str, "\n", "\n"+indent, -1)
str = strings.TrimRight(str, d.cs.Indent)
d.w.Write([]byte(str))
return
}
// Recursively call dump for each item.
for i := 0; i < numEntries; i++ {
d.dump(d.unpackValue(v.Index(i)))
if i < (numEntries - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
// dump is the main workhorse for dumping a value. It uses the passed reflect
// value to figure out what kind of object we are dealing with and formats it
// appropriately. It is a recursive function, however circular data structures
// are detected and handled properly.
func (d *dumpState) dump(v reflect.Value) {
// Handle invalid reflect values immediately.
kind := v.Kind()
if kind == reflect.Invalid {
d.w.Write(invalidAngleBytes)
return
}
// Handle pointers specially.
if kind == reflect.Ptr {
d.indent()
d.dumpPtr(v)
return
}
// Print type information unless already handled elsewhere.
if !d.ignoreNextType {
d.indent()
d.w.Write(openParenBytes)
d.w.Write([]byte(v.Type().String()))
d.w.Write(closeParenBytes)
d.w.Write(spaceBytes)
}
d.ignoreNextType = false
// Display length and capacity if the built-in len and cap functions
// work with the value's kind and the len/cap itself is non-zero.
valueLen, valueCap := 0, 0
switch v.Kind() {
case reflect.Array, reflect.Slice, reflect.Chan:
valueLen, valueCap = v.Len(), v.Cap()
case reflect.Map, reflect.String:
valueLen = v.Len()
}
if valueLen != 0 || !d.cs.DisableCapacities && valueCap != 0 {
d.w.Write(openParenBytes)
if valueLen != 0 {
d.w.Write(lenEqualsBytes)
printInt(d.w, int64(valueLen), 10)
}
if !d.cs.DisableCapacities && valueCap != 0 {
if valueLen != 0 {
d.w.Write(spaceBytes)
}
d.w.Write(capEqualsBytes)
printInt(d.w, int64(valueCap), 10)
}
d.w.Write(closeParenBytes)
d.w.Write(spaceBytes)
}
// Call Stringer/error interfaces if they exist and the handle methods flag
// is enabled
if !d.cs.DisableMethods {
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
if handled := handleMethods(d.cs, d.w, v); handled {
return
}
}
}
switch kind {
case reflect.Invalid:
// Do nothing. We should never get here since invalid has already
// been handled above.
case reflect.Bool:
printBool(d.w, v.Bool())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
printInt(d.w, v.Int(), 10)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
printUint(d.w, v.Uint(), 10)
case reflect.Float32:
printFloat(d.w, v.Float(), 32)
case reflect.Float64:
printFloat(d.w, v.Float(), 64)
case reflect.Complex64:
printComplex(d.w, v.Complex(), 32)
case reflect.Complex128:
printComplex(d.w, v.Complex(), 64)
case reflect.Slice:
if v.IsNil() {
d.w.Write(nilAngleBytes)
break
}
fallthrough
case reflect.Array:
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
d.dumpSlice(v)
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.String:
d.w.Write([]byte(strconv.Quote(v.String())))
case reflect.Interface:
// The only time we should get here is for nil interfaces due to
// unpackValue calls.
if v.IsNil() {
d.w.Write(nilAngleBytes)
}
case reflect.Ptr:
// Do nothing. We should never get here since pointers have already
// been handled above.
case reflect.Map:
// nil maps should be indicated as different than empty maps
if v.IsNil() {
d.w.Write(nilAngleBytes)
break
}
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
numEntries := v.Len()
keys := v.MapKeys()
if d.cs.SortKeys {
sortValues(keys, d.cs)
}
for i, key := range keys {
d.dump(d.unpackValue(key))
d.w.Write(colonSpaceBytes)
d.ignoreNextIndent = true
d.dump(d.unpackValue(v.MapIndex(key)))
if i < (numEntries - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.Struct:
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
vt := v.Type()
numFields := v.NumField()
for i := 0; i < numFields; i++ {
d.indent()
vtf := vt.Field(i)
d.w.Write([]byte(vtf.Name))
d.w.Write(colonSpaceBytes)
d.ignoreNextIndent = true
d.dump(d.unpackValue(v.Field(i)))
if i < (numFields - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.Uintptr:
printHexPtr(d.w, uintptr(v.Uint()))
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
printHexPtr(d.w, v.Pointer())
// There were not any other types at the time this code was written, but
// fall back to letting the default fmt package handle it in case any new
// types are added.
default:
if v.CanInterface() {
fmt.Fprintf(d.w, "%v", v.Interface())
} else {
fmt.Fprintf(d.w, "%v", v.String())
}
}
}
// fdump is a helper function to consolidate the logic from the various public
// methods which take varying writers and config states.
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
for _, arg := range a {
if arg == nil {
w.Write(interfaceBytes)
w.Write(spaceBytes)
w.Write(nilAngleBytes)
w.Write(newlineBytes)
continue
}
d := dumpState{w: w, cs: cs}
d.pointers = make(map[uintptr]int)
d.dump(reflect.ValueOf(arg))
d.w.Write(newlineBytes)
}
}
// Fdump formats and displays the passed arguments to io.Writer w. It formats
// exactly the same as Dump.
func Fdump(w io.Writer, a ...interface{}) {
fdump(&Config, w, a...)
}
// Sdump returns a string with the passed arguments formatted exactly the same
// as Dump.
func Sdump(a ...interface{}) string {
var buf bytes.Buffer
fdump(&Config, &buf, a...)
return buf.String()
}
/*
Dump displays the passed parameters to standard out with newlines, customizable
indentation, and additional debug information such as complete types and all
pointer addresses used to indirect to the final value. It provides the
following features over the built-in printing facilities provided by the fmt
package:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output
The configuration options are controlled by an exported package global,
spew.Config. See ConfigState for options documentation.
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
get the formatted result as a string.
*/
func Dump(a ...interface{}) {
fdump(&Config, os.Stdout, a...)
}

View File

@ -1,419 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"reflect"
"strconv"
"strings"
)
// supportedFlags is a list of all the character flags supported by fmt package.
const supportedFlags = "0-+# "
// formatState implements the fmt.Formatter interface and contains information
// about the state of a formatting operation. The NewFormatter function can
// be used to get a new Formatter which can be used directly as arguments
// in standard fmt package printing calls.
type formatState struct {
value interface{}
fs fmt.State
depth int
pointers map[uintptr]int
ignoreNextType bool
cs *ConfigState
}
// buildDefaultFormat recreates the original format string without precision
// and width information to pass in to fmt.Sprintf in the case of an
// unrecognized type. Unless new types are added to the language, this
// function won't ever be called.
func (f *formatState) buildDefaultFormat() (format string) {
buf := bytes.NewBuffer(percentBytes)
for _, flag := range supportedFlags {
if f.fs.Flag(int(flag)) {
buf.WriteRune(flag)
}
}
buf.WriteRune('v')
format = buf.String()
return format
}
// constructOrigFormat recreates the original format string including precision
// and width information to pass along to the standard fmt package. This allows
// automatic deferral of all format strings this package doesn't support.
func (f *formatState) constructOrigFormat(verb rune) (format string) {
buf := bytes.NewBuffer(percentBytes)
for _, flag := range supportedFlags {
if f.fs.Flag(int(flag)) {
buf.WriteRune(flag)
}
}
if width, ok := f.fs.Width(); ok {
buf.WriteString(strconv.Itoa(width))
}
if precision, ok := f.fs.Precision(); ok {
buf.Write(precisionBytes)
buf.WriteString(strconv.Itoa(precision))
}
buf.WriteRune(verb)
format = buf.String()
return format
}
// unpackValue returns values inside of non-nil interfaces when possible and
// ensures that types for values which have been unpacked from an interface
// are displayed when the show types flag is also set.
// This is useful for data types like structs, arrays, slices, and maps which
// can contain varying types packed inside an interface.
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface {
f.ignoreNextType = false
if !v.IsNil() {
v = v.Elem()
}
}
return v
}
// formatPtr handles formatting of pointers by indirecting them as necessary.
func (f *formatState) formatPtr(v reflect.Value) {
// Display nil if top level pointer is nil.
showTypes := f.fs.Flag('#')
if v.IsNil() && (!showTypes || f.ignoreNextType) {
f.fs.Write(nilAngleBytes)
return
}
// Remove pointers at or below the current depth from map used to detect
// circular refs.
for k, depth := range f.pointers {
if depth >= f.depth {
delete(f.pointers, k)
}
}
// Keep list of all dereferenced pointers to possibly show later.
pointerChain := make([]uintptr, 0)
// Figure out how many levels of indirection there are by derferencing
// pointers and unpacking interfaces down the chain while detecting circular
// references.
nilFound := false
cycleFound := false
indirects := 0
ve := v
for ve.Kind() == reflect.Ptr {
if ve.IsNil() {
nilFound = true
break
}
indirects++
addr := ve.Pointer()
pointerChain = append(pointerChain, addr)
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
cycleFound = true
indirects--
break
}
f.pointers[addr] = f.depth
ve = ve.Elem()
if ve.Kind() == reflect.Interface {
if ve.IsNil() {
nilFound = true
break
}
ve = ve.Elem()
}
}
// Display type or indirection level depending on flags.
if showTypes && !f.ignoreNextType {
f.fs.Write(openParenBytes)
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
f.fs.Write([]byte(ve.Type().String()))
f.fs.Write(closeParenBytes)
} else {
if nilFound || cycleFound {
indirects += strings.Count(ve.Type().String(), "*")
}
f.fs.Write(openAngleBytes)
f.fs.Write([]byte(strings.Repeat("*", indirects)))
f.fs.Write(closeAngleBytes)
}
// Display pointer information depending on flags.
if f.fs.Flag('+') && (len(pointerChain) > 0) {
f.fs.Write(openParenBytes)
for i, addr := range pointerChain {
if i > 0 {
f.fs.Write(pointerChainBytes)
}
printHexPtr(f.fs, addr)
}
f.fs.Write(closeParenBytes)
}
// Display dereferenced value.
switch {
case nilFound == true:
f.fs.Write(nilAngleBytes)
case cycleFound == true:
f.fs.Write(circularShortBytes)
default:
f.ignoreNextType = true
f.format(ve)
}
}
// format is the main workhorse for providing the Formatter interface. It
// uses the passed reflect value to figure out what kind of object we are
// dealing with and formats it appropriately. It is a recursive function,
// however circular data structures are detected and handled properly.
func (f *formatState) format(v reflect.Value) {
// Handle invalid reflect values immediately.
kind := v.Kind()
if kind == reflect.Invalid {
f.fs.Write(invalidAngleBytes)
return
}
// Handle pointers specially.
if kind == reflect.Ptr {
f.formatPtr(v)
return
}
// Print type information unless already handled elsewhere.
if !f.ignoreNextType && f.fs.Flag('#') {
f.fs.Write(openParenBytes)
f.fs.Write([]byte(v.Type().String()))
f.fs.Write(closeParenBytes)
}
f.ignoreNextType = false
// Call Stringer/error interfaces if they exist and the handle methods
// flag is enabled.
if !f.cs.DisableMethods {
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
if handled := handleMethods(f.cs, f.fs, v); handled {
return
}
}
}
switch kind {
case reflect.Invalid:
// Do nothing. We should never get here since invalid has already
// been handled above.
case reflect.Bool:
printBool(f.fs, v.Bool())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
printInt(f.fs, v.Int(), 10)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
printUint(f.fs, v.Uint(), 10)
case reflect.Float32:
printFloat(f.fs, v.Float(), 32)
case reflect.Float64:
printFloat(f.fs, v.Float(), 64)
case reflect.Complex64:
printComplex(f.fs, v.Complex(), 32)
case reflect.Complex128:
printComplex(f.fs, v.Complex(), 64)
case reflect.Slice:
if v.IsNil() {
f.fs.Write(nilAngleBytes)
break
}
fallthrough
case reflect.Array:
f.fs.Write(openBracketBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
numEntries := v.Len()
for i := 0; i < numEntries; i++ {
if i > 0 {
f.fs.Write(spaceBytes)
}
f.ignoreNextType = true
f.format(f.unpackValue(v.Index(i)))
}
}
f.depth--
f.fs.Write(closeBracketBytes)
case reflect.String:
f.fs.Write([]byte(v.String()))
case reflect.Interface:
// The only time we should get here is for nil interfaces due to
// unpackValue calls.
if v.IsNil() {
f.fs.Write(nilAngleBytes)
}
case reflect.Ptr:
// Do nothing. We should never get here since pointers have already
// been handled above.
case reflect.Map:
// nil maps should be indicated as different than empty maps
if v.IsNil() {
f.fs.Write(nilAngleBytes)
break
}
f.fs.Write(openMapBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
keys := v.MapKeys()
if f.cs.SortKeys {
sortValues(keys, f.cs)
}
for i, key := range keys {
if i > 0 {
f.fs.Write(spaceBytes)
}
f.ignoreNextType = true
f.format(f.unpackValue(key))
f.fs.Write(colonBytes)
f.ignoreNextType = true
f.format(f.unpackValue(v.MapIndex(key)))
}
}
f.depth--
f.fs.Write(closeMapBytes)
case reflect.Struct:
numFields := v.NumField()
f.fs.Write(openBraceBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
vt := v.Type()
for i := 0; i < numFields; i++ {
if i > 0 {
f.fs.Write(spaceBytes)
}
vtf := vt.Field(i)
if f.fs.Flag('+') || f.fs.Flag('#') {
f.fs.Write([]byte(vtf.Name))
f.fs.Write(colonBytes)
}
f.format(f.unpackValue(v.Field(i)))
}
}
f.depth--
f.fs.Write(closeBraceBytes)
case reflect.Uintptr:
printHexPtr(f.fs, uintptr(v.Uint()))
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
printHexPtr(f.fs, v.Pointer())
// There were not any other types at the time this code was written, but
// fall back to letting the default fmt package handle it if any get added.
default:
format := f.buildDefaultFormat()
if v.CanInterface() {
fmt.Fprintf(f.fs, format, v.Interface())
} else {
fmt.Fprintf(f.fs, format, v.String())
}
}
}
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
// details.
func (f *formatState) Format(fs fmt.State, verb rune) {
f.fs = fs
// Use standard formatting for verbs that are not v.
if verb != 'v' {
format := f.constructOrigFormat(verb)
fmt.Fprintf(fs, format, f.value)
return
}
if f.value == nil {
if fs.Flag('#') {
fs.Write(interfaceBytes)
}
fs.Write(nilAngleBytes)
return
}
f.format(reflect.ValueOf(f.value))
}
// newFormatter is a helper function to consolidate the logic from the various
// public methods which take varying config states.
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
fs := &formatState{value: v, cs: cs}
fs.pointers = make(map[uintptr]int)
return fs
}
/*
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
interface. As a result, it integrates cleanly with standard fmt package
printing functions. The formatter is useful for inline printing of smaller data
types similar to the standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Typically this function shouldn't be called directly. It is much easier to make
use of the custom formatter by calling one of the convenience functions such as
Printf, Println, or Fprintf.
*/
func NewFormatter(v interface{}) fmt.Formatter {
return newFormatter(&Config, v)
}

View File

@ -1,148 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"fmt"
"io"
)
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the formatted string as a value that satisfies error. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Errorf(format string, a ...interface{}) (err error) {
return fmt.Errorf(format, convertArgs(a)...)
}
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprint(w, convertArgs(a)...)
}
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
return fmt.Fprintf(w, format, convertArgs(a)...)
}
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
// passed with a default Formatter interface returned by NewFormatter. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprintln(w, convertArgs(a)...)
}
// Print is a wrapper for fmt.Print that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
func Print(a ...interface{}) (n int, err error) {
return fmt.Print(convertArgs(a)...)
}
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Printf(format string, a ...interface{}) (n int, err error) {
return fmt.Printf(format, convertArgs(a)...)
}
// Println is a wrapper for fmt.Println that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
func Println(a ...interface{}) (n int, err error) {
return fmt.Println(convertArgs(a)...)
}
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
func Sprint(a ...interface{}) string {
return fmt.Sprint(convertArgs(a)...)
}
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Sprintf(format string, a ...interface{}) string {
return fmt.Sprintf(format, convertArgs(a)...)
}
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
// were passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
func Sprintln(a ...interface{}) string {
return fmt.Sprintln(convertArgs(a)...)
}
// convertArgs accepts a slice of arguments and returns a slice of the same
// length with each argument converted to a default spew Formatter interface.
func convertArgs(args []interface{}) (formatters []interface{}) {
formatters = make([]interface{}, len(args))
for index, arg := range args {
formatters[index] = NewFormatter(arg)
}
return formatters
}

View File

@ -1,2 +1,2 @@
# pht # pht
Pure HTTP Tunnel - Tunnel over HTTP using only GET and POST requests. Plain HTTP Tunnel - Tunnel over HTTP using only GET and POST requests, NO Websocket, NO CONNECT method.

View File

@ -5,7 +5,7 @@ Leveled execution logs for Go.
This is an efficient pure Go implementation of leveled logs in the This is an efficient pure Go implementation of leveled logs in the
manner of the open source C++ package manner of the open source C++ package
http://code.google.com/p/google-glog https://github.com/google/glog
By binding methods to booleans it is possible to use the log package By binding methods to booleans it is possible to use the log package
without paying the expense of evaluating the arguments to the log. without paying the expense of evaluating the arguments to the log.

View File

@ -676,7 +676,10 @@ func (l *loggingT) output(s severity, buf *buffer, file string, line int, alsoTo
} }
} }
data := buf.Bytes() data := buf.Bytes()
if l.toStderr { if !flag.Parsed() {
os.Stderr.Write([]byte("ERROR: logging before flag.Parse: "))
os.Stderr.Write(data)
} else if l.toStderr {
os.Stderr.Write(data) os.Stderr.Write(data)
} else { } else {
if alsoToStderr || l.alsoToStderr || s >= l.stderrThreshold.get() { if alsoToStderr || l.alsoToStderr || s >= l.stderrThreshold.get() {

17
vendor/github.com/lucas-clemente/quic-go/Changelog.md generated vendored Normal file
View File

@ -0,0 +1,17 @@
# Changelog
## v0.6.0 (unreleased)
- Added `quic.Config` options for maximal flow control windows
- Add a `quic.Config` option for QUIC versions
- Add a `quic.Config` option to request truncation of the connection ID from a server
- Add a `quic.Config` option to configure the source address validation
- Add a `quic.Config` option to configure the handshake timeout
- Add a `quic.Config` option to configure keep-alive
- Implement `net.Conn`-style deadlines for streams
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details.
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details.
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
- Drop support for Go 1.7.
- Various bugfixes

View File

@ -1,4 +1,4 @@
# A QUIC server implementation in pure Go # A QUIC implementation in pure Go
<img src="docs/quic.png" width=303 height=124> <img src="docs/quic.png" width=303 height=124>
@ -7,32 +7,20 @@
[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master) [![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/) [![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/)
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go. While we're not far from being feature complete, there's still work to do regarding performance and security. At the moment, we do not recommend use in production systems. We appreciate any feedback :) quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
## Roadmap ## Roadmap
Done: quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
- Basic protocol with support for QUIC version 34-36
- QUIC client
- HTTP/2 support
- Crypto (RSA / ECDSA certificates, Curve25519 for key exchange, AES-GCM or Chacha20-Poly1305 as stream cipher)
- Loss detection and retransmission (currently fast retransmission & RTO)
- Flow Control
- Congestion control using cubic
Major TODOs:
- Security, especially DoS protections
- Performance
- Better packet loss detection
- Connection migration
## Guides ## Guides
Installing deps: We currently support Go 1.8+.
go get -t Installing and updating dependencies:
go get -t -u ./...
Running tests: Running tests:
@ -50,9 +38,13 @@ Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io /Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
### QUIC without HTTP/2
Take a look at [this echo example](example/echo/echo.go).
### Using the example client ### Using the example client
go run example/client/main.go https://quic.clemente.io go run example/client/main.go https://clemente.io
## Usage ## Usage
@ -67,14 +59,14 @@ h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to
### As a client ### As a client
See the [example client](example/client/main.go). Use a `QuicRoundTripper` as a `Transport` in a `http.Client`. See the [example client](example/client/main.go). Use a `h2quic.RoundTripper` as a `Transport` in a `http.Client`.
```go ```go
http.Client{ http.Client{
Transport: &h2quic.QuicRoundTripper{}, Transport: &h2quic.RoundTripper{},
} }
``` ```
## Building on Windows ## Contributing
Due to the low Windows timer resolution (see [StackOverflow question](http://stackoverflow.com/questions/37706834/high-resolution-timers-millisecond-precision-in-go-on-windows)) available with Go 1.6.x, some optimizations might not work when compiled with this version of the compiler. Please use Go 1.7 on Windows. We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [want-help](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aopen+is%3Aissue+label%3Awant-help). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.

View File

@ -9,21 +9,17 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error SentPacket(packet *Packet) error
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
SendingAllowed() bool
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
MaybeQueueRTOs()
DequeuePacketForRetransmission() (packet *Packet) DequeuePacketForRetransmission() (packet *Packet)
BytesInFlight() protocol.ByteCount
GetLeastUnacked() protocol.PacketNumber GetLeastUnacked() protocol.PacketNumber
SendingAllowed() bool GetAlarmTimeout() time.Time
CheckForError() error OnAlarm()
TimeOfFirstRTO() time.Time
} }
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
@ -31,5 +27,6 @@ type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame GetAckFrame() *frames.AckFrame
} }

View File

@ -13,8 +13,7 @@ type Packet struct {
PacketNumber protocol.PacketNumber PacketNumber protocol.PacketNumber
Frames []frames.Frame Frames []frames.Frame
Length protocol.ByteCount Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
MissingReports uint8
SendTime time.Time SendTime time.Time
} }

View File

@ -8,13 +8,6 @@ import (
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
) )
var (
// ErrDuplicatePacket occurres when a duplicate packet is received
ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet")
// ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored
ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting")
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct { type receivedPacketHandler struct {
@ -30,19 +23,13 @@ type receivedPacketHandler struct {
retransmittablePacketsReceivedSinceLastAck int retransmittablePacketsReceivedSinceLastAck int
ackQueued bool ackQueued bool
ackAlarm time.Time ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame lastAck *frames.AckFrame
} }
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler { func NewReceivedPacketHandler() ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{ return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(), packetHistory: newReceivedPacketHistory(),
ackAlarmResetCallback: ackAlarmResetCallback,
ackSendDelay: protocol.AckSendDelay, ackSendDelay: protocol.AckSendDelay,
} }
} }
@ -52,20 +39,11 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return errInvalidPacketNumber return errInvalidPacketNumber
} }
// if the packet number is smaller than the largest LeastUnacked value of a StopWaiting we received, we cannot detect if this packet has a duplicate number if packetNumber > h.ignorePacketsBelow {
// the packet has to be ignored anyway if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
if packetNumber <= h.ignorePacketsBelow {
return ErrPacketSmallerThanLastStopWaiting
}
if h.packetHistory.IsDuplicate(packetNumber) {
return ErrDuplicatePacket
}
err := h.packetHistory.ReceivedPacket(packetNumber)
if err != nil {
return err return err
} }
}
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
@ -89,7 +67,6 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
} }
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
var ackAlarmSet bool
h.packetsReceivedSinceLastAck++ h.packetsReceivedSinceLastAck++
if shouldInstigateAck { if shouldInstigateAck {
@ -124,7 +101,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
} else { } else {
if h.ackAlarm.IsZero() { if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay) h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
} }
} }
} }
@ -132,11 +108,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
if h.ackQueued { if h.ackQueued {
// cancel the ack alarm // cancel the ack alarm
h.ackAlarm = time.Time{} h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
} }
} }
@ -164,3 +135,5 @@ func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
return ack return ack
} }
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -2,9 +2,9 @@ package ackhandler
import ( import (
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type receivedPacketHistory struct { type receivedPacketHistory struct {

View File

@ -0,0 +1,38 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/frames"
)
// Returns a new slice with all non-retransmittable frames deleted.
func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
res := make([]frames.Frame, 0, len(fs))
for _, f := range fs {
if IsFrameRetransmittable(f) {
res = append(res, f)
}
}
return res
}
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f frames.Frame) bool {
switch f.(type) {
case *frames.StopWaitingFrame:
return false
case *frames.AckFrame:
return false
default:
return true
}
}
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
func HasRetransmittableFrames(fs []frames.Frame) bool {
for _, f := range fs {
if IsFrameRetransmittable(f) {
return true
}
}
return false
}

View File

@ -7,9 +7,21 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils" )
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
) )
var ( var (
@ -22,11 +34,10 @@ var (
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package") errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
) )
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number.") var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
type sentPacketHandler struct { type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber lastSentPacketNumber protocol.PacketNumber
lastSentPacketTime time.Time
skippedPackets []protocol.PacketNumber skippedPackets []protocol.PacketNumber
LargestAcked protocol.PacketNumber LargestAcked protocol.PacketNumber
@ -40,10 +51,17 @@ type sentPacketHandler struct {
bytesInFlight protocol.ByteCount bytesInFlight protocol.ByteCount
rttStats *congestion.RTTStats
congestion congestion.SendAlgorithm congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
consecutiveRTOCount uint32 // The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
} }
// NewSentPacketHandler creates a new sentPacketHandler // NewSentPacketHandler creates a new sentPacketHandler
@ -64,40 +82,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
} }
} }
func (h *sentPacketHandler) ackPacket(packetElement *PacketElement) {
packet := &packetElement.Value
h.bytesInFlight -= packet.Length
h.packetHistory.Remove(packetElement)
}
// nackPacket NACKs a packet
// it returns true if a FastRetransmissions was triggered
func (h *sentPacketHandler) nackPacket(packetElement *PacketElement) bool {
packet := &packetElement.Value
packet.MissingReports++
if packet.MissingReports > protocol.RetransmissionThreshold {
utils.Debugf("\tQueueing packet 0x%x for retransmission (fast)", packet.PacketNumber)
h.queuePacketForRetransmission(packetElement)
return true
}
return false
}
// does NOT set packet.Retransmitted. This variable is not needed anymore
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
packet := &packetElement.Value
h.bytesInFlight -= packet.Length
h.retransmissionQueue = append(h.retransmissionQueue, packet)
h.packetHistory.Remove(packetElement)
// strictly speaking, this is only necessary for RTO retransmissions
// this is because FastRetransmissions are triggered by missing ranges in ACKs, and then the LargestAcked will already be higher than the packet number of the retransmitted packet
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
}
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber { func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
if f := h.packetHistory.Front(); f != nil { if f := h.packetHistory.Front(); f != nil {
return f.Value.PacketNumber - 1 return f.Value.PacketNumber - 1
@ -110,6 +94,10 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
return errPacketNumberNotIncreasing return errPacketNumberNotIncreasing
} }
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
return ErrTooManyTrackedSentPackets
}
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p) h.skippedPackets = append(h.skippedPackets, p)
@ -118,25 +106,27 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
} }
} }
now := time.Now()
h.lastSentPacketTime = now
packet.SendTime = now
if packet.Length == 0 {
return errors.New("SentPacketHandler: packet cannot be empty")
}
h.bytesInFlight += packet.Length
h.lastSentPacketNumber = packet.PacketNumber h.lastSentPacketNumber = packet.PacketNumber
now := time.Now()
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet) h.packetHistory.PushBack(*packet)
}
h.congestion.OnPacketSent( h.congestion.OnPacketSent(
now, now,
h.BytesInFlight(), h.bytesInFlight,
packet.PacketNumber, packet.PacketNumber,
packet.Length, packet.Length,
true, /* TODO: is retransmittable */ isRetransmittable,
) )
h.updateLossDetectionAlarm()
return nil return nil
} }
@ -149,54 +139,57 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
if withPacketNumber <= h.largestReceivedPacketWithAck { if withPacketNumber <= h.largestReceivedPacketWithAck {
return ErrDuplicateOrOutOfOrderAck return ErrDuplicateOrOutOfOrderAck
} }
h.largestReceivedPacketWithAck = withPacketNumber h.largestReceivedPacketWithAck = withPacketNumber
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK) // ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
if ackFrame.LargestAcked <= h.largestInOrderAcked() { if ackFrame.LargestAcked <= h.largestInOrderAcked() {
return nil return nil
} }
// check if it acks any packets that were skipped
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return ErrAckForSkippedPacket
}
}
h.LargestAcked = ackFrame.LargestAcked h.LargestAcked = ackFrame.LargestAcked
var ackedPackets congestion.PacketVector if h.skippedPacketsAcked(ackFrame) {
var lostPackets congestion.PacketVector return ErrAckForSkippedPacket
ackRangeIndex := 0 }
rttUpdated := false
var el, elNext *PacketElement rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
for el = h.packetHistory.Front(); el != nil; el = elNext {
// determine the next list element right at the beginning, because el.Next() is not avaible anymore, when the list element is deleted (i.e. when the packet is ACKed) if rttUpdated {
elNext = el.Next() h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
if len(ackedPackets) > 0 {
for _, p := range ackedPackets {
h.onPacketAcked(p)
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
h.detectLostPackets()
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame) ([]*PacketElement, error) {
var ackedPackets []*PacketElement
ackRangeIndex := 0
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value packet := el.Value
packetNumber := packet.PacketNumber packetNumber := packet.PacketNumber
// NACK packets below the LowestAcked // Ignore packets below the LowestAcked
if packetNumber < ackFrame.LowestAcked { if packetNumber < ackFrame.LowestAcked {
retransmitted := h.nackPacket(el)
if retransmitted {
lostPackets = append(lostPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length})
}
continue continue
} }
// Break after LargestAcked is reached
// Update the RTT
if packetNumber == h.LargestAcked {
rttUpdated = true
timeDelta := rcvTime.Sub(packet.SendTime)
h.rttStats.UpdateRTT(timeDelta, ackFrame.DelayTime, rcvTime)
if utils.Debug() {
utils.Debugf("\tEstimated RTT: %dms", h.rttStats.SmoothedRTT()/time.Millisecond)
}
}
if packetNumber > ackFrame.LargestAcked { if packetNumber > ackFrame.LargestAcked {
break break
} }
@ -211,59 +204,119 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range
if packetNumber > ackRange.LastPacketNumber { if packetNumber > ackRange.LastPacketNumber {
return fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber) return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber)
}
h.ackPacket(el)
ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length})
} else {
retransmitted := h.nackPacket(el)
if retransmitted {
lostPackets = append(lostPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length})
} }
ackedPackets = append(ackedPackets, el)
} }
} else { } else {
h.ackPacket(el) ackedPackets = append(ackedPackets, el)
ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length})
} }
} }
if rttUpdated { return ackedPackets, nil
// Reset counter if a new packet was acked }
h.consecutiveRTOCount = 0
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber == largestAcked {
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
return true
}
// Packets are sorted by number, so we can stop searching
if packet.PacketNumber > largestAcked {
break
}
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
h.alarm = time.Time{}
return
} }
h.garbageCollectSkippedPackets() // TODO(#496): Handle handshake packets separately
// TODO(#497): TLP
if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO
h.alarm = time.Now().Add(h.computeRTOTimeout())
}
}
h.stopWaitingManager.ReceivedAck(ackFrame) func (h *sentPacketHandler) detectLostPackets() {
h.lossTime = time.Time{}
now := time.Now()
h.congestion.OnCongestionEvent( maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
rttUpdated, delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
h.BytesInFlight(),
ackedPackets,
lostPackets,
)
return nil var lostPackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber > h.LargestAcked {
break
}
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, el)
} else if h.lossTime.IsZero() {
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
}
if len(lostPackets) > 0 {
for _, p := range lostPackets {
h.queuePacketForRetransmission(p)
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
}
func (h *sentPacketHandler) OnAlarm() {
// TODO(#496): Handle handshake packets separately
// TODO(#497): TLP
if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
h.detectLostPackets()
} else {
// RTO
h.retransmitOldestTwoPackets()
h.rtoCount++
}
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
h.bytesInFlight -= packetElement.Value.Length
h.rtoCount = 0
// TODO(#497): h.tlpCount = 0
h.packetHistory.Remove(packetElement)
} }
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 { if len(h.retransmissionQueue) == 0 {
return nil return nil
} }
packet := h.retransmissionQueue[0]
if len(h.retransmissionQueue) > 0 { // Shift the slice and don't retain anything that isn't needed.
queueLen := len(h.retransmissionQueue) copy(h.retransmissionQueue, h.retransmissionQueue[1:])
// packets are usually NACKed in descending order. So use the slice as a stack h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
packet := h.retransmissionQueue[queueLen-1] h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
h.retransmissionQueue = h.retransmissionQueue[:queueLen-1]
return packet return packet
}
return nil
}
func (h *sentPacketHandler) BytesInFlight() protocol.ByteCount {
return h.bytesInFlight
} }
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber { func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
@ -275,65 +328,67 @@ func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingF
} }
func (h *sentPacketHandler) SendingAllowed() bool { func (h *sentPacketHandler) SendingAllowed() bool {
congestionLimited := h.BytesInFlight() > h.congestion.GetCongestionWindow() congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
return !(congestionLimited || maxTrackedLimited) if congestionLimited {
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
h.bytesInFlight,
h.congestion.GetCongestionWindow())
}
// Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
// to RTOs, but we currently don't have a nice way of distinguishing them.
haveRetransmissions := len(h.retransmissionQueue) > 0
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
} }
func (h *sentPacketHandler) CheckForError() error { func (h *sentPacketHandler) retransmitOldestTwoPackets() {
length := len(h.retransmissionQueue) + h.packetHistory.Len() if p := h.packetHistory.Front(); p != nil {
if protocol.PacketNumber(length) > protocol.MaxTrackedSentPackets { h.queueRTO(p)
return ErrTooManyTrackedSentPackets
} }
return nil if p := h.packetHistory.Front(); p != nil {
} h.queueRTO(p)
func (h *sentPacketHandler) MaybeQueueRTOs() {
if time.Now().Before(h.TimeOfFirstRTO()) {
return
} }
// Always queue the two oldest packets
if h.packetHistory.Front() != nil {
h.queueRTO(h.packetHistory.Front())
}
if h.packetHistory.Front() != nil {
h.queueRTO(h.packetHistory.Front())
}
// Reset the RTO timer here, since it's not clear that this packet contained any retransmittable frames
h.lastSentPacketTime = time.Now()
h.consecutiveRTOCount++
} }
func (h *sentPacketHandler) queueRTO(el *PacketElement) { func (h *sentPacketHandler) queueRTO(el *PacketElement) {
packet := &el.Value packet := &el.Value
packetsLost := congestion.PacketVector{congestion.PacketInfo{ utils.Debugf(
Number: packet.PacketNumber, "\tQueueing packet 0x%x for retransmission (RTO), %d outstanding",
Length: packet.Length, packet.PacketNumber,
}} h.packetHistory.Len(),
h.congestion.OnCongestionEvent(false, h.BytesInFlight(), nil, packetsLost) )
h.congestion.OnRetransmissionTimeout(true)
utils.Debugf("\tQueueing packet 0x%x for retransmission (RTO)", packet.PacketNumber)
h.queuePacketForRetransmission(el) h.queuePacketForRetransmission(el)
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
h.congestion.OnRetransmissionTimeout(true)
} }
func (h *sentPacketHandler) getRTO() time.Duration { func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
packet := &packetElement.Value
h.bytesInFlight -= packet.Length
h.retransmissionQueue = append(h.retransmissionQueue, packet)
h.packetHistory.Remove(packetElement)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
rto := h.congestion.RetransmissionDelay() rto := h.congestion.RetransmissionDelay()
if rto == 0 { if rto == 0 {
rto = protocol.DefaultRetransmissionTime rto = defaultRTOTimeout
} }
rto = utils.MaxDuration(rto, protocol.MinRetransmissionTime) rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff // Exponential backoff
rto *= 1 << h.consecutiveRTOCount rto = rto << h.rtoCount
return utils.MinDuration(rto, protocol.MaxRetransmissionTime) return utils.MinDuration(rto, maxRTOTimeout)
} }
func (h *sentPacketHandler) TimeOfFirstRTO() time.Time { func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool {
if h.lastSentPacketTime.IsZero() { for _, p := range h.skippedPackets {
return time.Time{} if ackFrame.AcksPacket(p) {
return true
} }
return h.lastSentPacketTime.Add(h.getRTO()) }
return false
} }
func (h *sentPacketHandler) garbageCollectSkippedPackets() { func (h *sentPacketHandler) garbageCollectSkippedPackets() {

View File

@ -5,6 +5,7 @@ os: Windows Server 2012 R2
environment: environment:
GOPATH: c:\gopath GOPATH: c:\gopath
CGO_ENABLED: 0 CGO_ENABLED: 0
TIMESCALE_FACTOR: 20
matrix: matrix:
- GOARCH: 386 - GOARCH: 386
- GOARCH: amd64 - GOARCH: amd64
@ -13,8 +14,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install: install:
- rmdir c:\go /s /q - rmdir c:\go /s /q
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.7.5.windows-amd64.zip - appveyor DownloadFile https://storage.googleapis.com/golang/go1.8.3.windows-amd64.zip
- 7z x go1.7.5.windows-amd64.zip -y -oC:\ > NUL - 7z x go1.8.3.windows-amd64.zip -y -oC:\ > NUL
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin - set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH% - echo %PATH%
- echo %GOPATH% - echo %GOPATH%
@ -27,7 +28,8 @@ install:
build_script: build_script:
- rm -r integrationtests - rm -r integrationtests
- ginkgo -r --randomizeAllSpecs --randomizeSuites --trace --progress - ginkgo -r --randomizeAllSpecs --randomizeSuites --trace --progress -skipPackage benchmark
- ginkgo --randomizeAllSpecs --randomizeSuites --trace --progress benchmark -- -samples=1
test: off test: off

View File

@ -13,7 +13,7 @@ func getPacketBuffer() []byte {
} }
func putPacketBuffer(buf []byte) { func putPacketBuffer(buf []byte) {
if cap(buf) != int(protocol.MaxPacketSize) { if cap(buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!") panic("putPacketBuffer called with packet of wrong size!")
} }
bufferPool.Put(buf[:0]) bufferPool.Put(buf[:0])
@ -21,6 +21,6 @@ func putPacketBuffer(buf []byte) {
func init() { func init() {
bufferPool.New = func() interface{} { bufferPool.New = func() interface{} {
return make([]byte, 0, protocol.MaxPacketSize) return make([]byte, 0, protocol.MaxReceivePacketSize)
} }
} }

View File

@ -4,216 +4,332 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"strings" "strings"
"sync/atomic" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A Client of QUIC type client struct {
type Client struct { mutex sync.Mutex
addr *net.UDPAddr listenErr error
conn *net.UDPConn
conn connection
hostname string hostname string
errorChan chan struct{}
handshakeChan <-chan handshakeEvent
tlsConf *tls.Config
config *Config
versionNegotiated bool // has version negotiation completed yet
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
version protocol.VersionNumber version protocol.VersionNumber
versionNegotiated bool
closed uint32 // atomic bool
tlsConfig *tls.Config
cryptoChangeCallback CryptoChangeCallback
versionNegotiateCallback VersionNegotiateCallback
session packetHandler session packetHandler
} }
// VersionNegotiateCallback is called once the client has a negotiated version
type VersionNegotiateCallback func() error
var errHostname = errors.New("Invalid hostname")
var ( var (
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
) )
// NewClient makes a new client // DialAddr establishes a new QUIC connection to a server.
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { // The hostname for SNI is taken from the given address.
udpAddr, err := net.ResolveUDPAddr("udp", host) func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, addr, tlsConf, config)
}
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
connectionID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
} }
hostname, _, err := net.SplitHostPort(host) clientConfig := populateClientConfig(config)
if err != nil { c := &client{
return nil, err conn: &conn{pconn: pconn, currentAddr: remoteAddr},
} connectionID: connID,
client := &Client{
addr: udpAddr,
conn: conn,
hostname: hostname, hostname: hostname,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default tlsConf: tlsConf,
connectionID: connectionID, config: clientConfig,
tlsConfig: tlsConfig, version: clientConfig.Versions[0],
cryptoChangeCallback: cryptoChangeCallback, errorChan: make(chan struct{}),
versionNegotiateCallback: versionNegotiateCallback,
} }
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) err = c.createNewSession(nil)
err = client.createNewSession(nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client, nil utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection()
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil {
return nil, err
}
err = sess.WaitUntilHandshakeComplete()
if err != nil {
return nil, err
}
return sess, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
}
}
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
go c.listen()
select {
case <-c.errorChan:
return c.listenErr
case ev := <-c.handshakeChan:
if ev.err != nil {
return ev.err
}
if ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
}
} }
// Listen listens // Listen listens
func (c *Client) Listen() error { func (c *client) listen() {
for { var err error
data := getPacketBuffer()
data = data[:protocol.MaxPacketSize]
n, _, err := c.conn.ReadFromUDP(data) for {
var n int
var addr net.Addr
data := getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err = c.conn.Read(data)
if err != nil { if err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") { if !strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil c.session.Close(err)
} }
return err break
} }
data = data[:n] data = data[:n]
err = c.handlePacket(data) c.handlePacket(addr, data)
if err != nil {
utils.Errorf("error handling packet: %s", err.Error())
c.session.Close(err)
return err
}
} }
} }
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs) func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) {
return c.session.OpenStream(id)
}
// Close closes the connection
func (c *Client) Close(e error) error {
// Only close once
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return nil
}
_ = c.session.Close(e)
return c.conn.Close()
}
func (c *Client) handlePacket(packet []byte) error {
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
return qerr.PacketTooLarge
}
rcvTime := time.Now() rcvTime := time.Now()
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
if err != nil { if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the Public Header
return
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
c.mutex.Lock()
defer c.mutex.Unlock()
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := parsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
return
}
// ignore delayed / duplicated version negotiation packets // ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag { if c.versionNegotiated && hdr.VersionFlag {
return nil return
} }
// this is the first packet after the client sent a packet with the VersionFlag set // this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version // if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated { if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true c.versionNegotiated = true
err = c.versionNegotiateCallback()
if err != nil {
return err
}
} }
if hdr.VersionFlag { if hdr.VersionFlag {
var hasCommonVersion bool // check if we're supporting any of the offered versions // version negotiation packets have no payload
for _, v := range hdr.SupportedVersions { if err := c.handlePacketWithVersionFlag(hdr); err != nil {
// check if the server sent the offered version in supported versions c.session.Close(err)
if v == c.version {
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.")
} }
if v != protocol.VersionUnsupported { return
hasCommonVersion = true
}
}
if !hasCommonVersion {
utils.Infof("No common version found.")
return qerr.InvalidVersion
}
ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions)
if !ok {
return qerr.VersionNegotiationMismatch
}
utils.Infof("Switching to QUIC version %d", highestSupportedVersion)
c.version = highestSupportedVersion
c.versionNegotiated = true
c.session.Close(errCloseSessionForNewVersion)
err = c.createNewSession(hdr.SupportedVersions)
if err != nil {
return err
}
err = c.versionNegotiateCallback()
if err != nil {
return err
}
return nil // version negotiation packets have no payload
} }
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{
remoteAddr: c.addr, remoteAddr: remoteAddr,
publicHeader: hdr, publicHeader: hdr,
data: packet[len(packet)-r.Len():], data: packet[len(packet)-r.Len():],
rcvTime: rcvTime, rcvTime: rcvTime,
}) })
return nil
} }
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
// ignore it
return nil
}
}
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if newVersion == protocol.VersionUnsupported {
return qerr.InvalidVersion
}
// switch to negotiated version
c.version = newVersion
c.versionNegotiated = true
var err error var err error
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
return err
}
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
return c.createNewSession(hdr.SupportedVersions)
}
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error
c.session, c.handshakeChan, err = newClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.tlsConf,
c.config,
negotiatedVersions,
)
if err != nil { if err != nil {
return err return err
} }
go c.session.run() go func() {
// session.run() returns as soon as the session is closed
err := c.session.run()
if err == errCloseSessionForNewVersion {
return
}
c.listenErr = err
close(c.errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
return nil return nil
} }
func (c *Client) streamCallback(session *Session, stream utils.Stream) {}
func (c *Client) closeCallback(id protocol.ConnectionID) {
utils.Infof("Connection %x closed.", id)
}

View File

@ -4,8 +4,8 @@ coverage:
- ackhandler/packet_linkedlist.go - ackhandler/packet_linkedlist.go
- h2quic/gzipreader.go - h2quic/gzipreader.go
- h2quic/response.go - h2quic/response.go
- utils/byteinterval_linkedlist.go - internal/utils/byteinterval_linkedlist.go
- utils/packetinterval_linkedlist.go - internal/utils/packetinterval_linkedlist.go
status: status:
project: project:
default: default:

View File

@ -12,12 +12,8 @@ type Bandwidth uint64
const ( const (
// BitsPerSecond is 1 bit per second // BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1 BitsPerSecond Bandwidth = 1
// KBitsPerSecond is 1000 bits per second
KBitsPerSecond = 1000 * BitsPerSecond
// BytesPerSecond is 1 byte per second // BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond BytesPerSecond = 8 * BitsPerSecond
// KBytesPerSecond is 1000 bytes per second
KBytesPerSecond = 1000 * BytesPerSecond
) )
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta // BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta

View File

@ -1,12 +0,0 @@
package congestion
import "github.com/lucas-clemente/quic-go/protocol"
// PacketInfo combines packet number and length of a packet for congestion calculation
type PacketInfo struct {
Number protocol.PacketNumber
Length protocol.ByteCount
}
// PacketVector is passed to the congestion algorithm
type PacketVector []PacketInfo

View File

@ -4,8 +4,8 @@ import (
"math" "math"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// This cubic implementation is based on the one found in Chromiums's QUIC // This cubic implementation is based on the one found in Chromiums's QUIC
@ -184,9 +184,19 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000 elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000
offset := int64(c.timeToOriginPoint) - elapsedTime offset := int64(c.timeToOriginPoint) - elapsedTime
// Right-shifts of negative, signed numbers have
// implementation-dependent behavior. Force the offset to be
// positive, similar to the kernel implementation.
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale) deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale)
targetCongestionWindow := c.originPointCongestionWindow - deltaCongestionWindow var targetCongestionWindow protocol.PacketNumber
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// With dynamic beta/alpha based on number of active streams, it is possible // With dynamic beta/alpha based on number of active streams, it is possible
// for the required_ack_count to become much lower than acked_packets_count_ // for the required_ack_count to become much lower than acked_packets_count_
// suddenly, leading to more than one iteration through the following loop. // suddenly, leading to more than one iteration through the following loop.

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
const ( const (
@ -125,24 +125,13 @@ func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
return c.slowstartThreshold return c.slowstartThreshold
} }
// OnCongestionEvent indicates an update to the congestion state, caused either by an incoming func (c *cubicSender) MaybeExitSlowStart() {
// ack or loss event timeout. |rttUpdated| indicates whether a new if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
// latest_rtt sample has been taken, |byte_in_flight| the bytes in flight
// prior to the congestion event. |ackedPackets| and |lostPackets| are
// any packets considered acked or lost as a result of the congestion event.
func (c *cubicSender) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector) {
if rttUpdated && c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
c.ExitSlowstart() c.ExitSlowstart()
} }
for _, i := range lostPackets {
c.onPacketLost(i.Number, i.Length, bytesInFlight)
}
for _, i := range ackedPackets {
c.onPacketAcked(i.Number, i.Length, bytesInFlight)
}
} }
func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() { if c.InRecovery() {
// PRR is used when in recovery. // PRR is used when in recovery.
@ -155,7 +144,7 @@ func (c *cubicSender) onPacketAcked(ackedPacketNumber protocol.PacketNumber, ack
} }
} }
func (c *cubicSender) onPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected. // already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback { if packetNumber <= c.largestSentAtLastCutback {

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// Note(pwestin): the magic clamping numbers come from the original code in // Note(pwestin): the magic clamping numbers come from the original code in

View File

@ -11,7 +11,9 @@ type SendAlgorithm interface {
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
GetCongestionWindow() protocol.ByteCount GetCongestionWindow() protocol.ByteCount
OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets PacketVector, lostPackets PacketVector) MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
SetNumEmulatedConnections(n int) SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool) OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration() OnConnectionMigration()

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937

View File

@ -3,7 +3,7 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
const ( const (

54
vendor/github.com/lucas-clemente/quic-go/conn.go generated vendored Normal file
View File

@ -0,0 +1,54 @@
package quic
import (
"net"
"sync"
)
type connection interface {
Write([]byte) error
Read([]byte) (int, net.Addr, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetCurrentRemoteAddr(net.Addr)
}
type conn struct {
mutex sync.RWMutex
pconn net.PacketConn
currentAddr net.Addr
}
var _ connection = &conn{}
func (c *conn) Write(p []byte) error {
_, err := c.pconn.WriteTo(p, c.currentAddr)
return err
}
func (c *conn) Read(p []byte) (int, net.Addr, error) {
return c.pconn.ReadFrom(p)
}
func (c *conn) SetCurrentRemoteAddr(addr net.Addr) {
c.mutex.Lock()
c.currentAddr = addr
c.mutex.Unlock()
}
func (c *conn) LocalAddr() net.Addr {
return c.pconn.LocalAddr()
}
func (c *conn) RemoteAddr() net.Addr {
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}
func (c *conn) Close() error {
return c.pconn.Close()
}

View File

@ -55,30 +55,59 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
return cert.Certificate[0], nil return cert.Certificate[0], nil
} }
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
if c.config.GetCertificate != nil { c := cc.config
cert, err := c.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) c, err := maybeGetConfigForClient(c, sni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cert != nil { // The rest of this function is mostly copied from crypto/tls.getCertificate
return cert, nil
if c.GetCertificate != nil {
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil {
return cert, err
} }
} }
if len(c.config.NameToCertificate) != 0 { if len(c.Certificates) == 0 {
if cert, ok := c.config.NameToCertificate[sni]; ok {
return cert, nil
}
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' })
if cert, ok := c.config.NameToCertificate[wildcardSNI]; ok {
return cert, nil
}
}
if len(c.config.Certificates) != 0 {
return &c.config.Certificates[0], nil
}
return nil, errNoMatchingCertificate return nil, errNoMatchingCertificate
}
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
return &c.Certificates[0], nil
}
name := strings.ToLower(sni)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
if cert, ok := c.NameToCertificate[name]; ok {
return cert, nil
}
// try replacing labels in the name with wildcards until we get a
// match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &c.Certificates[0], nil
}
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
} }

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type entryType uint8 type entryType uint8
@ -142,7 +142,7 @@ func decompressChain(data []byte) ([][]byte, error) {
} }
if numCerts == 0 { if numCerts == 0 {
return make([][]byte, 0, 0), nil return make([][]byte, 0), nil
} }
if hasCompressedCerts { if hasCompressedCerts {
@ -255,7 +255,7 @@ func splitHashes(hashes []byte) ([]uint64, error) {
} }
func getCommonCertificateHashes() []byte { func getCommonCertificateHashes() []byte {
ccs := make([]byte, 8*len(certSets), 8*len(certSets)) ccs := make([]byte, 8*len(certSets))
i := 0 i := 0
for certSetHash := range certSets { for certSetHash := range certSets {
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash) binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)

View File

@ -41,7 +41,7 @@ func (c *certManager) SetData(data []byte) error {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid") return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
} }
chain := make([]*x509.Certificate, len(byteChain), len(byteChain)) chain := make([]*x509.Certificate, len(byteChain))
for i, data := range byteChain { for i, data := range byteChain {
cert, err := x509.ParseCertificate(data) cert, err := x509.ParseCertificate(data)
if err != nil { if err != nil {
@ -107,15 +107,14 @@ func (c *certManager) Verify(hostname string) error {
var opts x509.VerifyOptions var opts x509.VerifyOptions
if c.config != nil { if c.config != nil {
opts.Roots = c.config.RootCAs opts.Roots = c.config.RootCAs
opts.DNSName = c.config.ServerName
if c.config.Time == nil { if c.config.Time == nil {
opts.CurrentTime = time.Now() opts.CurrentTime = time.Now()
} else { } else {
opts.CurrentTime = c.config.Time() opts.CurrentTime = c.config.Time()
} }
} else {
opts.DNSName = hostname
} }
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates // the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 { if len(c.chain) > 1 {

View File

@ -5,8 +5,8 @@ import (
"crypto/sha256" "crypto/sha256"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )

View File

@ -8,13 +8,24 @@ import (
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
) )
// NullAEAD handles not-yet encrypted packets // nullAEAD handles not-yet encrypted packets
type NullAEAD struct{} type nullAEAD struct {
perspective protocol.Perspective
version protocol.VersionNumber
}
var _ AEAD = &NullAEAD{} var _ AEAD = &nullAEAD{}
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(p protocol.Perspective, v protocol.VersionNumber) AEAD {
return &nullAEAD{
perspective: p,
version: v,
}
}
// Open and verify the ciphertext // Open and verify the ciphertext
func (NullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
if len(src) < 12 { if len(src) < 12 {
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long") return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
} }
@ -22,6 +33,13 @@ func (NullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associ
hash := fnv128a.New() hash := fnv128a.New()
hash.Write(associatedData) hash.Write(associatedData)
hash.Write(src[12:]) hash.Write(src[12:])
if n.version >= protocol.Version37 {
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Client"))
} else {
hash.Write([]byte("Server"))
}
}
testHigh, testLow := hash.Sum128() testHigh, testLow := hash.Sum128()
low := binary.LittleEndian.Uint64(src) low := binary.LittleEndian.Uint64(src)
@ -34,7 +52,7 @@ func (NullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associ
} }
// Seal writes hash and ciphertext to the buffer // Seal writes hash and ciphertext to the buffer
func (NullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
if cap(dst) < 12+len(src) { if cap(dst) < 12+len(src) {
dst = make([]byte, 12+len(src)) dst = make([]byte, 12+len(src))
} else { } else {
@ -44,6 +62,15 @@ func (NullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associ
hash := fnv128a.New() hash := fnv128a.New()
hash.Write(associatedData) hash.Write(associatedData)
hash.Write(src) hash.Write(src)
if n.version >= protocol.Version37 {
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Server"))
} else {
hash.Write([]byte("Client"))
}
}
high, low := hash.Sum128() high, low := hash.Sum128()
copy(dst[12:], src) copy(dst[12:], src)

View File

@ -5,48 +5,18 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"time"
"github.com/lucas-clemente/quic-go/protocol"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )
// StkSource is used to create and verify source address tokens // StkSource is used to create and verify source address tokens
type StkSource interface { type StkSource interface {
// NewToken creates a new token for a given IP address // NewToken creates a new token
NewToken(ip net.IP) ([]byte, error) NewToken([]byte) ([]byte, error)
// VerifyToken verifies if a token matches a given IP address and is not outdated // DecodeToken decodes a token
VerifyToken(ip net.IP, data []byte) error DecodeToken([]byte) ([]byte, error)
}
type sourceAddressToken struct {
ip net.IP
// unix timestamp in seconds
timestamp uint64
}
func (t *sourceAddressToken) serialize() []byte {
res := make([]byte, 8+len(t.ip))
binary.LittleEndian.PutUint64(res, t.timestamp)
copy(res[8:], t.ip)
return res
}
func parseToken(data []byte) (*sourceAddressToken, error) {
if len(data) != 8+4 && len(data) != 8+16 {
return nil, fmt.Errorf("invalid STK length: %d", len(data))
}
return &sourceAddressToken{
ip: data[8:],
timestamp: binary.LittleEndian.Uint64(data),
}, nil
} }
type stkSource struct { type stkSource struct {
@ -60,7 +30,11 @@ const stkKeySize = 16
const stkNonceSize = 16 const stkNonceSize = 16
// NewStkSource creates a source for source address tokens // NewStkSource creates a source for source address tokens
func NewStkSource(secret []byte) (StkSource, error) { func NewStkSource() (StkSource, error) {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
key, err := deriveKey(secret) key, err := deriveKey(secret)
if err != nil { if err != nil {
return nil, err return nil, err
@ -76,38 +50,20 @@ func NewStkSource(secret []byte) (StkSource, error) {
return &stkSource{aead: aead}, nil return &stkSource{aead: aead}, nil
} }
func (s *stkSource) NewToken(ip net.IP) ([]byte, error) { func (s *stkSource) NewToken(data []byte) ([]byte, error) {
return encryptToken(s.aead, &sourceAddressToken{ nonce := make([]byte, stkNonceSize)
ip: ip, if _, err := rand.Read(nonce); err != nil {
timestamp: uint64(time.Now().Unix()), return nil, err
}) }
return s.aead.Seal(nonce, nonce, data, nil), nil
} }
func (s *stkSource) VerifyToken(ip net.IP, data []byte) error { func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
if len(data) < stkNonceSize { if len(p) < stkNonceSize {
return errors.New("STK too short") return nil, fmt.Errorf("STK too short: %d", len(p))
} }
nonce := data[:stkNonceSize] nonce := p[:stkNonceSize]
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil)
if err != nil {
return err
}
token, err := parseToken(res)
if err != nil {
return err
}
if subtle.ConstantTimeCompare(token.ip, ip) != 1 {
return errors.New("invalid ip in STK")
}
if time.Now().Unix() > int64(token.timestamp)+protocol.STKExpiryTimeSec {
return errors.New("STK expired")
}
return nil
} }
func deriveKey(secret []byte) ([]byte, error) { func deriveKey(secret []byte) ([]byte, error) {
@ -118,11 +74,3 @@ func deriveKey(secret []byte) ([]byte, error) {
} }
return key, nil return key, nil
} }
func encryptToken(aead cipher.AEAD, token *sourceAddressToken) ([]byte, error) {
nonce := make([]byte, stkNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
return aead.Seal(nonce, nonce, token.serialize(), nil), nil
}

View File

@ -7,9 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowControlManager struct { type flowControlManager struct {
@ -17,7 +17,7 @@ type flowControlManager struct {
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
streamFlowController map[protocol.StreamID]*flowController streamFlowController map[protocol.StreamID]*flowController
contributesToConnectionFlowControl map[protocol.StreamID]bool connFlowController *flowController
mutex sync.RWMutex mutex sync.RWMutex
} }
@ -27,20 +27,17 @@ var errMapAccess = errors.New("Error accessing the flowController map.")
// NewFlowControlManager creates a new flow control manager // NewFlowControlManager creates a new flow control manager
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager { func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
fcm := flowControlManager{ return &flowControlManager{
connectionParameters: connectionParameters, connectionParameters: connectionParameters,
rttStats: rttStats, rttStats: rttStats,
streamFlowController: make(map[protocol.StreamID]*flowController), streamFlowController: make(map[protocol.StreamID]*flowController),
contributesToConnectionFlowControl: make(map[protocol.StreamID]bool), connFlowController: newFlowController(0, false, connectionParameters, rttStats),
} }
// initialize connection level flow controller
fcm.streamFlowController[0] = newFlowController(0, connectionParameters, rttStats)
fcm.contributesToConnectionFlowControl[0] = false
return &fcm
} }
// NewStream creates new flow controllers for a stream // NewStream creates new flow controllers for a stream
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) { // it does nothing if the stream already exists
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
@ -48,15 +45,13 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo
return return
} }
f.streamFlowController[streamID] = newFlowController(streamID, f.connectionParameters, f.rttStats) f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats)
f.contributesToConnectionFlowControl[streamID] = contributesToConnectionFlow
} }
// RemoveStream removes a closed stream from flow control // RemoveStream removes a closed stream from flow control
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
f.mutex.Lock() f.mutex.Lock()
delete(f.streamFlowController, streamID) delete(f.streamFlowController, streamID)
delete(f.contributesToConnectionFlowControl, streamID)
f.mutex.Unlock() f.mutex.Unlock()
} }
@ -77,31 +72,19 @@ func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset
} }
if streamFlowController.CheckFlowControlViolation() { if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
} }
if f.contributesToConnectionFlowControl[streamID] { if streamFlowController.ContributesToConnection() {
connectionFlowController := f.streamFlowController[0] f.connFlowController.IncrementHighestReceived(increment)
connectionFlowController.IncrementHighestReceived(increment) if f.connFlowController.CheckFlowControlViolation() {
if connectionFlowController.CheckFlowControlViolation() { return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow))
} }
} }
return nil return nil
} }
func (f *flowControlManager) GetBytesSent(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
return fc.GetBytesSent(), nil
}
// UpdateHighestReceived updates the highest received byte offset for a stream // UpdateHighestReceived updates the highest received byte offset for a stream
// it adds the number of additional bytes to connection level flow control // it adds the number of additional bytes to connection level flow control
// streamID must not be 0 here // streamID must not be 0 here
@ -118,14 +101,13 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset) increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
if streamFlowController.CheckFlowControlViolation() { if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
} }
if f.contributesToConnectionFlowControl[streamID] { if streamFlowController.ContributesToConnection() {
connectionFlowController := f.streamFlowController[0] f.connFlowController.IncrementHighestReceived(increment)
connectionFlowController.IncrementHighestReceived(increment) if f.connFlowController.CheckFlowControlViolation() {
if connectionFlowController.CheckFlowControlViolation() { return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow))
} }
} }
@ -137,15 +119,14 @@ func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
streamFlowController, err := f.getFlowController(streamID) fc, err := f.getFlowController(streamID)
if err != nil { if err != nil {
return err return err
} }
streamFlowController.AddBytesRead(n) fc.AddBytesRead(n)
if fc.ContributesToConnection() {
if f.contributesToConnectionFlowControl[streamID] { f.connFlowController.AddBytesRead(n)
f.streamFlowController[0].AddBytesRead(n)
} }
return nil return nil
@ -154,38 +135,53 @@ func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
// get WindowUpdates for streams
for id, fc := range f.streamFlowController { for id, fc := range f.streamFlowController {
if necessary, offset := fc.MaybeTriggerWindowUpdate(); necessary { if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary {
res = append(res, WindowUpdate{StreamID: id, Offset: offset}) res = append(res, WindowUpdate{StreamID: id, Offset: offset})
if fc.ContributesToConnection() && newIncrement != 0 {
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier))
} }
} }
return res }
// get a WindowUpdate for the connection
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary {
res = append(res, WindowUpdate{StreamID: 0, Offset: offset})
}
return
} }
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.Lock() f.mutex.RLock()
defer f.mutex.Unlock() defer f.mutex.RUnlock()
// StreamID can be 0 when retransmitting
if streamID == 0 {
return f.connFlowController.receiveWindow, nil
}
flowController, err := f.getFlowController(streamID) flowController, err := f.getFlowController(streamID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return flowController.receiveFlowControlWindow, nil return flowController.receiveWindow, nil
} }
// streamID must not be 0 here // streamID must not be 0 here
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
// Only lock the part reading from the map, since send-windows are only accessed from the session goroutine.
f.mutex.Lock() f.mutex.Lock()
streamFlowController, err := f.getFlowController(streamID) defer f.mutex.Unlock()
f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil { if err != nil {
return err return err
} }
streamFlowController.AddBytesSent(n) fc.AddBytesSent(n)
if fc.ContributesToConnection() {
if f.contributesToConnectionFlowControl[streamID] { f.connFlowController.AddBytesSent(n)
f.streamFlowController[0].AddBytesSent(n)
} }
return nil return nil
@ -193,45 +189,46 @@ func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol
// must not be called with StreamID 0 // must not be called with StreamID 0
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) { func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
// Only lock the part reading from the map, since send-windows are only accessed from the session goroutine.
f.mutex.RLock() f.mutex.RLock()
streamFlowController, err := f.getFlowController(streamID) defer f.mutex.RUnlock()
f.mutex.RUnlock()
fc, err := f.getFlowController(streamID)
if err != nil { if err != nil {
return 0, err return 0, err
} }
res := streamFlowController.SendWindowSize() res := fc.SendWindowSize()
contributes, ok := f.contributesToConnectionFlowControl[streamID] if fc.ContributesToConnection() {
if !ok { res = utils.MinByteCount(res, f.connFlowController.SendWindowSize())
return 0, errMapAccess
}
if contributes {
res = utils.MinByteCount(res, f.streamFlowController[0].SendWindowSize())
} }
return res, nil return res, nil
} }
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
// Only lock the part reading from the map, since send-windows are only accessed from the session goroutine.
f.mutex.RLock() f.mutex.RLock()
res := f.streamFlowController[0].SendWindowSize() defer f.mutex.RUnlock()
f.mutex.RUnlock()
return res return f.connFlowController.SendWindowSize()
} }
// streamID may be 0 here // streamID may be 0 here
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) { func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
// Only lock the part reading from the map, since send-windows are only accessed from the session goroutine.
f.mutex.Lock() f.mutex.Lock()
streamFlowController, err := f.getFlowController(streamID) defer f.mutex.Unlock()
f.mutex.Unlock()
var fc *flowController
if streamID == 0 {
fc = f.connFlowController
} else {
var err error
fc, err = f.getFlowController(streamID)
if err != nil { if err != nil {
return false, err return false, err
} }
}
return streamFlowController.UpdateSendWindow(offset), nil return fc.UpdateSendWindow(offset), nil
} }
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) { func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) {

View File

@ -6,91 +6,93 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowController struct { type flowController struct {
streamID protocol.StreamID streamID protocol.StreamID
contributesToConnection bool // does the stream contribute to connection level flow control
connectionParameters handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
bytesSent protocol.ByteCount bytesSent protocol.ByteCount
sendFlowControlWindow protocol.ByteCount sendWindow protocol.ByteCount
lastWindowUpdateTime time.Time lastWindowUpdateTime time.Time
bytesRead protocol.ByteCount bytesRead protocol.ByteCount
highestReceived protocol.ByteCount highestReceived protocol.ByteCount
receiveFlowControlWindow protocol.ByteCount receiveWindow protocol.ByteCount
receiveFlowControlWindowIncrement protocol.ByteCount receiveWindowIncrement protocol.ByteCount
maxReceiveFlowControlWindowIncrement protocol.ByteCount maxReceiveWindowIncrement protocol.ByteCount
} }
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously // ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
// newFlowController gets a new flow controller // newFlowController gets a new flow controller
func newFlowController(streamID protocol.StreamID, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController { func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
fc := flowController{ fc := flowController{
streamID: streamID, streamID: streamID,
contributesToConnection: contributesToConnection,
connectionParameters: connectionParameters, connectionParameters: connectionParameters,
rttStats: rttStats, rttStats: rttStats,
} }
if streamID == 0 { if streamID == 0 {
fc.receiveFlowControlWindow = connectionParameters.GetReceiveConnectionFlowControlWindow() fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow fc.receiveWindowIncrement = fc.receiveWindow
fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow() fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
} else { } else {
fc.receiveFlowControlWindow = connectionParameters.GetReceiveStreamFlowControlWindow() fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow fc.receiveWindowIncrement = fc.receiveWindow
fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow() fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
} }
return &fc return &fc
} }
func (c *flowController) getSendFlowControlWindow() protocol.ByteCount { func (c *flowController) ContributesToConnection() bool {
if c.sendFlowControlWindow == 0 { return c.contributesToConnection
}
func (c *flowController) getSendWindow() protocol.ByteCount {
if c.sendWindow == 0 {
if c.streamID == 0 { if c.streamID == 0 {
return c.connectionParameters.GetSendConnectionFlowControlWindow() return c.connectionParameters.GetSendConnectionFlowControlWindow()
} }
return c.connectionParameters.GetSendStreamFlowControlWindow() return c.connectionParameters.GetSendStreamFlowControlWindow()
} }
return c.sendFlowControlWindow return c.sendWindow
} }
func (c *flowController) AddBytesSent(n protocol.ByteCount) { func (c *flowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n c.bytesSent += n
} }
func (c *flowController) GetBytesSent() protocol.ByteCount {
return c.bytesSent
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame // UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated // it returns true if the window was actually updated
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
if newOffset > c.sendFlowControlWindow { if newOffset > c.sendWindow {
c.sendFlowControlWindow = newOffset c.sendWindow = newOffset
return true return true
} }
return false return false
} }
func (c *flowController) SendWindowSize() protocol.ByteCount { func (c *flowController) SendWindowSize() protocol.ByteCount {
sendFlowControlWindow := c.getSendFlowControlWindow() sendWindow := c.getSendWindow()
if c.bytesSent > sendFlowControlWindow { // should never happen, but make sure we don't do an underflow here if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
return 0 return 0
} }
return sendFlowControlWindow - c.bytesSent return sendWindow - c.bytesSent
} }
func (c *flowController) SendWindowOffset() protocol.ByteCount { func (c *flowController) SendWindowOffset() protocol.ByteCount {
return c.getSendFlowControlWindow() return c.getSendWindow()
} }
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher // UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
@ -117,28 +119,39 @@ func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount)
} }
func (c *flowController) AddBytesRead(n protocol.ByteCount) { func (c *flowController) AddBytesRead(n protocol.ByteCount) {
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window increment already works for the first WindowUpdate
if c.bytesRead == 0 {
c.lastWindowUpdateTime = time.Now()
}
c.bytesRead += n c.bytesRead += n
} }
// MaybeTriggerWindowUpdate determines if it is necessary to send a WindowUpdate // MaybeUpdateWindow updates the receive window, if necessary
// if so, it returns true and the offset of the window // if the receive window increment is changed, the new value is returned, otherwise a 0
func (c *flowController) MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) { // the last return value is the new offset of the receive window
diff := c.receiveFlowControlWindow - c.bytesRead func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) {
diff := c.receiveWindow - c.bytesRead
// Chromium implements the same threshold // Chromium implements the same threshold
if diff < (c.receiveFlowControlWindowIncrement / 2) { if diff < (c.receiveWindowIncrement / 2) {
var newWindowIncrement protocol.ByteCount
oldWindowIncrement := c.receiveWindowIncrement
c.maybeAdjustWindowIncrement() c.maybeAdjustWindowIncrement()
c.lastWindowUpdateTime = time.Now() if c.receiveWindowIncrement != oldWindowIncrement {
newWindowIncrement = c.receiveWindowIncrement
c.receiveFlowControlWindow = c.bytesRead + c.receiveFlowControlWindowIncrement
return true, c.receiveFlowControlWindow
} }
return false, 0 c.lastWindowUpdateTime = time.Now()
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
return true, newWindowIncrement, c.receiveWindow
}
return false, 0, 0
} }
// maybeAdjustWindowIncrement increases the receiveFlowControlWindowIncrement if we're sending WindowUpdates too often // maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
func (c *flowController) maybeAdjustWindowIncrement() { func (c *flowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() { if c.lastWindowUpdateTime.IsZero() {
return return
@ -149,19 +162,19 @@ func (c *flowController) maybeAdjustWindowIncrement() {
return return
} }
timeSinceLastWindowUpdate := time.Now().Sub(c.lastWindowUpdateTime) timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
// interval between the window updates is sufficiently large, no need to increase the increment // interval between the window updates is sufficiently large, no need to increase the increment
if timeSinceLastWindowUpdate >= 2*rtt { if timeSinceLastWindowUpdate >= 2*rtt {
return return
} }
oldWindowSize := c.receiveFlowControlWindowIncrement oldWindowSize := c.receiveWindowIncrement
c.receiveFlowControlWindowIncrement = utils.MinByteCount(2*c.receiveFlowControlWindowIncrement, c.maxReceiveFlowControlWindowIncrement) c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
// debug log, if the window size was actually increased // debug log, if the window size was actually increased
if oldWindowSize < c.receiveFlowControlWindowIncrement { if oldWindowSize < c.receiveWindowIncrement {
newWindowSize := c.receiveFlowControlWindowIncrement / (1 << 10) newWindowSize := c.receiveWindowIncrement / (1 << 10)
if c.streamID == 0 { if c.streamID == 0 {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize) utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
} else { } else {
@ -170,9 +183,16 @@ func (c *flowController) maybeAdjustWindowIncrement() {
} }
} }
func (c *flowController) CheckFlowControlViolation() bool { // EnsureMinimumWindowIncrement sets a minimum window increment
if c.highestReceived > c.receiveFlowControlWindow { // it is intended be used for the connection-level flow controller
return true // it should make sure that the connection-level window is increased when a stream-level window grows
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
if inc > c.receiveWindowIncrement {
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
} }
return false }
func (c *flowController) CheckFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
} }

View File

@ -5,8 +5,8 @@ import (
"errors" "errors"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (
@ -222,7 +222,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
utils.WriteUint48(b, uint64(f.LargestAcked)) utils.WriteUint48(b, uint64(f.LargestAcked))
} }
f.DelayTime = time.Now().Sub(f.PacketReceivedTime) f.DelayTime = time.Since(f.PacketReceivedTime)
utils.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) utils.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond))
var numRanges uint64 var numRanges uint64
@ -332,8 +332,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
// MinLength of a written frame // MinLength of a written frame
func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
var length protocol.ByteCount length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
length = 1 + 2 + 1 // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked)) length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked))
missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen()) missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen())
@ -351,10 +350,7 @@ func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount
// HasMissingRanges returns if this frame reports any missing packets // HasMissingRanges returns if this frame reports any missing packets
func (f *AckFrame) HasMissingRanges() bool { func (f *AckFrame) HasMissingRanges() bool {
if len(f.AckRanges) > 0 { return len(f.AckRanges) > 0
return true
}
return false
} }
func (f *AckFrame) validateAckRanges() bool { func (f *AckFrame) validateAckRanges() bool {

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A BlockedFrame in QUIC // A BlockedFrame in QUIC

View File

@ -6,9 +6,9 @@ import (
"io" "io"
"math" "math"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A ConnectionCloseFrame in QUIC // A ConnectionCloseFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A GoawayFrame is a GOAWAY frame // A GoawayFrame is a GOAWAY frame

View File

@ -1,6 +1,6 @@
package frames package frames
import "github.com/lucas-clemente/quic-go/utils" import "github.com/lucas-clemente/quic-go/internal/utils"
// LogFrame logs a frame, either sent or received // LogFrame logs a frame, either sent or received
func LogFrame(frame Frame, sent bool) { func LogFrame(frame Frame, sent bool) {

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A RstStreamFrame in QUIC // A RstStreamFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StopWaitingFrame in QUIC // A StopWaitingFrame in QUIC
@ -56,8 +56,7 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber
// MinLength of a written frame // MinLength of a written frame
func (f *StopWaitingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { func (f *StopWaitingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
var minLength protocol.ByteCount minLength := protocol.ByteCount(1) // typeByte
minLength = 1 // typeByte
if f.PacketNumberLen == protocol.PacketNumberLenInvalid { if f.PacketNumberLen == protocol.PacketNumberLenInvalid {
return 0, errPacketNumberLenNotSet return 0, errPacketNumberLenNotSet

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StreamFrame of QUIC // A StreamFrame of QUIC
@ -64,7 +64,7 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) {
return nil, qerr.Error(qerr.InvalidStreamData, "data len too large") return nil, qerr.Error(qerr.InvalidStreamData, "data len too large")
} }
if dataLen == 0 { if !frame.DataLenPresent {
// The rest of the packet is data // The rest of the packet is data
dataLen = uint16(r.Len()) dataLen = uint16(r.Len())
} }
@ -79,7 +79,11 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) {
} }
} }
if !frame.FinBit && len(frame.Data) == 0 { if frame.Offset+frame.DataLen() < frame.Offset {
return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset")
}
if !frame.FinBit && frame.DataLen() == 0 {
return nil, qerr.EmptyStreamFrameNoFin return nil, qerr.EmptyStreamFrameNoFin
} }

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A WindowUpdateFrame in QUIC // A WindowUpdateFrame in QUIC

View File

@ -15,89 +15,90 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type quicClient interface { type roundTripperOpts struct {
OpenStream(protocol.StreamID) (utils.Stream, error) DisableCompression bool
Close(error) error
Listen() error
} }
// Client is a HTTP2 client doing QUIC requests var dialAddr = quic.DialAddr
type Client struct {
mutex sync.RWMutex
cryptoChangedCond sync.Cond
t *QuicRoundTripper // client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
hostname string hostname string
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
handshakeErr error
dialOnce sync.Once
client quicClient session quic.Session
headerStream utils.Stream headerStream quic.Stream
headerErr *qerr.QuicError headerErr *qerr.QuicError
highestOpenedStream protocol.StreamID headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response responses map[protocol.StreamID]chan *http.Response
} }
var _ h2quicClient = &Client{} var _ http.RoundTripper = &client{}
// NewClient creates a new client var defaultQuicConfig = &quic.Config{
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) { RequestConnectionIDTruncation: true,
c := &Client{ KeepAlive: true,
t: t, }
// newClient creates a new client
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
highestOpenedStream: 3,
responses: make(map[protocol.StreamID]chan *http.Response), responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted,
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
} }
c.cryptoChangedCond = sync.Cond{L: &c.mutex} }
// dial dials the connection
func (c *client) dial() error {
var err error var err error
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback) c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
if err != nil {
return nil, err
}
go c.client.Listen()
return c, nil
}
func (c *Client) handleStreamCb(session *quic.Session, stream utils.Stream) {
utils.Debugf("Handling stream %d", stream.StreamID())
}
func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
c.cryptoChangedCond.L.Lock()
defer c.cryptoChangedCond.L.Unlock()
if isForwardSecure {
c.encryptionLevel = protocol.EncryptionForwardSecure
utils.Debugf("is forward secure")
} else {
c.encryptionLevel = protocol.EncryptionSecure
utils.Debugf("is secure")
}
c.cryptoChangedCond.Broadcast()
}
func (c *Client) versionNegotiateCallback() error {
var err error
// once the version has been negotiated, open the header stream
c.headerStream, err = c.client.OpenStream(3)
if err != nil { if err != nil {
return err return err
} }
// once the version has been negotiated, open the header stream
c.headerStream, err = c.session.OpenStream()
if err != nil {
return err
}
if c.headerStream.StreamID() != 3 {
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
}
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream() go c.handleHeaderStream()
return nil return nil
} }
func (c *Client) handleHeaderStream() { func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream) h2framer := http2.NewFramer(nil, c.headerStream)
@ -106,7 +107,7 @@ func (c *Client) handleHeaderStream() {
for { for {
frame, err := h2framer.ReadFrame() frame, err := h2framer.ReadFrame()
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame") c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
break break
} }
lastStream = protocol.StreamID(frame.Header().StreamID) lastStream = protocol.StreamID(frame.Header().StreamID)
@ -123,7 +124,7 @@ func (c *Client) handleHeaderStream() {
} }
c.mutex.RLock() c.mutex.RLock()
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock() c.mutex.RUnlock()
if !ok { if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
@ -134,57 +135,53 @@ func (c *Client) handleHeaderStream() {
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) c.headerErr = qerr.Error(qerr.InternalError, err.Error())
} }
headerChan <- rsp responseChan <- rsp
} }
// stop all running request // stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
c.mutex.Lock() close(c.headerErrored)
for _, responseChan := range c.responses {
responseChan <- nil
}
c.mutex.Unlock()
} }
// Do executes a request and returns a response // Roundtrip executes a request and returns a response
func (c *Client) Do(req *http.Request) (*http.Response, error) { func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one // TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" { if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme") return nil, errors.New("quic http2: unsupported scheme")
} }
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
utils.Debugf("%s vs %s", req.Host, c.hostname) return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
return nil, errors.New("h2quic Client BUG: Do called for the wrong client") }
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
} }
hasBody := (req.Body != nil) hasBody := (req.Body != nil)
c.mutex.Lock() responseChan := make(chan *http.Response)
c.highestOpenedStream += 2 dataStream, err := c.session.OpenStreamSync()
dataStreamID := c.highestOpenedStream
for c.encryptionLevel != protocol.EncryptionForwardSecure {
c.cryptoChangedCond.Wait()
}
hdrChan := make(chan *http.Response)
c.responses[dataStreamID] = hdrChan
c.mutex.Unlock()
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
dataStream, err := c.client.OpenStream(dataStreamID)
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
c.mutex.Lock()
c.responses[dataStream.StreamID()] = responseChan
c.mutex.Unlock()
var requestedGzip bool var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true requestedGzip = true
} }
// TODO: add support for trailers // TODO: add support for trailers
endStream := !hasBody endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip) err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
@ -206,20 +203,20 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
for !(bodySent && receivedResponse) { for !(bodySent && receivedResponse) {
select { select {
case res = <-hdrChan: case res = <-responseChan:
receivedResponse = true receivedResponse = true
c.mutex.Lock() c.mutex.Lock()
delete(c.responses, dataStreamID) delete(c.responses, dataStream.StreamID())
c.mutex.Unlock() c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
case err := <-resc: case err := <-resc:
bodySent = true bodySent = true
if err != nil { if err != nil {
return nil, err return nil, err
} }
case <-c.headerErrored:
// an error occured on the header stream
_ = c.CloseWithError(c.headerErr)
return nil, c.headerErr
} }
} }
@ -238,16 +235,15 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
res.Header.Del("Content-Length") res.Header.Del("Content-Length")
res.ContentLength = -1 res.ContentLength = -1
res.Body = &gzipReader{body: res.Body} res.Body = &gzipReader{body: res.Body}
setUncompressed(res) res.Uncompressed = true
} }
} }
res.Request = req res.Request = req
return res, nil return res, nil
} }
func (c *Client) writeRequestBody(dataStream utils.Stream, body io.ReadCloser) (err error) { func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() { defer func() {
cerr := body.Close() cerr := body.Close()
if err == nil { if err == nil {
@ -265,8 +261,15 @@ func (c *Client) writeRequestBody(dataStream utils.Stream, body io.ReadCloser) (
} }
// Close closes the client // Close closes the client
func (c *Client) Close(e error) { func (c *client) CloseWithError(e error) error {
_ = c.client.Close(e) if c.session == nil {
return nil
}
return c.session.Close(e)
}
func (c *client) Close() error {
return c.CloseWithError(nil)
} }
// copied from net/transport.go // copied from net/transport.go

View File

@ -3,18 +3,18 @@ package h2quic
import ( import (
"io" "io"
"github.com/lucas-clemente/quic-go/utils" quic "github.com/lucas-clemente/quic-go"
) )
type requestBody struct { type requestBody struct {
requestRead bool requestRead bool
dataStream utils.Stream dataStream quic.Stream
} }
// make sure the requestBody can be used as a http.Request.Body // make sure the requestBody can be used as a http.Request.Body
var _ io.ReadCloser = &requestBody{} var _ io.ReadCloser = &requestBody{}
func newRequestBody(stream utils.Stream) *requestBody { func newRequestBody(stream quic.Stream) *requestBody {
return &requestBody{dataStream: stream} return &requestBody{dataStream: stream}
} }

View File

@ -12,13 +12,14 @@ import (
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type requestWriter struct { type requestWriter struct {
mutex sync.Mutex mutex sync.Mutex
headerStream utils.Stream headerStream quic.Stream
henc *hpack.Encoder henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this hbuf bytes.Buffer // HPACK encoder writes into this
@ -26,7 +27,7 @@ type requestWriter struct {
const defaultUserAgent = "quic-go" const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream utils.Stream) *requestWriter { func newRequestWriter(headerStream quic.Stream) *requestWriter {
rw := &requestWriter{ rw := &requestWriter{
headerStream: headerStream, headerStream: headerStream,
} }

View File

@ -1,9 +0,0 @@
// +build go1.7
package h2quic
import "net/http"
func setUncompressed(res *http.Response) {
res.Uncompressed = true
}

View File

@ -1,9 +0,0 @@
// +build !go1.7
package h2quic
import "net/http"
func setUncompressed(res *http.Response) {
// http.Response.Uncompressed was introduced in go 1.7
}

View File

@ -7,17 +7,18 @@ import (
"strings" "strings"
"sync" "sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
type responseWriter struct { type responseWriter struct {
dataStreamID protocol.StreamID dataStreamID protocol.StreamID
dataStream utils.Stream dataStream quic.Stream
headerStream utils.Stream headerStream quic.Stream
headerStreamMutex *sync.Mutex headerStreamMutex *sync.Mutex
header http.Header header http.Header
@ -25,7 +26,7 @@ type responseWriter struct {
headerWritten bool headerWritten bool
} }
func newResponseWriter(headerStream utils.Stream, headerStreamMutex *sync.Mutex, dataStream utils.Stream, dataStreamID protocol.StreamID) *responseWriter { func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter {
return &responseWriter{ return &responseWriter{
header: http.Header{}, header: http.Header{},
headerStream: headerStream, headerStream: headerStream,
@ -82,9 +83,15 @@ func (w *responseWriter) Write(p []byte) (int, error) {
func (w *responseWriter) Flush() {} func (w *responseWriter) Flush() {}
// TODO: Implement a functional CloseNotify method.
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
// test that we implement http.Flusher // test that we implement http.Flusher
var _ http.Flusher = &responseWriter{} var _ http.Flusher = &responseWriter{}
// test that we implement http.CloseNotifier
var _ http.CloseNotifier = &responseWriter{}
// copied from http2/http2.go // copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code // bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4. // permits a body. See RFC 2616, section 4.4.

View File

@ -4,19 +4,23 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
) )
type h2quicClient interface { type roundTripCloser interface {
Do(*http.Request) (*http.Response, error) http.RoundTripper
io.Closer
} }
// QuicRoundTripper implements the http.RoundTripper interface // RoundTripper implements the http.RoundTripper interface
type QuicRoundTripper struct { type RoundTripper struct {
mutex sync.Mutex mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from // DisableCompression, if true, prevents the Transport from
@ -33,13 +37,29 @@ type QuicRoundTripper struct {
// tls.Client. If nil, the default configuration is used. // tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
clients map[string]h2quicClient // QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
clients map[string]roundTripCloser
} }
var _ http.RoundTripper = &QuicRoundTripper{} // RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
// RoundTrip does a round trip var _ roundTripCloser = &RoundTripper{}
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil { if req.URL == nil {
closeRequestBody(req) closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL") return nil, errors.New("quic: nil Request.URL")
@ -75,35 +95,48 @@ func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
} }
hostname := authorityAddr("https", hostnameFromRequest(req)) hostname := authorityAddr("https", hostnameFromRequest(req))
client, err := r.getClient(hostname) cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.Do(req) return cl.RoundTrip(req)
} }
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { // RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.clients == nil { if r.clients == nil {
r.clients = make(map[string]h2quicClient) r.clients = make(map[string]roundTripCloser)
} }
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
if !ok { if !ok {
var err error if onlyCached {
client, err = NewClient(r, r.TLSClientConfig, hostname) return nil, ErrNoCachedConn
if err != nil {
return nil, err
} }
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
r.clients[hostname] = client r.clients[hostname] = client
} }
return client, nil return client, nil
} }
func (r *QuicRoundTripper) disableCompression() bool { // Close closes the QUIC connections that this RoundTripper has used
return r.DisableCompression func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
} }
func closeRequestBody(req *http.Request) { func closeRequestBody(req *http.Request) {

View File

@ -7,35 +7,51 @@ import (
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
type streamCreator interface { type streamCreator interface {
GetOrOpenStream(protocol.StreamID) (utils.Stream, error) quic.Session
Close(error) error GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
RemoteAddr() *net.UDPAddr
} }
type remoteCloser interface {
CloseRemote(protocol.ByteCount)
}
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections. // Server is a HTTP2 server listening for QUIC connections.
type Server struct { type Server struct {
*http.Server *http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use // Private flag for demo, do not use
CloseAfterFirstRequest bool CloseAfterFirstRequest bool
port uint32 // used atomically port uint32 // used atomically
server *quic.Server listenerMutex sync.Mutex
serverMutex sync.Mutex listener quic.Listener
supportedVersionsAsString string
} }
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
@ -63,39 +79,51 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
} }
// Serve an existing UDP connection. // Serve an existing UDP connection.
func (s *Server) Serve(conn *net.UDPConn) error { func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn) return s.serveImpl(s.TLSConfig, conn)
} }
func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
s.serverMutex.Lock() s.listenerMutex.Lock()
if s.server != nil { if s.listener != nil {
s.serverMutex.Unlock() s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once") return errors.New("ListenAndServe may only be called once")
} }
var ln quic.Listener
var err error var err error
server, err := quic.NewServer(s.Addr, tlsConfig, s.handleStreamCb) if conn == nil {
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else {
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
}
if err != nil { if err != nil {
s.serverMutex.Unlock() s.listenerMutex.Unlock()
return err return err
} }
s.server = server s.listener = ln
s.serverMutex.Unlock() s.listenerMutex.Unlock()
if conn == nil {
return server.ListenAndServe() for {
sess, err := ln.Accept()
if err != nil {
return err
}
go s.handleHeaderStream(sess.(streamCreator))
} }
return server.Serve(conn)
} }
func (s *Server) handleStreamCb(session *quic.Session, stream utils.Stream) { func (s *Server) handleHeaderStream(session streamCreator) {
s.handleStream(session, stream) stream, err := session.AcceptStream()
} if err != nil {
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
func (s *Server) handleStream(session streamCreator, stream utils.Stream) { return
}
if stream.StreamID() != 3 { if stream.StreamID() != 3 {
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
return return
} }
@ -112,17 +140,17 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) {
if _, ok := err.(*qerr.QuicError); !ok { if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error()) utils.Errorf("error handling h2 request: %s", err.Error())
} }
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) session.Close(err)
return return
} }
} }
}() }()
} }
func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
h2frame, err := h2framer.ReadFrame() h2frame, err := h2framer.ReadFrame()
if err != nil { if err != nil {
return err return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
} }
h2headersFrame, ok := h2frame.(*http2.HeadersFrame) h2headersFrame, ok := h2frame.(*http2.HeadersFrame)
if !ok { if !ok {
@ -154,10 +182,14 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
if err != nil { if err != nil {
return err return err
} }
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
if dataStream == nil {
return nil
}
var streamEnded bool var streamEnded bool
if h2headersFrame.StreamEnded() { if h2headersFrame.StreamEnded() {
dataStream.CloseRemote(0) dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof _, _ = dataStream.Read([]byte{0}) // read the eof
} }
@ -209,11 +241,11 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Close() error { func (s *Server) Close() error {
s.serverMutex.Lock() s.listenerMutex.Lock()
defer s.serverMutex.Unlock() defer s.listenerMutex.Unlock()
if s.server != nil { if s.listener != nil {
err := s.server.Close() err := s.listener.Close()
s.server = nil s.listener = nil
return err return err
} }
return nil return nil
@ -228,7 +260,6 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC. // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443): // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
// Alternate-Protocol: 443:quic
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error { func (s *Server) SetQuicHeaders(hdr http.Header) error {
port := atomic.LoadUint32(&s.port) port := atomic.LoadUint32(&s.port)
@ -247,8 +278,16 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
atomic.StoreUint32(&s.port, port) atomic.StoreUint32(&s.port, port)
} }
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port)) if s.supportedVersionsAsString == "" {
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, protocol.SupportedVersionsAsString)) for i, v := range protocol.SupportedVersions {
s.supportedVersionsAsString += strconv.Itoa(int(v))
if i != len(protocol.SupportedVersions)-1 {
s.supportedVersionsAsString += ","
}
}
}
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil return nil
} }

View File

@ -2,13 +2,12 @@ package handshake
import ( import (
"bytes" "bytes"
"errors"
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// ConnectionParametersManager negotiates and stores the connection parameters // ConnectionParametersManager negotiates and stores the connection parameters
@ -42,7 +41,6 @@ type connectionParametersManager struct {
perspective protocol.Perspective perspective protocol.Perspective
flowControlNegotiated bool flowControlNegotiated bool
hasReceivedMaxIncomingDynamicStreams bool
truncateConnectionID bool truncateConnectionID bool
maxStreamsPerConnection uint32 maxStreamsPerConnection uint32
@ -52,12 +50,12 @@ type connectionParametersManager struct {
sendConnectionFlowControlWindow protocol.ByteCount sendConnectionFlowControlWindow protocol.ByteCount
receiveStreamFlowControlWindow protocol.ByteCount receiveStreamFlowControlWindow protocol.ByteCount
receiveConnectionFlowControlWindow protocol.ByteCount receiveConnectionFlowControlWindow protocol.ByteCount
maxReceiveStreamFlowControlWindow protocol.ByteCount
maxReceiveConnectionFlowControlWindow protocol.ByteCount
} }
var _ ConnectionParametersManager = &connectionParametersManager{} var _ ConnectionParametersManager = &connectionParametersManager{}
var errTagNotInConnectionParameterMap = errors.New("ConnectionParametersManager: Tag not found in ConnectionsParameter map")
// ErrMalformedTag is returned when the tag value cannot be read // ErrMalformedTag is returned when the tag value cannot be read
var ( var (
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
@ -65,7 +63,10 @@ var (
) )
// NewConnectionParamatersManager creates a new connection parameters manager // NewConnectionParamatersManager creates a new connection parameters manager
func NewConnectionParamatersManager(pers protocol.Perspective, v protocol.VersionNumber) ConnectionParametersManager { func NewConnectionParamatersManager(
pers protocol.Perspective, v protocol.VersionNumber,
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
) ConnectionParametersManager {
h := &connectionParametersManager{ h := &connectionParametersManager{
perspective: pers, perspective: pers,
version: v, version: v,
@ -73,6 +74,8 @@ func NewConnectionParamatersManager(pers protocol.Perspective, v protocol.Versio
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
} }
if h.perspective == protocol.PerspectiveServer { if h.perspective == protocol.PerspectiveServer {
@ -113,7 +116,6 @@ func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error {
return ErrMalformedTag return ErrMalformedTag
} }
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue) h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue)
h.hasReceivedMaxIncomingDynamicStreams = true
} }
if value, ok := params[TagICSL]; ok { if value, ok := params[TagICSL]; ok {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
@ -175,23 +177,18 @@ func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) {
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow()))
mspc := bytes.NewBuffer([]byte{}) mspc := bytes.NewBuffer([]byte{})
utils.WriteUint32(mspc, h.maxStreamsPerConnection) utils.WriteUint32(mspc, h.maxStreamsPerConnection)
mids := bytes.NewBuffer([]byte{})
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
icsl := bytes.NewBuffer([]byte{}) icsl := bytes.NewBuffer([]byte{})
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second)) utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second))
tags := map[Tag][]byte{ return map[Tag][]byte{
TagICSL: icsl.Bytes(), TagICSL: icsl.Bytes(),
TagMSPC: mspc.Bytes(), TagMSPC: mspc.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(), TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(), TagSFCW: sfcw.Bytes(),
} }, nil
if h.version > protocol.Version34 {
mids := bytes.NewBuffer([]byte{})
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
tags[TagMIDS] = mids.Bytes()
}
return tags, nil
} }
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data // GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
@ -217,10 +214,7 @@ func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protoc
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveStreamFlowControlWindow
return protocol.MaxReceiveStreamFlowControlWindowServer
}
return protocol.MaxReceiveStreamFlowControlWindowClient
} }
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data // GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
@ -232,10 +226,7 @@ func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() pr
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveConnectionFlowControlWindow
return protocol.MaxReceiveConnectionFlowControlWindowServer
}
return protocol.MaxReceiveConnectionFlowControlWindowClient
} }
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection // GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
@ -243,10 +234,7 @@ func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.version > protocol.Version34 && h.hasReceivedMaxIncomingDynamicStreams {
return h.maxIncomingDynamicStreamsPerConnection return h.maxIncomingDynamicStreamsPerConnection
}
return h.maxStreamsPerConnection
} }
// GetMaxIncomingStreams get the maximum number of incoming streams per connection // GetMaxIncomingStreams get the maximum number of incoming streams per connection
@ -254,14 +242,8 @@ func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
var val uint32 maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection
if h.version <= protocol.Version34 { return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier))
val = h.maxStreamsPerConnection
} else {
val = protocol.MaxIncomingDynamicStreamsPerConnection
}
return utils.MaxUint32(val+protocol.MaxStreamsMinimumIncrement, uint32(float64(val)*protocol.MaxStreamsMultiplier))
} }
// GetIdleConnectionStateLifetime gets the idle timeout // GetIdleConnectionStateLifetime gets the idle timeout

View File

@ -12,9 +12,9 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type cryptoSetupClient struct { type cryptoSetupClient struct {
@ -25,7 +25,7 @@ type cryptoSetupClient struct {
version protocol.VersionNumber version protocol.VersionNumber
negotiatedVersions []protocol.VersionNumber negotiatedVersions []protocol.VersionNumber
cryptoStream utils.Stream cryptoStream io.ReadWriter
serverConfig *serverConfigClient serverConfig *serverConfigClient
@ -33,24 +33,28 @@ type cryptoSetupClient struct {
sno []byte sno []byte
nonc []byte nonc []byte
proof []byte proof []byte
diversificationNonce []byte
chloForSignature []byte chloForSignature []byte
lastSentCHLO []byte lastSentCHLO []byte
certManager crypto.CertManager certManager crypto.CertManager
divNonceChan chan []byte
diversificationNonce []byte
clientHelloCounter int clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation KeyDerivationFunction keyDerivation KeyDerivationFunction
keyExchange KeyExchangeFunction
receivedSecurePacket bool receivedSecurePacket bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
aeadChanged chan struct{} aeadChanged chan<- protocol.EncryptionLevel
params *TransportParameters
connectionParameters ConnectionParametersManager connectionParameters ConnectionParametersManager
} }
var _ crypto.AEAD = &cryptoSetupClient{}
var _ CryptoSetup = &cryptoSetupClient{} var _ CryptoSetup = &cryptoSetupClient{}
var ( var (
@ -64,10 +68,11 @@ func NewCryptoSetupClient(
hostname string, hostname string,
connID protocol.ConnectionID, connID protocol.ConnectionID,
version protocol.VersionNumber, version protocol.VersionNumber,
cryptoStream utils.Stream, cryptoStream io.ReadWriter,
tlsConfig *tls.Config, tlsConfig *tls.Config,
connectionParameters ConnectionParametersManager, connectionParameters ConnectionParametersManager,
aeadChanged chan struct{}, aeadChanged chan<- protocol.EncryptionLevel,
params *TransportParameters,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
return &cryptoSetupClient{ return &cryptoSetupClient{
@ -78,57 +83,77 @@ func NewCryptoSetupClient(
certManager: crypto.NewCertManager(tlsConfig), certManager: crypto.NewCertManager(tlsConfig),
connectionParameters: connectionParameters, connectionParameters: connectionParameters,
keyDerivation: crypto.DeriveKeysAESGCM, keyDerivation: crypto.DeriveKeysAESGCM,
keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
negotiatedVersions: negotiatedVersions, negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte),
params: params,
}, nil }, nil
} }
func (h *cryptoSetupClient) HandleCryptoStream() error { func (h *cryptoSetupClient) HandleCryptoStream() error {
messageChan := make(chan HandshakeMessage)
errorChan := make(chan error)
go func() {
for {
message, err := ParseHandshakeMessage(h.cryptoStream)
if err != nil {
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
return
}
messageChan <- message
}
}()
for { for {
err := h.maybeUpgradeCrypto() err := h.maybeUpgradeCrypto()
if err != nil { if err != nil {
return err return err
} }
// send CHLOs until the forward secure encryption is established h.mutex.RLock()
if h.forwardSecureAEAD == nil { sendCHLO := h.secureAEAD == nil
h.mutex.RUnlock()
if sendCHLO {
err = h.sendCHLO() err = h.sendCHLO()
if err != nil { if err != nil {
return err return err
} }
} }
var shloData bytes.Buffer var message HandshakeMessage
select {
messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &shloData)) case divNonce := <-h.divNonceChan:
if err != nil { if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) {
return qerr.HandshakeFailed return errConflictingDiversificationNonces
}
h.diversificationNonce = divNonce
// there's no message to process, but we should try upgrading the crypto again
continue
case message = <-messageChan:
case err = <-errorChan:
return err
} }
if messageTag != TagSHLO && messageTag != TagREJ { utils.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
err = h.handleREJMessage(message.Data)
case TagSHLO:
err = h.handleSHLOMessage(message.Data)
default:
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }
if messageTag == TagSHLO {
utils.Debugf("Got SHLO:\n%s", printHandshakeMessage(cryptoData))
err = h.handleSHLOMessage(cryptoData)
if err != nil { if err != nil {
return err return err
} }
} }
if messageTag == TagREJ {
err = h.handleREJMessage(cryptoData)
if err != nil {
return err
}
}
}
} }
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
utils.Debugf("Got REJ:\n%s", printHandshakeMessage(cryptoData))
var err error var err error
if stk, ok := cryptoData[TagSTK]; ok { if stk, ok := cryptoData[TagSTK]; ok {
@ -244,7 +269,8 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
return qerr.InvalidCryptoMessageParameter return qerr.InvalidCryptoMessageParameter
} }
h.aeadChanged <- struct{}{} h.aeadChanged <- protocol.EncryptionForwardSecure
close(h.aeadChanged)
return nil return nil
} }
@ -264,7 +290,7 @@ func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
return false return false
} }
ver := protocol.VersionTagToNumber(verTag) ver := protocol.VersionTagToNumber(verTag)
if !protocol.IsSupportedVersion(ver) { if !protocol.IsSupportedVersion(protocol.SupportedVersions, ver) {
ver = protocol.VersionUnsupported ver = protocol.VersionUnsupported
} }
if ver != negotiatedVersion { if ver != negotiatedVersion {
@ -274,67 +300,90 @@ func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
return true return true
} }
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil { if err == nil {
return data, nil return data, protocol.EncryptionForwardSecure, nil
} }
return nil, err return nil, protocol.EncryptionUnspecified, err
} }
if h.secureAEAD != nil { if h.secureAEAD != nil {
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil { if err == nil {
h.receivedSecurePacket = true h.receivedSecurePacket = true
return data, nil return data, protocol.EncryptionSecure, nil
} }
if h.receivedSecurePacket { if h.receivedSecurePacket {
return nil, err return nil, protocol.EncryptionUnspecified, err
} }
} }
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
return (&crypto.NullAEAD{}).Open(dst, src, packetNumber, associatedData) if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, nil
} }
func (h *cryptoSetupClient) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) return protocol.EncryptionForwardSecure, h.sealForwardSecure
} else if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure
} else {
return protocol.EncryptionUnencrypted, h.sealUnencrypted
} }
if h.secureAEAD != nil { }
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.sealUnencrypted
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.sealUnencrypted, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no secureAEAD")
}
return h.sealSecure, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
}
return h.sealForwardSecure, nil
}
return nil, errors.New("CryptoSetupClient: no encryption level specified")
}
func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.nullAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
} }
return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData)
func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData)
} }
func (h *cryptoSetupClient) DiversificationNonce() []byte { func (h *cryptoSetupClient) DiversificationNonce() []byte {
panic("not needed for cryptoSetupClient") panic("not needed for cryptoSetupClient")
} }
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error { func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
if len(h.diversificationNonce) == 0 { h.divNonceChan <- data
h.diversificationNonce = data
return h.maybeUpgradeCrypto()
}
if !bytes.Equal(h.diversificationNonce, data) {
return errConflictingDiversificationNonces
}
return nil
}
func (h *cryptoSetupClient) LockForSealing() {
}
func (h *cryptoSetupClient) UnlockForSealing() {
}
func (h *cryptoSetupClient) HandshakeComplete() bool {
h.mutex.RLock()
complete := h.forwardSecureAEAD != nil
h.mutex.RUnlock()
return complete
} }
func (h *cryptoSetupClient) sendCHLO() error { func (h *cryptoSetupClient) sendCHLO() error {
@ -350,9 +399,13 @@ func (h *cryptoSetupClient) sendCHLO() error {
return err return err
} }
h.addPadding(tags) h.addPadding(tags)
message := HandshakeMessage{
Tag: TagCHLO,
Data: tags,
}
utils.Debugf("Sending CHLO:\n%s", printHandshakeMessage(tags)) utils.Debugf("Sending %s", message)
WriteHandshakeMessage(b, TagCHLO, tags) message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes()) _, err = h.cryptoStream.Write(b.Bytes())
if err != nil { if err != nil {
@ -377,14 +430,16 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
tags[TagCCS] = ccs tags[TagCCS] = ccs
} }
versionTag := make([]byte, 4, 4) versionTag := make([]byte, 4)
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
tags[TagVER] = versionTag tags[TagVER] = versionTag
if h.params.RequestConnectionIDTruncation {
tags[TagTCID] = []byte{0, 0, 0, 0}
}
if len(h.stk) > 0 { if len(h.stk) > 0 {
tags[TagSTK] = h.stk tags[TagSTK] = h.stk
} }
if len(h.sno) > 0 { if len(h.sno) > 0 {
tags[TagSNO] = h.sno tags[TagSNO] = h.sno
} }
@ -395,7 +450,7 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
leafCert := h.certManager.GetLeafCert() leafCert := h.certManager.GetLeafCert()
if leafCert != nil { if leafCert != nil {
certHash, _ := h.certManager.GetLeafCertHash() certHash, _ := h.certManager.GetLeafCertHash()
xlct := make([]byte, 8, 8) xlct := make([]byte, 8)
binary.LittleEndian.PutUint64(xlct, certHash) binary.LittleEndian.PutUint64(xlct, certHash)
tags[TagNONC] = h.nonc tags[TagNONC] = h.nonc
@ -430,7 +485,6 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
defer h.mutex.Unlock() defer h.mutex.Unlock()
leafCert := h.certManager.GetLeafCert() leafCert := h.certManager.GetLeafCert()
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) { if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
var err error var err error
var nonce []byte var nonce []byte
@ -455,7 +509,7 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
return err return err
} }
h.aeadChanged <- struct{}{} h.aeadChanged <- protocol.EncryptionSecure
} }
return nil return nil

View File

@ -1,16 +0,0 @@
package handshake
import "github.com/lucas-clemente/quic-go/protocol"
// CryptoSetup is a crypto setup
type CryptoSetup interface {
HandleCryptoStream() error
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
LockForSealing()
UnlockForSealing()
HandshakeComplete() bool
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient
}

View File

@ -4,14 +4,15 @@ import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"net" "net"
"sync" "sync"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// KeyDerivationFunction is used for key derivation // KeyDerivationFunction is used for key derivation
@ -23,48 +24,71 @@ type KeyExchangeFunction func() crypto.KeyExchange
// The CryptoSetupServer handles all things crypto for the Session // The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct { type cryptoSetupServer struct {
connID protocol.ConnectionID connID protocol.ConnectionID
ip net.IP remoteAddr net.Addr
version protocol.VersionNumber
scfg *ServerConfig scfg *ServerConfig
stkGenerator *STKGenerator
diversificationNonce []byte diversificationNonce []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
acceptSTKCallback func(net.Addr, *STK) bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
receivedForwardSecurePacket bool receivedForwardSecurePacket bool
sentSHLO bool
receivedSecurePacket bool receivedSecurePacket bool
aeadChanged chan struct{} aeadChanged chan<- protocol.EncryptionLevel
keyDerivation KeyDerivationFunction keyDerivation KeyDerivationFunction
keyExchange KeyExchangeFunction keyExchange KeyExchangeFunction
cryptoStream utils.Stream cryptoStream io.ReadWriter
connectionParameters ConnectionParametersManager connectionParameters ConnectionParametersManager
mutex sync.RWMutex mutex sync.RWMutex
} }
var _ crypto.AEAD = &cryptoSetupServer{} var _ CryptoSetup = &cryptoSetupServer{}
// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO
// this is an expiremnt implemented by Chrome in QUIC 36, which we don't support
// TODO: remove this when dropping support for QUIC 36
var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported")
// NewCryptoSetup creates a new CryptoSetup instance for a server // NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup( func NewCryptoSetup(
connID protocol.ConnectionID, connID protocol.ConnectionID,
ip net.IP, remoteAddr net.Addr,
version protocol.VersionNumber, version protocol.VersionNumber,
scfg *ServerConfig, scfg *ServerConfig,
cryptoStream utils.Stream, cryptoStream io.ReadWriter,
connectionParametersManager ConnectionParametersManager, connectionParametersManager ConnectionParametersManager,
aeadChanged chan struct{}, supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *STK) bool,
aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
stkGenerator, err := NewSTKGenerator()
if err != nil {
return nil, err
}
return &cryptoSetupServer{ return &cryptoSetupServer{
connID: connID, connID: connID,
ip: ip, remoteAddr: remoteAddr,
version: version, version: version,
supportedVersions: supportedVersions,
scfg: scfg, scfg: scfg,
stkGenerator: stkGenerator,
keyDerivation: crypto.DeriveKeysAESGCM, keyDerivation: crypto.DeriveKeysAESGCM,
keyExchange: getEphermalKEX, keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
connectionParameters: connectionParametersManager, connectionParameters: connectionParametersManager,
acceptSTKCallback: acceptSTK,
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
}, nil }, nil
} }
@ -73,17 +97,16 @@ func NewCryptoSetup(
func (h *cryptoSetupServer) HandleCryptoStream() error { func (h *cryptoSetupServer) HandleCryptoStream() error {
for { for {
var chloData bytes.Buffer var chloData bytes.Buffer
messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
if err != nil { if err != nil {
return qerr.HandshakeFailed return qerr.HandshakeFailed
} }
if messageTag != TagCHLO { if message.Tag != TagCHLO {
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }
utils.Debugf("Got CHLO:\n%s", printHandshakeMessage(cryptoData)) utils.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data)
done, err := h.handleMessage(chloData.Bytes(), cryptoData)
if err != nil { if err != nil {
return err return err
} }
@ -94,6 +117,10 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
} }
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) { func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment {
return false, ErrHOLExperiment
}
sniSlice, ok := cryptoData[TagSNI] sniSlice, ok := cryptoData[TagSNI]
if !ok { if !ok {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
@ -115,7 +142,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
verTag := binary.LittleEndian.Uint32(verSlice) verTag := binary.LittleEndian.Uint32(verSlice)
ver := protocol.VersionTagToNumber(verTag) ver := protocol.VersionTagToNumber(verTag)
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack. // If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
if ver != h.version && protocol.IsSupportedVersion(ver) { if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
} }
@ -146,49 +173,93 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
return false, err return false, err
} }
_, err = h.cryptoStream.Write(reply) _, err = h.cryptoStream.Write(reply)
if err != nil {
return false, err return false, err
}
return false, nil
} }
// Open a message // Open a message
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil { if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.receivedForwardSecurePacket = true h.receivedForwardSecurePacket = true
return res, nil close(h.aeadChanged)
}
return res, protocol.EncryptionForwardSecure, nil
} }
if h.receivedForwardSecurePacket { if h.receivedForwardSecurePacket {
return nil, err return nil, protocol.EncryptionUnspecified, err
} }
} }
if h.secureAEAD != nil { if h.secureAEAD != nil {
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil { if err == nil {
h.receivedSecurePacket = true h.receivedSecurePacket = true
return res, nil return res, protocol.EncryptionSecure, nil
} }
if h.receivedSecurePacket { if h.receivedSecurePacket {
return nil, err return nil, protocol.EncryptionUnspecified, err
} }
} }
return (&crypto.NullAEAD{}).Open(dst, src, packetNumber, associatedData) res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return res, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, err
} }
// Seal a message, call LockForSealing() before! func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
func (h *cryptoSetupServer) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { h.mutex.RLock()
if h.receivedForwardSecurePacket { defer h.mutex.RUnlock()
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) if h.forwardSecureAEAD != nil {
} else if h.secureAEAD != nil { return protocol.EncryptionForwardSecure, h.sealForwardSecure
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
} else {
return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData)
} }
return protocol.EncryptionUnencrypted, h.sealUnencrypted
}
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure
}
return protocol.EncryptionUnencrypted, h.sealUnencrypted
}
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.sealUnencrypted, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no secureAEAD")
}
return h.sealSecure, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
}
return h.sealForwardSecure, nil
}
return nil, errors.New("CryptoSetupServer: no encryption level specified")
}
func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.nullAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupServer) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData)
} }
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool { func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
@ -207,11 +278,16 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
if crypto.HashCert(cert) != xlct { if crypto.HashCert(cert) != xlct {
return true return true
} }
if err := h.scfg.stkSource.VerifyToken(h.ip, cryptoData[TagSTK]); err != nil { return !h.acceptSTK(cryptoData[TagSTK])
utils.Infof("STK invalid: %s", err.Error()) }
return true
} func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.stkGenerator.DecodeToken(token)
if err != nil {
utils.Debugf("STK invalid: %s", err.Error())
return false return false
}
return h.acceptSTKCallback(h.remoteAddr, stk)
} }
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
@ -219,7 +295,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small") return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
} }
token, err := h.scfg.stkSource.NewToken(h.ip) token, err := h.stkGenerator.NewToken(h.remoteAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -230,7 +306,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
TagSVID: []byte("quic-go"), TagSVID: []byte("quic-go"),
} }
if h.scfg.stkSource.VerifyToken(h.ip, cryptoData[TagSTK]) == nil { if h.acceptSTK(cryptoData[TagSTK]) {
proof, err := h.scfg.Sign(sni, chlo) proof, err := h.scfg.Sign(sni, chlo)
if err != nil { if err != nil {
return nil, err return nil, err
@ -248,9 +324,14 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
replyMap[TagCERT] = certCompressed replyMap[TagCERT] = certCompressed
} }
message := HandshakeMessage{
Tag: TagREJ,
Data: replyMap,
}
var serverReply bytes.Buffer var serverReply bytes.Buffer
WriteHandshakeMessage(&serverReply, TagREJ, replyMap) message.Write(&serverReply)
utils.Debugf("Sending REJ:\n%s", printHandshakeMessage(replyMap)) utils.Debugf("Sending %s", message)
return serverReply.Bytes(), nil return serverReply.Bytes(), nil
} }
@ -310,6 +391,8 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
return nil, err return nil, err
} }
h.aeadChanged <- protocol.EncryptionSecure
// Generate a new curve instance to derive the forward secure key // Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer var fsNonce bytes.Buffer
fsNonce.Write(clientNonce) fsNonce.Write(clientNonce)
@ -345,46 +428,37 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
return nil, err return nil, err
} }
// add crypto parameters // add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range h.supportedVersions {
utils.WriteUint32(verTag, protocol.VersionNumberToTag(v))
}
replyMap[TagPUBS] = ephermalKex.PublicKey() replyMap[TagPUBS] = ephermalKex.PublicKey()
replyMap[TagSNO] = serverNonce replyMap[TagSNO] = serverNonce
replyMap[TagVER] = protocol.SupportedVersionsAsTags replyMap[TagVER] = verTag.Bytes()
// note that the SHLO *has* to fit into one packet
message := HandshakeMessage{
Tag: TagSHLO,
Data: replyMap,
}
var reply bytes.Buffer var reply bytes.Buffer
WriteHandshakeMessage(&reply, TagSHLO, replyMap) message.Write(&reply)
utils.Debugf("Sending SHLO:\n%s", printHandshakeMessage(replyMap)) utils.Debugf("Sending %s", message)
h.aeadChanged <- struct{}{} h.aeadChanged <- protocol.EncryptionForwardSecure
return reply.Bytes(), nil return reply.Bytes(), nil
} }
// DiversificationNonce returns a diversification nonce if required in the next packet to be Seal'ed. See LockForSealing()! // DiversificationNonce returns the diversification nonce
func (h *cryptoSetupServer) DiversificationNonce() []byte { func (h *cryptoSetupServer) DiversificationNonce() []byte {
if h.receivedForwardSecurePacket || h.secureAEAD == nil {
return nil
}
return h.diversificationNonce return h.diversificationNonce
} }
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) error { func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer") panic("not needed for cryptoSetupServer")
} }
// LockForSealing should be called before Seal(). It is needed so that diversification nonces can be obtained before packets are sealed, and the AEADs are not changed in the meantime.
func (h *cryptoSetupServer) LockForSealing() {
h.mutex.RLock()
}
// UnlockForSealing should be called after Seal() is complete, see LockForSealing().
func (h *cryptoSetupServer) UnlockForSealing() {
h.mutex.RUnlock()
}
// HandshakeComplete returns true after the first forward secure packet was received form the client.
func (h *cryptoSetupServer) HandshakeComplete() bool {
return h.receivedForwardSecurePacket
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 { if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")

View File

@ -5,8 +5,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (
@ -29,14 +29,14 @@ func getEphermalKEX() (res crypto.KeyExchange) {
res = kexCurrent res = kexCurrent
t := kexCurrentTime t := kexCurrentTime
kexMutex.RUnlock() kexMutex.RUnlock()
if res != nil && time.Now().Sub(t) < kexLifetime { if res != nil && time.Since(t) < kexLifetime {
return res return res
} }
kexMutex.Lock() kexMutex.Lock()
defer kexMutex.Unlock() defer kexMutex.Unlock()
// Check if still unfulfilled // Check if still unfulfilled
if kexCurrent == nil || time.Now().Sub(kexCurrentTime) > kexLifetime { if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
if err != nil { if err != nil {
utils.Errorf("could not set KEX: %s", err.Error()) utils.Errorf("could not set KEX: %s", err.Error())

View File

@ -7,32 +7,40 @@ import (
"io" "io"
"sort" "sort"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A HandshakeMessage is a handshake message
type HandshakeMessage struct {
Tag Tag
Data map[Tag][]byte
}
var _ fmt.Stringer = &HandshakeMessage{}
// ParseHandshakeMessage reads a crypto message // ParseHandshakeMessage reads a crypto message
func ParseHandshakeMessage(r io.Reader) (Tag, map[Tag][]byte, error) { func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
slice4 := make([]byte, 4) slice4 := make([]byte, 4)
if _, err := io.ReadFull(r, slice4); err != nil { if _, err := io.ReadFull(r, slice4); err != nil {
return 0, nil, err return HandshakeMessage{}, err
} }
messageTag := Tag(binary.LittleEndian.Uint32(slice4)) messageTag := Tag(binary.LittleEndian.Uint32(slice4))
if _, err := io.ReadFull(r, slice4); err != nil { if _, err := io.ReadFull(r, slice4); err != nil {
return 0, nil, err return HandshakeMessage{}, err
} }
nPairs := binary.LittleEndian.Uint32(slice4) nPairs := binary.LittleEndian.Uint32(slice4)
if nPairs > protocol.CryptoMaxParams { if nPairs > protocol.CryptoMaxParams {
return 0, nil, qerr.CryptoTooManyEntries return HandshakeMessage{}, qerr.CryptoTooManyEntries
} }
index := make([]byte, nPairs*8) index := make([]byte, nPairs*8)
if _, err := io.ReadFull(r, index); err != nil { if _, err := io.ReadFull(r, index); err != nil {
return 0, nil, err return HandshakeMessage{}, err
} }
resultMap := map[Tag][]byte{} resultMap := map[Tag][]byte{}
@ -44,24 +52,27 @@ func ParseHandshakeMessage(r io.Reader) (Tag, map[Tag][]byte, error) {
dataLen := dataEnd - dataStart dataLen := dataEnd - dataStart
if dataLen > protocol.CryptoParameterMaxLength { if dataLen > protocol.CryptoParameterMaxLength {
return 0, nil, qerr.Error(qerr.CryptoInvalidValueLength, "value too long") return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
} }
data := make([]byte, dataLen) data := make([]byte, dataLen)
if _, err := io.ReadFull(r, data); err != nil { if _, err := io.ReadFull(r, data); err != nil {
return 0, nil, err return HandshakeMessage{}, err
} }
resultMap[tag] = data resultMap[tag] = data
dataStart = dataEnd dataStart = dataEnd
} }
return messageTag, resultMap, nil return HandshakeMessage{
Tag: messageTag,
Data: resultMap}, nil
} }
// WriteHandshakeMessage writes a crypto message // Write writes a crypto message
func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte) { func (h HandshakeMessage) Write(b *bytes.Buffer) {
utils.WriteUint32(b, uint32(messageTag)) data := h.Data
utils.WriteUint32(b, uint32(h.Tag))
utils.WriteUint16(b, uint16(len(data))) utils.WriteUint16(b, uint16(len(data)))
utils.WriteUint16(b, 0) utils.WriteUint16(b, 0)
@ -71,17 +82,8 @@ func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte)
indexData := make([]byte, 8*len(data)) indexData := make([]byte, 8*len(data))
b.Write(indexData) // Will be updated later b.Write(indexData) // Will be updated later
// Sort the tags
tags := make([]uint32, len(data))
i := 0
for t := range data {
tags[i] = uint32(t)
i++
}
sort.Sort(utils.Uint32Slice(tags))
offset := uint32(0) offset := uint32(0)
for i, t := range tags { for i, t := range h.getTagsSorted() {
v := data[Tag(t)] v := data[Tag(t)]
b.Write(v) b.Write(v)
offset += uint32(len(v)) offset += uint32(len(v))
@ -93,21 +95,32 @@ func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte)
copy(b.Bytes()[indexStart:], indexData) copy(b.Bytes()[indexStart:], indexData)
} }
func printHandshakeMessage(data map[Tag][]byte) string { func (h *HandshakeMessage) getTagsSorted() []uint32 {
var res string tags := make([]uint32, len(h.Data))
i := 0
for t := range h.Data {
tags[i] = uint32(t)
i++
}
sort.Sort(utils.Uint32Slice(tags))
return tags
}
func (h HandshakeMessage) String() string {
var pad string var pad string
for k, v := range data { res := tagToString(h.Tag) + ":\n"
if k == TagPAD { for _, t := range h.getTagsSorted() {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(k), len(v)) tag := Tag(t)
if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else { } else {
res += fmt.Sprintf("\t%s: %#v\n", tagToString(k), string(v)) res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
} }
} }
if len(pad) > 0 { if len(pad) > 0 {
res += pad res += pad
} }
return res return res
} }

View File

@ -0,0 +1,24 @@
package handshake
import "github.com/lucas-clemente/quic-go/protocol"
// Sealer seals a packet
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
// CryptoSetup is a crypto setup
type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
RequestConnectionIDTruncation bool
}

View File

@ -13,7 +13,6 @@ type ServerConfig struct {
certChain crypto.CertChain certChain crypto.CertChain
ID []byte ID []byte
obit []byte obit []byte
stkSource crypto.StkSource
} }
// NewServerConfig creates a new server config // NewServerConfig creates a new server config
@ -24,41 +23,34 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve
return nil, err return nil, err
} }
stkSecret := make([]byte, 32)
if _, err = rand.Read(stkSecret); err != nil {
return nil, err
}
obit := make([]byte, 8) obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil { if _, err = rand.Read(obit); err != nil {
return nil, err return nil, err
} }
stkSource, err := crypto.NewStkSource(stkSecret)
if err != nil {
return nil, err
}
return &ServerConfig{ return &ServerConfig{
kex: kex, kex: kex,
certChain: certChain, certChain: certChain,
ID: id, ID: id,
obit: obit, obit: obit,
stkSource: stkSource,
}, nil }, nil
} }
// Get the server config binary representation // Get the server config binary representation
func (s *ServerConfig) Get() []byte { func (s *ServerConfig) Get() []byte {
var serverConfig bytes.Buffer var serverConfig bytes.Buffer
WriteHandshakeMessage(&serverConfig, TagSCFG, map[Tag][]byte{ msg := HandshakeMessage{
Tag: TagSCFG,
Data: map[Tag][]byte{
TagSCID: s.ID, TagSCID: s.ID,
TagKEXS: []byte("C255"), TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"), TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...), TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
TagOBIT: s.obit, TagOBIT: s.obit,
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}) },
}
msg.Write(&serverConfig)
return serverConfig.Bytes() return serverConfig.Bytes()
} }

View File

@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type serverConfigClient struct { type serverConfigClient struct {
@ -28,16 +28,16 @@ var (
// parseServerConfig parses a server config // parseServerConfig parses a server config
func parseServerConfig(data []byte) (*serverConfigClient, error) { func parseServerConfig(data []byte) (*serverConfigClient, error) {
tag, tagMap, err := ParseHandshakeMessage(bytes.NewReader(data)) message, err := ParseHandshakeMessage(bytes.NewReader(data))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if tag != TagSCFG { if message.Tag != TagSCFG {
return nil, errMessageNotServerConfig return nil, errMessageNotServerConfig
} }
scfg := &serverConfigClient{raw: data} scfg := &serverConfigClient{raw: data}
err = scfg.parseValues(tagMap) err = scfg.parseValues(message.Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,7 +57,6 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
s.ID = scfgID s.ID = scfgID
// KEXS // KEXS
// TODO: allow for P256 in the list
// TODO: setup Key Exchange // TODO: setup Key Exchange
kexs, ok := tagMap[TagKEXS] kexs, ok := tagMap[TagKEXS]
if !ok { if !ok {
@ -66,8 +65,16 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
if len(kexs)%4 != 0 { if len(kexs)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS") return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
} }
if !bytes.Equal(kexs, []byte("C255")) { c255Foundat := -1
return qerr.Error(qerr.CryptoNoSupport, "KEXS")
for i := 0; i < len(kexs)/4; i++ {
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
c255Foundat = i
break
}
}
if c255Foundat < 0 {
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
} }
// AEAD // AEAD
@ -90,12 +97,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
} }
// PUBS // PUBS
// TODO: save this value
pubs, ok := tagMap[TagPUBS] pubs, ok := tagMap[TagPUBS]
if !ok { if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
} }
if len(pubs) != 35 {
var pubs_kexs []struct{Length uint32; Value []byte}
var last_len uint32
for i := 0; i < len(pubs)-3; i += int(last_len)+3 {
// the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &last_len);
if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
}
if last_len == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
if i+3+int(last_len) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
pubs_kexs = append(pubs_kexs, struct{Length uint32; Value []byte}{last_len, pubs[i+3:i+3+int(last_len)]})
}
if c255Foundat >= len(pubs_kexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
}
if pubs_kexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
} }
@ -105,8 +137,8 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
return err return err
} }
// the PUBS value is always prepended by []byte{0x20, 0x00, 0x00}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs[3:]) s.sharedSecret, err = s.kex.CalculateSharedKey(pubs_kexs[c255Foundat].Value)
if err != nil { if err != nil {
return err return err
} }

View File

@ -0,0 +1,100 @@
package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/crypto"
)
const (
stkPrefixIP byte = iota
stkPrefixString
)
// An STK is a source address token
type STK struct {
RemoteAddr string
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// An STKGenerator generates STKs
type STKGenerator struct {
stkSource crypto.StkSource
}
// NewSTKGenerator initializes a new STKGenerator
func NewSTKGenerator() (*STKGenerator, error) {
stkSource, err := crypto.NewStkSource()
if err != nil {
return nil, err
}
return &STKGenerator{
stkSource: stkSource,
}, nil
}
// NewToken generates a new STK token for a given source address
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.stkSource.NewToken(data)
}
// DecodeToken decodes an STK token
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.stkSource.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &STK{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{stkPrefixIP}, udpAddr.IP...)
}
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the STK
func decodeRemoteAddr(data []byte) string {
// data will never be empty for an STK that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == stkPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View File

@ -50,6 +50,11 @@ const (
// TagSFCW is the initial stream flow control receive window. // TagSFCW is the initial stream flow control receive window.
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24 TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagFHL2 forces head of line blocking.
// Chrome experiment (see https://codereview.chromium.org/2115033002)
// unsupported by quic-go
TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24
// TagSTK is the source-address token // TagSTK is the source-address token
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16 TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
// TagSNO is the server nonce // TagSNO is the server nonce

121
vendor/github.com/lucas-clemente/quic-go/interface.go generated vendored Normal file
View File

@ -0,0 +1,121 @@
package quic
import (
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/protocol"
)
// Stream is the interface implemented by QUIC streams
type Stream interface {
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
io.Writer
io.Closer
StreamID() protocol.StreamID
// Reset closes the stream with an error.
Reset(error)
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
SetDeadline(t time.Time) error
}
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peeer's concurrent stream limit is reached.
// New streams always have the smallest possible stream ID.
// TODO: Enable testing for the special error
OpenStream() (Stream, error)
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened.
// It always picks the smallest possible stream ID.
OpenStreamSync() (Stream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error
// WaitUntilClosed() blocks until the session is closed.
// Warning: This API should not be considered stable and might change soon.
WaitUntilClosed()
}
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
}
// An STK is a Source Address token.
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
type STK struct {
// The remote address this token was issued for.
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
remoteAddr string
// The time that the STK was issued (resolution 1 second)
sentTime time.Time
}
// Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct {
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber
// Ask the server to truncate the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDTruncation bool
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// AcceptSTK determines if an STK is accepted.
// It is called with stk = nil if the client didn't send an STK.
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
// This option is only valid for the server.
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow protocol.ByteCount
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
}
// A Listener for incoming QUIC connections
type Listener interface {
// Close the server, sending CONNECTION_CLOSE frames to each peer.
Close() error
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Accept returns new sessions. It should be called in a loop.
Accept() (Session, error)
}

View File

@ -9,7 +9,7 @@ import (
// GenerateConnectionID generates a connection ID using cryptographic random // GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID() (protocol.ConnectionID, error) { func GenerateConnectionID() (protocol.ConnectionID, error) {
b := make([]byte, 8, 8) b := make([]byte, 8)
_, err := rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -0,0 +1,94 @@
package utils
import (
"fmt"
"log"
"os"
"time"
)
// LogLevel of quic-go
type LogLevel uint8
const logEnv = "QUIC_GO_LOG_LEVEL"
const (
// LogLevelNothing disables
LogLevelNothing LogLevel = iota
// LogLevelError enables err logs
LogLevelError
// LogLevelInfo enables info logs (e.g. packets)
LogLevelInfo
// LogLevelDebug enables debug logs (e.g. packet contents)
LogLevelDebug
)
var (
logLevel = LogLevelNothing
timeFormat = ""
)
// SetLogLevel sets the log level
func SetLogLevel(level LogLevel) {
logLevel = level
}
// SetLogTimeFormat sets the format of the timestamp
// an empty string disables the logging of timestamps
func SetLogTimeFormat(format string) {
log.SetFlags(0) // disable timestamp logging done by the log package
timeFormat = format
}
// Debugf logs something
func Debugf(format string, args ...interface{}) {
if logLevel == LogLevelDebug {
logMessage(format, args...)
}
}
// Infof logs something
func Infof(format string, args ...interface{}) {
if logLevel >= LogLevelInfo {
logMessage(format, args...)
}
}
// Errorf logs something
func Errorf(format string, args ...interface{}) {
if logLevel >= LogLevelError {
logMessage(format, args...)
}
}
func logMessage(format string, args ...interface{}) {
if len(timeFormat) > 0 {
log.Printf(time.Now().Format(timeFormat)+" "+format, args...)
} else {
log.Printf(format, args...)
}
}
// Debug returns true if the log level is LogLevelDebug
func Debug() bool {
return logLevel == LogLevelDebug
}
func init() {
readLoggingEnv()
}
func readLoggingEnv() {
switch os.Getenv(logEnv) {
case "":
return
case "DEBUG":
logLevel = LogLevelDebug
case "INFO":
logLevel = LogLevelInfo
case "ERROR":
logLevel = LogLevelError
default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging")
}
}

View File

@ -0,0 +1,43 @@
package utils
import "time"
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
t *time.Timer
read bool
deadline time.Time
}
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(0)}
}
// Chan returns the channel of the wrapped timer
func (t *Timer) Chan() <-chan time.Time {
return t.t.C
}
// Reset the timer, no matter whether the value was read or not
func (t *Timer) Reset(deadline time.Time) {
if deadline.Equal(t.deadline) {
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !t.t.Stop() && !t.read {
<-t.t.C
}
t.t.Reset(deadline.Sub(time.Now()))
t.read = false
t.deadline = deadline
}
// SetRead should be called after the value from the chan was read
func (t *Timer) SetRead() {
t.read = true
}

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
@ -14,6 +15,7 @@ type packedPacket struct {
number protocol.PacketNumber number protocol.PacketNumber
raw []byte raw []byte
frames []frames.Frame frames []frames.Frame
encryptionLevel protocol.EncryptionLevel
} }
type packetPacker struct { type packetPacker struct {
@ -23,14 +25,22 @@ type packetPacker struct {
cryptoSetup handshake.CryptoSetup cryptoSetup handshake.CryptoSetup
packetNumberGenerator *packetNumberGenerator packetNumberGenerator *packetNumberGenerator
connectionParameters handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
streamFramer *streamFramer streamFramer *streamFramer
controlFrames []frames.Frame controlFrames []frames.Frame
stopWaiting *frames.StopWaitingFrame
ackFrame *frames.AckFrame
leastUnacked protocol.PacketNumber
} }
func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, connectionParameters handshake.ConnectionParametersManager, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber) *packetPacker { func newPacketPacker(connectionID protocol.ConnectionID,
cryptoSetup handshake.CryptoSetup,
connectionParameters handshake.ConnectionParametersManager,
streamFramer *streamFramer,
perspective protocol.Perspective,
version protocol.VersionNumber,
) *packetPacker {
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
connectionID: connectionID, connectionID: connectionID,
@ -43,135 +53,168 @@ func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.C
} }
// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame) (*packedPacket, error) {
// in case the connection is closed, all queued control frames aren't of any use anymore frames := []frames.Frame{ccf}
// discard them and queue the ConnectionCloseFrame encLevel, sealer := p.cryptoSetup.GetSealer()
p.controlFrames = []frames.Frame{ccf} ph := p.getPublicHeader(encLevel)
return p.packPacket(nil, leastUnacked) raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
}
func (p *packetPacker) PackAckPacket() (*packedPacket, error) {
if p.ackFrame == nil {
return nil, errors.New("packet packer BUG: no ack frame queued")
}
encLevel, sealer := p.cryptoSetup.GetSealer()
ph := p.getPublicHeader(encLevel)
frames := []frames.Frame{p.ackFrame}
if p.stopWaiting != nil {
p.stopWaiting.PacketNumber = ph.PacketNumber
p.stopWaiting.PacketNumberLen = ph.PacketNumberLen
frames = append(frames, p.stopWaiting)
p.stopWaiting = nil
}
p.ackFrame = nil
raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
}
// PackHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption
func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")
}
sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel)
if err != nil {
return nil, err
}
if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
}
ph := p.getPublicHeader(packet.EncryptionLevel)
p.stopWaiting.PacketNumber = ph.PacketNumber
p.stopWaiting.PacketNumberLen = ph.PacketNumberLen
frames := append([]frames.Frame{p.stopWaiting}, packet.Frames...)
p.stopWaiting = nil
raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: packet.EncryptionLevel,
}, err
} }
// PackPacket packs a new packet // PackPacket packs a new packet
// the stopWaitingFrame is *guaranteed* to be included in the next packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { func (p *packetPacker) PackPacket() (*packedPacket, error) {
p.controlFrames = append(p.controlFrames, controlFrames...) if p.streamFramer.HasCryptoStreamFrame() {
return p.packPacket(stopWaitingFrame, leastUnacked) return p.packCryptoPacket()
}
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
// cryptoSetup needs to be locked here, so that the AEADs are not changed between
// calling DiversificationNonce() and Seal().
p.cryptoSetup.LockForSealing()
defer p.cryptoSetup.UnlockForSealing()
currentPacketNumber := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked)
responsePublicHeader := &PublicHeader{
ConnectionID: p.connectionID,
PacketNumber: currentPacketNumber,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
} }
if p.perspective == protocol.PerspectiveServer { encLevel, sealer := p.cryptoSetup.GetSealer()
responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
// TODO: stop sending version numbers once a version has been negotiated publicHeader := p.getPublicHeader(encLevel)
if p.perspective == protocol.PerspectiveClient { publicHeaderLength, err := publicHeader.GetLength(p.perspective)
responsePublicHeader.VersionFlag = true
responsePublicHeader.VersionNumber = p.version
}
publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if p.stopWaiting != nil {
if stopWaitingFrame != nil { p.stopWaiting.PacketNumber = publicHeader.PacketNumber
stopWaitingFrame.PacketNumber = currentPacketNumber p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen
stopWaitingFrame.PacketNumberLen = packetNumberLen
} }
// we're packing a ConnectionClose, don't add any StreamFrames maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength
var isConnectionClose bool payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel))
if len(p.controlFrames) == 1 {
_, isConnectionClose = p.controlFrames[0].(*frames.ConnectionCloseFrame)
}
var payloadFrames []frames.Frame
if isConnectionClose {
payloadFrames = []frames.Frame{p.controlFrames[0]}
} else {
payloadFrames, err = p.composeNextPacket(stopWaitingFrame, publicHeaderLength)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
// Check if we have enough frames to send // Check if we have enough frames to send
if len(payloadFrames) == 0 { if len(payloadFrames) == 0 {
return nil, nil return nil, nil
} }
// Don't send out packets that only contain a StopWaitingFrame // Don't send out packets that only contain a StopWaitingFrame
if len(payloadFrames) == 1 && stopWaitingFrame != nil { if len(payloadFrames) == 1 && p.stopWaiting != nil {
return nil, nil return nil, nil
} }
p.stopWaiting = nil
p.ackFrame = nil
raw := getPacketBuffer() raw, err := p.writeAndSealPacket(publicHeader, payloadFrames, sealer)
buffer := bytes.NewBuffer(raw)
if err = responsePublicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
for _, frame := range payloadFrames {
err := frame.Write(buffer, p.version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large")
}
raw = raw[0:buffer.Len()]
p.cryptoSetup.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], currentPacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+12]
num := p.packetNumberGenerator.Pop()
if num != currentPacketNumber {
return nil, errors.New("PacketPacker BUG: Peeked and Popped packet numbers do not match.")
}
return &packedPacket{ return &packedPacket{
number: currentPacketNumber, number: publicHeader.PacketNumber,
raw: raw, raw: raw,
frames: payloadFrames, frames: payloadFrames,
encryptionLevel: encLevel,
}, nil }, nil
} }
func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFrame, publicHeaderLength protocol.ByteCount) ([]frames.Frame, error) { func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
var payloadLength protocol.ByteCount encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream()
var payloadFrames []frames.Frame publicHeader := p.getPublicHeader(encLevel)
publicHeaderLength, err := publicHeader.GetLength(p.perspective)
maxFrameSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength
if stopWaitingFrame != nil {
payloadFrames = append(payloadFrames, stopWaitingFrame)
minLength, err := stopWaitingFrame.MinLength(p.version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
payloadLength += minLength maxLen := protocol.MaxFrameAndPublicHeaderSize - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength
frames := []frames.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)}
raw, err := p.writeAndSealPacket(publicHeader, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
number: publicHeader.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
}
func (p *packetPacker) composeNextPacket(
maxFrameSize protocol.ByteCount,
canSendStreamFrames bool,
) ([]frames.Frame, error) {
var payloadLength protocol.ByteCount
var payloadFrames []frames.Frame
// STOP_WAITING and ACK will always fit
if p.stopWaiting != nil {
payloadFrames = append(payloadFrames, p.stopWaiting)
l, err := p.stopWaiting.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l
}
if p.ackFrame != nil {
payloadFrames = append(payloadFrames, p.ackFrame)
l, err := p.ackFrame.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l
} }
for len(p.controlFrames) > 0 { for len(p.controlFrames) > 0 {
frame := p.controlFrames[len(p.controlFrames)-1] frame := p.controlFrames[len(p.controlFrames)-1]
minLength, _ := frame.MinLength(p.version) // controlFrames does not contain any StopWaitingFrames. So it will *never* return an error minLength, err := frame.MinLength(p.version)
if err != nil {
return nil, err
}
if payloadLength+minLength > maxFrameSize { if payloadLength+minLength > maxFrameSize {
break break
} }
@ -184,6 +227,10 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra
return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize)
} }
if !canSendStreamFrames {
return payloadFrames, nil
}
// temporarily increase the maxFrameSize by 2 bytes // temporarily increase the maxFrameSize by 2 bytes
// this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set
// however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size // however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size
@ -206,6 +253,79 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra
return payloadFrames, nil return payloadFrames, nil
} }
func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) { func (p *packetPacker) QueueControlFrame(frame frames.Frame) {
switch f := frame.(type) {
case *frames.StopWaitingFrame:
p.stopWaiting = f
case *frames.AckFrame:
p.ackFrame = f
default:
p.controlFrames = append(p.controlFrames, f) p.controlFrames = append(p.controlFrames, f)
}
}
func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *PublicHeader {
pnum := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked)
publicHeader := &PublicHeader{
ConnectionID: p.connectionID,
PacketNumber: pnum,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
}
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure {
publicHeader.VersionFlag = true
publicHeader.VersionNumber = p.version
}
return publicHeader
}
func (p *packetPacker) writeAndSealPacket(
publicHeader *PublicHeader,
payloadFrames []frames.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)
if err := publicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
for _, frame := range payloadFrames {
err := frame.Write(buffer, p.version)
if err != nil {
return nil, err
}
}
if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large")
}
raw = raw[0:buffer.Len()]
_ = sealer(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+12]
num := p.packetNumberGenerator.Pop()
if num != publicHeader.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
return raw, nil
}
func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
if p.perspective == protocol.PerspectiveClient {
return encLevel >= protocol.EncryptionSecure
}
return encLevel == protocol.EncryptionForwardSecure
}
func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
p.leastUnacked = leastUnacked
} }

View File

@ -5,21 +5,29 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
type unpackedPacket struct {
encryptionLevel protocol.EncryptionLevel
frames []frames.Frame
}
type quicAEAD interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
}
type packetUnpacker struct { type packetUnpacker struct {
version protocol.VersionNumber version protocol.VersionNumber
aead crypto.AEAD aead quicAEAD
} }
func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, data []byte) (*unpackedPacket, error) {
buf := getPacketBuffer() buf := getPacketBuffer()
defer putPacketBuffer(buf) defer putPacketBuffer(buf)
decrypted, err := u.aead.Open(buf, data, hdr.PacketNumber, publicHeaderBinary) decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, publicHeaderBinary)
if err != nil { if err != nil {
// Wrap err in quicError so that public reset is sent by session // Wrap err in quicError so that public reset is sent by session
return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
@ -33,9 +41,11 @@ func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, da
fs := make([]frames.Frame, 0, 2) fs := make([]frames.Frame, 0, 2)
// Read all frames in the packet // Read all frames in the packet
ReadLoop:
for r.Len() > 0 { for r.Len() > 0 {
typeByte, _ := r.ReadByte() typeByte, _ := r.ReadByte()
if typeByte == 0x0 { // PADDING frame
continue
}
r.UnreadByte() r.UnreadByte()
var frame frames.Frame var frame frames.Frame
@ -43,6 +53,11 @@ ReadLoop:
frame, err = frames.ParseStreamFrame(r) frame, err = frames.ParseStreamFrame(r)
if err != nil { if err != nil {
err = qerr.Error(qerr.InvalidStreamData, err.Error()) err = qerr.Error(qerr.InvalidStreamData, err.Error())
} else {
streamID := frame.(*frames.StreamFrame).StreamID
if streamID != 1 && encryptionLevel <= protocol.EncryptionUnencrypted {
err = qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", streamID))
}
} }
} else if typeByte&0xc0 == 0x40 { } else if typeByte&0xc0 == 0x40 {
frame, err = frames.ParseAckFrame(r, u.version) frame, err = frames.ParseAckFrame(r, u.version)
@ -53,8 +68,6 @@ ReadLoop:
err = errors.New("unimplemented: CONGESTION_FEEDBACK") err = errors.New("unimplemented: CONGESTION_FEEDBACK")
} else { } else {
switch typeByte { switch typeByte {
case 0x0: // PAD, end of frames
break ReadLoop
case 0x01: case 0x01:
frame, err = frames.ParseRstStreamFrame(r) frame, err = frames.ParseRstStreamFrame(r)
if err != nil { if err != nil {
@ -100,6 +113,7 @@ ReadLoop:
} }
return &unpackedPacket{ return &unpackedPacket{
encryptionLevel: encryptionLevel,
frames: fs, frames: fs,
}, nil }, nil
} }

View File

@ -5,10 +5,24 @@ package protocol
type EncryptionLevel int type EncryptionLevel int
const ( const (
// Unencrypted is not encrypted // EncryptionUnspecified is a not specified encryption level
Unencrypted EncryptionLevel = iota EncryptionUnspecified EncryptionLevel = iota
// EncryptionUnencrypted is not encrypted
EncryptionUnencrypted
// EncryptionSecure is encrypted, but not forward secure // EncryptionSecure is encrypted, but not forward secure
EncryptionSecure EncryptionSecure
// EncryptionForwardSecure is forward secure // EncryptionForwardSecure is forward secure
EncryptionForwardSecure EncryptionForwardSecure
) )
func (e EncryptionLevel) String() string {
switch e {
case EncryptionUnencrypted:
return "unencrypted"
case EncryptionSecure:
return "encrypted (not forward-secure)"
case EncryptionForwardSecure:
return "forward-secure"
}
return "unknown"
}

View File

@ -1,9 +1,6 @@
package protocol package protocol
import ( import "math"
"math"
"time"
)
// A PacketNumber in QUIC // A PacketNumber in QUIC
type PacketNumber uint64 type PacketNumber uint64
@ -34,14 +31,13 @@ type StreamID uint32
type ByteCount uint64 type ByteCount uint64
// MaxByteCount is the maximum value of a ByteCount // MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = math.MaxUint64 const MaxByteCount = ByteCount(math.MaxUint64)
// MaxPacketSize is the maximum packet size, including the public header // MaxReceivePacketSize maximum packet size of any QUIC packet, based on
// This is the value used by Chromium for a QUIC packet sent using IPv6 (for IPv4 it would be 1370) // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
const MaxPacketSize ByteCount = 1350 // UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
// MaxFrameAndPublicHeaderSize is the maximum size of a QUIC frame plus PublicHeader const MaxReceivePacketSize ByteCount = 1452
const MaxFrameAndPublicHeaderSize = MaxPacketSize - 12 /*crypto signature*/
// DefaultTCPMSS is the default maximum packet size used in the Linux TCP implementation. // DefaultTCPMSS is the default maximum packet size used in the Linux TCP implementation.
// Used in QUIC for congestion window computations in bytes. // Used in QUIC for congestion window computations in bytes.
@ -53,15 +49,6 @@ const InitialStreamFlowControlWindow ByteCount = (1 << 14) // 16 kB
// InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending // InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending
const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB
// DefaultRetransmissionTime is the RTO time on new connections
const DefaultRetransmissionTime = 500 * time.Millisecond
// MinRetransmissionTime is the minimum RTO time
const MinRetransmissionTime = 200 * time.Millisecond
// MaxRetransmissionTime is the maximum RTO time
const MaxRetransmissionTime = 60 * time.Second
// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. // ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have.
const ClientHelloMinimumSize = 1024 const ClientHelloMinimumSize = 1024

View File

@ -2,6 +2,17 @@ package protocol
import "time" import "time"
// MaxPacketSize is the maximum packet size, including the public header, that we use for sending packets
// This is the value used by Chromium for a QUIC packet sent using IPv6 (for IPv4 it would be 1370)
const MaxPacketSize ByteCount = 1350
// MaxFrameAndPublicHeaderSize is the maximum size of a QUIC frame plus PublicHeader
const MaxFrameAndPublicHeaderSize = MaxPacketSize - 12 /*crypto signature*/
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
const NonForwardSecurePacketSizeReduction = 50
// DefaultMaxCongestionWindow is the default for the max congestion window // DefaultMaxCongestionWindow is the default for the max congestion window
const DefaultMaxCongestionWindow = 1000 const DefaultMaxCongestionWindow = 1000
@ -12,6 +23,10 @@ const InitialCongestionWindow = 32
// session queues for later until it sends a public reset. // session queues for later until it sends a public reset.
const MaxUndecryptablePackets = 10 const MaxUndecryptablePackets = 10
// PublicResetTimeout is the time to wait before sending a Public Reset when receiving too many undecryptable packets during the handshake
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
const PublicResetTimeout = 500 * time.Millisecond
// AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet // AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet
// This is the value Chromium is using // This is the value Chromium is using
const AckSendDelay = 25 * time.Millisecond const AckSendDelay = 25 * time.Millisecond
@ -24,21 +39,25 @@ const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB
// This is the value that Google servers are using // This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB
// MaxReceiveStreamFlowControlWindowServer is the maximum stream-level flow control window for receiving data // DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB const DefaultMaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB
// MaxReceiveConnectionFlowControlWindowServer is the connection-level flow control window for receiving data // DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB const DefaultMaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB
// MaxReceiveStreamFlowControlWindowClient is the maximum stream-level flow control window for receiving data, for the client // DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using // This is the value that Chromium is using
const MaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB const DefaultMaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB
// MaxReceiveConnectionFlowControlWindowClient is the connection-level flow control window for receiving data, for the server // DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB const DefaultMaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
// This is the value that Chromium is using
const ConnectionFlowControlMultiplier = 1.5
// MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection // MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection
const MaxStreamsPerConnection = 100 const MaxStreamsPerConnection = 100
@ -59,17 +78,14 @@ const MaxNewStreamIDDelta = 4 * MaxStreamsPerConnection
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow
// RetransmissionThreshold + 1 is the number of times a packet has to be NACKed so that it gets retransmitted
const RetransmissionThreshold = 3
// SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack // SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack
const SkipPacketAveragePeriodLength PacketNumber = 500 const SkipPacketAveragePeriodLength PacketNumber = 500
// MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation // MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation
const MaxTrackedSkippedPackets = 10 const MaxTrackedSkippedPackets = 10
// STKExpiryTimeSec is the valid time of a source address token in seconds // STKExpiryTime is the valid time of a source address token
const STKExpiryTimeSec = 24 * 60 * 60 const STKExpiryTime = 24 * time.Hour
// MaxTrackedSentPackets is maximum number of sent packets saved for either later retransmission or entropy calculation // MaxTrackedSentPackets is maximum number of sent packets saved for either later retransmission or entropy calculation
const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow
@ -112,8 +128,8 @@ const MaxIdleTimeoutServer = 1 * time.Minute
// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server // MaxIdleTimeoutClient is the idle timeout that the client suggests to the server
const MaxIdleTimeoutClient = 2 * time.Minute const MaxIdleTimeoutClient = 2 * time.Minute
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds. // DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const MaxTimeForCryptoHandshake = 10 * time.Second const DefaultHandshakeTimeout = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed // ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// after this time all information about the old connection will be deleted // after this time all information about the old connection will be deleted

View File

@ -1,35 +1,23 @@
package protocol package protocol
import (
"bytes"
"encoding/binary"
"strconv"
)
// VersionNumber is a version number as int // VersionNumber is a version number as int
type VersionNumber int type VersionNumber int
// The version numbers, making grepping easier // The version numbers, making grepping easier
const ( const (
Version34 VersionNumber = 34 + iota Version35 VersionNumber = 35 + iota
Version35
Version36 Version36
VersionWhatever = 0 // for when the version doesn't matter Version37
VersionUnsupported = -1 VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionUnsupported VersionNumber = -1
) )
// SupportedVersions lists the versions that the server supports // SupportedVersions lists the versions that the server supports
// must be in sorted order // must be in sorted descending order
var SupportedVersions = []VersionNumber{ var SupportedVersions = []VersionNumber{
Version34, Version35, Version36, Version37, Version36, Version35,
} }
// SupportedVersionsAsTags is needed for the SHLO crypto message
var SupportedVersionsAsTags []byte
// SupportedVersionsAsString is needed for the Alt-Scv HTTP header
var SupportedVersionsAsString string
// VersionNumberToTag maps version numbers ('32') to tags ('Q032') // VersionNumberToTag maps version numbers ('32') to tags ('Q032')
func VersionNumberToTag(vn VersionNumber) uint32 { func VersionNumberToTag(vn VersionNumber) uint32 {
v := uint32(vn) v := uint32(vn)
@ -42,8 +30,8 @@ func VersionTagToNumber(v uint32) VersionNumber {
} }
// IsSupportedVersion returns true if the server supports this version // IsSupportedVersion returns true if the server supports this version
func IsSupportedVersion(v VersionNumber) bool { func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
for _, t := range SupportedVersions { for _, t := range supported {
if t == v { if t == v {
return true return true
} }
@ -51,41 +39,17 @@ func IsSupportedVersion(v VersionNumber) bool {
return false return false
} }
// HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions // ChooseSupportedVersion finds the best version in the overlap of ours and theirs
// the versions in other do not need to be ordered // ours is a slice of versions that we support, sorted by our preference (descending)
// it returns true and the version number, if there is one, otherwise false // theirs is a slice of versions offered by the peer. The order does not matter
func HighestSupportedVersion(other []VersionNumber) (bool, VersionNumber) { // if no suitable version is found, it returns VersionUnsupported
var otherSupported []VersionNumber func ChooseSupportedVersion(ours, theirs []VersionNumber) VersionNumber {
for _, ver := range other { for _, ourVer := range ours {
if ver != VersionUnsupported { for _, theirVer := range theirs {
otherSupported = append(otherSupported, ver) if ourVer == theirVer {
} return ourVer
}
for i := len(SupportedVersions) - 1; i >= 0; i-- {
for _, ver := range otherSupported {
if ver == SupportedVersions[i] {
return true, ver
} }
} }
} }
return VersionUnsupported
return false, 0
}
func init() {
var b bytes.Buffer
for _, v := range SupportedVersions {
s := make([]byte, 4)
binary.LittleEndian.PutUint32(s, VersionNumberToTag(v))
b.Write(s)
}
SupportedVersionsAsTags = b.Bytes()
for i := len(SupportedVersions) - 1; i >= 0; i-- {
SupportedVersionsAsString += strconv.Itoa(int(SupportedVersions[i]))
if i != 0 {
SupportedVersionsAsString += ","
}
}
} }

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (
@ -17,7 +17,7 @@ var (
errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets")
) )
// The PublicHeader of a QUIC packet // The PublicHeader of a QUIC packet. Warning: This struct should not be considered stable and will change soon.
type PublicHeader struct { type PublicHeader struct {
Raw []byte Raw []byte
ConnectionID protocol.ConnectionID ConnectionID protocol.ConnectionID
@ -31,7 +31,7 @@ type PublicHeader struct {
DiversificationNonce []byte DiversificationNonce []byte
} }
// Write writes a public header // Write writes a public header. Warning: This API should not be considered stable and will change soon.
func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error { func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error {
publicFlagByte := uint8(0x00) publicFlagByte := uint8(0x00)
@ -109,8 +109,9 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe
return nil return nil
} }
// ParsePublicHeader parses a QUIC packet's public header // ParsePublicHeader parses a QUIC packet's public header.
// the packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient // The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient.
// Warning: This API should not be considered stable and will change soon.
func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) { func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) {
header := &PublicHeader{} header := &PublicHeader{}
@ -128,7 +129,8 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub
// return nil, errors.New("diversification nonces should only be sent by servers") // return nil, errors.New("diversification nonces should only be sent by servers")
// } // }
if publicFlagByte&0x08 == 0 { header.TruncateConnectionID = publicFlagByte&0x08 == 0
if header.TruncateConnectionID && packetSentBy == protocol.PerspectiveClient {
return nil, errReceivedTruncatedConnectionID return nil, errReceivedTruncatedConnectionID
} }
@ -146,15 +148,17 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub
} }
// Connection ID // Connection ID
connID, err := utils.ReadUint64(b) if !header.TruncateConnectionID {
var connID uint64
connID, err = utils.ReadUint64(b)
if err != nil { if err != nil {
return nil, err return nil, err
} }
header.ConnectionID = protocol.ConnectionID(connID) header.ConnectionID = protocol.ConnectionID(connID)
if header.ConnectionID == 0 { if header.ConnectionID == 0 {
return nil, errInvalidConnectionID return nil, errInvalidConnectionID
} }
}
if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 { if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 {
// TODO: remove the if once the Google servers send the correct value // TODO: remove the if once the Google servers send the correct value
@ -192,9 +196,6 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub
break break
} }
v := protocol.VersionTagToNumber(versionTag) v := protocol.VersionTagToNumber(versionTag)
if !protocol.IsSupportedVersion(v) {
v = protocol.VersionUnsupported
}
header.SupportedVersions = append(header.SupportedVersions, v) header.SupportedVersions = append(header.SupportedVersions, v)
} }
} }
@ -213,8 +214,8 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub
return header, nil return header, nil
} }
// GetLength gets the length of the publicHeader in bytes // GetLength gets the length of the publicHeader in bytes.
// can only be called for regular packets // It can only be called for regular packets.
func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) { func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) {
if h.VersionFlag && h.ResetFlag { if h.VersionFlag && h.ResetFlag {
return 0, errResetAndVersionFlagSet return 0, errResetAndVersionFlagSet

View File

@ -6,8 +6,8 @@ import (
"errors" "errors"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type publicReset struct { type publicReset struct {
@ -32,15 +32,15 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p
func parsePublicReset(r *bytes.Reader) (*publicReset, error) { func parsePublicReset(r *bytes.Reader) (*publicReset, error) {
pr := publicReset{} pr := publicReset{}
tag, tagMap, err := handshake.ParseHandshakeMessage(r) msg, err := handshake.ParseHandshakeMessage(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if tag != handshake.TagPRST { if msg.Tag != handshake.TagPRST {
return nil, errors.New("wrong public reset tag") return nil, errors.New("wrong public reset tag")
} }
rseq, ok := tagMap[handshake.TagRSEQ] rseq, ok := msg.Data[handshake.TagRSEQ]
if !ok { if !ok {
return nil, errors.New("RSEQ missing") return nil, errors.New("RSEQ missing")
} }
@ -49,7 +49,7 @@ func parsePublicReset(r *bytes.Reader) (*publicReset, error) {
} }
pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq))
rnon, ok := tagMap[handshake.TagRNON] rnon, ok := msg.Data[handshake.TagRNON]
if !ok { if !ok {
return nil, errors.New("RNON missing") return nil, errors.New("RNON missing")
} }

View File

@ -89,6 +89,9 @@ const (
EmptyStreamFrameNoFin ErrorCode = 50 EmptyStreamFrameNoFin ErrorCode = 50
// We received invalid data on the headers stream. // We received invalid data on the headers stream.
InvalidHeadersStreamData ErrorCode = 56 InvalidHeadersStreamData ErrorCode = 56
// Invalid data on the headers stream received because of decompression
// failure.
HeadersStreamDataDecompressFailure ErrorCode = 97
// The peer received too much data, violating flow control. // The peer received too much data, violating flow control.
FlowControlReceivedTooMuchData ErrorCode = 59 FlowControlReceivedTooMuchData ErrorCode = 59
// The peer sent too much data, violating flow control. // The peer sent too much data, violating flow control.

Some files were not shown because too many files have changed in this diff Show More