update vendor
This commit is contained in:
parent
030fbd8521
commit
25d223c384
6
vendor/github.com/bifurcation/mint/README.md
generated
vendored
6
vendor/github.com/bifurcation/mint/README.md
generated
vendored
@ -18,6 +18,12 @@ library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns
|
||||
with earlier TLS versions. However, unnecessary parts will be ruthlessly cut
|
||||
off.
|
||||
|
||||
## DTLS Support
|
||||
|
||||
Mint has partial support for DTLS, but that support is not yet complete
|
||||
and may still contain serious defects.
|
||||
|
||||
|
||||
## Quickstart
|
||||
|
||||
Installation is the same as for any other Go package:
|
||||
|
2
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
2
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
@ -46,6 +46,7 @@ const (
|
||||
AlertBadCertificateHashValue Alert = 114
|
||||
AlertUnknownPSKIdentity Alert = 115
|
||||
AlertNoApplicationProtocol Alert = 120
|
||||
AlertStatelessRetry Alert = 253
|
||||
AlertWouldBlock Alert = 254
|
||||
AlertNoAlert Alert = 255
|
||||
)
|
||||
@ -82,6 +83,7 @@ var alertText = map[Alert]string{
|
||||
AlertUnknownPSKIdentity: "unknown PSK identity",
|
||||
AlertNoApplicationProtocol: "no application protocol",
|
||||
AlertNoRenegotiation: "no renegotiation",
|
||||
AlertStatelessRetry: "stateless retry",
|
||||
AlertWouldBlock: "would have blocked",
|
||||
AlertNoAlert: "no alert",
|
||||
}
|
||||
|
683
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
683
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
File diff suppressed because it is too large
Load Diff
118
vendor/github.com/bifurcation/mint/common.go
generated
vendored
118
vendor/github.com/bifurcation/mint/common.go
generated
vendored
@ -5,9 +5,14 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedVersion uint16 = 0x7f15 // draft-21
|
||||
const (
|
||||
supportedVersion uint16 = 0x7f16 // draft-22
|
||||
tls12Version uint16 = 0x0303
|
||||
tls10Version uint16 = 0x0301
|
||||
dtls12WireVersion uint16 = 0xfefd
|
||||
)
|
||||
|
||||
var (
|
||||
// Flags for some minor compat issues
|
||||
allowWrongVersionNumber = true
|
||||
allowPKCS1 = true
|
||||
@ -20,6 +25,7 @@ const (
|
||||
RecordTypeAlert RecordType = 21
|
||||
RecordTypeHandshake RecordType = 22
|
||||
RecordTypeApplicationData RecordType = 23
|
||||
RecordTypeAck RecordType = 25
|
||||
)
|
||||
|
||||
// enum {...} HandshakeType;
|
||||
@ -42,6 +48,13 @@ const (
|
||||
HandshakeTypeMessageHash HandshakeType = 254
|
||||
)
|
||||
|
||||
var hrrRandomSentinel = [32]byte{
|
||||
0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11,
|
||||
0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
|
||||
0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e,
|
||||
0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
|
||||
}
|
||||
|
||||
// uint8 CipherSuite[2];
|
||||
type CipherSuite uint16
|
||||
|
||||
@ -150,3 +163,104 @@ const (
|
||||
KeyUpdateNotRequested KeyUpdateRequest = 0
|
||||
KeyUpdateRequested KeyUpdateRequest = 1
|
||||
)
|
||||
|
||||
type State uint8
|
||||
|
||||
const (
|
||||
StateInit = 0
|
||||
|
||||
// states valid for the client
|
||||
StateClientStart State = iota
|
||||
StateClientWaitSH
|
||||
StateClientWaitEE
|
||||
StateClientWaitCert
|
||||
StateClientWaitCV
|
||||
StateClientWaitFinished
|
||||
StateClientWaitCertCR
|
||||
StateClientConnected
|
||||
// states valid for the server
|
||||
StateServerStart State = iota
|
||||
StateServerRecvdCH
|
||||
StateServerNegotiated
|
||||
StateServerReadPastEarlyData
|
||||
StateServerWaitEOED
|
||||
StateServerWaitFlight2
|
||||
StateServerWaitCert
|
||||
StateServerWaitCV
|
||||
StateServerWaitFinished
|
||||
StateServerConnected
|
||||
)
|
||||
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClientStart:
|
||||
return "Client START"
|
||||
case StateClientWaitSH:
|
||||
return "Client WAIT_SH"
|
||||
case StateClientWaitEE:
|
||||
return "Client WAIT_EE"
|
||||
case StateClientWaitCert:
|
||||
return "Client WAIT_CERT"
|
||||
case StateClientWaitCV:
|
||||
return "Client WAIT_CV"
|
||||
case StateClientWaitFinished:
|
||||
return "Client WAIT_FINISHED"
|
||||
case StateClientWaitCertCR:
|
||||
return "Client WAIT_CERT_CR"
|
||||
case StateClientConnected:
|
||||
return "Client CONNECTED"
|
||||
case StateServerStart:
|
||||
return "Server START"
|
||||
case StateServerRecvdCH:
|
||||
return "Server RECVD_CH"
|
||||
case StateServerNegotiated:
|
||||
return "Server NEGOTIATED"
|
||||
case StateServerReadPastEarlyData:
|
||||
return "Server READ_PAST_EARLY_DATA"
|
||||
case StateServerWaitEOED:
|
||||
return "Server WAIT_EOED"
|
||||
case StateServerWaitFlight2:
|
||||
return "Server WAIT_FLIGHT2"
|
||||
case StateServerWaitCert:
|
||||
return "Server WAIT_CERT"
|
||||
case StateServerWaitCV:
|
||||
return "Server WAIT_CV"
|
||||
case StateServerWaitFinished:
|
||||
return "Server WAIT_FINISHED"
|
||||
case StateServerConnected:
|
||||
return "Server CONNECTED"
|
||||
default:
|
||||
return fmt.Sprintf("unknown state: %d", s)
|
||||
}
|
||||
}
|
||||
|
||||
// Epochs for DTLS (also used for key phase labelling)
|
||||
type Epoch uint16
|
||||
|
||||
const (
|
||||
EpochClear Epoch = 0
|
||||
EpochEarlyData Epoch = 1
|
||||
EpochHandshakeData Epoch = 2
|
||||
EpochApplicationData Epoch = 3
|
||||
EpochUpdate Epoch = 4
|
||||
)
|
||||
|
||||
func (e Epoch) label() string {
|
||||
switch e {
|
||||
case EpochClear:
|
||||
return "clear"
|
||||
case EpochEarlyData:
|
||||
return "early data"
|
||||
case EpochHandshakeData:
|
||||
return "handshake"
|
||||
case EpochApplicationData:
|
||||
return "application data"
|
||||
}
|
||||
return "Application data (updated)"
|
||||
}
|
||||
|
||||
func assert(b bool) {
|
||||
if !b {
|
||||
panic("Assertion failed")
|
||||
}
|
||||
}
|
||||
|
539
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
539
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
@ -4,6 +4,7 @@ import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -12,8 +13,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var WouldBlock = fmt.Errorf("Would have blocked")
|
||||
|
||||
type Certificate struct {
|
||||
Chain []*x509.Certificate
|
||||
PrivateKey crypto.Signer
|
||||
@ -36,16 +35,20 @@ type PreSharedKeyCache interface {
|
||||
Size() int
|
||||
}
|
||||
|
||||
type PSKMapCache map[string]PreSharedKey
|
||||
|
||||
// A CookieHandler does two things:
|
||||
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||
// - validates this byte string echoed by the client in the ClientHello
|
||||
// A CookieHandler can be used to give the application more fine-grained control over Cookies.
|
||||
// Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie.
|
||||
// When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie.
|
||||
type CookieHandler interface {
|
||||
// Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||
// If Generate returns nil, mint will not send a HelloRetryRequest.
|
||||
Generate(*Conn) ([]byte, error)
|
||||
// Validate is called when receiving a ClientHello containing a Cookie.
|
||||
// If validation failed, the handshake is aborted.
|
||||
Validate(*Conn, []byte) bool
|
||||
}
|
||||
|
||||
type PSKMapCache map[string]PreSharedKey
|
||||
|
||||
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||||
psk, ok = cache[key]
|
||||
return
|
||||
@ -74,14 +77,49 @@ type Config struct {
|
||||
AllowEarlyData bool
|
||||
// Require the client to echo a cookie.
|
||||
RequireCookie bool
|
||||
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
||||
// The default cookie handler uses 32 random bytes as a cookie.
|
||||
CookieHandler CookieHandler
|
||||
// A CookieHandler can be used to set and validate a cookie.
|
||||
// The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector.
|
||||
// If no CookieHandler is set, mint will always send a cookie.
|
||||
// The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent.
|
||||
CookieHandler CookieHandler
|
||||
// The CookieProtector is used to encrypt / decrypt cookies.
|
||||
// It should make sure that the Cookie cannot be read and tampered with by the client.
|
||||
// If non-blocking mode is used, and cookies are required, this field has to be set.
|
||||
// In blocking mode, a default cookie protector is used, if this is unused.
|
||||
CookieProtector CookieProtector
|
||||
// The ExtensionHandler is used to add custom extensions.
|
||||
ExtensionHandler AppExtensionHandler
|
||||
RequireClientAuth bool
|
||||
|
||||
// Time returns the current time as the number of seconds since the epoch.
|
||||
// If Time is nil, TLS uses time.Now.
|
||||
Time func() time.Time
|
||||
// RootCAs defines the set of root certificate authorities
|
||||
// that clients use when verifying server certificates.
|
||||
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||
RootCAs *x509.CertPool
|
||||
// InsecureSkipVerify controls whether a client verifies the
|
||||
// server's certificate chain and host name.
|
||||
// If InsecureSkipVerify is true, TLS accepts any certificate
|
||||
// presented by the server and any host name in that certificate.
|
||||
// In this mode, TLS is susceptible to man-in-the-middle attacks.
|
||||
// This should be used only for testing.
|
||||
InsecureSkipVerify bool
|
||||
|
||||
// Shared fields
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Certificates []*Certificate
|
||||
// VerifyPeerCertificate, if not nil, is called after normal
|
||||
// certificate verification by either a TLS client or server. It
|
||||
// receives the raw ASN.1 certificates provided by the peer and also
|
||||
// any verified chains that normal processing found. If it returns a
|
||||
// non-nil error, the handshake is aborted and that error results.
|
||||
//
|
||||
// If normal verification fails then the handshake will abort before
|
||||
// considering this callback. If normal verification is disabled by
|
||||
// setting InsecureSkipVerify then this callback will be considered but
|
||||
// the verifiedChains argument will always be nil.
|
||||
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
||||
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
@ -89,6 +127,7 @@ type Config struct {
|
||||
PSKs PreSharedKeyCache
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
NonBlocking bool
|
||||
UseDTLS bool
|
||||
|
||||
// The same config object can be shared among different connections, so it
|
||||
// needs its own mutex
|
||||
@ -110,17 +149,24 @@ func (c *Config) Clone() *Config {
|
||||
EarlyDataLifetime: c.EarlyDataLifetime,
|
||||
AllowEarlyData: c.AllowEarlyData,
|
||||
RequireCookie: c.RequireCookie,
|
||||
CookieHandler: c.CookieHandler,
|
||||
CookieProtector: c.CookieProtector,
|
||||
ExtensionHandler: c.ExtensionHandler,
|
||||
RequireClientAuth: c.RequireClientAuth,
|
||||
Time: c.Time,
|
||||
RootCAs: c.RootCAs,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
|
||||
Certificates: c.Certificates,
|
||||
AuthCertificate: c.AuthCertificate,
|
||||
CipherSuites: c.CipherSuites,
|
||||
Groups: c.Groups,
|
||||
SignatureSchemes: c.SignatureSchemes,
|
||||
NextProtos: c.NextProtos,
|
||||
PSKs: c.PSKs,
|
||||
PSKModes: c.PSKModes,
|
||||
NonBlocking: c.NonBlocking,
|
||||
Certificates: c.Certificates,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
CipherSuites: c.CipherSuites,
|
||||
Groups: c.Groups,
|
||||
SignatureSchemes: c.SignatureSchemes,
|
||||
NextProtos: c.NextProtos,
|
||||
PSKs: c.PSKs,
|
||||
PSKModes: c.PSKModes,
|
||||
NonBlocking: c.NonBlocking,
|
||||
UseDTLS: c.UseDTLS,
|
||||
}
|
||||
}
|
||||
|
||||
@ -147,28 +193,6 @@ func (c *Config) Init(isClient bool) error {
|
||||
if len(c.PSKModes) == 0 {
|
||||
c.PSKModes = defaultPSKModes
|
||||
}
|
||||
|
||||
// If there is no certificate, generate one
|
||||
if !isClient && len(c.Certificates) == 0 {
|
||||
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
||||
priv, err := newSigningKey(RSA_PSS_SHA256)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Certificates = []*Certificate{
|
||||
{
|
||||
Chain: []*x509.Certificate{cert},
|
||||
PrivateKey: priv,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -183,6 +207,14 @@ func (c *Config) ValidForClient() bool {
|
||||
return len(c.ServerName) > 0
|
||||
}
|
||||
|
||||
func (c *Config) time() time.Time {
|
||||
t := c.Time
|
||||
if t == nil {
|
||||
t = time.Now
|
||||
}
|
||||
return t()
|
||||
}
|
||||
|
||||
var (
|
||||
defaultSupportedCipherSuites = []CipherSuite{
|
||||
TLS_AES_128_GCM_SHA256,
|
||||
@ -214,10 +246,13 @@ var (
|
||||
)
|
||||
|
||||
type ConnectionState struct {
|
||||
HandshakeState string // string representation of the handshake state.
|
||||
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
||||
NextProto string // Selected ALPN proto
|
||||
HandshakeState State
|
||||
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
|
||||
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
|
||||
NextProto string // Selected ALPN proto
|
||||
UsingPSK bool // Are we using PSK.
|
||||
UsingEarlyData bool // Did we negotiate 0-RTT.
|
||||
}
|
||||
|
||||
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||||
@ -228,9 +263,7 @@ type Conn struct {
|
||||
conn net.Conn
|
||||
isClient bool
|
||||
|
||||
EarlyData []byte
|
||||
|
||||
state StateConnected
|
||||
state stateConnected
|
||||
hState HandshakeState
|
||||
handshakeMutex sync.Mutex
|
||||
handshakeAlert Alert
|
||||
@ -238,18 +271,28 @@ type Conn struct {
|
||||
|
||||
readBuffer []byte
|
||||
in, out *RecordLayer
|
||||
hIn, hOut *HandshakeLayer
|
||||
|
||||
extHandler AppExtensionHandler
|
||||
hsCtx *HandshakeContext
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||||
c := &Conn{conn: conn, config: config, isClient: isClient}
|
||||
c.in = NewRecordLayer(c.conn)
|
||||
c.out = NewRecordLayer(c.conn)
|
||||
c.hIn = NewHandshakeLayer(c.in)
|
||||
c.hIn.nonblocking = c.config.NonBlocking
|
||||
c.hOut = NewHandshakeLayer(c.out)
|
||||
c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
|
||||
if !config.UseDTLS {
|
||||
c.in = NewRecordLayerTLS(c.conn, directionRead)
|
||||
c.out = NewRecordLayerTLS(c.conn, directionWrite)
|
||||
c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
|
||||
c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
|
||||
} else {
|
||||
c.in = NewRecordLayerDTLS(c.conn, directionRead)
|
||||
c.out = NewRecordLayerDTLS(c.conn, directionWrite)
|
||||
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
|
||||
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
|
||||
c.hsCtx.timeoutMS = initialTimeout
|
||||
c.hsCtx.timers = newTimerSet()
|
||||
c.hsCtx.waitingNextFlight = true
|
||||
}
|
||||
c.in.label = c.label()
|
||||
c.out.label = c.label()
|
||||
c.hsCtx.hIn.nonblocking = c.config.NonBlocking
|
||||
return c
|
||||
}
|
||||
|
||||
@ -267,8 +310,12 @@ func (c *Conn) consumeRecord() error {
|
||||
// We do not support fragmentation of post-handshake handshake messages.
|
||||
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||||
start := 0
|
||||
headerLen := handshakeHeaderLenTLS
|
||||
if c.config.UseDTLS {
|
||||
headerLen = handshakeHeaderLenDTLS
|
||||
}
|
||||
for start < len(pt.fragment) {
|
||||
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
||||
if len(pt.fragment[start:]) < headerLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||||
}
|
||||
|
||||
@ -276,14 +323,15 @@ func (c *Conn) consumeRecord() error {
|
||||
hm.msgType = HandshakeType(pt.fragment[start])
|
||||
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||||
|
||||
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
||||
if len(pt.fragment[start+headerLen:]) < hmLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||||
}
|
||||
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
||||
|
||||
// Advance state machine
|
||||
state, actions, alert := c.state.Next(hm)
|
||||
hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen]
|
||||
|
||||
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||
// authentication, we'll need to allow transitions other than
|
||||
// Connected -> Connected
|
||||
state, actions, alert := c.state.ProcessMessage(hm)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
@ -299,18 +347,15 @@ func (c *Conn) consumeRecord() error {
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||
// authentication, we'll need to allow transitions other than
|
||||
// Connected -> Connected
|
||||
var connected bool
|
||||
c.state, connected = state.(StateConnected)
|
||||
c.state, connected = state.(stateConnected)
|
||||
if !connected {
|
||||
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
start += handshakeHeaderLen + hmLen
|
||||
start += headerLen + hmLen
|
||||
}
|
||||
case RecordTypeAlert:
|
||||
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
@ -332,17 +377,54 @@ func (c *Conn) consumeRecord() error {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
case RecordTypeAck:
|
||||
if !c.hsCtx.hIn.datagram {
|
||||
logf(logTypeHandshake, "Received ACK in TLS mode")
|
||||
return AlertUnexpectedMessage
|
||||
}
|
||||
return c.hsCtx.processAck(pt.fragment)
|
||||
|
||||
case RecordTypeApplicationData:
|
||||
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||||
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func readPartial(in *[]byte, buffer []byte) int {
|
||||
logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in)))
|
||||
read := copy(buffer, *in)
|
||||
*in = (*in)[read:]
|
||||
|
||||
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||
return read
|
||||
}
|
||||
|
||||
// Read application data up to the size of buffer. Handshake and alert records
|
||||
// are consumed by the Conn object directly.
|
||||
func (c *Conn) Read(buffer []byte) (int, error) {
|
||||
if _, connected := c.hState.(stateConnected); !connected {
|
||||
// Clients can't call Read prior to handshake completion.
|
||||
if c.isClient {
|
||||
return 0, errors.New("Read called before the handshake completed")
|
||||
}
|
||||
|
||||
// Neither can servers that don't allow early data.
|
||||
if !c.config.AllowEarlyData {
|
||||
return 0, errors.New("Read called before the handshake completed")
|
||||
}
|
||||
|
||||
// If there's no early data, then return WouldBlock
|
||||
if len(c.hsCtx.earlyData) == 0 {
|
||||
return 0, AlertWouldBlock
|
||||
}
|
||||
|
||||
return readPartial(&c.hsCtx.earlyData, buffer), nil
|
||||
}
|
||||
|
||||
// The handshake is now connected.
|
||||
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||||
if alert := c.Handshake(); alert != AlertNoAlert {
|
||||
return 0, alert
|
||||
@ -352,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Run our timers.
|
||||
if c.config.UseDTLS {
|
||||
if err := c.hsCtx.timers.check(time.Now()); err != nil {
|
||||
return 0, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Lock the input channel
|
||||
c.in.Lock()
|
||||
defer c.in.Unlock()
|
||||
@ -361,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) {
|
||||
// err can be nil if consumeRecord processed a non app-data
|
||||
// record.
|
||||
if err != nil {
|
||||
if c.config.NonBlocking || err != WouldBlock {
|
||||
if c.config.NonBlocking || err != AlertWouldBlock {
|
||||
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var read int
|
||||
n := len(buffer)
|
||||
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
||||
if len(c.readBuffer) <= n {
|
||||
buffer = buffer[:len(c.readBuffer)]
|
||||
copy(buffer, c.readBuffer)
|
||||
read = len(c.readBuffer)
|
||||
c.readBuffer = c.readBuffer[:0]
|
||||
} else {
|
||||
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
||||
copy(buffer[:n], c.readBuffer[:n])
|
||||
c.readBuffer = c.readBuffer[n:]
|
||||
read = n
|
||||
}
|
||||
|
||||
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||
return read, nil
|
||||
return readPartial(&c.readBuffer, buffer), nil
|
||||
}
|
||||
|
||||
// Write application data
|
||||
@ -393,6 +466,10 @@ func (c *Conn) Write(buffer []byte) (int, error) {
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
|
||||
if !c.Writable() {
|
||||
return 0, errors.New("Write called before the handshake completed (and early data not in use)")
|
||||
}
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
sent := 0
|
||||
@ -495,84 +572,44 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||||
}
|
||||
|
||||
switch action := actionGeneric.(type) {
|
||||
case SendHandshakeMessage:
|
||||
err := c.hOut.WriteMessage(action.Message)
|
||||
case QueueHandshakeMessage:
|
||||
logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType)
|
||||
err := c.hsCtx.hOut.QueueMessage(action.Message)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case SendQueuedHandshake:
|
||||
_, err := c.hsCtx.hOut.SendQueuedMessages()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
if c.config.UseDTLS {
|
||||
c.hsCtx.timers.start(retransmitTimerLabel,
|
||||
c.hsCtx.handshakeRetransmit,
|
||||
c.hsCtx.timeoutMS)
|
||||
}
|
||||
case RekeyIn:
|
||||
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet)
|
||||
err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case RekeyOut:
|
||||
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet)
|
||||
err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case SendEarlyData:
|
||||
logf(logTypeHandshake, "%s Sending early data...", label)
|
||||
_, err := c.Write(c.EarlyData)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case ReadPastEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading past early data...", label)
|
||||
// Scan past all records that fail to decrypt
|
||||
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok := err.(DecryptError)
|
||||
|
||||
for ok {
|
||||
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok = err.(DecryptError)
|
||||
}
|
||||
|
||||
case ReadEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading early data...", label)
|
||||
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
||||
|
||||
for t == RecordTypeApplicationData {
|
||||
// Read a record into the buffer. Note that this is safe
|
||||
// in blocking mode because we read the record in in
|
||||
// PeekRecordType.
|
||||
pt, err := c.in.ReadRecord()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
||||
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
||||
|
||||
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
||||
}
|
||||
logf(logTypeHandshake, "%s Done reading early data", label)
|
||||
case ResetOut:
|
||||
logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq)
|
||||
c.out.ResetClear(action.seq)
|
||||
|
||||
case StorePSK:
|
||||
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||||
@ -585,7 +622,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||||
}
|
||||
|
||||
default:
|
||||
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
||||
logf(logTypeHandshake, "%s Unknown action type", label)
|
||||
assert(false)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
@ -602,33 +640,13 @@ func (c *Conn) HandshakeSetup() Alert {
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
// Set things up
|
||||
caps := Capabilities{
|
||||
CipherSuites: c.config.CipherSuites,
|
||||
Groups: c.config.Groups,
|
||||
SignatureSchemes: c.config.SignatureSchemes,
|
||||
PSKs: c.config.PSKs,
|
||||
PSKModes: c.config.PSKModes,
|
||||
AllowEarlyData: c.config.AllowEarlyData,
|
||||
RequireCookie: c.config.RequireCookie,
|
||||
CookieHandler: c.config.CookieHandler,
|
||||
RequireClientAuth: c.config.RequireClientAuth,
|
||||
NextProtos: c.config.NextProtos,
|
||||
Certificates: c.config.Certificates,
|
||||
ExtensionHandler: c.extHandler,
|
||||
}
|
||||
opts := ConnectionOptions{
|
||||
ServerName: c.config.ServerName,
|
||||
NextProtos: c.config.NextProtos,
|
||||
EarlyData: c.EarlyData,
|
||||
}
|
||||
|
||||
if caps.RequireCookie && caps.CookieHandler == nil {
|
||||
caps.CookieHandler = &defaultCookieHandler{}
|
||||
}
|
||||
|
||||
if c.isClient {
|
||||
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
||||
state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||||
return alert
|
||||
@ -642,14 +660,56 @@ func (c *Conn) HandshakeSetup() Alert {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
state = ServerStateStart{Caps: caps, conn: c}
|
||||
if c.config.RequireCookie && c.config.CookieProtector == nil {
|
||||
logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.")
|
||||
if c.config.NonBlocking {
|
||||
logf(logTypeHandshake, "Not possible in non-blocking mode.")
|
||||
return AlertInternalError
|
||||
}
|
||||
var err error
|
||||
c.config.CookieProtector, err = NewDefaultCookieProtector()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error initializing cookie source: %v", alert)
|
||||
return AlertInternalError
|
||||
}
|
||||
}
|
||||
state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx}
|
||||
}
|
||||
|
||||
c.hState = state
|
||||
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
type handshakeMessageReader interface {
|
||||
ReadMessage() (*HandshakeMessage, Alert)
|
||||
}
|
||||
|
||||
type handshakeMessageReaderImpl struct {
|
||||
hsCtx *HandshakeContext
|
||||
}
|
||||
|
||||
var _ handshakeMessageReader = &handshakeMessageReaderImpl{}
|
||||
|
||||
func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) {
|
||||
var hm *HandshakeMessage
|
||||
var err error
|
||||
for {
|
||||
hm, err = r.hsCtx.hIn.ReadMessage()
|
||||
if err == AlertWouldBlock {
|
||||
return nil, AlertWouldBlock
|
||||
}
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error reading message: %v", err)
|
||||
return nil, AlertCloseNotify
|
||||
}
|
||||
if hm != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return hm, AlertNoAlert
|
||||
}
|
||||
|
||||
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||||
// determines whether a client or server handshake is performed. If a
|
||||
// handshake has already been performed, then its result will be returned.
|
||||
@ -669,48 +729,48 @@ func (c *Conn) Handshake() Alert {
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
var alert Alert
|
||||
if c.hState == nil {
|
||||
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
||||
alert = c.HandshakeSetup()
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label)
|
||||
alert := c.HandshakeSetup()
|
||||
if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) {
|
||||
return alert
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState)
|
||||
state := c.hState
|
||||
_, connected := state.(StateConnected)
|
||||
|
||||
var actions []HandshakeAction
|
||||
_, connected := state.(stateConnected)
|
||||
|
||||
hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx}
|
||||
for !connected {
|
||||
// Read a handshake message
|
||||
hm, err := c.hIn.ReadMessage()
|
||||
if err == WouldBlock {
|
||||
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
||||
var alert Alert
|
||||
var actions []HandshakeAction
|
||||
|
||||
// Advance the state machine
|
||||
state, actions, alert = state.Next(hmr)
|
||||
if alert == AlertWouldBlock {
|
||||
logf(logTypeHandshake, "%s Would block reading message: %s", label, alert)
|
||||
// If we blocked, then run our timers to see if any have expired.
|
||||
if c.hsCtx.hIn.datagram {
|
||||
if err := c.hsCtx.timers.check(time.Now()); err != nil {
|
||||
return AlertInternalError
|
||||
}
|
||||
}
|
||||
return AlertWouldBlock
|
||||
}
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
||||
if alert == AlertCloseNotify {
|
||||
logf(logTypeHandshake, "%s Error reading message: %s", label, alert)
|
||||
c.sendAlert(AlertCloseNotify)
|
||||
return AlertCloseNotify
|
||||
}
|
||||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||
|
||||
// Advance the state machine
|
||||
state, actions, alert = state.Next(hm)
|
||||
|
||||
if alert != AlertNoAlert {
|
||||
if alert != AlertNoAlert && alert != AlertStatelessRetry {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
return alert
|
||||
}
|
||||
|
||||
for index, action := range actions {
|
||||
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
if alert := c.takeAction(action); alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
@ -719,30 +779,48 @@ func (c *Conn) Handshake() Alert {
|
||||
|
||||
c.hState = state
|
||||
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||||
_, connected = state.(stateConnected)
|
||||
if connected {
|
||||
c.state = state.(stateConnected)
|
||||
c.handshakeComplete = true
|
||||
|
||||
_, connected = state.(StateConnected)
|
||||
}
|
||||
if !c.isClient {
|
||||
// Send NewSessionTicket if configured to
|
||||
if c.config.SendSessionTickets {
|
||||
actions, alert := c.state.NewSessionTicket(
|
||||
c.config.TicketLen,
|
||||
c.config.TicketLifetime,
|
||||
c.config.EarlyDataLifetime)
|
||||
|
||||
c.state = state.(StateConnected)
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send NewSessionTicket if acting as server
|
||||
if !c.isClient && c.config.SendSessionTickets {
|
||||
actions, alert := c.state.NewSessionTicket(
|
||||
c.config.TicketLen,
|
||||
c.config.TicketLifetime,
|
||||
c.config.EarlyDataLifetime)
|
||||
// If there is early data, move it into the main buffer
|
||||
if c.hsCtx.earlyData != nil {
|
||||
c.readBuffer = c.hsCtx.earlyData
|
||||
c.hsCtx.earlyData = nil
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
} else {
|
||||
assert(c.hsCtx.earlyData == nil)
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.NonBlocking {
|
||||
if alert == AlertStatelessRetry {
|
||||
return AlertStatelessRetry
|
||||
}
|
||||
return AlertNoAlert
|
||||
}
|
||||
}
|
||||
|
||||
c.handshakeComplete = true
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
@ -775,12 +853,15 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) GetHsState() string {
|
||||
return reflect.TypeOf(c.hState).Name()
|
||||
func (c *Conn) GetHsState() State {
|
||||
if c.hState == nil {
|
||||
return StateInit
|
||||
}
|
||||
return c.hState.State()
|
||||
}
|
||||
|
||||
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
_, connected := c.hState.(StateConnected)
|
||||
_, connected := c.hState.(stateConnected)
|
||||
if !connected {
|
||||
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||||
}
|
||||
@ -796,7 +877,7 @@ func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]b
|
||||
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||||
}
|
||||
|
||||
func (c *Conn) State() ConnectionState {
|
||||
func (c *Conn) ConnectionState() ConnectionState {
|
||||
state := ConnectionState{
|
||||
HandshakeState: c.GetHsState(),
|
||||
}
|
||||
@ -804,16 +885,32 @@ func (c *Conn) State() ConnectionState {
|
||||
if c.handshakeComplete {
|
||||
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||||
state.NextProto = c.state.Params.NextProto
|
||||
state.VerifiedChains = c.state.verifiedChains
|
||||
state.PeerCertificates = c.state.peerCertificates
|
||||
state.UsingPSK = c.state.Params.UsingPSK
|
||||
state.UsingEarlyData = c.state.Params.UsingEarlyData
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
||||
if c.hState != nil {
|
||||
return fmt.Errorf("Can't set extension handler after setup")
|
||||
func (c *Conn) Writable() bool {
|
||||
// If we're connected, we're writable.
|
||||
if _, connected := c.hState.(stateConnected); connected {
|
||||
return true
|
||||
}
|
||||
|
||||
c.extHandler = h
|
||||
return nil
|
||||
// If we're a client in 0-RTT, then we're writable.
|
||||
if c.isClient && c.out.cipher.epoch == EpochEarlyData {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Conn) label() string {
|
||||
if c.isClient {
|
||||
return "client"
|
||||
}
|
||||
return "server"
|
||||
}
|
||||
|
86
vendor/github.com/bifurcation/mint/cookie-protector.go
generated
vendored
Normal file
86
vendor/github.com/bifurcation/mint/cookie-protector.go
generated
vendored
Normal file
@ -0,0 +1,86 @@
|
||||
package mint
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// CookieProtector is used to create and verify a cookie
|
||||
type CookieProtector interface {
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
const cookieSecretSize = 32
|
||||
const cookieNonceSize = 32
|
||||
|
||||
// The DefaultCookieProtector is a simple implementation for the CookieProtector.
|
||||
type DefaultCookieProtector struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
var _ CookieProtector = &DefaultCookieProtector{}
|
||||
|
||||
// NewDefaultCookieProtector creates a source for source address tokens
|
||||
func NewDefaultCookieProtector() (CookieProtector, error) {
|
||||
secret := make([]byte, cookieSecretSize)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DefaultCookieProtector{secret: secret}, nil
|
||||
}
|
||||
|
||||
// NewToken encodes data into a new token.
|
||||
func (s *DefaultCookieProtector) NewToken(data []byte) ([]byte, error) {
|
||||
nonce := make([]byte, cookieNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token.
|
||||
func (s *DefaultCookieProtector) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < cookieNonceSize {
|
||||
return nil, fmt.Errorf("Token too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:cookieNonceSize]
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
|
||||
}
|
||||
|
||||
func (s *DefaultCookieProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
|
||||
h := hkdf.New(sha256.New, s.secret, nonce, []byte("mint cookie source"))
|
||||
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
|
||||
if _, err := io.ReadFull(h, key); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aeadNonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(h, aeadNonce); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return aead, aeadNonce, nil
|
||||
}
|
81
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
81
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
@ -331,40 +331,6 @@ func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||
sigAlg, ok := x509AlgMap[alg]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||
}
|
||||
if len(name) == 0 {
|
||||
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||
SignatureAlgorithm: sigAlg,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
DNSNames: []string{name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// It is safe to ignore the error here because we're parsing known-good data
|
||||
cert, _ := x509.ParseCertificate(der)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// XXX(rlb): Copied from crypto/x509
|
||||
type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
@ -652,3 +618,50 @@ func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
|
||||
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
||||
}
|
||||
}
|
||||
|
||||
func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) {
|
||||
priv, err := newSigningKey(alg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cert, err := newSelfSigned(name, alg, priv)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return priv, cert, nil
|
||||
}
|
||||
|
||||
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||
sigAlg, ok := x509AlgMap[alg]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||
}
|
||||
if len(name) == 0 {
|
||||
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||
SignatureAlgorithm: sigAlg,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
DNSNames: []string{name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// It is safe to ignore the error here because we're parsing known-good data
|
||||
cert, _ := x509.ParseCertificate(der)
|
||||
return cert, nil
|
||||
}
|
||||
|
222
vendor/github.com/bifurcation/mint/dtls.go
generated
vendored
Normal file
222
vendor/github.com/bifurcation/mint/dtls.go
generated
vendored
Normal file
@ -0,0 +1,222 @@
|
||||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
initialMtu = 1200
|
||||
initialTimeout = 100
|
||||
)
|
||||
|
||||
// labels for timers
|
||||
const (
|
||||
retransmitTimerLabel = "handshake retransmit"
|
||||
ackTimerLabel = "ack timer"
|
||||
)
|
||||
|
||||
type SentHandshakeFragment struct {
|
||||
seq uint32
|
||||
offset int
|
||||
fragLength int
|
||||
record uint64
|
||||
acked bool
|
||||
}
|
||||
|
||||
type DtlsAck struct {
|
||||
RecordNumbers []uint64 `tls:"head=2"`
|
||||
}
|
||||
|
||||
func wireVersion(h *HandshakeLayer) uint16 {
|
||||
if h.datagram {
|
||||
return dtls12WireVersion
|
||||
}
|
||||
return tls12Version
|
||||
}
|
||||
|
||||
func dtlsConvertVersion(version uint16) uint16 {
|
||||
if version == tls12Version {
|
||||
return dtls12WireVersion
|
||||
}
|
||||
if version == tls10Version {
|
||||
return 0xfeff
|
||||
}
|
||||
panic(fmt.Sprintf("Internal error, unexpected version=%d", version))
|
||||
}
|
||||
|
||||
// TODO(ekr@rtfm.com): Move these to state-machine.go
|
||||
func (h *HandshakeContext) handshakeRetransmit() error {
|
||||
if _, err := h.hOut.SendQueuedMessages(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.timers.start(retransmitTimerLabel,
|
||||
h.handshakeRetransmit,
|
||||
h.timeoutMS)
|
||||
|
||||
// TODO(ekr@rtfm.com): Back off timer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeContext) sendAck() error {
|
||||
toack := h.hIn.recvdRecords
|
||||
|
||||
count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU
|
||||
if len(toack) > count {
|
||||
toack = toack[:count]
|
||||
}
|
||||
logf(logTypeHandshake, "Sending ACK: [%x]", toack)
|
||||
|
||||
ack := &DtlsAck{toack}
|
||||
body, err := syntax.Marshal(&ack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.hOut.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAck,
|
||||
fragment: body,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeContext) processAck(data []byte) error {
|
||||
// Cancel the retransmit timer because we will be resending
|
||||
// and possibly re-arming later.
|
||||
h.timers.cancel(retransmitTimerLabel)
|
||||
|
||||
ack := &DtlsAck{}
|
||||
read, err := syntax.Unmarshal(data, &ack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) != read {
|
||||
return fmt.Errorf("Invalid encoding: Extra data not consumed")
|
||||
}
|
||||
logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers)
|
||||
|
||||
for _, r := range ack.RecordNumbers {
|
||||
for _, m := range h.sentFragments {
|
||||
if r == m.record {
|
||||
logf(logTypeHandshake, "Marking %v %v(%v) as acked",
|
||||
m.seq, m.offset, m.fragLength)
|
||||
m.acked = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
count, err := h.hOut.SendQueuedMessages()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
logf(logTypeHandshake, "All messages ACKed")
|
||||
h.hOut.ClearQueuedMessages()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset the timer
|
||||
h.timers.start(retransmitTimerLabel,
|
||||
h.handshakeRetransmit,
|
||||
h.timeoutMS)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) GetDTLSTimeout() (bool, time.Duration) {
|
||||
return c.hsCtx.timers.remaining()
|
||||
}
|
||||
|
||||
func (h *HandshakeContext) receivedHandshakeMessage() {
|
||||
logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight)
|
||||
// This just enables tests.
|
||||
if h.hIn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !h.hIn.datagram {
|
||||
return
|
||||
}
|
||||
|
||||
if h.waitingNextFlight {
|
||||
logf(logTypeHandshake, "Received the start of the flight")
|
||||
|
||||
// Clear the outgoing DTLS queue and terminate the retransmit timer
|
||||
h.hOut.ClearQueuedMessages()
|
||||
h.timers.cancel(retransmitTimerLabel)
|
||||
|
||||
// OK, we're not waiting any more.
|
||||
h.waitingNextFlight = false
|
||||
}
|
||||
|
||||
// Now pre-emptively arm the ACK timer if it's not armed already.
|
||||
// We'll automatically dis-arm it at the end of the handshake.
|
||||
if h.timers.getTimer(ackTimerLabel) == nil {
|
||||
h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HandshakeContext) receivedEndOfFlight() {
|
||||
logf(logTypeHandshake, "%p Received the end of the flight", h)
|
||||
if !h.hIn.datagram {
|
||||
return
|
||||
}
|
||||
|
||||
// Empty incoming queue
|
||||
h.hIn.queued = nil
|
||||
|
||||
// Note that we are waiting for the next flight.
|
||||
h.waitingNextFlight = true
|
||||
|
||||
// Clear the ACK queue.
|
||||
h.hIn.recvdRecords = nil
|
||||
|
||||
// Disarm the ACK timer
|
||||
h.timers.cancel(ackTimerLabel)
|
||||
}
|
||||
|
||||
func (h *HandshakeContext) receivedFinalFlight() {
|
||||
logf(logTypeHandshake, "%p Received final flight", h)
|
||||
if !h.hIn.datagram {
|
||||
return
|
||||
}
|
||||
|
||||
// Disarm the ACK timer
|
||||
h.timers.cancel(ackTimerLabel)
|
||||
|
||||
// But send an ACK immediately.
|
||||
h.sendAck()
|
||||
}
|
||||
|
||||
func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool {
|
||||
logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen)
|
||||
for _, f := range h.sentFragments {
|
||||
if !f.acked {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.seq != seq {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.offset > offset {
|
||||
continue
|
||||
}
|
||||
|
||||
// At this point, we know that the stored fragment starts
|
||||
// at or before what we want to send, so check where the end
|
||||
// is.
|
||||
if f.offset+f.fragLength < offset+fraglen {
|
||||
continue
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
102
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
102
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
@ -3,7 +3,6 @@ package mint
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
)
|
||||
|
||||
@ -77,14 +76,40 @@ func (el *ExtensionList) Add(src ExtensionBody) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (el ExtensionList) Find(dst ExtensionBody) bool {
|
||||
for _, ext := range el {
|
||||
if ext.ExtensionType == dst.Type() {
|
||||
_, err := dst.Unmarshal(ext.ExtensionData)
|
||||
return err == nil
|
||||
func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) {
|
||||
found := make(map[ExtensionType]bool)
|
||||
|
||||
for _, dst := range dsts {
|
||||
for _, ext := range el {
|
||||
if ext.ExtensionType == dst.Type() {
|
||||
if found[dst.Type()] {
|
||||
return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type())
|
||||
}
|
||||
|
||||
err := safeUnmarshal(dst, ext.ExtensionData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
found[dst.Type()] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
return found, nil
|
||||
}
|
||||
|
||||
func (el ExtensionList) Find(dst ExtensionBody) (bool, error) {
|
||||
for _, ext := range el {
|
||||
if ext.ExtensionType == dst.Type() {
|
||||
err := safeUnmarshal(dst, ext.ExtensionData)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
@ -529,19 +554,56 @@ func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
|
||||
// ProtocolVersion versions<2..254>;
|
||||
// } SupportedVersions;
|
||||
type SupportedVersionsExtension struct {
|
||||
HandshakeType HandshakeType
|
||||
Versions []uint16
|
||||
}
|
||||
|
||||
type SupportedVersionsClientHelloInner struct {
|
||||
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
||||
}
|
||||
|
||||
type SupportedVersionsServerHelloInner struct {
|
||||
Version uint16
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSupportedVersions
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sv)
|
||||
switch sv.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions})
|
||||
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
|
||||
return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]})
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sv)
|
||||
switch sv.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
var inner SupportedVersionsClientHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sv.Versions = inner.Versions
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
|
||||
var inner SupportedVersionsServerHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sv.Versions = []uint16{inner.Version}
|
||||
return read, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// struct {
|
||||
@ -562,25 +624,3 @@ func (c CookieExtension) Marshal() ([]byte, error) {
|
||||
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, c)
|
||||
}
|
||||
|
||||
// defaultCookieLength is the default length of a cookie
|
||||
const defaultCookieLength = 32
|
||||
|
||||
type defaultCookieHandler struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
var _ CookieHandler = &defaultCookieHandler{}
|
||||
|
||||
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
||||
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
|
||||
h.data = make([]byte, defaultCookieLength)
|
||||
if _, err := prng.Read(h.data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.data, nil
|
||||
}
|
||||
|
||||
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
|
||||
return bytes.Equal(h.data, data)
|
||||
}
|
||||
|
8
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
8
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
@ -66,8 +66,8 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
||||
f.remainder = f.remainder[copied:]
|
||||
f.writeOffset += copied
|
||||
if f.writeOffset < len(f.working) {
|
||||
logf(logTypeFrameReader, "Read would have blocked 1")
|
||||
return nil, nil, WouldBlock
|
||||
logf(logTypeVerbose, "Read would have blocked 1")
|
||||
return nil, nil, AlertWouldBlock
|
||||
}
|
||||
// Reset the write offset, because we are now full.
|
||||
f.writeOffset = 0
|
||||
@ -93,6 +93,6 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
||||
f.state = kFrameReaderBody
|
||||
}
|
||||
|
||||
logf(logTypeFrameReader, "Read would have blocked 2")
|
||||
return nil, nil, WouldBlock
|
||||
logf(logTypeVerbose, "Read would have blocked 2")
|
||||
return nil, nil, AlertWouldBlock
|
||||
}
|
||||
|
448
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
448
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
@ -7,7 +7,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
handshakeHeaderLen = 4 // handshake message header length
|
||||
handshakeHeaderLenTLS = 4 // handshake message header length
|
||||
handshakeHeaderLenDTLS = 12 // handshake message header length
|
||||
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||
)
|
||||
|
||||
@ -27,28 +28,42 @@ const (
|
||||
// opaque msg<0..2^24-1>
|
||||
// } Handshake;
|
||||
//
|
||||
// TODO: File a spec bug
|
||||
type HandshakeMessage struct {
|
||||
// Omitted: length
|
||||
msgType HandshakeType
|
||||
body []byte
|
||||
msgType HandshakeType
|
||||
seq uint32
|
||||
body []byte
|
||||
datagram bool
|
||||
offset uint32 // Used for DTLS
|
||||
length uint32
|
||||
cipher *cipherState
|
||||
}
|
||||
|
||||
// Note: This could be done with the `syntax` module, using the simplified
|
||||
// syntax as discussed above. However, since this is so simple, there's not
|
||||
// much benefit to doing so.
|
||||
// When datagram is set, we marshal this as a whole DTLS record.
|
||||
func (hm *HandshakeMessage) Marshal() []byte {
|
||||
if hm == nil {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
msgLen := len(hm.body)
|
||||
data := make([]byte, 4+len(hm.body))
|
||||
data[0] = byte(hm.msgType)
|
||||
data[1] = byte(msgLen >> 16)
|
||||
data[2] = byte(msgLen >> 8)
|
||||
data[3] = byte(msgLen)
|
||||
copy(data[4:], hm.body)
|
||||
fragLen := len(hm.body)
|
||||
var data []byte
|
||||
|
||||
if hm.datagram {
|
||||
data = make([]byte, handshakeHeaderLenDTLS+fragLen)
|
||||
} else {
|
||||
data = make([]byte, handshakeHeaderLenTLS+fragLen)
|
||||
}
|
||||
tmp := data
|
||||
tmp = encodeUint(uint64(hm.msgType), 1, tmp)
|
||||
tmp = encodeUint(uint64(hm.length), 3, tmp)
|
||||
if hm.datagram {
|
||||
tmp = encodeUint(uint64(hm.seq), 2, tmp)
|
||||
tmp = encodeUint(uint64(hm.offset), 3, tmp)
|
||||
tmp = encodeUint(uint64(fragLen), 3, tmp)
|
||||
}
|
||||
copy(tmp, hm.body)
|
||||
return data
|
||||
}
|
||||
|
||||
@ -61,8 +76,6 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||
body = new(ClientHelloBody)
|
||||
case HandshakeTypeServerHello:
|
||||
body = new(ServerHelloBody)
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
body = new(HelloRetryRequestBody)
|
||||
case HandshakeTypeEncryptedExtensions:
|
||||
body = new(EncryptedExtensionsBody)
|
||||
case HandshakeTypeCertificate:
|
||||
@ -83,62 +96,104 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||
}
|
||||
|
||||
_, err := body.Unmarshal(hm.body)
|
||||
err := safeUnmarshal(body, hm.body)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||
func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||
data, err := body.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HandshakeMessage{
|
||||
msgType: body.Type(),
|
||||
body: data,
|
||||
}, nil
|
||||
m := &HandshakeMessage{
|
||||
msgType: body.Type(),
|
||||
body: data,
|
||||
seq: h.msgSeq,
|
||||
datagram: h.datagram,
|
||||
length: uint32(len(data)),
|
||||
}
|
||||
h.msgSeq++
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type HandshakeLayer struct {
|
||||
nonblocking bool // Should we operate in nonblocking mode
|
||||
conn *RecordLayer // Used for reading/writing records
|
||||
frame *frameReader // The buffered frame reader
|
||||
ctx *HandshakeContext // The handshake we are attached to
|
||||
nonblocking bool // Should we operate in nonblocking mode
|
||||
conn *RecordLayer // Used for reading/writing records
|
||||
frame *frameReader // The buffered frame reader
|
||||
datagram bool // Is this DTLS?
|
||||
msgSeq uint32 // The DTLS message sequence number
|
||||
queued []*HandshakeMessage // In/out queue
|
||||
sent []*HandshakeMessage // Sent messages for DTLS
|
||||
recvdRecords []uint64 // Records we have received.
|
||||
maxFragmentLen int
|
||||
}
|
||||
|
||||
type handshakeLayerFrameDetails struct{}
|
||||
type handshakeLayerFrameDetails struct {
|
||||
datagram bool
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||
return handshakeHeaderLen
|
||||
if d.datagram {
|
||||
return handshakeHeaderLenDTLS
|
||||
}
|
||||
return handshakeHeaderLenTLS
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||
return handshakeHeaderLen + maxFragmentLen
|
||||
return d.headerLen() + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
logf(logTypeIO, "Header=%x", hdr)
|
||||
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
||||
// The length of this fragment (as opposed to the message)
|
||||
// is always the last three bytes for both TLS and DTLS
|
||||
val, _ := decodeUint(hdr[len(hdr)-3:], 3)
|
||||
return int(val), nil
|
||||
}
|
||||
|
||||
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
||||
func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
|
||||
h := HandshakeLayer{}
|
||||
h.ctx = c
|
||||
h.conn = r
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
||||
h.datagram = false
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
|
||||
h.maxFragmentLen = maxFragmentLen
|
||||
return &h
|
||||
}
|
||||
|
||||
func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
|
||||
h := HandshakeLayer{}
|
||||
h.ctx = c
|
||||
h.conn = r
|
||||
h.datagram = true
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
|
||||
h.maxFragmentLen = initialMtu // Not quite right
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) readRecord() error {
|
||||
logf(logTypeIO, "Trying to read record")
|
||||
pt, err := h.conn.ReadRecord()
|
||||
logf(logTypeVerbose, "Trying to read record")
|
||||
pt, err := h.conn.readRecordAnyEpoch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pt.contentType != RecordTypeHandshake &&
|
||||
pt.contentType != RecordTypeAlert {
|
||||
switch pt.contentType {
|
||||
case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
|
||||
default:
|
||||
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||
}
|
||||
|
||||
if pt.contentType == RecordTypeAck {
|
||||
if !h.datagram {
|
||||
return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
|
||||
}
|
||||
logf(logTypeIO, "read ACK")
|
||||
return h.ctx.processAck(pt.fragment)
|
||||
}
|
||||
|
||||
if pt.contentType == RecordTypeAlert {
|
||||
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||
if len(pt.fragment) < 2 {
|
||||
@ -148,7 +203,19 @@ func (h *HandshakeLayer) readRecord() error {
|
||||
return Alert(pt.fragment[1])
|
||||
}
|
||||
|
||||
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
||||
assert(h.ctx.hIn.conn != nil)
|
||||
if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
|
||||
// This is out of order but we're dropping it.
|
||||
// TODO(ekr@rtfm.com): If server, need to retransmit Finished.
|
||||
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Anything else shouldn't happen.
|
||||
return AlertIllegalParameter
|
||||
}
|
||||
|
||||
h.recvdRecords = append(h.recvdRecords, pt.seq)
|
||||
h.frame.addChunk(pt.fragment)
|
||||
|
||||
return nil
|
||||
@ -171,83 +238,314 @@ func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
|
||||
h.msgSeq = seq + 1
|
||||
var i int
|
||||
var m *HandshakeMessage
|
||||
for i, m = range h.queued {
|
||||
if m.seq > seq {
|
||||
break
|
||||
}
|
||||
}
|
||||
h.queued = h.queued[i:]
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
|
||||
if hm.seq < h.msgSeq {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO(ekr@rtfm.com): Send an ACK immediately if we got something
|
||||
// out of order.
|
||||
h.ctx.receivedHandshakeMessage()
|
||||
|
||||
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
|
||||
// TODO(ekr@rtfm.com): Check the length?
|
||||
// This is complete.
|
||||
h.noteMessageDelivered(hm.seq)
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// Now insert sorted.
|
||||
var i int
|
||||
for i = 0; i < len(h.queued); i++ {
|
||||
f := h.queued[i]
|
||||
if hm.seq < f.seq {
|
||||
break
|
||||
}
|
||||
if hm.offset < f.offset {
|
||||
break
|
||||
}
|
||||
}
|
||||
tmp := make([]*HandshakeMessage, 0, len(h.queued)+1)
|
||||
tmp = append(tmp, h.queued[:i]...)
|
||||
tmp = append(tmp, hm)
|
||||
tmp = append(tmp, h.queued[i:]...)
|
||||
h.queued = tmp
|
||||
|
||||
return h.checkMessageAvailable()
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
|
||||
if len(h.queued) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
hm := h.queued[0]
|
||||
if hm.seq != h.msgSeq {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
|
||||
// TODO(ekr@rtfm.com): Check the length?
|
||||
// This is complete.
|
||||
h.noteMessageDelivered(hm.seq)
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// OK, this at least might complete the message.
|
||||
end := uint32(0)
|
||||
buf := make([]byte, hm.length)
|
||||
|
||||
for _, f := range h.queued {
|
||||
// Out of fragments
|
||||
if f.seq > hm.seq {
|
||||
break
|
||||
}
|
||||
|
||||
if f.length != uint32(len(buf)) {
|
||||
return nil, fmt.Errorf("Mismatched DTLS length")
|
||||
}
|
||||
|
||||
if f.offset > end {
|
||||
break
|
||||
}
|
||||
|
||||
if f.offset+uint32(len(f.body)) > end {
|
||||
// OK, this is adding something we don't know about
|
||||
copy(buf[f.offset:], f.body)
|
||||
end = f.offset + uint32(len(f.body))
|
||||
if end == hm.length {
|
||||
h2 := *hm
|
||||
h2.offset = 0
|
||||
h2.body = buf
|
||||
h.noteMessageDelivered(hm.seq)
|
||||
return &h2, nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||
var hdr, body []byte
|
||||
var err error
|
||||
|
||||
hm, err := h.checkMessageAvailable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hm != nil {
|
||||
return hm, nil
|
||||
}
|
||||
for {
|
||||
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||
logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||
if h.frame.needed() > 0 {
|
||||
logf(logTypeHandshake, "Trying to read a new record")
|
||||
logf(logTypeVerbose, "Trying to read a new record")
|
||||
err = h.readRecord()
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
return nil, err
|
||||
|
||||
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hdr, body, err = h.frame.process()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "read handshake message")
|
||||
|
||||
hm := &HandshakeMessage{}
|
||||
hm = &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(hdr[0])
|
||||
|
||||
hm.datagram = h.datagram
|
||||
hm.body = make([]byte, len(body))
|
||||
copy(hm.body, body)
|
||||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||
if h.datagram {
|
||||
tmp, hdr := decodeUint(hdr[1:], 3)
|
||||
hm.length = uint32(tmp)
|
||||
tmp, hdr = decodeUint(hdr, 2)
|
||||
hm.seq = uint32(tmp)
|
||||
tmp, hdr = decodeUint(hdr, 3)
|
||||
hm.offset = uint32(tmp)
|
||||
|
||||
return h.newFragmentReceived(hm)
|
||||
}
|
||||
|
||||
hm.length = uint32(len(body))
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
||||
return h.WriteMessages([]*HandshakeMessage{hm})
|
||||
func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
|
||||
hm.cipher = h.conn.cipher
|
||||
h.queued = append(h.queued, hm)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
||||
func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
|
||||
logf(logTypeHandshake, "Sending outgoing messages")
|
||||
count, err := h.WriteMessages(h.queued)
|
||||
if !h.datagram {
|
||||
h.ClearQueuedMessages()
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) ClearQueuedMessages() {
|
||||
logf(logTypeHandshake, "Clearing outgoing hs message queue")
|
||||
h.queued = nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
|
||||
var buf []byte
|
||||
|
||||
// Figure out if we're going to want the full header or just
|
||||
// the body
|
||||
hdrlen := 0
|
||||
if hm.datagram {
|
||||
hdrlen = handshakeHeaderLenDTLS
|
||||
} else if start == 0 {
|
||||
hdrlen = handshakeHeaderLenTLS
|
||||
}
|
||||
|
||||
// Compute the amount of body we can fit in
|
||||
room -= hdrlen
|
||||
if room == 0 {
|
||||
// This works because we are doing one record per
|
||||
// message
|
||||
panic("Too short max fragment len")
|
||||
}
|
||||
bodylen := len(hm.body) - start
|
||||
if bodylen > room {
|
||||
bodylen = room
|
||||
}
|
||||
body := hm.body[start : start+bodylen]
|
||||
|
||||
// Now see if this chunk has been ACKed. This doesn't produce ideal
|
||||
// retransmission but is simple.
|
||||
if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
|
||||
logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
|
||||
return false, start + bodylen, nil
|
||||
}
|
||||
|
||||
// Encode the data.
|
||||
if hdrlen > 0 {
|
||||
hm2 := *hm
|
||||
hm2.offset = uint32(start)
|
||||
hm2.body = body
|
||||
buf = hm2.Marshal()
|
||||
hm = &hm2
|
||||
} else {
|
||||
buf = body
|
||||
}
|
||||
|
||||
if h.datagram {
|
||||
// Remember that we sent this.
|
||||
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
|
||||
hm.seq,
|
||||
start,
|
||||
len(body),
|
||||
h.conn.cipher.combineSeq(true),
|
||||
false,
|
||||
})
|
||||
}
|
||||
return true, start + bodylen, h.conn.writeRecordWithPadding(
|
||||
&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buf,
|
||||
},
|
||||
hm.cipher, 0)
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
|
||||
start := int(0)
|
||||
|
||||
if len(hm.body) > maxHandshakeMessageLen {
|
||||
return 0, fmt.Errorf("Tried to write a handshake message that's too long")
|
||||
}
|
||||
|
||||
written := 0
|
||||
wrote := false
|
||||
|
||||
// Always make one pass through to allow EOED (which is empty).
|
||||
for {
|
||||
var err error
|
||||
wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if wrote {
|
||||
written++
|
||||
}
|
||||
if start >= len(hm.body) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
|
||||
written := 0
|
||||
for _, hm := range hms {
|
||||
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||
}
|
||||
|
||||
// Write out headers and bodies
|
||||
buffer := []byte{}
|
||||
for _, msg := range hms {
|
||||
msgLen := len(msg.body)
|
||||
if msgLen > maxHandshakeMessageLen {
|
||||
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
||||
}
|
||||
|
||||
buffer = append(buffer, msg.Marshal()...)
|
||||
}
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start : start+maxFragmentLen],
|
||||
})
|
||||
|
||||
wrote, err := h.WriteMessage(hm)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
written += wrote
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// Send a final partial fragment if necessary
|
||||
if start < len(buffer) {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start:],
|
||||
})
|
||||
func encodeUint(v uint64, size int, out []byte) []byte {
|
||||
for i := size - 1; i >= 0; i-- {
|
||||
out[i] = byte(v & 0xff)
|
||||
v >>= 8
|
||||
}
|
||||
return out[size:]
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func decodeUint(in []byte, size int) (uint64, []byte) {
|
||||
val := uint64(0)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
val <<= 8
|
||||
val += uint64(in[i])
|
||||
}
|
||||
return val, in[size:]
|
||||
}
|
||||
|
||||
type marshalledPDU interface {
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
func safeUnmarshal(pdu marshalledPDU, data []byte) error {
|
||||
read, err := pdu.Unmarshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) != read {
|
||||
return fmt.Errorf("Invalid encoding: Extra data not consumed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
151
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
151
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
@ -25,15 +25,14 @@ type HandshakeMessageBody interface {
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } ClientHello;
|
||||
type ClientHelloBody struct {
|
||||
// Omitted: clientVersion
|
||||
// Omitted: legacySessionID
|
||||
// Omitted: legacyCompressionMethods
|
||||
Random [32]byte
|
||||
CipherSuites []CipherSuite
|
||||
Extensions ExtensionList
|
||||
LegacyVersion uint16
|
||||
Random [32]byte
|
||||
LegacySessionID []byte
|
||||
CipherSuites []CipherSuite
|
||||
Extensions ExtensionList
|
||||
}
|
||||
|
||||
type clientHelloBodyInner struct {
|
||||
type clientHelloBodyInnerTLS struct {
|
||||
LegacyVersion uint16
|
||||
Random [32]byte
|
||||
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||
@ -42,40 +41,86 @@ type clientHelloBodyInner struct {
|
||||
Extensions []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
type clientHelloBodyInnerDTLS struct {
|
||||
LegacyVersion uint16
|
||||
Random [32]byte
|
||||
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||
EmptyCookie uint8
|
||||
CipherSuites []CipherSuite `tls:"head=2,min=2"`
|
||||
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
|
||||
Extensions []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Type() HandshakeType {
|
||||
return HandshakeTypeClientHello
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(clientHelloBodyInner{
|
||||
LegacyVersion: 0x0303,
|
||||
Random: ch.Random,
|
||||
LegacySessionID: []byte{},
|
||||
CipherSuites: ch.CipherSuites,
|
||||
LegacyCompressionMethods: []byte{0},
|
||||
Extensions: ch.Extensions,
|
||||
})
|
||||
if ch.LegacyVersion == tls12Version {
|
||||
return syntax.Marshal(clientHelloBodyInnerTLS{
|
||||
LegacyVersion: ch.LegacyVersion,
|
||||
Random: ch.Random,
|
||||
LegacySessionID: []byte{},
|
||||
CipherSuites: ch.CipherSuites,
|
||||
LegacyCompressionMethods: []byte{0},
|
||||
Extensions: ch.Extensions,
|
||||
})
|
||||
} else {
|
||||
return syntax.Marshal(clientHelloBodyInnerDTLS{
|
||||
LegacyVersion: ch.LegacyVersion,
|
||||
Random: ch.Random,
|
||||
LegacySessionID: []byte{},
|
||||
CipherSuites: ch.CipherSuites,
|
||||
LegacyCompressionMethods: []byte{0},
|
||||
Extensions: ch.Extensions,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
||||
var inner clientHelloBodyInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var read int
|
||||
var err error
|
||||
|
||||
// We are strict about these things because we only support 1.3
|
||||
if inner.LegacyVersion != 0x0303 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
|
||||
}
|
||||
// Note that this might be 0, in which case we do TLS. That
|
||||
// makes the tests easier.
|
||||
if ch.LegacyVersion != dtls12WireVersion {
|
||||
var inner clientHelloBodyInnerTLS
|
||||
read, err = syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||
}
|
||||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||
}
|
||||
|
||||
ch.Random = inner.Random
|
||||
ch.CipherSuites = inner.CipherSuites
|
||||
ch.Extensions = inner.Extensions
|
||||
ch.LegacyVersion = inner.LegacyVersion
|
||||
ch.Random = inner.Random
|
||||
ch.LegacySessionID = inner.LegacySessionID
|
||||
ch.CipherSuites = inner.CipherSuites
|
||||
ch.Extensions = inner.Extensions
|
||||
} else {
|
||||
var inner clientHelloBodyInnerDTLS
|
||||
read, err = syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if inner.EmptyCookie != 0 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Invalid cookie")
|
||||
}
|
||||
|
||||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||
}
|
||||
|
||||
ch.LegacyVersion = inner.LegacyVersion
|
||||
ch.Random = inner.Random
|
||||
ch.LegacySessionID = inner.LegacySessionID
|
||||
ch.CipherSuites = inner.CipherSuites
|
||||
ch.Extensions = inner.Extensions
|
||||
}
|
||||
return read, nil
|
||||
}
|
||||
|
||||
@ -90,10 +135,15 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
||||
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
||||
}
|
||||
|
||||
chm, err := HandshakeMessageFromBody(&ch)
|
||||
body, err := ch.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chm := &HandshakeMessage{
|
||||
msgType: ch.Type(),
|
||||
body: body,
|
||||
length: uint32(len(body)),
|
||||
}
|
||||
chData := chm.Marshal()
|
||||
|
||||
psk := PreSharedKeyExtension{
|
||||
@ -116,39 +166,20 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion server_version;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<2..2^16-1>;
|
||||
// } HelloRetryRequest;
|
||||
type HelloRetryRequestBody struct {
|
||||
Version uint16
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Type() HandshakeType {
|
||||
return HandshakeTypeHelloRetryRequest
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(hrr)
|
||||
}
|
||||
|
||||
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, hrr)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion version;
|
||||
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
||||
// Random random;
|
||||
// opaque legacy_session_id_echo<0..32>;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// uint8 legacy_compression_method = 0;
|
||||
// Extension extensions<6..2^16-1>;
|
||||
// } ServerHello;
|
||||
type ServerHelloBody struct {
|
||||
Version uint16
|
||||
Random [32]byte
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
Version uint16
|
||||
Random [32]byte
|
||||
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||
CipherSuite CipherSuite
|
||||
LegacyCompressionMethod uint8
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (sh ServerHelloBody) Type() HandshakeType {
|
||||
|
11
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
11
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
@ -148,7 +148,7 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme
|
||||
}
|
||||
|
||||
if len(candidatesByName) == 0 {
|
||||
return nil, 0, fmt.Errorf("No certificates available for server name")
|
||||
return nil, 0, fmt.Errorf("No certificates available for server name: %s", *serverName)
|
||||
}
|
||||
|
||||
candidates = candidatesByName
|
||||
@ -168,10 +168,11 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme
|
||||
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
||||
}
|
||||
|
||||
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
|
||||
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
|
||||
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
|
||||
return usingEarlyData
|
||||
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) {
|
||||
using = gotEarlyData && usingPSK && allowEarlyData
|
||||
rejected = gotEarlyData && !using
|
||||
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected)
|
||||
return
|
||||
}
|
||||
|
||||
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
||||
|
306
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
306
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
@ -1,7 +1,6 @@
|
||||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -9,9 +8,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
sequenceNumberLen = 8 // sequence number length
|
||||
recordHeaderLen = 5 // record header length
|
||||
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||
sequenceNumberLen = 8 // sequence number length
|
||||
recordHeaderLenTLS = 5 // record header length (TLS)
|
||||
recordHeaderLenDTLS = 13 // record header length (DTLS)
|
||||
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||
)
|
||||
|
||||
type DecryptError string
|
||||
@ -20,9 +20,16 @@ func (err DecryptError) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
|
||||
type direction uint8
|
||||
|
||||
const (
|
||||
directionWrite = direction(1)
|
||||
directionRead = direction(2)
|
||||
)
|
||||
|
||||
// struct {
|
||||
// ContentType type;
|
||||
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
||||
// ProtocolVersion record_version [0301 for CH, 0303 for others]
|
||||
// uint16 length;
|
||||
// opaque fragment[TLSPlaintext.length];
|
||||
// } TLSPlaintext;
|
||||
@ -30,87 +37,177 @@ type TLSPlaintext struct {
|
||||
// Omitted: record_version (static)
|
||||
// Omitted: length (computed from fragment)
|
||||
contentType RecordType
|
||||
epoch Epoch
|
||||
seq uint64
|
||||
fragment []byte
|
||||
}
|
||||
|
||||
type cipherState struct {
|
||||
epoch Epoch // DTLS epoch
|
||||
ivLength int // Length of the seq and nonce fields
|
||||
seq uint64 // Zero-padded sequence number
|
||||
iv []byte // Buffer for the IV
|
||||
cipher cipher.AEAD // AEAD cipher
|
||||
}
|
||||
|
||||
type RecordLayer struct {
|
||||
sync.Mutex
|
||||
|
||||
label string
|
||||
direction direction
|
||||
version uint16 // The current version number
|
||||
conn io.ReadWriter // The underlying connection
|
||||
frame *frameReader // The buffered frame reader
|
||||
nextData []byte // The next record to send
|
||||
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
||||
cachedError error // Error on the last record read
|
||||
|
||||
ivLength int // Length of the seq and nonce fields
|
||||
seq []byte // Zero-padded sequence number
|
||||
nonce []byte // Buffer for per-record nonces
|
||||
cipher cipher.AEAD // AEAD cipher
|
||||
cipher *cipherState
|
||||
readCiphers map[Epoch]*cipherState
|
||||
|
||||
datagram bool
|
||||
}
|
||||
|
||||
type recordLayerFrameDetails struct{}
|
||||
type recordLayerFrameDetails struct {
|
||||
datagram bool
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) headerLen() int {
|
||||
return recordHeaderLen
|
||||
if d.datagram {
|
||||
return recordHeaderLenDTLS
|
||||
}
|
||||
return recordHeaderLenTLS
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) defaultReadLen() int {
|
||||
return recordHeaderLen + maxFragmentLen
|
||||
return d.headerLen() + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
return (int(hdr[3]) << 8) | int(hdr[4]), nil
|
||||
return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
|
||||
}
|
||||
|
||||
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
|
||||
func newCipherStateNull() *cipherState {
|
||||
return &cipherState{EpochClear, 0, 0, nil, nil}
|
||||
}
|
||||
|
||||
func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) {
|
||||
cipher, err := factory(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
|
||||
}
|
||||
|
||||
func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer {
|
||||
r := RecordLayer{}
|
||||
r.label = ""
|
||||
r.direction = dir
|
||||
r.conn = conn
|
||||
r.frame = newFrameReader(recordLayerFrameDetails{})
|
||||
r.ivLength = 0
|
||||
r.frame = newFrameReader(recordLayerFrameDetails{false})
|
||||
r.cipher = newCipherStateNull()
|
||||
r.version = tls10Version
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
|
||||
var err error
|
||||
r.cipher, err = cipher(key)
|
||||
func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer {
|
||||
r := RecordLayer{}
|
||||
r.label = ""
|
||||
r.direction = dir
|
||||
r.conn = conn
|
||||
r.frame = newFrameReader(recordLayerFrameDetails{true})
|
||||
r.cipher = newCipherStateNull()
|
||||
r.readCiphers = make(map[Epoch]*cipherState, 0)
|
||||
r.readCiphers[0] = r.cipher
|
||||
r.datagram = true
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *RecordLayer) SetVersion(v uint16) {
|
||||
r.version = v
|
||||
}
|
||||
|
||||
func (r *RecordLayer) ResetClear(seq uint64) {
|
||||
r.cipher = newCipherStateNull()
|
||||
r.cipher.seq = seq
|
||||
}
|
||||
|
||||
func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
|
||||
cipher, err := newCipherStateAead(epoch, factory, key, iv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ivLength = len(iv)
|
||||
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
|
||||
r.nonce = make([]byte, r.ivLength)
|
||||
copy(r.nonce, iv)
|
||||
r.cipher = cipher
|
||||
if r.datagram && r.direction == directionRead {
|
||||
r.readCiphers[epoch] = cipher
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) incrementSequenceNumber() {
|
||||
if r.ivLength == 0 {
|
||||
// TODO(ekr@rtfm.com): This is never used, which is a bug.
|
||||
func (r *RecordLayer) DiscardReadKey(epoch Epoch) {
|
||||
if !r.datagram {
|
||||
return
|
||||
}
|
||||
|
||||
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
|
||||
r.seq[i]++
|
||||
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
|
||||
if r.seq[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Not allowed to let sequence number wrap.
|
||||
// Instead, must renegotiate before it does.
|
||||
// Not likely enough to bother.
|
||||
panic("TLS: sequence number wraparound")
|
||||
_, ok := r.readCiphers[epoch]
|
||||
assert(ok)
|
||||
delete(r.readCiphers, epoch)
|
||||
}
|
||||
|
||||
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||
func (c *cipherState) combineSeq(datagram bool) uint64 {
|
||||
seq := c.seq
|
||||
if datagram {
|
||||
seq |= uint64(c.epoch) << 48
|
||||
}
|
||||
return seq
|
||||
}
|
||||
|
||||
func (c *cipherState) computeNonce(seq uint64) []byte {
|
||||
nonce := make([]byte, len(c.iv))
|
||||
copy(nonce, c.iv)
|
||||
|
||||
s := seq
|
||||
|
||||
offset := len(c.iv)
|
||||
for i := 0; i < 8; i++ {
|
||||
nonce[(offset-i)-1] ^= byte(s & 0xff)
|
||||
s >>= 8
|
||||
}
|
||||
logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
|
||||
|
||||
return nonce
|
||||
}
|
||||
|
||||
func (c *cipherState) incrementSequenceNumber() {
|
||||
if c.seq >= (1<<48 - 1) {
|
||||
// Not allowed to let sequence number wrap.
|
||||
// Instead, must renegotiate before it does.
|
||||
// Not likely enough to bother. This is the
|
||||
// DTLS limit.
|
||||
panic("TLS: sequence number wraparound")
|
||||
}
|
||||
c.seq++
|
||||
}
|
||||
|
||||
func (c *cipherState) overhead() int {
|
||||
if c.cipher == nil {
|
||||
return 0
|
||||
}
|
||||
return c.cipher.Overhead()
|
||||
}
|
||||
|
||||
func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||
assert(r.direction == directionWrite)
|
||||
logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
|
||||
// Expand the fragment to hold contentType, padding, and overhead
|
||||
originalLen := len(pt.fragment)
|
||||
plaintextLen := originalLen + 1 + padLen
|
||||
ciphertextLen := plaintextLen + r.cipher.Overhead()
|
||||
ciphertextLen := plaintextLen + cipher.overhead()
|
||||
|
||||
// Assemble the revised plaintext
|
||||
out := &TLSPlaintext{
|
||||
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: make([]byte, ciphertextLen),
|
||||
}
|
||||
@ -122,25 +219,28 @@ func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||
|
||||
// Encrypt the fragment
|
||||
payload := out.fragment[:plaintextLen]
|
||||
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
|
||||
cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil)
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
||||
if len(pt.fragment) < r.cipher.Overhead() {
|
||||
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
|
||||
func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) {
|
||||
assert(r.direction == directionRead)
|
||||
logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
|
||||
if len(pt.fragment) < r.cipher.overhead() {
|
||||
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
|
||||
return nil, 0, DecryptError(msg)
|
||||
}
|
||||
|
||||
decryptLen := len(pt.fragment) - r.cipher.Overhead()
|
||||
decryptLen := len(pt.fragment) - r.cipher.overhead()
|
||||
out := &TLSPlaintext{
|
||||
contentType: pt.contentType,
|
||||
fragment: make([]byte, decryptLen),
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
|
||||
_, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil)
|
||||
if err != nil {
|
||||
logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
|
||||
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
||||
}
|
||||
|
||||
@ -155,6 +255,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
||||
|
||||
// Truncate the message to remove contentType, padding, overhead
|
||||
out.fragment = out.fragment[:newLen]
|
||||
out.seq = seq
|
||||
return out, padLen, nil
|
||||
}
|
||||
|
||||
@ -163,11 +264,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
||||
var err error
|
||||
|
||||
for {
|
||||
pt, err = r.nextRecord()
|
||||
pt, err = r.nextRecord(false)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if !block || err != WouldBlock {
|
||||
if !block || err != AlertWouldBlock {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
@ -175,7 +276,7 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
||||
}
|
||||
|
||||
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||
pt, err := r.nextRecord()
|
||||
pt, err := r.nextRecord(false)
|
||||
|
||||
// Consume the cached record if there was one
|
||||
r.cachedRecord = nil
|
||||
@ -184,9 +285,20 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||
return pt, err
|
||||
}
|
||||
|
||||
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) {
|
||||
pt, err := r.nextRecord(true)
|
||||
|
||||
// Consume the cached record if there was one
|
||||
r.cachedRecord = nil
|
||||
r.cachedError = nil
|
||||
|
||||
return pt, err
|
||||
}
|
||||
|
||||
func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
|
||||
cipher := r.cipher
|
||||
if r.cachedRecord != nil {
|
||||
logf(logTypeIO, "Returning cached record")
|
||||
logf(logTypeIO, "%s Returning cached record", r.label)
|
||||
return r.cachedRecord, r.cachedError
|
||||
}
|
||||
|
||||
@ -194,34 +306,35 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
//
|
||||
// 1. We get a frame
|
||||
// 2. We try to read off the socket and get nothing, in which case
|
||||
// return WouldBlock
|
||||
// returnAlertWouldBlock
|
||||
// 3. We get an error.
|
||||
err := WouldBlock
|
||||
var err error
|
||||
err = AlertWouldBlock
|
||||
var header, body []byte
|
||||
|
||||
for err != nil {
|
||||
if r.frame.needed() > 0 {
|
||||
buf := make([]byte, recordHeaderLen+maxFragmentLen)
|
||||
buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
|
||||
n, err := r.conn.Read(buf)
|
||||
if err != nil {
|
||||
logf(logTypeIO, "Error reading, %v", err)
|
||||
logf(logTypeIO, "%s Error reading, %v", r.label, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nil, WouldBlock
|
||||
return nil, AlertWouldBlock
|
||||
}
|
||||
|
||||
logf(logTypeIO, "Read %v bytes", n)
|
||||
logf(logTypeIO, "%s Read %v bytes", r.label, n)
|
||||
|
||||
buf = buf[:n]
|
||||
r.frame.addChunk(buf)
|
||||
}
|
||||
|
||||
header, body, err = r.frame.process()
|
||||
// Loop around on WouldBlock to see if some
|
||||
// Loop around onAlertWouldBlock to see if some
|
||||
// data is now available.
|
||||
if err != nil && err != WouldBlock {
|
||||
if err != nil && err != AlertWouldBlock {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -231,7 +344,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
switch RecordType(header[0]) {
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
||||
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
|
||||
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
|
||||
pt.contentType = RecordType(header[0])
|
||||
}
|
||||
|
||||
@ -241,7 +354,8 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
}
|
||||
|
||||
// Validate size < max
|
||||
size := (int(header[3]) << 8) + int(header[4])
|
||||
size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
|
||||
|
||||
if size > maxFragmentLen+256 {
|
||||
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
||||
}
|
||||
@ -249,33 +363,67 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
pt.fragment = make([]byte, size)
|
||||
copy(pt.fragment, body)
|
||||
|
||||
// TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
|
||||
|
||||
// Attempt to decrypt fragment
|
||||
if r.cipher != nil {
|
||||
pt, _, err = r.decrypt(pt)
|
||||
seq := cipher.seq
|
||||
if r.datagram {
|
||||
// TODO(ekr@rtfm.com): Handle duplicates.
|
||||
seq, _ = decodeUint(header[3:11], 8)
|
||||
epoch := Epoch(seq >> 48)
|
||||
|
||||
// Look up the cipher suite from the epoch
|
||||
c, ok := r.readCiphers[epoch]
|
||||
if !ok {
|
||||
logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
|
||||
return nil, AlertWouldBlock
|
||||
}
|
||||
|
||||
if epoch != cipher.epoch {
|
||||
logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
|
||||
cipher.epoch, allowOldEpoch)
|
||||
if !allowOldEpoch {
|
||||
return nil, AlertWouldBlock
|
||||
}
|
||||
cipher = c
|
||||
}
|
||||
}
|
||||
|
||||
if cipher.cipher != nil {
|
||||
logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
|
||||
pt, _, err = r.decrypt(pt, seq)
|
||||
if err != nil {
|
||||
logf(logTypeIO, "%s Decryption failed", r.label)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
pt.epoch = cipher.epoch
|
||||
|
||||
// Check that plaintext length is not too long
|
||||
if len(pt.fragment) > maxFragmentLen {
|
||||
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
||||
}
|
||||
|
||||
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
|
||||
|
||||
r.cachedRecord = pt
|
||||
r.incrementSequenceNumber()
|
||||
cipher.incrementSequenceNumber()
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
||||
return r.WriteRecordWithPadding(pt, 0)
|
||||
return r.writeRecordWithPadding(pt, r.cipher, 0)
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
||||
if r.cipher != nil {
|
||||
pt = r.encrypt(pt, padLen)
|
||||
return r.writeRecordWithPadding(pt, r.cipher, padLen)
|
||||
}
|
||||
|
||||
func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
|
||||
seq := cipher.combineSeq(r.datagram)
|
||||
if cipher.cipher != nil {
|
||||
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
|
||||
pt = r.encrypt(cipher, seq, pt, padLen)
|
||||
} else if padLen > 0 {
|
||||
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
||||
}
|
||||
@ -285,12 +433,26 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error
|
||||
}
|
||||
|
||||
length := len(pt.fragment)
|
||||
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
|
||||
var header []byte
|
||||
|
||||
if !r.datagram {
|
||||
header = []byte{byte(pt.contentType),
|
||||
byte(r.version >> 8), byte(r.version & 0xff),
|
||||
byte(length >> 8), byte(length)}
|
||||
} else {
|
||||
header = make([]byte, 13)
|
||||
version := dtlsConvertVersion(r.version)
|
||||
copy(header, []byte{byte(pt.contentType),
|
||||
byte(version >> 8), byte(version & 0xff),
|
||||
})
|
||||
encodeUint(seq, 8, header[3:])
|
||||
encodeUint(uint64(length), 2, header[11:])
|
||||
}
|
||||
record := append(header, pt.fragment...)
|
||||
|
||||
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
|
||||
|
||||
r.incrementSequenceNumber()
|
||||
cipher.incrementSequenceNumber()
|
||||
_, err := r.conn.Write(record)
|
||||
return err
|
||||
}
|
||||
|
709
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
709
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
File diff suppressed because it is too large
Load Diff
111
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
111
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
@ -1,6 +1,7 @@
|
||||
package mint
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -8,32 +9,35 @@ import (
|
||||
// state transitions.
|
||||
type HandshakeAction interface{}
|
||||
|
||||
type SendHandshakeMessage struct {
|
||||
type QueueHandshakeMessage struct {
|
||||
Message *HandshakeMessage
|
||||
}
|
||||
|
||||
type SendQueuedHandshake struct{}
|
||||
|
||||
type SendEarlyData struct{}
|
||||
|
||||
type ReadEarlyData struct{}
|
||||
|
||||
type ReadPastEarlyData struct{}
|
||||
|
||||
type RekeyIn struct {
|
||||
Label string
|
||||
epoch Epoch
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type RekeyOut struct {
|
||||
Label string
|
||||
epoch Epoch
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type ResetOut struct {
|
||||
seq uint64
|
||||
}
|
||||
|
||||
type StorePSK struct {
|
||||
PSK PreSharedKey
|
||||
}
|
||||
|
||||
type HandshakeState interface {
|
||||
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
|
||||
Next(handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert)
|
||||
State() State
|
||||
}
|
||||
|
||||
type AppExtensionHandler interface {
|
||||
@ -41,35 +45,11 @@ type AppExtensionHandler interface {
|
||||
Receive(hs HandshakeType, el *ExtensionList) error
|
||||
}
|
||||
|
||||
// Capabilities objects represent the capabilities of a TLS client or server,
|
||||
// as an input to TLS negotiation
|
||||
type Capabilities struct {
|
||||
// For both client and server
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
PSKs PreSharedKeyCache
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
ExtensionHandler AppExtensionHandler
|
||||
|
||||
// For client
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
|
||||
// For server
|
||||
NextProtos []string
|
||||
AllowEarlyData bool
|
||||
RequireCookie bool
|
||||
CookieHandler CookieHandler
|
||||
RequireClientAuth bool
|
||||
}
|
||||
|
||||
// ConnectionOptions objects represent per-connection settings for a client
|
||||
// initiating a connection
|
||||
type ConnectionOptions struct {
|
||||
ServerName string
|
||||
NextProtos []string
|
||||
EarlyData []byte
|
||||
}
|
||||
|
||||
// ConnectionParameters objects represent the parameters negotiated for a
|
||||
@ -79,6 +59,7 @@ type ConnectionParameters struct {
|
||||
UsingDH bool
|
||||
ClientSendingEarlyData bool
|
||||
UsingEarlyData bool
|
||||
RejectedEarlyData bool
|
||||
UsingClientAuth bool
|
||||
|
||||
CipherSuite CipherSuite
|
||||
@ -86,18 +67,50 @@ type ConnectionParameters struct {
|
||||
NextProto string
|
||||
}
|
||||
|
||||
// StateConnected is symmetric between client and server
|
||||
type StateConnected struct {
|
||||
// Working state for the handshake.
|
||||
type HandshakeContext struct {
|
||||
timeoutMS uint32
|
||||
timers *timerSet
|
||||
recvdRecords []uint64
|
||||
sentFragments []*SentHandshakeFragment
|
||||
hIn, hOut *HandshakeLayer
|
||||
waitingNextFlight bool
|
||||
earlyData []byte
|
||||
}
|
||||
|
||||
func (hc *HandshakeContext) SetVersion(version uint16) {
|
||||
if hc.hIn.conn != nil {
|
||||
hc.hIn.conn.SetVersion(version)
|
||||
}
|
||||
if hc.hOut.conn != nil {
|
||||
hc.hOut.conn.SetVersion(version)
|
||||
}
|
||||
}
|
||||
|
||||
// stateConnected is symmetric between client and server
|
||||
type stateConnected struct {
|
||||
Params ConnectionParameters
|
||||
hsCtx *HandshakeContext
|
||||
isClient bool
|
||||
cryptoParams CipherSuiteParams
|
||||
resumptionSecret []byte
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
peerCertificates []*x509.Certificate
|
||||
verifiedChains [][]*x509.Certificate
|
||||
}
|
||||
|
||||
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||
var _ HandshakeState = &stateConnected{}
|
||||
|
||||
func (state stateConnected) State() State {
|
||||
if state.isClient {
|
||||
return StateClientConnected
|
||||
}
|
||||
return StateServerConnected
|
||||
}
|
||||
|
||||
func (state *stateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||
var trafficKeys keySet
|
||||
if state.isClient {
|
||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||
@ -109,20 +122,21 @@ func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAct
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||
kum, err := state.hsCtx.hOut.HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{kum},
|
||||
RekeyOut{Label: "update", KeySet: trafficKeys},
|
||||
QueueHandshakeMessage{kum},
|
||||
SendQueuedHandshake{},
|
||||
RekeyOut{epoch: EpochUpdate, KeySet: trafficKeys},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||
func (state *stateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||
tkt, err := NewSessionTicket(length, lifetime)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
||||
@ -149,7 +163,7 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif
|
||||
TicketAgeAdd: tkt.TicketAgeAdd,
|
||||
}
|
||||
|
||||
tktm, err := HandshakeMessageFromBody(tkt)
|
||||
tktm, err := state.hsCtx.hOut.HandshakeMessageFromBody(tkt)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
@ -157,12 +171,18 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
StorePSK{newPSK},
|
||||
SendHandshakeMessage{tktm},
|
||||
QueueHandshakeMessage{tktm},
|
||||
SendQueuedHandshake{},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
// Next does nothing for this state.
|
||||
func (state stateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||
return state, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state stateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
@ -187,20 +207,18 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
|
||||
toSend := []HandshakeAction{RekeyIn{epoch: EpochUpdate, KeySet: trafficKeys}}
|
||||
|
||||
// If requested, roll outbound keys and send a KeyUpdate
|
||||
if body.KeyUpdateRequest == KeyUpdateRequested {
|
||||
logf(logTypeHandshake, "Received key update, update requested", body.KeyUpdateRequest)
|
||||
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
||||
if alert != AlertNoAlert {
|
||||
return nil, nil, alert
|
||||
}
|
||||
|
||||
toSend = append(toSend, moreToSend...)
|
||||
}
|
||||
|
||||
return state, toSend, AlertNoAlert
|
||||
|
||||
case *NewSessionTicketBody:
|
||||
// XXX: Allow NewSessionTicket in both directions?
|
||||
if !state.isClient {
|
||||
@ -209,7 +227,6 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh
|
||||
|
||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
||||
|
||||
psk := PreSharedKey{
|
||||
CipherSuite: state.cryptoParams.Suite,
|
||||
IsResumption: true,
|
||||
|
10
vendor/github.com/bifurcation/mint/syntax/README.md
generated
vendored
10
vendor/github.com/bifurcation/mint/syntax/README.md
generated
vendored
@ -72,3 +72,13 @@ The available annotations right now are all related to vectors:
|
||||
fragment[TLSPlaintext.length]`. Note, however, that in cases where the length
|
||||
immediately preceds the array, these can be reframed as vectors with
|
||||
appropriate sizes.
|
||||
|
||||
|
||||
QUIC Extensions Syntax
|
||||
======================
|
||||
syntax also supports some minor extensions to allow implementing QUIC.
|
||||
|
||||
* The `varint` annotation describes a QUIC-style varint
|
||||
* `head=none` means no header, i.e., the bytes are encoded directly on the wire.
|
||||
On reading, the decoder will consume all available data.
|
||||
* `head=varint` means to encode the header as a varint
|
||||
|
203
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
203
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
@ -16,12 +16,22 @@ func Unmarshal(data []byte, v interface{}) (int, error) {
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
||||
// Unmarshaler is the interface implemented by types that can
|
||||
// unmarshal a TLS description of themselves. Note that unlike the
|
||||
// JSON unmarshaler interface, it is not known a priori how much of
|
||||
// the input data will be consumed. So the Unmarshaler must state
|
||||
// how much of the input data it consumed.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalTLS([]byte) (int, error)
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type decOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
varint bool // whether to decode as a varint
|
||||
}
|
||||
|
||||
type decodeState struct {
|
||||
@ -65,8 +75,14 @@ func typeDecoder(t reflect.Type) decoderFunc {
|
||||
return newTypeDecoder(t)
|
||||
}
|
||||
|
||||
var (
|
||||
unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
|
||||
)
|
||||
|
||||
func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
if t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(unmarshalerType) {
|
||||
return unmarshalerDecoder
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
@ -77,6 +93,8 @@ func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||
return newSliceDecoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructDecoder(t)
|
||||
case reflect.Ptr:
|
||||
return newPointerDecoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
@ -84,35 +102,87 @@ func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||
|
||||
///// Specific decoders below
|
||||
|
||||
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
var uintLen int
|
||||
switch v.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
uintLen = 1
|
||||
case reflect.Uint16:
|
||||
uintLen = 2
|
||||
case reflect.Uint32:
|
||||
uintLen = 4
|
||||
case reflect.Uint64:
|
||||
uintLen = 8
|
||||
func unmarshalerDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
um, ok := v.Interface().(Unmarshaler)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("Non-Unmarshaler passed to unmarshalerEncoder"))
|
||||
}
|
||||
|
||||
buf := make([]byte, uintLen)
|
||||
n, err := d.Read(buf)
|
||||
read, err := um.UnmarshalTLS(d.Bytes())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if n != uintLen {
|
||||
|
||||
if read > d.Len() {
|
||||
panic(fmt.Errorf("Invalid return value from UnmarshalTLS"))
|
||||
}
|
||||
|
||||
d.Next(read)
|
||||
return read
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
if opts.varint {
|
||||
return varintDecoder(d, v, opts)
|
||||
}
|
||||
|
||||
uintLen := int(v.Elem().Type().Size())
|
||||
buf := d.Next(uintLen)
|
||||
if len(buf) != uintLen {
|
||||
panic(fmt.Errorf("Insufficient data to read uint"))
|
||||
}
|
||||
|
||||
return setUintFromBuffer(v, buf)
|
||||
}
|
||||
|
||||
func varintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
l, val := readVarint(d)
|
||||
|
||||
uintLen := int(v.Elem().Type().Size())
|
||||
if uintLen < l {
|
||||
panic(fmt.Errorf("Uint too small to fit varint: %d < %d", uintLen, l))
|
||||
}
|
||||
|
||||
v.Elem().SetUint(val)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func readVarint(d *decodeState) (int, uint64) {
|
||||
// Read the first octet and decide the size of the presented varint
|
||||
first := d.Next(1)
|
||||
if len(first) != 1 {
|
||||
panic(fmt.Errorf("Insufficient data to read varint length"))
|
||||
}
|
||||
|
||||
twoBits := uint(first[0] >> 6)
|
||||
varintLen := 1 << twoBits
|
||||
|
||||
rest := d.Next(varintLen - 1)
|
||||
if len(rest) != varintLen-1 {
|
||||
panic(fmt.Errorf("Insufficient data to read varint"))
|
||||
}
|
||||
|
||||
buf := append(first, rest...)
|
||||
buf[0] &= 0x3f
|
||||
|
||||
return len(buf), decodeUintFromBuffer(buf)
|
||||
}
|
||||
|
||||
func decodeUintFromBuffer(buf []byte) uint64 {
|
||||
val := uint64(0)
|
||||
for _, b := range buf {
|
||||
val = (val << 8) + uint64(b)
|
||||
}
|
||||
|
||||
v.Elem().SetUint(val)
|
||||
return uintLen
|
||||
return val
|
||||
}
|
||||
|
||||
func setUintFromBuffer(v reflect.Value, buf []byte) int {
|
||||
v.Elem().SetUint(decodeUintFromBuffer(buf))
|
||||
return len(buf)
|
||||
}
|
||||
|
||||
//////////
|
||||
@ -143,44 +213,57 @@ type sliceDecoder struct {
|
||||
}
|
||||
|
||||
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
var length uint64
|
||||
var read int
|
||||
var data []byte
|
||||
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
||||
}
|
||||
|
||||
lengthBytes := make([]byte, opts.head)
|
||||
n, err := d.Read(lengthBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != opts.head {
|
||||
panic(fmt.Errorf("Not enough data to read header"))
|
||||
}
|
||||
// If the caller indicated there is no header, then read everything from the buffer
|
||||
if opts.head == headValueNoHead {
|
||||
for {
|
||||
chunk := d.Next(1024)
|
||||
data = append(data, chunk...)
|
||||
if len(chunk) != 1024 {
|
||||
break
|
||||
}
|
||||
}
|
||||
length = uint64(len(data))
|
||||
if opts.max > 0 && length > uint64(opts.max) {
|
||||
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||
}
|
||||
if length < uint64(opts.min) {
|
||||
panic(fmt.Errorf("Length of vector below declared min"))
|
||||
}
|
||||
} else {
|
||||
if opts.head != headValueVarint {
|
||||
lengthBytes := d.Next(int(opts.head))
|
||||
if len(lengthBytes) != int(opts.head) {
|
||||
panic(fmt.Errorf("Not enough data to read header"))
|
||||
}
|
||||
read = len(lengthBytes)
|
||||
length = decodeUintFromBuffer(lengthBytes)
|
||||
} else {
|
||||
read, length = readVarint(d)
|
||||
}
|
||||
if opts.max > 0 && length > uint64(opts.max) {
|
||||
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||
}
|
||||
if length < uint64(opts.min) {
|
||||
panic(fmt.Errorf("Length of vector below declared min"))
|
||||
}
|
||||
|
||||
length := uint(0)
|
||||
for _, b := range lengthBytes {
|
||||
length = (length << 8) + uint(b)
|
||||
}
|
||||
|
||||
if opts.max > 0 && length > opts.max {
|
||||
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||
}
|
||||
if length < opts.min {
|
||||
panic(fmt.Errorf("Length of vector below declared min"))
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err = d.Read(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != length {
|
||||
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
|
||||
data = d.Next(int(length))
|
||||
if len(data) != int(length) {
|
||||
panic(fmt.Errorf("Available data less than declared length [%d < %d]", len(data), length))
|
||||
}
|
||||
}
|
||||
|
||||
elemBuf := &decodeState{}
|
||||
elemBuf.Write(data)
|
||||
elems := []reflect.Value{}
|
||||
read := int(opts.head)
|
||||
for elemBuf.Len() > 0 {
|
||||
elem := reflect.New(sd.elementType)
|
||||
read += sd.elementDec(elemBuf, elem, opts)
|
||||
@ -231,9 +314,10 @@ func newStructDecoder(t reflect.Type) decoderFunc {
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
sd.fieldOpts[i] = decOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
varint: tagOpts[varintOption] > 0,
|
||||
}
|
||||
|
||||
sd.fieldDecs[i] = typeDecoder(f.Type)
|
||||
@ -241,3 +325,20 @@ func newStructDecoder(t reflect.Type) decoderFunc {
|
||||
|
||||
return sd.decode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type pointerDecoder struct {
|
||||
base decoderFunc
|
||||
}
|
||||
|
||||
func (pd *pointerDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
v.Elem().Set(reflect.New(v.Elem().Type().Elem()))
|
||||
return pd.base(d, v.Elem(), opts)
|
||||
}
|
||||
|
||||
func newPointerDecoder(t reflect.Type) decoderFunc {
|
||||
baseDecoder := typeDecoder(t.Elem())
|
||||
pd := pointerDecoder{base: baseDecoder}
|
||||
return pd.decode
|
||||
}
|
||||
|
145
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
145
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
@ -16,12 +16,19 @@ func Marshal(v interface{}) ([]byte, error) {
|
||||
return e.Bytes(), nil
|
||||
}
|
||||
|
||||
// Marshaler is the interface implemented by types that
|
||||
// have a defined TLS encoding.
|
||||
type Marshaler interface {
|
||||
MarshalTLS() ([]byte, error)
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type encOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
varint bool // whether to encode as a varint
|
||||
}
|
||||
|
||||
type encodeState struct {
|
||||
@ -62,8 +69,14 @@ func typeEncoder(t reflect.Type) encoderFunc {
|
||||
return newTypeEncoder(t)
|
||||
}
|
||||
|
||||
var (
|
||||
marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
|
||||
)
|
||||
|
||||
func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
if t.Implements(marshalerType) {
|
||||
return marshalerEncoder
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
@ -74,6 +87,8 @@ func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||
return newSliceEncoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructEncoder(t)
|
||||
case reflect.Ptr:
|
||||
return newPointerEncoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
@ -81,19 +96,65 @@ func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||
|
||||
///// Specific encoders below
|
||||
|
||||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
u := v.Uint()
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.WriteByte(byte(u))
|
||||
case reflect.Uint16:
|
||||
e.Write([]byte{byte(u >> 8), byte(u)})
|
||||
case reflect.Uint32:
|
||||
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
case reflect.Uint64:
|
||||
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
|
||||
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
panic(fmt.Errorf("Cannot encode nil pointer"))
|
||||
}
|
||||
|
||||
m, ok := v.Interface().(Marshaler)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("Non-Marshaler passed to marshalerEncoder"))
|
||||
}
|
||||
|
||||
b, err := m.MarshalTLS()
|
||||
if err == nil {
|
||||
_, err = e.Write(b)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if opts.varint {
|
||||
varintEncoder(e, v, opts)
|
||||
return
|
||||
}
|
||||
|
||||
writeUint(e, v.Uint(), int(v.Type().Size()))
|
||||
}
|
||||
|
||||
func varintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
writeVarint(e, v.Uint())
|
||||
}
|
||||
|
||||
func writeVarint(e *encodeState, u uint64) {
|
||||
if (u >> 62) > 0 {
|
||||
panic(fmt.Errorf("uint value is too big for varint"))
|
||||
}
|
||||
|
||||
var varintLen int
|
||||
for _, len := range []uint{1, 2, 4, 8} {
|
||||
if u < (uint64(1) << (8*len - 2)) {
|
||||
varintLen = int(len)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
twoBits := map[int]uint64{1: 0x00, 2: 0x01, 4: 0x02, 8: 0x03}[varintLen]
|
||||
shift := uint(8*varintLen - 2)
|
||||
writeUint(e, u|(twoBits<<shift), varintLen)
|
||||
}
|
||||
|
||||
func writeUint(e *encodeState, u uint64, len int) {
|
||||
data := make([]byte, len)
|
||||
for i := 0; i < len; i += 1 {
|
||||
data[i] = byte(u >> uint(8*(len-i-1)))
|
||||
}
|
||||
e.Write(data)
|
||||
}
|
||||
|
||||
//////////
|
||||
@ -121,27 +182,34 @@ type sliceEncoder struct {
|
||||
}
|
||||
|
||||
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||
}
|
||||
|
||||
arrayState := &encodeState{}
|
||||
se.ae.encode(arrayState, v, opts)
|
||||
|
||||
n := uint(arrayState.Len())
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||
}
|
||||
|
||||
if opts.max > 0 && n > opts.max {
|
||||
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
||||
}
|
||||
if n>>(8*opts.head) > 0 {
|
||||
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||
}
|
||||
if n < opts.min {
|
||||
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
||||
}
|
||||
|
||||
for i := int(opts.head - 1); i >= 0; i -= 1 {
|
||||
e.WriteByte(byte(n >> (8 * uint(i))))
|
||||
switch opts.head {
|
||||
case headValueNoHead:
|
||||
// None.
|
||||
case headValueVarint:
|
||||
writeVarint(e, uint64(n))
|
||||
default:
|
||||
if n>>(8*opts.head) > 0 {
|
||||
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||
}
|
||||
|
||||
writeUint(e, uint64(n), int(opts.head))
|
||||
}
|
||||
|
||||
e.Write(arrayState.Bytes())
|
||||
}
|
||||
|
||||
@ -176,12 +244,33 @@ func newStructEncoder(t reflect.Type) encoderFunc {
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
se.fieldOpts[i] = encOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
varint: tagOpts[varintOption] > 0,
|
||||
}
|
||||
se.fieldEncs[i] = typeEncoder(f.Type)
|
||||
}
|
||||
|
||||
return se.encode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type pointerEncoder struct {
|
||||
base encoderFunc
|
||||
}
|
||||
|
||||
func (pe pointerEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if v.IsNil() {
|
||||
panic(fmt.Errorf("Cannot marshal a struct containing a nil pointer"))
|
||||
}
|
||||
|
||||
pe.base(e, v.Elem(), opts)
|
||||
}
|
||||
|
||||
func newPointerEncoder(t reflect.Type) encoderFunc {
|
||||
baseEncoder := typeEncoder(t.Elem())
|
||||
pe := pointerEncoder{base: baseEncoder}
|
||||
return pe.encode
|
||||
}
|
||||
|
28
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
28
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
@ -5,16 +5,27 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// `tls:"head=2,min=2,max=255"`
|
||||
// `tls:"head=2,min=2,max=255,varint"`
|
||||
|
||||
type tagOptions map[string]uint
|
||||
|
||||
var (
|
||||
varintOption = "varint"
|
||||
|
||||
headOptionNone = "none"
|
||||
headOptionVarint = "varint"
|
||||
headValueNoHead = uint(255)
|
||||
headValueVarint = uint(254)
|
||||
)
|
||||
|
||||
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
||||
// name=value pairs, where the values MUST be unsigned integers
|
||||
// name=value pairs, where the values MUST be unsigned integers, or in
|
||||
// the special case of head, "none" or "varint"
|
||||
func parseTag(tag string) tagOptions {
|
||||
opts := tagOptions{}
|
||||
for _, token := range strings.Split(tag, ",") {
|
||||
if strings.Index(token, "=") == -1 {
|
||||
if token == varintOption {
|
||||
opts[varintOption] = 1
|
||||
continue
|
||||
}
|
||||
|
||||
@ -22,7 +33,16 @@ func parseTag(tag string) tagOptions {
|
||||
if len(parts[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||
|
||||
if len(parts) == 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
if parts[0] == "head" && parts[1] == headOptionNone {
|
||||
opts[parts[0]] = headValueNoHead
|
||||
} else if parts[0] == "head" && parts[1] == headOptionVarint {
|
||||
opts[parts[0]] = headValueVarint
|
||||
} else if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||
opts[parts[0]] = uint(val)
|
||||
}
|
||||
}
|
||||
|
122
vendor/github.com/bifurcation/mint/timer.go
generated
vendored
Normal file
122
vendor/github.com/bifurcation/mint/timer.go
generated
vendored
Normal file
@ -0,0 +1,122 @@
|
||||
package mint
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// This is a simple timer implementation. Timers are stored in a sorted
|
||||
// list.
|
||||
// TODO(ekr@rtfm.com): Add a way to uncouple these from the system
|
||||
// clock.
|
||||
type timerCb func() error
|
||||
|
||||
type timer struct {
|
||||
label string
|
||||
cb timerCb
|
||||
deadline time.Time
|
||||
duration uint32
|
||||
}
|
||||
|
||||
type timerSet struct {
|
||||
ts []*timer
|
||||
}
|
||||
|
||||
func newTimerSet() *timerSet {
|
||||
return &timerSet{}
|
||||
}
|
||||
|
||||
func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer {
|
||||
now := time.Now()
|
||||
t := timer{
|
||||
label,
|
||||
cb,
|
||||
now.Add(time.Millisecond * time.Duration(delayMs)),
|
||||
delayMs,
|
||||
}
|
||||
logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline)
|
||||
|
||||
var i int
|
||||
ntimers := len(ts.ts)
|
||||
for i = 0; i < ntimers; i++ {
|
||||
if t.deadline.Before(ts.ts[i].deadline) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tmp := make([]*timer, 0, ntimers+1)
|
||||
tmp = append(tmp, ts.ts[:i]...)
|
||||
tmp = append(tmp, &t)
|
||||
tmp = append(tmp, ts.ts[i:]...)
|
||||
ts.ts = tmp
|
||||
|
||||
return &t
|
||||
}
|
||||
|
||||
// TODO(ekr@rtfm.com): optimize this now that the list is sorted.
|
||||
// We should be able to do just one list manipulation, as long
|
||||
// as we're careful about how we handle inserts during callbacks.
|
||||
func (ts *timerSet) check(now time.Time) error {
|
||||
for i, t := range ts.ts {
|
||||
if now.After(t.deadline) {
|
||||
ts.ts = append(ts.ts[:i], ts.ts[:i+1]...)
|
||||
if t.cb != nil {
|
||||
logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline)
|
||||
cb := t.cb
|
||||
t.cb = nil
|
||||
err := cb()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the next time any of the timers would fire.
|
||||
func (ts *timerSet) remaining() (bool, time.Duration) {
|
||||
for _, t := range ts.ts {
|
||||
if t.cb != nil {
|
||||
return true, time.Until(t.deadline)
|
||||
}
|
||||
}
|
||||
|
||||
return false, time.Duration(0)
|
||||
}
|
||||
|
||||
func (ts *timerSet) cancel(label string) {
|
||||
for _, t := range ts.ts {
|
||||
if t.label == label {
|
||||
t.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *timerSet) getTimer(label string) *timer {
|
||||
for _, t := range ts.ts {
|
||||
if t.label == label && t.cb != nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *timerSet) getAllTimers() []string {
|
||||
var ret []string
|
||||
|
||||
for _, t := range ts.ts {
|
||||
if t.cb != nil {
|
||||
ret = append(ret, t.label)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (t *timer) cancel() {
|
||||
logf(logTypeHandshake, "Timer %s cancelled", t.label)
|
||||
t.cb = nil
|
||||
t.label = ""
|
||||
}
|
25
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
25
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
@ -51,11 +51,14 @@ func (l *Listener) Accept() (c net.Conn, err error) {
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||
func NewListener(inner net.Listener, config *Config) (net.Listener, error) {
|
||||
if config != nil && config.NonBlocking {
|
||||
return nil, errors.New("listening not possible in non-blocking mode")
|
||||
}
|
||||
l := new(Listener)
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
return l
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// Listen creates a TLS listener accepting connections on the
|
||||
@ -70,7 +73,7 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewListener(l, config), nil
|
||||
return NewListener(l, config)
|
||||
}
|
||||
|
||||
type TimeoutError struct{}
|
||||
@ -87,6 +90,10 @@ func (TimeoutError) Temporary() bool { return true }
|
||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||
// configuration; see the documentation of Config for the defaults.
|
||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||
if config != nil && config.NonBlocking {
|
||||
return nil, errors.New("dialing not possible in non-blocking mode")
|
||||
}
|
||||
|
||||
// We want the Timeout and Deadline values from dialer to cover the
|
||||
// whole process: TCP connection and TLS handshake. This means that we
|
||||
// also need to start our own timers now.
|
||||
@ -121,16 +128,20 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
|
||||
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
} else {
|
||||
config = config.Clone()
|
||||
}
|
||||
|
||||
// If no ServerName is set, infer the ServerName
|
||||
// from the hostname we're connecting to.
|
||||
if config.ServerName == "" {
|
||||
// Make a copy to avoid polluting argument or default.
|
||||
c := config.Clone()
|
||||
c.ServerName = hostname
|
||||
config = c
|
||||
config.ServerName = hostname
|
||||
|
||||
}
|
||||
|
||||
// Set up DTLS as needed.
|
||||
config.UseDTLS = (network == "udp")
|
||||
|
||||
conn := Client(rawConn, config)
|
||||
|
||||
if timeout == 0 {
|
||||
|
22
vendor/github.com/cheekybits/genny/LICENSE
generated
vendored
Normal file
22
vendor/github.com/cheekybits/genny/LICENSE
generated
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 cheekybits
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
2
vendor/github.com/cheekybits/genny/generic/doc.go
generated
vendored
Normal file
2
vendor/github.com/cheekybits/genny/generic/doc.go
generated
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
// Package generic contains the generic marker types.
|
||||
package generic
|
13
vendor/github.com/cheekybits/genny/generic/generic.go
generated
vendored
Normal file
13
vendor/github.com/cheekybits/genny/generic/generic.go
generated
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
package generic
|
||||
|
||||
// Type is the placeholder type that indicates a generic value.
|
||||
// When genny is executed, variables of this type will be replaced with
|
||||
// references to the specific types.
|
||||
// var GenericType generic.Type
|
||||
type Type interface{}
|
||||
|
||||
// Number is the placehoder type that indiccates a generic numerical value.
|
||||
// When genny is executed, variables of this type will be replaced with
|
||||
// references to the specific types.
|
||||
// var GenericType generic.Number
|
||||
type Number float64
|
21
vendor/github.com/hashicorp/golang-lru/2q.go
generated
vendored
21
vendor/github.com/hashicorp/golang-lru/2q.go
generated
vendored
@ -30,9 +30,9 @@ type TwoQueueCache struct {
|
||||
size int
|
||||
recentSize int
|
||||
|
||||
recent *simplelru.LRU
|
||||
frequent *simplelru.LRU
|
||||
recentEvict *simplelru.LRU
|
||||
recent simplelru.LRUCache
|
||||
frequent simplelru.LRUCache
|
||||
recentEvict simplelru.LRUCache
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
@ -84,7 +84,8 @@ func New2QParams(size int, recentRatio float64, ghostRatio float64) (*TwoQueueCa
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) {
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *TwoQueueCache) Get(key interface{}) (value interface{}, ok bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
@ -105,6 +106,7 @@ func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Add adds a value to the cache.
|
||||
func (c *TwoQueueCache) Add(key, value interface{}) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
@ -160,12 +162,15 @@ func (c *TwoQueueCache) ensureSpace(recentEvict bool) {
|
||||
c.frequent.RemoveOldest()
|
||||
}
|
||||
|
||||
// Len returns the number of items in the cache.
|
||||
func (c *TwoQueueCache) Len() int {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
return c.recent.Len() + c.frequent.Len()
|
||||
}
|
||||
|
||||
// Keys returns a slice of the keys in the cache.
|
||||
// The frequently used keys are first in the returned slice.
|
||||
func (c *TwoQueueCache) Keys() []interface{} {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
@ -174,6 +179,7 @@ func (c *TwoQueueCache) Keys() []interface{} {
|
||||
return append(k1, k2...)
|
||||
}
|
||||
|
||||
// Remove removes the provided key from the cache.
|
||||
func (c *TwoQueueCache) Remove(key interface{}) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
@ -188,6 +194,7 @@ func (c *TwoQueueCache) Remove(key interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// Purge is used to completely clear the cache.
|
||||
func (c *TwoQueueCache) Purge() {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
@ -196,13 +203,17 @@ func (c *TwoQueueCache) Purge() {
|
||||
c.recentEvict.Purge()
|
||||
}
|
||||
|
||||
// Contains is used to check if the cache contains a key
|
||||
// without updating recency or frequency.
|
||||
func (c *TwoQueueCache) Contains(key interface{}) bool {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
return c.frequent.Contains(key) || c.recent.Contains(key)
|
||||
}
|
||||
|
||||
func (c *TwoQueueCache) Peek(key interface{}) (interface{}, bool) {
|
||||
// Peek is used to inspect the cache value of a key
|
||||
// without updating recency or frequency.
|
||||
func (c *TwoQueueCache) Peek(key interface{}) (value interface{}, ok bool) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
if val, ok := c.frequent.Peek(key); ok {
|
||||
|
16
vendor/github.com/hashicorp/golang-lru/arc.go
generated
vendored
16
vendor/github.com/hashicorp/golang-lru/arc.go
generated
vendored
@ -18,11 +18,11 @@ type ARCCache struct {
|
||||
size int // Size is the total capacity of the cache
|
||||
p int // P is the dynamic preference towards T1 or T2
|
||||
|
||||
t1 *simplelru.LRU // T1 is the LRU for recently accessed items
|
||||
b1 *simplelru.LRU // B1 is the LRU for evictions from t1
|
||||
t1 simplelru.LRUCache // T1 is the LRU for recently accessed items
|
||||
b1 simplelru.LRUCache // B1 is the LRU for evictions from t1
|
||||
|
||||
t2 *simplelru.LRU // T2 is the LRU for frequently accessed items
|
||||
b2 *simplelru.LRU // B2 is the LRU for evictions from t2
|
||||
t2 simplelru.LRUCache // T2 is the LRU for frequently accessed items
|
||||
b2 simplelru.LRUCache // B2 is the LRU for evictions from t2
|
||||
|
||||
lock sync.RWMutex
|
||||
}
|
||||
@ -60,11 +60,11 @@ func NewARC(size int) (*ARCCache, error) {
|
||||
}
|
||||
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *ARCCache) Get(key interface{}) (interface{}, bool) {
|
||||
func (c *ARCCache) Get(key interface{}) (value interface{}, ok bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
// Ff the value is contained in T1 (recent), then
|
||||
// If the value is contained in T1 (recent), then
|
||||
// promote it to T2 (frequent)
|
||||
if val, ok := c.t1.Peek(key); ok {
|
||||
c.t1.Remove(key)
|
||||
@ -153,7 +153,7 @@ func (c *ARCCache) Add(key, value interface{}) {
|
||||
// Remove from B2
|
||||
c.b2.Remove(key)
|
||||
|
||||
// Add the key to the frequntly used list
|
||||
// Add the key to the frequently used list
|
||||
c.t2.Add(key, value)
|
||||
return
|
||||
}
|
||||
@ -247,7 +247,7 @@ func (c *ARCCache) Contains(key interface{}) bool {
|
||||
|
||||
// Peek is used to inspect the cache value of a key
|
||||
// without updating recency or frequency.
|
||||
func (c *ARCCache) Peek(key interface{}) (interface{}, bool) {
|
||||
func (c *ARCCache) Peek(key interface{}) (value interface{}, ok bool) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
if val, ok := c.t1.Peek(key); ok {
|
||||
|
21
vendor/github.com/hashicorp/golang-lru/doc.go
generated
vendored
Normal file
21
vendor/github.com/hashicorp/golang-lru/doc.go
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
// Package lru provides three different LRU caches of varying sophistication.
|
||||
//
|
||||
// Cache is a simple LRU cache. It is based on the
|
||||
// LRU implementation in groupcache:
|
||||
// https://github.com/golang/groupcache/tree/master/lru
|
||||
//
|
||||
// TwoQueueCache tracks frequently used and recently used entries separately.
|
||||
// This avoids a burst of accesses from taking out frequently used entries,
|
||||
// at the cost of about 2x computational overhead and some extra bookkeeping.
|
||||
//
|
||||
// ARCCache is an adaptive replacement cache. It tracks recent evictions as
|
||||
// well as recent usage in both the frequent and recent caches. Its
|
||||
// computational overhead is comparable to TwoQueueCache, but the memory
|
||||
// overhead is linear with the size of the cache.
|
||||
//
|
||||
// ARC has been patented by IBM, so do not use it if that is problematic for
|
||||
// your program.
|
||||
//
|
||||
// All caches in this package take locks while operating, and are therefore
|
||||
// thread-safe for consumers.
|
||||
package lru
|
28
vendor/github.com/hashicorp/golang-lru/lru.go
generated
vendored
28
vendor/github.com/hashicorp/golang-lru/lru.go
generated
vendored
@ -1,6 +1,3 @@
|
||||
// This package provides a simple LRU cache. It is based on the
|
||||
// LRU implementation in groupcache:
|
||||
// https://github.com/golang/groupcache/tree/master/lru
|
||||
package lru
|
||||
|
||||
import (
|
||||
@ -11,11 +8,11 @@ import (
|
||||
|
||||
// Cache is a thread-safe fixed size LRU cache.
|
||||
type Cache struct {
|
||||
lru *simplelru.LRU
|
||||
lru simplelru.LRUCache
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates an LRU of the given size
|
||||
// New creates an LRU of the given size.
|
||||
func New(size int) (*Cache, error) {
|
||||
return NewWithEvict(size, nil)
|
||||
}
|
||||
@ -33,7 +30,7 @@ func NewWithEvict(size int, onEvicted func(key interface{}, value interface{}))
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Purge is used to completely clear the cache
|
||||
// Purge is used to completely clear the cache.
|
||||
func (c *Cache) Purge() {
|
||||
c.lock.Lock()
|
||||
c.lru.Purge()
|
||||
@ -41,30 +38,30 @@ func (c *Cache) Purge() {
|
||||
}
|
||||
|
||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||
func (c *Cache) Add(key, value interface{}) bool {
|
||||
func (c *Cache) Add(key, value interface{}) (evicted bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return c.lru.Add(key, value)
|
||||
}
|
||||
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *Cache) Get(key interface{}) (interface{}, bool) {
|
||||
func (c *Cache) Get(key interface{}) (value interface{}, ok bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return c.lru.Get(key)
|
||||
}
|
||||
|
||||
// Check if a key is in the cache, without updating the recent-ness
|
||||
// or deleting it for being stale.
|
||||
// Contains checks if a key is in the cache, without updating the
|
||||
// recent-ness or deleting it for being stale.
|
||||
func (c *Cache) Contains(key interface{}) bool {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
return c.lru.Contains(key)
|
||||
}
|
||||
|
||||
// Returns the key value (or undefined if not found) without updating
|
||||
// Peek returns the key value (or undefined if not found) without updating
|
||||
// the "recently used"-ness of the key.
|
||||
func (c *Cache) Peek(key interface{}) (interface{}, bool) {
|
||||
func (c *Cache) Peek(key interface{}) (value interface{}, ok bool) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
return c.lru.Peek(key)
|
||||
@ -73,16 +70,15 @@ func (c *Cache) Peek(key interface{}) (interface{}, bool) {
|
||||
// ContainsOrAdd checks if a key is in the cache without updating the
|
||||
// recent-ness or deleting it for being stale, and if not, adds the value.
|
||||
// Returns whether found and whether an eviction occurred.
|
||||
func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evict bool) {
|
||||
func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evicted bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.lru.Contains(key) {
|
||||
return true, false
|
||||
} else {
|
||||
evict := c.lru.Add(key, value)
|
||||
return false, evict
|
||||
}
|
||||
evicted = c.lru.Add(key, value)
|
||||
return false, evicted
|
||||
}
|
||||
|
||||
// Remove removes the provided key from the cache.
|
||||
|
17
vendor/github.com/hashicorp/golang-lru/simplelru/lru.go
generated
vendored
17
vendor/github.com/hashicorp/golang-lru/simplelru/lru.go
generated
vendored
@ -36,7 +36,7 @@ func NewLRU(size int, onEvict EvictCallback) (*LRU, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Purge is used to completely clear the cache
|
||||
// Purge is used to completely clear the cache.
|
||||
func (c *LRU) Purge() {
|
||||
for k, v := range c.items {
|
||||
if c.onEvict != nil {
|
||||
@ -48,7 +48,7 @@ func (c *LRU) Purge() {
|
||||
}
|
||||
|
||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||
func (c *LRU) Add(key, value interface{}) bool {
|
||||
func (c *LRU) Add(key, value interface{}) (evicted bool) {
|
||||
// Check for existing item
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.evictList.MoveToFront(ent)
|
||||
@ -78,17 +78,18 @@ func (c *LRU) Get(key interface{}) (value interface{}, ok bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if a key is in the cache, without updating the recent-ness
|
||||
// Contains checks if a key is in the cache, without updating the recent-ness
|
||||
// or deleting it for being stale.
|
||||
func (c *LRU) Contains(key interface{}) (ok bool) {
|
||||
_, ok = c.items[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Returns the key value (or undefined if not found) without updating
|
||||
// Peek returns the key value (or undefined if not found) without updating
|
||||
// the "recently used"-ness of the key.
|
||||
func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
|
||||
if ent, ok := c.items[key]; ok {
|
||||
var ent *list.Element
|
||||
if ent, ok = c.items[key]; ok {
|
||||
return ent.Value.(*entry).value, true
|
||||
}
|
||||
return nil, ok
|
||||
@ -96,7 +97,7 @@ func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
|
||||
|
||||
// Remove removes the provided key from the cache, returning if the
|
||||
// key was contained.
|
||||
func (c *LRU) Remove(key interface{}) bool {
|
||||
func (c *LRU) Remove(key interface{}) (present bool) {
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.removeElement(ent)
|
||||
return true
|
||||
@ -105,7 +106,7 @@ func (c *LRU) Remove(key interface{}) bool {
|
||||
}
|
||||
|
||||
// RemoveOldest removes the oldest item from the cache.
|
||||
func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) {
|
||||
func (c *LRU) RemoveOldest() (key interface{}, value interface{}, ok bool) {
|
||||
ent := c.evictList.Back()
|
||||
if ent != nil {
|
||||
c.removeElement(ent)
|
||||
@ -116,7 +117,7 @@ func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) {
|
||||
}
|
||||
|
||||
// GetOldest returns the oldest entry
|
||||
func (c *LRU) GetOldest() (interface{}, interface{}, bool) {
|
||||
func (c *LRU) GetOldest() (key interface{}, value interface{}, ok bool) {
|
||||
ent := c.evictList.Back()
|
||||
if ent != nil {
|
||||
kv := ent.Value.(*entry)
|
||||
|
37
vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go
generated
vendored
Normal file
37
vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go
generated
vendored
Normal file
@ -0,0 +1,37 @@
|
||||
package simplelru
|
||||
|
||||
|
||||
// LRUCache is the interface for simple LRU cache.
|
||||
type LRUCache interface {
|
||||
// Adds a value to the cache, returns true if an eviction occurred and
|
||||
// updates the "recently used"-ness of the key.
|
||||
Add(key, value interface{}) bool
|
||||
|
||||
// Returns key's value from the cache and
|
||||
// updates the "recently used"-ness of the key. #value, isFound
|
||||
Get(key interface{}) (value interface{}, ok bool)
|
||||
|
||||
// Check if a key exsists in cache without updating the recent-ness.
|
||||
Contains(key interface{}) (ok bool)
|
||||
|
||||
// Returns key's value without updating the "recently used"-ness of the key.
|
||||
Peek(key interface{}) (value interface{}, ok bool)
|
||||
|
||||
// Removes a key from the cache.
|
||||
Remove(key interface{}) bool
|
||||
|
||||
// Removes the oldest entry from cache.
|
||||
RemoveOldest() (interface{}, interface{}, bool)
|
||||
|
||||
// Returns the oldest entry from the cache. #key, value, isFound
|
||||
GetOldest() (interface{}, interface{}, bool)
|
||||
|
||||
// Returns a slice of the keys in the cache, from oldest to newest.
|
||||
Keys() []interface{}
|
||||
|
||||
// Returns the number of items in the cache.
|
||||
Len() int
|
||||
|
||||
// Clear all cache entries
|
||||
Purge()
|
||||
}
|
21
vendor/github.com/isofew/go-stun/LICENSE
generated
vendored
Normal file
21
vendor/github.com/isofew/go-stun/LICENSE
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Vasily Vasilyev
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
284
vendor/github.com/isofew/go-stun/stun/agent.go
generated
vendored
Normal file
284
vendor/github.com/isofew/go-stun/stun/agent.go
generated
vendored
Normal file
@ -0,0 +1,284 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DefaultConfig = &Config{
|
||||
RetransmissionTimeout: 500 * time.Millisecond,
|
||||
TransactionTimeout: 39500 * time.Millisecond,
|
||||
Software: "pixelbender/go-stun",
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
ServeSTUN(msg *Message, tr Transport)
|
||||
}
|
||||
|
||||
type HandlerFunc func(msg *Message, tr Transport)
|
||||
|
||||
func (h HandlerFunc) ServeSTUN(msg *Message, tr Transport) {
|
||||
h(msg, tr)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// AuthMethod returns a key for MESSAGE-INTEGRITY attribute
|
||||
AuthMethod AuthMethod
|
||||
// Retransmission timeout, default is 500 milliseconds
|
||||
RetransmissionTimeout time.Duration
|
||||
// Transaction timeout, default is 39.5 seconds
|
||||
TransactionTimeout time.Duration
|
||||
// Fingerprint, if true all outgoing messages contain FINGERPRINT attribute
|
||||
Fingerprint bool
|
||||
// Software is a SOFTWARE attribute value for outgoing messages, if not empty
|
||||
Software string
|
||||
// Logf, if set all sent and received messages printed using Logf
|
||||
Logf func(format string, args ...interface{})
|
||||
}
|
||||
|
||||
func (c *Config) attrs() []Attr {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
var a []Attr
|
||||
if c.Software != "" {
|
||||
a = append(a, String(AttrSoftware, c.Software))
|
||||
}
|
||||
if c.Fingerprint {
|
||||
a = append(a, Fingerprint)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (c *Config) Clone() *Config {
|
||||
r := *c
|
||||
return &r
|
||||
}
|
||||
|
||||
type Agent struct {
|
||||
config *Config
|
||||
Handler Handler
|
||||
m mux
|
||||
}
|
||||
|
||||
func NewAgent(config *Config) *Agent {
|
||||
if config == nil {
|
||||
config = DefaultConfig
|
||||
}
|
||||
return &Agent{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) Send(msg *Message, tr Transport) (err error) {
|
||||
msg = &Message{
|
||||
msg.Type,
|
||||
msg.Transaction,
|
||||
append(a.config.attrs(), msg.Attributes...),
|
||||
}
|
||||
if log := a.config.Logf; log != nil {
|
||||
log("%v → %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
|
||||
}
|
||||
b := msg.Marshal(getBuffer()[:0])
|
||||
_, err = tr.Write(b)
|
||||
putBuffer(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Agent) ServeConn(c net.Conn, stop chan struct{}) error {
|
||||
if c, ok := c.(net.PacketConn); ok {
|
||||
return a.ServePacket(c, stop)
|
||||
}
|
||||
var (
|
||||
b = getBuffer()
|
||||
p int
|
||||
)
|
||||
defer putBuffer(b)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
if p >= len(b) {
|
||||
return errBufferOverflow
|
||||
}
|
||||
n, err := c.Read(b[p:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p += n
|
||||
n = 0
|
||||
for n < p {
|
||||
r, err := a.ServeTransport(b[n:p], c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n += r
|
||||
}
|
||||
if n > 0 {
|
||||
if n < p {
|
||||
p = copy(b, b[n:p])
|
||||
} else {
|
||||
p = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) ServePacket(c net.PacketConn, stop chan struct{}) error {
|
||||
b := getBuffer()
|
||||
defer putBuffer(b)
|
||||
// don't close the connection since we're going to reuse it
|
||||
// defer c.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
n, addr, err := c.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
a.ServeTransport(b[:n], &packetConn{c, addr})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) ServeTransport(b []byte, tr Transport) (n int, err error) {
|
||||
msg := &Message{}
|
||||
n, err = msg.Unmarshal(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
a.ServeSTUN(msg, tr)
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Agent) ServeSTUN(msg *Message, tr Transport) {
|
||||
if log := a.config.Logf; log != nil {
|
||||
log("%v ← %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
|
||||
}
|
||||
if a.m.serve(msg, tr) {
|
||||
return
|
||||
}
|
||||
if h := a.Handler; h != nil {
|
||||
go h.ServeSTUN(msg, tr)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) RoundTrip(req *Message, to Transport) (res *Message, from Transport, err error) {
|
||||
var (
|
||||
start = time.Now()
|
||||
rto = a.config.RetransmissionTimeout
|
||||
udp = to.LocalAddr().Network() == "udp"
|
||||
tx = a.m.newTx()
|
||||
)
|
||||
defer a.m.closeTx(tx)
|
||||
req = &Message{req.Type, tx.id, req.Attributes}
|
||||
if err = a.Send(req, to); err != nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
d := a.config.TransactionTimeout - time.Since(start)
|
||||
if d < 0 {
|
||||
err = errTimeout
|
||||
return
|
||||
}
|
||||
if udp && d > rto {
|
||||
d = rto
|
||||
}
|
||||
res, from, err = tx.Receive(d)
|
||||
if udp && err == errTimeout && d == rto {
|
||||
rto <<= 1
|
||||
a.Send(req, to)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type mux struct {
|
||||
sync.RWMutex
|
||||
t map[string]*transaction
|
||||
}
|
||||
|
||||
func (m *mux) serve(msg *Message, tr Transport) bool {
|
||||
m.RLock()
|
||||
tx, ok := m.t[string(msg.Transaction)]
|
||||
m.RUnlock()
|
||||
if ok {
|
||||
tx.msg, tx.from = msg, tr
|
||||
tx.Done()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mux) newTx() *transaction {
|
||||
tx := &transaction{id: NewTransaction()}
|
||||
m.Lock()
|
||||
if m.t == nil {
|
||||
m.t = make(map[string]*transaction)
|
||||
} else {
|
||||
for m.t[string(tx.id)] != nil {
|
||||
rand.Read(tx.id[4:])
|
||||
}
|
||||
}
|
||||
m.t[string(tx.id)] = tx
|
||||
m.Unlock()
|
||||
return tx
|
||||
}
|
||||
|
||||
func (m *mux) closeTx(tx *transaction) {
|
||||
m.Lock()
|
||||
delete(m.t, string(tx.id))
|
||||
m.Unlock()
|
||||
}
|
||||
|
||||
func (m *mux) Close() {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
for _, it := range m.t {
|
||||
it.Close()
|
||||
}
|
||||
m.t = nil
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
sync.WaitGroup
|
||||
id []byte
|
||||
from Transport
|
||||
msg *Message
|
||||
err error
|
||||
}
|
||||
|
||||
func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) {
|
||||
tx.Add(1)
|
||||
t := time.AfterFunc(d, tx.timeout)
|
||||
tx.Wait()
|
||||
t.Stop()
|
||||
if err = tx.err; err != nil {
|
||||
return
|
||||
}
|
||||
return tx.msg, tx.from, nil
|
||||
}
|
||||
|
||||
func (tx *transaction) timeout() {
|
||||
tx.err = errTimeout
|
||||
tx.Done()
|
||||
}
|
||||
|
||||
func (tx *transaction) Close() {
|
||||
tx.err = errCanceled
|
||||
tx.Done()
|
||||
}
|
||||
|
||||
var errCanceled = errors.New("stun: transaction canceled")
|
||||
var errTimeout = errors.New("stun: transaction timeout")
|
438
vendor/github.com/isofew/go-stun/stun/attribute.go
generated
vendored
Normal file
438
vendor/github.com/isofew/go-stun/stun/attribute.go
generated
vendored
Normal file
@ -0,0 +1,438 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Attribute represents a STUN attribute.
|
||||
type Attr interface {
|
||||
Type() uint16
|
||||
Marshal(p []byte) []byte
|
||||
Unmarshal(b []byte) error
|
||||
}
|
||||
|
||||
// IP address family
|
||||
const (
|
||||
IPv4 = 0x01
|
||||
IPv6 = 0x02
|
||||
)
|
||||
|
||||
// IP address family
|
||||
const (
|
||||
ChangeIP uint64 = 0x04
|
||||
ChangePort = 0x02
|
||||
)
|
||||
|
||||
func newAttr(typ uint16) Attr {
|
||||
switch typ {
|
||||
case AttrMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress,
|
||||
AttrXorMappedAddress, AttrAlternateServer, AttrResponseOrigin, AttrOtherAddress,
|
||||
AttrResponseAddress, AttrSourceAddress, AttrChangedAddress, AttrReflectedFrom:
|
||||
return &addr{typ: typ}
|
||||
case AttrRequestedAddressFamily, AttrRequestedTransport:
|
||||
return &number{typ: typ, size: 4, pad: 24}
|
||||
case AttrChannelNumber, AttrResponsePort:
|
||||
return &number{typ: typ, size: 4, pad: 16}
|
||||
case AttrLifetime, AttrConnectionID, AttrCacheTimeout,
|
||||
AttrBandwidth, AttrTimerVal,
|
||||
AttrTransactionTransmitCounter,
|
||||
AttrEcnCheck, AttrChangeRequest, AttrPriority:
|
||||
return &number{typ: typ, size: 4}
|
||||
case AttrIceControlled, AttrIceControlling:
|
||||
return &number{typ: typ, size: 8}
|
||||
case AttrUsername, AttrRealm, AttrNonce, AttrSoftware, AttrPassword, AttrThirdPartyAuthorization,
|
||||
AttrData, AttrAccessToken, AttrReservationToken, AttrMobilityTicket, AttrPadding, AttrUnknownAttributes:
|
||||
return &raw{typ: typ}
|
||||
case AttrMessageIntegrity:
|
||||
return &integrity{}
|
||||
case AttrErrorCode:
|
||||
return &Error{}
|
||||
case AttrEvenPort:
|
||||
return &number{typ: typ, size: 1}
|
||||
case AttrDontFragment, AttrUseCandidate:
|
||||
return flag(typ)
|
||||
case AttrFingerprint:
|
||||
return &fingerprint{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AttrName(typ uint16) string {
|
||||
if r, ok := attrNames[typ]; ok {
|
||||
return r
|
||||
}
|
||||
return "0x" + strconv.FormatUint(uint64(typ), 16)
|
||||
}
|
||||
|
||||
func Int(typ uint16, v uint64) Attr {
|
||||
switch typ {
|
||||
case AttrRequestedAddressFamily, AttrRequestedTransport:
|
||||
return &number{typ, 4, 24, v}
|
||||
case AttrChannelNumber, AttrResponsePort:
|
||||
return &number{typ, 4, 16, v}
|
||||
case AttrIceControlled, AttrIceControlling:
|
||||
return &number{typ, 8, 0, v}
|
||||
case AttrEvenPort:
|
||||
return &number{typ, 1, 0, v}
|
||||
default:
|
||||
return &number{typ, 4, 0, v}
|
||||
}
|
||||
}
|
||||
|
||||
type number struct {
|
||||
typ uint16
|
||||
size, pad uint8
|
||||
v uint64
|
||||
}
|
||||
|
||||
func (a *number) Type() uint16 { return a.typ }
|
||||
|
||||
func (a *number) Marshal(p []byte) []byte {
|
||||
r, b := grow(p, int(a.size))
|
||||
switch a.size {
|
||||
case 1:
|
||||
b[0] = byte(a.v)
|
||||
case 4:
|
||||
be.PutUint32(b, uint32(a.v<<a.pad))
|
||||
case 8:
|
||||
be.PutUint64(b, a.v<<a.pad)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *number) Unmarshal(b []byte) error {
|
||||
if len(b) < int(a.size) {
|
||||
return errFormat
|
||||
}
|
||||
switch a.size {
|
||||
case 1:
|
||||
a.v = uint64(b[0])
|
||||
case 4:
|
||||
a.v = uint64(be.Uint32(b) >> a.pad)
|
||||
case 8:
|
||||
a.v = be.Uint64(b) >> a.pad
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *number) String() string {
|
||||
return "0x" + strconv.FormatUint(a.v, 16)
|
||||
}
|
||||
|
||||
func Flag(v uint16) Attr {
|
||||
return flag(v)
|
||||
}
|
||||
|
||||
type flag uint16
|
||||
|
||||
func (attr flag) Type() uint16 { return uint16(attr) }
|
||||
func (flag) Marshal(p []byte) []byte { return p }
|
||||
func (flag) Unmarshal(b []byte) error { return nil }
|
||||
|
||||
// Error represents the ERROR-CODE attribute.
|
||||
type Error struct {
|
||||
Code int
|
||||
Reason string
|
||||
}
|
||||
|
||||
func NewError(code int) *Error {
|
||||
return &Error{code, ErrorText(code)}
|
||||
}
|
||||
|
||||
func (*Error) Type() uint16 { return AttrErrorCode }
|
||||
|
||||
func (e *Error) Marshal(p []byte) []byte {
|
||||
r, b := grow(p, 4+len(e.Reason))
|
||||
b[0] = 0
|
||||
b[1] = 0
|
||||
b[2] = byte(e.Code / 100)
|
||||
b[3] = byte(e.Code % 100)
|
||||
copy(b[4:], e.Reason)
|
||||
return r
|
||||
}
|
||||
|
||||
func (e *Error) Unmarshal(b []byte) error {
|
||||
if len(b) < 4 {
|
||||
return errFormat
|
||||
}
|
||||
e.Code = int(b[2])*100 + int(b[3])
|
||||
e.Reason = getString(b[4:])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Error) Error() string { return e.String() }
|
||||
func (e *Error) String() string { return fmt.Sprintf("%d %s", e.Code, e.Reason) }
|
||||
|
||||
// ErrorText returns a text for the STUN error code. It returns the empty string if the code is unknown.
|
||||
func ErrorText(code int) string { return errorText[code] }
|
||||
|
||||
func getString(b []byte) string {
|
||||
for i := len(b); i >= 0; i-- {
|
||||
if b[i-1] > 0 {
|
||||
return string(b[:i])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func Addr(typ uint16, v net.Addr) Attr {
|
||||
ip, port := SockAddr(v)
|
||||
return &addr{typ, ip, port}
|
||||
}
|
||||
|
||||
func SockAddr(v net.Addr) (net.IP, int) {
|
||||
switch a := v.(type) {
|
||||
case *net.UDPAddr:
|
||||
return a.IP, a.Port
|
||||
case *net.TCPAddr:
|
||||
return a.IP, a.Port
|
||||
case *net.IPAddr:
|
||||
return a.IP, 0
|
||||
default:
|
||||
return net.IPv4zero, 0
|
||||
}
|
||||
}
|
||||
|
||||
func sameAddr(a, b net.Addr) bool {
|
||||
aip, aport := SockAddr(a)
|
||||
bip, bport := SockAddr(b)
|
||||
return aip.Equal(bip) && aport == bport
|
||||
}
|
||||
|
||||
func NewAddr(network string, ip net.IP, port int) net.Addr {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
return &net.UDPAddr{IP: ip, Port: port}
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
return &net.TCPAddr{IP: ip, Port: port}
|
||||
}
|
||||
return &net.IPAddr{IP: ip}
|
||||
}
|
||||
|
||||
func IP(typ uint16, ip net.IP) Attr { return &addr{typ, ip, 0} }
|
||||
|
||||
type addr struct {
|
||||
typ uint16
|
||||
IP net.IP
|
||||
Port int
|
||||
}
|
||||
|
||||
func (addr *addr) Type() uint16 { return addr.typ }
|
||||
|
||||
func (addr *addr) Addr(network string) net.Addr {
|
||||
return NewAddr(network, addr.IP, addr.Port)
|
||||
}
|
||||
|
||||
func (addr *addr) Xored() bool {
|
||||
switch addr.typ {
|
||||
case AttrXorMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (addr *addr) Marshal(p []byte) []byte {
|
||||
return addr.MarshalAddr(p, nil)
|
||||
}
|
||||
|
||||
func (addr *addr) MarshalAddr(p, tx []byte) []byte {
|
||||
fam, ip := IPv4, addr.IP.To4()
|
||||
if ip == nil {
|
||||
fam, ip = IPv6, addr.IP
|
||||
}
|
||||
r, b := grow(p, 4+len(ip))
|
||||
b[0] = 0
|
||||
b[1] = byte(fam)
|
||||
if addr.Xored() && tx != nil {
|
||||
be.PutUint16(b[2:], uint16(addr.Port)^0x2112)
|
||||
b = b[4:]
|
||||
for i, it := range ip {
|
||||
b[i] = it ^ tx[i]
|
||||
}
|
||||
} else {
|
||||
be.PutUint16(b[2:], uint16(addr.Port))
|
||||
copy(b[4:], ip)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (addr *addr) Unmarshal(b []byte) error {
|
||||
return addr.UnmarshalAddr(b, nil)
|
||||
}
|
||||
|
||||
func (addr *addr) UnmarshalAddr(b, tx []byte) error {
|
||||
if len(b) < 4 {
|
||||
return errFormat
|
||||
}
|
||||
n, port := net.IPv4len, int(be.Uint16(b[2:]))
|
||||
if b[1] == IPv6 {
|
||||
n = net.IPv6len
|
||||
}
|
||||
if b = b[4:]; len(b) < n {
|
||||
return errFormat
|
||||
}
|
||||
addr.IP = make(net.IP, n)
|
||||
if addr.Xored() && tx != nil {
|
||||
for i, it := range b {
|
||||
addr.IP[i] = it ^ tx[i]
|
||||
}
|
||||
addr.Port = port ^ 0x2112
|
||||
} else {
|
||||
copy(addr.IP, b)
|
||||
addr.Port = port
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (addr *addr) Equal(a *addr) bool {
|
||||
return addr == a || (addr != nil && a != nil && addr.IP.Equal(a.IP) && addr.Port == a.Port)
|
||||
}
|
||||
|
||||
func (addr *addr) String() string {
|
||||
if addr.Port == 0 {
|
||||
return addr.IP.String()
|
||||
}
|
||||
return net.JoinHostPort(addr.IP.String(), strconv.Itoa(addr.Port))
|
||||
}
|
||||
|
||||
func Bytes(typ uint16, v []byte) Attr { return &raw{typ, v} }
|
||||
|
||||
type raw struct {
|
||||
typ uint16
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (attr *raw) Type() uint16 { return attr.typ }
|
||||
func (attr *raw) Marshal(p []byte) []byte { return append(p, attr.data...) }
|
||||
func (attr *raw) Unmarshal(p []byte) error {
|
||||
attr.data = p
|
||||
return nil
|
||||
}
|
||||
func (attr *raw) String() string { return string(attr.data) }
|
||||
|
||||
func String(typ uint16, v string) Attr {
|
||||
return &str{typ, v}
|
||||
}
|
||||
|
||||
type str struct {
|
||||
typ uint16
|
||||
data string
|
||||
}
|
||||
|
||||
func (attr *str) Type() uint16 { return attr.typ }
|
||||
func (attr *str) Marshal(p []byte) []byte { return append(p, attr.data...) }
|
||||
func (attr *str) Unmarshal(p []byte) error {
|
||||
attr.data = string(p)
|
||||
return nil
|
||||
}
|
||||
func (attr *str) String() string { return attr.data }
|
||||
|
||||
func MessageIntegrity(key []byte) Attr {
|
||||
return &integrity{key: key}
|
||||
}
|
||||
|
||||
type integrity struct {
|
||||
key, sum, raw []byte
|
||||
}
|
||||
|
||||
func (*integrity) Type() uint16 {
|
||||
return AttrMessageIntegrity
|
||||
}
|
||||
|
||||
func (attr *integrity) Marshal(p []byte) []byte {
|
||||
return append(p, attr.sum...)
|
||||
}
|
||||
|
||||
func (attr *integrity) Unmarshal(b []byte) error {
|
||||
if len(b) < 20 {
|
||||
return errFormat
|
||||
}
|
||||
attr.sum = b
|
||||
return nil
|
||||
}
|
||||
|
||||
func (attr *integrity) MarshalSum(p, raw []byte) []byte {
|
||||
n := len(raw) - 4
|
||||
be.PutUint16(raw[2:], uint16(n+4))
|
||||
return attr.Sum(attr.key, raw[:n], p)
|
||||
}
|
||||
|
||||
func (attr *integrity) UnmarshalSum(p, raw []byte) error {
|
||||
attr.raw = raw
|
||||
return attr.Unmarshal(p)
|
||||
}
|
||||
|
||||
func (attr *integrity) Sum(key, data, p []byte) []byte {
|
||||
h := hmac.New(sha1.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(p)
|
||||
}
|
||||
|
||||
func (attr *integrity) Check(key []byte) bool {
|
||||
r := attr.raw
|
||||
if len(r) < 44 {
|
||||
return r == nil
|
||||
}
|
||||
be.PutUint16(r[2:], uint16(len(r)-20))
|
||||
h := attr.Sum(key, r[:len(r)-24], nil)
|
||||
return bytes.Equal(h, attr.sum)
|
||||
}
|
||||
|
||||
var Fingerprint Attr = &fingerprint{}
|
||||
|
||||
type fingerprint struct {
|
||||
sum uint32
|
||||
raw []byte
|
||||
}
|
||||
|
||||
func (*fingerprint) Type() uint16 {
|
||||
return AttrFingerprint
|
||||
}
|
||||
|
||||
func (attr *fingerprint) Marshal(p []byte) []byte {
|
||||
r, b := grow(p, 4)
|
||||
be.PutUint32(b, attr.sum)
|
||||
return r
|
||||
}
|
||||
|
||||
func (attr *fingerprint) Unmarshal(b []byte) error {
|
||||
if len(b) < 4 {
|
||||
return errFormat
|
||||
}
|
||||
attr.sum = be.Uint32(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (attr *fingerprint) MarshalSum(p, raw []byte) []byte {
|
||||
n := len(raw) - 4
|
||||
be.PutUint16(raw[2:], uint16(n-12))
|
||||
v := attr.Sum(raw[:n])
|
||||
r, b := grow(p, 4)
|
||||
be.PutUint32(b, v)
|
||||
return r
|
||||
}
|
||||
|
||||
func (attr *fingerprint) UnmarshalSum(p, raw []byte) error {
|
||||
attr.raw = raw
|
||||
return attr.Unmarshal(p)
|
||||
}
|
||||
|
||||
func (attr *fingerprint) Sum(p []byte) uint32 {
|
||||
return crc32.ChecksumIEEE(p) ^ 0x5354554e
|
||||
}
|
||||
|
||||
func (attr *fingerprint) Check() bool {
|
||||
r := attr.raw
|
||||
if len(r) < 28 {
|
||||
return r == nil
|
||||
}
|
||||
be.PutUint16(r[2:], uint16(len(r)-20))
|
||||
return attr.Sum(r[:len(r)-8]) == attr.sum
|
||||
}
|
110
vendor/github.com/isofew/go-stun/stun/conn.go
generated
vendored
Normal file
110
vendor/github.com/isofew/go-stun/stun/conn.go
generated
vendored
Normal file
@ -0,0 +1,110 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
agent *Agent
|
||||
sess *Session
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, config *Config, stop chan struct{}) *Conn {
|
||||
a := NewAgent(config)
|
||||
go a.ServeConn(conn, stop)
|
||||
return &Conn{conn, a, nil}
|
||||
}
|
||||
|
||||
func (c *Conn) Network() string {
|
||||
return c.LocalAddr().Network()
|
||||
}
|
||||
|
||||
func (c *Conn) Discover() (net.Addr, error) {
|
||||
res, err := c.Request(&Message{Type: MethodBinding})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mapped := res.GetAddr(c.Network(), AttrXorMappedAddress, AttrMappedAddress)
|
||||
if mapped != nil {
|
||||
return mapped, nil
|
||||
}
|
||||
return nil, errors.New("stun: bad response, no mapped address")
|
||||
}
|
||||
|
||||
func (c *Conn) Request(req *Message) (res *Message, err error) {
|
||||
res, _, err = c.RequestTransport(req, c.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) RequestTransport(req *Message, to Transport) (res *Message, from Transport, err error) {
|
||||
sess := c.sess
|
||||
auth := c.agent.config.AuthMethod
|
||||
if to == nil {
|
||||
to = c.Conn
|
||||
}
|
||||
for {
|
||||
msg := &Message{
|
||||
req.Type,
|
||||
NewTransaction(),
|
||||
append(sess.attrs(), req.Attributes...),
|
||||
}
|
||||
res, from, err = c.agent.RoundTrip(msg, to)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
code := res.GetError()
|
||||
if code == nil {
|
||||
// FIXME: authorize response...
|
||||
if sess != nil {
|
||||
c.sess = sess
|
||||
}
|
||||
return
|
||||
}
|
||||
err = code
|
||||
switch code.Code {
|
||||
case CodeUnauthorized, CodeStaleNonce:
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
sess = &Session{
|
||||
Realm: res.GetString(AttrRealm),
|
||||
Nonce: res.GetString(AttrNonce),
|
||||
}
|
||||
if err = auth(sess); err != nil {
|
||||
return
|
||||
}
|
||||
auth = nil
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Realm string
|
||||
Nonce string
|
||||
Username string
|
||||
Key []byte
|
||||
}
|
||||
|
||||
func (s *Session) attrs() []Attr {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
var a []Attr
|
||||
if s.Realm != "" {
|
||||
a = append(a, String(AttrRealm, s.Realm))
|
||||
}
|
||||
if s.Nonce != "" {
|
||||
a = append(a, String(AttrNonce, s.Nonce))
|
||||
}
|
||||
if s.Username != "" {
|
||||
a = append(a, String(AttrUsername, s.Username))
|
||||
}
|
||||
if s.Key != nil {
|
||||
a = append(a, MessageIntegrity(s.Key))
|
||||
}
|
||||
return a
|
||||
}
|
203
vendor/github.com/isofew/go-stun/stun/gen.go
generated
vendored
Normal file
203
vendor/github.com/isofew/go-stun/stun/gen.go
generated
vendored
Normal file
@ -0,0 +1,203 @@
|
||||
//+build ignore
|
||||
|
||||
//go:generate go run gen.go
|
||||
// This program generates STUN parameters: methods, attributes and error codes by reading IANA registry.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := generate("http://www.iana.org/assignments/stun-parameters/stun-parameters.xml", "registry.go")
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func loadRegistry(url string) (*Registry, error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &Registry{}
|
||||
err = xml.Unmarshal(b, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := regexp.MustCompile("^([\\w-]+)(.*was\\s+([\\w-]+))?")
|
||||
for _, reg := range r.Registry {
|
||||
for _, r := range reg.Records {
|
||||
m := name.FindStringSubmatch(r.Description)
|
||||
if m != nil {
|
||||
r.Value = strings.ToLower(r.Value)
|
||||
r.Name = m[1]
|
||||
if r.Name == "Reserved" && m[3] != "" {
|
||||
r.Name = m[3]
|
||||
r.Deprecated = true
|
||||
}
|
||||
r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "rfc")
|
||||
r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "RFC-")
|
||||
}
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
Value string `xml:"value"`
|
||||
Description string `xml:"description"`
|
||||
Name string
|
||||
Deprecated bool
|
||||
Ref struct {
|
||||
Type string `xml:"type,attr"`
|
||||
Data string `xml:"data,attr"`
|
||||
} `xml:"xref"`
|
||||
}
|
||||
|
||||
func (r *Record) IsValid() bool {
|
||||
return r.Name != "Reserved" && r.Name != "Unassigned" && (r.Ref.Type == "rfc" || r.Ref.Type == "draft")
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
Title string `xml:"title"`
|
||||
Updated string `xml:"updated"`
|
||||
Registry []struct {
|
||||
Id string `xml:"id,attr"`
|
||||
Records []*Record `xml:"record"`
|
||||
} `xml:"registry"`
|
||||
}
|
||||
|
||||
func (reg *Registry) GetRecords(id string) []*Record {
|
||||
for _, it := range reg.Registry {
|
||||
if it.Id == id {
|
||||
return it.Records
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func generate(url, file string) error {
|
||||
reg, err := loadRegistry(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
fmt.Fprintf(b, "package stun\n\n")
|
||||
fmt.Fprintf(b, "// Do not edit. This file is generated by 'go generate gen.go'\n")
|
||||
fmt.Fprintf(b, "// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA).\n")
|
||||
fmt.Fprintf(b, "// %s, Updated: %s.\n\n", reg.Title, reg.Updated)
|
||||
|
||||
genMethods(reg.GetRecords("stun-parameters-2"), b)
|
||||
genAttributes(reg.GetRecords("stun-parameters-4"), b)
|
||||
genErrors(reg.GetRecords("stun-parameters-6"), b)
|
||||
|
||||
src, err := format.Source(b.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = ioutil.WriteFile(file, src, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func genMethods(records []*Record, b *bytes.Buffer) error {
|
||||
c, m := &bytes.Buffer{}, &bytes.Buffer{}
|
||||
ref := ""
|
||||
for _, it := range records {
|
||||
if it.IsValid() {
|
||||
if ref == it.Ref.Data {
|
||||
fmt.Fprintf(c, "Method%v uint16 = %s\n", it.Name, it.Value)
|
||||
} else {
|
||||
ref = it.Ref.Data
|
||||
fmt.Fprintf(c, "Method%v uint16 = %s // RFC %s\n", it.Name, it.Value, ref)
|
||||
}
|
||||
fmt.Fprintf(m, "Method%v: \"%s\",\n", it.Name, it.Name)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(b, "// STUN methods.\n")
|
||||
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
|
||||
fmt.Fprintf(b, "// STUN method names.\n")
|
||||
fmt.Fprintf(b, "var methodNames = map[uint16]string{\n%s}\n", m.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func genAttributes(records []*Record, b *bytes.Buffer) error {
|
||||
c, d, m := &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}
|
||||
ref := ""
|
||||
for _, it := range records {
|
||||
if it.IsValid() {
|
||||
a := c
|
||||
if it.Deprecated {
|
||||
a = d
|
||||
}
|
||||
v := strings.Replace(it.Name, "_", "-", -1)
|
||||
n := strings.Replace(v, "-", " ", -1)
|
||||
parts := strings.Fields(n)
|
||||
for i, s := range parts {
|
||||
switch s {
|
||||
case "", "ID":
|
||||
default:
|
||||
r, n := utf8.DecodeRuneInString(s)
|
||||
s = string(unicode.ToUpper(r)) + strings.ToLower(s[n:])
|
||||
}
|
||||
parts[i] = s
|
||||
}
|
||||
n = strings.Join(parts, "")
|
||||
if ref == it.Ref.Data || it.Deprecated {
|
||||
fmt.Fprintf(a, "Attr%v uint16 = %s\n", n, it.Value)
|
||||
} else {
|
||||
ref = it.Ref.Data
|
||||
fmt.Fprintf(a, "Attr%v uint16 = %s // RFC %s\n", n, it.Value, ref)
|
||||
}
|
||||
fmt.Fprintf(m, "Attr%v: \"%s\",\n", n, v)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(b, "// STUN attributes.\n")
|
||||
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
|
||||
fmt.Fprintf(b, "// Deprecated: For backwards compatibility only.\n")
|
||||
fmt.Fprintf(b, "const (\n%s)\n", d.Bytes())
|
||||
fmt.Fprintf(b, "// STUN attribute names.\n")
|
||||
fmt.Fprintf(b, "var attrNames = map[uint16]string{\n%s}\n", m.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func genErrors(records []*Record, b *bytes.Buffer) error {
|
||||
c, m := &bytes.Buffer{}, &bytes.Buffer{}
|
||||
ref := ""
|
||||
for _, it := range records {
|
||||
if it.IsValid() {
|
||||
n := strings.Replace(strings.Title(it.Description), " ", "", -1)
|
||||
if ref == it.Ref.Data {
|
||||
fmt.Fprintf(c, "Code%v int = %s\n", n, it.Value)
|
||||
} else {
|
||||
ref = it.Ref.Data
|
||||
fmt.Fprintf(c, "Code%v int = %s // RFC %s\n", n, it.Value, ref)
|
||||
}
|
||||
fmt.Fprintf(m, "Code%v: \"%s\",\n", n, it.Description)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(b, "// STUN error codes.\n")
|
||||
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
|
||||
fmt.Fprintf(b, "// STUN error texts.\n")
|
||||
fmt.Fprintf(b, "var errorText = map[int]string{\n%s}\n", m.Bytes())
|
||||
return nil
|
||||
}
|
351
vendor/github.com/isofew/go-stun/stun/message.go
generated
vendored
Normal file
351
vendor/github.com/isofew/go-stun/stun/message.go
generated
vendored
Normal file
@ -0,0 +1,351 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
KindRequest uint16 = 0x0000
|
||||
KindIndication uint16 = 0x0010
|
||||
KindResponse uint16 = 0x0100
|
||||
KindError uint16 = 0x0110
|
||||
)
|
||||
|
||||
// Message represents a STUN message.
|
||||
type Message struct {
|
||||
Type uint16
|
||||
Transaction []byte
|
||||
Attributes []Attr
|
||||
}
|
||||
|
||||
func (m *Message) Marshal(p []byte) []byte {
|
||||
pos := len(p)
|
||||
r, b := grow(p, 20)
|
||||
be.PutUint16(b, m.Type)
|
||||
|
||||
if m.Transaction != nil {
|
||||
copy(b[4:], m.Transaction)
|
||||
} else {
|
||||
copy(b[4:], magicCookie)
|
||||
rand.Read(b[8:20])
|
||||
}
|
||||
|
||||
sort.Sort(byPosition(m.Attributes))
|
||||
for _, attr := range m.Attributes {
|
||||
r = m.marshalAttr(r, attr, pos)
|
||||
}
|
||||
|
||||
be.PutUint16(r[pos+2:], uint16(len(r)-pos-20))
|
||||
return r
|
||||
}
|
||||
|
||||
func (m *Message) marshalAttr(p []byte, attr Attr, pos int) []byte {
|
||||
h := len(p)
|
||||
r, b := grow(p, 4)
|
||||
be.PutUint16(b, attr.Type())
|
||||
|
||||
switch v := attr.(type) {
|
||||
case *addr:
|
||||
r = v.MarshalAddr(r, r[pos+4:])
|
||||
case *integrity:
|
||||
r = v.MarshalSum(r, r[pos:])
|
||||
case *fingerprint:
|
||||
r = v.MarshalSum(r, r[pos:])
|
||||
default:
|
||||
r = v.Marshal(r)
|
||||
}
|
||||
n := len(r) - h - 4
|
||||
be.PutUint16(r[h+2:], uint16(n))
|
||||
|
||||
if pad := n & 3; pad != 0 {
|
||||
r, b = grow(r, 4-pad)
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (m *Message) Unmarshal(b []byte) (n int, err error) {
|
||||
if len(b) < 20 {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
l := int(be.Uint16(b[2:])) + 20
|
||||
if len(b) < l {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
pos, p := 20, make([]byte, l)
|
||||
copy(p, b[:l])
|
||||
|
||||
m.Type = be.Uint16(p)
|
||||
m.Transaction = p[4:20]
|
||||
|
||||
for pos < len(p) {
|
||||
s, attr, err := m.unmarshalAttr(p, pos)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
pos += s
|
||||
if attr != nil {
|
||||
m.Attributes = append(m.Attributes, attr)
|
||||
}
|
||||
}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (m *Message) unmarshalAttr(p []byte, pos int) (n int, attr Attr, err error) {
|
||||
b := p[pos:]
|
||||
if len(b) < 4 {
|
||||
err = errFormat
|
||||
return
|
||||
}
|
||||
typ := be.Uint16(b)
|
||||
attr, n = newAttr(typ), int(be.Uint16(b[2:]))+4
|
||||
if len(b) < n {
|
||||
err = errFormat
|
||||
return
|
||||
}
|
||||
|
||||
b = b[4:n]
|
||||
if attr != nil {
|
||||
switch v := attr.(type) {
|
||||
case *addr:
|
||||
err = v.UnmarshalAddr(b, m.Transaction)
|
||||
case *integrity:
|
||||
err = v.UnmarshalSum(b, p[:pos+n])
|
||||
case *fingerprint:
|
||||
err = v.UnmarshalSum(b, p[:pos+n])
|
||||
default:
|
||||
err = attr.Unmarshal(b)
|
||||
}
|
||||
} else if typ < 0x8000 {
|
||||
err = errFormat
|
||||
}
|
||||
if err != nil {
|
||||
err = &errAttribute{err, typ}
|
||||
return
|
||||
}
|
||||
if pad := n & 3; pad != 0 {
|
||||
n += 4 - pad
|
||||
if len(p) < pos+n {
|
||||
err = errFormat
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Message) Kind() uint16 {
|
||||
return m.Type & 0x110
|
||||
}
|
||||
|
||||
func (m *Message) Method() uint16 {
|
||||
return m.Type &^ 0x110
|
||||
}
|
||||
|
||||
func (m *Message) Add(attr Attr) {
|
||||
m.Attributes = append(m.Attributes, attr)
|
||||
}
|
||||
|
||||
func (m *Message) Set(attr Attr) {
|
||||
m.Del(attr.Type())
|
||||
m.Add(attr)
|
||||
}
|
||||
|
||||
func (m *Message) Del(typ uint16) {
|
||||
n := 0
|
||||
for _, a := range m.Attributes {
|
||||
if a.Type() != typ {
|
||||
m.Attributes[n] = a
|
||||
n++
|
||||
}
|
||||
}
|
||||
m.Attributes = m.Attributes[:n]
|
||||
}
|
||||
|
||||
func (m *Message) Get(typ uint16) (attr Attr) {
|
||||
for _, attr = range m.Attributes {
|
||||
if attr.Type() == typ {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) Has(typ uint16) bool {
|
||||
for _, attr := range m.Attributes {
|
||||
if attr.Type() == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Message) GetString(typ uint16) string {
|
||||
if str, ok := m.Get(typ).(fmt.Stringer); ok {
|
||||
return str.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Message) GetAddr(network string, typ ...uint16) net.Addr {
|
||||
for _, t := range typ {
|
||||
if addr, ok := m.Get(t).(*addr); ok {
|
||||
return addr.Addr(network)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) GetInt(typ uint16) (v uint64, ok bool) {
|
||||
attr := m.Get(typ)
|
||||
if r, ok := attr.(*number); ok {
|
||||
return r.v, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Message) GetBytes(typ uint16) []byte {
|
||||
if attr, ok := m.Get(typ).(*raw); ok {
|
||||
return attr.data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) GetError() *Error {
|
||||
if err, ok := m.Get(AttrErrorCode).(*Error); ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) CheckIntegrity(key []byte) bool {
|
||||
if attr, ok := m.Get(AttrMessageIntegrity).(*integrity); ok {
|
||||
return attr.Check(key)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Message) CheckFingerprint() bool {
|
||||
if attr, ok := m.Get(AttrFingerprint).(*fingerprint); ok {
|
||||
return attr.Check()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Message) String() string {
|
||||
sort.Sort(byPosition(m.Attributes))
|
||||
|
||||
// TODO: use sprintf
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
b.WriteString(MethodName(m.Type))
|
||||
b.WriteByte('{')
|
||||
tx := m.Transaction
|
||||
if tx == nil {
|
||||
b.WriteString("nil")
|
||||
} else if bytes.Equal(magicCookie, tx[:4]) {
|
||||
b.WriteString(hex.EncodeToString(tx[4:]))
|
||||
} else {
|
||||
b.WriteString(hex.EncodeToString(tx))
|
||||
}
|
||||
for _, attr := range m.Attributes {
|
||||
b.WriteString(", ")
|
||||
b.WriteString(AttrName(attr.Type()))
|
||||
switch v := attr.(type) {
|
||||
case *raw:
|
||||
b.WriteString(": \"")
|
||||
b.Write(v.data)
|
||||
b.WriteByte('"')
|
||||
case *str:
|
||||
b.WriteString(": \"")
|
||||
b.WriteString(v.data)
|
||||
b.WriteByte('"')
|
||||
case flag, *integrity, *fingerprint:
|
||||
default:
|
||||
b.WriteString(fmt.Sprintf(": %v", attr))
|
||||
}
|
||||
}
|
||||
b.WriteByte('}')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func MethodName(typ uint16) string {
|
||||
if r, ok := methodNames[typ&^0x110]; ok {
|
||||
switch typ & 0x110 {
|
||||
case KindRequest:
|
||||
return r + "Request"
|
||||
case KindIndication:
|
||||
return r + "Indication"
|
||||
case KindResponse:
|
||||
return r + "Response"
|
||||
case KindError:
|
||||
return r + "Error"
|
||||
}
|
||||
}
|
||||
return "0x" + strconv.FormatUint(uint64(typ), 16)
|
||||
}
|
||||
|
||||
func UnmarshalMessage(b []byte) (*Message, error) {
|
||||
m := &Message{}
|
||||
if _, err := m.Unmarshal(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var magicCookie = []byte{0x21, 0x12, 0xa4, 0x42}
|
||||
var alphanum = dict("01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
|
||||
type dict []byte
|
||||
|
||||
func (d dict) rand(n int) string {
|
||||
m, b := len(d), make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = d[rand.Intn(m)]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func NewTransaction() []byte {
|
||||
id := make([]byte, 16)
|
||||
copy(id, magicCookie)
|
||||
rand.Read(id[4:]) // TODO: configure random source
|
||||
return id
|
||||
}
|
||||
|
||||
type byPosition []Attr
|
||||
|
||||
func (s byPosition) Len() int { return len(s) }
|
||||
func (s byPosition) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s byPosition) Less(i, j int) bool {
|
||||
a, b := s[i].Type(), s[j].Type()
|
||||
switch b {
|
||||
case a:
|
||||
return i < j
|
||||
case AttrMessageIntegrity:
|
||||
return a != AttrFingerprint
|
||||
case AttrFingerprint:
|
||||
return true
|
||||
default:
|
||||
return i < j
|
||||
}
|
||||
}
|
||||
|
||||
type errAttribute struct {
|
||||
error
|
||||
typ uint16
|
||||
}
|
||||
|
||||
func (err errAttribute) Error() string {
|
||||
return "attribute " + AttrName(err.typ) + ": " + err.error.Error()
|
||||
}
|
172
vendor/github.com/isofew/go-stun/stun/nat.go
generated
vendored
Normal file
172
vendor/github.com/isofew/go-stun/stun/nat.go
generated
vendored
Normal file
@ -0,0 +1,172 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
EndpointIndependent = "endpoint-independent"
|
||||
AddressDependent = "address-dependent"
|
||||
AddressPortDependent = "address-port-dependent"
|
||||
)
|
||||
|
||||
type Detector struct {
|
||||
*Conn
|
||||
}
|
||||
|
||||
func NewDetector(c *Conn) *Detector {
|
||||
d := &Detector{c}
|
||||
d.agent.Handler = &Server{agent: c.agent}
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *Detector) Hairpinning() error {
|
||||
mapped, err := d.Discover()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err := net.Dial(d.Network(), mapped.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// not using stop channel
|
||||
c := NewConn(conn, d.agent.config, make(chan struct{}))
|
||||
defer c.Close()
|
||||
_, err = c.Discover()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Detector) DiscoverChange(change uint64) error {
|
||||
req := &Message{Type: MethodBinding, Attributes: []Attr{Int(AttrChangeRequest, change)}}
|
||||
_, from, err := d.RequestTransport(req, d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ip, port := SockAddr(d.RemoteAddr())
|
||||
chip, chport := SockAddr(from.RemoteAddr())
|
||||
if change&ChangeIP != 0 {
|
||||
if ip.Equal(chip) {
|
||||
return errors.New("stun: bad response, ip address is not changed")
|
||||
}
|
||||
} else if change&ChangePort != 0 {
|
||||
if port == chport {
|
||||
return errors.New("stun: bad response, port is not changed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Detector) Filtering() (string, error) {
|
||||
n := d.Network()
|
||||
if n != "udp" {
|
||||
return "", errors.New("stun: filtering test is not applicable to " + n)
|
||||
}
|
||||
_, err := d.Request(&Message{Type: MethodBinding})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = d.DiscoverChange(ChangeIP | ChangePort)
|
||||
switch err {
|
||||
case nil:
|
||||
return EndpointIndependent, nil
|
||||
case errTimeout:
|
||||
err = d.DiscoverChange(ChangePort)
|
||||
switch err {
|
||||
case nil:
|
||||
return AddressDependent, nil
|
||||
case errTimeout:
|
||||
return AddressPortDependent, nil
|
||||
}
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
func (d *Detector) DiscoverOther(addr net.Addr) (net.Addr, error) {
|
||||
n := addr.Network()
|
||||
conn, err := net.Dial(n, addr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
// not using stop channel
|
||||
go d.agent.ServeConn(conn, make(chan struct{}))
|
||||
res, _, err := d.RequestTransport(&Message{Type: MethodBinding}, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mapped := res.GetAddr(n, AttrXorMappedAddress, AttrMappedAddress)
|
||||
if mapped != nil {
|
||||
return mapped, nil
|
||||
}
|
||||
return nil, errors.New("stun: bad response, no mapped address")
|
||||
}
|
||||
|
||||
func (d *Detector) Mapping() (string, error) {
|
||||
n := d.Network()
|
||||
msg, err := d.Request(&Message{Type: MethodBinding})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
mapped, other := msg.GetAddr(n, AttrXorMappedAddress), msg.GetAddr(n, AttrOtherAddress)
|
||||
if mapped == nil {
|
||||
return "", errors.New("stun: bad response, no mapped address")
|
||||
}
|
||||
if other == nil {
|
||||
return "", errors.New("stun: bad response, no other address")
|
||||
}
|
||||
ip, _ := SockAddr(mapped)
|
||||
if ip.IsLoopback() {
|
||||
return EndpointIndependent, nil
|
||||
}
|
||||
for _, it := range local {
|
||||
if it.IP.Equal(ip) {
|
||||
return EndpointIndependent, nil
|
||||
}
|
||||
}
|
||||
ip, _ = SockAddr(other)
|
||||
_, port := SockAddr(d.RemoteAddr())
|
||||
a, err := d.DiscoverOther(NewAddr(n, ip, port))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if sameAddr(a, mapped) {
|
||||
return EndpointIndependent, nil
|
||||
}
|
||||
b, err := d.DiscoverOther(other)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if sameAddr(b, a) {
|
||||
return AddressDependent, nil
|
||||
}
|
||||
return AddressPortDependent, nil
|
||||
}
|
||||
|
||||
func LocalAddrs() []*net.IPAddr {
|
||||
return local
|
||||
}
|
||||
|
||||
var local []*net.IPAddr
|
||||
|
||||
func init() {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
addrs, _ := iface.Addrs()
|
||||
for _, it := range addrs {
|
||||
var ip net.IP
|
||||
switch it := it.(type) {
|
||||
case *net.IPNet:
|
||||
ip = it.IP
|
||||
case *net.IPAddr:
|
||||
ip = it.IP
|
||||
}
|
||||
if ip != nil && ip.IsGlobalUnicast() {
|
||||
local = append(local, &net.IPAddr{ip, iface.Name})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
179
vendor/github.com/isofew/go-stun/stun/registry.go
generated
vendored
Normal file
179
vendor/github.com/isofew/go-stun/stun/registry.go
generated
vendored
Normal file
@ -0,0 +1,179 @@
|
||||
package stun
|
||||
|
||||
// Do not edit. This file is generated by 'go generate gen.go'
|
||||
// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA).
|
||||
// Session Traversal Utilities for NAT (STUN) Parameters, Updated: 2016-11-14.
|
||||
|
||||
// STUN methods.
|
||||
const (
|
||||
MethodBinding uint16 = 0x001 // RFC 5389
|
||||
MethodSharedSecret uint16 = 0x002
|
||||
MethodAllocate uint16 = 0x003 // RFC 5766
|
||||
MethodRefresh uint16 = 0x004
|
||||
MethodSend uint16 = 0x006
|
||||
MethodData uint16 = 0x007
|
||||
MethodCreatePermission uint16 = 0x008
|
||||
MethodChannelBind uint16 = 0x009
|
||||
MethodConnect uint16 = 0x00a // RFC 6062
|
||||
MethodConnectionBind uint16 = 0x00b
|
||||
MethodConnectionAttempt uint16 = 0x00c
|
||||
)
|
||||
|
||||
// STUN method names.
|
||||
var methodNames = map[uint16]string{
|
||||
MethodBinding: "Binding",
|
||||
MethodSharedSecret: "SharedSecret",
|
||||
MethodAllocate: "Allocate",
|
||||
MethodRefresh: "Refresh",
|
||||
MethodSend: "Send",
|
||||
MethodData: "Data",
|
||||
MethodCreatePermission: "CreatePermission",
|
||||
MethodChannelBind: "ChannelBind",
|
||||
MethodConnect: "Connect",
|
||||
MethodConnectionBind: "ConnectionBind",
|
||||
MethodConnectionAttempt: "ConnectionAttempt",
|
||||
}
|
||||
|
||||
// STUN attributes.
|
||||
const (
|
||||
AttrMappedAddress uint16 = 0x0001 // RFC 5389
|
||||
AttrChangeRequest uint16 = 0x0003 // RFC 5780
|
||||
AttrUsername uint16 = 0x0006 // RFC 5389
|
||||
AttrMessageIntegrity uint16 = 0x0008
|
||||
AttrErrorCode uint16 = 0x0009
|
||||
AttrUnknownAttributes uint16 = 0x000a
|
||||
AttrChannelNumber uint16 = 0x000c // RFC 5766
|
||||
AttrLifetime uint16 = 0x000d
|
||||
AttrXorPeerAddress uint16 = 0x0012
|
||||
AttrData uint16 = 0x0013
|
||||
AttrRealm uint16 = 0x0014 // RFC 5389
|
||||
AttrNonce uint16 = 0x0015
|
||||
AttrXorRelayedAddress uint16 = 0x0016 // RFC 5766
|
||||
AttrRequestedAddressFamily uint16 = 0x0017 // RFC 6156
|
||||
AttrEvenPort uint16 = 0x0018 // RFC 5766
|
||||
AttrRequestedTransport uint16 = 0x0019
|
||||
AttrDontFragment uint16 = 0x001a
|
||||
AttrAccessToken uint16 = 0x001b // RFC 7635
|
||||
AttrXorMappedAddress uint16 = 0x0020 // RFC 5389
|
||||
AttrReservationToken uint16 = 0x0022 // RFC 5766
|
||||
AttrPriority uint16 = 0x0024 // RFC 5245
|
||||
AttrUseCandidate uint16 = 0x0025
|
||||
AttrPadding uint16 = 0x0026 // RFC 5780
|
||||
AttrResponsePort uint16 = 0x0027
|
||||
AttrConnectionID uint16 = 0x002a // RFC 6062
|
||||
AttrSoftware uint16 = 0x8022 // RFC 5389
|
||||
AttrAlternateServer uint16 = 0x8023
|
||||
AttrTransactionTransmitCounter uint16 = 0x8025 // RFC 7982
|
||||
AttrCacheTimeout uint16 = 0x8027 // RFC 5780
|
||||
AttrFingerprint uint16 = 0x8028 // RFC 5389
|
||||
AttrIceControlled uint16 = 0x8029 // RFC 5245
|
||||
AttrIceControlling uint16 = 0x802a
|
||||
AttrResponseOrigin uint16 = 0x802b // RFC 5780
|
||||
AttrOtherAddress uint16 = 0x802c
|
||||
AttrEcnCheck uint16 = 0x802d // RFC 6679
|
||||
AttrThirdPartyAuthorization uint16 = 0x802e // RFC 7635
|
||||
AttrMobilityTicket uint16 = 0x8030 // RFC 8016
|
||||
)
|
||||
|
||||
// Deprecated: For backwards compatibility only.
|
||||
const (
|
||||
AttrResponseAddress uint16 = 0x0002
|
||||
AttrSourceAddress uint16 = 0x0004
|
||||
AttrChangedAddress uint16 = 0x0005
|
||||
AttrPassword uint16 = 0x0007
|
||||
AttrReflectedFrom uint16 = 0x000b
|
||||
AttrBandwidth uint16 = 0x0010
|
||||
AttrTimerVal uint16 = 0x0021
|
||||
)
|
||||
|
||||
// STUN attribute names.
|
||||
var attrNames = map[uint16]string{
|
||||
AttrMappedAddress: "MAPPED-ADDRESS",
|
||||
AttrResponseAddress: "RESPONSE-ADDRESS",
|
||||
AttrChangeRequest: "CHANGE-REQUEST",
|
||||
AttrSourceAddress: "SOURCE-ADDRESS",
|
||||
AttrChangedAddress: "CHANGED-ADDRESS",
|
||||
AttrUsername: "USERNAME",
|
||||
AttrPassword: "PASSWORD",
|
||||
AttrMessageIntegrity: "MESSAGE-INTEGRITY",
|
||||
AttrErrorCode: "ERROR-CODE",
|
||||
AttrUnknownAttributes: "UNKNOWN-ATTRIBUTES",
|
||||
AttrReflectedFrom: "REFLECTED-FROM",
|
||||
AttrChannelNumber: "CHANNEL-NUMBER",
|
||||
AttrLifetime: "LIFETIME",
|
||||
AttrBandwidth: "BANDWIDTH",
|
||||
AttrXorPeerAddress: "XOR-PEER-ADDRESS",
|
||||
AttrData: "DATA",
|
||||
AttrRealm: "REALM",
|
||||
AttrNonce: "NONCE",
|
||||
AttrXorRelayedAddress: "XOR-RELAYED-ADDRESS",
|
||||
AttrRequestedAddressFamily: "REQUESTED-ADDRESS-FAMILY",
|
||||
AttrEvenPort: "EVEN-PORT",
|
||||
AttrRequestedTransport: "REQUESTED-TRANSPORT",
|
||||
AttrDontFragment: "DONT-FRAGMENT",
|
||||
AttrAccessToken: "ACCESS-TOKEN",
|
||||
AttrXorMappedAddress: "XOR-MAPPED-ADDRESS",
|
||||
AttrTimerVal: "TIMER-VAL",
|
||||
AttrReservationToken: "RESERVATION-TOKEN",
|
||||
AttrPriority: "PRIORITY",
|
||||
AttrUseCandidate: "USE-CANDIDATE",
|
||||
AttrPadding: "PADDING",
|
||||
AttrResponsePort: "RESPONSE-PORT",
|
||||
AttrConnectionID: "CONNECTION-ID",
|
||||
AttrSoftware: "SOFTWARE",
|
||||
AttrAlternateServer: "ALTERNATE-SERVER",
|
||||
AttrTransactionTransmitCounter: "TRANSACTION-TRANSMIT-COUNTER",
|
||||
AttrCacheTimeout: "CACHE-TIMEOUT",
|
||||
AttrFingerprint: "FINGERPRINT",
|
||||
AttrIceControlled: "ICE-CONTROLLED",
|
||||
AttrIceControlling: "ICE-CONTROLLING",
|
||||
AttrResponseOrigin: "RESPONSE-ORIGIN",
|
||||
AttrOtherAddress: "OTHER-ADDRESS",
|
||||
AttrEcnCheck: "ECN-CHECK",
|
||||
AttrThirdPartyAuthorization: "THIRD-PARTY-AUTHORIZATION",
|
||||
AttrMobilityTicket: "MOBILITY-TICKET",
|
||||
}
|
||||
|
||||
// STUN error codes.
|
||||
const (
|
||||
CodeTryAlternate int = 300 // RFC 5389
|
||||
CodeBadRequest int = 400
|
||||
CodeUnauthorized int = 401
|
||||
CodeForbidden int = 403 // RFC 5766
|
||||
CodeMobilityForbidden int = 405 // RFC 8016
|
||||
CodeUnknownAttribute int = 420 // RFC 5389
|
||||
CodeAllocationMismatch int = 437 // RFC 5766
|
||||
CodeStaleNonce int = 438 // RFC 5389
|
||||
CodeAddressFamilyNotSupported int = 440 // RFC 6156
|
||||
CodeWrongCredentials int = 441 // RFC 5766
|
||||
CodeUnsupportedTransportProtocol int = 442
|
||||
CodePeerAddressFamilyMismatch int = 443 // RFC 6156
|
||||
CodeConnectionAlreadyExists int = 446 // RFC 6062
|
||||
CodeConnectionTimeoutOrFailure int = 447
|
||||
CodeAllocationQuotaReached int = 486 // RFC 5766
|
||||
CodeRoleConflict int = 487 // RFC 5245
|
||||
CodeServerError int = 500 // RFC 5389
|
||||
CodeInsufficientCapacity int = 508 // RFC 5766
|
||||
)
|
||||
|
||||
// STUN error texts.
|
||||
var errorText = map[int]string{
|
||||
CodeTryAlternate: "Try Alternate",
|
||||
CodeBadRequest: "Bad Request",
|
||||
CodeUnauthorized: "Unauthorized",
|
||||
CodeForbidden: "Forbidden",
|
||||
CodeMobilityForbidden: "Mobility Forbidden",
|
||||
CodeUnknownAttribute: "Unknown Attribute",
|
||||
CodeAllocationMismatch: "Allocation Mismatch",
|
||||
CodeStaleNonce: "Stale Nonce",
|
||||
CodeAddressFamilyNotSupported: "Address Family not Supported",
|
||||
CodeWrongCredentials: "Wrong Credentials",
|
||||
CodeUnsupportedTransportProtocol: "Unsupported Transport Protocol",
|
||||
CodePeerAddressFamilyMismatch: "Peer Address Family Mismatch",
|
||||
CodeConnectionAlreadyExists: "Connection Already Exists",
|
||||
CodeConnectionTimeoutOrFailure: "Connection Timeout or Failure",
|
||||
CodeAllocationQuotaReached: "Allocation Quota Reached",
|
||||
CodeRoleConflict: "Role Conflict",
|
||||
CodeServerError: "Server Error",
|
||||
CodeInsufficientCapacity: "Insufficient Capacity",
|
||||
}
|
126
vendor/github.com/isofew/go-stun/stun/server.go
generated
vendored
Normal file
126
vendor/github.com/isofew/go-stun/stun/server.go
generated
vendored
Normal file
@ -0,0 +1,126 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func ListenAndServe(network, laddr string, config *Config) error {
|
||||
srv := NewServer(config)
|
||||
return srv.ListenAndServe(network, laddr)
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
agent *Agent
|
||||
|
||||
mu sync.RWMutex
|
||||
conns []net.PacketConn
|
||||
}
|
||||
|
||||
func NewServer(config *Config) *Server {
|
||||
srv := &Server{agent: NewAgent(config)}
|
||||
srv.agent.Handler = srv
|
||||
return srv
|
||||
}
|
||||
|
||||
func (srv *Server) ListenAndServe(network, laddr string) error {
|
||||
c, err := net.ListenPacket(network, laddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.addConn(c)
|
||||
defer srv.removeConn(c)
|
||||
// not using stop channel
|
||||
return srv.agent.ServePacket(c, make(chan struct{}))
|
||||
}
|
||||
|
||||
func (srv *Server) ServeSTUN(msg *Message, from Transport) {
|
||||
if msg.Type == MethodBinding {
|
||||
to := from
|
||||
mapped := from.RemoteAddr()
|
||||
ip, port := SockAddr(from.LocalAddr())
|
||||
|
||||
res := &Message{
|
||||
Type: MethodBinding | KindResponse,
|
||||
Transaction: msg.Transaction,
|
||||
Attributes: []Attr{
|
||||
Addr(AttrXorMappedAddress, mapped),
|
||||
Addr(AttrMappedAddress, mapped),
|
||||
},
|
||||
}
|
||||
|
||||
srv.mu.RLock()
|
||||
defer srv.mu.RUnlock()
|
||||
|
||||
if ch, ok := msg.GetInt(AttrChangeRequest); ok && ch != 0 {
|
||||
for _, c := range srv.conns {
|
||||
chip, chport := SockAddr(c.LocalAddr())
|
||||
if chip.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
if ch&ChangeIP != 0 {
|
||||
if !ip.Equal(chip) {
|
||||
to = &packetConn{c, mapped}
|
||||
break
|
||||
}
|
||||
} else if ch&ChangePort != 0 {
|
||||
if ip.Equal(chip) && port != chport {
|
||||
to = &packetConn{c, mapped}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(srv.conns) < 2 {
|
||||
srv.agent.Send(res, to)
|
||||
return
|
||||
}
|
||||
|
||||
other:
|
||||
for _, a := range srv.conns {
|
||||
aip, aport := SockAddr(a.LocalAddr())
|
||||
if aip.IsUnspecified() || !ip.Equal(aip) || port == aport {
|
||||
continue
|
||||
}
|
||||
for _, b := range srv.conns {
|
||||
bip, bport := SockAddr(b.LocalAddr())
|
||||
if bip.IsUnspecified() || bip.Equal(ip) || aport != bport {
|
||||
continue
|
||||
}
|
||||
res.Set(Addr(AttrOtherAddress, b.LocalAddr()))
|
||||
break other
|
||||
}
|
||||
}
|
||||
|
||||
srv.agent.Send(res, to)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) addConn(c net.PacketConn) {
|
||||
srv.mu.Lock()
|
||||
srv.conns = append(srv.conns, c)
|
||||
srv.mu.Unlock()
|
||||
}
|
||||
|
||||
func (srv *Server) removeConn(c net.PacketConn) {
|
||||
srv.mu.Lock()
|
||||
l := srv.conns
|
||||
for i, it := range l {
|
||||
if it == c {
|
||||
srv.conns = append(l[:i], l[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
}
|
||||
|
||||
func (srv *Server) Close() error {
|
||||
srv.mu.RLock()
|
||||
defer srv.mu.RUnlock()
|
||||
for _, it := range srv.conns {
|
||||
it.Close()
|
||||
}
|
||||
srv.conns = nil
|
||||
return nil
|
||||
}
|
129
vendor/github.com/isofew/go-stun/stun/stun.go
generated
vendored
Normal file
129
vendor/github.com/isofew/go-stun/stun/stun.go
generated
vendored
Normal file
@ -0,0 +1,129 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Discover(uri string) (net.PacketConn, net.Addr, error) {
|
||||
stop := make(chan struct{})
|
||||
conn, err := Dial(uri, nil, stop)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
addr, err := conn.Discover()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
// TODO: hijack
|
||||
// stop reading conn before returning it
|
||||
close(stop)
|
||||
// note. serveconn/packet func is blocked by read/readfrom at the time
|
||||
// we send the signal, which means it will still consume one more
|
||||
// packet and we can only read starting from the second packet.
|
||||
// (not too much of a problem, since we'll punch a few packets anyway)
|
||||
return conn.Conn.(net.PacketConn), addr, nil
|
||||
}
|
||||
|
||||
type AuthMethod func(sess *Session) error
|
||||
|
||||
// LongTermAuthMethod returns AuthMethod for long-term credentials.
|
||||
// Key = MD5(username ":" realm ":" SASLprep(password)).
|
||||
// SASLprep is defined in RFC 4013.
|
||||
func LongTermAuthMethod(username, password string) AuthMethod {
|
||||
return func(sess *Session) error {
|
||||
h := md5.New()
|
||||
h.Write([]byte(username + ":" + sess.Realm + ":" + password))
|
||||
sess.Username = username
|
||||
sess.Key = h.Sum(nil)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ShotTermAuthMethod returns AuthMethod for short-term credentials.
|
||||
// Key = SASLprep(password).
|
||||
// SASLprep is defined in RFC 4013.
|
||||
func ShortTermAuthMethod(password string) AuthMethod {
|
||||
key := []byte(password)
|
||||
return func(sess *Session) error {
|
||||
sess.Key = key
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Dial(uri string, config *Config, stop chan struct{}) (*Conn, error) {
|
||||
secure, network, addr, auth, err := parseURI(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var conn net.Conn
|
||||
if secure {
|
||||
conn, err = tls.Dial(network, addr, nil)
|
||||
} else {
|
||||
if strings.HasPrefix(network, "udp") {
|
||||
conn, err = dialUDP(network, addr)
|
||||
} else {
|
||||
conn, err = dialTCP(network, addr)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth != nil {
|
||||
config = config.Clone()
|
||||
config.AuthMethod = auth
|
||||
}
|
||||
return NewConn(conn, config, stop), nil
|
||||
}
|
||||
|
||||
func parseURI(uri string) (secure bool, network, addr string, auth AuthMethod, err error) {
|
||||
var u *url.URL
|
||||
if u, err = url.Parse(uri); err != nil {
|
||||
return
|
||||
}
|
||||
host, port, e := net.SplitHostPort(u.Opaque)
|
||||
if e != nil {
|
||||
host = u.Opaque
|
||||
}
|
||||
if a := u.User; a != nil {
|
||||
if password, ok := a.Password(); ok {
|
||||
auth = LongTermAuthMethod(a.Username(), password)
|
||||
} else {
|
||||
auth = ShortTermAuthMethod(a.Username())
|
||||
}
|
||||
}
|
||||
network = u.Query().Get("transport")
|
||||
if network == "" {
|
||||
network = "udp"
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "stun", "turn":
|
||||
if port == "" {
|
||||
port = "3478"
|
||||
}
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6", "tcp", "tcp4", "tcp6":
|
||||
default:
|
||||
err = errors.New("stun: unsupported transport: " + network)
|
||||
}
|
||||
case "stuns", "turns":
|
||||
if port == "" {
|
||||
port = "5478"
|
||||
}
|
||||
secure = true
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
default:
|
||||
err = errors.New("stun: unsupported transport: " + network)
|
||||
}
|
||||
default:
|
||||
err = errors.New("stun: unsupported scheme " + u.Scheme)
|
||||
}
|
||||
addr = net.JoinHostPort(host, port)
|
||||
return
|
||||
}
|
96
vendor/github.com/isofew/go-stun/stun/transport.go
generated
vendored
Normal file
96
vendor/github.com/isofew/go-stun/stun/transport.go
generated
vendored
Normal file
@ -0,0 +1,96 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Listener interface {
|
||||
Addr() net.Addr
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Transport interface {
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
Write(p []byte) (int, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Marshaler interface {
|
||||
Marshal(b []byte) []byte
|
||||
}
|
||||
|
||||
type TransportHandler interface {
|
||||
ServeTransport(b []byte, tr Transport) (int, error)
|
||||
}
|
||||
|
||||
func dialUDP(network, raddr string) (net.Conn, error) {
|
||||
addr, err := net.ResolveUDPAddr(network, raddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP(network, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &packetConn{conn, addr}, nil
|
||||
}
|
||||
|
||||
func dialTCP(network, raddr string) (net.Conn, error) {
|
||||
addr, err := net.ResolveTCPAddr(network, raddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.DialTCP(network, nil, addr)
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
net.PacketConn
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func (t *packetConn) Read(p []byte) (n int, err error) {
|
||||
n, _, err = t.ReadFrom(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (t *packetConn) Write(p []byte) (int, error) {
|
||||
return t.WriteTo(p, t.addr)
|
||||
}
|
||||
|
||||
func (t *packetConn) RemoteAddr() net.Addr {
|
||||
return t.addr
|
||||
}
|
||||
|
||||
var (
|
||||
errBufferOverflow = errors.New("stun: buffer overflow")
|
||||
errFormat = errors.New("stun: format error")
|
||||
)
|
||||
|
||||
func getBuffer() []byte {
|
||||
return make([]byte, 2048)
|
||||
}
|
||||
|
||||
func putBuffer(b []byte) {
|
||||
if cap(b) >= 2048 {
|
||||
}
|
||||
}
|
||||
|
||||
func grow(p []byte, n int) (b, a []byte) {
|
||||
l := len(p)
|
||||
r := l + n
|
||||
if r > cap(p) {
|
||||
b = make([]byte, (1+((r-1)>>10))<<10)[:r]
|
||||
a = b[l:r]
|
||||
if l > 0 {
|
||||
copy(b, p[:l])
|
||||
}
|
||||
} else {
|
||||
return p[:r], p[l:r]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var be = binary.BigEndian
|
14
vendor/github.com/lucas-clemente/quic-go/Changelog.md
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/Changelog.md
generated
vendored
@ -1,6 +1,18 @@
|
||||
# Changelog
|
||||
|
||||
## v0.6.0 (unreleased)
|
||||
## v0.8.0 (unreleased)
|
||||
|
||||
- Add support for unidirectional streams (for IETF QUIC).
|
||||
- Add a `quic.Config` option for the maximum number of incoming streams.
|
||||
|
||||
## v0.7.0 (2018-02-03)
|
||||
|
||||
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
|
||||
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
|
||||
- Expose the `ConnectionState` in the `Session` (experimental API).
|
||||
- Implement packet pacing.
|
||||
|
||||
## v0.6.0 (2017-12-12)
|
||||
|
||||
- Add support for QUIC 39, drop support for QUIC 35 - 37
|
||||
- Added `quic.Config` options for maximal flow control windows
|
||||
|
7
vendor/github.com/lucas-clemente/quic-go/ackhandler/_gen.go
generated
vendored
7
vendor/github.com/lucas-clemente/quic-go/ackhandler/_gen.go
generated
vendored
@ -1,7 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "github.com/clipperhouse/linkedlist"
|
||||
_ "github.com/clipperhouse/slice"
|
||||
_ "github.com/clipperhouse/stringer"
|
||||
)
|
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
@ -1,34 +0,0 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet) error
|
||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||
SetHandshakeComplete()
|
||||
|
||||
SendingAllowed() bool
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
ShouldSendRetransmittablePacket() bool
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetLeastUnacked() protocol.PacketNumber
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm()
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
|
||||
SetLowerLimit(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame() *wire.AckFrame
|
||||
}
|
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go
generated
vendored
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go
generated
vendored
@ -1,34 +0,0 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// A Packet is a packet
|
||||
// +gen linkedlist
|
||||
type Packet struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
Frames []wire.Frame
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
|
||||
SendTime time.Time
|
||||
}
|
||||
|
||||
// GetFramesForRetransmission gets all the frames for retransmission
|
||||
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
|
||||
var fs []wire.Frame
|
||||
for _, frame := range p.Frames {
|
||||
switch frame.(type) {
|
||||
case *wire.AckFrame:
|
||||
continue
|
||||
case *wire.StopWaitingFrame:
|
||||
continue
|
||||
}
|
||||
fs = append(fs, frame)
|
||||
}
|
||||
return fs
|
||||
}
|
141
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go
generated
vendored
141
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go
generated
vendored
@ -1,141 +0,0 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
|
||||
|
||||
type receivedPacketHandler struct {
|
||||
largestObserved protocol.PacketNumber
|
||||
lowerLimit protocol.PacketNumber
|
||||
largestObservedReceivedTime time.Time
|
||||
|
||||
packetHistory *receivedPacketHistory
|
||||
|
||||
ackSendDelay time.Duration
|
||||
|
||||
packetsReceivedSinceLastAck int
|
||||
retransmittablePacketsReceivedSinceLastAck int
|
||||
ackQueued bool
|
||||
ackAlarm time.Time
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
ackSendDelay: protocol.AckSendDelay,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
|
||||
if packetNumber == 0 {
|
||||
return errInvalidPacketNumber
|
||||
}
|
||||
|
||||
if packetNumber > h.largestObserved {
|
||||
h.largestObserved = packetNumber
|
||||
h.largestObservedReceivedTime = time.Now()
|
||||
}
|
||||
|
||||
if packetNumber <= h.lowerLimit {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.maybeQueueAck(packetNumber, shouldInstigateAck)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLowerLimit sets a lower limit for acking packets.
|
||||
// Packets with packet numbers smaller or equal than p will not be acked.
|
||||
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
|
||||
h.lowerLimit = p
|
||||
h.packetHistory.DeleteUpTo(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
|
||||
h.packetsReceivedSinceLastAck++
|
||||
|
||||
if shouldInstigateAck {
|
||||
h.retransmittablePacketsReceivedSinceLastAck++
|
||||
}
|
||||
|
||||
// always ack the first packet
|
||||
if h.lastAck == nil {
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if h.version < protocol.Version39 {
|
||||
// Always send an ack every 20 packets in order to allow the peer to discard
|
||||
// information from the SentPacketManager and provide an RTT measurement.
|
||||
// From QUIC 39, this is not needed anymore, since the peer will regularly send a retransmittable packet.
|
||||
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
|
||||
h.ackQueued = true
|
||||
}
|
||||
}
|
||||
|
||||
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
|
||||
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
|
||||
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
// check if a new missing range above the previously was created
|
||||
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if !h.ackQueued && shouldInstigateAck {
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
|
||||
h.ackQueued = true
|
||||
} else {
|
||||
if h.ackAlarm.IsZero() {
|
||||
h.ackAlarm = time.Now().Add(h.ackSendDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
// cancel the ack alarm
|
||||
h.ackAlarm = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ackRanges := h.packetHistory.GetAckRanges()
|
||||
ack := &wire.AckFrame{
|
||||
LargestAcked: h.largestObserved,
|
||||
LowestAcked: ackRanges[len(ackRanges)-1].First,
|
||||
PacketReceivedTime: h.largestObservedReceivedTime,
|
||||
}
|
||||
|
||||
if len(ackRanges) > 1 {
|
||||
ack.AckRanges = ackRanges
|
||||
}
|
||||
|
||||
h.lastAck = ack
|
||||
h.ackAlarm = time.Time{}
|
||||
h.ackQueued = false
|
||||
h.packetsReceivedSinceLastAck = 0
|
||||
h.retransmittablePacketsReceivedSinceLastAck = 0
|
||||
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
455
vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go
generated
vendored
455
vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go
generated
vendored
@ -1,455 +0,0 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||
// In fraction of an RTT.
|
||||
timeReorderingFraction = 1.0 / 8
|
||||
// The default RTT used before an RTT sample is taken.
|
||||
// Note: This constant is also defined in the congestion package.
|
||||
defaultInitialRTT = 100 * time.Millisecond
|
||||
// defaultRTOTimeout is the RTO time on new connections
|
||||
defaultRTOTimeout = 500 * time.Millisecond
|
||||
// Minimum time in the future a tail loss probe alarm may be set for.
|
||||
minTPLTimeout = 10 * time.Millisecond
|
||||
// Minimum time in the future an RTO alarm may be set for.
|
||||
minRTOTimeout = 200 * time.Millisecond
|
||||
// maxRTOTimeout is the maximum RTO time
|
||||
maxRTOTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
|
||||
ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
|
||||
// ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets
|
||||
ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets")
|
||||
// ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped
|
||||
ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||
)
|
||||
|
||||
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
|
||||
|
||||
type sentPacketHandler struct {
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet
|
||||
|
||||
LargestAcked protocol.PacketNumber
|
||||
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
|
||||
packetHistory *PacketList
|
||||
stopWaitingManager stopWaitingManager
|
||||
|
||||
retransmissionQueue []*Packet
|
||||
|
||||
bytesInFlight protocol.ByteCount
|
||||
|
||||
congestion congestion.SendAlgorithm
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
handshakeComplete bool
|
||||
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
||||
handshakeCount uint32
|
||||
|
||||
// The number of times an RTO has been sent without receiving an ack.
|
||||
rtoCount uint32
|
||||
|
||||
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
|
||||
lossTime time.Time
|
||||
|
||||
// The alarm timeout
|
||||
alarm time.Time
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
false, /* don't use reno since chromium doesn't (why?) */
|
||||
protocol.InitialCongestionWindow,
|
||||
protocol.DefaultMaxCongestionWindow,
|
||||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: NewPacketList(),
|
||||
stopWaitingManager: stopWaitingManager{},
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
|
||||
if f := h.packetHistory.Front(); f != nil {
|
||||
return f.Value.PacketNumber - 1
|
||||
}
|
||||
return h.LargestAcked
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool {
|
||||
return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
||||
if packet.PacketNumber <= h.lastSentPacketNumber {
|
||||
return errPacketNumberNotIncreasing
|
||||
}
|
||||
|
||||
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
|
||||
return ErrTooManyTrackedSentPackets
|
||||
}
|
||||
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
h.skippedPackets = append(h.skippedPackets, p)
|
||||
|
||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
||||
h.skippedPackets = h.skippedPackets[1:]
|
||||
}
|
||||
}
|
||||
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
now := time.Now()
|
||||
|
||||
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
||||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
packet.SendTime = now
|
||||
h.bytesInFlight += packet.Length
|
||||
h.packetHistory.PushBack(*packet)
|
||||
h.numNonRetransmittablePackets = 0
|
||||
} else {
|
||||
h.numNonRetransmittablePackets++
|
||||
}
|
||||
|
||||
h.congestion.OnPacketSent(
|
||||
now,
|
||||
h.bytesInFlight,
|
||||
packet.PacketNumber,
|
||||
packet.Length,
|
||||
isRetransmittable,
|
||||
)
|
||||
|
||||
h.updateLossDetectionAlarm()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||
if ackFrame.LargestAcked > h.lastSentPacketNumber {
|
||||
return errAckForUnsentPacket
|
||||
}
|
||||
|
||||
// duplicate or out-of-order ACK
|
||||
if withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||
return ErrDuplicateOrOutOfOrderAck
|
||||
}
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
|
||||
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
|
||||
if ackFrame.LargestAcked <= h.largestInOrderAcked() {
|
||||
return nil
|
||||
}
|
||||
h.LargestAcked = ackFrame.LargestAcked
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
return ErrAckForSkippedPacket
|
||||
}
|
||||
|
||||
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
|
||||
|
||||
if rttUpdated {
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
|
||||
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(ackedPackets) > 0 {
|
||||
for _, p := range ackedPackets {
|
||||
if encLevel < p.Value.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
|
||||
}
|
||||
h.onPacketAcked(p)
|
||||
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
h.detectLostPackets()
|
||||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
|
||||
var ackedPackets []*PacketElement
|
||||
ackRangeIndex := 0
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
packetNumber := packet.PacketNumber
|
||||
|
||||
// Ignore packets below the LowestAcked
|
||||
if packetNumber < ackFrame.LowestAcked {
|
||||
continue
|
||||
}
|
||||
// Break after LargestAcked is reached
|
||||
if packetNumber > ackFrame.LargestAcked {
|
||||
break
|
||||
}
|
||||
|
||||
if ackFrame.HasMissingRanges() {
|
||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if packetNumber >= ackRange.First { // packet i contained in ACK range
|
||||
if packetNumber > ackRange.Last {
|
||||
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
|
||||
}
|
||||
ackedPackets = append(ackedPackets, el)
|
||||
}
|
||||
} else {
|
||||
ackedPackets = append(ackedPackets, el)
|
||||
}
|
||||
}
|
||||
|
||||
return ackedPackets, nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
if packet.PacketNumber == largestAcked {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
|
||||
return true
|
||||
}
|
||||
// Packets are sorted by number, so we can stop searching
|
||||
if packet.PacketNumber > largestAcked {
|
||||
break
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if h.packetHistory.Len() == 0 {
|
||||
h.alarm = time.Time{}
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(#497): TLP
|
||||
if !h.handshakeComplete {
|
||||
h.alarm = time.Now().Add(h.computeHandshakeTimeout())
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = h.lossTime
|
||||
} else {
|
||||
// RTO
|
||||
h.alarm = time.Now().Add(h.computeRTOTimeout())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets() {
|
||||
h.lossTime = time.Time{}
|
||||
now := time.Now()
|
||||
|
||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||
|
||||
var lostPackets []*PacketElement
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
|
||||
if packet.PacketNumber > h.LargestAcked {
|
||||
break
|
||||
}
|
||||
|
||||
timeSinceSent := now.Sub(packet.SendTime)
|
||||
if timeSinceSent > delayUntilLost {
|
||||
lostPackets = append(lostPackets, el)
|
||||
} else if h.lossTime.IsZero() {
|
||||
// Note: This conditional is only entered once per call
|
||||
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
|
||||
}
|
||||
}
|
||||
|
||||
if len(lostPackets) > 0 {
|
||||
for _, p := range lostPackets {
|
||||
h.queuePacketForRetransmission(p)
|
||||
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) OnAlarm() {
|
||||
// TODO(#497): TLP
|
||||
if !h.handshakeComplete {
|
||||
h.queueHandshakePacketsForRetransmission()
|
||||
h.handshakeCount++
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit or time loss detection
|
||||
h.detectLostPackets()
|
||||
} else {
|
||||
// RTO
|
||||
h.retransmitOldestTwoPackets()
|
||||
h.rtoCount++
|
||||
}
|
||||
|
||||
h.updateLossDetectionAlarm()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
return h.alarm
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
|
||||
h.bytesInFlight -= packetElement.Value.Length
|
||||
h.rtoCount = 0
|
||||
h.handshakeCount = 0
|
||||
// TODO(#497): h.tlpCount = 0
|
||||
h.packetHistory.Remove(packetElement)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
||||
if len(h.retransmissionQueue) == 0 {
|
||||
return nil
|
||||
}
|
||||
packet := h.retransmissionQueue[0]
|
||||
// Shift the slice and don't retain anything that isn't needed.
|
||||
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
|
||||
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
|
||||
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
|
||||
return packet
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
|
||||
return h.largestInOrderAcked() + 1
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendingAllowed() bool {
|
||||
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
|
||||
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
|
||||
if congestionLimited {
|
||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
|
||||
h.bytesInFlight,
|
||||
h.congestion.GetCongestionWindow())
|
||||
}
|
||||
// Workaround for #555:
|
||||
// Always allow sending of retransmissions. This should probably be limited
|
||||
// to RTOs, but we currently don't have a nice way of distinguishing them.
|
||||
haveRetransmissions := len(h.retransmissionQueue) > 0
|
||||
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
||||
if p := h.packetHistory.Front(); p != nil {
|
||||
h.queueRTO(p)
|
||||
}
|
||||
if p := h.packetHistory.Front(); p != nil {
|
||||
h.queueRTO(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
|
||||
packet := &el.Value
|
||||
utils.Debugf(
|
||||
"\tQueueing packet 0x%x for retransmission (RTO), %d outstanding",
|
||||
packet.PacketNumber,
|
||||
h.packetHistory.Len(),
|
||||
)
|
||||
h.queuePacketForRetransmission(el)
|
||||
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
|
||||
h.congestion.OnRetransmissionTimeout(true)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
|
||||
var handshakePackets []*PacketElement
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, el)
|
||||
}
|
||||
}
|
||||
for _, el := range handshakePackets {
|
||||
h.queuePacketForRetransmission(el)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
|
||||
packet := &packetElement.Value
|
||||
h.bytesInFlight -= packet.Length
|
||||
h.retransmissionQueue = append(h.retransmissionQueue, packet)
|
||||
h.packetHistory.Remove(packetElement)
|
||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||
duration := 2 * h.rttStats.SmoothedRTT()
|
||||
if duration == 0 {
|
||||
duration = 2 * defaultInitialRTT
|
||||
}
|
||||
duration = utils.MaxDuration(duration, minTPLTimeout)
|
||||
// exponential backoff
|
||||
// There's an implicit limit to this set by the handshake timeout.
|
||||
return duration << h.handshakeCount
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||
rto := h.congestion.RetransmissionDelay()
|
||||
if rto == 0 {
|
||||
rto = defaultRTOTimeout
|
||||
}
|
||||
rto = utils.MaxDuration(rto, minRTOTimeout)
|
||||
// Exponential backoff
|
||||
rto = rto << h.rtoCount
|
||||
return utils.MinDuration(rto, maxRTOTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||
for _, p := range h.skippedPackets {
|
||||
if ackFrame.AcksPacket(p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||
lioa := h.largestInOrderAcked()
|
||||
deleteIndex := 0
|
||||
for i, p := range h.skippedPackets {
|
||||
if p <= lioa {
|
||||
deleteIndex = i + 1
|
||||
}
|
||||
}
|
||||
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
||||
}
|
4
vendor/github.com/lucas-clemente/quic-go/appveyor.yml
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/appveyor.yml
generated
vendored
@ -14,8 +14,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
|
||||
|
||||
install:
|
||||
- rmdir c:\go /s /q
|
||||
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.9.2.windows-amd64.zip
|
||||
- 7z x go1.9.2.windows-amd64.zip -y -oC:\ > NUL
|
||||
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.10.2.windows-amd64.zip
|
||||
- 7z x go1.10.2.windows-amd64.zip -y -oC:\ > NUL
|
||||
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
|
||||
- echo %PATH%
|
||||
- echo %GOPATH%
|
||||
|
13
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
@ -8,19 +8,20 @@ import (
|
||||
|
||||
var bufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() []byte {
|
||||
return bufferPool.Get().([]byte)
|
||||
func getPacketBuffer() *[]byte {
|
||||
return bufferPool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
func putPacketBuffer(buf []byte) {
|
||||
if cap(buf) != int(protocol.MaxReceivePacketSize) {
|
||||
func putPacketBuffer(buf *[]byte) {
|
||||
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
bufferPool.Put(buf[:0])
|
||||
bufferPool.Put(buf)
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() interface{} {
|
||||
return make([]byte, 0, protocol.MaxReceivePacketSize)
|
||||
b := make([]byte, 0, protocol.MaxReceivePacketSize)
|
||||
return &b
|
||||
}
|
||||
}
|
||||
|
414
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
414
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
@ -22,24 +23,29 @@ type client struct {
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
handshakeChan <-chan handshakeEvent
|
||||
|
||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
tls handshake.MintTLS // only used when using TLS
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
session packetHandler
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = utils.GenerateConnectionID
|
||||
generateConnectionID = protocol.GenerateConnectionID
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
)
|
||||
|
||||
@ -57,69 +63,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
||||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrNonFWSecure(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func DialNonFWSecure(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
connID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session.(NonFWSession), nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(
|
||||
@ -129,14 +72,57 @@ func Dial(
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
||||
clientConfig := populateClientConfig(config)
|
||||
version := clientConfig.Versions[0]
|
||||
srcConnID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := sess.WaitUntilHandshakeComplete(); err != nil {
|
||||
destConnID := srcConnID
|
||||
if version.UsesTLS() {
|
||||
destConnID, err = generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
if hostname == "" {
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// check that all versions are actually supported
|
||||
if config != nil {
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: version,
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
@ -167,6 +153,18 @@ func populateClientConfig(config *Config) *Config {
|
||||
if maxReceiveConnectionFlowControlWindow == 0 {
|
||||
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
|
||||
}
|
||||
maxIncomingStreams := config.MaxIncomingStreams
|
||||
if maxIncomingStreams == 0 {
|
||||
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
||||
} else if maxIncomingStreams < 0 {
|
||||
maxIncomingStreams = 0
|
||||
}
|
||||
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
||||
if maxIncomingUniStreams == 0 {
|
||||
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
@ -175,29 +173,87 @@ func populateClientConfig(config *Config) *Config {
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
KeepAlive: config.KeepAlive,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
KeepAlive: config.KeepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
if err := c.createNewSession(c.version, nil); err != nil {
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS()
|
||||
} else {
|
||||
err = c.dialGQUIC()
|
||||
}
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialGQUIC() error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
return c.establishSecureConnection()
|
||||
}
|
||||
|
||||
func (c *client) dialTLS() error {
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
||||
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
||||
}
|
||||
csc := handshake.NewCryptoStreamConn(nil)
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mintConf.ExtensionHandler = extHandler
|
||||
mintConf.ServerName = c.hostname
|
||||
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
if err != handshake.ErrCloseSessionForRetry {
|
||||
return err
|
||||
}
|
||||
c.logger.Infof("Received a Retry packet. Recreating session.")
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
runErr = c.session.run()
|
||||
if runErr == errCloseSessionForNewVersion {
|
||||
// run the new session
|
||||
runErr = c.session.run()
|
||||
}
|
||||
runErr = c.session.run() // returns as soon as the session is closed
|
||||
close(errorChan)
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
c.logger.Infof("Connection %s closed.", c.srcConnID)
|
||||
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||
c.conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// wait until the server accepts the QUIC version (or an error occurs)
|
||||
@ -210,96 +266,95 @@ func (c *client) establishSecureConnection() error {
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case ev := <-c.handshakeChan:
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
}
|
||||
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
|
||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
||||
}
|
||||
return nil
|
||||
case err := <-c.session.handshakeStatus():
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Listen listens
|
||||
// Listen listens on the underlying connection and passes packets on for handling.
|
||||
// It returns when the connection is closed.
|
||||
func (c *client) listen() {
|
||||
var err error
|
||||
|
||||
for {
|
||||
var n int
|
||||
var addr net.Addr
|
||||
data := getPacketBuffer()
|
||||
data := *getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, addr, err = c.conn.Read(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
c.session.Close(err)
|
||||
c.mutex.Lock()
|
||||
if c.session != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
data = data[:n]
|
||||
|
||||
c.handlePacket(addr, data)
|
||||
if err := c.handlePacket(addr, data[:n]); err != nil {
|
||||
c.logger.Errorf("error handling packet: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||
// drop the packet if we can't parse the header
|
||||
if err != nil {
|
||||
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||
// drop this packet if we can't parse the header
|
||||
return
|
||||
return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||
}
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||
return
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||
return
|
||||
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
packetData := packet[len(packet)-r.Len():]
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if hdr.ResetFlag {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
|
||||
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||
return
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(r)
|
||||
if err != nil {
|
||||
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
return
|
||||
}
|
||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
return
|
||||
}
|
||||
|
||||
isVersionNegotiationPacket := hdr.VersionFlag /* gQUIC Version Negotiation Packet */ || hdr.Type == protocol.PacketTypeVersionNegotiation /* IETF draft style Version Negotiation Packet */
|
||||
|
||||
// handle Version Negotiation Packets
|
||||
if isVersionNegotiationPacket {
|
||||
if hdr.IsVersionNegotiation {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
return
|
||||
return errors.New("received a delayed Version Negotiation Packet")
|
||||
}
|
||||
|
||||
// version negotiation packets have no payload
|
||||
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if hdr.IsPublicHeader {
|
||||
return c.handleGQUICPacket(hdr, r, packetData, remoteAddr, rcvTime)
|
||||
}
|
||||
return c.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime)
|
||||
}
|
||||
|
||||
func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
||||
// TODO(#1003): add support for server-chosen connection IDs
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
if hdr.IsLongHeader {
|
||||
if hdr.Type != protocol.PacketTypeRetry && hdr.Type != protocol.PacketTypeHandshake {
|
||||
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
|
||||
}
|
||||
c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen)
|
||||
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
|
||||
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
|
||||
}
|
||||
packetData = packetData[:int(hdr.PayloadLen)]
|
||||
// TODO(#1312): implement parsing of compound packets
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
@ -312,9 +367,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||
c.session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packet[len(packet)-r.Len():],
|
||||
data: packetData,
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
|
||||
if hdr.ResetFlag {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.srcConnID) {
|
||||
return errors.New("Received a spoofed Public Reset")
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
}
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||
if !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
close(c.versionNegotiationChan)
|
||||
}
|
||||
|
||||
c.session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
@ -327,42 +421,66 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
}
|
||||
}
|
||||
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
}
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.negotiatedVersions = hdr.SupportedVersions
|
||||
|
||||
// switch to negotiated version
|
||||
initialVersion := c.version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
var err error
|
||||
c.connectionID, err = utils.GenerateConnectionID()
|
||||
c.destConnID, err = generateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||
|
||||
// create a new session and close the old one
|
||||
// the new session must be created first to update client member variables
|
||||
oldSession := c.session
|
||||
defer oldSession.Close(errCloseSessionForNewVersion)
|
||||
return c.createNewSession(initialVersion, hdr.SupportedVersions)
|
||||
// in gQUIC, there's only one connection ID
|
||||
if !c.version.UsesTLS() {
|
||||
c.srcConnID = c.destConnID
|
||||
}
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error {
|
||||
var err error
|
||||
utils.Debugf("createNewSession with initial version %s", initialVersion)
|
||||
c.session, c.handshakeChan, err = newClientSession(
|
||||
func (c *client) createNewGQUICSession() (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.destConnID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
initialVersion,
|
||||
negotiatedVersions,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
c.logger,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
version protocol.VersionNumber,
|
||||
) (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newTLSClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.config,
|
||||
c.tls,
|
||||
paramsChan,
|
||||
1,
|
||||
c.logger,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
7
vendor/github.com/lucas-clemente/quic-go/codecov.yml
generated
vendored
7
vendor/github.com/lucas-clemente/quic-go/codecov.yml
generated
vendored
@ -1,11 +1,16 @@
|
||||
coverage:
|
||||
round: nearest
|
||||
ignore:
|
||||
- ackhandler/packet_linkedlist.go
|
||||
- streams_map_incoming_bidi.go
|
||||
- streams_map_incoming_uni.go
|
||||
- streams_map_outgoing_bidi.go
|
||||
- streams_map_outgoing_uni.go
|
||||
- h2quic/gzipreader.go
|
||||
- h2quic/response.go
|
||||
- internal/ackhandler/packet_linkedlist.go
|
||||
- internal/utils/byteinterval_linkedlist.go
|
||||
- internal/utils/packetinterval_linkedlist.go
|
||||
- internal/utils/linkedlist/linkedlist.go
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
|
183
vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go
generated
vendored
183
vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go
generated
vendored
@ -1,183 +0,0 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
// Note: This constant is also defined in the ackhandler package.
|
||||
initialRTTus = 100 * 1000
|
||||
rttAlpha float32 = 0.125
|
||||
oneMinusAlpha float32 = (1 - rttAlpha)
|
||||
rttBeta float32 = 0.25
|
||||
oneMinusBeta float32 = (1 - rttBeta)
|
||||
halfWindow float32 = 0.5
|
||||
quarterWindow float32 = 0.25
|
||||
)
|
||||
|
||||
type rttSample struct {
|
||||
rtt time.Duration
|
||||
time time.Time
|
||||
}
|
||||
|
||||
// RTTStats provides round-trip statistics
|
||||
type RTTStats struct {
|
||||
initialRTTus int64
|
||||
|
||||
recentMinRTTwindow time.Duration
|
||||
minRTT time.Duration
|
||||
latestRTT time.Duration
|
||||
smoothedRTT time.Duration
|
||||
meanDeviation time.Duration
|
||||
|
||||
numMinRTTsamplesRemaining uint32
|
||||
|
||||
newMinRTT rttSample
|
||||
recentMinRTT rttSample
|
||||
halfWindowRTT rttSample
|
||||
quarterWindowRTT rttSample
|
||||
}
|
||||
|
||||
// NewRTTStats makes a properly initialized RTTStats object
|
||||
func NewRTTStats() *RTTStats {
|
||||
return &RTTStats{
|
||||
initialRTTus: initialRTTus,
|
||||
recentMinRTTwindow: utils.InfDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// InitialRTTus is the initial RTT in us
|
||||
func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus }
|
||||
|
||||
// MinRTT Returns the minRTT for the entire connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
|
||||
|
||||
// LatestRTT returns the most recent rtt measurement.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
|
||||
|
||||
// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the
|
||||
// minRTT for the entire connection if SampleNewMinRtt was never called.
|
||||
func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt }
|
||||
|
||||
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
|
||||
|
||||
// GetQuarterWindowRTT gets the quarter window RTT
|
||||
func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt }
|
||||
|
||||
// GetHalfWindowRTT gets the half window RTT
|
||||
func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt }
|
||||
|
||||
// MeanDeviation gets the mean deviation
|
||||
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
|
||||
|
||||
// SetRecentMinRTTwindow sets how old a recent min rtt sample can be.
|
||||
func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
|
||||
r.recentMinRTTwindow = recentMinRTTwindow
|
||||
}
|
||||
|
||||
// UpdateRTT updates the RTT based on a new sample.
|
||||
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
||||
if sendDelta == utils.InfDuration || sendDelta <= 0 {
|
||||
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
|
||||
return
|
||||
}
|
||||
|
||||
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
|
||||
// ackDelay but the raw observed sendDelta, since poor clock granularity at
|
||||
// the client may cause a high ackDelay to result in underestimation of the
|
||||
// r.minRTT.
|
||||
if r.minRTT == 0 || r.minRTT > sendDelta {
|
||||
r.minRTT = sendDelta
|
||||
}
|
||||
r.updateRecentMinRTT(sendDelta, now)
|
||||
|
||||
// Correct for ackDelay if information received from the peer results in a
|
||||
// positive RTT sample. Otherwise, we use the sendDelta as a reasonable
|
||||
// measure for smoothedRTT.
|
||||
sample := sendDelta
|
||||
if sample > ackDelay {
|
||||
sample -= ackDelay
|
||||
}
|
||||
r.latestRTT = sample
|
||||
// First time call.
|
||||
if r.smoothedRTT == 0 {
|
||||
r.smoothedRTT = sample
|
||||
r.meanDeviation = sample / 2
|
||||
} else {
|
||||
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
|
||||
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update.
|
||||
if r.numMinRTTsamplesRemaining > 0 {
|
||||
r.numMinRTTsamplesRemaining--
|
||||
if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt {
|
||||
r.newMinRTT = rttSample{rtt: sample, time: now}
|
||||
}
|
||||
if r.numMinRTTsamplesRemaining == 0 {
|
||||
r.recentMinRTT = r.newMinRTT
|
||||
r.halfWindowRTT = r.newMinRTT
|
||||
r.quarterWindowRTT = r.newMinRTT
|
||||
}
|
||||
}
|
||||
|
||||
// Update the three recent rtt samples.
|
||||
if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt {
|
||||
r.recentMinRTT = rttSample{rtt: sample, time: now}
|
||||
r.halfWindowRTT = r.recentMinRTT
|
||||
r.quarterWindowRTT = r.recentMinRTT
|
||||
} else if sample <= r.halfWindowRTT.rtt {
|
||||
r.halfWindowRTT = rttSample{rtt: sample, time: now}
|
||||
r.quarterWindowRTT = r.halfWindowRTT
|
||||
} else if sample <= r.quarterWindowRTT.rtt {
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
}
|
||||
|
||||
// Expire old min rtt samples.
|
||||
if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) {
|
||||
r.recentMinRTT = r.halfWindowRTT
|
||||
r.halfWindowRTT = r.quarterWindowRTT
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
} else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) {
|
||||
r.halfWindowRTT = r.quarterWindowRTT
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
} else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) {
|
||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
||||
}
|
||||
}
|
||||
|
||||
// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next
|
||||
// |numSamples| UpdateRTT calls.
|
||||
func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) {
|
||||
r.numMinRTTsamplesRemaining = numSamples
|
||||
r.newMinRTT = rttSample{}
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
|
||||
func (r *RTTStats) OnConnectionMigration() {
|
||||
r.latestRTT = 0
|
||||
r.minRTT = 0
|
||||
r.smoothedRTT = 0
|
||||
r.meanDeviation = 0
|
||||
r.initialRTTus = initialRTTus
|
||||
r.numMinRTTsamplesRemaining = 0
|
||||
r.recentMinRTTwindow = utils.InfDuration
|
||||
r.recentMinRTT = rttSample{}
|
||||
r.halfWindowRTT = rttSample{}
|
||||
r.quarterWindowRTT = rttSample{}
|
||||
}
|
||||
|
||||
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
|
||||
// is larger. The mean deviation is increased to the most recent deviation if
|
||||
// it's larger.
|
||||
func (r *RTTStats) ExpireSmoothedMetrics() {
|
||||
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
|
||||
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
|
||||
}
|
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStreamI interface {
|
||||
StreamID() protocol.StreamID
|
||||
io.Reader
|
||||
io.Writer
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
closeForShutdown(error)
|
||||
setReadOffset(protocol.ByteCount)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type cryptoStream struct {
|
||||
*stream
|
||||
}
|
||||
|
||||
var _ cryptoStreamI = &cryptoStream{}
|
||||
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||
return &cryptoStream{str}
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||
s.receiveStream.readOffset = offset
|
||||
s.receiveStream.frameQueue.readPosition = offset
|
||||
}
|
109
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
109
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
@ -16,23 +16,48 @@ type StreamID = protocol.StreamID
|
||||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
// VersionGQUIC39 is gQUIC version 39.
|
||||
const VersionGQUIC39 = protocol.Version39
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
type ConnectionState = handshake.ConnectionState
|
||||
|
||||
// An ErrorCode is an application-defined error code.
|
||||
type ErrorCode = protocol.ApplicationErrorCode
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
type Stream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Read reads data from the stream.
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Reader
|
||||
// Write writes data to the stream.
|
||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Writer
|
||||
// Close closes the write-direction of the stream.
|
||||
// Future calls to Write are not permitted after calling Close.
|
||||
// It must not be called concurrently with Write.
|
||||
// It must not be called after calling CancelWrite.
|
||||
io.Closer
|
||||
StreamID() StreamID
|
||||
// Reset closes the stream with an error.
|
||||
Reset(error)
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// It must not be called after Close.
|
||||
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||
// Write will unblock immediately, and future calls to Write will fail.
|
||||
CancelWrite(ErrorCode) error
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// It will ask the peer to stop transmitting stream data.
|
||||
// Read will unblock immediately, and future Read calls will fail.
|
||||
CancelRead(ErrorCode) error
|
||||
// The context is canceled as soon as the write-side of the stream is closed.
|
||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
@ -53,18 +78,63 @@ type Stream interface {
|
||||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A ReceiveStream is a unidirectional Receive Stream.
|
||||
type ReceiveStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Read
|
||||
io.Reader
|
||||
// see Stream.CancelRead
|
||||
CancelRead(ErrorCode) error
|
||||
// see Stream.SetReadDealine
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A SendStream is a unidirectional Send Stream.
|
||||
type SendStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Write
|
||||
io.Writer
|
||||
// see Stream.Close
|
||||
io.Closer
|
||||
// see Stream.CancelWrite
|
||||
CancelWrite(ErrorCode) error
|
||||
// see Stream.Context
|
||||
Context() context.Context
|
||||
// see Stream.SetWriteDeadline
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||
type StreamError interface {
|
||||
error
|
||||
Canceled() bool
|
||||
ErrorCode() ErrorCode
|
||||
}
|
||||
|
||||
// A Session is a QUIC connection between two peers.
|
||||
type Session interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
|
||||
AcceptStream() (Stream, error)
|
||||
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached.
|
||||
// New streams always have the smallest possible stream ID.
|
||||
// TODO: Enable testing for the special error
|
||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||
AcceptUniStream() (ReceiveStream, error)
|
||||
// OpenStream opens a new bidirectional QUIC stream.
|
||||
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream.
|
||||
// TODO(#1152): Enable testing for the special error
|
||||
OpenStream() (Stream, error)
|
||||
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
// It always picks the smallest possible stream ID.
|
||||
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
OpenStreamSync() (Stream, error)
|
||||
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
|
||||
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||
// TODO(#1152): Enable testing for the special error
|
||||
OpenUniStream() (SendStream, error)
|
||||
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
|
||||
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
OpenUniStreamSync() (SendStream, error)
|
||||
// LocalAddr returns the local address.
|
||||
LocalAddr() net.Addr
|
||||
// RemoteAddr returns the address of the peer.
|
||||
@ -74,13 +144,9 @@ type Session interface {
|
||||
// The context is cancelled when the session is closed.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
||||
// The communication is encrypted, but not yet forward secure.
|
||||
type NonFWSession interface {
|
||||
Session
|
||||
WaitUntilHandshakeComplete() error
|
||||
// ConnectionState returns basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
ConnectionState() ConnectionState
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
@ -113,6 +179,17 @@ type Config struct {
|
||||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||
MaxReceiveConnectionFlowControlWindow uint64
|
||||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingStreams int
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// This value doesn't have any effect in Google QUIC.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingUniStreams int
|
||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||
KeepAlive bool
|
||||
}
|
||||
|
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go
generated
vendored
Normal file
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go
generated
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
package ackhandler
|
||||
|
||||
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet
|
46
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
46
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet)
|
||||
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
|
||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||
SetHandshakeComplete()
|
||||
|
||||
// The SendMode determines if and what kind of packets can be sent.
|
||||
SendMode() SendMode
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() time.Time
|
||||
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
|
||||
// It always returns a number greater or equal than 1.
|
||||
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
|
||||
// Note that the number of packets is only calculated based on the pacing algorithm.
|
||||
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||
ShouldSendNumPackets() int
|
||||
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm() error
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||
IgnoreBelow(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame() *wire.AckFrame
|
||||
}
|
29
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
29
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// A Packet is a packet
|
||||
type Packet struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
PacketType protocol.PacketType
|
||||
Frames []wire.Frame
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
SendTime time.Time
|
||||
|
||||
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
|
||||
|
||||
// There are two reasons why a packet cannot be retransmitted:
|
||||
// * it was already retransmitted
|
||||
// * this packet is a retransmission, and we already received an ACK for the original packet
|
||||
canBeRetransmitted bool
|
||||
includedInBytesInFlight bool
|
||||
retransmittedAs []protocol.PacketNumber
|
||||
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
|
||||
retransmissionOf protocol.PacketNumber
|
||||
}
|
@ -1,13 +1,10 @@
|
||||
// Generated by: main
|
||||
// TypeWriter: linkedlist
|
||||
// Directive: +gen on Packet
|
||||
// This file was automatically generated by genny.
|
||||
// Any changes will be lost if this file is regenerated.
|
||||
// see https://github.com/cheekybits/genny
|
||||
|
||||
package ackhandler
|
||||
|
||||
// List is a modification of http://golang.org/pkg/container/list/
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// Linked list implementation from the Go standard library.
|
||||
|
||||
// PacketElement is an element of a linked list.
|
||||
type PacketElement struct {
|
||||
@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PacketList represents a doubly linked list.
|
||||
// The zero value for PacketList is an empty list ready to use.
|
||||
// PacketList is a linked list of Packets.
|
||||
type PacketList struct {
|
||||
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list length excluding (this) sentinel element
|
||||
@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
|
||||
// The complexity is O(1).
|
||||
func (l *PacketList) Len() int { return l.len }
|
||||
|
||||
// Front returns the first element of list l or nil.
|
||||
// Front returns the first element of list l or nil if the list is empty.
|
||||
func (l *PacketList) Front() *PacketElement {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
|
||||
return l.root.next
|
||||
}
|
||||
|
||||
// Back returns the last element of list l or nil.
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *PacketList) Back() *PacketElement {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero PacketList value.
|
||||
// lazyInit lazily initializes a zero List value.
|
||||
func (l *PacketList) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at).
|
||||
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
|
||||
return l.insert(&PacketElement{Value: v}, at)
|
||||
}
|
||||
@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
|
||||
|
||||
// Remove removes e from l if e is an element of list l.
|
||||
// It returns the element value e.Value.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) Remove(e *PacketElement) Packet {
|
||||
if e.list == l {
|
||||
// if e.list == l, l must have been initialized when e was inserted
|
||||
// in l or l == nil (e is a zero PacketElement) and l.remove will crash
|
||||
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||
l.remove(e)
|
||||
}
|
||||
return e.Value
|
||||
@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
|
||||
|
||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) MoveToFront(e *PacketElement) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), &l.root)
|
||||
}
|
||||
|
||||
// MoveToBack moves element e to the back of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) MoveToBack(e *PacketElement) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
// see comment in PacketList.Remove about initialization of l
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), l.root.prev)
|
||||
}
|
||||
|
||||
// MoveBefore moves element e to its new position before mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
||||
}
|
||||
|
||||
// MoveAfter moves element e to its new position after mark.
|
||||
// If e is not an element of l, or e == mark, the list is not modified.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
||||
}
|
||||
|
||||
// PushBackList inserts a copy of an other list at the back of list l.
|
||||
// The lists l and other may be the same.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *PacketList) PushBackList(other *PacketList) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||
@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
|
||||
}
|
||||
|
||||
// PushFrontList inserts a copy of an other list at the front of list l.
|
||||
// The lists l and other may be the same.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *PacketList) PushFrontList(other *PacketList) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
215
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
215
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
@ -0,0 +1,215 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type receivedPacketHandler struct {
|
||||
largestObserved protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
largestObservedReceivedTime time.Time
|
||||
|
||||
packetHistory *receivedPacketHistory
|
||||
|
||||
ackSendDelay time.Duration
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
packetsReceivedSinceLastAck int
|
||||
retransmittablePacketsReceivedSinceLastAck int
|
||||
ackQueued bool
|
||||
ackAlarm time.Time
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
const (
|
||||
// maximum delay that can be applied to an ACK for a retransmittable packet
|
||||
ackSendDelay = 25 * time.Millisecond
|
||||
// initial maximum number of retransmittable packets received before sending an ack.
|
||||
initialRetransmittablePacketsBeforeAck = 2
|
||||
// number of retransmittable that an ACK is sent for
|
||||
retransmittablePacketsBeforeAck = 10
|
||||
// 1/5 RTT delay when doing ack decimation
|
||||
ackDecimationDelay = 1.0 / 4
|
||||
// 1/8 RTT delay when doing ack decimation
|
||||
shortAckDecimationDelay = 1.0 / 8
|
||||
// Minimum number of packets received before ack decimation is enabled.
|
||||
// This intends to avoid the beginning of slow start, when CWNDs may be
|
||||
// rapidly increasing.
|
||||
minReceivedBeforeAckDecimation = 100
|
||||
// Maximum number of packets to ack immediately after a missing packet for
|
||||
// fast retransmission to kick in at the sender. This limit is created to
|
||||
// reduce the number of acks sent that have no benefit for fast retransmission.
|
||||
// Set to the number of nacks needed for fast retransmit plus one for protection
|
||||
// against an ack loss
|
||||
maxPacketsAfterNewMissing = 4
|
||||
)
|
||||
|
||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||
func NewReceivedPacketHandler(
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
ackSendDelay: ackSendDelay,
|
||||
rttStats: rttStats,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||
if packetNumber < h.ignoreBelow {
|
||||
return nil
|
||||
}
|
||||
|
||||
isMissing := h.isMissing(packetNumber)
|
||||
if packetNumber > h.largestObserved {
|
||||
h.largestObserved = packetNumber
|
||||
h.largestObservedReceivedTime = rcvTime
|
||||
}
|
||||
|
||||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IgnoreBelow sets a lower limit for acking packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
||||
if p <= h.ignoreBelow {
|
||||
return
|
||||
}
|
||||
h.ignoreBelow = p
|
||||
h.packetHistory.DeleteBelow(p)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
|
||||
}
|
||||
}
|
||||
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil || p < h.ignoreBelow {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
highestRange := h.packetHistory.GetHighestAckRange()
|
||||
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
|
||||
}
|
||||
|
||||
// maybeQueueAck queues an ACK, if necessary.
|
||||
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
|
||||
// in ACK_DECIMATION_WITH_REORDERING mode.
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
|
||||
h.packetsReceivedSinceLastAck++
|
||||
|
||||
// always ack the first packet
|
||||
if h.lastAck == nil {
|
||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||
h.ackQueued = true
|
||||
return
|
||||
}
|
||||
|
||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||
if wasMissing {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
|
||||
}
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if !h.ackQueued && shouldInstigateAck {
|
||||
h.retransmittablePacketsReceivedSinceLastAck++
|
||||
|
||||
if packetNumber > minReceivedBeforeAckDecimation {
|
||||
// ack up to 10 packets at once
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
|
||||
h.ackQueued = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
|
||||
}
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
|
||||
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
|
||||
h.ackAlarm = rcvTime.Add(ackDelay)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// send an ACK every 2 retransmittable packets
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
|
||||
}
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
|
||||
}
|
||||
h.ackAlarm = rcvTime.Add(ackSendDelay)
|
||||
}
|
||||
}
|
||||
// If there are new missing packets to report, set a short timer to send an ACK.
|
||||
if h.hasNewMissingPackets() {
|
||||
// wait the minimum of 1/8 min RTT and the existing ack time
|
||||
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
|
||||
ackTime := rcvTime.Add(ackDelay)
|
||||
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
|
||||
h.ackAlarm = ackTime
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
// cancel the ack alarm
|
||||
h.ackAlarm = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
||||
now := time.Now()
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
||||
return nil
|
||||
}
|
||||
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||
}
|
||||
|
||||
ack := &wire.AckFrame{
|
||||
AckRanges: h.packetHistory.GetAckRanges(),
|
||||
DelayTime: now.Sub(h.largestObservedReceivedTime),
|
||||
}
|
||||
|
||||
h.lastAck = ack
|
||||
h.ackAlarm = time.Time{}
|
||||
h.ackQueued = false
|
||||
h.packetsReceivedSinceLastAck = 0
|
||||
h.retransmittablePacketsReceivedSinceLastAck = 0
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
@ -74,17 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all entries up to (and including) p
|
||||
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
|
||||
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
|
||||
// DeleteBelow deletes all entries below (but not including) p
|
||||
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||
if p <= h.lowestInReceivedPacketNumbers {
|
||||
return
|
||||
}
|
||||
h.lowestInReceivedPacketNumbers = p
|
||||
|
||||
nextEl := h.ranges.Front()
|
||||
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
||||
nextEl = el.Next()
|
||||
|
||||
if p >= el.Value.Start && p < el.Value.End {
|
||||
el.Value.Start = p + 1
|
||||
} else if el.Value.End <= p { // delete a whole range
|
||||
if p > el.Value.Start && p <= el.Value.End {
|
||||
el.Value.Start = p
|
||||
} else if el.Value.End < p { // delete a whole range
|
||||
h.ranges.Remove(el)
|
||||
} else { // no ranges affected. Nothing to do
|
||||
return
|
||||
@ -101,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
|
||||
ackRanges := make([]wire.AckRange, h.ranges.Len())
|
||||
i := 0
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
|
||||
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
|
||||
i++
|
||||
}
|
||||
return ackRanges
|
||||
@ -111,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
||||
ackRange := wire.AckRange{}
|
||||
if h.ranges.Len() > 0 {
|
||||
r := h.ranges.Back().Value
|
||||
ackRange.First = r.Start
|
||||
ackRange.Last = r.End
|
||||
ackRange.Smallest = r.Start
|
||||
ackRange.Largest = r.End
|
||||
}
|
||||
return ackRange
|
||||
}
|
40
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
40
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
package ackhandler
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The SendMode says what kind of packets can be sent.
|
||||
type SendMode uint8
|
||||
|
||||
const (
|
||||
// SendNone means that no packets should be sent
|
||||
SendNone SendMode = iota
|
||||
// SendAck means an ACK-only packet should be sent
|
||||
SendAck
|
||||
// SendRetransmission means that retransmissions should be sent
|
||||
SendRetransmission
|
||||
// SendRTO means that an RTO probe packet should be sent
|
||||
SendRTO
|
||||
// SendTLP means that a TLP probe packet should be sent
|
||||
SendTLP
|
||||
// SendAny means that any packet should be sent
|
||||
SendAny
|
||||
)
|
||||
|
||||
func (s SendMode) String() string {
|
||||
switch s {
|
||||
case SendNone:
|
||||
return "none"
|
||||
case SendAck:
|
||||
return "ack"
|
||||
case SendRetransmission:
|
||||
return "retransmission"
|
||||
case SendRTO:
|
||||
return "rto"
|
||||
case SendTLP:
|
||||
return "tlp"
|
||||
case SendAny:
|
||||
return "any"
|
||||
default:
|
||||
return fmt.Sprintf("invalid send mode: %d", s)
|
||||
}
|
||||
}
|
649
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
649
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
@ -0,0 +1,649 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||
// In fraction of an RTT.
|
||||
timeReorderingFraction = 1.0 / 8
|
||||
// defaultRTOTimeout is the RTO time on new connections
|
||||
defaultRTOTimeout = 500 * time.Millisecond
|
||||
// Minimum time in the future a tail loss probe alarm may be set for.
|
||||
minTPLTimeout = 10 * time.Millisecond
|
||||
// Maximum number of tail loss probes before an RTO fires.
|
||||
maxTLPs = 2
|
||||
// Minimum time in the future an RTO alarm may be set for.
|
||||
minRTOTimeout = 200 * time.Millisecond
|
||||
// maxRTOTimeout is the maximum RTO time
|
||||
maxRTOTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type sentPacketHandler struct {
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
lastSentRetransmittablePacketTime time.Time
|
||||
lastSentHandshakePacketTime time.Time
|
||||
|
||||
nextPacketSendTime time.Time
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||
// example: we send an ACK for packets 90-100 with packet number 20
|
||||
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
|
||||
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||
largestSentBeforeRTO protocol.PacketNumber
|
||||
|
||||
packetHistory *sentPacketHistory
|
||||
stopWaitingManager stopWaitingManager
|
||||
|
||||
retransmissionQueue []*Packet
|
||||
|
||||
bytesInFlight protocol.ByteCount
|
||||
|
||||
congestion congestion.SendAlgorithm
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
handshakeComplete bool
|
||||
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
||||
handshakeCount uint32
|
||||
|
||||
// The number of times a TLP has been sent without receiving an ack.
|
||||
tlpCount uint32
|
||||
allowTLP bool
|
||||
|
||||
// The number of times an RTO has been sent without receiving an ack.
|
||||
rtoCount uint32
|
||||
// The number of RTO probe packets that should be sent.
|
||||
numRTOs int
|
||||
|
||||
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
|
||||
lossTime time.Time
|
||||
|
||||
// The alarm timeout
|
||||
alarm time.Time
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
false, /* don't use reno since chromium doesn't (why?) */
|
||||
protocol.InitialCongestionWindow,
|
||||
protocol.DefaultMaxCongestionWindow,
|
||||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: newSentPacketHistory(),
|
||||
stopWaitingManager: stopWaitingManager{},
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
|
||||
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||
return p.PacketNumber
|
||||
}
|
||||
return h.largestAcked + 1
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range handshakePackets {
|
||||
h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
h.retransmissionQueue = queue
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacket(packet *Packet) {
|
||||
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||
h.packetHistory.SentPacket(packet)
|
||||
h.updateLossDetectionAlarm()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||
var p []*Packet
|
||||
for _, packet := range packets {
|
||||
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||
p = append(p, packet)
|
||||
}
|
||||
}
|
||||
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
|
||||
h.updateLossDetectionAlarm()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
h.skippedPackets = append(h.skippedPackets, p)
|
||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
||||
h.skippedPackets = h.skippedPackets[1:]
|
||||
}
|
||||
}
|
||||
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
|
||||
if len(packet.Frames) > 0 {
|
||||
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||
packet.largestAcked = ackFrame.LargestAcked()
|
||||
}
|
||||
}
|
||||
|
||||
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
||||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.lastSentHandshakePacketTime = packet.SendTime
|
||||
}
|
||||
h.lastSentRetransmittablePacketTime = packet.SendTime
|
||||
packet.includedInBytesInFlight = true
|
||||
h.bytesInFlight += packet.Length
|
||||
packet.canBeRetransmitted = true
|
||||
if h.numRTOs > 0 {
|
||||
h.numRTOs--
|
||||
}
|
||||
h.allowTLP = false
|
||||
}
|
||||
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
|
||||
|
||||
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
|
||||
return isRetransmittable
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||
largestAcked := ackFrame.LargestAcked()
|
||||
if largestAcked > h.lastSentPacketNumber {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||
}
|
||||
|
||||
// duplicate or out of order ACK
|
||||
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
|
||||
return nil
|
||||
}
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
}
|
||||
|
||||
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
|
||||
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
priorInFlight := h.bytesInFlight
|
||||
for _, p := range ackedPackets {
|
||||
if encLevel < p.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||
}
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
if p.largestAcked != 0 {
|
||||
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
|
||||
}
|
||||
if err := h.onPacketAcked(p, rcvTime); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.includedInBytesInFlight {
|
||||
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
|
||||
return err
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||
return h.lowestPacketNotConfirmedAcked
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
|
||||
var ackedPackets []*Packet
|
||||
ackRangeIndex := 0
|
||||
lowestAcked := ackFrame.LowestAcked()
|
||||
largestAcked := ackFrame.LargestAcked()
|
||||
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
// Ignore packets below the lowest acked
|
||||
if p.PacketNumber < lowestAcked {
|
||||
return true, nil
|
||||
}
|
||||
// Break after largest acked is reached
|
||||
if p.PacketNumber > largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ackFrame.HasMissingRanges() {
|
||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
|
||||
if p.PacketNumber > ackRange.Largest {
|
||||
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||
}
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
}
|
||||
} else {
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(ackedPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(ackedPackets))
|
||||
for i, p := range ackedPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
|
||||
}
|
||||
return ackedPackets, err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
|
||||
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if h.packetHistory.Len() == 0 {
|
||||
h.alarm = time.Time{}
|
||||
return
|
||||
}
|
||||
|
||||
if !h.handshakeComplete {
|
||||
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = h.lossTime
|
||||
} else {
|
||||
// RTO or TLP alarm
|
||||
alarmDuration := h.computeRTOTimeout()
|
||||
if h.tlpCount < maxTLPs {
|
||||
tlpAlarm := h.computeTLPTimeout()
|
||||
// if the RTO duration is shorter than the TLP duration, use the RTO duration
|
||||
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
|
||||
}
|
||||
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
|
||||
h.lossTime = time.Time{}
|
||||
|
||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||
|
||||
var lostPackets []*Packet
|
||||
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
|
||||
if packet.PacketNumber > h.largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
timeSinceSent := now.Sub(packet.SendTime)
|
||||
if timeSinceSent > delayUntilLost {
|
||||
lostPackets = append(lostPackets, packet)
|
||||
} else if h.lossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
|
||||
}
|
||||
// Note: This conditional is only entered once per call
|
||||
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(lostPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(lostPackets))
|
||||
for i, p := range lostPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
|
||||
}
|
||||
|
||||
for _, p := range lostPackets {
|
||||
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
|
||||
}
|
||||
if p.canBeRetransmitted {
|
||||
// queue the packet for retransmission, and report the loss to the congestion controller
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) OnAlarm() error {
|
||||
now := time.Now()
|
||||
|
||||
var err error
|
||||
if !h.handshakeComplete {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in handshake mode")
|
||||
}
|
||||
h.handshakeCount++
|
||||
err = h.queueHandshakePacketsForRetransmission()
|
||||
} else if !h.lossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in loss timer mode")
|
||||
}
|
||||
// Early retransmit or time loss detection
|
||||
err = h.detectLostPackets(now, h.bytesInFlight)
|
||||
} else if h.tlpCount < maxTLPs {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in TLP mode")
|
||||
}
|
||||
h.allowTLP = true
|
||||
h.tlpCount++
|
||||
} else {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in RTO mode")
|
||||
}
|
||||
// RTO
|
||||
h.rtoCount++
|
||||
h.numRTOs += 2
|
||||
err = h.queueRTOs()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
return h.alarm
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
|
||||
// This happens if a packet and its retransmissions is acked in the same ACK.
|
||||
// As soon as we process the first one, this will remove all the retransmissions,
|
||||
// so we won't find the retransmitted packet number later.
|
||||
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only report the acking of this packet to the congestion controller if:
|
||||
// * it is a retransmittable packet
|
||||
// * this packet wasn't retransmitted yet
|
||||
if p.isRetransmission {
|
||||
// that the parent doesn't exist is expected to happen every time the original packet was already acked
|
||||
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
|
||||
if len(parent.retransmittedAs) == 1 {
|
||||
parent.retransmittedAs = nil
|
||||
} else {
|
||||
// remove this packet from the slice of retransmission
|
||||
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
|
||||
for _, pn := range parent.retransmittedAs {
|
||||
if pn != p.PacketNumber {
|
||||
retransmittedAs = append(retransmittedAs, pn)
|
||||
}
|
||||
}
|
||||
parent.retransmittedAs = retransmittedAs
|
||||
}
|
||||
}
|
||||
}
|
||||
// this also applies to packets that have been retransmitted as probe packets
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
}
|
||||
if h.rtoCount > 0 {
|
||||
h.verifyRTO(p.PacketNumber)
|
||||
}
|
||||
if err := h.stopRetransmissionsFor(p); err != nil {
|
||||
return err
|
||||
}
|
||||
h.rtoCount = 0
|
||||
h.tlpCount = 0
|
||||
h.handshakeCount = 0
|
||||
return h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
|
||||
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range p.retransmittedAs {
|
||||
packet := h.packetHistory.GetPacket(r)
|
||||
if packet == nil {
|
||||
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
|
||||
}
|
||||
h.stopRetransmissionsFor(packet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
|
||||
if pn <= h.largestSentBeforeRTO {
|
||||
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
|
||||
// Replace SRTT with latest_rtt and increase the variance to prevent
|
||||
// a spurious RTO from happening again.
|
||||
h.rttStats.ExpireSmoothedMetrics()
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
|
||||
h.congestion.OnRetransmissionTimeout(true)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
||||
if len(h.retransmissionQueue) == 0 {
|
||||
return nil
|
||||
}
|
||||
packet := h.retransmissionQueue[0]
|
||||
// Shift the slice and don't retain anything that isn't needed.
|
||||
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
|
||||
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
|
||||
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
|
||||
return packet
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendMode() SendMode {
|
||||
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
|
||||
|
||||
// Don't send any packets if we're keeping track of the maximum number of packets.
|
||||
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
|
||||
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
|
||||
// but still allow sending of retransmissions and ACKs.
|
||||
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||
}
|
||||
return SendNone
|
||||
}
|
||||
if h.allowTLP {
|
||||
return SendTLP
|
||||
}
|
||||
if h.numRTOs > 0 {
|
||||
return SendRTO
|
||||
}
|
||||
// Only send ACKs if we're congestion limited.
|
||||
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
// Send retransmissions first, if there are any.
|
||||
if len(h.retransmissionQueue) > 0 {
|
||||
return SendRetransmission
|
||||
}
|
||||
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
return SendAny
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||
return h.nextPacketSendTime
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
||||
if h.numRTOs > 0 {
|
||||
// RTO probes should not be paced, but must be sent immediately.
|
||||
return h.numRTOs
|
||||
}
|
||||
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||
if delay == 0 || delay > protocol.MinPacingDelay {
|
||||
return 1
|
||||
}
|
||||
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
|
||||
}
|
||||
|
||||
// retransmit the oldest two packets
|
||||
func (h *sentPacketHandler) queueRTOs() error {
|
||||
h.largestSentBeforeRTO = h.lastSentPacketNumber
|
||||
// Queue the first two outstanding packets for retransmission.
|
||||
// This does NOT declare this packets as lost:
|
||||
// They are still tracked in the packet history and count towards the bytes in flight.
|
||||
for i := 0; i < 2; i++ {
|
||||
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||
h.logger.Debugf("Queueing packet %#x for retransmission (RTO)", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range handshakePackets {
|
||||
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
||||
if !p.canBeRetransmitted {
|
||||
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
|
||||
}
|
||||
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
|
||||
// exponential backoff
|
||||
// There's an implicit limit to this set by the handshake timeout.
|
||||
return duration << h.handshakeCount
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
|
||||
// TODO(#1236): include the max_ack_delay
|
||||
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||
var rto time.Duration
|
||||
rtt := h.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
rto = defaultRTOTimeout
|
||||
} else {
|
||||
rto = rtt + 4*h.rttStats.MeanDeviation()
|
||||
}
|
||||
rto = utils.MaxDuration(rto, minRTOTimeout)
|
||||
// Exponential backoff
|
||||
rto = rto << h.rtoCount
|
||||
return utils.MinDuration(rto, maxRTOTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||
for _, p := range h.skippedPackets {
|
||||
if ackFrame.AcksPacket(p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||
lowestUnacked := h.lowestUnacked()
|
||||
deleteIndex := 0
|
||||
for i, p := range h.skippedPackets {
|
||||
if p < lowestUnacked {
|
||||
deleteIndex = i + 1
|
||||
}
|
||||
}
|
||||
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
||||
}
|
127
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
127
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type sentPacketHistory struct {
|
||||
packetList *PacketList
|
||||
packetMap map[protocol.PacketNumber]*PacketElement
|
||||
|
||||
firstOutstanding *PacketElement
|
||||
}
|
||||
|
||||
func newSentPacketHistory() *sentPacketHistory {
|
||||
return &sentPacketHistory{
|
||||
packetList: NewPacketList(),
|
||||
packetMap: make(map[protocol.PacketNumber]*PacketElement),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPacket(p *Packet) {
|
||||
h.sentPacketImpl(p)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
||||
el := h.packetList.PushBack(*p)
|
||||
h.packetMap[p.PacketNumber] = el
|
||||
if h.firstOutstanding == nil {
|
||||
h.firstOutstanding = el
|
||||
}
|
||||
return el
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||
retransmission, ok := h.packetMap[retransmissionOf]
|
||||
// The retransmitted packet is not present anymore.
|
||||
// This can happen if it was acked in between dequeueing of the retransmission and sending.
|
||||
// Just treat the retransmissions as normal packets.
|
||||
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
|
||||
if !ok {
|
||||
for _, packet := range packets {
|
||||
h.sentPacketImpl(packet)
|
||||
}
|
||||
return
|
||||
}
|
||||
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
|
||||
for i, packet := range packets {
|
||||
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
|
||||
el := h.sentPacketImpl(packet)
|
||||
el.Value.isRetransmission = true
|
||||
el.Value.retransmissionOf = retransmissionOf
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
|
||||
if el, ok := h.packetMap[p]; ok {
|
||||
return &el.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterate iterates through all packets.
|
||||
// The callback must not modify the history.
|
||||
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
|
||||
cont := true
|
||||
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
|
||||
var err error
|
||||
cont, err = cb(&el.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FirstOutStanding returns the first outstanding packet.
|
||||
// It must not be modified (e.g. retransmitted).
|
||||
// Use DequeueFirstPacketForRetransmission() to retransmit it.
|
||||
func (h *sentPacketHistory) FirstOutstanding() *Packet {
|
||||
if h.firstOutstanding == nil {
|
||||
return nil
|
||||
}
|
||||
return &h.firstOutstanding.Value
|
||||
}
|
||||
|
||||
// QueuePacketForRetransmission marks a packet for retransmission.
|
||||
// A packet can only be queued once.
|
||||
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
|
||||
el, ok := h.packetMap[pn]
|
||||
if !ok {
|
||||
return fmt.Errorf("sent packet history: packet %d not found", pn)
|
||||
}
|
||||
el.Value.canBeRetransmitted = false
|
||||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
|
||||
// This is necessary every time the first outstanding packet is deleted or retransmitted.
|
||||
func (h *sentPacketHistory) readjustFirstOutstanding() {
|
||||
el := h.firstOutstanding.Next()
|
||||
for el != nil && !el.Value.canBeRetransmitted {
|
||||
el = el.Next()
|
||||
}
|
||||
h.firstOutstanding = el
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Len() int {
|
||||
return len(h.packetMap)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
||||
el, ok := h.packetMap[p]
|
||||
if !ok {
|
||||
return fmt.Errorf("packet %d not found in sent packet history", p)
|
||||
}
|
||||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
}
|
||||
h.packetList.Remove(el)
|
||||
delete(h.packetMap, p)
|
||||
return nil
|
||||
}
|
@ -30,8 +30,9 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFr
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||
if ack.LargestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = ack.LargestAcked + 1
|
||||
largestAcked := ack.LargestAcked()
|
||||
if largestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = largestAcked + 1
|
||||
}
|
||||
}
|
||||
|
@ -16,11 +16,10 @@ import (
|
||||
// allow a 10 shift right to divide.
|
||||
|
||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling
|
||||
// round trip time.
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const cubeScale = 40
|
||||
const cubeCongestionWindowScale = 410
|
||||
const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale
|
||||
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
|
||||
|
||||
const defaultNumConnections = 2
|
||||
|
||||
@ -32,39 +31,35 @@ const beta float32 = 0.7
|
||||
// new concurrent flows and speed up convergence.
|
||||
const betaLastMax float32 = 0.85
|
||||
|
||||
// If true, Cubic's epoch is shifted when the sender is application-limited.
|
||||
const shiftQuicCubicEpochWhenAppLimited = true
|
||||
|
||||
const maxCubicTimeInterval = 30 * time.Millisecond
|
||||
|
||||
// Cubic implements the cubic algorithm from TCP
|
||||
type Cubic struct {
|
||||
clock Clock
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// Time when this cycle started, after last loss event.
|
||||
epoch time.Time
|
||||
// Time when sender went into application-limited period. Zero if not in
|
||||
// application-limited period.
|
||||
appLimitedStartTime time.Time
|
||||
// Time when we updated last_congestion_window.
|
||||
lastUpdateTime time.Time
|
||||
// Last congestion window (in packets) used.
|
||||
lastCongestionWindow protocol.PacketNumber
|
||||
// Max congestion window (in packets) used just before last loss event.
|
||||
|
||||
// Max congestion window used just before last loss event.
|
||||
// Note: to improve fairness to other streams an additional back off is
|
||||
// applied to this value if the new value is below our latest value.
|
||||
lastMaxCongestionWindow protocol.PacketNumber
|
||||
// Number of acked packets since the cycle started (epoch).
|
||||
ackedPacketsCount protocol.PacketNumber
|
||||
lastMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Number of acked bytes since the cycle started (epoch).
|
||||
ackedBytesCount protocol.ByteCount
|
||||
|
||||
// TCP Reno equivalent congestion window in packets.
|
||||
estimatedTCPcongestionWindow protocol.PacketNumber
|
||||
estimatedTCPcongestionWindow protocol.ByteCount
|
||||
|
||||
// Origin point of cubic function.
|
||||
originPointCongestionWindow protocol.PacketNumber
|
||||
originPointCongestionWindow protocol.ByteCount
|
||||
|
||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||
timeToOriginPoint uint32
|
||||
|
||||
// Last congestion window in packets computed by cubic function.
|
||||
lastTargetCongestionWindow protocol.PacketNumber
|
||||
lastTargetCongestionWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
// NewCubic returns a new Cubic instance
|
||||
@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic {
|
||||
// Reset is called after a timeout to reset the cubic state
|
||||
func (c *Cubic) Reset() {
|
||||
c.epoch = time.Time{}
|
||||
c.appLimitedStartTime = time.Time{}
|
||||
c.lastUpdateTime = time.Time{}
|
||||
c.lastCongestionWindow = 0
|
||||
c.lastMaxCongestionWindow = 0
|
||||
c.ackedPacketsCount = 0
|
||||
c.ackedBytesCount = 0
|
||||
c.estimatedTCPcongestionWindow = 0
|
||||
c.originPointCongestionWindow = 0
|
||||
c.timeToOriginPoint = 0
|
||||
@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 {
|
||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
func (c *Cubic) betaLastMax() float32 {
|
||||
// betaLastMax is the additional backoff factor after loss for our
|
||||
// N-connection emulation, which emulates the additional backoff of
|
||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||
// effective multiplier is computed as:
|
||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||
// the available congestion window. Resets Cubic state during quiescence.
|
||||
func (c *Cubic) OnApplicationLimited() {
|
||||
if shiftQuicCubicEpochWhenAppLimited {
|
||||
// When sender is not using the available congestion window, Cubic's epoch
|
||||
// should not continue growing. Record the time when sender goes into an
|
||||
// app-limited period here, to compensate later when cwnd growth happens.
|
||||
if c.appLimitedStartTime.IsZero() {
|
||||
c.appLimitedStartTime = c.clock.Now()
|
||||
}
|
||||
} else {
|
||||
// When sender is not using the available congestion window, Cubic's epoch
|
||||
// should not continue growing. Reset the epoch when in such a period.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
// When sender is not using the available congestion window, the window does
|
||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||
// using the entire window during the time since the beginning of the current
|
||||
// "epoch" (the end of the last loss recovery period). Since
|
||||
// application-limited periods break this assumption, we reset the epoch when
|
||||
// in such a period. This reset effectively freezes congestion window growth
|
||||
// through application-limited periods and allows Cubic growth to continue
|
||||
// when the entire window is being used.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
|
||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||
// a loss event. Returns the new congestion window in packets. The new
|
||||
// congestion window is a multiplicative decrease of our current window.
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber {
|
||||
if currentCongestionWindow < c.lastMaxCongestionWindow {
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
|
||||
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
|
||||
// We never reached the old max, so assume we are competing with another
|
||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||
c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow))
|
||||
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||
} else {
|
||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||
}
|
||||
c.epoch = time.Time{} // Reset time.
|
||||
return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta())
|
||||
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||
}
|
||||
|
||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||
// Returns the new congestion window in packets. The new congestion window
|
||||
// follows a cubic function that depends on the time passed since last
|
||||
// packet loss.
|
||||
func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber {
|
||||
c.ackedPacketsCount++ // Packets acked.
|
||||
currentTime := c.clock.Now()
|
||||
|
||||
// Cubic is "independent" of RTT, the update is limited by the time elapsed.
|
||||
if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) {
|
||||
return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow)
|
||||
}
|
||||
c.lastCongestionWindow = currentCongestionWindow
|
||||
c.lastUpdateTime = currentTime
|
||||
func (c *Cubic) CongestionWindowAfterAck(
|
||||
ackedBytes protocol.ByteCount,
|
||||
currentCongestionWindow protocol.ByteCount,
|
||||
delayMin time.Duration,
|
||||
eventTime time.Time,
|
||||
) protocol.ByteCount {
|
||||
c.ackedBytesCount += ackedBytes
|
||||
|
||||
if c.epoch.IsZero() {
|
||||
// First ACK after a loss event.
|
||||
c.epoch = currentTime // Start of epoch.
|
||||
c.ackedPacketsCount = 1 // Reset count.
|
||||
c.epoch = eventTime // Start of epoch.
|
||||
c.ackedBytesCount = ackedBytes // Reset count.
|
||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||
@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
|
||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||
}
|
||||
} else {
|
||||
// If sender was app-limited, then freeze congestion window growth during
|
||||
// app-limited period. Continue growth now by shifting the epoch-start
|
||||
// through the app-limited period.
|
||||
if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() {
|
||||
shift := currentTime.Sub(c.appLimitedStartTime)
|
||||
c.epoch = c.epoch.Add(shift)
|
||||
c.appLimitedStartTime = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||
// the round trip time in account. This is done to allow us to use shift as a
|
||||
// divide operator.
|
||||
elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000
|
||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||
|
||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||
// Right-shifts of negative, signed numbers have
|
||||
// implementation-dependent behavior. Force the offset to be
|
||||
// positive, similar to the kernel implementation.
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale)
|
||||
var targetCongestionWindow protocol.PacketNumber
|
||||
|
||||
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
|
||||
var targetCongestionWindow protocol.ByteCount
|
||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||
} else {
|
||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||
}
|
||||
// With dynamic beta/alpha based on number of active streams, it is possible
|
||||
// for the required_ack_count to become much lower than acked_packets_count_
|
||||
// suddenly, leading to more than one iteration through the following loop.
|
||||
for {
|
||||
// Update estimated TCP congestion_window.
|
||||
requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha())
|
||||
if c.ackedPacketsCount < requiredAckCount {
|
||||
break
|
||||
}
|
||||
c.ackedPacketsCount -= requiredAckCount
|
||||
c.estimatedTCPcongestionWindow++
|
||||
}
|
||||
// Limit the CWND increase to half the acked bytes.
|
||||
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||
|
||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||
// time we ack an estimated tcp window of bytes. For small
|
||||
// congestion windows (less than 25), the formula below will
|
||||
// increase slightly slower than linearly per estimated tcp window
|
||||
// of bytes.
|
||||
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
|
||||
c.ackedBytesCount = 0
|
||||
|
||||
// We have a new cubic congestion window.
|
||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||
@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
|
||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||
}
|
||||
|
||||
return targetCongestionWindow
|
||||
}
|
||||
|
@ -8,9 +8,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxBurstBytes = 3 * protocol.DefaultTCPMSS
|
||||
defaultMinimumCongestionWindow protocol.PacketNumber = 2
|
||||
renoBeta float32 = 0.7 // Reno backoff factor.
|
||||
maxBurstBytes = 3 * protocol.DefaultTCPMSS
|
||||
renoBeta float32 = 0.7 // Reno backoff factor.
|
||||
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
|
||||
)
|
||||
|
||||
type cubicSender struct {
|
||||
@ -31,12 +31,6 @@ type cubicSender struct {
|
||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||
largestSentAtLastCutback protocol.PacketNumber
|
||||
|
||||
// Congestion window in packets.
|
||||
congestionWindow protocol.PacketNumber
|
||||
|
||||
// Slow start congestion window in packets, aka ssthresh.
|
||||
slowstartThreshold protocol.PacketNumber
|
||||
|
||||
// Whether the last loss event caused us to exit slowstart.
|
||||
// Used for stats collection of slowstartPacketsLost
|
||||
lastCutbackExitedSlowstart bool
|
||||
@ -44,24 +38,35 @@ type cubicSender struct {
|
||||
// When true, exit slow start with large cutback of congestion window.
|
||||
slowStartLargeReduction bool
|
||||
|
||||
// Minimum congestion window in packets.
|
||||
minCongestionWindow protocol.PacketNumber
|
||||
// Congestion window in packets.
|
||||
congestionWindow protocol.ByteCount
|
||||
|
||||
// Maximum number of outstanding packets for tcp.
|
||||
maxTCPCongestionWindow protocol.PacketNumber
|
||||
// Minimum congestion window in packets.
|
||||
minCongestionWindow protocol.ByteCount
|
||||
|
||||
// Maximum congestion window.
|
||||
maxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Slow start congestion window in bytes, aka ssthresh.
|
||||
slowstartThreshold protocol.ByteCount
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// ACK counter for the Reno implementation.
|
||||
congestionWindowCount protocol.ByteCount
|
||||
numAckedPackets uint64
|
||||
|
||||
initialCongestionWindow protocol.PacketNumber
|
||||
initialMaxCongestionWindow protocol.PacketNumber
|
||||
initialCongestionWindow protocol.ByteCount
|
||||
initialMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
minSlowStartExitWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
var _ SendAlgorithm = &cubicSender{}
|
||||
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
|
||||
|
||||
// NewCubicSender makes a new cubic sender
|
||||
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo {
|
||||
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
|
||||
return &cubicSender{
|
||||
rttStats: rttStats,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
@ -69,28 +74,37 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
|
||||
congestionWindow: initialCongestionWindow,
|
||||
minCongestionWindow: defaultMinimumCongestionWindow,
|
||||
slowstartThreshold: initialMaxCongestionWindow,
|
||||
maxTCPCongestionWindow: initialMaxCongestionWindow,
|
||||
maxCongestionWindow: initialMaxCongestionWindow,
|
||||
numConnections: defaultNumConnections,
|
||||
cubic: NewCubic(clock),
|
||||
reno: reno,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold())
|
||||
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if c.GetCongestionWindow() > bytesInFlight {
|
||||
return 0
|
||||
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
|
||||
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
|
||||
delay = delay * 8 / 5
|
||||
}
|
||||
return utils.InfDuration
|
||||
return delay
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
||||
// Only update bytesInFlight for data packets.
|
||||
func (c *cubicSender) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
bytesInFlight protocol.ByteCount,
|
||||
packetNumber protocol.PacketNumber,
|
||||
bytes protocol.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
if !isRetransmittable {
|
||||
return false
|
||||
return
|
||||
}
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
@ -98,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By
|
||||
}
|
||||
c.largestSentPacketNumber = packetNumber
|
||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *cubicSender) InRecovery() bool {
|
||||
@ -110,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool {
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
||||
return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS
|
||||
return c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
|
||||
return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS
|
||||
return c.slowstartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) ExitSlowstart() {
|
||||
c.slowstartThreshold = c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
|
||||
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
|
||||
return c.slowstartThreshold
|
||||
}
|
||||
|
||||
@ -131,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) OnPacketAcked(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
c.prr.OnPacketAcked(ackedBytes)
|
||||
return
|
||||
}
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight)
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||
if c.InSlowStart() {
|
||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) OnPacketLost(
|
||||
packetNumber protocol.PacketNumber,
|
||||
lostBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
) {
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
@ -152,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
||||
c.stats.slowstartPacketsLost++
|
||||
c.stats.slowstartBytesLost += lostBytes
|
||||
if c.slowStartLargeReduction {
|
||||
if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS {
|
||||
// Reduce congestion window by 1 for every mss of bytes lost.
|
||||
c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow)
|
||||
}
|
||||
// Reduce congestion window by lost_bytes for every loss.
|
||||
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
|
||||
c.slowstartThreshold = c.congestionWindow
|
||||
}
|
||||
}
|
||||
@ -166,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
||||
c.stats.slowstartPacketsLost++
|
||||
}
|
||||
|
||||
c.prr.OnPacketLost(bytesInFlight)
|
||||
c.prr.OnPacketLost(priorInFlight)
|
||||
|
||||
// TODO(chromium): Separate out all of slow start into a separate class.
|
||||
if c.slowStartLargeReduction && c.InSlowStart() {
|
||||
c.congestionWindow = c.congestionWindow - 1
|
||||
if c.congestionWindow >= 2*c.initialCongestionWindow {
|
||||
c.minSlowStartExitWindow = c.congestionWindow / 2
|
||||
}
|
||||
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
|
||||
} else if c.reno {
|
||||
c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta())
|
||||
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
|
||||
} else {
|
||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||
}
|
||||
// Enforce a minimum congestion window.
|
||||
if c.congestionWindow < c.minCongestionWindow {
|
||||
c.congestionWindow = c.minCongestionWindow
|
||||
}
|
||||
@ -184,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||
// reset packet count from congestion avoidance mode. We start
|
||||
// counting again when we're out of recovery.
|
||||
c.congestionWindowCount = 0
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
|
||||
func (c *cubicSender) RenoBeta() float32 {
|
||||
@ -197,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 {
|
||||
|
||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||
// represents, but quic has a separate ack for each packet.
|
||||
func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
||||
func (c *cubicSender) maybeIncreaseCwnd(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
// Do not increase the congestion window unless the sender is close to using
|
||||
// the current window.
|
||||
if !c.isCwndLimited(bytesInFlight) {
|
||||
if !c.isCwndLimited(priorInFlight) {
|
||||
c.cubic.OnApplicationLimited()
|
||||
return
|
||||
}
|
||||
if c.congestionWindow >= c.maxTCPCongestionWindow {
|
||||
if c.congestionWindow >= c.maxCongestionWindow {
|
||||
return
|
||||
}
|
||||
if c.InSlowStart() {
|
||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||
c.congestionWindow++
|
||||
c.congestionWindow += protocol.DefaultTCPMSS
|
||||
return
|
||||
}
|
||||
// Congestion avoidance
|
||||
if c.reno {
|
||||
// Classic Reno congestion avoidance.
|
||||
c.congestionWindowCount++
|
||||
c.numAckedPackets++
|
||||
// Divide by num_connections to smoothly increase the CWND at a faster
|
||||
// rate than conventional Reno.
|
||||
if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow {
|
||||
c.congestionWindow++
|
||||
c.congestionWindowCount = 0
|
||||
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
|
||||
c.congestionWindow += protocol.DefaultTCPMSS
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
} else {
|
||||
c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT()))
|
||||
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,21 +306,13 @@ func (c *cubicSender) OnConnectionMigration() {
|
||||
c.largestSentAtLastCutback = 0
|
||||
c.lastCutbackExitedSlowstart = false
|
||||
c.cubic.Reset()
|
||||
c.congestionWindowCount = 0
|
||||
c.numAckedPackets = 0
|
||||
c.congestionWindow = c.initialCongestionWindow
|
||||
c.slowstartThreshold = c.initialMaxCongestionWindow
|
||||
c.maxTCPCongestionWindow = c.initialMaxCongestionWindow
|
||||
c.maxCongestionWindow = c.initialMaxCongestionWindow
|
||||
}
|
||||
|
||||
// SetSlowStartLargeReduction allows enabling the SSLR experiment
|
||||
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
|
||||
c.slowStartLargeReduction = enabled
|
||||
}
|
||||
|
||||
// RetransmissionDelay gives the time to retransmission
|
||||
func (c *cubicSender) RetransmissionDelay() time.Duration {
|
||||
if c.rttStats.SmoothedRTT() == 0 {
|
||||
return 0
|
||||
}
|
||||
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
|
||||
}
|
@ -8,16 +8,15 @@ import (
|
||||
|
||||
// A SendAlgorithm performs congestion control and calculates the congestion window
|
||||
type SendAlgorithm interface {
|
||||
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
|
||||
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
MaybeExitSlowStart()
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
|
||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
|
||||
SetNumEmulatedConnections(n int)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
OnConnectionMigration()
|
||||
RetransmissionDelay() time.Duration
|
||||
|
||||
// Experiments
|
||||
SetSlowStartLargeReduction(enabled bool)
|
||||
@ -31,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface {
|
||||
// Stuff only used in testing
|
||||
|
||||
HybridSlowStart() *HybridSlowStart
|
||||
SlowstartThreshold() protocol.PacketNumber
|
||||
SlowstartThreshold() protocol.ByteCount
|
||||
RenoBeta() float32
|
||||
InRecovery() bool
|
||||
}
|
@ -1,10 +1,7 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
||||
@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
|
||||
// OnPacketLost should be called on the first loss that triggers a recovery
|
||||
// period and all other methods in this class should only be called when in
|
||||
// recovery.
|
||||
func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) {
|
||||
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
|
||||
p.bytesSentSinceLoss = 0
|
||||
p.bytesInFlightBeforeLoss = bytesInFlight
|
||||
p.bytesInFlightBeforeLoss = priorInFlight
|
||||
p.bytesDeliveredSinceLoss = 0
|
||||
p.ackCountSinceLoss = 0
|
||||
}
|
||||
@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
|
||||
p.ackCountSinceLoss++
|
||||
}
|
||||
|
||||
// TimeUntilSend calculates the time until a packet can be sent
|
||||
func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration {
|
||||
// CanSend returns if packets can be sent
|
||||
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
|
||||
// Return QuicTime::Zero In order to ensure limited transmit always works.
|
||||
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
|
||||
return 0
|
||||
return true
|
||||
}
|
||||
if congestionWindow > bytesInFlight {
|
||||
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
|
||||
// of sending the entire available window. This prevents burst retransmits
|
||||
// when more packets are lost than the CWND reduction.
|
||||
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
|
||||
if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss {
|
||||
return utils.InfDuration
|
||||
}
|
||||
return 0
|
||||
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
|
||||
}
|
||||
// Implement Proportional Rate Reduction (RFC6937).
|
||||
// Checks a simplified version of the PRR formula that doesn't use division:
|
||||
// AvailableSendWindow =
|
||||
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
|
||||
if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss {
|
||||
return 0
|
||||
}
|
||||
return utils.InfDuration
|
||||
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
|
||||
}
|
101
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
Normal file
101
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
Normal file
@ -0,0 +1,101 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
rttAlpha float32 = 0.125
|
||||
oneMinusAlpha float32 = (1 - rttAlpha)
|
||||
rttBeta float32 = 0.25
|
||||
oneMinusBeta float32 = (1 - rttBeta)
|
||||
// The default RTT used before an RTT sample is taken.
|
||||
defaultInitialRTT = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// RTTStats provides round-trip statistics
|
||||
type RTTStats struct {
|
||||
minRTT time.Duration
|
||||
latestRTT time.Duration
|
||||
smoothedRTT time.Duration
|
||||
meanDeviation time.Duration
|
||||
}
|
||||
|
||||
// NewRTTStats makes a properly initialized RTTStats object
|
||||
func NewRTTStats() *RTTStats {
|
||||
return &RTTStats{}
|
||||
}
|
||||
|
||||
// MinRTT Returns the minRTT for the entire connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
|
||||
|
||||
// LatestRTT returns the most recent rtt measurement.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
|
||||
|
||||
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
|
||||
|
||||
// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection.
|
||||
// If no valid updates have occurred, it returns the initial RTT.
|
||||
func (r *RTTStats) SmoothedOrInitialRTT() time.Duration {
|
||||
if r.smoothedRTT != 0 {
|
||||
return r.smoothedRTT
|
||||
}
|
||||
return defaultInitialRTT
|
||||
}
|
||||
|
||||
// MeanDeviation gets the mean deviation
|
||||
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
|
||||
|
||||
// UpdateRTT updates the RTT based on a new sample.
|
||||
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
||||
if sendDelta == utils.InfDuration || sendDelta <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
|
||||
// ackDelay but the raw observed sendDelta, since poor clock granularity at
|
||||
// the client may cause a high ackDelay to result in underestimation of the
|
||||
// r.minRTT.
|
||||
if r.minRTT == 0 || r.minRTT > sendDelta {
|
||||
r.minRTT = sendDelta
|
||||
}
|
||||
|
||||
// Correct for ackDelay if information received from the peer results in a
|
||||
// an RTT sample at least as large as minRTT. Otherwise, only use the
|
||||
// sendDelta.
|
||||
sample := sendDelta
|
||||
if sample-r.minRTT >= ackDelay {
|
||||
sample -= ackDelay
|
||||
}
|
||||
r.latestRTT = sample
|
||||
// First time call.
|
||||
if r.smoothedRTT == 0 {
|
||||
r.smoothedRTT = sample
|
||||
r.meanDeviation = sample / 2
|
||||
} else {
|
||||
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
|
||||
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
|
||||
}
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
|
||||
func (r *RTTStats) OnConnectionMigration() {
|
||||
r.latestRTT = 0
|
||||
r.minRTT = 0
|
||||
r.smoothedRTT = 0
|
||||
r.meanDeviation = 0
|
||||
}
|
||||
|
||||
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
|
||||
// is larger. The mean deviation is increased to the most recent deviation if
|
||||
// it's larger.
|
||||
func (r *RTTStats) ExpireSmoothedMetrics() {
|
||||
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
|
||||
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
|
||||
}
|
22
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go
generated
vendored
22
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go
generated
vendored
@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
|
||||
return cert.Certificate[0], nil
|
||||
}
|
||||
|
||||
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
c := cc.config
|
||||
c, err := maybeGetConfigForClient(c, sni)
|
||||
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
conf := c.config
|
||||
conf, err := maybeGetConfigForClient(conf, sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The rest of this function is mostly copied from crypto/tls.getCertificate
|
||||
|
||||
if c.GetCertificate != nil {
|
||||
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
||||
if conf.GetCertificate != nil {
|
||||
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
||||
if cert != nil || err != nil {
|
||||
return cert, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.Certificates) == 0 {
|
||||
if len(conf.Certificates) == 0 {
|
||||
return nil, errNoMatchingCertificate
|
||||
}
|
||||
|
||||
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
|
||||
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
|
||||
// There's only one choice, so no point doing any work.
|
||||
return &c.Certificates[0], nil
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
name := strings.ToLower(sni)
|
||||
@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
|
||||
if cert, ok := c.NameToCertificate[name]; ok {
|
||||
if cert, ok := conf.NameToCertificate[name]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok := c.NameToCertificate[candidate]; ok {
|
||||
if cert, ok := conf.NameToCertificate[candidate]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing matches, return the first certificate.
|
||||
return &c.Certificates[0], nil
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
|
||||
|
5
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go
generated
vendored
5
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go
generated
vendored
@ -18,6 +18,7 @@ type CertManager interface {
|
||||
GetLeafCertHash() (uint64, error)
|
||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
||||
Verify(hostname string) error
|
||||
GetChain() []*x509.Certificate
|
||||
}
|
||||
|
||||
type certManager struct {
|
||||
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *certManager) GetChain() []*x509.Certificate {
|
||||
return c.chain
|
||||
}
|
||||
|
||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
||||
return getCommonCertificateHashes()
|
||||
}
|
||||
|
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
@ -1,61 +0,0 @@
|
||||
// +build ignore
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/aead/chacha20"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadChacha20Poly1305 struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
|
||||
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
|
||||
}
|
||||
// copy because ChaCha20Poly1305 expects array pointers
|
||||
var MyKey, OtherKey [32]byte
|
||||
copy(MyKey[:], myKey)
|
||||
copy(OtherKey[:], otherKey)
|
||||
|
||||
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadChacha20Poly1305{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
71
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go
generated
vendored
71
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go
generated
vendored
@ -1,71 +0,0 @@
|
||||
// +build ignore
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Chacha20poly1305", func() {
|
||||
var (
|
||||
alice, bob AEAD
|
||||
keyAlice, keyBob, ivAlice, ivBob []byte
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
keyAlice = make([]byte, 32)
|
||||
keyBob = make([]byte, 32)
|
||||
ivAlice = make([]byte, 4)
|
||||
ivBob = make([]byte, 4)
|
||||
rand.Reader.Read(keyAlice)
|
||||
rand.Reader.Read(keyBob)
|
||||
rand.Reader.Read(ivAlice)
|
||||
rand.Reader.Read(ivBob)
|
||||
var err error
|
||||
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("seals and opens", func() {
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
text, err := bob.Open(nil, b, 42, []byte("aad"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(text).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("seals and opens reverse", func() {
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
text, err := alice.Open(nil, b, 42, []byte("aad"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(text).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("has the proper length", func() {
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
Expect(b).To(HaveLen(6 + 12))
|
||||
})
|
||||
|
||||
It("fails with wrong aad", func() {
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
_, err := bob.Open(nil, b, 42, []byte("aad2"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects wrong key and iv sizes", func() {
|
||||
var err error
|
||||
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
|
||||
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
|
||||
Expect(err).To(MatchError(e))
|
||||
})
|
||||
})
|
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go
generated
vendored
@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) {
|
||||
if _, err := rand.Read(c.secret[:]); err != nil {
|
||||
return nil, errors.New("Curve25519: could not create private key")
|
||||
}
|
||||
// See https://cr.yp.to/ecdh.html
|
||||
c.secret[0] &= 248
|
||||
c.secret[31] &= 127
|
||||
c.secret[31] |= 64
|
||||
curve25519.ScalarBaseMult(&c.public, &c.secret)
|
||||
return c, nil
|
||||
}
|
||||
|
19
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
19
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
@ -1,13 +1,16 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret"
|
||||
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
|
||||
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
|
||||
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
|
||||
)
|
||||
|
||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||
@ -16,6 +19,14 @@ type TLSExporter interface {
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
}
|
||||
|
||||
func qhkdfExpand(secret []byte, label string, length int) []byte {
|
||||
qlabel := make([]byte, 2+1+5+len(label))
|
||||
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
|
||||
qlabel[2] = uint8(5 + len(label))
|
||||
copy(qlabel[3:], []byte("QUIC "+label))
|
||||
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
|
||||
}
|
||||
|
||||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
||||
var myLabel, otherLabel string
|
||||
@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen)
|
||||
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen)
|
||||
key = qhkdfExpand(secret, "key", cs.KeyLen)
|
||||
iv = qhkdfExpand(secret, "iv", cs.IvLen)
|
||||
return key, iv, nil
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
|
||||
} else {
|
||||
info.Write([]byte("QUIC key expansion\x00"))
|
||||
}
|
||||
utils.BigEndian.WriteUint64(&info, uint64(connID))
|
||||
info.Write(connID)
|
||||
info.Write(chlo)
|
||||
info.Write(scfg)
|
||||
info.Write(cert)
|
||||
|
17
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
17
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
@ -2,13 +2,12 @@ package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39}
|
||||
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
|
||||
|
||||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||
@ -28,17 +27,15 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
|
||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
connID := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(connID, uint64(connectionID))
|
||||
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID)
|
||||
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
|
||||
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
|
||||
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
||||
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16)
|
||||
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12)
|
||||
key = qhkdfExpand(secret, "key", 16)
|
||||
iv = qhkdfExpand(secret, "iv", 12)
|
||||
return
|
||||
}
|
||||
|
36
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go
generated
vendored
36
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go
generated
vendored
@ -1,10 +1,11 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/fnv128a"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
|
||||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
||||
}
|
||||
|
||||
hash := fnv128a.New()
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src[12:])
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
|
||||
} else {
|
||||
hash.Write([]byte("Server"))
|
||||
}
|
||||
testHigh, testLow := hash.Sum128()
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
low := binary.LittleEndian.Uint64(src)
|
||||
high := binary.LittleEndian.Uint32(src[8:])
|
||||
|
||||
if uint32(testHigh&0xffffffff) != high || testLow != low {
|
||||
return nil, errors.New("NullAEAD: failed to authenticate received data")
|
||||
if !bytes.Equal(sum[:12], src[:12]) {
|
||||
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
|
||||
}
|
||||
return src[12:], nil
|
||||
}
|
||||
@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
|
||||
dst = dst[:12+len(src)]
|
||||
}
|
||||
|
||||
hash := fnv128a.New()
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src)
|
||||
|
||||
@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
|
||||
} else {
|
||||
hash.Write([]byte("Client"))
|
||||
}
|
||||
|
||||
high, low := hash.Sum128()
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
copy(dst[12:], src)
|
||||
binary.LittleEndian.PutUint64(dst, low)
|
||||
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
|
||||
copy(dst, sum[:12])
|
||||
return dst
|
||||
}
|
||||
|
||||
func (n *nullAEADFNV128a) Overhead() int {
|
||||
return 12
|
||||
}
|
||||
|
||||
func reverse(a []byte) {
|
||||
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
|
||||
a[left], a[right] = a[right], a[left]
|
||||
}
|
||||
}
|
||||
|
76
vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go
generated
vendored
76
vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go
generated
vendored
@ -1,76 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// StkSource is used to create and verify source address tokens
|
||||
type StkSource interface {
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type stkSource struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
const stkKeySize = 16
|
||||
|
||||
// Chrome currently sets this to 12, but discusses changing it to 16. We start
|
||||
// at 16 :)
|
||||
const stkNonceSize = 16
|
||||
|
||||
// NewStkSource creates a source for source address tokens
|
||||
func NewStkSource() (StkSource, error) {
|
||||
secret := make([]byte, 32)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := deriveKey(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &stkSource{aead: aead}, nil
|
||||
}
|
||||
|
||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
||||
nonce := make([]byte, stkNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.aead.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < stkNonceSize {
|
||||
return nil, fmt.Errorf("STK too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:stkNonceSize]
|
||||
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
|
||||
}
|
||||
|
||||
func deriveKey(secret []byte) ([]byte, error) {
|
||||
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
|
||||
key := make([]byte, stkKeySize)
|
||||
if _, err := io.ReadFull(r, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
82
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
82
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
@ -4,41 +4,38 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type baseFlowController struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
// for sending data
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
|
||||
lastWindowUpdateTime time.Time
|
||||
// for receiving data
|
||||
mutex sync.RWMutex
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowSize protocol.ByteCount
|
||||
maxReceiveWindowSize protocol.ByteCount
|
||||
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowIncrement protocol.ByteCount
|
||||
maxReceiveWindowIncrement protocol.ByteCount
|
||||
epochStartTime time.Time
|
||||
epochStartOffset protocol.ByteCount
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
||||
// it returns true if the window was actually updated
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if offset > c.sendWindow {
|
||||
c.sendWindow = offset
|
||||
}
|
||||
@ -57,52 +54,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// pretend we sent a WindowUpdate when reading the first byte
|
||||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
||||
// this way auto-tuning of the window size already works for the first WindowUpdate
|
||||
if c.bytesRead == 0 {
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than the threshold was consumed
|
||||
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
|
||||
}
|
||||
|
||||
// getWindowUpdate updates the receive window, if necessary
|
||||
// it returns the new offset
|
||||
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
||||
diff := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than half of it was already consumed
|
||||
if diff >= (c.receiveWindowIncrement / 2) {
|
||||
if !c.hasWindowUpdate() {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.maybeAdjustWindowIncrement()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
c.maybeAdjustWindowSize()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||
return c.receiveWindow
|
||||
}
|
||||
|
||||
func (c *baseFlowController) IsBlocked() bool {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
return c.sendWindowSize() == 0
|
||||
}
|
||||
|
||||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
||||
func (c *baseFlowController) maybeAdjustWindowIncrement() {
|
||||
if c.lastWindowUpdateTime.IsZero() {
|
||||
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||
func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||
// don't do anything if less than half the window has been consumed
|
||||
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||
return
|
||||
}
|
||||
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
|
||||
// interval between the window updates is sufficiently large, no need to increase the increment
|
||||
if timeSinceLastWindowUpdate >= 2*rtt {
|
||||
return
|
||||
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||
// window is consumed too fast, try to increase the window size
|
||||
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||
}
|
||||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
|
||||
func (c *baseFlowController) startNewAutoTuningEpoch() {
|
||||
c.epochStartTime = time.Now()
|
||||
c.epochStartOffset = c.bytesRead
|
||||
}
|
||||
|
||||
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||
|
@ -2,16 +2,18 @@ package flowcontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
lastBlockedAt protocol.ByteCount
|
||||
baseFlowController
|
||||
|
||||
queueWindowUpdate func()
|
||||
}
|
||||
|
||||
var _ ConnectionFlowController = &connectionFlowController{}
|
||||
@ -21,25 +23,37 @@ var _ ConnectionFlowController = &connectionFlowController{}
|
||||
func NewConnectionFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(),
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ConnectionFlowController {
|
||||
return &connectionFlowController{
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowIncrement: receiveWindow,
|
||||
maxReceiveWindowIncrement: maxReceiveWindow,
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
logger: logger,
|
||||
},
|
||||
queueWindowUpdate: queueWindowUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
return c.baseFlowController.sendWindowSize()
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||
c.mutex.Lock()
|
||||
@ -52,26 +66,34 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) MaybeQueueWindowUpdate() {
|
||||
c.mutex.Lock()
|
||||
hasWindowUpdate := c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if hasWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
oldWindowIncrement := c.receiveWindowIncrement
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if oldWindowIncrement < c.receiveWindowIncrement {
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
|
||||
if oldWindowSize < c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
||||
// EnsureMinimumWindowSize sets a minimum window size
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
|
||||
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if inc > c.receiveWindowIncrement {
|
||||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
|
||||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
||||
if inc > c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
|
||||
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
8
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
@ -5,17 +5,19 @@ import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
type flowController interface {
|
||||
// for sending
|
||||
SendWindowSize() protocol.ByteCount
|
||||
IsBlocked() bool
|
||||
UpdateSendWindow(protocol.ByteCount)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||
MaybeQueueWindowUpdate() // queues a window update, if necessary
|
||||
}
|
||||
|
||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsBlocked() (bool, protocol.ByteCount)
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
||||
@ -25,13 +27,15 @@ type StreamFlowController interface {
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
type connectionFlowControllerI interface {
|
||||
ConnectionFlowController
|
||||
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||
// for sending
|
||||
EnsureMinimumWindowIncrement(protocol.ByteCount)
|
||||
EnsureMinimumWindowSize(protocol.ByteCount)
|
||||
// for receiving
|
||||
IncrementHighestReceived(protocol.ByteCount) error
|
||||
}
|
||||
|
62
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
62
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
@ -3,7 +3,7 @@ package flowcontrol
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
@ -14,6 +14,8 @@ type streamFlowController struct {
|
||||
|
||||
streamID protocol.StreamID
|
||||
|
||||
queueWindowUpdate func()
|
||||
|
||||
connection connectionFlowControllerI
|
||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||
|
||||
@ -30,18 +32,22 @@ func NewStreamFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(protocol.StreamID),
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
contributesToConnection: contributesToConnection,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowIncrement: receiveWindow,
|
||||
maxReceiveWindowIncrement: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -102,9 +108,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
window := c.baseFlowController.sendWindowSize()
|
||||
if c.contributesToConnection {
|
||||
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
||||
@ -112,17 +115,44 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
return window
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// IsBlocked says if it is blocked by stream-level flow control.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 {
|
||||
return false, 0
|
||||
}
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
oldWindowIncrement := c.receiveWindowIncrement
|
||||
func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
||||
c.mutex.Lock()
|
||||
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if hasWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
if c.contributesToConnection {
|
||||
c.connection.MaybeQueueWindowUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
|
||||
c.mutex.Lock()
|
||||
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
|
||||
if c.receivedFinalOffset {
|
||||
c.mutex.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||
if c.contributesToConnection {
|
||||
c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier))
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
|
12
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
12
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
@ -6,7 +6,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -29,17 +29,17 @@ type token struct {
|
||||
|
||||
// A CookieGenerator generates Cookies
|
||||
type CookieGenerator struct {
|
||||
cookieSource crypto.StkSource
|
||||
cookieProtector mint.CookieProtector
|
||||
}
|
||||
|
||||
// NewCookieGenerator initializes a new CookieGenerator
|
||||
func NewCookieGenerator() (*CookieGenerator, error) {
|
||||
stkSource, err := crypto.NewStkSource()
|
||||
cookieProtector, err := mint.NewDefaultCookieProtector()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CookieGenerator{
|
||||
cookieSource: stkSource,
|
||||
cookieProtector: cookieProtector,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.cookieSource.NewToken(data)
|
||||
return g.cookieProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// DecodeToken decodes a Cookie
|
||||
@ -62,7 +62,7 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := g.cookieSource.DecodeToken(encrypted)
|
||||
data, err := g.cookieProtector.DecodeToken(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
26
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
@ -7,36 +7,44 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type cookieHandler struct {
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
|
||||
// A CookieHandler generates and validates cookies.
|
||||
// The cookie is sent in the TLS Retry.
|
||||
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
|
||||
type CookieHandler struct {
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
cookieGenerator *CookieGenerator
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ mint.CookieHandler = &cookieHandler{}
|
||||
var _ mint.CookieHandler = &CookieHandler{}
|
||||
|
||||
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) {
|
||||
// NewCookieHandler creates a new CookieHandler.
|
||||
func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cookieHandler{
|
||||
return &CookieHandler{
|
||||
callback: callback,
|
||||
cookieGenerator: cookieGenerator,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
// Generate a new cookie for a mint connection.
|
||||
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
if h.callback(conn.RemoteAddr(), nil) {
|
||||
return nil, nil
|
||||
}
|
||||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
// Validate a cookie.
|
||||
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
data, err := h.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
return false
|
||||
}
|
||||
return h.callback(conn.RemoteAddr(), data)
|
||||
|
90
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
90
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
@ -38,23 +38,24 @@ type cryptoSetupClient struct {
|
||||
lastSentCHLO []byte
|
||||
certManager crypto.CertManager
|
||||
|
||||
divNonceChan chan []byte
|
||||
divNonceChan chan struct{}
|
||||
diversificationNonce []byte
|
||||
|
||||
clientHelloCounter int
|
||||
serverVerified bool // has the certificate chain and the proof already been verified
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
|
||||
receivedSecurePacket bool
|
||||
nullAEAD crypto.AEAD
|
||||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
|
||||
paramsChan chan<- TransportParameters
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
paramsChan chan<- TransportParameters
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
params *TransportParameters
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupClient{}
|
||||
@ -74,15 +75,17 @@ func NewCryptoSetupClient(
|
||||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cryptoSetupClient{
|
||||
divNonceChan := make(chan struct{})
|
||||
cs := &cryptoSetupClient{
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
@ -90,19 +93,20 @@ func NewCryptoSetupClient(
|
||||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
params: params,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
negotiatedVersions: negotiatedVersions,
|
||||
divNonceChan: make(chan []byte),
|
||||
}, nil
|
||||
divNonceChan: divNonceChan,
|
||||
logger: logger,
|
||||
}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
messageChan := make(chan HandshakeMessage)
|
||||
errorChan := make(chan error)
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
@ -116,37 +120,30 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
}()
|
||||
|
||||
for {
|
||||
err := h.maybeUpgradeCrypto()
|
||||
if err != nil {
|
||||
if err := h.maybeUpgradeCrypto(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.mutex.RLock()
|
||||
sendCHLO := h.secureAEAD == nil
|
||||
h.mutex.RUnlock()
|
||||
|
||||
if sendCHLO {
|
||||
err = h.sendCHLO()
|
||||
if err != nil {
|
||||
if err := h.sendCHLO(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var message HandshakeMessage
|
||||
select {
|
||||
case divNonce := <-h.divNonceChan:
|
||||
if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) {
|
||||
return errConflictingDiversificationNonces
|
||||
}
|
||||
h.diversificationNonce = divNonce
|
||||
case <-h.divNonceChan:
|
||||
// there's no message to process, but we should try upgrading the crypto again
|
||||
continue
|
||||
case message = <-messageChan:
|
||||
case err = <-errorChan:
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
}
|
||||
|
||||
utils.Debugf("Got %s", message)
|
||||
h.logger.Debugf("Got %s", message)
|
||||
switch message.Tag {
|
||||
case TagREJ:
|
||||
if err := h.handleREJMessage(message.Data); err != nil {
|
||||
@ -159,8 +156,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
}
|
||||
// blocks until the session has received the parameters
|
||||
h.paramsChan <- *params
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
close(h.aeadChanged)
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
default:
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
@ -211,7 +208,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
||||
|
||||
err = h.certManager.Verify(h.hostname)
|
||||
if err != nil {
|
||||
utils.Infof("Certificate validation failed: %s", err.Error())
|
||||
h.logger.Infof("Certificate validation failed: %s", err.Error())
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
}
|
||||
@ -219,7 +216,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
||||
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
|
||||
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
|
||||
if !validProof {
|
||||
utils.Infof("Server proof verification failed")
|
||||
h.logger.Infof("Server proof verification failed")
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
|
||||
@ -277,6 +274,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*Trans
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
|
||||
|
||||
params, err := readHelloMap(cryptoData)
|
||||
if err != nil {
|
||||
@ -322,6 +320,7 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
|
||||
if h.secureAEAD != nil {
|
||||
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return data, protocol.EncryptionSecure, nil
|
||||
}
|
||||
@ -373,16 +372,28 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
|
||||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) DiversificationNonce() []byte {
|
||||
panic("not needed for cryptoSetupClient")
|
||||
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
HandshakeComplete: h.forwardSecureAEAD != nil,
|
||||
PeerCertificates: h.certManager.GetChain(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
||||
h.divNonceChan <- data
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
|
||||
h.mutex.Lock()
|
||||
if len(h.diversificationNonce) > 0 {
|
||||
defer h.mutex.Unlock()
|
||||
if !bytes.Equal(h.diversificationNonce, divNonce) {
|
||||
return errConflictingDiversificationNonces
|
||||
}
|
||||
return nil
|
||||
}
|
||||
h.diversificationNonce = divNonce
|
||||
h.mutex.Unlock()
|
||||
h.divNonceChan <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
@ -403,7 +414,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
|
||||
Data: tags,
|
||||
}
|
||||
|
||||
utils.Debugf("Sending %s", message)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
message.Write(b)
|
||||
|
||||
_, err = h.cryptoStream.Write(b.Bytes())
|
||||
@ -462,7 +473,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
|
||||
for _, tag := range tags {
|
||||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
||||
}
|
||||
paddingSize := protocol.ClientHelloMinimumSize - size
|
||||
paddingSize := protocol.MinClientHelloSize - size
|
||||
if paddingSize > 0 {
|
||||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
||||
}
|
||||
@ -500,10 +511,9 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionSecure
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
98
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
98
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
@ -19,10 +19,12 @@ import (
|
||||
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
// KeyExchangeFunction is used to make a new KEX
|
||||
type KeyExchangeFunction func() crypto.KeyExchange
|
||||
type KeyExchangeFunction func() (crypto.KeyExchange, error)
|
||||
|
||||
// The CryptoSetupServer handles all things crypto for the Session
|
||||
type cryptoSetupServer struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
connID protocol.ConnectionID
|
||||
remoteAddr net.Addr
|
||||
scfg *ServerConfig
|
||||
@ -42,7 +44,7 @@ type cryptoSetupServer struct {
|
||||
|
||||
receivedParams bool
|
||||
paramsChan chan<- TransportParameters
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
@ -51,7 +53,9 @@ type cryptoSetupServer struct {
|
||||
|
||||
params *TransportParameters
|
||||
|
||||
mutex sync.RWMutex
|
||||
sni string // need to fill out the ConnectionState
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupServer{}
|
||||
@ -71,32 +75,36 @@ func NewCryptoSetup(
|
||||
connID protocol.ConnectionID,
|
||||
remoteAddr net.Addr,
|
||||
version protocol.VersionNumber,
|
||||
divNonce []byte,
|
||||
scfg *ServerConfig,
|
||||
params *TransportParameters,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
acceptSTK func(net.Addr, *Cookie) bool,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cryptoSetupServer{
|
||||
cryptoStream: cryptoStream,
|
||||
connID: connID,
|
||||
remoteAddr: remoteAddr,
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
scfg: scfg,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
params: params,
|
||||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
aeadChanged: aeadChanged,
|
||||
cryptoStream: cryptoStream,
|
||||
connID: connID,
|
||||
remoteAddr: remoteAddr,
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
diversificationNonce: divNonce,
|
||||
scfg: scfg,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
params: params,
|
||||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -112,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
|
||||
utils.Debugf("Got %s", message)
|
||||
h.logger.Debugf("Got %s", message)
|
||||
done, err := h.handleMessage(chloData.Bytes(), message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -139,6 +147,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
||||
if sni == "" {
|
||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
||||
}
|
||||
h.sni = sni
|
||||
|
||||
// prevent version downgrade attacks
|
||||
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
|
||||
@ -182,7 +191,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
||||
if _, err := h.cryptoStream.Write(reply); err != nil {
|
||||
return false, err
|
||||
}
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.sentSHLO)
|
||||
return true, nil
|
||||
}
|
||||
@ -205,10 +214,11 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
||||
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
||||
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
|
||||
h.receivedForwardSecurePacket = true
|
||||
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan
|
||||
// wait for the send on the handshakeEvent chan
|
||||
<-h.sentSHLO
|
||||
close(h.aeadChanged)
|
||||
close(h.handshakeEvent)
|
||||
}
|
||||
return res, protocol.EncryptionForwardSecure, nil
|
||||
}
|
||||
@ -219,6 +229,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
||||
if h.secureAEAD != nil {
|
||||
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return res, protocol.EncryptionSecure, nil
|
||||
}
|
||||
@ -294,17 +305,13 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
||||
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
||||
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("STK invalid: %s", err.Error())
|
||||
h.logger.Debugf("STK invalid: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
return h.acceptSTKCallback(h.remoteAddr, stk)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
if len(chlo) < protocol.ClientHelloMinimumSize {
|
||||
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
|
||||
}
|
||||
|
||||
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -341,7 +348,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
||||
|
||||
var serverReply bytes.Buffer
|
||||
message.Write(&serverReply)
|
||||
utils.Debugf("Sending %s", message)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return serverReply.Bytes(), nil
|
||||
}
|
||||
|
||||
@ -365,11 +372,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.diversificationNonce = make([]byte, 32)
|
||||
if _, err = rand.Read(h.diversificationNonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientNonce := cryptoData[TagNONC]
|
||||
err = h.validateClientNonce(clientNonce)
|
||||
if err != nil {
|
||||
@ -400,14 +402,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionSecure
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
|
||||
// Generate a new curve instance to derive the forward secure key
|
||||
var fsNonce bytes.Buffer
|
||||
fsNonce.Write(clientNonce)
|
||||
fsNonce.Write(serverNonce)
|
||||
ephermalKex := h.keyExchange()
|
||||
ephermalKex, err := h.keyExchange()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -427,6 +432,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
|
||||
|
||||
replyMap := h.params.getHelloMap()
|
||||
// add crypto parameters
|
||||
@ -445,21 +451,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
||||
}
|
||||
var reply bytes.Buffer
|
||||
message.Write(&reply)
|
||||
utils.Debugf("Sending %s", message)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return reply.Bytes(), nil
|
||||
}
|
||||
|
||||
// DiversificationNonce returns the diversification nonce
|
||||
func (h *cryptoSetupServer) DiversificationNonce() []byte {
|
||||
return h.diversificationNonce
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
ServerName: h.sni,
|
||||
HandshakeComplete: h.receivedForwardSecurePacket,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||
|
189
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
189
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
@ -1,10 +1,9 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
@ -12,6 +11,9 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
|
||||
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
|
||||
|
||||
// KeyDerivationFunction is used for key derivation
|
||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
@ -20,64 +22,33 @@ type cryptoSetupTLS struct {
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
tls mintTLS
|
||||
conn *fakeConn
|
||||
|
||||
nextPacketType protocol.PacketType
|
||||
|
||||
keyDerivation KeyDerivationFunction
|
||||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
tls MintTLS
|
||||
cryptoStream *CryptoStreamConn
|
||||
handshakeEvent chan<- struct{}
|
||||
}
|
||||
|
||||
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
func NewCryptoSetupTLSServer(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
tlsConfig *tls.Config,
|
||||
remoteAddr net.Addr,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
checkCookie func(net.Addr, *Cookie) bool,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
tls MintTLS,
|
||||
cryptoStream *CryptoStreamConn,
|
||||
nullAEAD crypto.AEAD,
|
||||
handshakeEvent chan<- struct{},
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mintConf.RequireCookie = true
|
||||
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := &fakeConn{
|
||||
stream: cryptoStream,
|
||||
pers: protocol.PerspectiveServer,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
mintConn := mint.Server(conn, mintConf)
|
||||
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
) CryptoSetupTLS {
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveServer,
|
||||
tls: &mintController{mintConn},
|
||||
conn: conn,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
}, nil
|
||||
tls: tls,
|
||||
cryptoStream: cryptoStream,
|
||||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||
@ -85,59 +56,44 @@ func NewCryptoSetupTLSClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
hostname string,
|
||||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
initialVersion protocol.VersionNumber,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
handshakeEvent chan<- struct{},
|
||||
tls MintTLS,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mintConf.ServerName = hostname
|
||||
conn := &fakeConn{
|
||||
stream: cryptoStream,
|
||||
pers: protocol.PerspectiveClient,
|
||||
}
|
||||
mintConn := mint.Client(conn, mintConf)
|
||||
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
) (CryptoSetupTLS, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cryptoSetupTLS{
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: &mintController{mintConn},
|
||||
tls: tls,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
nextPacketType: protocol.PacketTypeInitial,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
|
||||
// send out that data now
|
||||
if _, err := h.cryptoStream.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
handshakeLoop:
|
||||
for {
|
||||
switch alert := h.tls.Handshake(); alert {
|
||||
case mint.AlertNoAlert: // handshake complete
|
||||
break handshakeLoop
|
||||
case mint.AlertWouldBlock:
|
||||
h.determineNextPacketType()
|
||||
if err := h.conn.Continue(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
switch h.tls.State() {
|
||||
case mint.StateClientStart: // this happens if a stateless retry is performed
|
||||
return ErrCloseSessionForRetry
|
||||
case mint.StateClientConnected, mint.StateServerConnected:
|
||||
break handshakeLoop
|
||||
}
|
||||
}
|
||||
|
||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
||||
@ -148,28 +104,23 @@ handshakeLoop:
|
||||
h.aead = aead
|
||||
h.mutex.Unlock()
|
||||
|
||||
// signal to the outside world that the handshake completed
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
close(h.aeadChanged)
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.aead != nil {
|
||||
data, err := h.aead.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return data, protocol.EncryptionForwardSecure, nil
|
||||
if h.aead == nil {
|
||||
return nil, errors.New("no 1-RTT sealer")
|
||||
}
|
||||
data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return data, protocol.EncryptionUnencrypted, nil
|
||||
return h.aead.Open(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
@ -204,39 +155,13 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) determineNextPacketType() error {
|
||||
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
state := h.tls.State().HandshakeState
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
switch state {
|
||||
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
|
||||
h.nextPacketType = protocol.PacketTypeRetry
|
||||
case "ServerStateWaitFinished":
|
||||
h.nextPacketType = protocol.PacketTypeHandshake
|
||||
default:
|
||||
// TODO: accept 0-RTT data
|
||||
return fmt.Errorf("Unexpected handshake state: %s", state)
|
||||
}
|
||||
return nil
|
||||
mintConnState := h.tls.ConnectionState()
|
||||
return ConnectionState{
|
||||
// TODO: set the ServerName, once mint exports it
|
||||
HandshakeComplete: h.aead != nil,
|
||||
PeerCertificates: mintConnState.PeerCertificates,
|
||||
}
|
||||
// client
|
||||
if state != "ClientStateWaitSH" {
|
||||
h.nextPacketType = protocol.PacketTypeHandshake
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.nextPacketType
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
||||
panic("diversification nonce not needed for TLS")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
|
||||
panic("diversification nonce not needed for TLS")
|
||||
}
|
||||
|
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
Normal file
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
Normal file
@ -0,0 +1,101 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// The CryptoStreamConn is used as the net.Conn passed to mint.
|
||||
// It has two operating modes:
|
||||
// 1. It can read and write to bytes.Buffers.
|
||||
// 2. It can use a quic.Stream for reading and writing.
|
||||
// The buffer-mode is only used by the server, in order to statelessly handle retries.
|
||||
type CryptoStreamConn struct {
|
||||
remoteAddr net.Addr
|
||||
|
||||
// the buffers are used before the session is initialized
|
||||
readBuf bytes.Buffer
|
||||
writeBuf bytes.Buffer
|
||||
|
||||
// stream will be set once the session is initialized
|
||||
stream io.ReadWriter
|
||||
}
|
||||
|
||||
var _ net.Conn = &CryptoStreamConn{}
|
||||
|
||||
// NewCryptoStreamConn creates a new CryptoStreamConn
|
||||
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
|
||||
return &CryptoStreamConn{remoteAddr: remoteAddr}
|
||||
}
|
||||
|
||||
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
|
||||
if c.stream != nil {
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
return c.readBuf.Read(b)
|
||||
}
|
||||
|
||||
// AddDataForReading adds data to the read buffer.
|
||||
// This data will ONLY be read when the stream has not been set.
|
||||
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
|
||||
c.readBuf.Write(data)
|
||||
}
|
||||
|
||||
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
|
||||
if c.stream != nil {
|
||||
return c.stream.Write(p)
|
||||
}
|
||||
return c.writeBuf.Write(p)
|
||||
}
|
||||
|
||||
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
|
||||
func (c *CryptoStreamConn) GetDataForWriting() []byte {
|
||||
defer c.writeBuf.Reset()
|
||||
data := make([]byte, c.writeBuf.Len())
|
||||
copy(data, c.writeBuf.Bytes())
|
||||
return data
|
||||
}
|
||||
|
||||
// SetStream sets the stream.
|
||||
// After setting the stream, the read and write buffer won't be used any more.
|
||||
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
|
||||
c.stream = stream
|
||||
}
|
||||
|
||||
// Flush copies the contents of the write buffer to the stream
|
||||
func (c *CryptoStreamConn) Flush() (int, error) {
|
||||
n, err := io.Copy(c.stream, &c.writeBuf)
|
||||
return int(n), err
|
||||
}
|
||||
|
||||
// Close is not implemented
|
||||
func (c *CryptoStreamConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr is not implemented
|
||||
func (c *CryptoStreamConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address
|
||||
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
// SetReadDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -24,27 +23,26 @@ var (
|
||||
// used for all connections for 60 seconds is negligible. Thus we can amortise
|
||||
// the Diffie-Hellman key generation at the server over all the connections in a
|
||||
// small time span.
|
||||
func getEphermalKEX() (res crypto.KeyExchange) {
|
||||
func getEphermalKEX() (crypto.KeyExchange, error) {
|
||||
kexMutex.RLock()
|
||||
res = kexCurrent
|
||||
res := kexCurrent
|
||||
t := kexCurrentTime
|
||||
kexMutex.RUnlock()
|
||||
if res != nil && time.Since(t) < kexLifetime {
|
||||
return res
|
||||
return res, nil
|
||||
}
|
||||
|
||||
kexMutex.Lock()
|
||||
defer kexMutex.Unlock()
|
||||
// Check if still unfulfilled
|
||||
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
|
||||
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
utils.Errorf("could not set KEX: %s", err.Error())
|
||||
return kexCurrent
|
||||
return nil, err
|
||||
}
|
||||
kexCurrent = kex
|
||||
kexCurrentTime = time.Now()
|
||||
return kexCurrent
|
||||
return kexCurrent, nil
|
||||
}
|
||||
return kexCurrent
|
||||
return kexCurrent, nil
|
||||
}
|
||||
|
5
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
5
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
|
||||
|
||||
offset := uint32(0)
|
||||
for i, t := range h.getTagsSorted() {
|
||||
v := data[Tag(t)]
|
||||
v := data[t]
|
||||
b.Write(v)
|
||||
offset += uint32(len(v))
|
||||
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
|
||||
@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag {
|
||||
func (h HandshakeMessage) String() string {
|
||||
var pad string
|
||||
res := tagToString(h.Tag) + ":\n"
|
||||
for _, t := range h.getTagsSorted() {
|
||||
tag := Tag(t)
|
||||
for _, tag := range h.getTagsSorted() {
|
||||
if tag == TagPAD {
|
||||
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
|
||||
} else {
|
||||
|
57
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
57
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
@ -1,6 +1,11 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"io"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
@ -10,16 +15,54 @@ type Sealer interface {
|
||||
Overhead() int
|
||||
}
|
||||
|
||||
// CryptoSetup is a crypto setup
|
||||
type CryptoSetup interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
// A TLSExtensionHandler sends and received the QUIC TLS extension.
|
||||
// It provides the parameters sent by the peer on a channel.
|
||||
type TLSExtensionHandler interface {
|
||||
Send(mint.HandshakeType, *mint.ExtensionList) error
|
||||
Receive(mint.HandshakeType, *mint.ExtensionList) error
|
||||
GetPeerParams() <-chan TransportParameters
|
||||
}
|
||||
|
||||
// MintTLS combines some methods needed to interact with mint.
|
||||
type MintTLS interface {
|
||||
crypto.TLSExporter
|
||||
|
||||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
State() mint.State
|
||||
ConnectionState() mint.ConnectionState
|
||||
|
||||
SetCryptoStream(io.ReadWriter)
|
||||
}
|
||||
|
||||
type baseCryptoSetup interface {
|
||||
HandleCryptoStream() error
|
||||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||
}
|
||||
|
||||
// CryptoSetup is the crypto setup used by gQUIC
|
||||
type CryptoSetup interface {
|
||||
baseCryptoSetup
|
||||
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
}
|
||||
|
||||
// CryptoSetupTLS is the crypto setup used by IETF QUIC
|
||||
type CryptoSetupTLS interface {
|
||||
baseCryptoSetup
|
||||
|
||||
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
type ConnectionState struct {
|
||||
HandshakeComplete bool // handshake is complete
|
||||
ServerName string // server name requested by client, if any (server side only)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
|
||||
}
|
||||
|
127
vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go
generated
vendored
127
vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go
generated
vendored
@ -1,127 +0,0 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
gocrypto "crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
|
||||
mconf := &mint.Config{
|
||||
NonBlocking: true,
|
||||
CipherSuites: []mint.CipherSuite{
|
||||
mint.TLS_AES_128_GCM_SHA256,
|
||||
mint.TLS_AES_256_GCM_SHA384,
|
||||
},
|
||||
}
|
||||
if tlsConf != nil {
|
||||
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
|
||||
for i, certChain := range tlsConf.Certificates {
|
||||
mconf.Certificates[i] = &mint.Certificate{
|
||||
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
|
||||
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
|
||||
}
|
||||
for j, cert := range certChain.Certificate {
|
||||
c, err := x509.ParseCertificate(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mconf.Certificates[i].Chain[j] = c
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mconf, nil
|
||||
}
|
||||
|
||||
type mintTLS interface {
|
||||
// These two methods are the same as the crypto.TLSExporter interface.
|
||||
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
|
||||
GetCipherSuite() mint.CipherSuiteParams
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
State() mint.ConnectionState
|
||||
}
|
||||
|
||||
var _ crypto.TLSExporter = (mintTLS)(nil)
|
||||
|
||||
type mintController struct {
|
||||
conn *mint.Conn
|
||||
}
|
||||
|
||||
var _ mintTLS = &mintController{}
|
||||
|
||||
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
|
||||
return mc.conn.State().CipherSuite
|
||||
}
|
||||
|
||||
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
return mc.conn.ComputeExporter(label, context, keyLength)
|
||||
}
|
||||
|
||||
func (mc *mintController) Handshake() mint.Alert {
|
||||
return mc.conn.Handshake()
|
||||
}
|
||||
|
||||
func (mc *mintController) State() mint.ConnectionState {
|
||||
return mc.conn.State()
|
||||
}
|
||||
|
||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
||||
// so we wrap a stream such that implements a net.Conn
|
||||
type fakeConn struct {
|
||||
stream io.ReadWriter
|
||||
pers protocol.Perspective
|
||||
remoteAddr net.Addr
|
||||
|
||||
blockRead bool
|
||||
writeBuffer bytes.Buffer
|
||||
}
|
||||
|
||||
var _ net.Conn = &fakeConn{}
|
||||
|
||||
func (c *fakeConn) Read(b []byte) (int, error) {
|
||||
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
|
||||
return 0, nil
|
||||
}
|
||||
c.blockRead = true // block the next Read call
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
|
||||
func (c *fakeConn) Write(p []byte) (int, error) {
|
||||
if c.pers == protocol.PerspectiveClient {
|
||||
return c.stream.Write(p)
|
||||
}
|
||||
// Buffer all writes by the server.
|
||||
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
|
||||
return c.writeBuffer.Write(p)
|
||||
}
|
||||
|
||||
func (c *fakeConn) Continue() error {
|
||||
c.blockRead = false
|
||||
if c.pers == protocol.PerspectiveClient {
|
||||
return nil
|
||||
}
|
||||
// write all contents of the write buffer to the stream.
|
||||
_, err := c.stream.Write(c.writeBuffer.Bytes())
|
||||
c.writeBuffer.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *fakeConn) Close() error { return nil }
|
||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *fakeConn) SetDeadline(time.Time) error { return nil }
|
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go
generated
vendored
@ -9,10 +9,10 @@ import (
|
||||
|
||||
// ServerConfig is a server config
|
||||
type ServerConfig struct {
|
||||
kex crypto.KeyExchange
|
||||
certChain crypto.CertChain
|
||||
ID []byte
|
||||
obit []byte
|
||||
kex crypto.KeyExchange
|
||||
certChain crypto.CertChain
|
||||
ID []byte
|
||||
obit []byte
|
||||
cookieGenerator *CookieGenerator
|
||||
}
|
||||
|
||||
@ -36,10 +36,10 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve
|
||||
}
|
||||
|
||||
return &ServerConfig{
|
||||
kex: kex,
|
||||
certChain: certChain,
|
||||
ID: id,
|
||||
obit: obit,
|
||||
kex: kex,
|
||||
certChain: certChain,
|
||||
ID: id,
|
||||
obit: obit,
|
||||
cookieGenerator: cookieGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user