update vendor
This commit is contained in:
parent
030fbd8521
commit
25d223c384
6
vendor/github.com/bifurcation/mint/README.md
generated
vendored
6
vendor/github.com/bifurcation/mint/README.md
generated
vendored
@ -18,6 +18,12 @@ library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns
|
|||||||
with earlier TLS versions. However, unnecessary parts will be ruthlessly cut
|
with earlier TLS versions. However, unnecessary parts will be ruthlessly cut
|
||||||
off.
|
off.
|
||||||
|
|
||||||
|
## DTLS Support
|
||||||
|
|
||||||
|
Mint has partial support for DTLS, but that support is not yet complete
|
||||||
|
and may still contain serious defects.
|
||||||
|
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
Installation is the same as for any other Go package:
|
Installation is the same as for any other Go package:
|
||||||
|
2
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
2
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
@ -46,6 +46,7 @@ const (
|
|||||||
AlertBadCertificateHashValue Alert = 114
|
AlertBadCertificateHashValue Alert = 114
|
||||||
AlertUnknownPSKIdentity Alert = 115
|
AlertUnknownPSKIdentity Alert = 115
|
||||||
AlertNoApplicationProtocol Alert = 120
|
AlertNoApplicationProtocol Alert = 120
|
||||||
|
AlertStatelessRetry Alert = 253
|
||||||
AlertWouldBlock Alert = 254
|
AlertWouldBlock Alert = 254
|
||||||
AlertNoAlert Alert = 255
|
AlertNoAlert Alert = 255
|
||||||
)
|
)
|
||||||
@ -82,6 +83,7 @@ var alertText = map[Alert]string{
|
|||||||
AlertUnknownPSKIdentity: "unknown PSK identity",
|
AlertUnknownPSKIdentity: "unknown PSK identity",
|
||||||
AlertNoApplicationProtocol: "no application protocol",
|
AlertNoApplicationProtocol: "no application protocol",
|
||||||
AlertNoRenegotiation: "no renegotiation",
|
AlertNoRenegotiation: "no renegotiation",
|
||||||
|
AlertStatelessRetry: "stateless retry",
|
||||||
AlertWouldBlock: "would have blocked",
|
AlertWouldBlock: "would have blocked",
|
||||||
AlertNoAlert: "no alert",
|
AlertNoAlert: "no alert",
|
||||||
}
|
}
|
||||||
|
479
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
479
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
@ -3,6 +3,7 @@ package mint
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto"
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
"hash"
|
"hash"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -49,29 +50,31 @@ import (
|
|||||||
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
||||||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||||
|
|
||||||
type ClientStateStart struct {
|
type clientStateStart struct {
|
||||||
Caps Capabilities
|
Config *Config
|
||||||
Opts ConnectionOptions
|
Opts ConnectionOptions
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
|
||||||
cookie []byte
|
cookie []byte
|
||||||
firstClientHello *HandshakeMessage
|
firstClientHello *HandshakeMessage
|
||||||
helloRetryRequest *HandshakeMessage
|
helloRetryRequest *HandshakeMessage
|
||||||
|
hsCtx *HandshakeContext
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
var _ HandshakeState = &clientStateStart{}
|
||||||
if hm != nil {
|
|
||||||
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
|
|
||||||
return nil, nil, AlertUnexpectedMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func (state clientStateStart) State() State {
|
||||||
|
return StateClientStart
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
// key_shares
|
// key_shares
|
||||||
offeredDH := map[NamedGroup][]byte{}
|
offeredDH := map[NamedGroup][]byte{}
|
||||||
ks := KeyShareExtension{
|
ks := KeyShareExtension{
|
||||||
HandshakeType: HandshakeTypeClientHello,
|
HandshakeType: HandshakeTypeClientHello,
|
||||||
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
|
Shares: make([]KeyShareEntry, len(state.Config.Groups)),
|
||||||
}
|
}
|
||||||
for i, group := range state.Caps.Groups {
|
for i, group := range state.Config.Groups {
|
||||||
pub, priv, err := newKeyShare(group)
|
pub, priv, err := newKeyShare(group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
|
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
|
||||||
@ -86,10 +89,10 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
logf(logTypeHandshake, "opts: %+v", state.Opts)
|
logf(logTypeHandshake, "opts: %+v", state.Opts)
|
||||||
|
|
||||||
// supported_versions, supported_groups, signature_algorithms, server_name
|
// supported_versions, supported_groups, signature_algorithms, server_name
|
||||||
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
|
sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}}
|
||||||
sni := ServerNameExtension(state.Opts.ServerName)
|
sni := ServerNameExtension(state.Opts.ServerName)
|
||||||
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
|
sg := SupportedGroupsExtension{Groups: state.Config.Groups}
|
||||||
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
sa := SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes}
|
||||||
|
|
||||||
state.Params.ServerName = state.Opts.ServerName
|
state.Params.ServerName = state.Opts.ServerName
|
||||||
|
|
||||||
@ -101,7 +104,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
|
|
||||||
// Construct base ClientHello
|
// Construct base ClientHello
|
||||||
ch := &ClientHelloBody{
|
ch := &ClientHelloBody{
|
||||||
CipherSuites: state.Caps.CipherSuites,
|
LegacyVersion: wireVersion(state.hsCtx.hIn),
|
||||||
|
CipherSuites: state.Config.CipherSuites,
|
||||||
}
|
}
|
||||||
_, err := prng.Read(ch.Random[:])
|
_, err := prng.Read(ch.Random[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -133,8 +137,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run the external extension handler.
|
// Run the external extension handler.
|
||||||
if state.Caps.ExtensionHandler != nil {
|
if state.Config.ExtensionHandler != nil {
|
||||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
|
err := state.Config.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
|
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
@ -150,7 +154,7 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
var earlySecret []byte
|
var earlySecret []byte
|
||||||
var clientEarlyTrafficKeys keySet
|
var clientEarlyTrafficKeys keySet
|
||||||
var clientHello *HandshakeMessage
|
var clientHello *HandshakeMessage
|
||||||
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
|
if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok {
|
||||||
offeredPSK = key
|
offeredPSK = key
|
||||||
|
|
||||||
// Narrow ciphersuites to ones that match PSK hash
|
// Narrow ciphersuites to ones that match PSK hash
|
||||||
@ -168,8 +172,10 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
}
|
}
|
||||||
ch.CipherSuites = compatibleSuites
|
ch.CipherSuites = compatibleSuites
|
||||||
|
|
||||||
|
// TODO(ekr@rtfm.com): Check that the ticket can be used for early
|
||||||
|
// data.
|
||||||
// Signal early data if we're going to do it
|
// Signal early data if we're going to do it
|
||||||
if len(state.Opts.EarlyData) > 0 {
|
if state.Config.AllowEarlyData && state.helloRetryRequest == nil {
|
||||||
state.Params.ClientSendingEarlyData = true
|
state.Params.ClientSendingEarlyData = true
|
||||||
ed = &EarlyDataExtension{}
|
ed = &EarlyDataExtension{}
|
||||||
err = ch.Extensions.Add(ed)
|
err = ch.Extensions.Add(ed)
|
||||||
@ -180,11 +186,11 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Signal supported PSK key exchange modes
|
// Signal supported PSK key exchange modes
|
||||||
if len(state.Caps.PSKModes) == 0 {
|
if len(state.Config.PSKModes) == 0 {
|
||||||
logf(logTypeHandshake, "PSK selected, but no PSKModes")
|
logf(logTypeHandshake, "PSK selected, but no PSKModes")
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
}
|
}
|
||||||
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
|
kem := &PSKKeyExchangeModesExtension{KEModes: state.Config.PSKModes}
|
||||||
err = ch.Extensions.Add(kem)
|
err = ch.Extensions.Add(kem)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
|
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
|
||||||
@ -241,7 +247,7 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
|
|
||||||
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
||||||
// this one should too.
|
// this one should too.
|
||||||
clientHello, _ = HandshakeMessageFromBody(ch)
|
clientHello, _ = state.hsCtx.hOut.HandshakeMessageFromBody(ch)
|
||||||
|
|
||||||
// Compute early traffic keys
|
// Compute early traffic keys
|
||||||
h := params.Hash.New()
|
h := params.Hash.New()
|
||||||
@ -251,11 +257,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||||
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
|
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
|
||||||
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
|
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
|
||||||
} else if len(state.Opts.EarlyData) > 0 {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
|
|
||||||
return nil, nil, AlertInternalError
|
|
||||||
} else {
|
} else {
|
||||||
clientHello, err = HandshakeMessageFromBody(ch)
|
clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
|
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
@ -263,10 +266,12 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
|
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
|
||||||
nextState := ClientStateWaitSH{
|
state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2.
|
||||||
Caps: state.Caps,
|
nextState := clientStateWaitSH{
|
||||||
|
Config: state.Config,
|
||||||
Opts: state.Opts,
|
Opts: state.Opts,
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
OfferedDH: offeredDH,
|
OfferedDH: offeredDH,
|
||||||
OfferedPSK: offeredPSK,
|
OfferedPSK: offeredPSK,
|
||||||
|
|
||||||
@ -279,22 +284,23 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
|
|||||||
}
|
}
|
||||||
|
|
||||||
toSend := []HandshakeAction{
|
toSend := []HandshakeAction{
|
||||||
SendHandshakeMessage{clientHello},
|
QueueHandshakeMessage{clientHello},
|
||||||
|
SendQueuedHandshake{},
|
||||||
}
|
}
|
||||||
if state.Params.ClientSendingEarlyData {
|
if state.Params.ClientSendingEarlyData {
|
||||||
toSend = append(toSend, []HandshakeAction{
|
toSend = append(toSend, []HandshakeAction{
|
||||||
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
|
RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys},
|
||||||
SendEarlyData{},
|
|
||||||
}...)
|
}...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nextState, toSend, AlertNoAlert
|
return nextState, toSend, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientStateWaitSH struct {
|
type clientStateWaitSH struct {
|
||||||
Caps Capabilities
|
Config *Config
|
||||||
Opts ConnectionOptions
|
Opts ConnectionOptions
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
hsCtx *HandshakeContext
|
||||||
OfferedDH map[NamedGroup][]byte
|
OfferedDH map[NamedGroup][]byte
|
||||||
OfferedPSK PreSharedKey
|
OfferedPSK PreSharedKey
|
||||||
PSK []byte
|
PSK []byte
|
||||||
@ -307,49 +313,73 @@ type ClientStateWaitSH struct {
|
|||||||
clientHello *HandshakeMessage
|
clientHello *HandshakeMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
var _ HandshakeState = &clientStateWaitSH{}
|
||||||
if hm == nil {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
|
func (state clientStateWaitSH) State() State {
|
||||||
|
return StateClientWaitSH
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
hm, alert := hr.ReadMessage()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
|
|
||||||
|
if hm == nil || hm.msgType != HandshakeTypeServerHello {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message")
|
||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyGeneric, err := hm.ToBody()
|
sh := &ServerHelloBody{}
|
||||||
|
if _, err := sh.Unmarshal(hm.body); err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] unexpected message")
|
||||||
|
return nil, nil, AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common SH/HRR processing first.
|
||||||
|
// 1. Check that sh.version is TLS 1.2
|
||||||
|
if sh.Version != tls12Version {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] illegal legacy version [%v]", sh.Version)
|
||||||
|
return nil, nil, AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check that it responded with a valid version.
|
||||||
|
supportedVersions := SupportedVersionsExtension{HandshakeType: HandshakeTypeServerHello}
|
||||||
|
foundSupportedVersions, err := sh.Extensions.Find(&supportedVersions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
|
logf(logTypeHandshake, "[ClientStateWaitSH] invalid supported_versions extension [%v]", err)
|
||||||
return nil, nil, AlertDecodeError
|
return nil, nil, AlertDecodeError
|
||||||
}
|
}
|
||||||
|
if !foundSupportedVersions {
|
||||||
switch body := bodyGeneric.(type) {
|
logf(logTypeHandshake, "[ClientStateWaitSH] no supported_versions extension")
|
||||||
case *HelloRetryRequestBody:
|
return nil, nil, AlertMissingExtension
|
||||||
hrr := body
|
|
||||||
|
|
||||||
if state.helloRetryRequest != nil {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
|
|
||||||
return nil, nil, AlertUnexpectedMessage
|
|
||||||
}
|
}
|
||||||
|
if supportedVersions.Versions[0] != supportedVersion {
|
||||||
// Check that the version sent by the server is the one we support
|
logf(logTypeHandshake, "[ClientStateWaitSH] unsupported version [%x]", supportedVersions.Versions[0])
|
||||||
if hrr.Version != supportedVersion {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
|
|
||||||
return nil, nil, AlertProtocolVersion
|
return nil, nil, AlertProtocolVersion
|
||||||
}
|
}
|
||||||
|
// 3. Check that the server provided a supported ciphersuite
|
||||||
// Check that the server provided a supported ciphersuite
|
|
||||||
supportedCipherSuite := false
|
supportedCipherSuite := false
|
||||||
for _, suite := range state.Caps.CipherSuites {
|
for _, suite := range state.Config.CipherSuites {
|
||||||
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
|
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
|
||||||
}
|
}
|
||||||
if !supportedCipherSuite {
|
if !supportedCipherSuite {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
|
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||||
return nil, nil, AlertHandshakeFailure
|
return nil, nil, AlertHandshakeFailure
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now check for the sentinel.
|
||||||
|
|
||||||
|
if sh.Random == hrrRandomSentinel {
|
||||||
|
// This is actually HRR.
|
||||||
|
hrr := sh
|
||||||
|
|
||||||
// Narrow the supported ciphersuites to the server-provided one
|
// Narrow the supported ciphersuites to the server-provided one
|
||||||
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
|
state.Config.CipherSuites = []CipherSuite{hrr.CipherSuite}
|
||||||
|
|
||||||
// Handle external extensions.
|
// Handle external extensions.
|
||||||
if state.Caps.ExtensionHandler != nil {
|
if state.Config.ExtensionHandler != nil {
|
||||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
err := state.Config.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
@ -358,10 +388,14 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
|
|||||||
|
|
||||||
// The only thing we know how to respond to in an HRR is the Cookie
|
// The only thing we know how to respond to in an HRR is the Cookie
|
||||||
// extension, so if there is either no Cookie extension or anything other
|
// extension, so if there is either no Cookie extension or anything other
|
||||||
// than a Cookie extension, we have to fail.
|
// than a Cookie extension and SupportedVersions we have to fail.
|
||||||
serverCookie := new(CookieExtension)
|
serverCookie := new(CookieExtension)
|
||||||
foundCookie := hrr.Extensions.Find(serverCookie)
|
foundCookie, err := hrr.Extensions.Find(serverCookie)
|
||||||
if !foundCookie || len(hrr.Extensions) != 1 {
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitSH] Invalid server cookie extension [%v]", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
if !foundCookie || len(hrr.Extensions) != 2 {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
|
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
|
||||||
return nil, nil, AlertIllegalParameter
|
return nil, nil, AlertIllegalParameter
|
||||||
}
|
}
|
||||||
@ -376,37 +410,26 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
|
|||||||
body: h.Sum(nil),
|
body: h.Sum(nil),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state.hsCtx.receivedEndOfFlight()
|
||||||
|
|
||||||
|
// TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT
|
||||||
|
// mode. In DTLS, we also need to bump the sequence number.
|
||||||
|
// This is a pre-existing defect in Mint. Issue #175.
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
|
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
|
||||||
return ClientStateStart{
|
return clientStateStart{
|
||||||
Caps: state.Caps,
|
Config: state.Config,
|
||||||
Opts: state.Opts,
|
Opts: state.Opts,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cookie: serverCookie.Cookie,
|
cookie: serverCookie.Cookie,
|
||||||
firstClientHello: firstClientHello,
|
firstClientHello: firstClientHello,
|
||||||
helloRetryRequest: hm,
|
helloRetryRequest: hm,
|
||||||
}.Next(nil)
|
}, []HandshakeAction{ResetOut{1}}, AlertNoAlert
|
||||||
|
|
||||||
case *ServerHelloBody:
|
|
||||||
sh := body
|
|
||||||
|
|
||||||
// Check that the version sent by the server is the one we support
|
|
||||||
if sh.Version != supportedVersion {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
|
|
||||||
return nil, nil, AlertProtocolVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the server provided a supported ciphersuite
|
|
||||||
supportedCipherSuite := false
|
|
||||||
for _, suite := range state.Caps.CipherSuites {
|
|
||||||
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
|
|
||||||
}
|
|
||||||
if !supportedCipherSuite {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
|
||||||
return nil, nil, AlertHandshakeFailure
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is SH.
|
||||||
// Handle external extensions.
|
// Handle external extensions.
|
||||||
if state.Caps.ExtensionHandler != nil {
|
if state.Config.ExtensionHandler != nil {
|
||||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
|
err := state.Config.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
@ -417,15 +440,22 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
|
|||||||
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
|
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
|
||||||
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
|
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
|
||||||
|
|
||||||
foundPSK := sh.Extensions.Find(&serverPSK)
|
foundExts, err := sh.Extensions.Parse(
|
||||||
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
|
[]ExtensionBody{
|
||||||
|
&serverPSK,
|
||||||
|
&serverKeyShare,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
if foundPSK && (serverPSK.SelectedIdentity == 0) {
|
if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) {
|
||||||
state.Params.UsingPSK = true
|
state.Params.UsingPSK = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var dhSecret []byte
|
var dhSecret []byte
|
||||||
if foundKeyShare {
|
if foundExts[ExtensionTypeKeyShare] {
|
||||||
sks := serverKeyShare.Shares[0]
|
sks := serverKeyShare.Shares[0]
|
||||||
priv, ok := state.OfferedDH[sks.Group]
|
priv, ok := state.OfferedDH[sks.Group]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -488,114 +518,152 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
|
|||||||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||||
|
|
||||||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
|
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
|
||||||
nextState := ClientStateWaitEE{
|
nextState := clientStateWaitEE{
|
||||||
Caps: state.Caps,
|
Config: state.Config,
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cryptoParams: params,
|
cryptoParams: params,
|
||||||
handshakeHash: handshakeHash,
|
handshakeHash: handshakeHash,
|
||||||
certificates: state.Caps.Certificates,
|
|
||||||
masterSecret: masterSecret,
|
masterSecret: masterSecret,
|
||||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||||
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
|
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
|
||||||
}
|
}
|
||||||
toSend := []HandshakeAction{
|
toSend := []HandshakeAction{
|
||||||
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
|
RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys},
|
||||||
}
|
}
|
||||||
return nextState, toSend, AlertNoAlert
|
// We're definitely not going to have to send anything with
|
||||||
|
// early data.
|
||||||
|
if !state.Params.ClientSendingEarlyData {
|
||||||
|
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
|
||||||
|
KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)})
|
||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
|
return nextState, toSend, AlertNoAlert
|
||||||
return nil, nil, AlertUnexpectedMessage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientStateWaitEE struct {
|
type clientStateWaitEE struct {
|
||||||
Caps Capabilities
|
Config *Config
|
||||||
AuthCertificate func(chain []CertificateEntry) error
|
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
hsCtx *HandshakeContext
|
||||||
cryptoParams CipherSuiteParams
|
cryptoParams CipherSuiteParams
|
||||||
handshakeHash hash.Hash
|
handshakeHash hash.Hash
|
||||||
certificates []*Certificate
|
|
||||||
masterSecret []byte
|
masterSecret []byte
|
||||||
clientHandshakeTrafficSecret []byte
|
clientHandshakeTrafficSecret []byte
|
||||||
serverHandshakeTrafficSecret []byte
|
serverHandshakeTrafficSecret []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
var _ HandshakeState = &clientStateWaitEE{}
|
||||||
|
|
||||||
|
func (state clientStateWaitEE) State() State {
|
||||||
|
return StateClientWaitEE
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
hm, alert := hr.ReadMessage()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
|
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
|
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
|
||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
ee := EncryptedExtensionsBody{}
|
ee := EncryptedExtensionsBody{}
|
||||||
_, err := ee.Unmarshal(hm.body)
|
if err := safeUnmarshal(&ee, hm.body); err != nil {
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
|
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
|
||||||
return nil, nil, AlertDecodeError
|
return nil, nil, AlertDecodeError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle external extensions.
|
// Handle external extensions.
|
||||||
if state.Caps.ExtensionHandler != nil {
|
if state.Config.ExtensionHandler != nil {
|
||||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
err := state.Config.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
|
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serverALPN := ALPNExtension{}
|
serverALPN := &ALPNExtension{}
|
||||||
serverEarlyData := EarlyDataExtension{}
|
serverEarlyData := &EarlyDataExtension{}
|
||||||
|
|
||||||
gotALPN := ee.Extensions.Find(&serverALPN)
|
foundExts, err := ee.Extensions.Parse(
|
||||||
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
|
[]ExtensionBody{
|
||||||
|
serverALPN,
|
||||||
|
serverEarlyData,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
|
|
||||||
if gotALPN && len(serverALPN.Protocols) > 0 {
|
state.Params.UsingEarlyData = foundExts[ExtensionTypeEarlyData]
|
||||||
|
|
||||||
|
if foundExts[ExtensionTypeALPN] && len(serverALPN.Protocols) > 0 {
|
||||||
state.Params.NextProto = serverALPN.Protocols[0]
|
state.Params.NextProto = serverALPN.Protocols[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
state.handshakeHash.Write(hm.Marshal())
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
|
toSend := []HandshakeAction{}
|
||||||
|
|
||||||
|
if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData {
|
||||||
|
// We didn't get 0-RTT, so rekey to handshake.
|
||||||
|
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
|
||||||
|
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
|
||||||
|
}
|
||||||
|
|
||||||
if state.Params.UsingPSK {
|
if state.Params.UsingPSK {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
|
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
|
||||||
nextState := ClientStateWaitFinished{
|
nextState := clientStateWaitFinished{
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cryptoParams: state.cryptoParams,
|
cryptoParams: state.cryptoParams,
|
||||||
handshakeHash: state.handshakeHash,
|
handshakeHash: state.handshakeHash,
|
||||||
certificates: state.certificates,
|
certificates: state.Config.Certificates,
|
||||||
masterSecret: state.masterSecret,
|
masterSecret: state.masterSecret,
|
||||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
}
|
}
|
||||||
return nextState, nil, AlertNoAlert
|
return nextState, toSend, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
|
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
|
||||||
nextState := ClientStateWaitCertCR{
|
nextState := clientStateWaitCertCR{
|
||||||
AuthCertificate: state.AuthCertificate,
|
Config: state.Config,
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cryptoParams: state.cryptoParams,
|
cryptoParams: state.cryptoParams,
|
||||||
handshakeHash: state.handshakeHash,
|
handshakeHash: state.handshakeHash,
|
||||||
certificates: state.certificates,
|
|
||||||
masterSecret: state.masterSecret,
|
masterSecret: state.masterSecret,
|
||||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
}
|
}
|
||||||
return nextState, nil, AlertNoAlert
|
return nextState, toSend, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientStateWaitCertCR struct {
|
type clientStateWaitCertCR struct {
|
||||||
AuthCertificate func(chain []CertificateEntry) error
|
Config *Config
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
hsCtx *HandshakeContext
|
||||||
cryptoParams CipherSuiteParams
|
cryptoParams CipherSuiteParams
|
||||||
handshakeHash hash.Hash
|
handshakeHash hash.Hash
|
||||||
certificates []*Certificate
|
|
||||||
masterSecret []byte
|
masterSecret []byte
|
||||||
clientHandshakeTrafficSecret []byte
|
clientHandshakeTrafficSecret []byte
|
||||||
serverHandshakeTrafficSecret []byte
|
serverHandshakeTrafficSecret []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
var _ HandshakeState = &clientStateWaitCertCR{}
|
||||||
|
|
||||||
|
func (state clientStateWaitCertCR) State() State {
|
||||||
|
return StateClientWaitCertCR
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
hm, alert := hr.ReadMessage()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
if hm == nil {
|
if hm == nil {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
|
||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
@ -612,12 +680,12 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [
|
|||||||
switch body := bodyGeneric.(type) {
|
switch body := bodyGeneric.(type) {
|
||||||
case *CertificateBody:
|
case *CertificateBody:
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
|
||||||
nextState := ClientStateWaitCV{
|
nextState := clientStateWaitCV{
|
||||||
AuthCertificate: state.AuthCertificate,
|
Config: state.Config,
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cryptoParams: state.cryptoParams,
|
cryptoParams: state.cryptoParams,
|
||||||
handshakeHash: state.handshakeHash,
|
handshakeHash: state.handshakeHash,
|
||||||
certificates: state.certificates,
|
|
||||||
serverCertificate: body,
|
serverCertificate: body,
|
||||||
masterSecret: state.masterSecret,
|
masterSecret: state.masterSecret,
|
||||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
@ -635,12 +703,12 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [
|
|||||||
state.Params.UsingClientAuth = true
|
state.Params.UsingClientAuth = true
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
|
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
|
||||||
nextState := ClientStateWaitCert{
|
nextState := clientStateWaitCert{
|
||||||
AuthCertificate: state.AuthCertificate,
|
Config: state.Config,
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cryptoParams: state.cryptoParams,
|
cryptoParams: state.cryptoParams,
|
||||||
handshakeHash: state.handshakeHash,
|
handshakeHash: state.handshakeHash,
|
||||||
certificates: state.certificates,
|
|
||||||
serverCertificateRequest: body,
|
serverCertificateRequest: body,
|
||||||
masterSecret: state.masterSecret,
|
masterSecret: state.masterSecret,
|
||||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
@ -652,13 +720,13 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [
|
|||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientStateWaitCert struct {
|
type clientStateWaitCert struct {
|
||||||
AuthCertificate func(chain []CertificateEntry) error
|
Config *Config
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
hsCtx *HandshakeContext
|
||||||
cryptoParams CipherSuiteParams
|
cryptoParams CipherSuiteParams
|
||||||
handshakeHash hash.Hash
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
certificates []*Certificate
|
|
||||||
serverCertificateRequest *CertificateRequestBody
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
masterSecret []byte
|
masterSecret []byte
|
||||||
@ -666,15 +734,24 @@ type ClientStateWaitCert struct {
|
|||||||
serverHandshakeTrafficSecret []byte
|
serverHandshakeTrafficSecret []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
var _ HandshakeState = &clientStateWaitCert{}
|
||||||
|
|
||||||
|
func (state clientStateWaitCert) State() State {
|
||||||
|
return StateClientWaitCert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
hm, alert := hr.ReadMessage()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
|
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
|
||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
cert := &CertificateBody{}
|
cert := &CertificateBody{}
|
||||||
_, err := cert.Unmarshal(hm.body)
|
if err := safeUnmarshal(cert, hm.body); err != nil {
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
|
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
|
||||||
return nil, nil, AlertDecodeError
|
return nil, nil, AlertDecodeError
|
||||||
}
|
}
|
||||||
@ -682,12 +759,12 @@ func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H
|
|||||||
state.handshakeHash.Write(hm.Marshal())
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
|
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
|
||||||
nextState := ClientStateWaitCV{
|
nextState := clientStateWaitCV{
|
||||||
AuthCertificate: state.AuthCertificate,
|
Config: state.Config,
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cryptoParams: state.cryptoParams,
|
cryptoParams: state.cryptoParams,
|
||||||
handshakeHash: state.handshakeHash,
|
handshakeHash: state.handshakeHash,
|
||||||
certificates: state.certificates,
|
|
||||||
serverCertificate: cert,
|
serverCertificate: cert,
|
||||||
serverCertificateRequest: state.serverCertificateRequest,
|
serverCertificateRequest: state.serverCertificateRequest,
|
||||||
masterSecret: state.masterSecret,
|
masterSecret: state.masterSecret,
|
||||||
@ -697,13 +774,13 @@ func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H
|
|||||||
return nextState, nil, AlertNoAlert
|
return nextState, nil, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientStateWaitCV struct {
|
type clientStateWaitCV struct {
|
||||||
AuthCertificate func(chain []CertificateEntry) error
|
Config *Config
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
hsCtx *HandshakeContext
|
||||||
cryptoParams CipherSuiteParams
|
cryptoParams CipherSuiteParams
|
||||||
handshakeHash hash.Hash
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
certificates []*Certificate
|
|
||||||
serverCertificate *CertificateBody
|
serverCertificate *CertificateBody
|
||||||
serverCertificateRequest *CertificateRequestBody
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
|
||||||
@ -712,15 +789,24 @@ type ClientStateWaitCV struct {
|
|||||||
serverHandshakeTrafficSecret []byte
|
serverHandshakeTrafficSecret []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
var _ HandshakeState = &clientStateWaitCV{}
|
||||||
|
|
||||||
|
func (state clientStateWaitCV) State() State {
|
||||||
|
return StateClientWaitCV
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
hm, alert := hr.ReadMessage()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
|
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
|
||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
certVerify := CertificateVerifyBody{}
|
certVerify := CertificateVerifyBody{}
|
||||||
_, err := certVerify.Unmarshal(hm.body)
|
if err := safeUnmarshal(&certVerify, hm.body); err != nil {
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
|
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
|
||||||
return nil, nil, AlertDecodeError
|
return nil, nil, AlertDecodeError
|
||||||
}
|
}
|
||||||
@ -734,46 +820,89 @@ func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []Han
|
|||||||
return nil, nil, AlertHandshakeFailure
|
return nil, nil, AlertHandshakeFailure
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.AuthCertificate != nil {
|
certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList))
|
||||||
err := state.AuthCertificate(state.serverCertificate.CertificateList)
|
rawCerts := make([][]byte, len(state.serverCertificate.CertificateList))
|
||||||
|
for i, certEntry := range state.serverCertificate.CertificateList {
|
||||||
|
certs[i] = certEntry.CertData
|
||||||
|
rawCerts[i] = certEntry.CertData.Raw
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifiedChains [][]*x509.Certificate
|
||||||
|
if !state.Config.InsecureSkipVerify {
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: state.Config.RootCAs,
|
||||||
|
CurrentTime: state.Config.time(),
|
||||||
|
DNSName: state.Config.ServerName,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, cert := range certs {
|
||||||
|
if i == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opts.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
verifiedChains, err = certs[0].Verify(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
|
logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err)
|
||||||
|
return nil, nil, AlertBadCertificate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Config.VerifyPeerCertificate != nil {
|
||||||
|
if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err)
|
||||||
return nil, nil, AlertBadCertificate
|
return nil, nil, AlertBadCertificate
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.handshakeHash.Write(hm.Marshal())
|
state.handshakeHash.Write(hm.Marshal())
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
|
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
|
||||||
nextState := ClientStateWaitFinished{
|
nextState := clientStateWaitFinished{
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
cryptoParams: state.cryptoParams,
|
cryptoParams: state.cryptoParams,
|
||||||
handshakeHash: state.handshakeHash,
|
handshakeHash: state.handshakeHash,
|
||||||
certificates: state.certificates,
|
certificates: state.Config.Certificates,
|
||||||
serverCertificateRequest: state.serverCertificateRequest,
|
serverCertificateRequest: state.serverCertificateRequest,
|
||||||
masterSecret: state.masterSecret,
|
masterSecret: state.masterSecret,
|
||||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||||
|
peerCertificates: certs,
|
||||||
|
verifiedChains: verifiedChains,
|
||||||
}
|
}
|
||||||
return nextState, nil, AlertNoAlert
|
return nextState, nil, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientStateWaitFinished struct {
|
type clientStateWaitFinished struct {
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
hsCtx *HandshakeContext
|
||||||
cryptoParams CipherSuiteParams
|
cryptoParams CipherSuiteParams
|
||||||
handshakeHash hash.Hash
|
handshakeHash hash.Hash
|
||||||
|
|
||||||
certificates []*Certificate
|
certificates []*Certificate
|
||||||
serverCertificateRequest *CertificateRequestBody
|
serverCertificateRequest *CertificateRequestBody
|
||||||
|
peerCertificates []*x509.Certificate
|
||||||
|
verifiedChains [][]*x509.Certificate
|
||||||
|
|
||||||
masterSecret []byte
|
masterSecret []byte
|
||||||
clientHandshakeTrafficSecret []byte
|
clientHandshakeTrafficSecret []byte
|
||||||
serverHandshakeTrafficSecret []byte
|
serverHandshakeTrafficSecret []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
var _ HandshakeState = &clientStateWaitFinished{}
|
||||||
|
|
||||||
|
func (state clientStateWaitFinished) State() State {
|
||||||
|
return StateClientWaitFinished
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
hm, alert := hr.ReadMessage()
|
||||||
|
if alert != AlertNoAlert {
|
||||||
|
return nil, nil, alert
|
||||||
|
}
|
||||||
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
|
||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
@ -788,8 +917,7 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
|
|||||||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||||
|
|
||||||
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
|
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
|
||||||
_, err := fin.Unmarshal(hm.body)
|
if err := safeUnmarshal(fin, hm.body); err != nil {
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
|
||||||
return nil, nil, AlertDecodeError
|
return nil, nil, AlertDecodeError
|
||||||
}
|
}
|
||||||
@ -822,25 +950,32 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
|
|||||||
toSend := []HandshakeAction{}
|
toSend := []HandshakeAction{}
|
||||||
|
|
||||||
if state.Params.UsingEarlyData {
|
if state.Params.UsingEarlyData {
|
||||||
|
logf(logTypeHandshake, "Sending end of early data")
|
||||||
// Note: We only send EOED if the server is actually going to use the early
|
// Note: We only send EOED if the server is actually going to use the early
|
||||||
// data. Otherwise, it will never see it, and the transcripts will
|
// data. Otherwise, it will never see it, and the transcripts will
|
||||||
// mismatch.
|
// mismatch.
|
||||||
// EOED marshal is infallible
|
// EOED marshal is infallible
|
||||||
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
|
eoedm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(&EndOfEarlyDataBody{})
|
||||||
toSend = append(toSend, SendHandshakeMessage{eoedm})
|
toSend = append(toSend, QueueHandshakeMessage{eoedm})
|
||||||
|
|
||||||
state.handshakeHash.Write(eoedm.Marshal())
|
state.handshakeHash.Write(eoedm.Marshal())
|
||||||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
|
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
|
||||||
}
|
|
||||||
|
|
||||||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
// And then rekey to handshake
|
||||||
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
|
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
|
||||||
|
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
|
||||||
|
}
|
||||||
|
|
||||||
if state.Params.UsingClientAuth {
|
if state.Params.UsingClientAuth {
|
||||||
// Extract constraints from certicateRequest
|
// Extract constraints from certicateRequest
|
||||||
schemes := SignatureAlgorithmsExtension{}
|
schemes := SignatureAlgorithmsExtension{}
|
||||||
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
|
gotSchemes, err := state.serverCertificateRequest.Extensions.Find(&schemes)
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING invalid signature_schemes extension [%v]", err)
|
||||||
|
return nil, nil, AlertDecodeError
|
||||||
|
}
|
||||||
if !gotSchemes {
|
if !gotSchemes {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found")
|
||||||
return nil, nil, AlertIllegalParameter
|
return nil, nil, AlertIllegalParameter
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -851,13 +986,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
|
|||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||||
|
|
||||||
certificate := &CertificateBody{}
|
certificate := &CertificateBody{}
|
||||||
certm, err := HandshakeMessageFromBody(certificate)
|
certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
toSend = append(toSend, QueueHandshakeMessage{certm})
|
||||||
state.handshakeHash.Write(certm.Marshal())
|
state.handshakeHash.Write(certm.Marshal())
|
||||||
} else {
|
} else {
|
||||||
// Create and send Certificate, CertificateVerify
|
// Create and send Certificate, CertificateVerify
|
||||||
@ -867,13 +1002,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
|
|||||||
for i, entry := range cert.Chain {
|
for i, entry := range cert.Chain {
|
||||||
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||||
}
|
}
|
||||||
certm, err := HandshakeMessageFromBody(certificate)
|
certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
toSend = append(toSend, QueueHandshakeMessage{certm})
|
||||||
state.handshakeHash.Write(certm.Marshal())
|
state.handshakeHash.Write(certm.Marshal())
|
||||||
|
|
||||||
hcv := state.handshakeHash.Sum(nil)
|
hcv := state.handshakeHash.Sum(nil)
|
||||||
@ -887,13 +1022,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
|
|||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
}
|
}
|
||||||
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
toSend = append(toSend, SendHandshakeMessage{certvm})
|
toSend = append(toSend, QueueHandshakeMessage{certvm})
|
||||||
state.handshakeHash.Write(certvm.Marshal())
|
state.handshakeHash.Write(certvm.Marshal())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -909,7 +1044,7 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
|
|||||||
VerifyDataLen: len(clientFinishedData),
|
VerifyDataLen: len(clientFinishedData),
|
||||||
VerifyData: clientFinishedData,
|
VerifyData: clientFinishedData,
|
||||||
}
|
}
|
||||||
finm, err := HandshakeMessageFromBody(fin)
|
finm, err := state.hsCtx.hOut.HandshakeMessageFromBody(fin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
|
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
|
||||||
return nil, nil, AlertInternalError
|
return nil, nil, AlertInternalError
|
||||||
@ -923,20 +1058,26 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
|
|||||||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||||
|
|
||||||
toSend = append(toSend, []HandshakeAction{
|
toSend = append(toSend, []HandshakeAction{
|
||||||
SendHandshakeMessage{finm},
|
QueueHandshakeMessage{finm},
|
||||||
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
|
SendQueuedHandshake{},
|
||||||
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
|
RekeyIn{epoch: EpochApplicationData, KeySet: serverTrafficKeys},
|
||||||
|
RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys},
|
||||||
}...)
|
}...)
|
||||||
|
|
||||||
|
state.hsCtx.receivedEndOfFlight()
|
||||||
|
|
||||||
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
|
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
|
||||||
nextState := StateConnected{
|
nextState := stateConnected{
|
||||||
Params: state.Params,
|
Params: state.Params,
|
||||||
|
hsCtx: state.hsCtx,
|
||||||
isClient: true,
|
isClient: true,
|
||||||
cryptoParams: state.cryptoParams,
|
cryptoParams: state.cryptoParams,
|
||||||
resumptionSecret: resumptionSecret,
|
resumptionSecret: resumptionSecret,
|
||||||
clientTrafficSecret: clientTrafficSecret,
|
clientTrafficSecret: clientTrafficSecret,
|
||||||
serverTrafficSecret: serverTrafficSecret,
|
serverTrafficSecret: serverTrafficSecret,
|
||||||
exporterSecret: exporterSecret,
|
exporterSecret: exporterSecret,
|
||||||
|
peerCertificates: state.peerCertificates,
|
||||||
|
verifiedChains: state.verifiedChains,
|
||||||
}
|
}
|
||||||
return nextState, toSend, AlertNoAlert
|
return nextState, toSend, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
118
vendor/github.com/bifurcation/mint/common.go
generated
vendored
118
vendor/github.com/bifurcation/mint/common.go
generated
vendored
@ -5,9 +5,14 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
supportedVersion uint16 = 0x7f15 // draft-21
|
supportedVersion uint16 = 0x7f16 // draft-22
|
||||||
|
tls12Version uint16 = 0x0303
|
||||||
|
tls10Version uint16 = 0x0301
|
||||||
|
dtls12WireVersion uint16 = 0xfefd
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
// Flags for some minor compat issues
|
// Flags for some minor compat issues
|
||||||
allowWrongVersionNumber = true
|
allowWrongVersionNumber = true
|
||||||
allowPKCS1 = true
|
allowPKCS1 = true
|
||||||
@ -20,6 +25,7 @@ const (
|
|||||||
RecordTypeAlert RecordType = 21
|
RecordTypeAlert RecordType = 21
|
||||||
RecordTypeHandshake RecordType = 22
|
RecordTypeHandshake RecordType = 22
|
||||||
RecordTypeApplicationData RecordType = 23
|
RecordTypeApplicationData RecordType = 23
|
||||||
|
RecordTypeAck RecordType = 25
|
||||||
)
|
)
|
||||||
|
|
||||||
// enum {...} HandshakeType;
|
// enum {...} HandshakeType;
|
||||||
@ -42,6 +48,13 @@ const (
|
|||||||
HandshakeTypeMessageHash HandshakeType = 254
|
HandshakeTypeMessageHash HandshakeType = 254
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var hrrRandomSentinel = [32]byte{
|
||||||
|
0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11,
|
||||||
|
0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
|
||||||
|
0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e,
|
||||||
|
0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
|
||||||
|
}
|
||||||
|
|
||||||
// uint8 CipherSuite[2];
|
// uint8 CipherSuite[2];
|
||||||
type CipherSuite uint16
|
type CipherSuite uint16
|
||||||
|
|
||||||
@ -150,3 +163,104 @@ const (
|
|||||||
KeyUpdateNotRequested KeyUpdateRequest = 0
|
KeyUpdateNotRequested KeyUpdateRequest = 0
|
||||||
KeyUpdateRequested KeyUpdateRequest = 1
|
KeyUpdateRequested KeyUpdateRequest = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type State uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateInit = 0
|
||||||
|
|
||||||
|
// states valid for the client
|
||||||
|
StateClientStart State = iota
|
||||||
|
StateClientWaitSH
|
||||||
|
StateClientWaitEE
|
||||||
|
StateClientWaitCert
|
||||||
|
StateClientWaitCV
|
||||||
|
StateClientWaitFinished
|
||||||
|
StateClientWaitCertCR
|
||||||
|
StateClientConnected
|
||||||
|
// states valid for the server
|
||||||
|
StateServerStart State = iota
|
||||||
|
StateServerRecvdCH
|
||||||
|
StateServerNegotiated
|
||||||
|
StateServerReadPastEarlyData
|
||||||
|
StateServerWaitEOED
|
||||||
|
StateServerWaitFlight2
|
||||||
|
StateServerWaitCert
|
||||||
|
StateServerWaitCV
|
||||||
|
StateServerWaitFinished
|
||||||
|
StateServerConnected
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s State) String() string {
|
||||||
|
switch s {
|
||||||
|
case StateClientStart:
|
||||||
|
return "Client START"
|
||||||
|
case StateClientWaitSH:
|
||||||
|
return "Client WAIT_SH"
|
||||||
|
case StateClientWaitEE:
|
||||||
|
return "Client WAIT_EE"
|
||||||
|
case StateClientWaitCert:
|
||||||
|
return "Client WAIT_CERT"
|
||||||
|
case StateClientWaitCV:
|
||||||
|
return "Client WAIT_CV"
|
||||||
|
case StateClientWaitFinished:
|
||||||
|
return "Client WAIT_FINISHED"
|
||||||
|
case StateClientWaitCertCR:
|
||||||
|
return "Client WAIT_CERT_CR"
|
||||||
|
case StateClientConnected:
|
||||||
|
return "Client CONNECTED"
|
||||||
|
case StateServerStart:
|
||||||
|
return "Server START"
|
||||||
|
case StateServerRecvdCH:
|
||||||
|
return "Server RECVD_CH"
|
||||||
|
case StateServerNegotiated:
|
||||||
|
return "Server NEGOTIATED"
|
||||||
|
case StateServerReadPastEarlyData:
|
||||||
|
return "Server READ_PAST_EARLY_DATA"
|
||||||
|
case StateServerWaitEOED:
|
||||||
|
return "Server WAIT_EOED"
|
||||||
|
case StateServerWaitFlight2:
|
||||||
|
return "Server WAIT_FLIGHT2"
|
||||||
|
case StateServerWaitCert:
|
||||||
|
return "Server WAIT_CERT"
|
||||||
|
case StateServerWaitCV:
|
||||||
|
return "Server WAIT_CV"
|
||||||
|
case StateServerWaitFinished:
|
||||||
|
return "Server WAIT_FINISHED"
|
||||||
|
case StateServerConnected:
|
||||||
|
return "Server CONNECTED"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown state: %d", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Epochs for DTLS (also used for key phase labelling)
|
||||||
|
type Epoch uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
EpochClear Epoch = 0
|
||||||
|
EpochEarlyData Epoch = 1
|
||||||
|
EpochHandshakeData Epoch = 2
|
||||||
|
EpochApplicationData Epoch = 3
|
||||||
|
EpochUpdate Epoch = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e Epoch) label() string {
|
||||||
|
switch e {
|
||||||
|
case EpochClear:
|
||||||
|
return "clear"
|
||||||
|
case EpochEarlyData:
|
||||||
|
return "early data"
|
||||||
|
case EpochHandshakeData:
|
||||||
|
return "handshake"
|
||||||
|
case EpochApplicationData:
|
||||||
|
return "application data"
|
||||||
|
}
|
||||||
|
return "Application data (updated)"
|
||||||
|
}
|
||||||
|
|
||||||
|
func assert(b bool) {
|
||||||
|
if !b {
|
||||||
|
panic("Assertion failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
497
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
497
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
@ -4,6 +4,7 @@ import (
|
|||||||
"crypto"
|
"crypto"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -12,8 +13,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var WouldBlock = fmt.Errorf("Would have blocked")
|
|
||||||
|
|
||||||
type Certificate struct {
|
type Certificate struct {
|
||||||
Chain []*x509.Certificate
|
Chain []*x509.Certificate
|
||||||
PrivateKey crypto.Signer
|
PrivateKey crypto.Signer
|
||||||
@ -36,16 +35,20 @@ type PreSharedKeyCache interface {
|
|||||||
Size() int
|
Size() int
|
||||||
}
|
}
|
||||||
|
|
||||||
type PSKMapCache map[string]PreSharedKey
|
// A CookieHandler can be used to give the application more fine-grained control over Cookies.
|
||||||
|
// Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie.
|
||||||
// A CookieHandler does two things:
|
// When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie.
|
||||||
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
|
||||||
// - validates this byte string echoed by the client in the ClientHello
|
|
||||||
type CookieHandler interface {
|
type CookieHandler interface {
|
||||||
|
// Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||||
|
// If Generate returns nil, mint will not send a HelloRetryRequest.
|
||||||
Generate(*Conn) ([]byte, error)
|
Generate(*Conn) ([]byte, error)
|
||||||
|
// Validate is called when receiving a ClientHello containing a Cookie.
|
||||||
|
// If validation failed, the handshake is aborted.
|
||||||
Validate(*Conn, []byte) bool
|
Validate(*Conn, []byte) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PSKMapCache map[string]PreSharedKey
|
||||||
|
|
||||||
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||||||
psk, ok = cache[key]
|
psk, ok = cache[key]
|
||||||
return
|
return
|
||||||
@ -74,14 +77,49 @@ type Config struct {
|
|||||||
AllowEarlyData bool
|
AllowEarlyData bool
|
||||||
// Require the client to echo a cookie.
|
// Require the client to echo a cookie.
|
||||||
RequireCookie bool
|
RequireCookie bool
|
||||||
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
// A CookieHandler can be used to set and validate a cookie.
|
||||||
// The default cookie handler uses 32 random bytes as a cookie.
|
// The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector.
|
||||||
|
// If no CookieHandler is set, mint will always send a cookie.
|
||||||
|
// The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent.
|
||||||
CookieHandler CookieHandler
|
CookieHandler CookieHandler
|
||||||
|
// The CookieProtector is used to encrypt / decrypt cookies.
|
||||||
|
// It should make sure that the Cookie cannot be read and tampered with by the client.
|
||||||
|
// If non-blocking mode is used, and cookies are required, this field has to be set.
|
||||||
|
// In blocking mode, a default cookie protector is used, if this is unused.
|
||||||
|
CookieProtector CookieProtector
|
||||||
|
// The ExtensionHandler is used to add custom extensions.
|
||||||
|
ExtensionHandler AppExtensionHandler
|
||||||
RequireClientAuth bool
|
RequireClientAuth bool
|
||||||
|
|
||||||
|
// Time returns the current time as the number of seconds since the epoch.
|
||||||
|
// If Time is nil, TLS uses time.Now.
|
||||||
|
Time func() time.Time
|
||||||
|
// RootCAs defines the set of root certificate authorities
|
||||||
|
// that clients use when verifying server certificates.
|
||||||
|
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||||
|
RootCAs *x509.CertPool
|
||||||
|
// InsecureSkipVerify controls whether a client verifies the
|
||||||
|
// server's certificate chain and host name.
|
||||||
|
// If InsecureSkipVerify is true, TLS accepts any certificate
|
||||||
|
// presented by the server and any host name in that certificate.
|
||||||
|
// In this mode, TLS is susceptible to man-in-the-middle attacks.
|
||||||
|
// This should be used only for testing.
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
|
||||||
// Shared fields
|
// Shared fields
|
||||||
Certificates []*Certificate
|
Certificates []*Certificate
|
||||||
AuthCertificate func(chain []CertificateEntry) error
|
// VerifyPeerCertificate, if not nil, is called after normal
|
||||||
|
// certificate verification by either a TLS client or server. It
|
||||||
|
// receives the raw ASN.1 certificates provided by the peer and also
|
||||||
|
// any verified chains that normal processing found. If it returns a
|
||||||
|
// non-nil error, the handshake is aborted and that error results.
|
||||||
|
//
|
||||||
|
// If normal verification fails then the handshake will abort before
|
||||||
|
// considering this callback. If normal verification is disabled by
|
||||||
|
// setting InsecureSkipVerify then this callback will be considered but
|
||||||
|
// the verifiedChains argument will always be nil.
|
||||||
|
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
||||||
|
|
||||||
CipherSuites []CipherSuite
|
CipherSuites []CipherSuite
|
||||||
Groups []NamedGroup
|
Groups []NamedGroup
|
||||||
SignatureSchemes []SignatureScheme
|
SignatureSchemes []SignatureScheme
|
||||||
@ -89,6 +127,7 @@ type Config struct {
|
|||||||
PSKs PreSharedKeyCache
|
PSKs PreSharedKeyCache
|
||||||
PSKModes []PSKKeyExchangeMode
|
PSKModes []PSKKeyExchangeMode
|
||||||
NonBlocking bool
|
NonBlocking bool
|
||||||
|
UseDTLS bool
|
||||||
|
|
||||||
// The same config object can be shared among different connections, so it
|
// The same config object can be shared among different connections, so it
|
||||||
// needs its own mutex
|
// needs its own mutex
|
||||||
@ -110,10 +149,16 @@ func (c *Config) Clone() *Config {
|
|||||||
EarlyDataLifetime: c.EarlyDataLifetime,
|
EarlyDataLifetime: c.EarlyDataLifetime,
|
||||||
AllowEarlyData: c.AllowEarlyData,
|
AllowEarlyData: c.AllowEarlyData,
|
||||||
RequireCookie: c.RequireCookie,
|
RequireCookie: c.RequireCookie,
|
||||||
|
CookieHandler: c.CookieHandler,
|
||||||
|
CookieProtector: c.CookieProtector,
|
||||||
|
ExtensionHandler: c.ExtensionHandler,
|
||||||
RequireClientAuth: c.RequireClientAuth,
|
RequireClientAuth: c.RequireClientAuth,
|
||||||
|
Time: c.Time,
|
||||||
|
RootCAs: c.RootCAs,
|
||||||
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||||
|
|
||||||
Certificates: c.Certificates,
|
Certificates: c.Certificates,
|
||||||
AuthCertificate: c.AuthCertificate,
|
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||||
CipherSuites: c.CipherSuites,
|
CipherSuites: c.CipherSuites,
|
||||||
Groups: c.Groups,
|
Groups: c.Groups,
|
||||||
SignatureSchemes: c.SignatureSchemes,
|
SignatureSchemes: c.SignatureSchemes,
|
||||||
@ -121,6 +166,7 @@ func (c *Config) Clone() *Config {
|
|||||||
PSKs: c.PSKs,
|
PSKs: c.PSKs,
|
||||||
PSKModes: c.PSKModes,
|
PSKModes: c.PSKModes,
|
||||||
NonBlocking: c.NonBlocking,
|
NonBlocking: c.NonBlocking,
|
||||||
|
UseDTLS: c.UseDTLS,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,28 +193,6 @@ func (c *Config) Init(isClient bool) error {
|
|||||||
if len(c.PSKModes) == 0 {
|
if len(c.PSKModes) == 0 {
|
||||||
c.PSKModes = defaultPSKModes
|
c.PSKModes = defaultPSKModes
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is no certificate, generate one
|
|
||||||
if !isClient && len(c.Certificates) == 0 {
|
|
||||||
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
|
||||||
priv, err := newSigningKey(RSA_PSS_SHA256)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Certificates = []*Certificate{
|
|
||||||
{
|
|
||||||
Chain: []*x509.Certificate{cert},
|
|
||||||
PrivateKey: priv,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,6 +207,14 @@ func (c *Config) ValidForClient() bool {
|
|||||||
return len(c.ServerName) > 0
|
return len(c.ServerName) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) time() time.Time {
|
||||||
|
t := c.Time
|
||||||
|
if t == nil {
|
||||||
|
t = time.Now
|
||||||
|
}
|
||||||
|
return t()
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultSupportedCipherSuites = []CipherSuite{
|
defaultSupportedCipherSuites = []CipherSuite{
|
||||||
TLS_AES_128_GCM_SHA256,
|
TLS_AES_128_GCM_SHA256,
|
||||||
@ -214,10 +246,13 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
HandshakeState string // string representation of the handshake state.
|
HandshakeState State
|
||||||
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
|
||||||
|
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
|
||||||
NextProto string // Selected ALPN proto
|
NextProto string // Selected ALPN proto
|
||||||
|
UsingPSK bool // Are we using PSK.
|
||||||
|
UsingEarlyData bool // Did we negotiate 0-RTT.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn implements the net.Conn interface, as with "crypto/tls"
|
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||||||
@ -228,9 +263,7 @@ type Conn struct {
|
|||||||
conn net.Conn
|
conn net.Conn
|
||||||
isClient bool
|
isClient bool
|
||||||
|
|
||||||
EarlyData []byte
|
state stateConnected
|
||||||
|
|
||||||
state StateConnected
|
|
||||||
hState HandshakeState
|
hState HandshakeState
|
||||||
handshakeMutex sync.Mutex
|
handshakeMutex sync.Mutex
|
||||||
handshakeAlert Alert
|
handshakeAlert Alert
|
||||||
@ -238,18 +271,28 @@ type Conn struct {
|
|||||||
|
|
||||||
readBuffer []byte
|
readBuffer []byte
|
||||||
in, out *RecordLayer
|
in, out *RecordLayer
|
||||||
hIn, hOut *HandshakeLayer
|
hsCtx *HandshakeContext
|
||||||
|
|
||||||
extHandler AppExtensionHandler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||||||
c := &Conn{conn: conn, config: config, isClient: isClient}
|
c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
|
||||||
c.in = NewRecordLayer(c.conn)
|
if !config.UseDTLS {
|
||||||
c.out = NewRecordLayer(c.conn)
|
c.in = NewRecordLayerTLS(c.conn, directionRead)
|
||||||
c.hIn = NewHandshakeLayer(c.in)
|
c.out = NewRecordLayerTLS(c.conn, directionWrite)
|
||||||
c.hIn.nonblocking = c.config.NonBlocking
|
c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
|
||||||
c.hOut = NewHandshakeLayer(c.out)
|
c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
|
||||||
|
} else {
|
||||||
|
c.in = NewRecordLayerDTLS(c.conn, directionRead)
|
||||||
|
c.out = NewRecordLayerDTLS(c.conn, directionWrite)
|
||||||
|
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
|
||||||
|
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
|
||||||
|
c.hsCtx.timeoutMS = initialTimeout
|
||||||
|
c.hsCtx.timers = newTimerSet()
|
||||||
|
c.hsCtx.waitingNextFlight = true
|
||||||
|
}
|
||||||
|
c.in.label = c.label()
|
||||||
|
c.out.label = c.label()
|
||||||
|
c.hsCtx.hIn.nonblocking = c.config.NonBlocking
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -267,8 +310,12 @@ func (c *Conn) consumeRecord() error {
|
|||||||
// We do not support fragmentation of post-handshake handshake messages.
|
// We do not support fragmentation of post-handshake handshake messages.
|
||||||
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||||||
start := 0
|
start := 0
|
||||||
|
headerLen := handshakeHeaderLenTLS
|
||||||
|
if c.config.UseDTLS {
|
||||||
|
headerLen = handshakeHeaderLenDTLS
|
||||||
|
}
|
||||||
for start < len(pt.fragment) {
|
for start < len(pt.fragment) {
|
||||||
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
if len(pt.fragment[start:]) < headerLen {
|
||||||
return fmt.Errorf("Post-handshake handshake message too short for header")
|
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,14 +323,15 @@ func (c *Conn) consumeRecord() error {
|
|||||||
hm.msgType = HandshakeType(pt.fragment[start])
|
hm.msgType = HandshakeType(pt.fragment[start])
|
||||||
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||||||
|
|
||||||
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
if len(pt.fragment[start+headerLen:]) < hmLen {
|
||||||
return fmt.Errorf("Post-handshake handshake message too short for body")
|
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||||||
}
|
}
|
||||||
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen]
|
||||||
|
|
||||||
// Advance state machine
|
|
||||||
state, actions, alert := c.state.Next(hm)
|
|
||||||
|
|
||||||
|
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||||
|
// authentication, we'll need to allow transitions other than
|
||||||
|
// Connected -> Connected
|
||||||
|
state, actions, alert := c.state.ProcessMessage(hm)
|
||||||
if alert != AlertNoAlert {
|
if alert != AlertNoAlert {
|
||||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||||
c.sendAlert(alert)
|
c.sendAlert(alert)
|
||||||
@ -299,18 +347,15 @@ func (c *Conn) consumeRecord() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
|
||||||
// authentication, we'll need to allow transitions other than
|
|
||||||
// Connected -> Connected
|
|
||||||
var connected bool
|
var connected bool
|
||||||
c.state, connected = state.(StateConnected)
|
c.state, connected = state.(stateConnected)
|
||||||
if !connected {
|
if !connected {
|
||||||
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||||||
c.sendAlert(alert)
|
c.sendAlert(alert)
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
start += handshakeHeaderLen + hmLen
|
start += headerLen + hmLen
|
||||||
}
|
}
|
||||||
case RecordTypeAlert:
|
case RecordTypeAlert:
|
||||||
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||||
@ -332,17 +377,54 @@ func (c *Conn) consumeRecord() error {
|
|||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case RecordTypeAck:
|
||||||
|
if !c.hsCtx.hIn.datagram {
|
||||||
|
logf(logTypeHandshake, "Received ACK in TLS mode")
|
||||||
|
return AlertUnexpectedMessage
|
||||||
|
}
|
||||||
|
return c.hsCtx.processAck(pt.fragment)
|
||||||
|
|
||||||
case RecordTypeApplicationData:
|
case RecordTypeApplicationData:
|
||||||
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||||||
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readPartial(in *[]byte, buffer []byte) int {
|
||||||
|
logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in)))
|
||||||
|
read := copy(buffer, *in)
|
||||||
|
*in = (*in)[read:]
|
||||||
|
|
||||||
|
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
// Read application data up to the size of buffer. Handshake and alert records
|
// Read application data up to the size of buffer. Handshake and alert records
|
||||||
// are consumed by the Conn object directly.
|
// are consumed by the Conn object directly.
|
||||||
func (c *Conn) Read(buffer []byte) (int, error) {
|
func (c *Conn) Read(buffer []byte) (int, error) {
|
||||||
|
if _, connected := c.hState.(stateConnected); !connected {
|
||||||
|
// Clients can't call Read prior to handshake completion.
|
||||||
|
if c.isClient {
|
||||||
|
return 0, errors.New("Read called before the handshake completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Neither can servers that don't allow early data.
|
||||||
|
if !c.config.AllowEarlyData {
|
||||||
|
return 0, errors.New("Read called before the handshake completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's no early data, then return WouldBlock
|
||||||
|
if len(c.hsCtx.earlyData) == 0 {
|
||||||
|
return 0, AlertWouldBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
return readPartial(&c.hsCtx.earlyData, buffer), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The handshake is now connected.
|
||||||
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||||||
if alert := c.Handshake(); alert != AlertNoAlert {
|
if alert := c.Handshake(); alert != AlertNoAlert {
|
||||||
return 0, alert
|
return 0, alert
|
||||||
@ -352,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) {
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run our timers.
|
||||||
|
if c.config.UseDTLS {
|
||||||
|
if err := c.hsCtx.timers.check(time.Now()); err != nil {
|
||||||
|
return 0, AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Lock the input channel
|
// Lock the input channel
|
||||||
c.in.Lock()
|
c.in.Lock()
|
||||||
defer c.in.Unlock()
|
defer c.in.Unlock()
|
||||||
@ -361,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) {
|
|||||||
// err can be nil if consumeRecord processed a non app-data
|
// err can be nil if consumeRecord processed a non app-data
|
||||||
// record.
|
// record.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if c.config.NonBlocking || err != WouldBlock {
|
if c.config.NonBlocking || err != AlertWouldBlock {
|
||||||
logf(logTypeIO, "conn.Read returns err=%v", err)
|
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var read int
|
return readPartial(&c.readBuffer, buffer), nil
|
||||||
n := len(buffer)
|
|
||||||
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
|
||||||
if len(c.readBuffer) <= n {
|
|
||||||
buffer = buffer[:len(c.readBuffer)]
|
|
||||||
copy(buffer, c.readBuffer)
|
|
||||||
read = len(c.readBuffer)
|
|
||||||
c.readBuffer = c.readBuffer[:0]
|
|
||||||
} else {
|
|
||||||
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
|
||||||
copy(buffer[:n], c.readBuffer[:n])
|
|
||||||
c.readBuffer = c.readBuffer[n:]
|
|
||||||
read = n
|
|
||||||
}
|
|
||||||
|
|
||||||
logf(logTypeVerbose, "Returning %v", string(buffer))
|
|
||||||
return read, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write application data
|
// Write application data
|
||||||
@ -393,6 +466,10 @@ func (c *Conn) Write(buffer []byte) (int, error) {
|
|||||||
c.out.Lock()
|
c.out.Lock()
|
||||||
defer c.out.Unlock()
|
defer c.out.Unlock()
|
||||||
|
|
||||||
|
if !c.Writable() {
|
||||||
|
return 0, errors.New("Write called before the handshake completed (and early data not in use)")
|
||||||
|
}
|
||||||
|
|
||||||
// Send full-size fragments
|
// Send full-size fragments
|
||||||
var start int
|
var start int
|
||||||
sent := 0
|
sent := 0
|
||||||
@ -495,84 +572,44 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch action := actionGeneric.(type) {
|
switch action := actionGeneric.(type) {
|
||||||
case SendHandshakeMessage:
|
case QueueHandshakeMessage:
|
||||||
err := c.hOut.WriteMessage(action.Message)
|
logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType)
|
||||||
|
err := c.hsCtx.hOut.QueueMessage(action.Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||||
return AlertInternalError
|
return AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case SendQueuedHandshake:
|
||||||
|
_, err := c.hsCtx.hOut.SendQueuedMessages()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
if c.config.UseDTLS {
|
||||||
|
c.hsCtx.timers.start(retransmitTimerLabel,
|
||||||
|
c.hsCtx.handshakeRetransmit,
|
||||||
|
c.hsCtx.timeoutMS)
|
||||||
|
}
|
||||||
case RekeyIn:
|
case RekeyIn:
|
||||||
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet)
|
||||||
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||||||
return AlertInternalError
|
return AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
case RekeyOut:
|
case RekeyOut:
|
||||||
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet)
|
||||||
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||||||
return AlertInternalError
|
return AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
case SendEarlyData:
|
case ResetOut:
|
||||||
logf(logTypeHandshake, "%s Sending early data...", label)
|
logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq)
|
||||||
_, err := c.Write(c.EarlyData)
|
c.out.ResetClear(action.seq)
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
|
||||||
return AlertInternalError
|
|
||||||
}
|
|
||||||
|
|
||||||
case ReadPastEarlyData:
|
|
||||||
logf(logTypeHandshake, "%s Reading past early data...", label)
|
|
||||||
// Scan past all records that fail to decrypt
|
|
||||||
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
_, ok := err.(DecryptError)
|
|
||||||
|
|
||||||
for ok {
|
|
||||||
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
_, ok = err.(DecryptError)
|
|
||||||
}
|
|
||||||
|
|
||||||
case ReadEarlyData:
|
|
||||||
logf(logTypeHandshake, "%s Reading early data...", label)
|
|
||||||
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
|
||||||
return AlertInternalError
|
|
||||||
}
|
|
||||||
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
|
||||||
|
|
||||||
for t == RecordTypeApplicationData {
|
|
||||||
// Read a record into the buffer. Note that this is safe
|
|
||||||
// in blocking mode because we read the record in in
|
|
||||||
// PeekRecordType.
|
|
||||||
pt, err := c.in.ReadRecord()
|
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
|
||||||
return AlertInternalError
|
|
||||||
}
|
|
||||||
|
|
||||||
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
|
||||||
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
|
||||||
|
|
||||||
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
|
||||||
if err != nil {
|
|
||||||
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
|
||||||
return AlertInternalError
|
|
||||||
}
|
|
||||||
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
|
||||||
}
|
|
||||||
logf(logTypeHandshake, "%s Done reading early data", label)
|
|
||||||
|
|
||||||
case StorePSK:
|
case StorePSK:
|
||||||
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||||||
@ -585,7 +622,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
|||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
logf(logTypeHandshake, "%s Unknown action type", label)
|
||||||
|
assert(false)
|
||||||
return AlertInternalError
|
return AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -602,33 +640,13 @@ func (c *Conn) HandshakeSetup() Alert {
|
|||||||
return AlertInternalError
|
return AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set things up
|
|
||||||
caps := Capabilities{
|
|
||||||
CipherSuites: c.config.CipherSuites,
|
|
||||||
Groups: c.config.Groups,
|
|
||||||
SignatureSchemes: c.config.SignatureSchemes,
|
|
||||||
PSKs: c.config.PSKs,
|
|
||||||
PSKModes: c.config.PSKModes,
|
|
||||||
AllowEarlyData: c.config.AllowEarlyData,
|
|
||||||
RequireCookie: c.config.RequireCookie,
|
|
||||||
CookieHandler: c.config.CookieHandler,
|
|
||||||
RequireClientAuth: c.config.RequireClientAuth,
|
|
||||||
NextProtos: c.config.NextProtos,
|
|
||||||
Certificates: c.config.Certificates,
|
|
||||||
ExtensionHandler: c.extHandler,
|
|
||||||
}
|
|
||||||
opts := ConnectionOptions{
|
opts := ConnectionOptions{
|
||||||
ServerName: c.config.ServerName,
|
ServerName: c.config.ServerName,
|
||||||
NextProtos: c.config.NextProtos,
|
NextProtos: c.config.NextProtos,
|
||||||
EarlyData: c.EarlyData,
|
|
||||||
}
|
|
||||||
|
|
||||||
if caps.RequireCookie && caps.CookieHandler == nil {
|
|
||||||
caps.CookieHandler = &defaultCookieHandler{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.isClient {
|
if c.isClient {
|
||||||
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
|
||||||
if alert != AlertNoAlert {
|
if alert != AlertNoAlert {
|
||||||
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||||||
return alert
|
return alert
|
||||||
@ -642,14 +660,56 @@ func (c *Conn) HandshakeSetup() Alert {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
state = ServerStateStart{Caps: caps, conn: c}
|
if c.config.RequireCookie && c.config.CookieProtector == nil {
|
||||||
|
logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.")
|
||||||
|
if c.config.NonBlocking {
|
||||||
|
logf(logTypeHandshake, "Not possible in non-blocking mode.")
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
c.config.CookieProtector, err = NewDefaultCookieProtector()
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "Error initializing cookie source: %v", alert)
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.hState = state
|
c.hState = state
|
||||||
|
|
||||||
return AlertNoAlert
|
return AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type handshakeMessageReader interface {
|
||||||
|
ReadMessage() (*HandshakeMessage, Alert)
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshakeMessageReaderImpl struct {
|
||||||
|
hsCtx *HandshakeContext
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ handshakeMessageReader = &handshakeMessageReaderImpl{}
|
||||||
|
|
||||||
|
func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) {
|
||||||
|
var hm *HandshakeMessage
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
hm, err = r.hsCtx.hIn.ReadMessage()
|
||||||
|
if err == AlertWouldBlock {
|
||||||
|
return nil, AlertWouldBlock
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logf(logTypeHandshake, "Error reading message: %v", err)
|
||||||
|
return nil, AlertCloseNotify
|
||||||
|
}
|
||||||
|
if hm != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hm, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||||||
// determines whether a client or server handshake is performed. If a
|
// determines whether a client or server handshake is performed. If a
|
||||||
// handshake has already been performed, then its result will be returned.
|
// handshake has already been performed, then its result will be returned.
|
||||||
@ -669,48 +729,48 @@ func (c *Conn) Handshake() Alert {
|
|||||||
return AlertNoAlert
|
return AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
var alert Alert
|
|
||||||
if c.hState == nil {
|
if c.hState == nil {
|
||||||
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label)
|
||||||
alert = c.HandshakeSetup()
|
alert := c.HandshakeSetup()
|
||||||
if alert != AlertNoAlert {
|
if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) {
|
||||||
return alert
|
return alert
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState)
|
||||||
state := c.hState
|
state := c.hState
|
||||||
_, connected := state.(StateConnected)
|
_, connected := state.(stateConnected)
|
||||||
|
|
||||||
|
hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx}
|
||||||
|
for !connected {
|
||||||
|
var alert Alert
|
||||||
var actions []HandshakeAction
|
var actions []HandshakeAction
|
||||||
|
|
||||||
for !connected {
|
// Advance the state machine
|
||||||
// Read a handshake message
|
state, actions, alert = state.Next(hmr)
|
||||||
hm, err := c.hIn.ReadMessage()
|
if alert == AlertWouldBlock {
|
||||||
if err == WouldBlock {
|
logf(logTypeHandshake, "%s Would block reading message: %s", label, alert)
|
||||||
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
// If we blocked, then run our timers to see if any have expired.
|
||||||
|
if c.hsCtx.hIn.datagram {
|
||||||
|
if err := c.hsCtx.timers.check(time.Now()); err != nil {
|
||||||
|
return AlertInternalError
|
||||||
|
}
|
||||||
|
}
|
||||||
return AlertWouldBlock
|
return AlertWouldBlock
|
||||||
}
|
}
|
||||||
if err != nil {
|
if alert == AlertCloseNotify {
|
||||||
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
logf(logTypeHandshake, "%s Error reading message: %s", label, alert)
|
||||||
c.sendAlert(AlertCloseNotify)
|
c.sendAlert(AlertCloseNotify)
|
||||||
return AlertCloseNotify
|
return AlertCloseNotify
|
||||||
}
|
}
|
||||||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
if alert != AlertNoAlert && alert != AlertStatelessRetry {
|
||||||
|
|
||||||
// Advance the state machine
|
|
||||||
state, actions, alert = state.Next(hm)
|
|
||||||
|
|
||||||
if alert != AlertNoAlert {
|
|
||||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||||
return alert
|
return alert
|
||||||
}
|
}
|
||||||
|
|
||||||
for index, action := range actions {
|
for index, action := range actions {
|
||||||
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||||||
alert = c.takeAction(action)
|
if alert := c.takeAction(action); alert != AlertNoAlert {
|
||||||
if alert != AlertNoAlert {
|
|
||||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||||
c.sendAlert(alert)
|
c.sendAlert(alert)
|
||||||
return alert
|
return alert
|
||||||
@ -719,14 +779,14 @@ func (c *Conn) Handshake() Alert {
|
|||||||
|
|
||||||
c.hState = state
|
c.hState = state
|
||||||
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||||||
|
_, connected = state.(stateConnected)
|
||||||
|
if connected {
|
||||||
|
c.state = state.(stateConnected)
|
||||||
|
c.handshakeComplete = true
|
||||||
|
|
||||||
_, connected = state.(StateConnected)
|
if !c.isClient {
|
||||||
}
|
// Send NewSessionTicket if configured to
|
||||||
|
if c.config.SendSessionTickets {
|
||||||
c.state = state.(StateConnected)
|
|
||||||
|
|
||||||
// Send NewSessionTicket if acting as server
|
|
||||||
if !c.isClient && c.config.SendSessionTickets {
|
|
||||||
actions, alert := c.state.NewSessionTicket(
|
actions, alert := c.state.NewSessionTicket(
|
||||||
c.config.TicketLen,
|
c.config.TicketLen,
|
||||||
c.config.TicketLifetime,
|
c.config.TicketLifetime,
|
||||||
@ -742,7 +802,25 @@ func (c *Conn) Handshake() Alert {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.handshakeComplete = true
|
// If there is early data, move it into the main buffer
|
||||||
|
if c.hsCtx.earlyData != nil {
|
||||||
|
c.readBuffer = c.hsCtx.earlyData
|
||||||
|
c.hsCtx.earlyData = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
assert(c.hsCtx.earlyData == nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.NonBlocking {
|
||||||
|
if alert == AlertStatelessRetry {
|
||||||
|
return AlertStatelessRetry
|
||||||
|
}
|
||||||
|
return AlertNoAlert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return AlertNoAlert
|
return AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -775,12 +853,15 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) GetHsState() string {
|
func (c *Conn) GetHsState() State {
|
||||||
return reflect.TypeOf(c.hState).Name()
|
if c.hState == nil {
|
||||||
|
return StateInit
|
||||||
|
}
|
||||||
|
return c.hState.State()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||||
_, connected := c.hState.(StateConnected)
|
_, connected := c.hState.(stateConnected)
|
||||||
if !connected {
|
if !connected {
|
||||||
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||||||
}
|
}
|
||||||
@ -796,7 +877,7 @@ func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]b
|
|||||||
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) State() ConnectionState {
|
func (c *Conn) ConnectionState() ConnectionState {
|
||||||
state := ConnectionState{
|
state := ConnectionState{
|
||||||
HandshakeState: c.GetHsState(),
|
HandshakeState: c.GetHsState(),
|
||||||
}
|
}
|
||||||
@ -804,16 +885,32 @@ func (c *Conn) State() ConnectionState {
|
|||||||
if c.handshakeComplete {
|
if c.handshakeComplete {
|
||||||
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||||||
state.NextProto = c.state.Params.NextProto
|
state.NextProto = c.state.Params.NextProto
|
||||||
|
state.VerifiedChains = c.state.verifiedChains
|
||||||
|
state.PeerCertificates = c.state.peerCertificates
|
||||||
|
state.UsingPSK = c.state.Params.UsingPSK
|
||||||
|
state.UsingEarlyData = c.state.Params.UsingEarlyData
|
||||||
}
|
}
|
||||||
|
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
func (c *Conn) Writable() bool {
|
||||||
if c.hState != nil {
|
// If we're connected, we're writable.
|
||||||
return fmt.Errorf("Can't set extension handler after setup")
|
if _, connected := c.hState.(stateConnected); connected {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
c.extHandler = h
|
// If we're a client in 0-RTT, then we're writable.
|
||||||
return nil
|
if c.isClient && c.out.cipher.epoch == EpochEarlyData {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) label() string {
|
||||||
|
if c.isClient {
|
||||||
|
return "client"
|
||||||
|
}
|
||||||
|
return "server"
|
||||||
}
|
}
|
||||||
|
86
vendor/github.com/bifurcation/mint/cookie-protector.go
generated
vendored
Normal file
86
vendor/github.com/bifurcation/mint/cookie-protector.go
generated
vendored
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CookieProtector is used to create and verify a cookie
|
||||||
|
type CookieProtector interface {
|
||||||
|
// NewToken creates a new token
|
||||||
|
NewToken([]byte) ([]byte, error)
|
||||||
|
// DecodeToken decodes a token
|
||||||
|
DecodeToken([]byte) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const cookieSecretSize = 32
|
||||||
|
const cookieNonceSize = 32
|
||||||
|
|
||||||
|
// The DefaultCookieProtector is a simple implementation for the CookieProtector.
|
||||||
|
type DefaultCookieProtector struct {
|
||||||
|
secret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ CookieProtector = &DefaultCookieProtector{}
|
||||||
|
|
||||||
|
// NewDefaultCookieProtector creates a source for source address tokens
|
||||||
|
func NewDefaultCookieProtector() (CookieProtector, error) {
|
||||||
|
secret := make([]byte, cookieSecretSize)
|
||||||
|
if _, err := rand.Read(secret); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DefaultCookieProtector{secret: secret}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToken encodes data into a new token.
|
||||||
|
func (s *DefaultCookieProtector) NewToken(data []byte) ([]byte, error) {
|
||||||
|
nonce := make([]byte, cookieNonceSize)
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeToken decodes a token.
|
||||||
|
func (s *DefaultCookieProtector) DecodeToken(p []byte) ([]byte, error) {
|
||||||
|
if len(p) < cookieNonceSize {
|
||||||
|
return nil, fmt.Errorf("Token too short: %d", len(p))
|
||||||
|
}
|
||||||
|
nonce := p[:cookieNonceSize]
|
||||||
|
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultCookieProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
|
||||||
|
h := hkdf.New(sha256.New, s.secret, nonce, []byte("mint cookie source"))
|
||||||
|
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
|
||||||
|
if _, err := io.ReadFull(h, key); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
aeadNonce := make([]byte, 12)
|
||||||
|
if _, err := io.ReadFull(h, aeadNonce); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
c, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
aead, err := cipher.NewGCM(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return aead, aeadNonce, nil
|
||||||
|
}
|
81
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
81
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
@ -331,40 +331,6 @@ func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
|
||||||
sigAlg, ok := x509AlgMap[alg]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
|
||||||
}
|
|
||||||
if len(name) == 0 {
|
|
||||||
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: serial,
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().AddDate(0, 0, 1),
|
|
||||||
SignatureAlgorithm: sigAlg,
|
|
||||||
Subject: pkix.Name{CommonName: name},
|
|
||||||
DNSNames: []string{name},
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
}
|
|
||||||
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// It is safe to ignore the error here because we're parsing known-good data
|
|
||||||
cert, _ := x509.ParseCertificate(der)
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// XXX(rlb): Copied from crypto/x509
|
// XXX(rlb): Copied from crypto/x509
|
||||||
type ecdsaSignature struct {
|
type ecdsaSignature struct {
|
||||||
R, S *big.Int
|
R, S *big.Int
|
||||||
@ -652,3 +618,50 @@ func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
|
|||||||
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) {
|
||||||
|
priv, err := newSigningKey(alg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := newSelfSigned(name, alg, priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return priv, cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||||
|
sigAlg, ok := x509AlgMap[alg]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||||
|
}
|
||||||
|
if len(name) == 0 {
|
||||||
|
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||||
|
SignatureAlgorithm: sigAlg,
|
||||||
|
Subject: pkix.Name{CommonName: name},
|
||||||
|
DNSNames: []string{name},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is safe to ignore the error here because we're parsing known-good data
|
||||||
|
cert, _ := x509.ParseCertificate(der)
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
222
vendor/github.com/bifurcation/mint/dtls.go
generated
vendored
Normal file
222
vendor/github.com/bifurcation/mint/dtls.go
generated
vendored
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/bifurcation/mint/syntax"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
initialMtu = 1200
|
||||||
|
initialTimeout = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// labels for timers
|
||||||
|
const (
|
||||||
|
retransmitTimerLabel = "handshake retransmit"
|
||||||
|
ackTimerLabel = "ack timer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SentHandshakeFragment struct {
|
||||||
|
seq uint32
|
||||||
|
offset int
|
||||||
|
fragLength int
|
||||||
|
record uint64
|
||||||
|
acked bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type DtlsAck struct {
|
||||||
|
RecordNumbers []uint64 `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func wireVersion(h *HandshakeLayer) uint16 {
|
||||||
|
if h.datagram {
|
||||||
|
return dtls12WireVersion
|
||||||
|
}
|
||||||
|
return tls12Version
|
||||||
|
}
|
||||||
|
|
||||||
|
func dtlsConvertVersion(version uint16) uint16 {
|
||||||
|
if version == tls12Version {
|
||||||
|
return dtls12WireVersion
|
||||||
|
}
|
||||||
|
if version == tls10Version {
|
||||||
|
return 0xfeff
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("Internal error, unexpected version=%d", version))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ekr@rtfm.com): Move these to state-machine.go
|
||||||
|
func (h *HandshakeContext) handshakeRetransmit() error {
|
||||||
|
if _, err := h.hOut.SendQueuedMessages(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
h.timers.start(retransmitTimerLabel,
|
||||||
|
h.handshakeRetransmit,
|
||||||
|
h.timeoutMS)
|
||||||
|
|
||||||
|
// TODO(ekr@rtfm.com): Back off timer
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeContext) sendAck() error {
|
||||||
|
toack := h.hIn.recvdRecords
|
||||||
|
|
||||||
|
count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU
|
||||||
|
if len(toack) > count {
|
||||||
|
toack = toack[:count]
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "Sending ACK: [%x]", toack)
|
||||||
|
|
||||||
|
ack := &DtlsAck{toack}
|
||||||
|
body, err := syntax.Marshal(&ack)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = h.hOut.conn.WriteRecord(&TLSPlaintext{
|
||||||
|
contentType: RecordTypeAck,
|
||||||
|
fragment: body,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeContext) processAck(data []byte) error {
|
||||||
|
// Cancel the retransmit timer because we will be resending
|
||||||
|
// and possibly re-arming later.
|
||||||
|
h.timers.cancel(retransmitTimerLabel)
|
||||||
|
|
||||||
|
ack := &DtlsAck{}
|
||||||
|
read, err := syntax.Unmarshal(data, &ack)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(data) != read {
|
||||||
|
return fmt.Errorf("Invalid encoding: Extra data not consumed")
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers)
|
||||||
|
|
||||||
|
for _, r := range ack.RecordNumbers {
|
||||||
|
for _, m := range h.sentFragments {
|
||||||
|
if r == m.record {
|
||||||
|
logf(logTypeHandshake, "Marking %v %v(%v) as acked",
|
||||||
|
m.seq, m.offset, m.fragLength)
|
||||||
|
m.acked = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := h.hOut.SendQueuedMessages()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
logf(logTypeHandshake, "All messages ACKed")
|
||||||
|
h.hOut.ClearQueuedMessages()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the timer
|
||||||
|
h.timers.start(retransmitTimerLabel,
|
||||||
|
h.handshakeRetransmit,
|
||||||
|
h.timeoutMS)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) GetDTLSTimeout() (bool, time.Duration) {
|
||||||
|
return c.hsCtx.timers.remaining()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeContext) receivedHandshakeMessage() {
|
||||||
|
logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight)
|
||||||
|
// This just enables tests.
|
||||||
|
if h.hIn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.hIn.datagram {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.waitingNextFlight {
|
||||||
|
logf(logTypeHandshake, "Received the start of the flight")
|
||||||
|
|
||||||
|
// Clear the outgoing DTLS queue and terminate the retransmit timer
|
||||||
|
h.hOut.ClearQueuedMessages()
|
||||||
|
h.timers.cancel(retransmitTimerLabel)
|
||||||
|
|
||||||
|
// OK, we're not waiting any more.
|
||||||
|
h.waitingNextFlight = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now pre-emptively arm the ACK timer if it's not armed already.
|
||||||
|
// We'll automatically dis-arm it at the end of the handshake.
|
||||||
|
if h.timers.getTimer(ackTimerLabel) == nil {
|
||||||
|
h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeContext) receivedEndOfFlight() {
|
||||||
|
logf(logTypeHandshake, "%p Received the end of the flight", h)
|
||||||
|
if !h.hIn.datagram {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty incoming queue
|
||||||
|
h.hIn.queued = nil
|
||||||
|
|
||||||
|
// Note that we are waiting for the next flight.
|
||||||
|
h.waitingNextFlight = true
|
||||||
|
|
||||||
|
// Clear the ACK queue.
|
||||||
|
h.hIn.recvdRecords = nil
|
||||||
|
|
||||||
|
// Disarm the ACK timer
|
||||||
|
h.timers.cancel(ackTimerLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeContext) receivedFinalFlight() {
|
||||||
|
logf(logTypeHandshake, "%p Received final flight", h)
|
||||||
|
if !h.hIn.datagram {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disarm the ACK timer
|
||||||
|
h.timers.cancel(ackTimerLabel)
|
||||||
|
|
||||||
|
// But send an ACK immediately.
|
||||||
|
h.sendAck()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool {
|
||||||
|
logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen)
|
||||||
|
for _, f := range h.sentFragments {
|
||||||
|
if !f.acked {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.seq != seq {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.offset > offset {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we know that the stored fragment starts
|
||||||
|
// at or before what we want to send, so check where the end
|
||||||
|
// is.
|
||||||
|
if f.offset+f.fragLength < offset+fraglen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
98
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
98
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
@ -3,7 +3,6 @@ package mint
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/bifurcation/mint/syntax"
|
"github.com/bifurcation/mint/syntax"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -77,14 +76,40 @@ func (el *ExtensionList) Add(src ExtensionBody) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (el ExtensionList) Find(dst ExtensionBody) bool {
|
func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) {
|
||||||
|
found := make(map[ExtensionType]bool)
|
||||||
|
|
||||||
|
for _, dst := range dsts {
|
||||||
for _, ext := range el {
|
for _, ext := range el {
|
||||||
if ext.ExtensionType == dst.Type() {
|
if ext.ExtensionType == dst.Type() {
|
||||||
_, err := dst.Unmarshal(ext.ExtensionData)
|
if found[dst.Type()] {
|
||||||
return err == nil
|
return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
err := safeUnmarshal(dst, ext.ExtensionData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
found[dst.Type()] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
}
|
||||||
|
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el ExtensionList) Find(dst ExtensionBody) (bool, error) {
|
||||||
|
for _, ext := range el {
|
||||||
|
if ext.ExtensionType == dst.Type() {
|
||||||
|
err := safeUnmarshal(dst, ext.ExtensionData)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// struct {
|
// struct {
|
||||||
@ -529,19 +554,56 @@ func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
|
|||||||
// ProtocolVersion versions<2..254>;
|
// ProtocolVersion versions<2..254>;
|
||||||
// } SupportedVersions;
|
// } SupportedVersions;
|
||||||
type SupportedVersionsExtension struct {
|
type SupportedVersionsExtension struct {
|
||||||
|
HandshakeType HandshakeType
|
||||||
|
Versions []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type SupportedVersionsClientHelloInner struct {
|
||||||
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SupportedVersionsServerHelloInner struct {
|
||||||
|
Version uint16
|
||||||
|
}
|
||||||
|
|
||||||
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
||||||
return ExtensionTypeSupportedVersions
|
return ExtensionTypeSupportedVersions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
||||||
return syntax.Marshal(sv)
|
switch sv.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions})
|
||||||
|
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
|
||||||
|
return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]})
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
||||||
return syntax.Unmarshal(data, sv)
|
switch sv.HandshakeType {
|
||||||
|
case HandshakeTypeClientHello:
|
||||||
|
var inner SupportedVersionsClientHelloInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sv.Versions = inner.Versions
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
|
||||||
|
var inner SupportedVersionsServerHelloInner
|
||||||
|
read, err := syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sv.Versions = []uint16{inner.Version}
|
||||||
|
return read, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// struct {
|
// struct {
|
||||||
@ -562,25 +624,3 @@ func (c CookieExtension) Marshal() ([]byte, error) {
|
|||||||
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
||||||
return syntax.Unmarshal(data, c)
|
return syntax.Unmarshal(data, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultCookieLength is the default length of a cookie
|
|
||||||
const defaultCookieLength = 32
|
|
||||||
|
|
||||||
type defaultCookieHandler struct {
|
|
||||||
data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ CookieHandler = &defaultCookieHandler{}
|
|
||||||
|
|
||||||
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
|
||||||
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
|
|
||||||
h.data = make([]byte, defaultCookieLength)
|
|
||||||
if _, err := prng.Read(h.data); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return h.data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
|
|
||||||
return bytes.Equal(h.data, data)
|
|
||||||
}
|
|
||||||
|
8
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
8
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
@ -66,8 +66,8 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
|||||||
f.remainder = f.remainder[copied:]
|
f.remainder = f.remainder[copied:]
|
||||||
f.writeOffset += copied
|
f.writeOffset += copied
|
||||||
if f.writeOffset < len(f.working) {
|
if f.writeOffset < len(f.working) {
|
||||||
logf(logTypeFrameReader, "Read would have blocked 1")
|
logf(logTypeVerbose, "Read would have blocked 1")
|
||||||
return nil, nil, WouldBlock
|
return nil, nil, AlertWouldBlock
|
||||||
}
|
}
|
||||||
// Reset the write offset, because we are now full.
|
// Reset the write offset, because we are now full.
|
||||||
f.writeOffset = 0
|
f.writeOffset = 0
|
||||||
@ -93,6 +93,6 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
|||||||
f.state = kFrameReaderBody
|
f.state = kFrameReaderBody
|
||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeFrameReader, "Read would have blocked 2")
|
logf(logTypeVerbose, "Read would have blocked 2")
|
||||||
return nil, nil, WouldBlock
|
return nil, nil, AlertWouldBlock
|
||||||
}
|
}
|
||||||
|
432
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
432
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
@ -7,7 +7,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
handshakeHeaderLen = 4 // handshake message header length
|
handshakeHeaderLenTLS = 4 // handshake message header length
|
||||||
|
handshakeHeaderLenDTLS = 12 // handshake message header length
|
||||||
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,28 +28,42 @@ const (
|
|||||||
// opaque msg<0..2^24-1>
|
// opaque msg<0..2^24-1>
|
||||||
// } Handshake;
|
// } Handshake;
|
||||||
//
|
//
|
||||||
// TODO: File a spec bug
|
|
||||||
type HandshakeMessage struct {
|
type HandshakeMessage struct {
|
||||||
// Omitted: length
|
|
||||||
msgType HandshakeType
|
msgType HandshakeType
|
||||||
|
seq uint32
|
||||||
body []byte
|
body []byte
|
||||||
|
datagram bool
|
||||||
|
offset uint32 // Used for DTLS
|
||||||
|
length uint32
|
||||||
|
cipher *cipherState
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: This could be done with the `syntax` module, using the simplified
|
// Note: This could be done with the `syntax` module, using the simplified
|
||||||
// syntax as discussed above. However, since this is so simple, there's not
|
// syntax as discussed above. However, since this is so simple, there's not
|
||||||
// much benefit to doing so.
|
// much benefit to doing so.
|
||||||
|
// When datagram is set, we marshal this as a whole DTLS record.
|
||||||
func (hm *HandshakeMessage) Marshal() []byte {
|
func (hm *HandshakeMessage) Marshal() []byte {
|
||||||
if hm == nil {
|
if hm == nil {
|
||||||
return []byte{}
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
msgLen := len(hm.body)
|
fragLen := len(hm.body)
|
||||||
data := make([]byte, 4+len(hm.body))
|
var data []byte
|
||||||
data[0] = byte(hm.msgType)
|
|
||||||
data[1] = byte(msgLen >> 16)
|
if hm.datagram {
|
||||||
data[2] = byte(msgLen >> 8)
|
data = make([]byte, handshakeHeaderLenDTLS+fragLen)
|
||||||
data[3] = byte(msgLen)
|
} else {
|
||||||
copy(data[4:], hm.body)
|
data = make([]byte, handshakeHeaderLenTLS+fragLen)
|
||||||
|
}
|
||||||
|
tmp := data
|
||||||
|
tmp = encodeUint(uint64(hm.msgType), 1, tmp)
|
||||||
|
tmp = encodeUint(uint64(hm.length), 3, tmp)
|
||||||
|
if hm.datagram {
|
||||||
|
tmp = encodeUint(uint64(hm.seq), 2, tmp)
|
||||||
|
tmp = encodeUint(uint64(hm.offset), 3, tmp)
|
||||||
|
tmp = encodeUint(uint64(fragLen), 3, tmp)
|
||||||
|
}
|
||||||
|
copy(tmp, hm.body)
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,8 +76,6 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
|||||||
body = new(ClientHelloBody)
|
body = new(ClientHelloBody)
|
||||||
case HandshakeTypeServerHello:
|
case HandshakeTypeServerHello:
|
||||||
body = new(ServerHelloBody)
|
body = new(ServerHelloBody)
|
||||||
case HandshakeTypeHelloRetryRequest:
|
|
||||||
body = new(HelloRetryRequestBody)
|
|
||||||
case HandshakeTypeEncryptedExtensions:
|
case HandshakeTypeEncryptedExtensions:
|
||||||
body = new(EncryptedExtensionsBody)
|
body = new(EncryptedExtensionsBody)
|
||||||
case HandshakeTypeCertificate:
|
case HandshakeTypeCertificate:
|
||||||
@ -83,62 +96,104 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
|||||||
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := body.Unmarshal(hm.body)
|
err := safeUnmarshal(body, hm.body)
|
||||||
return body, err
|
return body, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||||
data, err := body.Marshal()
|
data, err := body.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &HandshakeMessage{
|
m := &HandshakeMessage{
|
||||||
msgType: body.Type(),
|
msgType: body.Type(),
|
||||||
body: data,
|
body: data,
|
||||||
}, nil
|
seq: h.msgSeq,
|
||||||
|
datagram: h.datagram,
|
||||||
|
length: uint32(len(data)),
|
||||||
|
}
|
||||||
|
h.msgSeq++
|
||||||
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type HandshakeLayer struct {
|
type HandshakeLayer struct {
|
||||||
|
ctx *HandshakeContext // The handshake we are attached to
|
||||||
nonblocking bool // Should we operate in nonblocking mode
|
nonblocking bool // Should we operate in nonblocking mode
|
||||||
conn *RecordLayer // Used for reading/writing records
|
conn *RecordLayer // Used for reading/writing records
|
||||||
frame *frameReader // The buffered frame reader
|
frame *frameReader // The buffered frame reader
|
||||||
|
datagram bool // Is this DTLS?
|
||||||
|
msgSeq uint32 // The DTLS message sequence number
|
||||||
|
queued []*HandshakeMessage // In/out queue
|
||||||
|
sent []*HandshakeMessage // Sent messages for DTLS
|
||||||
|
recvdRecords []uint64 // Records we have received.
|
||||||
|
maxFragmentLen int
|
||||||
}
|
}
|
||||||
|
|
||||||
type handshakeLayerFrameDetails struct{}
|
type handshakeLayerFrameDetails struct {
|
||||||
|
datagram bool
|
||||||
|
}
|
||||||
|
|
||||||
func (d handshakeLayerFrameDetails) headerLen() int {
|
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||||
return handshakeHeaderLen
|
if d.datagram {
|
||||||
|
return handshakeHeaderLenDTLS
|
||||||
|
}
|
||||||
|
return handshakeHeaderLenTLS
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||||
return handshakeHeaderLen + maxFragmentLen
|
return d.headerLen() + maxFragmentLen
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||||
logf(logTypeIO, "Header=%x", hdr)
|
logf(logTypeIO, "Header=%x", hdr)
|
||||||
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
// The length of this fragment (as opposed to the message)
|
||||||
|
// is always the last three bytes for both TLS and DTLS
|
||||||
|
val, _ := decodeUint(hdr[len(hdr)-3:], 3)
|
||||||
|
return int(val), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
|
||||||
h := HandshakeLayer{}
|
h := HandshakeLayer{}
|
||||||
|
h.ctx = c
|
||||||
h.conn = r
|
h.conn = r
|
||||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
h.datagram = false
|
||||||
|
h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
|
||||||
|
h.maxFragmentLen = maxFragmentLen
|
||||||
|
return &h
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
|
||||||
|
h := HandshakeLayer{}
|
||||||
|
h.ctx = c
|
||||||
|
h.conn = r
|
||||||
|
h.datagram = true
|
||||||
|
h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
|
||||||
|
h.maxFragmentLen = initialMtu // Not quite right
|
||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HandshakeLayer) readRecord() error {
|
func (h *HandshakeLayer) readRecord() error {
|
||||||
logf(logTypeIO, "Trying to read record")
|
logf(logTypeVerbose, "Trying to read record")
|
||||||
pt, err := h.conn.ReadRecord()
|
pt, err := h.conn.readRecordAnyEpoch()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if pt.contentType != RecordTypeHandshake &&
|
switch pt.contentType {
|
||||||
pt.contentType != RecordTypeAlert {
|
case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
|
||||||
|
default:
|
||||||
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pt.contentType == RecordTypeAck {
|
||||||
|
if !h.datagram {
|
||||||
|
return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
|
||||||
|
}
|
||||||
|
logf(logTypeIO, "read ACK")
|
||||||
|
return h.ctx.processAck(pt.fragment)
|
||||||
|
}
|
||||||
|
|
||||||
if pt.contentType == RecordTypeAlert {
|
if pt.contentType == RecordTypeAlert {
|
||||||
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||||
if len(pt.fragment) < 2 {
|
if len(pt.fragment) < 2 {
|
||||||
@ -148,7 +203,19 @@ func (h *HandshakeLayer) readRecord() error {
|
|||||||
return Alert(pt.fragment[1])
|
return Alert(pt.fragment[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
assert(h.ctx.hIn.conn != nil)
|
||||||
|
if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
|
||||||
|
// This is out of order but we're dropping it.
|
||||||
|
// TODO(ekr@rtfm.com): If server, need to retransmit Finished.
|
||||||
|
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anything else shouldn't happen.
|
||||||
|
return AlertIllegalParameter
|
||||||
|
}
|
||||||
|
|
||||||
|
h.recvdRecords = append(h.recvdRecords, pt.seq)
|
||||||
h.frame.addChunk(pt.fragment)
|
h.frame.addChunk(pt.fragment)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -171,83 +238,314 @@ func (h *HandshakeLayer) sendAlert(err Alert) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
|
||||||
|
h.msgSeq = seq + 1
|
||||||
|
var i int
|
||||||
|
var m *HandshakeMessage
|
||||||
|
for i, m = range h.queued {
|
||||||
|
if m.seq > seq {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.queued = h.queued[i:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
|
||||||
|
if hm.seq < h.msgSeq {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ekr@rtfm.com): Send an ACK immediately if we got something
|
||||||
|
// out of order.
|
||||||
|
h.ctx.receivedHandshakeMessage()
|
||||||
|
|
||||||
|
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
|
||||||
|
// TODO(ekr@rtfm.com): Check the length?
|
||||||
|
// This is complete.
|
||||||
|
h.noteMessageDelivered(hm.seq)
|
||||||
|
return hm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now insert sorted.
|
||||||
|
var i int
|
||||||
|
for i = 0; i < len(h.queued); i++ {
|
||||||
|
f := h.queued[i]
|
||||||
|
if hm.seq < f.seq {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if hm.offset < f.offset {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tmp := make([]*HandshakeMessage, 0, len(h.queued)+1)
|
||||||
|
tmp = append(tmp, h.queued[:i]...)
|
||||||
|
tmp = append(tmp, hm)
|
||||||
|
tmp = append(tmp, h.queued[i:]...)
|
||||||
|
h.queued = tmp
|
||||||
|
|
||||||
|
return h.checkMessageAvailable()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
|
||||||
|
if len(h.queued) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hm := h.queued[0]
|
||||||
|
if hm.seq != h.msgSeq {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
|
||||||
|
// TODO(ekr@rtfm.com): Check the length?
|
||||||
|
// This is complete.
|
||||||
|
h.noteMessageDelivered(hm.seq)
|
||||||
|
return hm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OK, this at least might complete the message.
|
||||||
|
end := uint32(0)
|
||||||
|
buf := make([]byte, hm.length)
|
||||||
|
|
||||||
|
for _, f := range h.queued {
|
||||||
|
// Out of fragments
|
||||||
|
if f.seq > hm.seq {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.length != uint32(len(buf)) {
|
||||||
|
return nil, fmt.Errorf("Mismatched DTLS length")
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.offset > end {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.offset+uint32(len(f.body)) > end {
|
||||||
|
// OK, this is adding something we don't know about
|
||||||
|
copy(buf[f.offset:], f.body)
|
||||||
|
end = f.offset + uint32(len(f.body))
|
||||||
|
if end == hm.length {
|
||||||
|
h2 := *hm
|
||||||
|
h2.offset = 0
|
||||||
|
h2.body = buf
|
||||||
|
h.noteMessageDelivered(hm.seq)
|
||||||
|
return &h2, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||||
var hdr, body []byte
|
var hdr, body []byte
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
for {
|
hm, err := h.checkMessageAvailable()
|
||||||
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
if err != nil {
|
||||||
if h.frame.needed() > 0 {
|
|
||||||
logf(logTypeHandshake, "Trying to read a new record")
|
|
||||||
err = h.readRecord()
|
|
||||||
}
|
|
||||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if hm != nil {
|
||||||
|
return hm, nil
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||||
|
if h.frame.needed() > 0 {
|
||||||
|
logf(logTypeVerbose, "Trying to read a new record")
|
||||||
|
err = h.readRecord()
|
||||||
|
|
||||||
|
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hdr, body, err = h.frame.process()
|
hdr, body, err = h.frame.process()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeHandshake, "read handshake message")
|
logf(logTypeHandshake, "read handshake message")
|
||||||
|
|
||||||
hm := &HandshakeMessage{}
|
hm = &HandshakeMessage{}
|
||||||
hm.msgType = HandshakeType(hdr[0])
|
hm.msgType = HandshakeType(hdr[0])
|
||||||
|
hm.datagram = h.datagram
|
||||||
hm.body = make([]byte, len(body))
|
hm.body = make([]byte, len(body))
|
||||||
copy(hm.body, body)
|
copy(hm.body, body)
|
||||||
|
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||||
|
if h.datagram {
|
||||||
|
tmp, hdr := decodeUint(hdr[1:], 3)
|
||||||
|
hm.length = uint32(tmp)
|
||||||
|
tmp, hdr = decodeUint(hdr, 2)
|
||||||
|
hm.seq = uint32(tmp)
|
||||||
|
tmp, hdr = decodeUint(hdr, 3)
|
||||||
|
hm.offset = uint32(tmp)
|
||||||
|
|
||||||
|
return h.newFragmentReceived(hm)
|
||||||
|
}
|
||||||
|
|
||||||
|
hm.length = uint32(len(body))
|
||||||
return hm, nil
|
return hm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
|
||||||
return h.WriteMessages([]*HandshakeMessage{hm})
|
hm.cipher = h.conn.cipher
|
||||||
|
h.queued = append(h.queued, hm)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
|
||||||
|
logf(logTypeHandshake, "Sending outgoing messages")
|
||||||
|
count, err := h.WriteMessages(h.queued)
|
||||||
|
if !h.datagram {
|
||||||
|
h.ClearQueuedMessages()
|
||||||
|
}
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) ClearQueuedMessages() {
|
||||||
|
logf(logTypeHandshake, "Clearing outgoing hs message queue")
|
||||||
|
h.queued = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
|
||||||
|
var buf []byte
|
||||||
|
|
||||||
|
// Figure out if we're going to want the full header or just
|
||||||
|
// the body
|
||||||
|
hdrlen := 0
|
||||||
|
if hm.datagram {
|
||||||
|
hdrlen = handshakeHeaderLenDTLS
|
||||||
|
} else if start == 0 {
|
||||||
|
hdrlen = handshakeHeaderLenTLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the amount of body we can fit in
|
||||||
|
room -= hdrlen
|
||||||
|
if room == 0 {
|
||||||
|
// This works because we are doing one record per
|
||||||
|
// message
|
||||||
|
panic("Too short max fragment len")
|
||||||
|
}
|
||||||
|
bodylen := len(hm.body) - start
|
||||||
|
if bodylen > room {
|
||||||
|
bodylen = room
|
||||||
|
}
|
||||||
|
body := hm.body[start : start+bodylen]
|
||||||
|
|
||||||
|
// Now see if this chunk has been ACKed. This doesn't produce ideal
|
||||||
|
// retransmission but is simple.
|
||||||
|
if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
|
||||||
|
logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
|
||||||
|
return false, start + bodylen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the data.
|
||||||
|
if hdrlen > 0 {
|
||||||
|
hm2 := *hm
|
||||||
|
hm2.offset = uint32(start)
|
||||||
|
hm2.body = body
|
||||||
|
buf = hm2.Marshal()
|
||||||
|
hm = &hm2
|
||||||
|
} else {
|
||||||
|
buf = body
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.datagram {
|
||||||
|
// Remember that we sent this.
|
||||||
|
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
|
||||||
|
hm.seq,
|
||||||
|
start,
|
||||||
|
len(body),
|
||||||
|
h.conn.cipher.combineSeq(true),
|
||||||
|
false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return true, start + bodylen, h.conn.writeRecordWithPadding(
|
||||||
|
&TLSPlaintext{
|
||||||
|
contentType: RecordTypeHandshake,
|
||||||
|
fragment: buf,
|
||||||
|
},
|
||||||
|
hm.cipher, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
|
||||||
|
start := int(0)
|
||||||
|
|
||||||
|
if len(hm.body) > maxHandshakeMessageLen {
|
||||||
|
return 0, fmt.Errorf("Tried to write a handshake message that's too long")
|
||||||
|
}
|
||||||
|
|
||||||
|
written := 0
|
||||||
|
wrote := false
|
||||||
|
|
||||||
|
// Always make one pass through to allow EOED (which is empty).
|
||||||
|
for {
|
||||||
|
var err error
|
||||||
|
wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if wrote {
|
||||||
|
written++
|
||||||
|
}
|
||||||
|
if start >= len(hm.body) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
|
||||||
|
written := 0
|
||||||
for _, hm := range hms {
|
for _, hm := range hms {
|
||||||
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||||
|
|
||||||
|
wrote, err := h.WriteMessage(hm)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
}
|
}
|
||||||
|
written += wrote
|
||||||
// Write out headers and bodies
|
|
||||||
buffer := []byte{}
|
|
||||||
for _, msg := range hms {
|
|
||||||
msgLen := len(msg.body)
|
|
||||||
if msgLen > maxHandshakeMessageLen {
|
|
||||||
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
|
||||||
}
|
}
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
buffer = append(buffer, msg.Marshal()...)
|
func encodeUint(v uint64, size int, out []byte) []byte {
|
||||||
|
for i := size - 1; i >= 0; i-- {
|
||||||
|
out[i] = byte(v & 0xff)
|
||||||
|
v >>= 8
|
||||||
}
|
}
|
||||||
|
return out[size:]
|
||||||
|
}
|
||||||
|
|
||||||
// Send full-size fragments
|
func decodeUint(in []byte, size int) (uint64, []byte) {
|
||||||
var start int
|
val := uint64(0)
|
||||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
|
||||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
|
||||||
contentType: RecordTypeHandshake,
|
|
||||||
fragment: buffer[start : start+maxFragmentLen],
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send a final partial fragment if necessary
|
|
||||||
if start < len(buffer) {
|
|
||||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
|
||||||
contentType: RecordTypeHandshake,
|
|
||||||
fragment: buffer[start:],
|
|
||||||
})
|
|
||||||
|
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
val <<= 8
|
||||||
|
val += uint64(in[i])
|
||||||
|
}
|
||||||
|
return val, in[size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
type marshalledPDU interface {
|
||||||
|
Marshal() ([]byte, error)
|
||||||
|
Unmarshal(data []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeUnmarshal(pdu marshalledPDU, data []byte) error {
|
||||||
|
read, err := pdu.Unmarshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if len(data) != read {
|
||||||
|
return fmt.Errorf("Invalid encoding: Extra data not consumed")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
109
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
109
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
@ -25,15 +25,14 @@ type HandshakeMessageBody interface {
|
|||||||
// Extension extensions<0..2^16-1>;
|
// Extension extensions<0..2^16-1>;
|
||||||
// } ClientHello;
|
// } ClientHello;
|
||||||
type ClientHelloBody struct {
|
type ClientHelloBody struct {
|
||||||
// Omitted: clientVersion
|
LegacyVersion uint16
|
||||||
// Omitted: legacySessionID
|
|
||||||
// Omitted: legacyCompressionMethods
|
|
||||||
Random [32]byte
|
Random [32]byte
|
||||||
|
LegacySessionID []byte
|
||||||
CipherSuites []CipherSuite
|
CipherSuites []CipherSuite
|
||||||
Extensions ExtensionList
|
Extensions ExtensionList
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientHelloBodyInner struct {
|
type clientHelloBodyInnerTLS struct {
|
||||||
LegacyVersion uint16
|
LegacyVersion uint16
|
||||||
Random [32]byte
|
Random [32]byte
|
||||||
LegacySessionID []byte `tls:"head=1,max=32"`
|
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||||
@ -42,40 +41,86 @@ type clientHelloBodyInner struct {
|
|||||||
Extensions []Extension `tls:"head=2"`
|
Extensions []Extension `tls:"head=2"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type clientHelloBodyInnerDTLS struct {
|
||||||
|
LegacyVersion uint16
|
||||||
|
Random [32]byte
|
||||||
|
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||||
|
EmptyCookie uint8
|
||||||
|
CipherSuites []CipherSuite `tls:"head=2,min=2"`
|
||||||
|
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
|
||||||
|
Extensions []Extension `tls:"head=2"`
|
||||||
|
}
|
||||||
|
|
||||||
func (ch ClientHelloBody) Type() HandshakeType {
|
func (ch ClientHelloBody) Type() HandshakeType {
|
||||||
return HandshakeTypeClientHello
|
return HandshakeTypeClientHello
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
||||||
return syntax.Marshal(clientHelloBodyInner{
|
if ch.LegacyVersion == tls12Version {
|
||||||
LegacyVersion: 0x0303,
|
return syntax.Marshal(clientHelloBodyInnerTLS{
|
||||||
|
LegacyVersion: ch.LegacyVersion,
|
||||||
Random: ch.Random,
|
Random: ch.Random,
|
||||||
LegacySessionID: []byte{},
|
LegacySessionID: []byte{},
|
||||||
CipherSuites: ch.CipherSuites,
|
CipherSuites: ch.CipherSuites,
|
||||||
LegacyCompressionMethods: []byte{0},
|
LegacyCompressionMethods: []byte{0},
|
||||||
Extensions: ch.Extensions,
|
Extensions: ch.Extensions,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
return syntax.Marshal(clientHelloBodyInnerDTLS{
|
||||||
|
LegacyVersion: ch.LegacyVersion,
|
||||||
|
Random: ch.Random,
|
||||||
|
LegacySessionID: []byte{},
|
||||||
|
CipherSuites: ch.CipherSuites,
|
||||||
|
LegacyCompressionMethods: []byte{0},
|
||||||
|
Extensions: ch.Extensions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
||||||
var inner clientHelloBodyInner
|
var read int
|
||||||
read, err := syntax.Unmarshal(data, &inner)
|
var err error
|
||||||
|
|
||||||
|
// Note that this might be 0, in which case we do TLS. That
|
||||||
|
// makes the tests easier.
|
||||||
|
if ch.LegacyVersion != dtls12WireVersion {
|
||||||
|
var inner clientHelloBodyInnerTLS
|
||||||
|
read, err = syntax.Unmarshal(data, &inner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// We are strict about these things because we only support 1.3
|
|
||||||
if inner.LegacyVersion != 0x0303 {
|
|
||||||
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ch.LegacyVersion = inner.LegacyVersion
|
||||||
ch.Random = inner.Random
|
ch.Random = inner.Random
|
||||||
|
ch.LegacySessionID = inner.LegacySessionID
|
||||||
ch.CipherSuites = inner.CipherSuites
|
ch.CipherSuites = inner.CipherSuites
|
||||||
ch.Extensions = inner.Extensions
|
ch.Extensions = inner.Extensions
|
||||||
|
} else {
|
||||||
|
var inner clientHelloBodyInnerDTLS
|
||||||
|
read, err = syntax.Unmarshal(data, &inner)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if inner.EmptyCookie != 0 {
|
||||||
|
return 0, fmt.Errorf("tls.clienthello: Invalid cookie")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||||
|
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.LegacyVersion = inner.LegacyVersion
|
||||||
|
ch.Random = inner.Random
|
||||||
|
ch.LegacySessionID = inner.LegacySessionID
|
||||||
|
ch.CipherSuites = inner.CipherSuites
|
||||||
|
ch.Extensions = inner.Extensions
|
||||||
|
}
|
||||||
return read, nil
|
return read, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,10 +135,15 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
||||||
}
|
}
|
||||||
|
|
||||||
chm, err := HandshakeMessageFromBody(&ch)
|
body, err := ch.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
chm := &HandshakeMessage{
|
||||||
|
msgType: ch.Type(),
|
||||||
|
body: body,
|
||||||
|
length: uint32(len(body)),
|
||||||
|
}
|
||||||
chData := chm.Marshal()
|
chData := chm.Marshal()
|
||||||
|
|
||||||
psk := PreSharedKeyExtension{
|
psk := PreSharedKeyExtension{
|
||||||
@ -116,38 +166,19 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// struct {
|
// struct {
|
||||||
// ProtocolVersion server_version;
|
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
||||||
// CipherSuite cipher_suite;
|
|
||||||
// Extension extensions<2..2^16-1>;
|
|
||||||
// } HelloRetryRequest;
|
|
||||||
type HelloRetryRequestBody struct {
|
|
||||||
Version uint16
|
|
||||||
CipherSuite CipherSuite
|
|
||||||
Extensions ExtensionList `tls:"head=2,min=2"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hrr HelloRetryRequestBody) Type() HandshakeType {
|
|
||||||
return HandshakeTypeHelloRetryRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
|
|
||||||
return syntax.Marshal(hrr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
|
|
||||||
return syntax.Unmarshal(data, hrr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// struct {
|
|
||||||
// ProtocolVersion version;
|
|
||||||
// Random random;
|
// Random random;
|
||||||
|
// opaque legacy_session_id_echo<0..32>;
|
||||||
// CipherSuite cipher_suite;
|
// CipherSuite cipher_suite;
|
||||||
// Extension extensions<0..2^16-1>;
|
// uint8 legacy_compression_method = 0;
|
||||||
|
// Extension extensions<6..2^16-1>;
|
||||||
// } ServerHello;
|
// } ServerHello;
|
||||||
type ServerHelloBody struct {
|
type ServerHelloBody struct {
|
||||||
Version uint16
|
Version uint16
|
||||||
Random [32]byte
|
Random [32]byte
|
||||||
|
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||||
CipherSuite CipherSuite
|
CipherSuite CipherSuite
|
||||||
|
LegacyCompressionMethod uint8
|
||||||
Extensions ExtensionList `tls:"head=2"`
|
Extensions ExtensionList `tls:"head=2"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
11
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
11
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
@ -148,7 +148,7 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(candidatesByName) == 0 {
|
if len(candidatesByName) == 0 {
|
||||||
return nil, 0, fmt.Errorf("No certificates available for server name")
|
return nil, 0, fmt.Errorf("No certificates available for server name: %s", *serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates = candidatesByName
|
candidates = candidatesByName
|
||||||
@ -168,10 +168,11 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme
|
|||||||
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
||||||
}
|
}
|
||||||
|
|
||||||
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
|
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) {
|
||||||
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
|
using = gotEarlyData && usingPSK && allowEarlyData
|
||||||
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
|
rejected = gotEarlyData && !using
|
||||||
return usingEarlyData
|
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
||||||
|
302
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
302
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
@ -1,7 +1,6 @@
|
|||||||
package mint
|
package mint
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -10,7 +9,8 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
sequenceNumberLen = 8 // sequence number length
|
sequenceNumberLen = 8 // sequence number length
|
||||||
recordHeaderLen = 5 // record header length
|
recordHeaderLenTLS = 5 // record header length (TLS)
|
||||||
|
recordHeaderLenDTLS = 13 // record header length (DTLS)
|
||||||
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,9 +20,16 @@ func (err DecryptError) Error() string {
|
|||||||
return string(err)
|
return string(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type direction uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
directionWrite = direction(1)
|
||||||
|
directionRead = direction(2)
|
||||||
|
)
|
||||||
|
|
||||||
// struct {
|
// struct {
|
||||||
// ContentType type;
|
// ContentType type;
|
||||||
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
// ProtocolVersion record_version [0301 for CH, 0303 for others]
|
||||||
// uint16 length;
|
// uint16 length;
|
||||||
// opaque fragment[TLSPlaintext.length];
|
// opaque fragment[TLSPlaintext.length];
|
||||||
// } TLSPlaintext;
|
// } TLSPlaintext;
|
||||||
@ -30,87 +37,177 @@ type TLSPlaintext struct {
|
|||||||
// Omitted: record_version (static)
|
// Omitted: record_version (static)
|
||||||
// Omitted: length (computed from fragment)
|
// Omitted: length (computed from fragment)
|
||||||
contentType RecordType
|
contentType RecordType
|
||||||
|
epoch Epoch
|
||||||
|
seq uint64
|
||||||
fragment []byte
|
fragment []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cipherState struct {
|
||||||
|
epoch Epoch // DTLS epoch
|
||||||
|
ivLength int // Length of the seq and nonce fields
|
||||||
|
seq uint64 // Zero-padded sequence number
|
||||||
|
iv []byte // Buffer for the IV
|
||||||
|
cipher cipher.AEAD // AEAD cipher
|
||||||
|
}
|
||||||
|
|
||||||
type RecordLayer struct {
|
type RecordLayer struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
label string
|
||||||
|
direction direction
|
||||||
|
version uint16 // The current version number
|
||||||
conn io.ReadWriter // The underlying connection
|
conn io.ReadWriter // The underlying connection
|
||||||
frame *frameReader // The buffered frame reader
|
frame *frameReader // The buffered frame reader
|
||||||
nextData []byte // The next record to send
|
nextData []byte // The next record to send
|
||||||
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
||||||
cachedError error // Error on the last record read
|
cachedError error // Error on the last record read
|
||||||
|
|
||||||
ivLength int // Length of the seq and nonce fields
|
cipher *cipherState
|
||||||
seq []byte // Zero-padded sequence number
|
readCiphers map[Epoch]*cipherState
|
||||||
nonce []byte // Buffer for per-record nonces
|
|
||||||
cipher cipher.AEAD // AEAD cipher
|
datagram bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type recordLayerFrameDetails struct{}
|
type recordLayerFrameDetails struct {
|
||||||
|
datagram bool
|
||||||
|
}
|
||||||
|
|
||||||
func (d recordLayerFrameDetails) headerLen() int {
|
func (d recordLayerFrameDetails) headerLen() int {
|
||||||
return recordHeaderLen
|
if d.datagram {
|
||||||
|
return recordHeaderLenDTLS
|
||||||
|
}
|
||||||
|
return recordHeaderLenTLS
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d recordLayerFrameDetails) defaultReadLen() int {
|
func (d recordLayerFrameDetails) defaultReadLen() int {
|
||||||
return recordHeaderLen + maxFragmentLen
|
return d.headerLen() + maxFragmentLen
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||||
return (int(hdr[3]) << 8) | int(hdr[4]), nil
|
return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
|
func newCipherStateNull() *cipherState {
|
||||||
|
return &cipherState{EpochClear, 0, 0, nil, nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) {
|
||||||
|
cipher, err := factory(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer {
|
||||||
r := RecordLayer{}
|
r := RecordLayer{}
|
||||||
|
r.label = ""
|
||||||
|
r.direction = dir
|
||||||
r.conn = conn
|
r.conn = conn
|
||||||
r.frame = newFrameReader(recordLayerFrameDetails{})
|
r.frame = newFrameReader(recordLayerFrameDetails{false})
|
||||||
r.ivLength = 0
|
r.cipher = newCipherStateNull()
|
||||||
|
r.version = tls10Version
|
||||||
return &r
|
return &r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
|
func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer {
|
||||||
var err error
|
r := RecordLayer{}
|
||||||
r.cipher, err = cipher(key)
|
r.label = ""
|
||||||
|
r.direction = dir
|
||||||
|
r.conn = conn
|
||||||
|
r.frame = newFrameReader(recordLayerFrameDetails{true})
|
||||||
|
r.cipher = newCipherStateNull()
|
||||||
|
r.readCiphers = make(map[Epoch]*cipherState, 0)
|
||||||
|
r.readCiphers[0] = r.cipher
|
||||||
|
r.datagram = true
|
||||||
|
return &r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) SetVersion(v uint16) {
|
||||||
|
r.version = v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) ResetClear(seq uint64) {
|
||||||
|
r.cipher = newCipherStateNull()
|
||||||
|
r.cipher.seq = seq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
|
||||||
|
cipher, err := newCipherStateAead(epoch, factory, key, iv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
r.cipher = cipher
|
||||||
r.ivLength = len(iv)
|
if r.datagram && r.direction == directionRead {
|
||||||
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
|
r.readCiphers[epoch] = cipher
|
||||||
r.nonce = make([]byte, r.ivLength)
|
}
|
||||||
copy(r.nonce, iv)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) incrementSequenceNumber() {
|
// TODO(ekr@rtfm.com): This is never used, which is a bug.
|
||||||
if r.ivLength == 0 {
|
func (r *RecordLayer) DiscardReadKey(epoch Epoch) {
|
||||||
|
if !r.datagram {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
|
_, ok := r.readCiphers[epoch]
|
||||||
r.seq[i]++
|
assert(ok)
|
||||||
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
|
delete(r.readCiphers, epoch)
|
||||||
if r.seq[i] != 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not allowed to let sequence number wrap.
|
|
||||||
// Instead, must renegotiate before it does.
|
|
||||||
// Not likely enough to bother.
|
|
||||||
panic("TLS: sequence number wraparound")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
func (c *cipherState) combineSeq(datagram bool) uint64 {
|
||||||
|
seq := c.seq
|
||||||
|
if datagram {
|
||||||
|
seq |= uint64(c.epoch) << 48
|
||||||
|
}
|
||||||
|
return seq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cipherState) computeNonce(seq uint64) []byte {
|
||||||
|
nonce := make([]byte, len(c.iv))
|
||||||
|
copy(nonce, c.iv)
|
||||||
|
|
||||||
|
s := seq
|
||||||
|
|
||||||
|
offset := len(c.iv)
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
nonce[(offset-i)-1] ^= byte(s & 0xff)
|
||||||
|
s >>= 8
|
||||||
|
}
|
||||||
|
logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
|
||||||
|
|
||||||
|
return nonce
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cipherState) incrementSequenceNumber() {
|
||||||
|
if c.seq >= (1<<48 - 1) {
|
||||||
|
// Not allowed to let sequence number wrap.
|
||||||
|
// Instead, must renegotiate before it does.
|
||||||
|
// Not likely enough to bother. This is the
|
||||||
|
// DTLS limit.
|
||||||
|
panic("TLS: sequence number wraparound")
|
||||||
|
}
|
||||||
|
c.seq++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cipherState) overhead() int {
|
||||||
|
if c.cipher == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.cipher.Overhead()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||||
|
assert(r.direction == directionWrite)
|
||||||
|
logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
|
||||||
// Expand the fragment to hold contentType, padding, and overhead
|
// Expand the fragment to hold contentType, padding, and overhead
|
||||||
originalLen := len(pt.fragment)
|
originalLen := len(pt.fragment)
|
||||||
plaintextLen := originalLen + 1 + padLen
|
plaintextLen := originalLen + 1 + padLen
|
||||||
ciphertextLen := plaintextLen + r.cipher.Overhead()
|
ciphertextLen := plaintextLen + cipher.overhead()
|
||||||
|
|
||||||
// Assemble the revised plaintext
|
// Assemble the revised plaintext
|
||||||
out := &TLSPlaintext{
|
out := &TLSPlaintext{
|
||||||
|
|
||||||
contentType: RecordTypeApplicationData,
|
contentType: RecordTypeApplicationData,
|
||||||
fragment: make([]byte, ciphertextLen),
|
fragment: make([]byte, ciphertextLen),
|
||||||
}
|
}
|
||||||
@ -122,25 +219,28 @@ func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
|||||||
|
|
||||||
// Encrypt the fragment
|
// Encrypt the fragment
|
||||||
payload := out.fragment[:plaintextLen]
|
payload := out.fragment[:plaintextLen]
|
||||||
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
|
cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) {
|
||||||
if len(pt.fragment) < r.cipher.Overhead() {
|
assert(r.direction == directionRead)
|
||||||
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
|
logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
|
||||||
|
if len(pt.fragment) < r.cipher.overhead() {
|
||||||
|
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
|
||||||
return nil, 0, DecryptError(msg)
|
return nil, 0, DecryptError(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
decryptLen := len(pt.fragment) - r.cipher.Overhead()
|
decryptLen := len(pt.fragment) - r.cipher.overhead()
|
||||||
out := &TLSPlaintext{
|
out := &TLSPlaintext{
|
||||||
contentType: pt.contentType,
|
contentType: pt.contentType,
|
||||||
fragment: make([]byte, decryptLen),
|
fragment: make([]byte, decryptLen),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt
|
// Decrypt
|
||||||
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
|
_, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
|
||||||
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,6 +255,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
|||||||
|
|
||||||
// Truncate the message to remove contentType, padding, overhead
|
// Truncate the message to remove contentType, padding, overhead
|
||||||
out.fragment = out.fragment[:newLen]
|
out.fragment = out.fragment[:newLen]
|
||||||
|
out.seq = seq
|
||||||
return out, padLen, nil
|
return out, padLen, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,11 +264,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
for {
|
for {
|
||||||
pt, err = r.nextRecord()
|
pt, err = r.nextRecord(false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if !block || err != WouldBlock {
|
if !block || err != AlertWouldBlock {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,7 +276,7 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||||
pt, err := r.nextRecord()
|
pt, err := r.nextRecord(false)
|
||||||
|
|
||||||
// Consume the cached record if there was one
|
// Consume the cached record if there was one
|
||||||
r.cachedRecord = nil
|
r.cachedRecord = nil
|
||||||
@ -184,9 +285,20 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
|||||||
return pt, err
|
return pt, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) {
|
||||||
|
pt, err := r.nextRecord(true)
|
||||||
|
|
||||||
|
// Consume the cached record if there was one
|
||||||
|
r.cachedRecord = nil
|
||||||
|
r.cachedError = nil
|
||||||
|
|
||||||
|
return pt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
|
||||||
|
cipher := r.cipher
|
||||||
if r.cachedRecord != nil {
|
if r.cachedRecord != nil {
|
||||||
logf(logTypeIO, "Returning cached record")
|
logf(logTypeIO, "%s Returning cached record", r.label)
|
||||||
return r.cachedRecord, r.cachedError
|
return r.cachedRecord, r.cachedError
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,34 +306,35 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
|||||||
//
|
//
|
||||||
// 1. We get a frame
|
// 1. We get a frame
|
||||||
// 2. We try to read off the socket and get nothing, in which case
|
// 2. We try to read off the socket and get nothing, in which case
|
||||||
// return WouldBlock
|
// returnAlertWouldBlock
|
||||||
// 3. We get an error.
|
// 3. We get an error.
|
||||||
err := WouldBlock
|
var err error
|
||||||
|
err = AlertWouldBlock
|
||||||
var header, body []byte
|
var header, body []byte
|
||||||
|
|
||||||
for err != nil {
|
for err != nil {
|
||||||
if r.frame.needed() > 0 {
|
if r.frame.needed() > 0 {
|
||||||
buf := make([]byte, recordHeaderLen+maxFragmentLen)
|
buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
|
||||||
n, err := r.conn.Read(buf)
|
n, err := r.conn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeIO, "Error reading, %v", err)
|
logf(logTypeIO, "%s Error reading, %v", r.label, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return nil, WouldBlock
|
return nil, AlertWouldBlock
|
||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeIO, "Read %v bytes", n)
|
logf(logTypeIO, "%s Read %v bytes", r.label, n)
|
||||||
|
|
||||||
buf = buf[:n]
|
buf = buf[:n]
|
||||||
r.frame.addChunk(buf)
|
r.frame.addChunk(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
header, body, err = r.frame.process()
|
header, body, err = r.frame.process()
|
||||||
// Loop around on WouldBlock to see if some
|
// Loop around onAlertWouldBlock to see if some
|
||||||
// data is now available.
|
// data is now available.
|
||||||
if err != nil && err != WouldBlock {
|
if err != nil && err != AlertWouldBlock {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,7 +344,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
|||||||
switch RecordType(header[0]) {
|
switch RecordType(header[0]) {
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
||||||
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
|
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
|
||||||
pt.contentType = RecordType(header[0])
|
pt.contentType = RecordType(header[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,7 +354,8 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate size < max
|
// Validate size < max
|
||||||
size := (int(header[3]) << 8) + int(header[4])
|
size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
|
||||||
|
|
||||||
if size > maxFragmentLen+256 {
|
if size > maxFragmentLen+256 {
|
||||||
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
||||||
}
|
}
|
||||||
@ -249,33 +363,67 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
|||||||
pt.fragment = make([]byte, size)
|
pt.fragment = make([]byte, size)
|
||||||
copy(pt.fragment, body)
|
copy(pt.fragment, body)
|
||||||
|
|
||||||
|
// TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
|
||||||
|
|
||||||
// Attempt to decrypt fragment
|
// Attempt to decrypt fragment
|
||||||
if r.cipher != nil {
|
seq := cipher.seq
|
||||||
pt, _, err = r.decrypt(pt)
|
if r.datagram {
|
||||||
|
// TODO(ekr@rtfm.com): Handle duplicates.
|
||||||
|
seq, _ = decodeUint(header[3:11], 8)
|
||||||
|
epoch := Epoch(seq >> 48)
|
||||||
|
|
||||||
|
// Look up the cipher suite from the epoch
|
||||||
|
c, ok := r.readCiphers[epoch]
|
||||||
|
if !ok {
|
||||||
|
logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
|
||||||
|
return nil, AlertWouldBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
if epoch != cipher.epoch {
|
||||||
|
logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
|
||||||
|
cipher.epoch, allowOldEpoch)
|
||||||
|
if !allowOldEpoch {
|
||||||
|
return nil, AlertWouldBlock
|
||||||
|
}
|
||||||
|
cipher = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cipher.cipher != nil {
|
||||||
|
logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
|
||||||
|
pt, _, err = r.decrypt(pt, seq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logf(logTypeIO, "%s Decryption failed", r.label)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pt.epoch = cipher.epoch
|
||||||
|
|
||||||
// Check that plaintext length is not too long
|
// Check that plaintext length is not too long
|
||||||
if len(pt.fragment) > maxFragmentLen {
|
if len(pt.fragment) > maxFragmentLen {
|
||||||
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
||||||
}
|
}
|
||||||
|
|
||||||
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
|
logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
|
||||||
|
|
||||||
r.cachedRecord = pt
|
r.cachedRecord = pt
|
||||||
r.incrementSequenceNumber()
|
cipher.incrementSequenceNumber()
|
||||||
return pt, nil
|
return pt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
||||||
return r.WriteRecordWithPadding(pt, 0)
|
return r.writeRecordWithPadding(pt, r.cipher, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
||||||
if r.cipher != nil {
|
return r.writeRecordWithPadding(pt, r.cipher, padLen)
|
||||||
pt = r.encrypt(pt, padLen)
|
}
|
||||||
|
|
||||||
|
func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
|
||||||
|
seq := cipher.combineSeq(r.datagram)
|
||||||
|
if cipher.cipher != nil {
|
||||||
|
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
|
||||||
|
pt = r.encrypt(cipher, seq, pt, padLen)
|
||||||
} else if padLen > 0 {
|
} else if padLen > 0 {
|
||||||
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
||||||
}
|
}
|
||||||
@ -285,12 +433,26 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
length := len(pt.fragment)
|
length := len(pt.fragment)
|
||||||
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
|
var header []byte
|
||||||
|
|
||||||
|
if !r.datagram {
|
||||||
|
header = []byte{byte(pt.contentType),
|
||||||
|
byte(r.version >> 8), byte(r.version & 0xff),
|
||||||
|
byte(length >> 8), byte(length)}
|
||||||
|
} else {
|
||||||
|
header = make([]byte, 13)
|
||||||
|
version := dtlsConvertVersion(r.version)
|
||||||
|
copy(header, []byte{byte(pt.contentType),
|
||||||
|
byte(version >> 8), byte(version & 0xff),
|
||||||
|
})
|
||||||
|
encodeUint(seq, 8, header[3:])
|
||||||
|
encodeUint(uint64(length), 2, header[11:])
|
||||||
|
}
|
||||||
record := append(header, pt.fragment...)
|
record := append(header, pt.fragment...)
|
||||||
|
|
||||||
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
|
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
|
||||||
|
|
||||||
r.incrementSequenceNumber()
|
cipher.incrementSequenceNumber()
|
||||||
_, err := r.conn.Write(record)
|
_, err := r.conn.Write(record)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
671
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
671
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
File diff suppressed because it is too large
Load Diff
111
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
111
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
@ -1,6 +1,7 @@
|
|||||||
package mint
|
package mint
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/x509"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -8,32 +9,35 @@ import (
|
|||||||
// state transitions.
|
// state transitions.
|
||||||
type HandshakeAction interface{}
|
type HandshakeAction interface{}
|
||||||
|
|
||||||
type SendHandshakeMessage struct {
|
type QueueHandshakeMessage struct {
|
||||||
Message *HandshakeMessage
|
Message *HandshakeMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SendQueuedHandshake struct{}
|
||||||
|
|
||||||
type SendEarlyData struct{}
|
type SendEarlyData struct{}
|
||||||
|
|
||||||
type ReadEarlyData struct{}
|
|
||||||
|
|
||||||
type ReadPastEarlyData struct{}
|
|
||||||
|
|
||||||
type RekeyIn struct {
|
type RekeyIn struct {
|
||||||
Label string
|
epoch Epoch
|
||||||
KeySet keySet
|
KeySet keySet
|
||||||
}
|
}
|
||||||
|
|
||||||
type RekeyOut struct {
|
type RekeyOut struct {
|
||||||
Label string
|
epoch Epoch
|
||||||
KeySet keySet
|
KeySet keySet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ResetOut struct {
|
||||||
|
seq uint64
|
||||||
|
}
|
||||||
|
|
||||||
type StorePSK struct {
|
type StorePSK struct {
|
||||||
PSK PreSharedKey
|
PSK PreSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
type HandshakeState interface {
|
type HandshakeState interface {
|
||||||
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
|
Next(handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert)
|
||||||
|
State() State
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppExtensionHandler interface {
|
type AppExtensionHandler interface {
|
||||||
@ -41,35 +45,11 @@ type AppExtensionHandler interface {
|
|||||||
Receive(hs HandshakeType, el *ExtensionList) error
|
Receive(hs HandshakeType, el *ExtensionList) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capabilities objects represent the capabilities of a TLS client or server,
|
|
||||||
// as an input to TLS negotiation
|
|
||||||
type Capabilities struct {
|
|
||||||
// For both client and server
|
|
||||||
CipherSuites []CipherSuite
|
|
||||||
Groups []NamedGroup
|
|
||||||
SignatureSchemes []SignatureScheme
|
|
||||||
PSKs PreSharedKeyCache
|
|
||||||
Certificates []*Certificate
|
|
||||||
AuthCertificate func(chain []CertificateEntry) error
|
|
||||||
ExtensionHandler AppExtensionHandler
|
|
||||||
|
|
||||||
// For client
|
|
||||||
PSKModes []PSKKeyExchangeMode
|
|
||||||
|
|
||||||
// For server
|
|
||||||
NextProtos []string
|
|
||||||
AllowEarlyData bool
|
|
||||||
RequireCookie bool
|
|
||||||
CookieHandler CookieHandler
|
|
||||||
RequireClientAuth bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionOptions objects represent per-connection settings for a client
|
// ConnectionOptions objects represent per-connection settings for a client
|
||||||
// initiating a connection
|
// initiating a connection
|
||||||
type ConnectionOptions struct {
|
type ConnectionOptions struct {
|
||||||
ServerName string
|
ServerName string
|
||||||
NextProtos []string
|
NextProtos []string
|
||||||
EarlyData []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionParameters objects represent the parameters negotiated for a
|
// ConnectionParameters objects represent the parameters negotiated for a
|
||||||
@ -79,6 +59,7 @@ type ConnectionParameters struct {
|
|||||||
UsingDH bool
|
UsingDH bool
|
||||||
ClientSendingEarlyData bool
|
ClientSendingEarlyData bool
|
||||||
UsingEarlyData bool
|
UsingEarlyData bool
|
||||||
|
RejectedEarlyData bool
|
||||||
UsingClientAuth bool
|
UsingClientAuth bool
|
||||||
|
|
||||||
CipherSuite CipherSuite
|
CipherSuite CipherSuite
|
||||||
@ -86,18 +67,50 @@ type ConnectionParameters struct {
|
|||||||
NextProto string
|
NextProto string
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateConnected is symmetric between client and server
|
// Working state for the handshake.
|
||||||
type StateConnected struct {
|
type HandshakeContext struct {
|
||||||
|
timeoutMS uint32
|
||||||
|
timers *timerSet
|
||||||
|
recvdRecords []uint64
|
||||||
|
sentFragments []*SentHandshakeFragment
|
||||||
|
hIn, hOut *HandshakeLayer
|
||||||
|
waitingNextFlight bool
|
||||||
|
earlyData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HandshakeContext) SetVersion(version uint16) {
|
||||||
|
if hc.hIn.conn != nil {
|
||||||
|
hc.hIn.conn.SetVersion(version)
|
||||||
|
}
|
||||||
|
if hc.hOut.conn != nil {
|
||||||
|
hc.hOut.conn.SetVersion(version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateConnected is symmetric between client and server
|
||||||
|
type stateConnected struct {
|
||||||
Params ConnectionParameters
|
Params ConnectionParameters
|
||||||
|
hsCtx *HandshakeContext
|
||||||
isClient bool
|
isClient bool
|
||||||
cryptoParams CipherSuiteParams
|
cryptoParams CipherSuiteParams
|
||||||
resumptionSecret []byte
|
resumptionSecret []byte
|
||||||
clientTrafficSecret []byte
|
clientTrafficSecret []byte
|
||||||
serverTrafficSecret []byte
|
serverTrafficSecret []byte
|
||||||
exporterSecret []byte
|
exporterSecret []byte
|
||||||
|
peerCertificates []*x509.Certificate
|
||||||
|
verifiedChains [][]*x509.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
var _ HandshakeState = &stateConnected{}
|
||||||
|
|
||||||
|
func (state stateConnected) State() State {
|
||||||
|
if state.isClient {
|
||||||
|
return StateClientConnected
|
||||||
|
}
|
||||||
|
return StateServerConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *stateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||||
var trafficKeys keySet
|
var trafficKeys keySet
|
||||||
if state.isClient {
|
if state.isClient {
|
||||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||||
@ -109,20 +122,21 @@ func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAct
|
|||||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
kum, err := state.hsCtx.hOut.HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
||||||
return nil, AlertInternalError
|
return nil, AlertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
toSend := []HandshakeAction{
|
toSend := []HandshakeAction{
|
||||||
SendHandshakeMessage{kum},
|
QueueHandshakeMessage{kum},
|
||||||
RekeyOut{Label: "update", KeySet: trafficKeys},
|
SendQueuedHandshake{},
|
||||||
|
RekeyOut{epoch: EpochUpdate, KeySet: trafficKeys},
|
||||||
}
|
}
|
||||||
return toSend, AlertNoAlert
|
return toSend, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
func (state *stateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||||
tkt, err := NewSessionTicket(length, lifetime)
|
tkt, err := NewSessionTicket(length, lifetime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
||||||
@ -149,7 +163,7 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif
|
|||||||
TicketAgeAdd: tkt.TicketAgeAdd,
|
TicketAgeAdd: tkt.TicketAgeAdd,
|
||||||
}
|
}
|
||||||
|
|
||||||
tktm, err := HandshakeMessageFromBody(tkt)
|
tktm, err := state.hsCtx.hOut.HandshakeMessageFromBody(tkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
||||||
return nil, AlertInternalError
|
return nil, AlertInternalError
|
||||||
@ -157,12 +171,18 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif
|
|||||||
|
|
||||||
toSend := []HandshakeAction{
|
toSend := []HandshakeAction{
|
||||||
StorePSK{newPSK},
|
StorePSK{newPSK},
|
||||||
SendHandshakeMessage{tktm},
|
QueueHandshakeMessage{tktm},
|
||||||
|
SendQueuedHandshake{},
|
||||||
}
|
}
|
||||||
return toSend, AlertNoAlert
|
return toSend, AlertNoAlert
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
// Next does nothing for this state.
|
||||||
|
func (state stateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
|
return state, nil, AlertNoAlert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state stateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||||
if hm == nil {
|
if hm == nil {
|
||||||
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
||||||
return nil, nil, AlertUnexpectedMessage
|
return nil, nil, AlertUnexpectedMessage
|
||||||
@ -187,20 +207,18 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh
|
|||||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
|
toSend := []HandshakeAction{RekeyIn{epoch: EpochUpdate, KeySet: trafficKeys}}
|
||||||
|
|
||||||
// If requested, roll outbound keys and send a KeyUpdate
|
// If requested, roll outbound keys and send a KeyUpdate
|
||||||
if body.KeyUpdateRequest == KeyUpdateRequested {
|
if body.KeyUpdateRequest == KeyUpdateRequested {
|
||||||
|
logf(logTypeHandshake, "Received key update, update requested", body.KeyUpdateRequest)
|
||||||
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
||||||
if alert != AlertNoAlert {
|
if alert != AlertNoAlert {
|
||||||
return nil, nil, alert
|
return nil, nil, alert
|
||||||
}
|
}
|
||||||
|
|
||||||
toSend = append(toSend, moreToSend...)
|
toSend = append(toSend, moreToSend...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return state, toSend, AlertNoAlert
|
return state, toSend, AlertNoAlert
|
||||||
|
|
||||||
case *NewSessionTicketBody:
|
case *NewSessionTicketBody:
|
||||||
// XXX: Allow NewSessionTicket in both directions?
|
// XXX: Allow NewSessionTicket in both directions?
|
||||||
if !state.isClient {
|
if !state.isClient {
|
||||||
@ -209,7 +227,6 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh
|
|||||||
|
|
||||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||||
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
||||||
|
|
||||||
psk := PreSharedKey{
|
psk := PreSharedKey{
|
||||||
CipherSuite: state.cryptoParams.Suite,
|
CipherSuite: state.cryptoParams.Suite,
|
||||||
IsResumption: true,
|
IsResumption: true,
|
||||||
|
10
vendor/github.com/bifurcation/mint/syntax/README.md
generated
vendored
10
vendor/github.com/bifurcation/mint/syntax/README.md
generated
vendored
@ -72,3 +72,13 @@ The available annotations right now are all related to vectors:
|
|||||||
fragment[TLSPlaintext.length]`. Note, however, that in cases where the length
|
fragment[TLSPlaintext.length]`. Note, however, that in cases where the length
|
||||||
immediately preceds the array, these can be reframed as vectors with
|
immediately preceds the array, these can be reframed as vectors with
|
||||||
appropriate sizes.
|
appropriate sizes.
|
||||||
|
|
||||||
|
|
||||||
|
QUIC Extensions Syntax
|
||||||
|
======================
|
||||||
|
syntax also supports some minor extensions to allow implementing QUIC.
|
||||||
|
|
||||||
|
* The `varint` annotation describes a QUIC-style varint
|
||||||
|
* `head=none` means no header, i.e., the bytes are encoded directly on the wire.
|
||||||
|
On reading, the decoder will consume all available data.
|
||||||
|
* `head=varint` means to encode the header as a varint
|
||||||
|
177
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
177
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
@ -16,12 +16,22 @@ func Unmarshal(data []byte, v interface{}) (int, error) {
|
|||||||
return d.unmarshal(v)
|
return d.unmarshal(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmarshaler is the interface implemented by types that can
|
||||||
|
// unmarshal a TLS description of themselves. Note that unlike the
|
||||||
|
// JSON unmarshaler interface, it is not known a priori how much of
|
||||||
|
// the input data will be consumed. So the Unmarshaler must state
|
||||||
|
// how much of the input data it consumed.
|
||||||
|
type Unmarshaler interface {
|
||||||
|
UnmarshalTLS([]byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
// These are the options that can be specified in the struct tag. Right now,
|
// These are the options that can be specified in the struct tag. Right now,
|
||||||
// all of them apply to variable-length vectors and nothing else
|
// all of them apply to variable-length vectors and nothing else
|
||||||
type decOpts struct {
|
type decOpts struct {
|
||||||
head uint // length of length in bytes
|
head uint // length of length in bytes
|
||||||
min uint // minimum size in bytes
|
min uint // minimum size in bytes
|
||||||
max uint // maximum size in bytes
|
max uint // maximum size in bytes
|
||||||
|
varint bool // whether to decode as a varint
|
||||||
}
|
}
|
||||||
|
|
||||||
type decodeState struct {
|
type decodeState struct {
|
||||||
@ -65,8 +75,14 @@ func typeDecoder(t reflect.Type) decoderFunc {
|
|||||||
return newTypeDecoder(t)
|
return newTypeDecoder(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
|
||||||
|
)
|
||||||
|
|
||||||
func newTypeDecoder(t reflect.Type) decoderFunc {
|
func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
if t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(unmarshalerType) {
|
||||||
|
return unmarshalerDecoder
|
||||||
|
}
|
||||||
|
|
||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
@ -77,6 +93,8 @@ func newTypeDecoder(t reflect.Type) decoderFunc {
|
|||||||
return newSliceDecoder(t)
|
return newSliceDecoder(t)
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
return newStructDecoder(t)
|
return newStructDecoder(t)
|
||||||
|
case reflect.Ptr:
|
||||||
|
return newPointerDecoder(t)
|
||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||||
}
|
}
|
||||||
@ -84,35 +102,87 @@ func newTypeDecoder(t reflect.Type) decoderFunc {
|
|||||||
|
|
||||||
///// Specific decoders below
|
///// Specific decoders below
|
||||||
|
|
||||||
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
func unmarshalerDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
var uintLen int
|
um, ok := v.Interface().(Unmarshaler)
|
||||||
switch v.Elem().Kind() {
|
if !ok {
|
||||||
case reflect.Uint8:
|
panic(fmt.Errorf("Non-Unmarshaler passed to unmarshalerEncoder"))
|
||||||
uintLen = 1
|
|
||||||
case reflect.Uint16:
|
|
||||||
uintLen = 2
|
|
||||||
case reflect.Uint32:
|
|
||||||
uintLen = 4
|
|
||||||
case reflect.Uint64:
|
|
||||||
uintLen = 8
|
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := make([]byte, uintLen)
|
read, err := um.UnmarshalTLS(d.Bytes())
|
||||||
n, err := d.Read(buf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
if n != uintLen {
|
|
||||||
|
if read > d.Len() {
|
||||||
|
panic(fmt.Errorf("Invalid return value from UnmarshalTLS"))
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Next(read)
|
||||||
|
return read
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
if opts.varint {
|
||||||
|
return varintDecoder(d, v, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
uintLen := int(v.Elem().Type().Size())
|
||||||
|
buf := d.Next(uintLen)
|
||||||
|
if len(buf) != uintLen {
|
||||||
panic(fmt.Errorf("Insufficient data to read uint"))
|
panic(fmt.Errorf("Insufficient data to read uint"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return setUintFromBuffer(v, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func varintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
l, val := readVarint(d)
|
||||||
|
|
||||||
|
uintLen := int(v.Elem().Type().Size())
|
||||||
|
if uintLen < l {
|
||||||
|
panic(fmt.Errorf("Uint too small to fit varint: %d < %d", uintLen, l))
|
||||||
|
}
|
||||||
|
|
||||||
|
v.Elem().SetUint(val)
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func readVarint(d *decodeState) (int, uint64) {
|
||||||
|
// Read the first octet and decide the size of the presented varint
|
||||||
|
first := d.Next(1)
|
||||||
|
if len(first) != 1 {
|
||||||
|
panic(fmt.Errorf("Insufficient data to read varint length"))
|
||||||
|
}
|
||||||
|
|
||||||
|
twoBits := uint(first[0] >> 6)
|
||||||
|
varintLen := 1 << twoBits
|
||||||
|
|
||||||
|
rest := d.Next(varintLen - 1)
|
||||||
|
if len(rest) != varintLen-1 {
|
||||||
|
panic(fmt.Errorf("Insufficient data to read varint"))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := append(first, rest...)
|
||||||
|
buf[0] &= 0x3f
|
||||||
|
|
||||||
|
return len(buf), decodeUintFromBuffer(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeUintFromBuffer(buf []byte) uint64 {
|
||||||
val := uint64(0)
|
val := uint64(0)
|
||||||
for _, b := range buf {
|
for _, b := range buf {
|
||||||
val = (val << 8) + uint64(b)
|
val = (val << 8) + uint64(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
v.Elem().SetUint(val)
|
return val
|
||||||
return uintLen
|
}
|
||||||
|
|
||||||
|
func setUintFromBuffer(v reflect.Value, buf []byte) int {
|
||||||
|
v.Elem().SetUint(decodeUintFromBuffer(buf))
|
||||||
|
return len(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////
|
//////////
|
||||||
@ -143,44 +213,57 @@ type sliceDecoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
var length uint64
|
||||||
|
var read int
|
||||||
|
var data []byte
|
||||||
|
|
||||||
if opts.head == 0 {
|
if opts.head == 0 {
|
||||||
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
lengthBytes := make([]byte, opts.head)
|
// If the caller indicated there is no header, then read everything from the buffer
|
||||||
n, err := d.Read(lengthBytes)
|
if opts.head == headValueNoHead {
|
||||||
if err != nil {
|
for {
|
||||||
panic(err)
|
chunk := d.Next(1024)
|
||||||
|
data = append(data, chunk...)
|
||||||
|
if len(chunk) != 1024 {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if uint(n) != opts.head {
|
|
||||||
panic(fmt.Errorf("Not enough data to read header"))
|
|
||||||
}
|
}
|
||||||
|
length = uint64(len(data))
|
||||||
length := uint(0)
|
if opts.max > 0 && length > uint64(opts.max) {
|
||||||
for _, b := range lengthBytes {
|
|
||||||
length = (length << 8) + uint(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.max > 0 && length > opts.max {
|
|
||||||
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||||
}
|
}
|
||||||
if length < opts.min {
|
if length < uint64(opts.min) {
|
||||||
|
panic(fmt.Errorf("Length of vector below declared min"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if opts.head != headValueVarint {
|
||||||
|
lengthBytes := d.Next(int(opts.head))
|
||||||
|
if len(lengthBytes) != int(opts.head) {
|
||||||
|
panic(fmt.Errorf("Not enough data to read header"))
|
||||||
|
}
|
||||||
|
read = len(lengthBytes)
|
||||||
|
length = decodeUintFromBuffer(lengthBytes)
|
||||||
|
} else {
|
||||||
|
read, length = readVarint(d)
|
||||||
|
}
|
||||||
|
if opts.max > 0 && length > uint64(opts.max) {
|
||||||
|
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||||
|
}
|
||||||
|
if length < uint64(opts.min) {
|
||||||
panic(fmt.Errorf("Length of vector below declared min"))
|
panic(fmt.Errorf("Length of vector below declared min"))
|
||||||
}
|
}
|
||||||
|
|
||||||
data := make([]byte, length)
|
data = d.Next(int(length))
|
||||||
n, err = d.Read(data)
|
if len(data) != int(length) {
|
||||||
if err != nil {
|
panic(fmt.Errorf("Available data less than declared length [%d < %d]", len(data), length))
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
if uint(n) != length {
|
|
||||||
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
elemBuf := &decodeState{}
|
elemBuf := &decodeState{}
|
||||||
elemBuf.Write(data)
|
elemBuf.Write(data)
|
||||||
elems := []reflect.Value{}
|
elems := []reflect.Value{}
|
||||||
read := int(opts.head)
|
|
||||||
for elemBuf.Len() > 0 {
|
for elemBuf.Len() > 0 {
|
||||||
elem := reflect.New(sd.elementType)
|
elem := reflect.New(sd.elementType)
|
||||||
read += sd.elementDec(elemBuf, elem, opts)
|
read += sd.elementDec(elemBuf, elem, opts)
|
||||||
@ -234,6 +317,7 @@ func newStructDecoder(t reflect.Type) decoderFunc {
|
|||||||
head: tagOpts["head"],
|
head: tagOpts["head"],
|
||||||
max: tagOpts["max"],
|
max: tagOpts["max"],
|
||||||
min: tagOpts["min"],
|
min: tagOpts["min"],
|
||||||
|
varint: tagOpts[varintOption] > 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
sd.fieldDecs[i] = typeDecoder(f.Type)
|
sd.fieldDecs[i] = typeDecoder(f.Type)
|
||||||
@ -241,3 +325,20 @@ func newStructDecoder(t reflect.Type) decoderFunc {
|
|||||||
|
|
||||||
return sd.decode
|
return sd.decode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type pointerDecoder struct {
|
||||||
|
base decoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pd *pointerDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||||
|
v.Elem().Set(reflect.New(v.Elem().Type().Elem()))
|
||||||
|
return pd.base(d, v.Elem(), opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPointerDecoder(t reflect.Type) decoderFunc {
|
||||||
|
baseDecoder := typeDecoder(t.Elem())
|
||||||
|
pd := pointerDecoder{base: baseDecoder}
|
||||||
|
return pd.decode
|
||||||
|
}
|
||||||
|
133
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
133
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
@ -16,12 +16,19 @@ func Marshal(v interface{}) ([]byte, error) {
|
|||||||
return e.Bytes(), nil
|
return e.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Marshaler is the interface implemented by types that
|
||||||
|
// have a defined TLS encoding.
|
||||||
|
type Marshaler interface {
|
||||||
|
MarshalTLS() ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
// These are the options that can be specified in the struct tag. Right now,
|
// These are the options that can be specified in the struct tag. Right now,
|
||||||
// all of them apply to variable-length vectors and nothing else
|
// all of them apply to variable-length vectors and nothing else
|
||||||
type encOpts struct {
|
type encOpts struct {
|
||||||
head uint // length of length in bytes
|
head uint // length of length in bytes
|
||||||
min uint // minimum size in bytes
|
min uint // minimum size in bytes
|
||||||
max uint // maximum size in bytes
|
max uint // maximum size in bytes
|
||||||
|
varint bool // whether to encode as a varint
|
||||||
}
|
}
|
||||||
|
|
||||||
type encodeState struct {
|
type encodeState struct {
|
||||||
@ -62,8 +69,14 @@ func typeEncoder(t reflect.Type) encoderFunc {
|
|||||||
return newTypeEncoder(t)
|
return newTypeEncoder(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
|
||||||
|
)
|
||||||
|
|
||||||
func newTypeEncoder(t reflect.Type) encoderFunc {
|
func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
if t.Implements(marshalerType) {
|
||||||
|
return marshalerEncoder
|
||||||
|
}
|
||||||
|
|
||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
@ -74,6 +87,8 @@ func newTypeEncoder(t reflect.Type) encoderFunc {
|
|||||||
return newSliceEncoder(t)
|
return newSliceEncoder(t)
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
return newStructEncoder(t)
|
return newStructEncoder(t)
|
||||||
|
case reflect.Ptr:
|
||||||
|
return newPointerEncoder(t)
|
||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||||
}
|
}
|
||||||
@ -81,19 +96,65 @@ func newTypeEncoder(t reflect.Type) encoderFunc {
|
|||||||
|
|
||||||
///// Specific encoders below
|
///// Specific encoders below
|
||||||
|
|
||||||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
u := v.Uint()
|
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||||
switch v.Type().Kind() {
|
panic(fmt.Errorf("Cannot encode nil pointer"))
|
||||||
case reflect.Uint8:
|
|
||||||
e.WriteByte(byte(u))
|
|
||||||
case reflect.Uint16:
|
|
||||||
e.Write([]byte{byte(u >> 8), byte(u)})
|
|
||||||
case reflect.Uint32:
|
|
||||||
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
|
||||||
case reflect.Uint64:
|
|
||||||
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
|
|
||||||
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m, ok := v.Interface().(Marshaler)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("Non-Marshaler passed to marshalerEncoder"))
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := m.MarshalTLS()
|
||||||
|
if err == nil {
|
||||||
|
_, err = e.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
if opts.varint {
|
||||||
|
varintEncoder(e, v, opts)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeUint(e, v.Uint(), int(v.Type().Size()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func varintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
writeVarint(e, v.Uint())
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeVarint(e *encodeState, u uint64) {
|
||||||
|
if (u >> 62) > 0 {
|
||||||
|
panic(fmt.Errorf("uint value is too big for varint"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var varintLen int
|
||||||
|
for _, len := range []uint{1, 2, 4, 8} {
|
||||||
|
if u < (uint64(1) << (8*len - 2)) {
|
||||||
|
varintLen = int(len)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
twoBits := map[int]uint64{1: 0x00, 2: 0x01, 4: 0x02, 8: 0x03}[varintLen]
|
||||||
|
shift := uint(8*varintLen - 2)
|
||||||
|
writeUint(e, u|(twoBits<<shift), varintLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeUint(e *encodeState, u uint64, len int) {
|
||||||
|
data := make([]byte, len)
|
||||||
|
for i := 0; i < len; i += 1 {
|
||||||
|
data[i] = byte(u >> uint(8*(len-i-1)))
|
||||||
|
}
|
||||||
|
e.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////
|
//////////
|
||||||
@ -121,27 +182,34 @@ type sliceEncoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
if opts.head == 0 {
|
|
||||||
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
|
||||||
}
|
|
||||||
|
|
||||||
arrayState := &encodeState{}
|
arrayState := &encodeState{}
|
||||||
se.ae.encode(arrayState, v, opts)
|
se.ae.encode(arrayState, v, opts)
|
||||||
|
|
||||||
n := uint(arrayState.Len())
|
n := uint(arrayState.Len())
|
||||||
|
if opts.head == 0 {
|
||||||
|
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||||
|
}
|
||||||
|
|
||||||
if opts.max > 0 && n > opts.max {
|
if opts.max > 0 && n > opts.max {
|
||||||
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
||||||
}
|
}
|
||||||
if n>>(8*opts.head) > 0 {
|
|
||||||
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
|
||||||
}
|
|
||||||
if n < opts.min {
|
if n < opts.min {
|
||||||
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := int(opts.head - 1); i >= 0; i -= 1 {
|
switch opts.head {
|
||||||
e.WriteByte(byte(n >> (8 * uint(i))))
|
case headValueNoHead:
|
||||||
|
// None.
|
||||||
|
case headValueVarint:
|
||||||
|
writeVarint(e, uint64(n))
|
||||||
|
default:
|
||||||
|
if n>>(8*opts.head) > 0 {
|
||||||
|
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
writeUint(e, uint64(n), int(opts.head))
|
||||||
|
}
|
||||||
|
|
||||||
e.Write(arrayState.Bytes())
|
e.Write(arrayState.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,9 +247,30 @@ func newStructEncoder(t reflect.Type) encoderFunc {
|
|||||||
head: tagOpts["head"],
|
head: tagOpts["head"],
|
||||||
max: tagOpts["max"],
|
max: tagOpts["max"],
|
||||||
min: tagOpts["min"],
|
min: tagOpts["min"],
|
||||||
|
varint: tagOpts[varintOption] > 0,
|
||||||
}
|
}
|
||||||
se.fieldEncs[i] = typeEncoder(f.Type)
|
se.fieldEncs[i] = typeEncoder(f.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
return se.encode
|
return se.encode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////
|
||||||
|
|
||||||
|
type pointerEncoder struct {
|
||||||
|
base encoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pe pointerEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
|
if v.IsNil() {
|
||||||
|
panic(fmt.Errorf("Cannot marshal a struct containing a nil pointer"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pe.base(e, v.Elem(), opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPointerEncoder(t reflect.Type) encoderFunc {
|
||||||
|
baseEncoder := typeEncoder(t.Elem())
|
||||||
|
pe := pointerEncoder{base: baseEncoder}
|
||||||
|
return pe.encode
|
||||||
|
}
|
||||||
|
28
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
28
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
@ -5,16 +5,27 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// `tls:"head=2,min=2,max=255"`
|
// `tls:"head=2,min=2,max=255,varint"`
|
||||||
|
|
||||||
type tagOptions map[string]uint
|
type tagOptions map[string]uint
|
||||||
|
|
||||||
|
var (
|
||||||
|
varintOption = "varint"
|
||||||
|
|
||||||
|
headOptionNone = "none"
|
||||||
|
headOptionVarint = "varint"
|
||||||
|
headValueNoHead = uint(255)
|
||||||
|
headValueVarint = uint(254)
|
||||||
|
)
|
||||||
|
|
||||||
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
||||||
// name=value pairs, where the values MUST be unsigned integers
|
// name=value pairs, where the values MUST be unsigned integers, or in
|
||||||
|
// the special case of head, "none" or "varint"
|
||||||
func parseTag(tag string) tagOptions {
|
func parseTag(tag string) tagOptions {
|
||||||
opts := tagOptions{}
|
opts := tagOptions{}
|
||||||
for _, token := range strings.Split(tag, ",") {
|
for _, token := range strings.Split(tag, ",") {
|
||||||
if strings.Index(token, "=") == -1 {
|
if token == varintOption {
|
||||||
|
opts[varintOption] = 1
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22,7 +33,16 @@ func parseTag(tag string) tagOptions {
|
|||||||
if len(parts[0]) == 0 {
|
if len(parts[0]) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
|
||||||
|
if len(parts) == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if parts[0] == "head" && parts[1] == headOptionNone {
|
||||||
|
opts[parts[0]] = headValueNoHead
|
||||||
|
} else if parts[0] == "head" && parts[1] == headOptionVarint {
|
||||||
|
opts[parts[0]] = headValueVarint
|
||||||
|
} else if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||||
opts[parts[0]] = uint(val)
|
opts[parts[0]] = uint(val)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
122
vendor/github.com/bifurcation/mint/timer.go
generated
vendored
Normal file
122
vendor/github.com/bifurcation/mint/timer.go
generated
vendored
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
package mint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a simple timer implementation. Timers are stored in a sorted
|
||||||
|
// list.
|
||||||
|
// TODO(ekr@rtfm.com): Add a way to uncouple these from the system
|
||||||
|
// clock.
|
||||||
|
type timerCb func() error
|
||||||
|
|
||||||
|
type timer struct {
|
||||||
|
label string
|
||||||
|
cb timerCb
|
||||||
|
deadline time.Time
|
||||||
|
duration uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type timerSet struct {
|
||||||
|
ts []*timer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTimerSet() *timerSet {
|
||||||
|
return &timerSet{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer {
|
||||||
|
now := time.Now()
|
||||||
|
t := timer{
|
||||||
|
label,
|
||||||
|
cb,
|
||||||
|
now.Add(time.Millisecond * time.Duration(delayMs)),
|
||||||
|
delayMs,
|
||||||
|
}
|
||||||
|
logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline)
|
||||||
|
|
||||||
|
var i int
|
||||||
|
ntimers := len(ts.ts)
|
||||||
|
for i = 0; i < ntimers; i++ {
|
||||||
|
if t.deadline.Before(ts.ts[i].deadline) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp := make([]*timer, 0, ntimers+1)
|
||||||
|
tmp = append(tmp, ts.ts[:i]...)
|
||||||
|
tmp = append(tmp, &t)
|
||||||
|
tmp = append(tmp, ts.ts[i:]...)
|
||||||
|
ts.ts = tmp
|
||||||
|
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ekr@rtfm.com): optimize this now that the list is sorted.
|
||||||
|
// We should be able to do just one list manipulation, as long
|
||||||
|
// as we're careful about how we handle inserts during callbacks.
|
||||||
|
func (ts *timerSet) check(now time.Time) error {
|
||||||
|
for i, t := range ts.ts {
|
||||||
|
if now.After(t.deadline) {
|
||||||
|
ts.ts = append(ts.ts[:i], ts.ts[:i+1]...)
|
||||||
|
if t.cb != nil {
|
||||||
|
logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline)
|
||||||
|
cb := t.cb
|
||||||
|
t.cb = nil
|
||||||
|
err := cb()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the next time any of the timers would fire.
|
||||||
|
func (ts *timerSet) remaining() (bool, time.Duration) {
|
||||||
|
for _, t := range ts.ts {
|
||||||
|
if t.cb != nil {
|
||||||
|
return true, time.Until(t.deadline)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, time.Duration(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *timerSet) cancel(label string) {
|
||||||
|
for _, t := range ts.ts {
|
||||||
|
if t.label == label {
|
||||||
|
t.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *timerSet) getTimer(label string) *timer {
|
||||||
|
for _, t := range ts.ts {
|
||||||
|
if t.label == label && t.cb != nil {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *timerSet) getAllTimers() []string {
|
||||||
|
var ret []string
|
||||||
|
|
||||||
|
for _, t := range ts.ts {
|
||||||
|
if t.cb != nil {
|
||||||
|
ret = append(ret, t.label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *timer) cancel() {
|
||||||
|
logf(logTypeHandshake, "Timer %s cancelled", t.label)
|
||||||
|
t.cb = nil
|
||||||
|
t.label = ""
|
||||||
|
}
|
25
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
25
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
@ -51,11 +51,14 @@ func (l *Listener) Accept() (c net.Conn, err error) {
|
|||||||
// Listener and wraps each connection with Server.
|
// Listener and wraps each connection with Server.
|
||||||
// The configuration config must be non-nil and must include
|
// The configuration config must be non-nil and must include
|
||||||
// at least one certificate or else set GetCertificate.
|
// at least one certificate or else set GetCertificate.
|
||||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
func NewListener(inner net.Listener, config *Config) (net.Listener, error) {
|
||||||
|
if config != nil && config.NonBlocking {
|
||||||
|
return nil, errors.New("listening not possible in non-blocking mode")
|
||||||
|
}
|
||||||
l := new(Listener)
|
l := new(Listener)
|
||||||
l.Listener = inner
|
l.Listener = inner
|
||||||
l.config = config
|
l.config = config
|
||||||
return l
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen creates a TLS listener accepting connections on the
|
// Listen creates a TLS listener accepting connections on the
|
||||||
@ -70,7 +73,7 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewListener(l, config), nil
|
return NewListener(l, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TimeoutError struct{}
|
type TimeoutError struct{}
|
||||||
@ -87,6 +90,10 @@ func (TimeoutError) Temporary() bool { return true }
|
|||||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||||
// configuration; see the documentation of Config for the defaults.
|
// configuration; see the documentation of Config for the defaults.
|
||||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
|
if config != nil && config.NonBlocking {
|
||||||
|
return nil, errors.New("dialing not possible in non-blocking mode")
|
||||||
|
}
|
||||||
|
|
||||||
// We want the Timeout and Deadline values from dialer to cover the
|
// We want the Timeout and Deadline values from dialer to cover the
|
||||||
// whole process: TCP connection and TLS handshake. This means that we
|
// whole process: TCP connection and TLS handshake. This means that we
|
||||||
// also need to start our own timers now.
|
// also need to start our own timers now.
|
||||||
@ -121,16 +128,20 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
|
|||||||
|
|
||||||
if config == nil {
|
if config == nil {
|
||||||
config = &Config{}
|
config = &Config{}
|
||||||
|
} else {
|
||||||
|
config = config.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no ServerName is set, infer the ServerName
|
// If no ServerName is set, infer the ServerName
|
||||||
// from the hostname we're connecting to.
|
// from the hostname we're connecting to.
|
||||||
if config.ServerName == "" {
|
if config.ServerName == "" {
|
||||||
// Make a copy to avoid polluting argument or default.
|
config.ServerName = hostname
|
||||||
c := config.Clone()
|
|
||||||
c.ServerName = hostname
|
|
||||||
config = c
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up DTLS as needed.
|
||||||
|
config.UseDTLS = (network == "udp")
|
||||||
|
|
||||||
conn := Client(rawConn, config)
|
conn := Client(rawConn, config)
|
||||||
|
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
|
22
vendor/github.com/cheekybits/genny/LICENSE
generated
vendored
Normal file
22
vendor/github.com/cheekybits/genny/LICENSE
generated
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2014 cheekybits
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
2
vendor/github.com/cheekybits/genny/generic/doc.go
generated
vendored
Normal file
2
vendor/github.com/cheekybits/genny/generic/doc.go
generated
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
// Package generic contains the generic marker types.
|
||||||
|
package generic
|
13
vendor/github.com/cheekybits/genny/generic/generic.go
generated
vendored
Normal file
13
vendor/github.com/cheekybits/genny/generic/generic.go
generated
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package generic
|
||||||
|
|
||||||
|
// Type is the placeholder type that indicates a generic value.
|
||||||
|
// When genny is executed, variables of this type will be replaced with
|
||||||
|
// references to the specific types.
|
||||||
|
// var GenericType generic.Type
|
||||||
|
type Type interface{}
|
||||||
|
|
||||||
|
// Number is the placehoder type that indiccates a generic numerical value.
|
||||||
|
// When genny is executed, variables of this type will be replaced with
|
||||||
|
// references to the specific types.
|
||||||
|
// var GenericType generic.Number
|
||||||
|
type Number float64
|
21
vendor/github.com/hashicorp/golang-lru/2q.go
generated
vendored
21
vendor/github.com/hashicorp/golang-lru/2q.go
generated
vendored
@ -30,9 +30,9 @@ type TwoQueueCache struct {
|
|||||||
size int
|
size int
|
||||||
recentSize int
|
recentSize int
|
||||||
|
|
||||||
recent *simplelru.LRU
|
recent simplelru.LRUCache
|
||||||
frequent *simplelru.LRU
|
frequent simplelru.LRUCache
|
||||||
recentEvict *simplelru.LRU
|
recentEvict simplelru.LRUCache
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +84,8 @@ func New2QParams(size int, recentRatio float64, ghostRatio float64) (*TwoQueueCa
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) {
|
// Get looks up a key's value from the cache.
|
||||||
|
func (c *TwoQueueCache) Get(key interface{}) (value interface{}, ok bool) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
@ -105,6 +106,7 @@ func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add adds a value to the cache.
|
||||||
func (c *TwoQueueCache) Add(key, value interface{}) {
|
func (c *TwoQueueCache) Add(key, value interface{}) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
@ -160,12 +162,15 @@ func (c *TwoQueueCache) ensureSpace(recentEvict bool) {
|
|||||||
c.frequent.RemoveOldest()
|
c.frequent.RemoveOldest()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Len returns the number of items in the cache.
|
||||||
func (c *TwoQueueCache) Len() int {
|
func (c *TwoQueueCache) Len() int {
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
return c.recent.Len() + c.frequent.Len()
|
return c.recent.Len() + c.frequent.Len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Keys returns a slice of the keys in the cache.
|
||||||
|
// The frequently used keys are first in the returned slice.
|
||||||
func (c *TwoQueueCache) Keys() []interface{} {
|
func (c *TwoQueueCache) Keys() []interface{} {
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
@ -174,6 +179,7 @@ func (c *TwoQueueCache) Keys() []interface{} {
|
|||||||
return append(k1, k2...)
|
return append(k1, k2...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove removes the provided key from the cache.
|
||||||
func (c *TwoQueueCache) Remove(key interface{}) {
|
func (c *TwoQueueCache) Remove(key interface{}) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
@ -188,6 +194,7 @@ func (c *TwoQueueCache) Remove(key interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Purge is used to completely clear the cache.
|
||||||
func (c *TwoQueueCache) Purge() {
|
func (c *TwoQueueCache) Purge() {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
@ -196,13 +203,17 @@ func (c *TwoQueueCache) Purge() {
|
|||||||
c.recentEvict.Purge()
|
c.recentEvict.Purge()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Contains is used to check if the cache contains a key
|
||||||
|
// without updating recency or frequency.
|
||||||
func (c *TwoQueueCache) Contains(key interface{}) bool {
|
func (c *TwoQueueCache) Contains(key interface{}) bool {
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
return c.frequent.Contains(key) || c.recent.Contains(key)
|
return c.frequent.Contains(key) || c.recent.Contains(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TwoQueueCache) Peek(key interface{}) (interface{}, bool) {
|
// Peek is used to inspect the cache value of a key
|
||||||
|
// without updating recency or frequency.
|
||||||
|
func (c *TwoQueueCache) Peek(key interface{}) (value interface{}, ok bool) {
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
if val, ok := c.frequent.Peek(key); ok {
|
if val, ok := c.frequent.Peek(key); ok {
|
||||||
|
16
vendor/github.com/hashicorp/golang-lru/arc.go
generated
vendored
16
vendor/github.com/hashicorp/golang-lru/arc.go
generated
vendored
@ -18,11 +18,11 @@ type ARCCache struct {
|
|||||||
size int // Size is the total capacity of the cache
|
size int // Size is the total capacity of the cache
|
||||||
p int // P is the dynamic preference towards T1 or T2
|
p int // P is the dynamic preference towards T1 or T2
|
||||||
|
|
||||||
t1 *simplelru.LRU // T1 is the LRU for recently accessed items
|
t1 simplelru.LRUCache // T1 is the LRU for recently accessed items
|
||||||
b1 *simplelru.LRU // B1 is the LRU for evictions from t1
|
b1 simplelru.LRUCache // B1 is the LRU for evictions from t1
|
||||||
|
|
||||||
t2 *simplelru.LRU // T2 is the LRU for frequently accessed items
|
t2 simplelru.LRUCache // T2 is the LRU for frequently accessed items
|
||||||
b2 *simplelru.LRU // B2 is the LRU for evictions from t2
|
b2 simplelru.LRUCache // B2 is the LRU for evictions from t2
|
||||||
|
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
@ -60,11 +60,11 @@ func NewARC(size int) (*ARCCache, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get looks up a key's value from the cache.
|
// Get looks up a key's value from the cache.
|
||||||
func (c *ARCCache) Get(key interface{}) (interface{}, bool) {
|
func (c *ARCCache) Get(key interface{}) (value interface{}, ok bool) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
// Ff the value is contained in T1 (recent), then
|
// If the value is contained in T1 (recent), then
|
||||||
// promote it to T2 (frequent)
|
// promote it to T2 (frequent)
|
||||||
if val, ok := c.t1.Peek(key); ok {
|
if val, ok := c.t1.Peek(key); ok {
|
||||||
c.t1.Remove(key)
|
c.t1.Remove(key)
|
||||||
@ -153,7 +153,7 @@ func (c *ARCCache) Add(key, value interface{}) {
|
|||||||
// Remove from B2
|
// Remove from B2
|
||||||
c.b2.Remove(key)
|
c.b2.Remove(key)
|
||||||
|
|
||||||
// Add the key to the frequntly used list
|
// Add the key to the frequently used list
|
||||||
c.t2.Add(key, value)
|
c.t2.Add(key, value)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -247,7 +247,7 @@ func (c *ARCCache) Contains(key interface{}) bool {
|
|||||||
|
|
||||||
// Peek is used to inspect the cache value of a key
|
// Peek is used to inspect the cache value of a key
|
||||||
// without updating recency or frequency.
|
// without updating recency or frequency.
|
||||||
func (c *ARCCache) Peek(key interface{}) (interface{}, bool) {
|
func (c *ARCCache) Peek(key interface{}) (value interface{}, ok bool) {
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
if val, ok := c.t1.Peek(key); ok {
|
if val, ok := c.t1.Peek(key); ok {
|
||||||
|
21
vendor/github.com/hashicorp/golang-lru/doc.go
generated
vendored
Normal file
21
vendor/github.com/hashicorp/golang-lru/doc.go
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
// Package lru provides three different LRU caches of varying sophistication.
|
||||||
|
//
|
||||||
|
// Cache is a simple LRU cache. It is based on the
|
||||||
|
// LRU implementation in groupcache:
|
||||||
|
// https://github.com/golang/groupcache/tree/master/lru
|
||||||
|
//
|
||||||
|
// TwoQueueCache tracks frequently used and recently used entries separately.
|
||||||
|
// This avoids a burst of accesses from taking out frequently used entries,
|
||||||
|
// at the cost of about 2x computational overhead and some extra bookkeeping.
|
||||||
|
//
|
||||||
|
// ARCCache is an adaptive replacement cache. It tracks recent evictions as
|
||||||
|
// well as recent usage in both the frequent and recent caches. Its
|
||||||
|
// computational overhead is comparable to TwoQueueCache, but the memory
|
||||||
|
// overhead is linear with the size of the cache.
|
||||||
|
//
|
||||||
|
// ARC has been patented by IBM, so do not use it if that is problematic for
|
||||||
|
// your program.
|
||||||
|
//
|
||||||
|
// All caches in this package take locks while operating, and are therefore
|
||||||
|
// thread-safe for consumers.
|
||||||
|
package lru
|
28
vendor/github.com/hashicorp/golang-lru/lru.go
generated
vendored
28
vendor/github.com/hashicorp/golang-lru/lru.go
generated
vendored
@ -1,6 +1,3 @@
|
|||||||
// This package provides a simple LRU cache. It is based on the
|
|
||||||
// LRU implementation in groupcache:
|
|
||||||
// https://github.com/golang/groupcache/tree/master/lru
|
|
||||||
package lru
|
package lru
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -11,11 +8,11 @@ import (
|
|||||||
|
|
||||||
// Cache is a thread-safe fixed size LRU cache.
|
// Cache is a thread-safe fixed size LRU cache.
|
||||||
type Cache struct {
|
type Cache struct {
|
||||||
lru *simplelru.LRU
|
lru simplelru.LRUCache
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates an LRU of the given size
|
// New creates an LRU of the given size.
|
||||||
func New(size int) (*Cache, error) {
|
func New(size int) (*Cache, error) {
|
||||||
return NewWithEvict(size, nil)
|
return NewWithEvict(size, nil)
|
||||||
}
|
}
|
||||||
@ -33,7 +30,7 @@ func NewWithEvict(size int, onEvicted func(key interface{}, value interface{}))
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purge is used to completely clear the cache
|
// Purge is used to completely clear the cache.
|
||||||
func (c *Cache) Purge() {
|
func (c *Cache) Purge() {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
c.lru.Purge()
|
c.lru.Purge()
|
||||||
@ -41,30 +38,30 @@ func (c *Cache) Purge() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||||
func (c *Cache) Add(key, value interface{}) bool {
|
func (c *Cache) Add(key, value interface{}) (evicted bool) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
return c.lru.Add(key, value)
|
return c.lru.Add(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get looks up a key's value from the cache.
|
// Get looks up a key's value from the cache.
|
||||||
func (c *Cache) Get(key interface{}) (interface{}, bool) {
|
func (c *Cache) Get(key interface{}) (value interface{}, ok bool) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
return c.lru.Get(key)
|
return c.lru.Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if a key is in the cache, without updating the recent-ness
|
// Contains checks if a key is in the cache, without updating the
|
||||||
// or deleting it for being stale.
|
// recent-ness or deleting it for being stale.
|
||||||
func (c *Cache) Contains(key interface{}) bool {
|
func (c *Cache) Contains(key interface{}) bool {
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
return c.lru.Contains(key)
|
return c.lru.Contains(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the key value (or undefined if not found) without updating
|
// Peek returns the key value (or undefined if not found) without updating
|
||||||
// the "recently used"-ness of the key.
|
// the "recently used"-ness of the key.
|
||||||
func (c *Cache) Peek(key interface{}) (interface{}, bool) {
|
func (c *Cache) Peek(key interface{}) (value interface{}, ok bool) {
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
return c.lru.Peek(key)
|
return c.lru.Peek(key)
|
||||||
@ -73,16 +70,15 @@ func (c *Cache) Peek(key interface{}) (interface{}, bool) {
|
|||||||
// ContainsOrAdd checks if a key is in the cache without updating the
|
// ContainsOrAdd checks if a key is in the cache without updating the
|
||||||
// recent-ness or deleting it for being stale, and if not, adds the value.
|
// recent-ness or deleting it for being stale, and if not, adds the value.
|
||||||
// Returns whether found and whether an eviction occurred.
|
// Returns whether found and whether an eviction occurred.
|
||||||
func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evict bool) {
|
func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evicted bool) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
defer c.lock.Unlock()
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
if c.lru.Contains(key) {
|
if c.lru.Contains(key) {
|
||||||
return true, false
|
return true, false
|
||||||
} else {
|
|
||||||
evict := c.lru.Add(key, value)
|
|
||||||
return false, evict
|
|
||||||
}
|
}
|
||||||
|
evicted = c.lru.Add(key, value)
|
||||||
|
return false, evicted
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove removes the provided key from the cache.
|
// Remove removes the provided key from the cache.
|
||||||
|
17
vendor/github.com/hashicorp/golang-lru/simplelru/lru.go
generated
vendored
17
vendor/github.com/hashicorp/golang-lru/simplelru/lru.go
generated
vendored
@ -36,7 +36,7 @@ func NewLRU(size int, onEvict EvictCallback) (*LRU, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purge is used to completely clear the cache
|
// Purge is used to completely clear the cache.
|
||||||
func (c *LRU) Purge() {
|
func (c *LRU) Purge() {
|
||||||
for k, v := range c.items {
|
for k, v := range c.items {
|
||||||
if c.onEvict != nil {
|
if c.onEvict != nil {
|
||||||
@ -48,7 +48,7 @@ func (c *LRU) Purge() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||||
func (c *LRU) Add(key, value interface{}) bool {
|
func (c *LRU) Add(key, value interface{}) (evicted bool) {
|
||||||
// Check for existing item
|
// Check for existing item
|
||||||
if ent, ok := c.items[key]; ok {
|
if ent, ok := c.items[key]; ok {
|
||||||
c.evictList.MoveToFront(ent)
|
c.evictList.MoveToFront(ent)
|
||||||
@ -78,17 +78,18 @@ func (c *LRU) Get(key interface{}) (value interface{}, ok bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if a key is in the cache, without updating the recent-ness
|
// Contains checks if a key is in the cache, without updating the recent-ness
|
||||||
// or deleting it for being stale.
|
// or deleting it for being stale.
|
||||||
func (c *LRU) Contains(key interface{}) (ok bool) {
|
func (c *LRU) Contains(key interface{}) (ok bool) {
|
||||||
_, ok = c.items[key]
|
_, ok = c.items[key]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the key value (or undefined if not found) without updating
|
// Peek returns the key value (or undefined if not found) without updating
|
||||||
// the "recently used"-ness of the key.
|
// the "recently used"-ness of the key.
|
||||||
func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
|
func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
|
||||||
if ent, ok := c.items[key]; ok {
|
var ent *list.Element
|
||||||
|
if ent, ok = c.items[key]; ok {
|
||||||
return ent.Value.(*entry).value, true
|
return ent.Value.(*entry).value, true
|
||||||
}
|
}
|
||||||
return nil, ok
|
return nil, ok
|
||||||
@ -96,7 +97,7 @@ func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
|
|||||||
|
|
||||||
// Remove removes the provided key from the cache, returning if the
|
// Remove removes the provided key from the cache, returning if the
|
||||||
// key was contained.
|
// key was contained.
|
||||||
func (c *LRU) Remove(key interface{}) bool {
|
func (c *LRU) Remove(key interface{}) (present bool) {
|
||||||
if ent, ok := c.items[key]; ok {
|
if ent, ok := c.items[key]; ok {
|
||||||
c.removeElement(ent)
|
c.removeElement(ent)
|
||||||
return true
|
return true
|
||||||
@ -105,7 +106,7 @@ func (c *LRU) Remove(key interface{}) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOldest removes the oldest item from the cache.
|
// RemoveOldest removes the oldest item from the cache.
|
||||||
func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) {
|
func (c *LRU) RemoveOldest() (key interface{}, value interface{}, ok bool) {
|
||||||
ent := c.evictList.Back()
|
ent := c.evictList.Back()
|
||||||
if ent != nil {
|
if ent != nil {
|
||||||
c.removeElement(ent)
|
c.removeElement(ent)
|
||||||
@ -116,7 +117,7 @@ func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetOldest returns the oldest entry
|
// GetOldest returns the oldest entry
|
||||||
func (c *LRU) GetOldest() (interface{}, interface{}, bool) {
|
func (c *LRU) GetOldest() (key interface{}, value interface{}, ok bool) {
|
||||||
ent := c.evictList.Back()
|
ent := c.evictList.Back()
|
||||||
if ent != nil {
|
if ent != nil {
|
||||||
kv := ent.Value.(*entry)
|
kv := ent.Value.(*entry)
|
||||||
|
37
vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go
generated
vendored
Normal file
37
vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go
generated
vendored
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package simplelru
|
||||||
|
|
||||||
|
|
||||||
|
// LRUCache is the interface for simple LRU cache.
|
||||||
|
type LRUCache interface {
|
||||||
|
// Adds a value to the cache, returns true if an eviction occurred and
|
||||||
|
// updates the "recently used"-ness of the key.
|
||||||
|
Add(key, value interface{}) bool
|
||||||
|
|
||||||
|
// Returns key's value from the cache and
|
||||||
|
// updates the "recently used"-ness of the key. #value, isFound
|
||||||
|
Get(key interface{}) (value interface{}, ok bool)
|
||||||
|
|
||||||
|
// Check if a key exsists in cache without updating the recent-ness.
|
||||||
|
Contains(key interface{}) (ok bool)
|
||||||
|
|
||||||
|
// Returns key's value without updating the "recently used"-ness of the key.
|
||||||
|
Peek(key interface{}) (value interface{}, ok bool)
|
||||||
|
|
||||||
|
// Removes a key from the cache.
|
||||||
|
Remove(key interface{}) bool
|
||||||
|
|
||||||
|
// Removes the oldest entry from cache.
|
||||||
|
RemoveOldest() (interface{}, interface{}, bool)
|
||||||
|
|
||||||
|
// Returns the oldest entry from the cache. #key, value, isFound
|
||||||
|
GetOldest() (interface{}, interface{}, bool)
|
||||||
|
|
||||||
|
// Returns a slice of the keys in the cache, from oldest to newest.
|
||||||
|
Keys() []interface{}
|
||||||
|
|
||||||
|
// Returns the number of items in the cache.
|
||||||
|
Len() int
|
||||||
|
|
||||||
|
// Clear all cache entries
|
||||||
|
Purge()
|
||||||
|
}
|
21
vendor/github.com/isofew/go-stun/LICENSE
generated
vendored
Normal file
21
vendor/github.com/isofew/go-stun/LICENSE
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2016 Vasily Vasilyev
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
284
vendor/github.com/isofew/go-stun/stun/agent.go
generated
vendored
Normal file
284
vendor/github.com/isofew/go-stun/stun/agent.go
generated
vendored
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var DefaultConfig = &Config{
|
||||||
|
RetransmissionTimeout: 500 * time.Millisecond,
|
||||||
|
TransactionTimeout: 39500 * time.Millisecond,
|
||||||
|
Software: "pixelbender/go-stun",
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handler interface {
|
||||||
|
ServeSTUN(msg *Message, tr Transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandlerFunc func(msg *Message, tr Transport)
|
||||||
|
|
||||||
|
func (h HandlerFunc) ServeSTUN(msg *Message, tr Transport) {
|
||||||
|
h(msg, tr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
// AuthMethod returns a key for MESSAGE-INTEGRITY attribute
|
||||||
|
AuthMethod AuthMethod
|
||||||
|
// Retransmission timeout, default is 500 milliseconds
|
||||||
|
RetransmissionTimeout time.Duration
|
||||||
|
// Transaction timeout, default is 39.5 seconds
|
||||||
|
TransactionTimeout time.Duration
|
||||||
|
// Fingerprint, if true all outgoing messages contain FINGERPRINT attribute
|
||||||
|
Fingerprint bool
|
||||||
|
// Software is a SOFTWARE attribute value for outgoing messages, if not empty
|
||||||
|
Software string
|
||||||
|
// Logf, if set all sent and received messages printed using Logf
|
||||||
|
Logf func(format string, args ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) attrs() []Attr {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var a []Attr
|
||||||
|
if c.Software != "" {
|
||||||
|
a = append(a, String(AttrSoftware, c.Software))
|
||||||
|
}
|
||||||
|
if c.Fingerprint {
|
||||||
|
a = append(a, Fingerprint)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Clone() *Config {
|
||||||
|
r := *c
|
||||||
|
return &r
|
||||||
|
}
|
||||||
|
|
||||||
|
type Agent struct {
|
||||||
|
config *Config
|
||||||
|
Handler Handler
|
||||||
|
m mux
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAgent(config *Config) *Agent {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultConfig
|
||||||
|
}
|
||||||
|
return &Agent{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) Send(msg *Message, tr Transport) (err error) {
|
||||||
|
msg = &Message{
|
||||||
|
msg.Type,
|
||||||
|
msg.Transaction,
|
||||||
|
append(a.config.attrs(), msg.Attributes...),
|
||||||
|
}
|
||||||
|
if log := a.config.Logf; log != nil {
|
||||||
|
log("%v → %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
|
||||||
|
}
|
||||||
|
b := msg.Marshal(getBuffer()[:0])
|
||||||
|
_, err = tr.Write(b)
|
||||||
|
putBuffer(b)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) ServeConn(c net.Conn, stop chan struct{}) error {
|
||||||
|
if c, ok := c.(net.PacketConn); ok {
|
||||||
|
return a.ServePacket(c, stop)
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
b = getBuffer()
|
||||||
|
p int
|
||||||
|
)
|
||||||
|
defer putBuffer(b)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if p >= len(b) {
|
||||||
|
return errBufferOverflow
|
||||||
|
}
|
||||||
|
n, err := c.Read(b[p:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p += n
|
||||||
|
n = 0
|
||||||
|
for n < p {
|
||||||
|
r, err := a.ServeTransport(b[n:p], c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n += r
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
if n < p {
|
||||||
|
p = copy(b, b[n:p])
|
||||||
|
} else {
|
||||||
|
p = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) ServePacket(c net.PacketConn, stop chan struct{}) error {
|
||||||
|
b := getBuffer()
|
||||||
|
defer putBuffer(b)
|
||||||
|
// don't close the connection since we're going to reuse it
|
||||||
|
// defer c.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
n, addr, err := c.ReadFrom(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
a.ServeTransport(b[:n], &packetConn{c, addr})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) ServeTransport(b []byte, tr Transport) (n int, err error) {
|
||||||
|
msg := &Message{}
|
||||||
|
n, err = msg.Unmarshal(b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.ServeSTUN(msg, tr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) ServeSTUN(msg *Message, tr Transport) {
|
||||||
|
if log := a.config.Logf; log != nil {
|
||||||
|
log("%v ← %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
|
||||||
|
}
|
||||||
|
if a.m.serve(msg, tr) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h := a.Handler; h != nil {
|
||||||
|
go h.ServeSTUN(msg, tr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) RoundTrip(req *Message, to Transport) (res *Message, from Transport, err error) {
|
||||||
|
var (
|
||||||
|
start = time.Now()
|
||||||
|
rto = a.config.RetransmissionTimeout
|
||||||
|
udp = to.LocalAddr().Network() == "udp"
|
||||||
|
tx = a.m.newTx()
|
||||||
|
)
|
||||||
|
defer a.m.closeTx(tx)
|
||||||
|
req = &Message{req.Type, tx.id, req.Attributes}
|
||||||
|
if err = a.Send(req, to); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
d := a.config.TransactionTimeout - time.Since(start)
|
||||||
|
if d < 0 {
|
||||||
|
err = errTimeout
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if udp && d > rto {
|
||||||
|
d = rto
|
||||||
|
}
|
||||||
|
res, from, err = tx.Receive(d)
|
||||||
|
if udp && err == errTimeout && d == rto {
|
||||||
|
rto <<= 1
|
||||||
|
a.Send(req, to)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mux struct {
|
||||||
|
sync.RWMutex
|
||||||
|
t map[string]*transaction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mux) serve(msg *Message, tr Transport) bool {
|
||||||
|
m.RLock()
|
||||||
|
tx, ok := m.t[string(msg.Transaction)]
|
||||||
|
m.RUnlock()
|
||||||
|
if ok {
|
||||||
|
tx.msg, tx.from = msg, tr
|
||||||
|
tx.Done()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mux) newTx() *transaction {
|
||||||
|
tx := &transaction{id: NewTransaction()}
|
||||||
|
m.Lock()
|
||||||
|
if m.t == nil {
|
||||||
|
m.t = make(map[string]*transaction)
|
||||||
|
} else {
|
||||||
|
for m.t[string(tx.id)] != nil {
|
||||||
|
rand.Read(tx.id[4:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.t[string(tx.id)] = tx
|
||||||
|
m.Unlock()
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mux) closeTx(tx *transaction) {
|
||||||
|
m.Lock()
|
||||||
|
delete(m.t, string(tx.id))
|
||||||
|
m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mux) Close() {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
for _, it := range m.t {
|
||||||
|
it.Close()
|
||||||
|
}
|
||||||
|
m.t = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type transaction struct {
|
||||||
|
sync.WaitGroup
|
||||||
|
id []byte
|
||||||
|
from Transport
|
||||||
|
msg *Message
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) {
|
||||||
|
tx.Add(1)
|
||||||
|
t := time.AfterFunc(d, tx.timeout)
|
||||||
|
tx.Wait()
|
||||||
|
t.Stop()
|
||||||
|
if err = tx.err; err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return tx.msg, tx.from, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *transaction) timeout() {
|
||||||
|
tx.err = errTimeout
|
||||||
|
tx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *transaction) Close() {
|
||||||
|
tx.err = errCanceled
|
||||||
|
tx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
var errCanceled = errors.New("stun: transaction canceled")
|
||||||
|
var errTimeout = errors.New("stun: transaction timeout")
|
438
vendor/github.com/isofew/go-stun/stun/attribute.go
generated
vendored
Normal file
438
vendor/github.com/isofew/go-stun/stun/attribute.go
generated
vendored
Normal file
@ -0,0 +1,438 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"fmt"
|
||||||
|
"hash/crc32"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Attribute represents a STUN attribute.
|
||||||
|
type Attr interface {
|
||||||
|
Type() uint16
|
||||||
|
Marshal(p []byte) []byte
|
||||||
|
Unmarshal(b []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP address family
|
||||||
|
const (
|
||||||
|
IPv4 = 0x01
|
||||||
|
IPv6 = 0x02
|
||||||
|
)
|
||||||
|
|
||||||
|
// IP address family
|
||||||
|
const (
|
||||||
|
ChangeIP uint64 = 0x04
|
||||||
|
ChangePort = 0x02
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAttr(typ uint16) Attr {
|
||||||
|
switch typ {
|
||||||
|
case AttrMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress,
|
||||||
|
AttrXorMappedAddress, AttrAlternateServer, AttrResponseOrigin, AttrOtherAddress,
|
||||||
|
AttrResponseAddress, AttrSourceAddress, AttrChangedAddress, AttrReflectedFrom:
|
||||||
|
return &addr{typ: typ}
|
||||||
|
case AttrRequestedAddressFamily, AttrRequestedTransport:
|
||||||
|
return &number{typ: typ, size: 4, pad: 24}
|
||||||
|
case AttrChannelNumber, AttrResponsePort:
|
||||||
|
return &number{typ: typ, size: 4, pad: 16}
|
||||||
|
case AttrLifetime, AttrConnectionID, AttrCacheTimeout,
|
||||||
|
AttrBandwidth, AttrTimerVal,
|
||||||
|
AttrTransactionTransmitCounter,
|
||||||
|
AttrEcnCheck, AttrChangeRequest, AttrPriority:
|
||||||
|
return &number{typ: typ, size: 4}
|
||||||
|
case AttrIceControlled, AttrIceControlling:
|
||||||
|
return &number{typ: typ, size: 8}
|
||||||
|
case AttrUsername, AttrRealm, AttrNonce, AttrSoftware, AttrPassword, AttrThirdPartyAuthorization,
|
||||||
|
AttrData, AttrAccessToken, AttrReservationToken, AttrMobilityTicket, AttrPadding, AttrUnknownAttributes:
|
||||||
|
return &raw{typ: typ}
|
||||||
|
case AttrMessageIntegrity:
|
||||||
|
return &integrity{}
|
||||||
|
case AttrErrorCode:
|
||||||
|
return &Error{}
|
||||||
|
case AttrEvenPort:
|
||||||
|
return &number{typ: typ, size: 1}
|
||||||
|
case AttrDontFragment, AttrUseCandidate:
|
||||||
|
return flag(typ)
|
||||||
|
case AttrFingerprint:
|
||||||
|
return &fingerprint{}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AttrName(typ uint16) string {
|
||||||
|
if r, ok := attrNames[typ]; ok {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
return "0x" + strconv.FormatUint(uint64(typ), 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Int(typ uint16, v uint64) Attr {
|
||||||
|
switch typ {
|
||||||
|
case AttrRequestedAddressFamily, AttrRequestedTransport:
|
||||||
|
return &number{typ, 4, 24, v}
|
||||||
|
case AttrChannelNumber, AttrResponsePort:
|
||||||
|
return &number{typ, 4, 16, v}
|
||||||
|
case AttrIceControlled, AttrIceControlling:
|
||||||
|
return &number{typ, 8, 0, v}
|
||||||
|
case AttrEvenPort:
|
||||||
|
return &number{typ, 1, 0, v}
|
||||||
|
default:
|
||||||
|
return &number{typ, 4, 0, v}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type number struct {
|
||||||
|
typ uint16
|
||||||
|
size, pad uint8
|
||||||
|
v uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *number) Type() uint16 { return a.typ }
|
||||||
|
|
||||||
|
func (a *number) Marshal(p []byte) []byte {
|
||||||
|
r, b := grow(p, int(a.size))
|
||||||
|
switch a.size {
|
||||||
|
case 1:
|
||||||
|
b[0] = byte(a.v)
|
||||||
|
case 4:
|
||||||
|
be.PutUint32(b, uint32(a.v<<a.pad))
|
||||||
|
case 8:
|
||||||
|
be.PutUint64(b, a.v<<a.pad)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *number) Unmarshal(b []byte) error {
|
||||||
|
if len(b) < int(a.size) {
|
||||||
|
return errFormat
|
||||||
|
}
|
||||||
|
switch a.size {
|
||||||
|
case 1:
|
||||||
|
a.v = uint64(b[0])
|
||||||
|
case 4:
|
||||||
|
a.v = uint64(be.Uint32(b) >> a.pad)
|
||||||
|
case 8:
|
||||||
|
a.v = be.Uint64(b) >> a.pad
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *number) String() string {
|
||||||
|
return "0x" + strconv.FormatUint(a.v, 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Flag(v uint16) Attr {
|
||||||
|
return flag(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
type flag uint16
|
||||||
|
|
||||||
|
func (attr flag) Type() uint16 { return uint16(attr) }
|
||||||
|
func (flag) Marshal(p []byte) []byte { return p }
|
||||||
|
func (flag) Unmarshal(b []byte) error { return nil }
|
||||||
|
|
||||||
|
// Error represents the ERROR-CODE attribute.
|
||||||
|
type Error struct {
|
||||||
|
Code int
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewError(code int) *Error {
|
||||||
|
return &Error{code, ErrorText(code)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Error) Type() uint16 { return AttrErrorCode }
|
||||||
|
|
||||||
|
func (e *Error) Marshal(p []byte) []byte {
|
||||||
|
r, b := grow(p, 4+len(e.Reason))
|
||||||
|
b[0] = 0
|
||||||
|
b[1] = 0
|
||||||
|
b[2] = byte(e.Code / 100)
|
||||||
|
b[3] = byte(e.Code % 100)
|
||||||
|
copy(b[4:], e.Reason)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Unmarshal(b []byte) error {
|
||||||
|
if len(b) < 4 {
|
||||||
|
return errFormat
|
||||||
|
}
|
||||||
|
e.Code = int(b[2])*100 + int(b[3])
|
||||||
|
e.Reason = getString(b[4:])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Error() string { return e.String() }
|
||||||
|
func (e *Error) String() string { return fmt.Sprintf("%d %s", e.Code, e.Reason) }
|
||||||
|
|
||||||
|
// ErrorText returns a text for the STUN error code. It returns the empty string if the code is unknown.
|
||||||
|
func ErrorText(code int) string { return errorText[code] }
|
||||||
|
|
||||||
|
func getString(b []byte) string {
|
||||||
|
for i := len(b); i >= 0; i-- {
|
||||||
|
if b[i-1] > 0 {
|
||||||
|
return string(b[:i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func Addr(typ uint16, v net.Addr) Attr {
|
||||||
|
ip, port := SockAddr(v)
|
||||||
|
return &addr{typ, ip, port}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SockAddr(v net.Addr) (net.IP, int) {
|
||||||
|
switch a := v.(type) {
|
||||||
|
case *net.UDPAddr:
|
||||||
|
return a.IP, a.Port
|
||||||
|
case *net.TCPAddr:
|
||||||
|
return a.IP, a.Port
|
||||||
|
case *net.IPAddr:
|
||||||
|
return a.IP, 0
|
||||||
|
default:
|
||||||
|
return net.IPv4zero, 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameAddr(a, b net.Addr) bool {
|
||||||
|
aip, aport := SockAddr(a)
|
||||||
|
bip, bport := SockAddr(b)
|
||||||
|
return aip.Equal(bip) && aport == bport
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAddr(network string, ip net.IP, port int) net.Addr {
|
||||||
|
switch network {
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
return &net.UDPAddr{IP: ip, Port: port}
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
return &net.TCPAddr{IP: ip, Port: port}
|
||||||
|
}
|
||||||
|
return &net.IPAddr{IP: ip}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IP(typ uint16, ip net.IP) Attr { return &addr{typ, ip, 0} }
|
||||||
|
|
||||||
|
type addr struct {
|
||||||
|
typ uint16
|
||||||
|
IP net.IP
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) Type() uint16 { return addr.typ }
|
||||||
|
|
||||||
|
func (addr *addr) Addr(network string) net.Addr {
|
||||||
|
return NewAddr(network, addr.IP, addr.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) Xored() bool {
|
||||||
|
switch addr.typ {
|
||||||
|
case AttrXorMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) Marshal(p []byte) []byte {
|
||||||
|
return addr.MarshalAddr(p, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) MarshalAddr(p, tx []byte) []byte {
|
||||||
|
fam, ip := IPv4, addr.IP.To4()
|
||||||
|
if ip == nil {
|
||||||
|
fam, ip = IPv6, addr.IP
|
||||||
|
}
|
||||||
|
r, b := grow(p, 4+len(ip))
|
||||||
|
b[0] = 0
|
||||||
|
b[1] = byte(fam)
|
||||||
|
if addr.Xored() && tx != nil {
|
||||||
|
be.PutUint16(b[2:], uint16(addr.Port)^0x2112)
|
||||||
|
b = b[4:]
|
||||||
|
for i, it := range ip {
|
||||||
|
b[i] = it ^ tx[i]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
be.PutUint16(b[2:], uint16(addr.Port))
|
||||||
|
copy(b[4:], ip)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) Unmarshal(b []byte) error {
|
||||||
|
return addr.UnmarshalAddr(b, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) UnmarshalAddr(b, tx []byte) error {
|
||||||
|
if len(b) < 4 {
|
||||||
|
return errFormat
|
||||||
|
}
|
||||||
|
n, port := net.IPv4len, int(be.Uint16(b[2:]))
|
||||||
|
if b[1] == IPv6 {
|
||||||
|
n = net.IPv6len
|
||||||
|
}
|
||||||
|
if b = b[4:]; len(b) < n {
|
||||||
|
return errFormat
|
||||||
|
}
|
||||||
|
addr.IP = make(net.IP, n)
|
||||||
|
if addr.Xored() && tx != nil {
|
||||||
|
for i, it := range b {
|
||||||
|
addr.IP[i] = it ^ tx[i]
|
||||||
|
}
|
||||||
|
addr.Port = port ^ 0x2112
|
||||||
|
} else {
|
||||||
|
copy(addr.IP, b)
|
||||||
|
addr.Port = port
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) Equal(a *addr) bool {
|
||||||
|
return addr == a || (addr != nil && a != nil && addr.IP.Equal(a.IP) && addr.Port == a.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (addr *addr) String() string {
|
||||||
|
if addr.Port == 0 {
|
||||||
|
return addr.IP.String()
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(addr.IP.String(), strconv.Itoa(addr.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Bytes(typ uint16, v []byte) Attr { return &raw{typ, v} }
|
||||||
|
|
||||||
|
type raw struct {
|
||||||
|
typ uint16
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *raw) Type() uint16 { return attr.typ }
|
||||||
|
func (attr *raw) Marshal(p []byte) []byte { return append(p, attr.data...) }
|
||||||
|
func (attr *raw) Unmarshal(p []byte) error {
|
||||||
|
attr.data = p
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (attr *raw) String() string { return string(attr.data) }
|
||||||
|
|
||||||
|
func String(typ uint16, v string) Attr {
|
||||||
|
return &str{typ, v}
|
||||||
|
}
|
||||||
|
|
||||||
|
type str struct {
|
||||||
|
typ uint16
|
||||||
|
data string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *str) Type() uint16 { return attr.typ }
|
||||||
|
func (attr *str) Marshal(p []byte) []byte { return append(p, attr.data...) }
|
||||||
|
func (attr *str) Unmarshal(p []byte) error {
|
||||||
|
attr.data = string(p)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (attr *str) String() string { return attr.data }
|
||||||
|
|
||||||
|
func MessageIntegrity(key []byte) Attr {
|
||||||
|
return &integrity{key: key}
|
||||||
|
}
|
||||||
|
|
||||||
|
type integrity struct {
|
||||||
|
key, sum, raw []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*integrity) Type() uint16 {
|
||||||
|
return AttrMessageIntegrity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *integrity) Marshal(p []byte) []byte {
|
||||||
|
return append(p, attr.sum...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *integrity) Unmarshal(b []byte) error {
|
||||||
|
if len(b) < 20 {
|
||||||
|
return errFormat
|
||||||
|
}
|
||||||
|
attr.sum = b
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *integrity) MarshalSum(p, raw []byte) []byte {
|
||||||
|
n := len(raw) - 4
|
||||||
|
be.PutUint16(raw[2:], uint16(n+4))
|
||||||
|
return attr.Sum(attr.key, raw[:n], p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *integrity) UnmarshalSum(p, raw []byte) error {
|
||||||
|
attr.raw = raw
|
||||||
|
return attr.Unmarshal(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *integrity) Sum(key, data, p []byte) []byte {
|
||||||
|
h := hmac.New(sha1.New, key)
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *integrity) Check(key []byte) bool {
|
||||||
|
r := attr.raw
|
||||||
|
if len(r) < 44 {
|
||||||
|
return r == nil
|
||||||
|
}
|
||||||
|
be.PutUint16(r[2:], uint16(len(r)-20))
|
||||||
|
h := attr.Sum(key, r[:len(r)-24], nil)
|
||||||
|
return bytes.Equal(h, attr.sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
var Fingerprint Attr = &fingerprint{}
|
||||||
|
|
||||||
|
type fingerprint struct {
|
||||||
|
sum uint32
|
||||||
|
raw []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fingerprint) Type() uint16 {
|
||||||
|
return AttrFingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *fingerprint) Marshal(p []byte) []byte {
|
||||||
|
r, b := grow(p, 4)
|
||||||
|
be.PutUint32(b, attr.sum)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *fingerprint) Unmarshal(b []byte) error {
|
||||||
|
if len(b) < 4 {
|
||||||
|
return errFormat
|
||||||
|
}
|
||||||
|
attr.sum = be.Uint32(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *fingerprint) MarshalSum(p, raw []byte) []byte {
|
||||||
|
n := len(raw) - 4
|
||||||
|
be.PutUint16(raw[2:], uint16(n-12))
|
||||||
|
v := attr.Sum(raw[:n])
|
||||||
|
r, b := grow(p, 4)
|
||||||
|
be.PutUint32(b, v)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *fingerprint) UnmarshalSum(p, raw []byte) error {
|
||||||
|
attr.raw = raw
|
||||||
|
return attr.Unmarshal(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *fingerprint) Sum(p []byte) uint32 {
|
||||||
|
return crc32.ChecksumIEEE(p) ^ 0x5354554e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attr *fingerprint) Check() bool {
|
||||||
|
r := attr.raw
|
||||||
|
if len(r) < 28 {
|
||||||
|
return r == nil
|
||||||
|
}
|
||||||
|
be.PutUint16(r[2:], uint16(len(r)-20))
|
||||||
|
return attr.Sum(r[:len(r)-8]) == attr.sum
|
||||||
|
}
|
110
vendor/github.com/isofew/go-stun/stun/conn.go
generated
vendored
Normal file
110
vendor/github.com/isofew/go-stun/stun/conn.go
generated
vendored
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
agent *Agent
|
||||||
|
sess *Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(conn net.Conn, config *Config, stop chan struct{}) *Conn {
|
||||||
|
a := NewAgent(config)
|
||||||
|
go a.ServeConn(conn, stop)
|
||||||
|
return &Conn{conn, a, nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Network() string {
|
||||||
|
return c.LocalAddr().Network()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Discover() (net.Addr, error) {
|
||||||
|
res, err := c.Request(&Message{Type: MethodBinding})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mapped := res.GetAddr(c.Network(), AttrXorMappedAddress, AttrMappedAddress)
|
||||||
|
if mapped != nil {
|
||||||
|
return mapped, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("stun: bad response, no mapped address")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Request(req *Message) (res *Message, err error) {
|
||||||
|
res, _, err = c.RequestTransport(req, c.Conn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) RequestTransport(req *Message, to Transport) (res *Message, from Transport, err error) {
|
||||||
|
sess := c.sess
|
||||||
|
auth := c.agent.config.AuthMethod
|
||||||
|
if to == nil {
|
||||||
|
to = c.Conn
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
msg := &Message{
|
||||||
|
req.Type,
|
||||||
|
NewTransaction(),
|
||||||
|
append(sess.attrs(), req.Attributes...),
|
||||||
|
}
|
||||||
|
res, from, err = c.agent.RoundTrip(msg, to)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := res.GetError()
|
||||||
|
if code == nil {
|
||||||
|
// FIXME: authorize response...
|
||||||
|
if sess != nil {
|
||||||
|
c.sess = sess
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = code
|
||||||
|
switch code.Code {
|
||||||
|
case CodeUnauthorized, CodeStaleNonce:
|
||||||
|
if auth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sess = &Session{
|
||||||
|
Realm: res.GetString(AttrRealm),
|
||||||
|
Nonce: res.GetString(AttrNonce),
|
||||||
|
}
|
||||||
|
if err = auth(sess); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth = nil
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
Realm string
|
||||||
|
Nonce string
|
||||||
|
Username string
|
||||||
|
Key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) attrs() []Attr {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var a []Attr
|
||||||
|
if s.Realm != "" {
|
||||||
|
a = append(a, String(AttrRealm, s.Realm))
|
||||||
|
}
|
||||||
|
if s.Nonce != "" {
|
||||||
|
a = append(a, String(AttrNonce, s.Nonce))
|
||||||
|
}
|
||||||
|
if s.Username != "" {
|
||||||
|
a = append(a, String(AttrUsername, s.Username))
|
||||||
|
}
|
||||||
|
if s.Key != nil {
|
||||||
|
a = append(a, MessageIntegrity(s.Key))
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
203
vendor/github.com/isofew/go-stun/stun/gen.go
generated
vendored
Normal file
203
vendor/github.com/isofew/go-stun/stun/gen.go
generated
vendored
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
//+build ignore
|
||||||
|
|
||||||
|
//go:generate go run gen.go
|
||||||
|
// This program generates STUN parameters: methods, attributes and error codes by reading IANA registry.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"go/format"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
err := generate("http://www.iana.org/assignments/stun-parameters/stun-parameters.xml", "registry.go")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadRegistry(url string) (*Registry, error) {
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r := &Registry{}
|
||||||
|
err = xml.Unmarshal(b, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
name := regexp.MustCompile("^([\\w-]+)(.*was\\s+([\\w-]+))?")
|
||||||
|
for _, reg := range r.Registry {
|
||||||
|
for _, r := range reg.Records {
|
||||||
|
m := name.FindStringSubmatch(r.Description)
|
||||||
|
if m != nil {
|
||||||
|
r.Value = strings.ToLower(r.Value)
|
||||||
|
r.Name = m[1]
|
||||||
|
if r.Name == "Reserved" && m[3] != "" {
|
||||||
|
r.Name = m[3]
|
||||||
|
r.Deprecated = true
|
||||||
|
}
|
||||||
|
r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "rfc")
|
||||||
|
r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "RFC-")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Record struct {
|
||||||
|
Value string `xml:"value"`
|
||||||
|
Description string `xml:"description"`
|
||||||
|
Name string
|
||||||
|
Deprecated bool
|
||||||
|
Ref struct {
|
||||||
|
Type string `xml:"type,attr"`
|
||||||
|
Data string `xml:"data,attr"`
|
||||||
|
} `xml:"xref"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Record) IsValid() bool {
|
||||||
|
return r.Name != "Reserved" && r.Name != "Unassigned" && (r.Ref.Type == "rfc" || r.Ref.Type == "draft")
|
||||||
|
}
|
||||||
|
|
||||||
|
type Registry struct {
|
||||||
|
Title string `xml:"title"`
|
||||||
|
Updated string `xml:"updated"`
|
||||||
|
Registry []struct {
|
||||||
|
Id string `xml:"id,attr"`
|
||||||
|
Records []*Record `xml:"record"`
|
||||||
|
} `xml:"registry"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reg *Registry) GetRecords(id string) []*Record {
|
||||||
|
for _, it := range reg.Registry {
|
||||||
|
if it.Id == id {
|
||||||
|
return it.Records
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generate(url, file string) error {
|
||||||
|
reg, err := loadRegistry(url)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
fmt.Fprintf(b, "package stun\n\n")
|
||||||
|
fmt.Fprintf(b, "// Do not edit. This file is generated by 'go generate gen.go'\n")
|
||||||
|
fmt.Fprintf(b, "// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA).\n")
|
||||||
|
fmt.Fprintf(b, "// %s, Updated: %s.\n\n", reg.Title, reg.Updated)
|
||||||
|
|
||||||
|
genMethods(reg.GetRecords("stun-parameters-2"), b)
|
||||||
|
genAttributes(reg.GetRecords("stun-parameters-4"), b)
|
||||||
|
genErrors(reg.GetRecords("stun-parameters-6"), b)
|
||||||
|
|
||||||
|
src, err := format.Source(b.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = ioutil.WriteFile(file, src, 0644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func genMethods(records []*Record, b *bytes.Buffer) error {
|
||||||
|
c, m := &bytes.Buffer{}, &bytes.Buffer{}
|
||||||
|
ref := ""
|
||||||
|
for _, it := range records {
|
||||||
|
if it.IsValid() {
|
||||||
|
if ref == it.Ref.Data {
|
||||||
|
fmt.Fprintf(c, "Method%v uint16 = %s\n", it.Name, it.Value)
|
||||||
|
} else {
|
||||||
|
ref = it.Ref.Data
|
||||||
|
fmt.Fprintf(c, "Method%v uint16 = %s // RFC %s\n", it.Name, it.Value, ref)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(m, "Method%v: \"%s\",\n", it.Name, it.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(b, "// STUN methods.\n")
|
||||||
|
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
|
||||||
|
fmt.Fprintf(b, "// STUN method names.\n")
|
||||||
|
fmt.Fprintf(b, "var methodNames = map[uint16]string{\n%s}\n", m.Bytes())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func genAttributes(records []*Record, b *bytes.Buffer) error {
|
||||||
|
c, d, m := &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}
|
||||||
|
ref := ""
|
||||||
|
for _, it := range records {
|
||||||
|
if it.IsValid() {
|
||||||
|
a := c
|
||||||
|
if it.Deprecated {
|
||||||
|
a = d
|
||||||
|
}
|
||||||
|
v := strings.Replace(it.Name, "_", "-", -1)
|
||||||
|
n := strings.Replace(v, "-", " ", -1)
|
||||||
|
parts := strings.Fields(n)
|
||||||
|
for i, s := range parts {
|
||||||
|
switch s {
|
||||||
|
case "", "ID":
|
||||||
|
default:
|
||||||
|
r, n := utf8.DecodeRuneInString(s)
|
||||||
|
s = string(unicode.ToUpper(r)) + strings.ToLower(s[n:])
|
||||||
|
}
|
||||||
|
parts[i] = s
|
||||||
|
}
|
||||||
|
n = strings.Join(parts, "")
|
||||||
|
if ref == it.Ref.Data || it.Deprecated {
|
||||||
|
fmt.Fprintf(a, "Attr%v uint16 = %s\n", n, it.Value)
|
||||||
|
} else {
|
||||||
|
ref = it.Ref.Data
|
||||||
|
fmt.Fprintf(a, "Attr%v uint16 = %s // RFC %s\n", n, it.Value, ref)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(m, "Attr%v: \"%s\",\n", n, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(b, "// STUN attributes.\n")
|
||||||
|
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
|
||||||
|
fmt.Fprintf(b, "// Deprecated: For backwards compatibility only.\n")
|
||||||
|
fmt.Fprintf(b, "const (\n%s)\n", d.Bytes())
|
||||||
|
fmt.Fprintf(b, "// STUN attribute names.\n")
|
||||||
|
fmt.Fprintf(b, "var attrNames = map[uint16]string{\n%s}\n", m.Bytes())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func genErrors(records []*Record, b *bytes.Buffer) error {
|
||||||
|
c, m := &bytes.Buffer{}, &bytes.Buffer{}
|
||||||
|
ref := ""
|
||||||
|
for _, it := range records {
|
||||||
|
if it.IsValid() {
|
||||||
|
n := strings.Replace(strings.Title(it.Description), " ", "", -1)
|
||||||
|
if ref == it.Ref.Data {
|
||||||
|
fmt.Fprintf(c, "Code%v int = %s\n", n, it.Value)
|
||||||
|
} else {
|
||||||
|
ref = it.Ref.Data
|
||||||
|
fmt.Fprintf(c, "Code%v int = %s // RFC %s\n", n, it.Value, ref)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(m, "Code%v: \"%s\",\n", n, it.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(b, "// STUN error codes.\n")
|
||||||
|
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
|
||||||
|
fmt.Fprintf(b, "// STUN error texts.\n")
|
||||||
|
fmt.Fprintf(b, "var errorText = map[int]string{\n%s}\n", m.Bytes())
|
||||||
|
return nil
|
||||||
|
}
|
351
vendor/github.com/isofew/go-stun/stun/message.go
generated
vendored
Normal file
351
vendor/github.com/isofew/go-stun/stun/message.go
generated
vendored
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
KindRequest uint16 = 0x0000
|
||||||
|
KindIndication uint16 = 0x0010
|
||||||
|
KindResponse uint16 = 0x0100
|
||||||
|
KindError uint16 = 0x0110
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message represents a STUN message.
|
||||||
|
type Message struct {
|
||||||
|
Type uint16
|
||||||
|
Transaction []byte
|
||||||
|
Attributes []Attr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Marshal(p []byte) []byte {
|
||||||
|
pos := len(p)
|
||||||
|
r, b := grow(p, 20)
|
||||||
|
be.PutUint16(b, m.Type)
|
||||||
|
|
||||||
|
if m.Transaction != nil {
|
||||||
|
copy(b[4:], m.Transaction)
|
||||||
|
} else {
|
||||||
|
copy(b[4:], magicCookie)
|
||||||
|
rand.Read(b[8:20])
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(byPosition(m.Attributes))
|
||||||
|
for _, attr := range m.Attributes {
|
||||||
|
r = m.marshalAttr(r, attr, pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
be.PutUint16(r[pos+2:], uint16(len(r)-pos-20))
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) marshalAttr(p []byte, attr Attr, pos int) []byte {
|
||||||
|
h := len(p)
|
||||||
|
r, b := grow(p, 4)
|
||||||
|
be.PutUint16(b, attr.Type())
|
||||||
|
|
||||||
|
switch v := attr.(type) {
|
||||||
|
case *addr:
|
||||||
|
r = v.MarshalAddr(r, r[pos+4:])
|
||||||
|
case *integrity:
|
||||||
|
r = v.MarshalSum(r, r[pos:])
|
||||||
|
case *fingerprint:
|
||||||
|
r = v.MarshalSum(r, r[pos:])
|
||||||
|
default:
|
||||||
|
r = v.Marshal(r)
|
||||||
|
}
|
||||||
|
n := len(r) - h - 4
|
||||||
|
be.PutUint16(r[h+2:], uint16(n))
|
||||||
|
|
||||||
|
if pad := n & 3; pad != 0 {
|
||||||
|
r, b = grow(r, 4-pad)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Unmarshal(b []byte) (n int, err error) {
|
||||||
|
if len(b) < 20 {
|
||||||
|
err = io.EOF
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l := int(be.Uint16(b[2:])) + 20
|
||||||
|
if len(b) < l {
|
||||||
|
err = io.EOF
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pos, p := 20, make([]byte, l)
|
||||||
|
copy(p, b[:l])
|
||||||
|
|
||||||
|
m.Type = be.Uint16(p)
|
||||||
|
m.Transaction = p[4:20]
|
||||||
|
|
||||||
|
for pos < len(p) {
|
||||||
|
s, attr, err := m.unmarshalAttr(p, pos)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
pos += s
|
||||||
|
if attr != nil {
|
||||||
|
m.Attributes = append(m.Attributes, attr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) unmarshalAttr(p []byte, pos int) (n int, attr Attr, err error) {
|
||||||
|
b := p[pos:]
|
||||||
|
if len(b) < 4 {
|
||||||
|
err = errFormat
|
||||||
|
return
|
||||||
|
}
|
||||||
|
typ := be.Uint16(b)
|
||||||
|
attr, n = newAttr(typ), int(be.Uint16(b[2:]))+4
|
||||||
|
if len(b) < n {
|
||||||
|
err = errFormat
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b = b[4:n]
|
||||||
|
if attr != nil {
|
||||||
|
switch v := attr.(type) {
|
||||||
|
case *addr:
|
||||||
|
err = v.UnmarshalAddr(b, m.Transaction)
|
||||||
|
case *integrity:
|
||||||
|
err = v.UnmarshalSum(b, p[:pos+n])
|
||||||
|
case *fingerprint:
|
||||||
|
err = v.UnmarshalSum(b, p[:pos+n])
|
||||||
|
default:
|
||||||
|
err = attr.Unmarshal(b)
|
||||||
|
}
|
||||||
|
} else if typ < 0x8000 {
|
||||||
|
err = errFormat
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
err = &errAttribute{err, typ}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pad := n & 3; pad != 0 {
|
||||||
|
n += 4 - pad
|
||||||
|
if len(p) < pos+n {
|
||||||
|
err = errFormat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Kind() uint16 {
|
||||||
|
return m.Type & 0x110
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Method() uint16 {
|
||||||
|
return m.Type &^ 0x110
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Add(attr Attr) {
|
||||||
|
m.Attributes = append(m.Attributes, attr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Set(attr Attr) {
|
||||||
|
m.Del(attr.Type())
|
||||||
|
m.Add(attr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Del(typ uint16) {
|
||||||
|
n := 0
|
||||||
|
for _, a := range m.Attributes {
|
||||||
|
if a.Type() != typ {
|
||||||
|
m.Attributes[n] = a
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.Attributes = m.Attributes[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Get(typ uint16) (attr Attr) {
|
||||||
|
for _, attr = range m.Attributes {
|
||||||
|
if attr.Type() == typ {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Has(typ uint16) bool {
|
||||||
|
for _, attr := range m.Attributes {
|
||||||
|
if attr.Type() == typ {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) GetString(typ uint16) string {
|
||||||
|
if str, ok := m.Get(typ).(fmt.Stringer); ok {
|
||||||
|
return str.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) GetAddr(network string, typ ...uint16) net.Addr {
|
||||||
|
for _, t := range typ {
|
||||||
|
if addr, ok := m.Get(t).(*addr); ok {
|
||||||
|
return addr.Addr(network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) GetInt(typ uint16) (v uint64, ok bool) {
|
||||||
|
attr := m.Get(typ)
|
||||||
|
if r, ok := attr.(*number); ok {
|
||||||
|
return r.v, true
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) GetBytes(typ uint16) []byte {
|
||||||
|
if attr, ok := m.Get(typ).(*raw); ok {
|
||||||
|
return attr.data
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) GetError() *Error {
|
||||||
|
if err, ok := m.Get(AttrErrorCode).(*Error); ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) CheckIntegrity(key []byte) bool {
|
||||||
|
if attr, ok := m.Get(AttrMessageIntegrity).(*integrity); ok {
|
||||||
|
return attr.Check(key)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) CheckFingerprint() bool {
|
||||||
|
if attr, ok := m.Get(AttrFingerprint).(*fingerprint); ok {
|
||||||
|
return attr.Check()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) String() string {
|
||||||
|
sort.Sort(byPosition(m.Attributes))
|
||||||
|
|
||||||
|
// TODO: use sprintf
|
||||||
|
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
b.WriteString(MethodName(m.Type))
|
||||||
|
b.WriteByte('{')
|
||||||
|
tx := m.Transaction
|
||||||
|
if tx == nil {
|
||||||
|
b.WriteString("nil")
|
||||||
|
} else if bytes.Equal(magicCookie, tx[:4]) {
|
||||||
|
b.WriteString(hex.EncodeToString(tx[4:]))
|
||||||
|
} else {
|
||||||
|
b.WriteString(hex.EncodeToString(tx))
|
||||||
|
}
|
||||||
|
for _, attr := range m.Attributes {
|
||||||
|
b.WriteString(", ")
|
||||||
|
b.WriteString(AttrName(attr.Type()))
|
||||||
|
switch v := attr.(type) {
|
||||||
|
case *raw:
|
||||||
|
b.WriteString(": \"")
|
||||||
|
b.Write(v.data)
|
||||||
|
b.WriteByte('"')
|
||||||
|
case *str:
|
||||||
|
b.WriteString(": \"")
|
||||||
|
b.WriteString(v.data)
|
||||||
|
b.WriteByte('"')
|
||||||
|
case flag, *integrity, *fingerprint:
|
||||||
|
default:
|
||||||
|
b.WriteString(fmt.Sprintf(": %v", attr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteByte('}')
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func MethodName(typ uint16) string {
|
||||||
|
if r, ok := methodNames[typ&^0x110]; ok {
|
||||||
|
switch typ & 0x110 {
|
||||||
|
case KindRequest:
|
||||||
|
return r + "Request"
|
||||||
|
case KindIndication:
|
||||||
|
return r + "Indication"
|
||||||
|
case KindResponse:
|
||||||
|
return r + "Response"
|
||||||
|
case KindError:
|
||||||
|
return r + "Error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "0x" + strconv.FormatUint(uint64(typ), 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalMessage(b []byte) (*Message, error) {
|
||||||
|
m := &Message{}
|
||||||
|
if _, err := m.Unmarshal(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var magicCookie = []byte{0x21, 0x12, 0xa4, 0x42}
|
||||||
|
var alphanum = dict("01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
|
||||||
|
type dict []byte
|
||||||
|
|
||||||
|
func (d dict) rand(n int) string {
|
||||||
|
m, b := len(d), make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = d[rand.Intn(m)]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTransaction() []byte {
|
||||||
|
id := make([]byte, 16)
|
||||||
|
copy(id, magicCookie)
|
||||||
|
rand.Read(id[4:]) // TODO: configure random source
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
type byPosition []Attr
|
||||||
|
|
||||||
|
func (s byPosition) Len() int { return len(s) }
|
||||||
|
func (s byPosition) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
func (s byPosition) Less(i, j int) bool {
|
||||||
|
a, b := s[i].Type(), s[j].Type()
|
||||||
|
switch b {
|
||||||
|
case a:
|
||||||
|
return i < j
|
||||||
|
case AttrMessageIntegrity:
|
||||||
|
return a != AttrFingerprint
|
||||||
|
case AttrFingerprint:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return i < j
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errAttribute struct {
|
||||||
|
error
|
||||||
|
typ uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err errAttribute) Error() string {
|
||||||
|
return "attribute " + AttrName(err.typ) + ": " + err.error.Error()
|
||||||
|
}
|
172
vendor/github.com/isofew/go-stun/stun/nat.go
generated
vendored
Normal file
172
vendor/github.com/isofew/go-stun/stun/nat.go
generated
vendored
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointIndependent = "endpoint-independent"
|
||||||
|
AddressDependent = "address-dependent"
|
||||||
|
AddressPortDependent = "address-port-dependent"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Detector struct {
|
||||||
|
*Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDetector(c *Conn) *Detector {
|
||||||
|
d := &Detector{c}
|
||||||
|
d.agent.Handler = &Server{agent: c.agent}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Hairpinning() error {
|
||||||
|
mapped, err := d.Discover()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn, err := net.Dial(d.Network(), mapped.String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// not using stop channel
|
||||||
|
c := NewConn(conn, d.agent.config, make(chan struct{}))
|
||||||
|
defer c.Close()
|
||||||
|
_, err = c.Discover()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) DiscoverChange(change uint64) error {
|
||||||
|
req := &Message{Type: MethodBinding, Attributes: []Attr{Int(AttrChangeRequest, change)}}
|
||||||
|
_, from, err := d.RequestTransport(req, d)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ip, port := SockAddr(d.RemoteAddr())
|
||||||
|
chip, chport := SockAddr(from.RemoteAddr())
|
||||||
|
if change&ChangeIP != 0 {
|
||||||
|
if ip.Equal(chip) {
|
||||||
|
return errors.New("stun: bad response, ip address is not changed")
|
||||||
|
}
|
||||||
|
} else if change&ChangePort != 0 {
|
||||||
|
if port == chport {
|
||||||
|
return errors.New("stun: bad response, port is not changed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Filtering() (string, error) {
|
||||||
|
n := d.Network()
|
||||||
|
if n != "udp" {
|
||||||
|
return "", errors.New("stun: filtering test is not applicable to " + n)
|
||||||
|
}
|
||||||
|
_, err := d.Request(&Message{Type: MethodBinding})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
err = d.DiscoverChange(ChangeIP | ChangePort)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return EndpointIndependent, nil
|
||||||
|
case errTimeout:
|
||||||
|
err = d.DiscoverChange(ChangePort)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return AddressDependent, nil
|
||||||
|
case errTimeout:
|
||||||
|
return AddressPortDependent, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) DiscoverOther(addr net.Addr) (net.Addr, error) {
|
||||||
|
n := addr.Network()
|
||||||
|
conn, err := net.Dial(n, addr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
// not using stop channel
|
||||||
|
go d.agent.ServeConn(conn, make(chan struct{}))
|
||||||
|
res, _, err := d.RequestTransport(&Message{Type: MethodBinding}, conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mapped := res.GetAddr(n, AttrXorMappedAddress, AttrMappedAddress)
|
||||||
|
if mapped != nil {
|
||||||
|
return mapped, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("stun: bad response, no mapped address")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Mapping() (string, error) {
|
||||||
|
n := d.Network()
|
||||||
|
msg, err := d.Request(&Message{Type: MethodBinding})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
mapped, other := msg.GetAddr(n, AttrXorMappedAddress), msg.GetAddr(n, AttrOtherAddress)
|
||||||
|
if mapped == nil {
|
||||||
|
return "", errors.New("stun: bad response, no mapped address")
|
||||||
|
}
|
||||||
|
if other == nil {
|
||||||
|
return "", errors.New("stun: bad response, no other address")
|
||||||
|
}
|
||||||
|
ip, _ := SockAddr(mapped)
|
||||||
|
if ip.IsLoopback() {
|
||||||
|
return EndpointIndependent, nil
|
||||||
|
}
|
||||||
|
for _, it := range local {
|
||||||
|
if it.IP.Equal(ip) {
|
||||||
|
return EndpointIndependent, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ip, _ = SockAddr(other)
|
||||||
|
_, port := SockAddr(d.RemoteAddr())
|
||||||
|
a, err := d.DiscoverOther(NewAddr(n, ip, port))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if sameAddr(a, mapped) {
|
||||||
|
return EndpointIndependent, nil
|
||||||
|
}
|
||||||
|
b, err := d.DiscoverOther(other)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if sameAddr(b, a) {
|
||||||
|
return AddressDependent, nil
|
||||||
|
}
|
||||||
|
return AddressPortDependent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LocalAddrs() []*net.IPAddr {
|
||||||
|
return local
|
||||||
|
}
|
||||||
|
|
||||||
|
var local []*net.IPAddr
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
addrs, _ := iface.Addrs()
|
||||||
|
for _, it := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch it := it.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = it.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = it.IP
|
||||||
|
}
|
||||||
|
if ip != nil && ip.IsGlobalUnicast() {
|
||||||
|
local = append(local, &net.IPAddr{ip, iface.Name})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
179
vendor/github.com/isofew/go-stun/stun/registry.go
generated
vendored
Normal file
179
vendor/github.com/isofew/go-stun/stun/registry.go
generated
vendored
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
// Do not edit. This file is generated by 'go generate gen.go'
|
||||||
|
// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA).
|
||||||
|
// Session Traversal Utilities for NAT (STUN) Parameters, Updated: 2016-11-14.
|
||||||
|
|
||||||
|
// STUN methods.
|
||||||
|
const (
|
||||||
|
MethodBinding uint16 = 0x001 // RFC 5389
|
||||||
|
MethodSharedSecret uint16 = 0x002
|
||||||
|
MethodAllocate uint16 = 0x003 // RFC 5766
|
||||||
|
MethodRefresh uint16 = 0x004
|
||||||
|
MethodSend uint16 = 0x006
|
||||||
|
MethodData uint16 = 0x007
|
||||||
|
MethodCreatePermission uint16 = 0x008
|
||||||
|
MethodChannelBind uint16 = 0x009
|
||||||
|
MethodConnect uint16 = 0x00a // RFC 6062
|
||||||
|
MethodConnectionBind uint16 = 0x00b
|
||||||
|
MethodConnectionAttempt uint16 = 0x00c
|
||||||
|
)
|
||||||
|
|
||||||
|
// STUN method names.
|
||||||
|
var methodNames = map[uint16]string{
|
||||||
|
MethodBinding: "Binding",
|
||||||
|
MethodSharedSecret: "SharedSecret",
|
||||||
|
MethodAllocate: "Allocate",
|
||||||
|
MethodRefresh: "Refresh",
|
||||||
|
MethodSend: "Send",
|
||||||
|
MethodData: "Data",
|
||||||
|
MethodCreatePermission: "CreatePermission",
|
||||||
|
MethodChannelBind: "ChannelBind",
|
||||||
|
MethodConnect: "Connect",
|
||||||
|
MethodConnectionBind: "ConnectionBind",
|
||||||
|
MethodConnectionAttempt: "ConnectionAttempt",
|
||||||
|
}
|
||||||
|
|
||||||
|
// STUN attributes.
|
||||||
|
const (
|
||||||
|
AttrMappedAddress uint16 = 0x0001 // RFC 5389
|
||||||
|
AttrChangeRequest uint16 = 0x0003 // RFC 5780
|
||||||
|
AttrUsername uint16 = 0x0006 // RFC 5389
|
||||||
|
AttrMessageIntegrity uint16 = 0x0008
|
||||||
|
AttrErrorCode uint16 = 0x0009
|
||||||
|
AttrUnknownAttributes uint16 = 0x000a
|
||||||
|
AttrChannelNumber uint16 = 0x000c // RFC 5766
|
||||||
|
AttrLifetime uint16 = 0x000d
|
||||||
|
AttrXorPeerAddress uint16 = 0x0012
|
||||||
|
AttrData uint16 = 0x0013
|
||||||
|
AttrRealm uint16 = 0x0014 // RFC 5389
|
||||||
|
AttrNonce uint16 = 0x0015
|
||||||
|
AttrXorRelayedAddress uint16 = 0x0016 // RFC 5766
|
||||||
|
AttrRequestedAddressFamily uint16 = 0x0017 // RFC 6156
|
||||||
|
AttrEvenPort uint16 = 0x0018 // RFC 5766
|
||||||
|
AttrRequestedTransport uint16 = 0x0019
|
||||||
|
AttrDontFragment uint16 = 0x001a
|
||||||
|
AttrAccessToken uint16 = 0x001b // RFC 7635
|
||||||
|
AttrXorMappedAddress uint16 = 0x0020 // RFC 5389
|
||||||
|
AttrReservationToken uint16 = 0x0022 // RFC 5766
|
||||||
|
AttrPriority uint16 = 0x0024 // RFC 5245
|
||||||
|
AttrUseCandidate uint16 = 0x0025
|
||||||
|
AttrPadding uint16 = 0x0026 // RFC 5780
|
||||||
|
AttrResponsePort uint16 = 0x0027
|
||||||
|
AttrConnectionID uint16 = 0x002a // RFC 6062
|
||||||
|
AttrSoftware uint16 = 0x8022 // RFC 5389
|
||||||
|
AttrAlternateServer uint16 = 0x8023
|
||||||
|
AttrTransactionTransmitCounter uint16 = 0x8025 // RFC 7982
|
||||||
|
AttrCacheTimeout uint16 = 0x8027 // RFC 5780
|
||||||
|
AttrFingerprint uint16 = 0x8028 // RFC 5389
|
||||||
|
AttrIceControlled uint16 = 0x8029 // RFC 5245
|
||||||
|
AttrIceControlling uint16 = 0x802a
|
||||||
|
AttrResponseOrigin uint16 = 0x802b // RFC 5780
|
||||||
|
AttrOtherAddress uint16 = 0x802c
|
||||||
|
AttrEcnCheck uint16 = 0x802d // RFC 6679
|
||||||
|
AttrThirdPartyAuthorization uint16 = 0x802e // RFC 7635
|
||||||
|
AttrMobilityTicket uint16 = 0x8030 // RFC 8016
|
||||||
|
)
|
||||||
|
|
||||||
|
// Deprecated: For backwards compatibility only.
|
||||||
|
const (
|
||||||
|
AttrResponseAddress uint16 = 0x0002
|
||||||
|
AttrSourceAddress uint16 = 0x0004
|
||||||
|
AttrChangedAddress uint16 = 0x0005
|
||||||
|
AttrPassword uint16 = 0x0007
|
||||||
|
AttrReflectedFrom uint16 = 0x000b
|
||||||
|
AttrBandwidth uint16 = 0x0010
|
||||||
|
AttrTimerVal uint16 = 0x0021
|
||||||
|
)
|
||||||
|
|
||||||
|
// STUN attribute names.
|
||||||
|
var attrNames = map[uint16]string{
|
||||||
|
AttrMappedAddress: "MAPPED-ADDRESS",
|
||||||
|
AttrResponseAddress: "RESPONSE-ADDRESS",
|
||||||
|
AttrChangeRequest: "CHANGE-REQUEST",
|
||||||
|
AttrSourceAddress: "SOURCE-ADDRESS",
|
||||||
|
AttrChangedAddress: "CHANGED-ADDRESS",
|
||||||
|
AttrUsername: "USERNAME",
|
||||||
|
AttrPassword: "PASSWORD",
|
||||||
|
AttrMessageIntegrity: "MESSAGE-INTEGRITY",
|
||||||
|
AttrErrorCode: "ERROR-CODE",
|
||||||
|
AttrUnknownAttributes: "UNKNOWN-ATTRIBUTES",
|
||||||
|
AttrReflectedFrom: "REFLECTED-FROM",
|
||||||
|
AttrChannelNumber: "CHANNEL-NUMBER",
|
||||||
|
AttrLifetime: "LIFETIME",
|
||||||
|
AttrBandwidth: "BANDWIDTH",
|
||||||
|
AttrXorPeerAddress: "XOR-PEER-ADDRESS",
|
||||||
|
AttrData: "DATA",
|
||||||
|
AttrRealm: "REALM",
|
||||||
|
AttrNonce: "NONCE",
|
||||||
|
AttrXorRelayedAddress: "XOR-RELAYED-ADDRESS",
|
||||||
|
AttrRequestedAddressFamily: "REQUESTED-ADDRESS-FAMILY",
|
||||||
|
AttrEvenPort: "EVEN-PORT",
|
||||||
|
AttrRequestedTransport: "REQUESTED-TRANSPORT",
|
||||||
|
AttrDontFragment: "DONT-FRAGMENT",
|
||||||
|
AttrAccessToken: "ACCESS-TOKEN",
|
||||||
|
AttrXorMappedAddress: "XOR-MAPPED-ADDRESS",
|
||||||
|
AttrTimerVal: "TIMER-VAL",
|
||||||
|
AttrReservationToken: "RESERVATION-TOKEN",
|
||||||
|
AttrPriority: "PRIORITY",
|
||||||
|
AttrUseCandidate: "USE-CANDIDATE",
|
||||||
|
AttrPadding: "PADDING",
|
||||||
|
AttrResponsePort: "RESPONSE-PORT",
|
||||||
|
AttrConnectionID: "CONNECTION-ID",
|
||||||
|
AttrSoftware: "SOFTWARE",
|
||||||
|
AttrAlternateServer: "ALTERNATE-SERVER",
|
||||||
|
AttrTransactionTransmitCounter: "TRANSACTION-TRANSMIT-COUNTER",
|
||||||
|
AttrCacheTimeout: "CACHE-TIMEOUT",
|
||||||
|
AttrFingerprint: "FINGERPRINT",
|
||||||
|
AttrIceControlled: "ICE-CONTROLLED",
|
||||||
|
AttrIceControlling: "ICE-CONTROLLING",
|
||||||
|
AttrResponseOrigin: "RESPONSE-ORIGIN",
|
||||||
|
AttrOtherAddress: "OTHER-ADDRESS",
|
||||||
|
AttrEcnCheck: "ECN-CHECK",
|
||||||
|
AttrThirdPartyAuthorization: "THIRD-PARTY-AUTHORIZATION",
|
||||||
|
AttrMobilityTicket: "MOBILITY-TICKET",
|
||||||
|
}
|
||||||
|
|
||||||
|
// STUN error codes.
|
||||||
|
const (
|
||||||
|
CodeTryAlternate int = 300 // RFC 5389
|
||||||
|
CodeBadRequest int = 400
|
||||||
|
CodeUnauthorized int = 401
|
||||||
|
CodeForbidden int = 403 // RFC 5766
|
||||||
|
CodeMobilityForbidden int = 405 // RFC 8016
|
||||||
|
CodeUnknownAttribute int = 420 // RFC 5389
|
||||||
|
CodeAllocationMismatch int = 437 // RFC 5766
|
||||||
|
CodeStaleNonce int = 438 // RFC 5389
|
||||||
|
CodeAddressFamilyNotSupported int = 440 // RFC 6156
|
||||||
|
CodeWrongCredentials int = 441 // RFC 5766
|
||||||
|
CodeUnsupportedTransportProtocol int = 442
|
||||||
|
CodePeerAddressFamilyMismatch int = 443 // RFC 6156
|
||||||
|
CodeConnectionAlreadyExists int = 446 // RFC 6062
|
||||||
|
CodeConnectionTimeoutOrFailure int = 447
|
||||||
|
CodeAllocationQuotaReached int = 486 // RFC 5766
|
||||||
|
CodeRoleConflict int = 487 // RFC 5245
|
||||||
|
CodeServerError int = 500 // RFC 5389
|
||||||
|
CodeInsufficientCapacity int = 508 // RFC 5766
|
||||||
|
)
|
||||||
|
|
||||||
|
// STUN error texts.
|
||||||
|
var errorText = map[int]string{
|
||||||
|
CodeTryAlternate: "Try Alternate",
|
||||||
|
CodeBadRequest: "Bad Request",
|
||||||
|
CodeUnauthorized: "Unauthorized",
|
||||||
|
CodeForbidden: "Forbidden",
|
||||||
|
CodeMobilityForbidden: "Mobility Forbidden",
|
||||||
|
CodeUnknownAttribute: "Unknown Attribute",
|
||||||
|
CodeAllocationMismatch: "Allocation Mismatch",
|
||||||
|
CodeStaleNonce: "Stale Nonce",
|
||||||
|
CodeAddressFamilyNotSupported: "Address Family not Supported",
|
||||||
|
CodeWrongCredentials: "Wrong Credentials",
|
||||||
|
CodeUnsupportedTransportProtocol: "Unsupported Transport Protocol",
|
||||||
|
CodePeerAddressFamilyMismatch: "Peer Address Family Mismatch",
|
||||||
|
CodeConnectionAlreadyExists: "Connection Already Exists",
|
||||||
|
CodeConnectionTimeoutOrFailure: "Connection Timeout or Failure",
|
||||||
|
CodeAllocationQuotaReached: "Allocation Quota Reached",
|
||||||
|
CodeRoleConflict: "Role Conflict",
|
||||||
|
CodeServerError: "Server Error",
|
||||||
|
CodeInsufficientCapacity: "Insufficient Capacity",
|
||||||
|
}
|
126
vendor/github.com/isofew/go-stun/stun/server.go
generated
vendored
Normal file
126
vendor/github.com/isofew/go-stun/stun/server.go
generated
vendored
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ListenAndServe(network, laddr string, config *Config) error {
|
||||||
|
srv := NewServer(config)
|
||||||
|
return srv.ListenAndServe(network, laddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
agent *Agent
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
conns []net.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(config *Config) *Server {
|
||||||
|
srv := &Server{agent: NewAgent(config)}
|
||||||
|
srv.agent.Handler = srv
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) ListenAndServe(network, laddr string) error {
|
||||||
|
c, err := net.ListenPacket(network, laddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
srv.addConn(c)
|
||||||
|
defer srv.removeConn(c)
|
||||||
|
// not using stop channel
|
||||||
|
return srv.agent.ServePacket(c, make(chan struct{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) ServeSTUN(msg *Message, from Transport) {
|
||||||
|
if msg.Type == MethodBinding {
|
||||||
|
to := from
|
||||||
|
mapped := from.RemoteAddr()
|
||||||
|
ip, port := SockAddr(from.LocalAddr())
|
||||||
|
|
||||||
|
res := &Message{
|
||||||
|
Type: MethodBinding | KindResponse,
|
||||||
|
Transaction: msg.Transaction,
|
||||||
|
Attributes: []Attr{
|
||||||
|
Addr(AttrXorMappedAddress, mapped),
|
||||||
|
Addr(AttrMappedAddress, mapped),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.mu.RLock()
|
||||||
|
defer srv.mu.RUnlock()
|
||||||
|
|
||||||
|
if ch, ok := msg.GetInt(AttrChangeRequest); ok && ch != 0 {
|
||||||
|
for _, c := range srv.conns {
|
||||||
|
chip, chport := SockAddr(c.LocalAddr())
|
||||||
|
if chip.IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch&ChangeIP != 0 {
|
||||||
|
if !ip.Equal(chip) {
|
||||||
|
to = &packetConn{c, mapped}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else if ch&ChangePort != 0 {
|
||||||
|
if ip.Equal(chip) && port != chport {
|
||||||
|
to = &packetConn{c, mapped}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(srv.conns) < 2 {
|
||||||
|
srv.agent.Send(res, to)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
other:
|
||||||
|
for _, a := range srv.conns {
|
||||||
|
aip, aport := SockAddr(a.LocalAddr())
|
||||||
|
if aip.IsUnspecified() || !ip.Equal(aip) || port == aport {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, b := range srv.conns {
|
||||||
|
bip, bport := SockAddr(b.LocalAddr())
|
||||||
|
if bip.IsUnspecified() || bip.Equal(ip) || aport != bport {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
res.Set(Addr(AttrOtherAddress, b.LocalAddr()))
|
||||||
|
break other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.agent.Send(res, to)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) addConn(c net.PacketConn) {
|
||||||
|
srv.mu.Lock()
|
||||||
|
srv.conns = append(srv.conns, c)
|
||||||
|
srv.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) removeConn(c net.PacketConn) {
|
||||||
|
srv.mu.Lock()
|
||||||
|
l := srv.conns
|
||||||
|
for i, it := range l {
|
||||||
|
if it == c {
|
||||||
|
srv.conns = append(l[:i], l[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
srv.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) Close() error {
|
||||||
|
srv.mu.RLock()
|
||||||
|
defer srv.mu.RUnlock()
|
||||||
|
for _, it := range srv.conns {
|
||||||
|
it.Close()
|
||||||
|
}
|
||||||
|
srv.conns = nil
|
||||||
|
return nil
|
||||||
|
}
|
129
vendor/github.com/isofew/go-stun/stun/stun.go
generated
vendored
Normal file
129
vendor/github.com/isofew/go-stun/stun/stun.go
generated
vendored
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Discover(uri string) (net.PacketConn, net.Addr, error) {
|
||||||
|
stop := make(chan struct{})
|
||||||
|
conn, err := Dial(uri, nil, stop)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
addr, err := conn.Discover()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
// TODO: hijack
|
||||||
|
// stop reading conn before returning it
|
||||||
|
close(stop)
|
||||||
|
// note. serveconn/packet func is blocked by read/readfrom at the time
|
||||||
|
// we send the signal, which means it will still consume one more
|
||||||
|
// packet and we can only read starting from the second packet.
|
||||||
|
// (not too much of a problem, since we'll punch a few packets anyway)
|
||||||
|
return conn.Conn.(net.PacketConn), addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthMethod func(sess *Session) error
|
||||||
|
|
||||||
|
// LongTermAuthMethod returns AuthMethod for long-term credentials.
|
||||||
|
// Key = MD5(username ":" realm ":" SASLprep(password)).
|
||||||
|
// SASLprep is defined in RFC 4013.
|
||||||
|
func LongTermAuthMethod(username, password string) AuthMethod {
|
||||||
|
return func(sess *Session) error {
|
||||||
|
h := md5.New()
|
||||||
|
h.Write([]byte(username + ":" + sess.Realm + ":" + password))
|
||||||
|
sess.Username = username
|
||||||
|
sess.Key = h.Sum(nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShotTermAuthMethod returns AuthMethod for short-term credentials.
|
||||||
|
// Key = SASLprep(password).
|
||||||
|
// SASLprep is defined in RFC 4013.
|
||||||
|
func ShortTermAuthMethod(password string) AuthMethod {
|
||||||
|
key := []byte(password)
|
||||||
|
return func(sess *Session) error {
|
||||||
|
sess.Key = key
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dial(uri string, config *Config, stop chan struct{}) (*Conn, error) {
|
||||||
|
secure, network, addr, auth, err := parseURI(uri)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var conn net.Conn
|
||||||
|
if secure {
|
||||||
|
conn, err = tls.Dial(network, addr, nil)
|
||||||
|
} else {
|
||||||
|
if strings.HasPrefix(network, "udp") {
|
||||||
|
conn, err = dialUDP(network, addr)
|
||||||
|
} else {
|
||||||
|
conn, err = dialTCP(network, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if auth != nil {
|
||||||
|
config = config.Clone()
|
||||||
|
config.AuthMethod = auth
|
||||||
|
}
|
||||||
|
return NewConn(conn, config, stop), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseURI(uri string) (secure bool, network, addr string, auth AuthMethod, err error) {
|
||||||
|
var u *url.URL
|
||||||
|
if u, err = url.Parse(uri); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
host, port, e := net.SplitHostPort(u.Opaque)
|
||||||
|
if e != nil {
|
||||||
|
host = u.Opaque
|
||||||
|
}
|
||||||
|
if a := u.User; a != nil {
|
||||||
|
if password, ok := a.Password(); ok {
|
||||||
|
auth = LongTermAuthMethod(a.Username(), password)
|
||||||
|
} else {
|
||||||
|
auth = ShortTermAuthMethod(a.Username())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
network = u.Query().Get("transport")
|
||||||
|
if network == "" {
|
||||||
|
network = "udp"
|
||||||
|
}
|
||||||
|
switch u.Scheme {
|
||||||
|
case "stun", "turn":
|
||||||
|
if port == "" {
|
||||||
|
port = "3478"
|
||||||
|
}
|
||||||
|
switch network {
|
||||||
|
case "udp", "udp4", "udp6", "tcp", "tcp4", "tcp6":
|
||||||
|
default:
|
||||||
|
err = errors.New("stun: unsupported transport: " + network)
|
||||||
|
}
|
||||||
|
case "stuns", "turns":
|
||||||
|
if port == "" {
|
||||||
|
port = "5478"
|
||||||
|
}
|
||||||
|
secure = true
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
default:
|
||||||
|
err = errors.New("stun: unsupported transport: " + network)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = errors.New("stun: unsupported scheme " + u.Scheme)
|
||||||
|
}
|
||||||
|
addr = net.JoinHostPort(host, port)
|
||||||
|
return
|
||||||
|
}
|
96
vendor/github.com/isofew/go-stun/stun/transport.go
generated
vendored
Normal file
96
vendor/github.com/isofew/go-stun/stun/transport.go
generated
vendored
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Listener interface {
|
||||||
|
Addr() net.Addr
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Transport interface {
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
Write(p []byte) (int, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Marshaler interface {
|
||||||
|
Marshal(b []byte) []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type TransportHandler interface {
|
||||||
|
ServeTransport(b []byte, tr Transport) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialUDP(network, raddr string) (net.Conn, error) {
|
||||||
|
addr, err := net.ResolveUDPAddr(network, raddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := net.ListenUDP(network, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &packetConn{conn, addr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialTCP(network, raddr string) (net.Conn, error) {
|
||||||
|
addr, err := net.ResolveTCPAddr(network, raddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.DialTCP(network, nil, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type packetConn struct {
|
||||||
|
net.PacketConn
|
||||||
|
addr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *packetConn) Read(p []byte) (n int, err error) {
|
||||||
|
n, _, err = t.ReadFrom(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *packetConn) Write(p []byte) (int, error) {
|
||||||
|
return t.WriteTo(p, t.addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *packetConn) RemoteAddr() net.Addr {
|
||||||
|
return t.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errBufferOverflow = errors.New("stun: buffer overflow")
|
||||||
|
errFormat = errors.New("stun: format error")
|
||||||
|
)
|
||||||
|
|
||||||
|
func getBuffer() []byte {
|
||||||
|
return make([]byte, 2048)
|
||||||
|
}
|
||||||
|
|
||||||
|
func putBuffer(b []byte) {
|
||||||
|
if cap(b) >= 2048 {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func grow(p []byte, n int) (b, a []byte) {
|
||||||
|
l := len(p)
|
||||||
|
r := l + n
|
||||||
|
if r > cap(p) {
|
||||||
|
b = make([]byte, (1+((r-1)>>10))<<10)[:r]
|
||||||
|
a = b[l:r]
|
||||||
|
if l > 0 {
|
||||||
|
copy(b, p[:l])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return p[:r], p[l:r]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var be = binary.BigEndian
|
14
vendor/github.com/lucas-clemente/quic-go/Changelog.md
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/Changelog.md
generated
vendored
@ -1,6 +1,18 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## v0.6.0 (unreleased)
|
## v0.8.0 (unreleased)
|
||||||
|
|
||||||
|
- Add support for unidirectional streams (for IETF QUIC).
|
||||||
|
- Add a `quic.Config` option for the maximum number of incoming streams.
|
||||||
|
|
||||||
|
## v0.7.0 (2018-02-03)
|
||||||
|
|
||||||
|
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
|
||||||
|
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
|
||||||
|
- Expose the `ConnectionState` in the `Session` (experimental API).
|
||||||
|
- Implement packet pacing.
|
||||||
|
|
||||||
|
## v0.6.0 (2017-12-12)
|
||||||
|
|
||||||
- Add support for QUIC 39, drop support for QUIC 35 - 37
|
- Add support for QUIC 39, drop support for QUIC 35 - 37
|
||||||
- Added `quic.Config` options for maximal flow control windows
|
- Added `quic.Config` options for maximal flow control windows
|
||||||
|
7
vendor/github.com/lucas-clemente/quic-go/ackhandler/_gen.go
generated
vendored
7
vendor/github.com/lucas-clemente/quic-go/ackhandler/_gen.go
generated
vendored
@ -1,7 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
_ "github.com/clipperhouse/linkedlist"
|
|
||||||
_ "github.com/clipperhouse/slice"
|
|
||||||
_ "github.com/clipperhouse/stringer"
|
|
||||||
)
|
|
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
@ -1,34 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SentPacketHandler handles ACKs received for outgoing packets
|
|
||||||
type SentPacketHandler interface {
|
|
||||||
// SentPacket may modify the packet
|
|
||||||
SentPacket(packet *Packet) error
|
|
||||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
|
||||||
SetHandshakeComplete()
|
|
||||||
|
|
||||||
SendingAllowed() bool
|
|
||||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
|
||||||
ShouldSendRetransmittablePacket() bool
|
|
||||||
DequeuePacketForRetransmission() (packet *Packet)
|
|
||||||
GetLeastUnacked() protocol.PacketNumber
|
|
||||||
|
|
||||||
GetAlarmTimeout() time.Time
|
|
||||||
OnAlarm()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
|
||||||
type ReceivedPacketHandler interface {
|
|
||||||
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
|
|
||||||
SetLowerLimit(protocol.PacketNumber)
|
|
||||||
|
|
||||||
GetAlarmTimeout() time.Time
|
|
||||||
GetAckFrame() *wire.AckFrame
|
|
||||||
}
|
|
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go
generated
vendored
34
vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go
generated
vendored
@ -1,34 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Packet is a packet
|
|
||||||
// +gen linkedlist
|
|
||||||
type Packet struct {
|
|
||||||
PacketNumber protocol.PacketNumber
|
|
||||||
Frames []wire.Frame
|
|
||||||
Length protocol.ByteCount
|
|
||||||
EncryptionLevel protocol.EncryptionLevel
|
|
||||||
|
|
||||||
SendTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFramesForRetransmission gets all the frames for retransmission
|
|
||||||
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
|
|
||||||
var fs []wire.Frame
|
|
||||||
for _, frame := range p.Frames {
|
|
||||||
switch frame.(type) {
|
|
||||||
case *wire.AckFrame:
|
|
||||||
continue
|
|
||||||
case *wire.StopWaitingFrame:
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fs = append(fs, frame)
|
|
||||||
}
|
|
||||||
return fs
|
|
||||||
}
|
|
141
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go
generated
vendored
141
vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go
generated
vendored
@ -1,141 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
)
|
|
||||||
|
|
||||||
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
|
|
||||||
|
|
||||||
type receivedPacketHandler struct {
|
|
||||||
largestObserved protocol.PacketNumber
|
|
||||||
lowerLimit protocol.PacketNumber
|
|
||||||
largestObservedReceivedTime time.Time
|
|
||||||
|
|
||||||
packetHistory *receivedPacketHistory
|
|
||||||
|
|
||||||
ackSendDelay time.Duration
|
|
||||||
|
|
||||||
packetsReceivedSinceLastAck int
|
|
||||||
retransmittablePacketsReceivedSinceLastAck int
|
|
||||||
ackQueued bool
|
|
||||||
ackAlarm time.Time
|
|
||||||
lastAck *wire.AckFrame
|
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
|
||||||
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
|
|
||||||
return &receivedPacketHandler{
|
|
||||||
packetHistory: newReceivedPacketHistory(),
|
|
||||||
ackSendDelay: protocol.AckSendDelay,
|
|
||||||
version: version,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
|
|
||||||
if packetNumber == 0 {
|
|
||||||
return errInvalidPacketNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
if packetNumber > h.largestObserved {
|
|
||||||
h.largestObserved = packetNumber
|
|
||||||
h.largestObservedReceivedTime = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
if packetNumber <= h.lowerLimit {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
h.maybeQueueAck(packetNumber, shouldInstigateAck)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLowerLimit sets a lower limit for acking packets.
|
|
||||||
// Packets with packet numbers smaller or equal than p will not be acked.
|
|
||||||
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
|
|
||||||
h.lowerLimit = p
|
|
||||||
h.packetHistory.DeleteUpTo(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
|
|
||||||
h.packetsReceivedSinceLastAck++
|
|
||||||
|
|
||||||
if shouldInstigateAck {
|
|
||||||
h.retransmittablePacketsReceivedSinceLastAck++
|
|
||||||
}
|
|
||||||
|
|
||||||
// always ack the first packet
|
|
||||||
if h.lastAck == nil {
|
|
||||||
h.ackQueued = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.version < protocol.Version39 {
|
|
||||||
// Always send an ack every 20 packets in order to allow the peer to discard
|
|
||||||
// information from the SentPacketManager and provide an RTT measurement.
|
|
||||||
// From QUIC 39, this is not needed anymore, since the peer will regularly send a retransmittable packet.
|
|
||||||
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
|
|
||||||
h.ackQueued = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
|
|
||||||
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
|
|
||||||
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
|
|
||||||
h.ackQueued = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if a new missing range above the previously was created
|
|
||||||
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
|
|
||||||
h.ackQueued = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !h.ackQueued && shouldInstigateAck {
|
|
||||||
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
|
|
||||||
h.ackQueued = true
|
|
||||||
} else {
|
|
||||||
if h.ackAlarm.IsZero() {
|
|
||||||
h.ackAlarm = time.Now().Add(h.ackSendDelay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.ackQueued {
|
|
||||||
// cancel the ack alarm
|
|
||||||
h.ackAlarm = time.Time{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
|
||||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ackRanges := h.packetHistory.GetAckRanges()
|
|
||||||
ack := &wire.AckFrame{
|
|
||||||
LargestAcked: h.largestObserved,
|
|
||||||
LowestAcked: ackRanges[len(ackRanges)-1].First,
|
|
||||||
PacketReceivedTime: h.largestObservedReceivedTime,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ackRanges) > 1 {
|
|
||||||
ack.AckRanges = ackRanges
|
|
||||||
}
|
|
||||||
|
|
||||||
h.lastAck = ack
|
|
||||||
h.ackAlarm = time.Time{}
|
|
||||||
h.ackQueued = false
|
|
||||||
h.packetsReceivedSinceLastAck = 0
|
|
||||||
h.retransmittablePacketsReceivedSinceLastAck = 0
|
|
||||||
|
|
||||||
return ack
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
|
455
vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go
generated
vendored
455
vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go
generated
vendored
@ -1,455 +0,0 @@
|
|||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
|
||||||
// In fraction of an RTT.
|
|
||||||
timeReorderingFraction = 1.0 / 8
|
|
||||||
// The default RTT used before an RTT sample is taken.
|
|
||||||
// Note: This constant is also defined in the congestion package.
|
|
||||||
defaultInitialRTT = 100 * time.Millisecond
|
|
||||||
// defaultRTOTimeout is the RTO time on new connections
|
|
||||||
defaultRTOTimeout = 500 * time.Millisecond
|
|
||||||
// Minimum time in the future a tail loss probe alarm may be set for.
|
|
||||||
minTPLTimeout = 10 * time.Millisecond
|
|
||||||
// Minimum time in the future an RTO alarm may be set for.
|
|
||||||
minRTOTimeout = 200 * time.Millisecond
|
|
||||||
// maxRTOTimeout is the maximum RTO time
|
|
||||||
maxRTOTimeout = 60 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
|
|
||||||
ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
|
|
||||||
// ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets
|
|
||||||
ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets")
|
|
||||||
// ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped
|
|
||||||
ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
|
||||||
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
|
||||||
)
|
|
||||||
|
|
||||||
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
|
|
||||||
|
|
||||||
type sentPacketHandler struct {
|
|
||||||
lastSentPacketNumber protocol.PacketNumber
|
|
||||||
skippedPackets []protocol.PacketNumber
|
|
||||||
|
|
||||||
numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet
|
|
||||||
|
|
||||||
LargestAcked protocol.PacketNumber
|
|
||||||
|
|
||||||
largestReceivedPacketWithAck protocol.PacketNumber
|
|
||||||
|
|
||||||
packetHistory *PacketList
|
|
||||||
stopWaitingManager stopWaitingManager
|
|
||||||
|
|
||||||
retransmissionQueue []*Packet
|
|
||||||
|
|
||||||
bytesInFlight protocol.ByteCount
|
|
||||||
|
|
||||||
congestion congestion.SendAlgorithm
|
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
|
|
||||||
handshakeComplete bool
|
|
||||||
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
|
||||||
handshakeCount uint32
|
|
||||||
|
|
||||||
// The number of times an RTO has been sent without receiving an ack.
|
|
||||||
rtoCount uint32
|
|
||||||
|
|
||||||
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
|
|
||||||
lossTime time.Time
|
|
||||||
|
|
||||||
// The alarm timeout
|
|
||||||
alarm time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSentPacketHandler creates a new sentPacketHandler
|
|
||||||
func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
|
|
||||||
congestion := congestion.NewCubicSender(
|
|
||||||
congestion.DefaultClock{},
|
|
||||||
rttStats,
|
|
||||||
false, /* don't use reno since chromium doesn't (why?) */
|
|
||||||
protocol.InitialCongestionWindow,
|
|
||||||
protocol.DefaultMaxCongestionWindow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return &sentPacketHandler{
|
|
||||||
packetHistory: NewPacketList(),
|
|
||||||
stopWaitingManager: stopWaitingManager{},
|
|
||||||
rttStats: rttStats,
|
|
||||||
congestion: congestion,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
|
|
||||||
if f := h.packetHistory.Front(); f != nil {
|
|
||||||
return f.Value.PacketNumber - 1
|
|
||||||
}
|
|
||||||
return h.LargestAcked
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool {
|
|
||||||
return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
|
||||||
h.handshakeComplete = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|
||||||
if packet.PacketNumber <= h.lastSentPacketNumber {
|
|
||||||
return errPacketNumberNotIncreasing
|
|
||||||
}
|
|
||||||
|
|
||||||
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
|
|
||||||
return ErrTooManyTrackedSentPackets
|
|
||||||
}
|
|
||||||
|
|
||||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
|
||||||
h.skippedPackets = append(h.skippedPackets, p)
|
|
||||||
|
|
||||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
|
||||||
h.skippedPackets = h.skippedPackets[1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
h.lastSentPacketNumber = packet.PacketNumber
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
|
||||||
isRetransmittable := len(packet.Frames) != 0
|
|
||||||
|
|
||||||
if isRetransmittable {
|
|
||||||
packet.SendTime = now
|
|
||||||
h.bytesInFlight += packet.Length
|
|
||||||
h.packetHistory.PushBack(*packet)
|
|
||||||
h.numNonRetransmittablePackets = 0
|
|
||||||
} else {
|
|
||||||
h.numNonRetransmittablePackets++
|
|
||||||
}
|
|
||||||
|
|
||||||
h.congestion.OnPacketSent(
|
|
||||||
now,
|
|
||||||
h.bytesInFlight,
|
|
||||||
packet.PacketNumber,
|
|
||||||
packet.Length,
|
|
||||||
isRetransmittable,
|
|
||||||
)
|
|
||||||
|
|
||||||
h.updateLossDetectionAlarm()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
|
||||||
if ackFrame.LargestAcked > h.lastSentPacketNumber {
|
|
||||||
return errAckForUnsentPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
// duplicate or out-of-order ACK
|
|
||||||
if withPacketNumber <= h.largestReceivedPacketWithAck {
|
|
||||||
return ErrDuplicateOrOutOfOrderAck
|
|
||||||
}
|
|
||||||
h.largestReceivedPacketWithAck = withPacketNumber
|
|
||||||
|
|
||||||
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
|
|
||||||
if ackFrame.LargestAcked <= h.largestInOrderAcked() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
h.LargestAcked = ackFrame.LargestAcked
|
|
||||||
|
|
||||||
if h.skippedPacketsAcked(ackFrame) {
|
|
||||||
return ErrAckForSkippedPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
|
|
||||||
|
|
||||||
if rttUpdated {
|
|
||||||
h.congestion.MaybeExitSlowStart()
|
|
||||||
}
|
|
||||||
|
|
||||||
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ackedPackets) > 0 {
|
|
||||||
for _, p := range ackedPackets {
|
|
||||||
if encLevel < p.Value.EncryptionLevel {
|
|
||||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
|
|
||||||
}
|
|
||||||
h.onPacketAcked(p)
|
|
||||||
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
h.detectLostPackets()
|
|
||||||
h.updateLossDetectionAlarm()
|
|
||||||
|
|
||||||
h.garbageCollectSkippedPackets()
|
|
||||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
|
|
||||||
var ackedPackets []*PacketElement
|
|
||||||
ackRangeIndex := 0
|
|
||||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
|
||||||
packet := el.Value
|
|
||||||
packetNumber := packet.PacketNumber
|
|
||||||
|
|
||||||
// Ignore packets below the LowestAcked
|
|
||||||
if packetNumber < ackFrame.LowestAcked {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Break after LargestAcked is reached
|
|
||||||
if packetNumber > ackFrame.LargestAcked {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if ackFrame.HasMissingRanges() {
|
|
||||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
|
||||||
|
|
||||||
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
|
||||||
ackRangeIndex++
|
|
||||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
|
||||||
}
|
|
||||||
|
|
||||||
if packetNumber >= ackRange.First { // packet i contained in ACK range
|
|
||||||
if packetNumber > ackRange.Last {
|
|
||||||
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
|
|
||||||
}
|
|
||||||
ackedPackets = append(ackedPackets, el)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ackedPackets = append(ackedPackets, el)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ackedPackets, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
|
|
||||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
|
||||||
packet := el.Value
|
|
||||||
if packet.PacketNumber == largestAcked {
|
|
||||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Packets are sorted by number, so we can stop searching
|
|
||||||
if packet.PacketNumber > largestAcked {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
|
||||||
// Cancel the alarm if no packets are outstanding
|
|
||||||
if h.packetHistory.Len() == 0 {
|
|
||||||
h.alarm = time.Time{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(#497): TLP
|
|
||||||
if !h.handshakeComplete {
|
|
||||||
h.alarm = time.Now().Add(h.computeHandshakeTimeout())
|
|
||||||
} else if !h.lossTime.IsZero() {
|
|
||||||
// Early retransmit timer or time loss detection.
|
|
||||||
h.alarm = h.lossTime
|
|
||||||
} else {
|
|
||||||
// RTO
|
|
||||||
h.alarm = time.Now().Add(h.computeRTOTimeout())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) detectLostPackets() {
|
|
||||||
h.lossTime = time.Time{}
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
|
||||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
|
||||||
|
|
||||||
var lostPackets []*PacketElement
|
|
||||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
|
||||||
packet := el.Value
|
|
||||||
|
|
||||||
if packet.PacketNumber > h.LargestAcked {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
timeSinceSent := now.Sub(packet.SendTime)
|
|
||||||
if timeSinceSent > delayUntilLost {
|
|
||||||
lostPackets = append(lostPackets, el)
|
|
||||||
} else if h.lossTime.IsZero() {
|
|
||||||
// Note: This conditional is only entered once per call
|
|
||||||
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(lostPackets) > 0 {
|
|
||||||
for _, p := range lostPackets {
|
|
||||||
h.queuePacketForRetransmission(p)
|
|
||||||
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) OnAlarm() {
|
|
||||||
// TODO(#497): TLP
|
|
||||||
if !h.handshakeComplete {
|
|
||||||
h.queueHandshakePacketsForRetransmission()
|
|
||||||
h.handshakeCount++
|
|
||||||
} else if !h.lossTime.IsZero() {
|
|
||||||
// Early retransmit or time loss detection
|
|
||||||
h.detectLostPackets()
|
|
||||||
} else {
|
|
||||||
// RTO
|
|
||||||
h.retransmitOldestTwoPackets()
|
|
||||||
h.rtoCount++
|
|
||||||
}
|
|
||||||
|
|
||||||
h.updateLossDetectionAlarm()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
|
||||||
return h.alarm
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
|
|
||||||
h.bytesInFlight -= packetElement.Value.Length
|
|
||||||
h.rtoCount = 0
|
|
||||||
h.handshakeCount = 0
|
|
||||||
// TODO(#497): h.tlpCount = 0
|
|
||||||
h.packetHistory.Remove(packetElement)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
|
||||||
if len(h.retransmissionQueue) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
packet := h.retransmissionQueue[0]
|
|
||||||
// Shift the slice and don't retain anything that isn't needed.
|
|
||||||
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
|
|
||||||
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
|
|
||||||
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
|
|
||||||
return packet
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
|
|
||||||
return h.largestInOrderAcked() + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
|
||||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) SendingAllowed() bool {
|
|
||||||
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
|
|
||||||
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
|
|
||||||
if congestionLimited {
|
|
||||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
|
|
||||||
h.bytesInFlight,
|
|
||||||
h.congestion.GetCongestionWindow())
|
|
||||||
}
|
|
||||||
// Workaround for #555:
|
|
||||||
// Always allow sending of retransmissions. This should probably be limited
|
|
||||||
// to RTOs, but we currently don't have a nice way of distinguishing them.
|
|
||||||
haveRetransmissions := len(h.retransmissionQueue) > 0
|
|
||||||
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
|
||||||
if p := h.packetHistory.Front(); p != nil {
|
|
||||||
h.queueRTO(p)
|
|
||||||
}
|
|
||||||
if p := h.packetHistory.Front(); p != nil {
|
|
||||||
h.queueRTO(p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
|
|
||||||
packet := &el.Value
|
|
||||||
utils.Debugf(
|
|
||||||
"\tQueueing packet 0x%x for retransmission (RTO), %d outstanding",
|
|
||||||
packet.PacketNumber,
|
|
||||||
h.packetHistory.Len(),
|
|
||||||
)
|
|
||||||
h.queuePacketForRetransmission(el)
|
|
||||||
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
|
|
||||||
h.congestion.OnRetransmissionTimeout(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
|
|
||||||
var handshakePackets []*PacketElement
|
|
||||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
|
||||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
|
||||||
handshakePackets = append(handshakePackets, el)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, el := range handshakePackets {
|
|
||||||
h.queuePacketForRetransmission(el)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
|
|
||||||
packet := &packetElement.Value
|
|
||||||
h.bytesInFlight -= packet.Length
|
|
||||||
h.retransmissionQueue = append(h.retransmissionQueue, packet)
|
|
||||||
h.packetHistory.Remove(packetElement)
|
|
||||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
|
||||||
duration := 2 * h.rttStats.SmoothedRTT()
|
|
||||||
if duration == 0 {
|
|
||||||
duration = 2 * defaultInitialRTT
|
|
||||||
}
|
|
||||||
duration = utils.MaxDuration(duration, minTPLTimeout)
|
|
||||||
// exponential backoff
|
|
||||||
// There's an implicit limit to this set by the handshake timeout.
|
|
||||||
return duration << h.handshakeCount
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
|
||||||
rto := h.congestion.RetransmissionDelay()
|
|
||||||
if rto == 0 {
|
|
||||||
rto = defaultRTOTimeout
|
|
||||||
}
|
|
||||||
rto = utils.MaxDuration(rto, minRTOTimeout)
|
|
||||||
// Exponential backoff
|
|
||||||
rto = rto << h.rtoCount
|
|
||||||
return utils.MinDuration(rto, maxRTOTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
|
||||||
for _, p := range h.skippedPackets {
|
|
||||||
if ackFrame.AcksPacket(p) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
|
||||||
lioa := h.largestInOrderAcked()
|
|
||||||
deleteIndex := 0
|
|
||||||
for i, p := range h.skippedPackets {
|
|
||||||
if p <= lioa {
|
|
||||||
deleteIndex = i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
|
||||||
}
|
|
4
vendor/github.com/lucas-clemente/quic-go/appveyor.yml
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/appveyor.yml
generated
vendored
@ -14,8 +14,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
|
|||||||
|
|
||||||
install:
|
install:
|
||||||
- rmdir c:\go /s /q
|
- rmdir c:\go /s /q
|
||||||
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.9.2.windows-amd64.zip
|
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.10.2.windows-amd64.zip
|
||||||
- 7z x go1.9.2.windows-amd64.zip -y -oC:\ > NUL
|
- 7z x go1.10.2.windows-amd64.zip -y -oC:\ > NUL
|
||||||
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
|
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
|
||||||
- echo %PATH%
|
- echo %PATH%
|
||||||
- echo %GOPATH%
|
- echo %GOPATH%
|
||||||
|
13
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
@ -8,19 +8,20 @@ import (
|
|||||||
|
|
||||||
var bufferPool sync.Pool
|
var bufferPool sync.Pool
|
||||||
|
|
||||||
func getPacketBuffer() []byte {
|
func getPacketBuffer() *[]byte {
|
||||||
return bufferPool.Get().([]byte)
|
return bufferPool.Get().(*[]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func putPacketBuffer(buf []byte) {
|
func putPacketBuffer(buf *[]byte) {
|
||||||
if cap(buf) != int(protocol.MaxReceivePacketSize) {
|
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
|
||||||
panic("putPacketBuffer called with packet of wrong size!")
|
panic("putPacketBuffer called with packet of wrong size!")
|
||||||
}
|
}
|
||||||
bufferPool.Put(buf[:0])
|
bufferPool.Put(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
bufferPool.New = func() interface{} {
|
bufferPool.New = func() interface{} {
|
||||||
return make([]byte, 0, protocol.MaxReceivePacketSize)
|
b := make([]byte, 0, protocol.MaxReceivePacketSize)
|
||||||
|
return &b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
406
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
406
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
@ -10,6 +10,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
@ -22,24 +23,29 @@ type client struct {
|
|||||||
conn connection
|
conn connection
|
||||||
hostname string
|
hostname string
|
||||||
|
|
||||||
handshakeChan <-chan handshakeEvent
|
|
||||||
|
|
||||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||||
versionNegotiated bool // has version negotiation completed yet
|
versionNegotiated bool // has the server accepted our version
|
||||||
receivedVersionNegotiationPacket bool
|
receivedVersionNegotiationPacket bool
|
||||||
|
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||||
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
config *Config
|
config *Config
|
||||||
|
tls handshake.MintTLS // only used when using TLS
|
||||||
|
|
||||||
connectionID protocol.ConnectionID
|
srcConnID protocol.ConnectionID
|
||||||
|
destConnID protocol.ConnectionID
|
||||||
|
|
||||||
|
initialVersion protocol.VersionNumber
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
|
||||||
session packetHandler
|
session packetHandler
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// make it possible to mock connection ID generation in the tests
|
// make it possible to mock connection ID generation in the tests
|
||||||
generateConnectionID = utils.GenerateConnectionID
|
generateConnectionID = protocol.GenerateConnectionID
|
||||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,69 +63,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||||||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
|
||||||
// The hostname for SNI is taken from the given address.
|
|
||||||
func DialAddrNonFWSecure(
|
|
||||||
addr string,
|
|
||||||
tlsConf *tls.Config,
|
|
||||||
config *Config,
|
|
||||||
) (NonFWSession, error) {
|
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
|
||||||
// The host parameter is used for SNI.
|
|
||||||
func DialNonFWSecure(
|
|
||||||
pconn net.PacketConn,
|
|
||||||
remoteAddr net.Addr,
|
|
||||||
host string,
|
|
||||||
tlsConf *tls.Config,
|
|
||||||
config *Config,
|
|
||||||
) (NonFWSession, error) {
|
|
||||||
connID, err := generateConnectionID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var hostname string
|
|
||||||
if tlsConf != nil {
|
|
||||||
hostname = tlsConf.ServerName
|
|
||||||
}
|
|
||||||
|
|
||||||
if hostname == "" {
|
|
||||||
hostname, _, err = net.SplitHostPort(host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clientConfig := populateClientConfig(config)
|
|
||||||
c := &client{
|
|
||||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
|
||||||
connectionID: connID,
|
|
||||||
hostname: hostname,
|
|
||||||
tlsConf: tlsConf,
|
|
||||||
config: clientConfig,
|
|
||||||
version: clientConfig.Versions[0],
|
|
||||||
versionNegotiationChan: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
|
||||||
|
|
||||||
if err := c.establishSecureConnection(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return c.session.(NonFWSession), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||||
// The host parameter is used for SNI.
|
// The host parameter is used for SNI.
|
||||||
func Dial(
|
func Dial(
|
||||||
@ -129,14 +72,57 @@ func Dial(
|
|||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
clientConfig := populateClientConfig(config)
|
||||||
|
version := clientConfig.Versions[0]
|
||||||
|
srcConnID, err := generateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := sess.WaitUntilHandshakeComplete(); err != nil {
|
destConnID := srcConnID
|
||||||
|
if version.UsesTLS() {
|
||||||
|
destConnID, err = generateConnectionID()
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return sess, nil
|
}
|
||||||
|
|
||||||
|
var hostname string
|
||||||
|
if tlsConf != nil {
|
||||||
|
hostname = tlsConf.ServerName
|
||||||
|
}
|
||||||
|
if hostname == "" {
|
||||||
|
hostname, _, err = net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that all versions are actually supported
|
||||||
|
if config != nil {
|
||||||
|
for _, v := range config.Versions {
|
||||||
|
if !protocol.IsValidVersion(v) {
|
||||||
|
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c := &client{
|
||||||
|
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||||
|
srcConnID: srcConnID,
|
||||||
|
destConnID: destConnID,
|
||||||
|
hostname: hostname,
|
||||||
|
tlsConf: tlsConf,
|
||||||
|
config: clientConfig,
|
||||||
|
version: version,
|
||||||
|
versionNegotiationChan: make(chan struct{}),
|
||||||
|
logger: utils.DefaultLogger,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||||
|
|
||||||
|
if err := c.dial(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||||
@ -167,6 +153,18 @@ func populateClientConfig(config *Config) *Config {
|
|||||||
if maxReceiveConnectionFlowControlWindow == 0 {
|
if maxReceiveConnectionFlowControlWindow == 0 {
|
||||||
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
|
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
|
||||||
}
|
}
|
||||||
|
maxIncomingStreams := config.MaxIncomingStreams
|
||||||
|
if maxIncomingStreams == 0 {
|
||||||
|
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
||||||
|
} else if maxIncomingStreams < 0 {
|
||||||
|
maxIncomingStreams = 0
|
||||||
|
}
|
||||||
|
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
||||||
|
if maxIncomingUniStreams == 0 {
|
||||||
|
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
||||||
|
} else if maxIncomingUniStreams < 0 {
|
||||||
|
maxIncomingUniStreams = 0
|
||||||
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
Versions: versions,
|
Versions: versions,
|
||||||
@ -175,29 +173,87 @@ func populateClientConfig(config *Config) *Config {
|
|||||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||||
|
MaxIncomingStreams: maxIncomingStreams,
|
||||||
|
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||||
KeepAlive: config.KeepAlive,
|
KeepAlive: config.KeepAlive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
func (c *client) dial() error {
|
||||||
func (c *client) establishSecureConnection() error {
|
var err error
|
||||||
if err := c.createNewSession(c.version, nil); err != nil {
|
if c.version.UsesTLS() {
|
||||||
|
err = c.dialTLS()
|
||||||
|
} else {
|
||||||
|
err = c.dialGQUIC()
|
||||||
|
}
|
||||||
|
if err == errCloseSessionForNewVersion {
|
||||||
|
return c.dial()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialGQUIC() error {
|
||||||
|
if err := c.createNewGQUICSession(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go c.listen()
|
go c.listen()
|
||||||
|
return c.establishSecureConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialTLS() error {
|
||||||
|
params := &handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
IdleTimeout: c.config.IdleTimeout,
|
||||||
|
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||||
|
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
||||||
|
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
||||||
|
}
|
||||||
|
csc := handshake.NewCryptoStreamConn(nil)
|
||||||
|
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
|
||||||
|
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mintConf.ExtensionHandler = extHandler
|
||||||
|
mintConf.ServerName = c.hostname
|
||||||
|
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||||
|
|
||||||
|
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go c.listen()
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
if err != handshake.ErrCloseSessionForRetry {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.logger.Infof("Received a Retry packet. Recreating session.")
|
||||||
|
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||||
|
// It returns:
|
||||||
|
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||||
|
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||||
|
// - any other error that might occur
|
||||||
|
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||||
|
func (c *client) establishSecureConnection() error {
|
||||||
var runErr error
|
var runErr error
|
||||||
errorChan := make(chan struct{})
|
errorChan := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
// session.run() returns as soon as the session is closed
|
runErr = c.session.run() // returns as soon as the session is closed
|
||||||
runErr = c.session.run()
|
|
||||||
if runErr == errCloseSessionForNewVersion {
|
|
||||||
// run the new session
|
|
||||||
runErr = c.session.run()
|
|
||||||
}
|
|
||||||
close(errorChan)
|
close(errorChan)
|
||||||
utils.Infof("Connection %x closed.", c.connectionID)
|
c.logger.Infof("Connection %s closed.", c.srcConnID)
|
||||||
|
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait until the server accepts the QUIC version (or an error occurs)
|
// wait until the server accepts the QUIC version (or an error occurs)
|
||||||
@ -210,96 +266,95 @@ func (c *client) establishSecureConnection() error {
|
|||||||
select {
|
select {
|
||||||
case <-errorChan:
|
case <-errorChan:
|
||||||
return runErr
|
return runErr
|
||||||
case ev := <-c.handshakeChan:
|
case err := <-c.session.handshakeStatus():
|
||||||
if ev.err != nil {
|
return err
|
||||||
return ev.err
|
|
||||||
}
|
|
||||||
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
|
|
||||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen listens
|
// Listen listens on the underlying connection and passes packets on for handling.
|
||||||
|
// It returns when the connection is closed.
|
||||||
func (c *client) listen() {
|
func (c *client) listen() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var n int
|
var n int
|
||||||
var addr net.Addr
|
var addr net.Addr
|
||||||
data := getPacketBuffer()
|
data := *getPacketBuffer()
|
||||||
data = data[:protocol.MaxReceivePacketSize]
|
data = data[:protocol.MaxReceivePacketSize]
|
||||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||||
n, addr, err = c.conn.Read(data)
|
n, addr, err = c.conn.Read(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||||
|
c.mutex.Lock()
|
||||||
|
if c.session != nil {
|
||||||
c.session.Close(err)
|
c.session.Close(err)
|
||||||
}
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
data = data[:n]
|
if err := c.handlePacket(addr, data[:n]); err != nil {
|
||||||
|
c.logger.Errorf("error handling packet: %s", err.Error())
|
||||||
c.handlePacket(addr, data)
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||||
rcvTime := time.Now()
|
rcvTime := time.Now()
|
||||||
|
|
||||||
r := bytes.NewReader(packet)
|
r := bytes.NewReader(packet)
|
||||||
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||||
|
// drop the packet if we can't parse the header
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||||
// drop this packet if we can't parse the header
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// reject packets with truncated connection id if we didn't request truncation
|
// reject packets with truncated connection id if we didn't request truncation
|
||||||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||||
return
|
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
||||||
}
|
|
||||||
// reject packets with the wrong connection ID
|
|
||||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
|
packetData := packet[len(packet)-r.Len():]
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
if hdr.ResetFlag {
|
|
||||||
cr := c.conn.RemoteAddr()
|
|
||||||
// check if the remote address and the connection ID match
|
|
||||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
|
||||||
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
|
|
||||||
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pr, err := wire.ParsePublicReset(r)
|
|
||||||
if err != nil {
|
|
||||||
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
|
||||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
isVersionNegotiationPacket := hdr.VersionFlag /* gQUIC Version Negotiation Packet */ || hdr.Type == protocol.PacketTypeVersionNegotiation /* IETF draft style Version Negotiation Packet */
|
|
||||||
|
|
||||||
// handle Version Negotiation Packets
|
// handle Version Negotiation Packets
|
||||||
if isVersionNegotiationPacket {
|
if hdr.IsVersionNegotiation {
|
||||||
// ignore delayed / duplicated version negotiation packets
|
// ignore delayed / duplicated version negotiation packets
|
||||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||||
return
|
return errors.New("received a delayed Version Negotiation Packet")
|
||||||
}
|
}
|
||||||
|
|
||||||
// version negotiation packets have no payload
|
// version negotiation packets have no payload
|
||||||
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||||
c.session.Close(err)
|
c.session.Close(err)
|
||||||
}
|
}
|
||||||
return
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.IsPublicHeader {
|
||||||
|
return c.handleGQUICPacket(hdr, r, packetData, remoteAddr, rcvTime)
|
||||||
|
}
|
||||||
|
return c.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
||||||
|
// TODO(#1003): add support for server-chosen connection IDs
|
||||||
|
// reject packets with the wrong connection ID
|
||||||
|
if !hdr.DestConnectionID.Equal(c.srcConnID) {
|
||||||
|
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
|
||||||
|
}
|
||||||
|
if hdr.IsLongHeader {
|
||||||
|
if hdr.Type != protocol.PacketTypeRetry && hdr.Type != protocol.PacketTypeHandshake {
|
||||||
|
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
|
||||||
|
}
|
||||||
|
c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen)
|
||||||
|
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
|
||||||
|
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
|
||||||
|
}
|
||||||
|
packetData = packetData[:int(hdr.PayloadLen)]
|
||||||
|
// TODO(#1312): implement parsing of compound packets
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is the first packet we are receiving
|
// this is the first packet we are receiving
|
||||||
@ -312,9 +367,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||||||
c.session.handlePacket(&receivedPacket{
|
c.session.handlePacket(&receivedPacket{
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
header: hdr,
|
header: hdr,
|
||||||
data: packet[len(packet)-r.Len():],
|
data: packetData,
|
||||||
rcvTime: rcvTime,
|
rcvTime: rcvTime,
|
||||||
})
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
||||||
|
// reject packets with the wrong connection ID
|
||||||
|
if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.srcConnID) {
|
||||||
|
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.ResetFlag {
|
||||||
|
cr := c.conn.RemoteAddr()
|
||||||
|
// check if the remote address and the connection ID match
|
||||||
|
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||||
|
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.srcConnID) {
|
||||||
|
return errors.New("Received a spoofed Public Reset")
|
||||||
|
}
|
||||||
|
pr, err := wire.ParsePublicReset(r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||||
|
}
|
||||||
|
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||||
|
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is the first packet we are receiving
|
||||||
|
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||||
|
if !c.versionNegotiated {
|
||||||
|
c.versionNegotiated = true
|
||||||
|
close(c.versionNegotiationChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.session.handlePacket(&receivedPacket{
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
header: hdr,
|
||||||
|
data: packetData,
|
||||||
|
rcvTime: rcvTime,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
@ -327,42 +421,66 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.receivedVersionNegotiationPacket = true
|
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||||
|
|
||||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||||
if !ok {
|
if !ok {
|
||||||
return qerr.InvalidVersion
|
return qerr.InvalidVersion
|
||||||
}
|
}
|
||||||
|
c.receivedVersionNegotiationPacket = true
|
||||||
|
c.negotiatedVersions = hdr.SupportedVersions
|
||||||
|
|
||||||
// switch to negotiated version
|
// switch to negotiated version
|
||||||
initialVersion := c.version
|
c.initialVersion = c.version
|
||||||
c.version = newVersion
|
c.version = newVersion
|
||||||
var err error
|
var err error
|
||||||
c.connectionID, err = utils.GenerateConnectionID()
|
c.destConnID, err = generateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
// in gQUIC, there's only one connection ID
|
||||||
|
if !c.version.UsesTLS() {
|
||||||
// create a new session and close the old one
|
c.srcConnID = c.destConnID
|
||||||
// the new session must be created first to update client member variables
|
}
|
||||||
oldSession := c.session
|
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||||
defer oldSession.Close(errCloseSessionForNewVersion)
|
c.session.Close(errCloseSessionForNewVersion)
|
||||||
return c.createNewSession(initialVersion, hdr.SupportedVersions)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error {
|
func (c *client) createNewGQUICSession() (err error) {
|
||||||
var err error
|
c.mutex.Lock()
|
||||||
utils.Debugf("createNewSession with initial version %s", initialVersion)
|
defer c.mutex.Unlock()
|
||||||
c.session, c.handshakeChan, err = newClientSession(
|
c.session, err = newClientSession(
|
||||||
c.conn,
|
c.conn,
|
||||||
c.hostname,
|
c.hostname,
|
||||||
c.version,
|
c.version,
|
||||||
c.connectionID,
|
c.destConnID,
|
||||||
c.tlsConf,
|
c.tlsConf,
|
||||||
c.config,
|
c.config,
|
||||||
initialVersion,
|
c.initialVersion,
|
||||||
negotiatedVersions,
|
c.negotiatedVersions,
|
||||||
|
c.logger,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) createNewTLSSession(
|
||||||
|
paramsChan <-chan handshake.TransportParameters,
|
||||||
|
version protocol.VersionNumber,
|
||||||
|
) (err error) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
c.session, err = newTLSClientSession(
|
||||||
|
c.conn,
|
||||||
|
c.hostname,
|
||||||
|
c.version,
|
||||||
|
c.destConnID,
|
||||||
|
c.srcConnID,
|
||||||
|
c.config,
|
||||||
|
c.tls,
|
||||||
|
paramsChan,
|
||||||
|
1,
|
||||||
|
c.logger,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
7
vendor/github.com/lucas-clemente/quic-go/codecov.yml
generated
vendored
7
vendor/github.com/lucas-clemente/quic-go/codecov.yml
generated
vendored
@ -1,11 +1,16 @@
|
|||||||
coverage:
|
coverage:
|
||||||
round: nearest
|
round: nearest
|
||||||
ignore:
|
ignore:
|
||||||
- ackhandler/packet_linkedlist.go
|
- streams_map_incoming_bidi.go
|
||||||
|
- streams_map_incoming_uni.go
|
||||||
|
- streams_map_outgoing_bidi.go
|
||||||
|
- streams_map_outgoing_uni.go
|
||||||
- h2quic/gzipreader.go
|
- h2quic/gzipreader.go
|
||||||
- h2quic/response.go
|
- h2quic/response.go
|
||||||
|
- internal/ackhandler/packet_linkedlist.go
|
||||||
- internal/utils/byteinterval_linkedlist.go
|
- internal/utils/byteinterval_linkedlist.go
|
||||||
- internal/utils/packetinterval_linkedlist.go
|
- internal/utils/packetinterval_linkedlist.go
|
||||||
|
- internal/utils/linkedlist/linkedlist.go
|
||||||
status:
|
status:
|
||||||
project:
|
project:
|
||||||
default:
|
default:
|
||||||
|
183
vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go
generated
vendored
183
vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go
generated
vendored
@ -1,183 +0,0 @@
|
|||||||
package congestion
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Note: This constant is also defined in the ackhandler package.
|
|
||||||
initialRTTus = 100 * 1000
|
|
||||||
rttAlpha float32 = 0.125
|
|
||||||
oneMinusAlpha float32 = (1 - rttAlpha)
|
|
||||||
rttBeta float32 = 0.25
|
|
||||||
oneMinusBeta float32 = (1 - rttBeta)
|
|
||||||
halfWindow float32 = 0.5
|
|
||||||
quarterWindow float32 = 0.25
|
|
||||||
)
|
|
||||||
|
|
||||||
type rttSample struct {
|
|
||||||
rtt time.Duration
|
|
||||||
time time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// RTTStats provides round-trip statistics
|
|
||||||
type RTTStats struct {
|
|
||||||
initialRTTus int64
|
|
||||||
|
|
||||||
recentMinRTTwindow time.Duration
|
|
||||||
minRTT time.Duration
|
|
||||||
latestRTT time.Duration
|
|
||||||
smoothedRTT time.Duration
|
|
||||||
meanDeviation time.Duration
|
|
||||||
|
|
||||||
numMinRTTsamplesRemaining uint32
|
|
||||||
|
|
||||||
newMinRTT rttSample
|
|
||||||
recentMinRTT rttSample
|
|
||||||
halfWindowRTT rttSample
|
|
||||||
quarterWindowRTT rttSample
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRTTStats makes a properly initialized RTTStats object
|
|
||||||
func NewRTTStats() *RTTStats {
|
|
||||||
return &RTTStats{
|
|
||||||
initialRTTus: initialRTTus,
|
|
||||||
recentMinRTTwindow: utils.InfDuration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitialRTTus is the initial RTT in us
|
|
||||||
func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus }
|
|
||||||
|
|
||||||
// MinRTT Returns the minRTT for the entire connection.
|
|
||||||
// May return Zero if no valid updates have occurred.
|
|
||||||
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
|
|
||||||
|
|
||||||
// LatestRTT returns the most recent rtt measurement.
|
|
||||||
// May return Zero if no valid updates have occurred.
|
|
||||||
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
|
|
||||||
|
|
||||||
// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the
|
|
||||||
// minRTT for the entire connection if SampleNewMinRtt was never called.
|
|
||||||
func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt }
|
|
||||||
|
|
||||||
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
|
|
||||||
// May return Zero if no valid updates have occurred.
|
|
||||||
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
|
|
||||||
|
|
||||||
// GetQuarterWindowRTT gets the quarter window RTT
|
|
||||||
func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt }
|
|
||||||
|
|
||||||
// GetHalfWindowRTT gets the half window RTT
|
|
||||||
func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt }
|
|
||||||
|
|
||||||
// MeanDeviation gets the mean deviation
|
|
||||||
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
|
|
||||||
|
|
||||||
// SetRecentMinRTTwindow sets how old a recent min rtt sample can be.
|
|
||||||
func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
|
|
||||||
r.recentMinRTTwindow = recentMinRTTwindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateRTT updates the RTT based on a new sample.
|
|
||||||
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
|
||||||
if sendDelta == utils.InfDuration || sendDelta <= 0 {
|
|
||||||
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
|
|
||||||
// ackDelay but the raw observed sendDelta, since poor clock granularity at
|
|
||||||
// the client may cause a high ackDelay to result in underestimation of the
|
|
||||||
// r.minRTT.
|
|
||||||
if r.minRTT == 0 || r.minRTT > sendDelta {
|
|
||||||
r.minRTT = sendDelta
|
|
||||||
}
|
|
||||||
r.updateRecentMinRTT(sendDelta, now)
|
|
||||||
|
|
||||||
// Correct for ackDelay if information received from the peer results in a
|
|
||||||
// positive RTT sample. Otherwise, we use the sendDelta as a reasonable
|
|
||||||
// measure for smoothedRTT.
|
|
||||||
sample := sendDelta
|
|
||||||
if sample > ackDelay {
|
|
||||||
sample -= ackDelay
|
|
||||||
}
|
|
||||||
r.latestRTT = sample
|
|
||||||
// First time call.
|
|
||||||
if r.smoothedRTT == 0 {
|
|
||||||
r.smoothedRTT = sample
|
|
||||||
r.meanDeviation = sample / 2
|
|
||||||
} else {
|
|
||||||
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
|
|
||||||
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update.
|
|
||||||
if r.numMinRTTsamplesRemaining > 0 {
|
|
||||||
r.numMinRTTsamplesRemaining--
|
|
||||||
if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt {
|
|
||||||
r.newMinRTT = rttSample{rtt: sample, time: now}
|
|
||||||
}
|
|
||||||
if r.numMinRTTsamplesRemaining == 0 {
|
|
||||||
r.recentMinRTT = r.newMinRTT
|
|
||||||
r.halfWindowRTT = r.newMinRTT
|
|
||||||
r.quarterWindowRTT = r.newMinRTT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the three recent rtt samples.
|
|
||||||
if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt {
|
|
||||||
r.recentMinRTT = rttSample{rtt: sample, time: now}
|
|
||||||
r.halfWindowRTT = r.recentMinRTT
|
|
||||||
r.quarterWindowRTT = r.recentMinRTT
|
|
||||||
} else if sample <= r.halfWindowRTT.rtt {
|
|
||||||
r.halfWindowRTT = rttSample{rtt: sample, time: now}
|
|
||||||
r.quarterWindowRTT = r.halfWindowRTT
|
|
||||||
} else if sample <= r.quarterWindowRTT.rtt {
|
|
||||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expire old min rtt samples.
|
|
||||||
if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) {
|
|
||||||
r.recentMinRTT = r.halfWindowRTT
|
|
||||||
r.halfWindowRTT = r.quarterWindowRTT
|
|
||||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
|
||||||
} else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) {
|
|
||||||
r.halfWindowRTT = r.quarterWindowRTT
|
|
||||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
|
||||||
} else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) {
|
|
||||||
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next
|
|
||||||
// |numSamples| UpdateRTT calls.
|
|
||||||
func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) {
|
|
||||||
r.numMinRTTsamplesRemaining = numSamples
|
|
||||||
r.newMinRTT = rttSample{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
|
|
||||||
func (r *RTTStats) OnConnectionMigration() {
|
|
||||||
r.latestRTT = 0
|
|
||||||
r.minRTT = 0
|
|
||||||
r.smoothedRTT = 0
|
|
||||||
r.meanDeviation = 0
|
|
||||||
r.initialRTTus = initialRTTus
|
|
||||||
r.numMinRTTsamplesRemaining = 0
|
|
||||||
r.recentMinRTTwindow = utils.InfDuration
|
|
||||||
r.recentMinRTT = rttSample{}
|
|
||||||
r.halfWindowRTT = rttSample{}
|
|
||||||
r.quarterWindowRTT = rttSample{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
|
|
||||||
// is larger. The mean deviation is increased to the most recent deviation if
|
|
||||||
// it's larger.
|
|
||||||
func (r *RTTStats) ExpireSmoothedMetrics() {
|
|
||||||
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
|
|
||||||
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
|
|
||||||
}
|
|
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cryptoStreamI interface {
|
||||||
|
StreamID() protocol.StreamID
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
handleStreamFrame(*wire.StreamFrame) error
|
||||||
|
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||||
|
closeForShutdown(error)
|
||||||
|
setReadOffset(protocol.ByteCount)
|
||||||
|
// methods needed for flow control
|
||||||
|
getWindowUpdate() protocol.ByteCount
|
||||||
|
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cryptoStream struct {
|
||||||
|
*stream
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ cryptoStreamI = &cryptoStream{}
|
||||||
|
|
||||||
|
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||||
|
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||||
|
return &cryptoStream{str}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadOffset sets the read offset.
|
||||||
|
// It is only needed for the crypto stream.
|
||||||
|
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||||
|
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||||
|
s.receiveStream.readOffset = offset
|
||||||
|
s.receiveStream.frameQueue.readPosition = offset
|
||||||
|
}
|
109
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
109
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
@ -16,23 +16,48 @@ type StreamID = protocol.StreamID
|
|||||||
// A VersionNumber is a QUIC version number.
|
// A VersionNumber is a QUIC version number.
|
||||||
type VersionNumber = protocol.VersionNumber
|
type VersionNumber = protocol.VersionNumber
|
||||||
|
|
||||||
|
// VersionGQUIC39 is gQUIC version 39.
|
||||||
|
const VersionGQUIC39 = protocol.Version39
|
||||||
|
|
||||||
// A Cookie can be used to verify the ownership of the client address.
|
// A Cookie can be used to verify the ownership of the client address.
|
||||||
type Cookie = handshake.Cookie
|
type Cookie = handshake.Cookie
|
||||||
|
|
||||||
|
// ConnectionState records basic details about the QUIC connection.
|
||||||
|
type ConnectionState = handshake.ConnectionState
|
||||||
|
|
||||||
|
// An ErrorCode is an application-defined error code.
|
||||||
|
type ErrorCode = protocol.ApplicationErrorCode
|
||||||
|
|
||||||
// Stream is the interface implemented by QUIC streams
|
// Stream is the interface implemented by QUIC streams
|
||||||
type Stream interface {
|
type Stream interface {
|
||||||
|
// StreamID returns the stream ID.
|
||||||
|
StreamID() StreamID
|
||||||
// Read reads data from the stream.
|
// Read reads data from the stream.
|
||||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||||
|
// If the stream was canceled by the peer, the error implements the StreamError
|
||||||
|
// interface, and Canceled() == true.
|
||||||
io.Reader
|
io.Reader
|
||||||
// Write writes data to the stream.
|
// Write writes data to the stream.
|
||||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||||
|
// If the stream was canceled by the peer, the error implements the StreamError
|
||||||
|
// interface, and Canceled() == true.
|
||||||
io.Writer
|
io.Writer
|
||||||
|
// Close closes the write-direction of the stream.
|
||||||
|
// Future calls to Write are not permitted after calling Close.
|
||||||
|
// It must not be called concurrently with Write.
|
||||||
|
// It must not be called after calling CancelWrite.
|
||||||
io.Closer
|
io.Closer
|
||||||
StreamID() StreamID
|
// CancelWrite aborts sending on this stream.
|
||||||
// Reset closes the stream with an error.
|
// It must not be called after Close.
|
||||||
Reset(error)
|
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||||
|
// Write will unblock immediately, and future calls to Write will fail.
|
||||||
|
CancelWrite(ErrorCode) error
|
||||||
|
// CancelRead aborts receiving on this stream.
|
||||||
|
// It will ask the peer to stop transmitting stream data.
|
||||||
|
// Read will unblock immediately, and future Read calls will fail.
|
||||||
|
CancelRead(ErrorCode) error
|
||||||
// The context is canceled as soon as the write-side of the stream is closed.
|
// The context is canceled as soon as the write-side of the stream is closed.
|
||||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
@ -53,18 +78,63 @@ type Stream interface {
|
|||||||
SetDeadline(t time.Time) error
|
SetDeadline(t time.Time) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A ReceiveStream is a unidirectional Receive Stream.
|
||||||
|
type ReceiveStream interface {
|
||||||
|
// see Stream.StreamID
|
||||||
|
StreamID() StreamID
|
||||||
|
// see Stream.Read
|
||||||
|
io.Reader
|
||||||
|
// see Stream.CancelRead
|
||||||
|
CancelRead(ErrorCode) error
|
||||||
|
// see Stream.SetReadDealine
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// A SendStream is a unidirectional Send Stream.
|
||||||
|
type SendStream interface {
|
||||||
|
// see Stream.StreamID
|
||||||
|
StreamID() StreamID
|
||||||
|
// see Stream.Write
|
||||||
|
io.Writer
|
||||||
|
// see Stream.Close
|
||||||
|
io.Closer
|
||||||
|
// see Stream.CancelWrite
|
||||||
|
CancelWrite(ErrorCode) error
|
||||||
|
// see Stream.Context
|
||||||
|
Context() context.Context
|
||||||
|
// see Stream.SetWriteDeadline
|
||||||
|
SetWriteDeadline(t time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||||
|
type StreamError interface {
|
||||||
|
error
|
||||||
|
Canceled() bool
|
||||||
|
ErrorCode() ErrorCode
|
||||||
|
}
|
||||||
|
|
||||||
// A Session is a QUIC connection between two peers.
|
// A Session is a QUIC connection between two peers.
|
||||||
type Session interface {
|
type Session interface {
|
||||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||||
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
|
|
||||||
AcceptStream() (Stream, error)
|
AcceptStream() (Stream, error)
|
||||||
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached.
|
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||||
// New streams always have the smallest possible stream ID.
|
AcceptUniStream() (ReceiveStream, error)
|
||||||
// TODO: Enable testing for the special error
|
// OpenStream opens a new bidirectional QUIC stream.
|
||||||
|
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||||
|
// There is no signaling to the peer about new streams:
|
||||||
|
// The peer can only accept the stream after data has been sent on the stream.
|
||||||
|
// TODO(#1152): Enable testing for the special error
|
||||||
OpenStream() (Stream, error)
|
OpenStream() (Stream, error)
|
||||||
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened.
|
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||||
// It always picks the smallest possible stream ID.
|
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||||
OpenStreamSync() (Stream, error)
|
OpenStreamSync() (Stream, error)
|
||||||
|
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
|
||||||
|
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||||
|
// TODO(#1152): Enable testing for the special error
|
||||||
|
OpenUniStream() (SendStream, error)
|
||||||
|
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
|
||||||
|
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||||
|
OpenUniStreamSync() (SendStream, error)
|
||||||
// LocalAddr returns the local address.
|
// LocalAddr returns the local address.
|
||||||
LocalAddr() net.Addr
|
LocalAddr() net.Addr
|
||||||
// RemoteAddr returns the address of the peer.
|
// RemoteAddr returns the address of the peer.
|
||||||
@ -74,13 +144,9 @@ type Session interface {
|
|||||||
// The context is cancelled when the session is closed.
|
// The context is cancelled when the session is closed.
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
}
|
// ConnectionState returns basic details about the QUIC connection.
|
||||||
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
ConnectionState() ConnectionState
|
||||||
// The communication is encrypted, but not yet forward secure.
|
|
||||||
type NonFWSession interface {
|
|
||||||
Session
|
|
||||||
WaitUntilHandshakeComplete() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config contains all configuration data needed for a QUIC server or client.
|
// Config contains all configuration data needed for a QUIC server or client.
|
||||||
@ -113,6 +179,17 @@ type Config struct {
|
|||||||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||||
MaxReceiveConnectionFlowControlWindow uint64
|
MaxReceiveConnectionFlowControlWindow uint64
|
||||||
|
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||||
|
// If not set, it will default to 100.
|
||||||
|
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||||
|
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||||
|
MaxIncomingStreams int
|
||||||
|
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||||
|
// This value doesn't have any effect in Google QUIC.
|
||||||
|
// If not set, it will default to 100.
|
||||||
|
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||||
|
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||||
|
MaxIncomingUniStreams int
|
||||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||||
KeepAlive bool
|
KeepAlive bool
|
||||||
}
|
}
|
||||||
|
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go
generated
vendored
Normal file
3
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go
generated
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet
|
46
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
46
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SentPacketHandler handles ACKs received for outgoing packets
|
||||||
|
type SentPacketHandler interface {
|
||||||
|
// SentPacket may modify the packet
|
||||||
|
SentPacket(packet *Packet)
|
||||||
|
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
|
||||||
|
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||||
|
SetHandshakeComplete()
|
||||||
|
|
||||||
|
// The SendMode determines if and what kind of packets can be sent.
|
||||||
|
SendMode() SendMode
|
||||||
|
// TimeUntilSend is the time when the next packet should be sent.
|
||||||
|
// It is used for pacing packets.
|
||||||
|
TimeUntilSend() time.Time
|
||||||
|
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
|
||||||
|
// It always returns a number greater or equal than 1.
|
||||||
|
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
|
||||||
|
// Note that the number of packets is only calculated based on the pacing algorithm.
|
||||||
|
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||||
|
ShouldSendNumPackets() int
|
||||||
|
|
||||||
|
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||||
|
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||||
|
DequeuePacketForRetransmission() (packet *Packet)
|
||||||
|
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
||||||
|
|
||||||
|
GetAlarmTimeout() time.Time
|
||||||
|
OnAlarm() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||||
|
type ReceivedPacketHandler interface {
|
||||||
|
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||||
|
IgnoreBelow(protocol.PacketNumber)
|
||||||
|
|
||||||
|
GetAlarmTimeout() time.Time
|
||||||
|
GetAckFrame() *wire.AckFrame
|
||||||
|
}
|
29
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
29
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Packet is a packet
|
||||||
|
type Packet struct {
|
||||||
|
PacketNumber protocol.PacketNumber
|
||||||
|
PacketType protocol.PacketType
|
||||||
|
Frames []wire.Frame
|
||||||
|
Length protocol.ByteCount
|
||||||
|
EncryptionLevel protocol.EncryptionLevel
|
||||||
|
SendTime time.Time
|
||||||
|
|
||||||
|
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
|
||||||
|
|
||||||
|
// There are two reasons why a packet cannot be retransmitted:
|
||||||
|
// * it was already retransmitted
|
||||||
|
// * this packet is a retransmission, and we already received an ACK for the original packet
|
||||||
|
canBeRetransmitted bool
|
||||||
|
includedInBytesInFlight bool
|
||||||
|
retransmittedAs []protocol.PacketNumber
|
||||||
|
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
|
||||||
|
retransmissionOf protocol.PacketNumber
|
||||||
|
}
|
@ -1,13 +1,10 @@
|
|||||||
// Generated by: main
|
// This file was automatically generated by genny.
|
||||||
// TypeWriter: linkedlist
|
// Any changes will be lost if this file is regenerated.
|
||||||
// Directive: +gen on Packet
|
// see https://github.com/cheekybits/genny
|
||||||
|
|
||||||
package ackhandler
|
package ackhandler
|
||||||
|
|
||||||
// List is a modification of http://golang.org/pkg/container/list/
|
// Linked list implementation from the Go standard library.
|
||||||
// Copyright 2009 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// PacketElement is an element of a linked list.
|
// PacketElement is an element of a linked list.
|
||||||
type PacketElement struct {
|
type PacketElement struct {
|
||||||
@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PacketList represents a doubly linked list.
|
// PacketList is a linked list of Packets.
|
||||||
// The zero value for PacketList is an empty list ready to use.
|
|
||||||
type PacketList struct {
|
type PacketList struct {
|
||||||
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
|
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
|
||||||
len int // current list length excluding (this) sentinel element
|
len int // current list length excluding (this) sentinel element
|
||||||
@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
|
|||||||
// The complexity is O(1).
|
// The complexity is O(1).
|
||||||
func (l *PacketList) Len() int { return l.len }
|
func (l *PacketList) Len() int { return l.len }
|
||||||
|
|
||||||
// Front returns the first element of list l or nil.
|
// Front returns the first element of list l or nil if the list is empty.
|
||||||
func (l *PacketList) Front() *PacketElement {
|
func (l *PacketList) Front() *PacketElement {
|
||||||
if l.len == 0 {
|
if l.len == 0 {
|
||||||
return nil
|
return nil
|
||||||
@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
|
|||||||
return l.root.next
|
return l.root.next
|
||||||
}
|
}
|
||||||
|
|
||||||
// Back returns the last element of list l or nil.
|
// Back returns the last element of list l or nil if the list is empty.
|
||||||
func (l *PacketList) Back() *PacketElement {
|
func (l *PacketList) Back() *PacketElement {
|
||||||
if l.len == 0 {
|
if l.len == 0 {
|
||||||
return nil
|
return nil
|
||||||
@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
|
|||||||
return l.root.prev
|
return l.root.prev
|
||||||
}
|
}
|
||||||
|
|
||||||
// lazyInit lazily initializes a zero PacketList value.
|
// lazyInit lazily initializes a zero List value.
|
||||||
func (l *PacketList) lazyInit() {
|
func (l *PacketList) lazyInit() {
|
||||||
if l.root.next == nil {
|
if l.root.next == nil {
|
||||||
l.Init()
|
l.Init()
|
||||||
@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at).
|
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||||
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
|
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
|
||||||
return l.insert(&PacketElement{Value: v}, at)
|
return l.insert(&PacketElement{Value: v}, at)
|
||||||
}
|
}
|
||||||
@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
|
|||||||
|
|
||||||
// Remove removes e from l if e is an element of list l.
|
// Remove removes e from l if e is an element of list l.
|
||||||
// It returns the element value e.Value.
|
// It returns the element value e.Value.
|
||||||
|
// The element must not be nil.
|
||||||
func (l *PacketList) Remove(e *PacketElement) Packet {
|
func (l *PacketList) Remove(e *PacketElement) Packet {
|
||||||
if e.list == l {
|
if e.list == l {
|
||||||
// if e.list == l, l must have been initialized when e was inserted
|
// if e.list == l, l must have been initialized when e was inserted
|
||||||
// in l or l == nil (e is a zero PacketElement) and l.remove will crash
|
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||||
l.remove(e)
|
l.remove(e)
|
||||||
}
|
}
|
||||||
return e.Value
|
return e.Value
|
||||||
@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
|
|||||||
|
|
||||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||||
// If mark is not an element of l, the list is not modified.
|
// If mark is not an element of l, the list is not modified.
|
||||||
|
// The mark must not be nil.
|
||||||
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
|
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
|
||||||
if mark.list != l {
|
if mark.list != l {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// see comment in PacketList.Remove about initialization of l
|
// see comment in List.Remove about initialization of l
|
||||||
return l.insertValue(v, mark.prev)
|
return l.insertValue(v, mark.prev)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||||
// If mark is not an element of l, the list is not modified.
|
// If mark is not an element of l, the list is not modified.
|
||||||
|
// The mark must not be nil.
|
||||||
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
|
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
|
||||||
if mark.list != l {
|
if mark.list != l {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// see comment in PacketList.Remove about initialization of l
|
// see comment in List.Remove about initialization of l
|
||||||
return l.insertValue(v, mark)
|
return l.insertValue(v, mark)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoveToFront moves element e to the front of list l.
|
// MoveToFront moves element e to the front of list l.
|
||||||
// If e is not an element of l, the list is not modified.
|
// If e is not an element of l, the list is not modified.
|
||||||
|
// The element must not be nil.
|
||||||
func (l *PacketList) MoveToFront(e *PacketElement) {
|
func (l *PacketList) MoveToFront(e *PacketElement) {
|
||||||
if e.list != l || l.root.next == e {
|
if e.list != l || l.root.next == e {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// see comment in PacketList.Remove about initialization of l
|
// see comment in List.Remove about initialization of l
|
||||||
l.insert(l.remove(e), &l.root)
|
l.insert(l.remove(e), &l.root)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoveToBack moves element e to the back of list l.
|
// MoveToBack moves element e to the back of list l.
|
||||||
// If e is not an element of l, the list is not modified.
|
// If e is not an element of l, the list is not modified.
|
||||||
|
// The element must not be nil.
|
||||||
func (l *PacketList) MoveToBack(e *PacketElement) {
|
func (l *PacketList) MoveToBack(e *PacketElement) {
|
||||||
if e.list != l || l.root.prev == e {
|
if e.list != l || l.root.prev == e {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// see comment in PacketList.Remove about initialization of l
|
// see comment in List.Remove about initialization of l
|
||||||
l.insert(l.remove(e), l.root.prev)
|
l.insert(l.remove(e), l.root.prev)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoveBefore moves element e to its new position before mark.
|
// MoveBefore moves element e to its new position before mark.
|
||||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||||
|
// The element and mark must not be nil.
|
||||||
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
||||||
if e.list != l || e == mark || mark.list != l {
|
if e.list != l || e == mark || mark.list != l {
|
||||||
return
|
return
|
||||||
@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MoveAfter moves element e to its new position after mark.
|
// MoveAfter moves element e to its new position after mark.
|
||||||
// If e is not an element of l, or e == mark, the list is not modified.
|
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||||
|
// The element and mark must not be nil.
|
||||||
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
||||||
if e.list != l || e == mark || mark.list != l {
|
if e.list != l || e == mark || mark.list != l {
|
||||||
return
|
return
|
||||||
@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PushBackList inserts a copy of an other list at the back of list l.
|
// PushBackList inserts a copy of an other list at the back of list l.
|
||||||
// The lists l and other may be the same.
|
// The lists l and other may be the same. They must not be nil.
|
||||||
func (l *PacketList) PushBackList(other *PacketList) {
|
func (l *PacketList) PushBackList(other *PacketList) {
|
||||||
l.lazyInit()
|
l.lazyInit()
|
||||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||||
@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PushFrontList inserts a copy of an other list at the front of list l.
|
// PushFrontList inserts a copy of an other list at the front of list l.
|
||||||
// The lists l and other may be the same.
|
// The lists l and other may be the same. They must not be nil.
|
||||||
func (l *PacketList) PushFrontList(other *PacketList) {
|
func (l *PacketList) PushFrontList(other *PacketList) {
|
||||||
l.lazyInit()
|
l.lazyInit()
|
||||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
215
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
215
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type receivedPacketHandler struct {
|
||||||
|
largestObserved protocol.PacketNumber
|
||||||
|
ignoreBelow protocol.PacketNumber
|
||||||
|
largestObservedReceivedTime time.Time
|
||||||
|
|
||||||
|
packetHistory *receivedPacketHistory
|
||||||
|
|
||||||
|
ackSendDelay time.Duration
|
||||||
|
rttStats *congestion.RTTStats
|
||||||
|
|
||||||
|
packetsReceivedSinceLastAck int
|
||||||
|
retransmittablePacketsReceivedSinceLastAck int
|
||||||
|
ackQueued bool
|
||||||
|
ackAlarm time.Time
|
||||||
|
lastAck *wire.AckFrame
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
|
|
||||||
|
version protocol.VersionNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maximum delay that can be applied to an ACK for a retransmittable packet
|
||||||
|
ackSendDelay = 25 * time.Millisecond
|
||||||
|
// initial maximum number of retransmittable packets received before sending an ack.
|
||||||
|
initialRetransmittablePacketsBeforeAck = 2
|
||||||
|
// number of retransmittable that an ACK is sent for
|
||||||
|
retransmittablePacketsBeforeAck = 10
|
||||||
|
// 1/5 RTT delay when doing ack decimation
|
||||||
|
ackDecimationDelay = 1.0 / 4
|
||||||
|
// 1/8 RTT delay when doing ack decimation
|
||||||
|
shortAckDecimationDelay = 1.0 / 8
|
||||||
|
// Minimum number of packets received before ack decimation is enabled.
|
||||||
|
// This intends to avoid the beginning of slow start, when CWNDs may be
|
||||||
|
// rapidly increasing.
|
||||||
|
minReceivedBeforeAckDecimation = 100
|
||||||
|
// Maximum number of packets to ack immediately after a missing packet for
|
||||||
|
// fast retransmission to kick in at the sender. This limit is created to
|
||||||
|
// reduce the number of acks sent that have no benefit for fast retransmission.
|
||||||
|
// Set to the number of nacks needed for fast retransmit plus one for protection
|
||||||
|
// against an ack loss
|
||||||
|
maxPacketsAfterNewMissing = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||||
|
func NewReceivedPacketHandler(
|
||||||
|
rttStats *congestion.RTTStats,
|
||||||
|
logger utils.Logger,
|
||||||
|
version protocol.VersionNumber,
|
||||||
|
) ReceivedPacketHandler {
|
||||||
|
return &receivedPacketHandler{
|
||||||
|
packetHistory: newReceivedPacketHistory(),
|
||||||
|
ackSendDelay: ackSendDelay,
|
||||||
|
rttStats: rttStats,
|
||||||
|
logger: logger,
|
||||||
|
version: version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||||
|
if packetNumber < h.ignoreBelow {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
isMissing := h.isMissing(packetNumber)
|
||||||
|
if packetNumber > h.largestObserved {
|
||||||
|
h.largestObserved = packetNumber
|
||||||
|
h.largestObservedReceivedTime = rcvTime
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IgnoreBelow sets a lower limit for acking packets.
|
||||||
|
// Packets with packet numbers smaller than p will not be acked.
|
||||||
|
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
||||||
|
if p <= h.ignoreBelow {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.ignoreBelow = p
|
||||||
|
h.packetHistory.DeleteBelow(p)
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMissing says if a packet was reported missing in the last ACK.
|
||||||
|
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
|
||||||
|
if h.lastAck == nil || p < h.ignoreBelow {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
|
||||||
|
if h.lastAck == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
highestRange := h.packetHistory.GetHighestAckRange()
|
||||||
|
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeQueueAck queues an ACK, if necessary.
|
||||||
|
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
|
||||||
|
// in ACK_DECIMATION_WITH_REORDERING mode.
|
||||||
|
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
|
||||||
|
h.packetsReceivedSinceLastAck++
|
||||||
|
|
||||||
|
// always ack the first packet
|
||||||
|
if h.lastAck == nil {
|
||||||
|
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||||
|
h.ackQueued = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||||
|
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||||
|
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||||
|
if wasMissing {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
|
||||||
|
}
|
||||||
|
h.ackQueued = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.ackQueued && shouldInstigateAck {
|
||||||
|
h.retransmittablePacketsReceivedSinceLastAck++
|
||||||
|
|
||||||
|
if packetNumber > minReceivedBeforeAckDecimation {
|
||||||
|
// ack up to 10 packets at once
|
||||||
|
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
|
||||||
|
h.ackQueued = true
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
|
||||||
|
}
|
||||||
|
} else if h.ackAlarm.IsZero() {
|
||||||
|
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
|
||||||
|
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
|
||||||
|
h.ackAlarm = rcvTime.Add(ackDelay)
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// send an ACK every 2 retransmittable packets
|
||||||
|
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
|
||||||
|
}
|
||||||
|
h.ackQueued = true
|
||||||
|
} else if h.ackAlarm.IsZero() {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
|
||||||
|
}
|
||||||
|
h.ackAlarm = rcvTime.Add(ackSendDelay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If there are new missing packets to report, set a short timer to send an ACK.
|
||||||
|
if h.hasNewMissingPackets() {
|
||||||
|
// wait the minimum of 1/8 min RTT and the existing ack time
|
||||||
|
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
|
||||||
|
ackTime := rcvTime.Add(ackDelay)
|
||||||
|
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
|
||||||
|
h.ackAlarm = ackTime
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.ackQueued {
|
||||||
|
// cancel the ack alarm
|
||||||
|
h.ackAlarm = time.Time{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
||||||
|
now := time.Now()
|
||||||
|
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
||||||
|
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||||
|
}
|
||||||
|
|
||||||
|
ack := &wire.AckFrame{
|
||||||
|
AckRanges: h.packetHistory.GetAckRanges(),
|
||||||
|
DelayTime: now.Sub(h.largestObservedReceivedTime),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.lastAck = ack
|
||||||
|
h.ackAlarm = time.Time{}
|
||||||
|
h.ackQueued = false
|
||||||
|
h.packetsReceivedSinceLastAck = 0
|
||||||
|
h.retransmittablePacketsReceivedSinceLastAck = 0
|
||||||
|
return ack
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
@ -74,17 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUpTo deletes all entries up to (and including) p
|
// DeleteBelow deletes all entries below (but not including) p
|
||||||
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
|
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||||
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
|
if p <= h.lowestInReceivedPacketNumbers {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.lowestInReceivedPacketNumbers = p
|
||||||
|
|
||||||
nextEl := h.ranges.Front()
|
nextEl := h.ranges.Front()
|
||||||
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
||||||
nextEl = el.Next()
|
nextEl = el.Next()
|
||||||
|
|
||||||
if p >= el.Value.Start && p < el.Value.End {
|
if p > el.Value.Start && p <= el.Value.End {
|
||||||
el.Value.Start = p + 1
|
el.Value.Start = p
|
||||||
} else if el.Value.End <= p { // delete a whole range
|
} else if el.Value.End < p { // delete a whole range
|
||||||
h.ranges.Remove(el)
|
h.ranges.Remove(el)
|
||||||
} else { // no ranges affected. Nothing to do
|
} else { // no ranges affected. Nothing to do
|
||||||
return
|
return
|
||||||
@ -101,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
|
|||||||
ackRanges := make([]wire.AckRange, h.ranges.Len())
|
ackRanges := make([]wire.AckRange, h.ranges.Len())
|
||||||
i := 0
|
i := 0
|
||||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||||
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
|
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
return ackRanges
|
return ackRanges
|
||||||
@ -111,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
|||||||
ackRange := wire.AckRange{}
|
ackRange := wire.AckRange{}
|
||||||
if h.ranges.Len() > 0 {
|
if h.ranges.Len() > 0 {
|
||||||
r := h.ranges.Back().Value
|
r := h.ranges.Back().Value
|
||||||
ackRange.First = r.Start
|
ackRange.Smallest = r.Start
|
||||||
ackRange.Last = r.End
|
ackRange.Largest = r.End
|
||||||
}
|
}
|
||||||
return ackRange
|
return ackRange
|
||||||
}
|
}
|
40
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
40
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// The SendMode says what kind of packets can be sent.
|
||||||
|
type SendMode uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SendNone means that no packets should be sent
|
||||||
|
SendNone SendMode = iota
|
||||||
|
// SendAck means an ACK-only packet should be sent
|
||||||
|
SendAck
|
||||||
|
// SendRetransmission means that retransmissions should be sent
|
||||||
|
SendRetransmission
|
||||||
|
// SendRTO means that an RTO probe packet should be sent
|
||||||
|
SendRTO
|
||||||
|
// SendTLP means that a TLP probe packet should be sent
|
||||||
|
SendTLP
|
||||||
|
// SendAny means that any packet should be sent
|
||||||
|
SendAny
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s SendMode) String() string {
|
||||||
|
switch s {
|
||||||
|
case SendNone:
|
||||||
|
return "none"
|
||||||
|
case SendAck:
|
||||||
|
return "ack"
|
||||||
|
case SendRetransmission:
|
||||||
|
return "retransmission"
|
||||||
|
case SendRTO:
|
||||||
|
return "rto"
|
||||||
|
case SendTLP:
|
||||||
|
return "tlp"
|
||||||
|
case SendAny:
|
||||||
|
return "any"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("invalid send mode: %d", s)
|
||||||
|
}
|
||||||
|
}
|
649
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
649
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
@ -0,0 +1,649 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||||
|
// In fraction of an RTT.
|
||||||
|
timeReorderingFraction = 1.0 / 8
|
||||||
|
// defaultRTOTimeout is the RTO time on new connections
|
||||||
|
defaultRTOTimeout = 500 * time.Millisecond
|
||||||
|
// Minimum time in the future a tail loss probe alarm may be set for.
|
||||||
|
minTPLTimeout = 10 * time.Millisecond
|
||||||
|
// Maximum number of tail loss probes before an RTO fires.
|
||||||
|
maxTLPs = 2
|
||||||
|
// Minimum time in the future an RTO alarm may be set for.
|
||||||
|
minRTOTimeout = 200 * time.Millisecond
|
||||||
|
// maxRTOTimeout is the maximum RTO time
|
||||||
|
maxRTOTimeout = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type sentPacketHandler struct {
|
||||||
|
lastSentPacketNumber protocol.PacketNumber
|
||||||
|
lastSentRetransmittablePacketTime time.Time
|
||||||
|
lastSentHandshakePacketTime time.Time
|
||||||
|
|
||||||
|
nextPacketSendTime time.Time
|
||||||
|
skippedPackets []protocol.PacketNumber
|
||||||
|
|
||||||
|
largestAcked protocol.PacketNumber
|
||||||
|
largestReceivedPacketWithAck protocol.PacketNumber
|
||||||
|
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||||
|
// example: we send an ACK for packets 90-100 with packet number 20
|
||||||
|
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
|
||||||
|
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||||
|
largestSentBeforeRTO protocol.PacketNumber
|
||||||
|
|
||||||
|
packetHistory *sentPacketHistory
|
||||||
|
stopWaitingManager stopWaitingManager
|
||||||
|
|
||||||
|
retransmissionQueue []*Packet
|
||||||
|
|
||||||
|
bytesInFlight protocol.ByteCount
|
||||||
|
|
||||||
|
congestion congestion.SendAlgorithm
|
||||||
|
rttStats *congestion.RTTStats
|
||||||
|
|
||||||
|
handshakeComplete bool
|
||||||
|
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
||||||
|
handshakeCount uint32
|
||||||
|
|
||||||
|
// The number of times a TLP has been sent without receiving an ack.
|
||||||
|
tlpCount uint32
|
||||||
|
allowTLP bool
|
||||||
|
|
||||||
|
// The number of times an RTO has been sent without receiving an ack.
|
||||||
|
rtoCount uint32
|
||||||
|
// The number of RTO probe packets that should be sent.
|
||||||
|
numRTOs int
|
||||||
|
|
||||||
|
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
|
||||||
|
lossTime time.Time
|
||||||
|
|
||||||
|
// The alarm timeout
|
||||||
|
alarm time.Time
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSentPacketHandler creates a new sentPacketHandler
|
||||||
|
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||||
|
congestion := congestion.NewCubicSender(
|
||||||
|
congestion.DefaultClock{},
|
||||||
|
rttStats,
|
||||||
|
false, /* don't use reno since chromium doesn't (why?) */
|
||||||
|
protocol.InitialCongestionWindow,
|
||||||
|
protocol.DefaultMaxCongestionWindow,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &sentPacketHandler{
|
||||||
|
packetHistory: newSentPacketHistory(),
|
||||||
|
stopWaitingManager: stopWaitingManager{},
|
||||||
|
rttStats: rttStats,
|
||||||
|
congestion: congestion,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
|
||||||
|
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||||
|
return p.PacketNumber
|
||||||
|
}
|
||||||
|
return h.largestAcked + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||||
|
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
||||||
|
var queue []*Packet
|
||||||
|
for _, packet := range h.retransmissionQueue {
|
||||||
|
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||||
|
queue = append(queue, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var handshakePackets []*Packet
|
||||||
|
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||||
|
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
|
||||||
|
handshakePackets = append(handshakePackets, p)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
for _, p := range handshakePackets {
|
||||||
|
h.packetHistory.Remove(p.PacketNumber)
|
||||||
|
}
|
||||||
|
h.retransmissionQueue = queue
|
||||||
|
h.handshakeComplete = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) SentPacket(packet *Packet) {
|
||||||
|
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||||
|
h.packetHistory.SentPacket(packet)
|
||||||
|
h.updateLossDetectionAlarm()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||||
|
var p []*Packet
|
||||||
|
for _, packet := range packets {
|
||||||
|
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||||
|
p = append(p, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
|
||||||
|
h.updateLossDetectionAlarm()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
||||||
|
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||||
|
h.skippedPackets = append(h.skippedPackets, p)
|
||||||
|
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
||||||
|
h.skippedPackets = h.skippedPackets[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.lastSentPacketNumber = packet.PacketNumber
|
||||||
|
|
||||||
|
if len(packet.Frames) > 0 {
|
||||||
|
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||||
|
packet.largestAcked = ackFrame.LargestAcked()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
||||||
|
isRetransmittable := len(packet.Frames) != 0
|
||||||
|
|
||||||
|
if isRetransmittable {
|
||||||
|
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||||
|
h.lastSentHandshakePacketTime = packet.SendTime
|
||||||
|
}
|
||||||
|
h.lastSentRetransmittablePacketTime = packet.SendTime
|
||||||
|
packet.includedInBytesInFlight = true
|
||||||
|
h.bytesInFlight += packet.Length
|
||||||
|
packet.canBeRetransmitted = true
|
||||||
|
if h.numRTOs > 0 {
|
||||||
|
h.numRTOs--
|
||||||
|
}
|
||||||
|
h.allowTLP = false
|
||||||
|
}
|
||||||
|
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
|
||||||
|
|
||||||
|
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
|
||||||
|
return isRetransmittable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||||
|
largestAcked := ackFrame.LargestAcked()
|
||||||
|
if largestAcked > h.lastSentPacketNumber {
|
||||||
|
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||||
|
}
|
||||||
|
|
||||||
|
// duplicate or out of order ACK
|
||||||
|
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||||
|
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.largestReceivedPacketWithAck = withPacketNumber
|
||||||
|
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
||||||
|
|
||||||
|
if h.skippedPacketsAcked(ackFrame) {
|
||||||
|
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
|
||||||
|
h.congestion.MaybeExitSlowStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
priorInFlight := h.bytesInFlight
|
||||||
|
for _, p := range ackedPackets {
|
||||||
|
if encLevel < p.EncryptionLevel {
|
||||||
|
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||||
|
}
|
||||||
|
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||||
|
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||||
|
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||||
|
if p.largestAcked != 0 {
|
||||||
|
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
|
||||||
|
}
|
||||||
|
if err := h.onPacketAcked(p, rcvTime); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if p.includedInBytesInFlight {
|
||||||
|
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.updateLossDetectionAlarm()
|
||||||
|
|
||||||
|
h.garbageCollectSkippedPackets()
|
||||||
|
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||||
|
return h.lowestPacketNotConfirmedAcked
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
|
||||||
|
var ackedPackets []*Packet
|
||||||
|
ackRangeIndex := 0
|
||||||
|
lowestAcked := ackFrame.LowestAcked()
|
||||||
|
largestAcked := ackFrame.LargestAcked()
|
||||||
|
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||||
|
// Ignore packets below the lowest acked
|
||||||
|
if p.PacketNumber < lowestAcked {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
// Break after largest acked is reached
|
||||||
|
if p.PacketNumber > largestAcked {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ackFrame.HasMissingRanges() {
|
||||||
|
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||||
|
|
||||||
|
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||||
|
ackRangeIndex++
|
||||||
|
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
|
||||||
|
if p.PacketNumber > ackRange.Largest {
|
||||||
|
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||||
|
}
|
||||||
|
ackedPackets = append(ackedPackets, p)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ackedPackets = append(ackedPackets, p)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
if h.logger.Debug() && len(ackedPackets) > 0 {
|
||||||
|
pns := make([]protocol.PacketNumber, len(ackedPackets))
|
||||||
|
for i, p := range ackedPackets {
|
||||||
|
pns[i] = p.PacketNumber
|
||||||
|
}
|
||||||
|
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
|
||||||
|
}
|
||||||
|
return ackedPackets, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
|
||||||
|
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
|
||||||
|
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||||
|
// Cancel the alarm if no packets are outstanding
|
||||||
|
if h.packetHistory.Len() == 0 {
|
||||||
|
h.alarm = time.Time{}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.handshakeComplete {
|
||||||
|
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
|
||||||
|
} else if !h.lossTime.IsZero() {
|
||||||
|
// Early retransmit timer or time loss detection.
|
||||||
|
h.alarm = h.lossTime
|
||||||
|
} else {
|
||||||
|
// RTO or TLP alarm
|
||||||
|
alarmDuration := h.computeRTOTimeout()
|
||||||
|
if h.tlpCount < maxTLPs {
|
||||||
|
tlpAlarm := h.computeTLPTimeout()
|
||||||
|
// if the RTO duration is shorter than the TLP duration, use the RTO duration
|
||||||
|
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
|
||||||
|
}
|
||||||
|
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
|
||||||
|
h.lossTime = time.Time{}
|
||||||
|
|
||||||
|
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||||
|
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||||
|
|
||||||
|
var lostPackets []*Packet
|
||||||
|
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
|
||||||
|
if packet.PacketNumber > h.largestAcked {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeSinceSent := now.Sub(packet.SendTime)
|
||||||
|
if timeSinceSent > delayUntilLost {
|
||||||
|
lostPackets = append(lostPackets, packet)
|
||||||
|
} else if h.lossTime.IsZero() {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
|
||||||
|
}
|
||||||
|
// Note: This conditional is only entered once per call
|
||||||
|
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
if h.logger.Debug() && len(lostPackets) > 0 {
|
||||||
|
pns := make([]protocol.PacketNumber, len(lostPackets))
|
||||||
|
for i, p := range lostPackets {
|
||||||
|
pns[i] = p.PacketNumber
|
||||||
|
}
|
||||||
|
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range lostPackets {
|
||||||
|
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
|
||||||
|
if p.includedInBytesInFlight {
|
||||||
|
h.bytesInFlight -= p.Length
|
||||||
|
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
|
||||||
|
}
|
||||||
|
if p.canBeRetransmitted {
|
||||||
|
// queue the packet for retransmission, and report the loss to the congestion controller
|
||||||
|
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.packetHistory.Remove(p.PacketNumber)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) OnAlarm() error {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if !h.handshakeComplete {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("Loss detection alarm fired in handshake mode")
|
||||||
|
}
|
||||||
|
h.handshakeCount++
|
||||||
|
err = h.queueHandshakePacketsForRetransmission()
|
||||||
|
} else if !h.lossTime.IsZero() {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("Loss detection alarm fired in loss timer mode")
|
||||||
|
}
|
||||||
|
// Early retransmit or time loss detection
|
||||||
|
err = h.detectLostPackets(now, h.bytesInFlight)
|
||||||
|
} else if h.tlpCount < maxTLPs {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("Loss detection alarm fired in TLP mode")
|
||||||
|
}
|
||||||
|
h.allowTLP = true
|
||||||
|
h.tlpCount++
|
||||||
|
} else {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("Loss detection alarm fired in RTO mode")
|
||||||
|
}
|
||||||
|
// RTO
|
||||||
|
h.rtoCount++
|
||||||
|
h.numRTOs += 2
|
||||||
|
err = h.queueRTOs()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.updateLossDetectionAlarm()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||||
|
return h.alarm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
|
||||||
|
// This happens if a packet and its retransmissions is acked in the same ACK.
|
||||||
|
// As soon as we process the first one, this will remove all the retransmissions,
|
||||||
|
// so we won't find the retransmitted packet number later.
|
||||||
|
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// only report the acking of this packet to the congestion controller if:
|
||||||
|
// * it is a retransmittable packet
|
||||||
|
// * this packet wasn't retransmitted yet
|
||||||
|
if p.isRetransmission {
|
||||||
|
// that the parent doesn't exist is expected to happen every time the original packet was already acked
|
||||||
|
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
|
||||||
|
if len(parent.retransmittedAs) == 1 {
|
||||||
|
parent.retransmittedAs = nil
|
||||||
|
} else {
|
||||||
|
// remove this packet from the slice of retransmission
|
||||||
|
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
|
||||||
|
for _, pn := range parent.retransmittedAs {
|
||||||
|
if pn != p.PacketNumber {
|
||||||
|
retransmittedAs = append(retransmittedAs, pn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parent.retransmittedAs = retransmittedAs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// this also applies to packets that have been retransmitted as probe packets
|
||||||
|
if p.includedInBytesInFlight {
|
||||||
|
h.bytesInFlight -= p.Length
|
||||||
|
}
|
||||||
|
if h.rtoCount > 0 {
|
||||||
|
h.verifyRTO(p.PacketNumber)
|
||||||
|
}
|
||||||
|
if err := h.stopRetransmissionsFor(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.rtoCount = 0
|
||||||
|
h.tlpCount = 0
|
||||||
|
h.handshakeCount = 0
|
||||||
|
return h.packetHistory.Remove(p.PacketNumber)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
|
||||||
|
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, r := range p.retransmittedAs {
|
||||||
|
packet := h.packetHistory.GetPacket(r)
|
||||||
|
if packet == nil {
|
||||||
|
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
|
||||||
|
}
|
||||||
|
h.stopRetransmissionsFor(packet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
|
||||||
|
if pn <= h.largestSentBeforeRTO {
|
||||||
|
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
|
||||||
|
// Replace SRTT with latest_rtt and increase the variance to prevent
|
||||||
|
// a spurious RTO from happening again.
|
||||||
|
h.rttStats.ExpireSmoothedMetrics()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
|
||||||
|
h.congestion.OnRetransmissionTimeout(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
||||||
|
if len(h.retransmissionQueue) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
packet := h.retransmissionQueue[0]
|
||||||
|
// Shift the slice and don't retain anything that isn't needed.
|
||||||
|
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
|
||||||
|
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
|
||||||
|
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
|
||||||
|
return packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||||
|
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||||
|
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) SendMode() SendMode {
|
||||||
|
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
|
||||||
|
|
||||||
|
// Don't send any packets if we're keeping track of the maximum number of packets.
|
||||||
|
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
|
||||||
|
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
|
||||||
|
// but still allow sending of retransmissions and ACKs.
|
||||||
|
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||||
|
}
|
||||||
|
return SendNone
|
||||||
|
}
|
||||||
|
if h.allowTLP {
|
||||||
|
return SendTLP
|
||||||
|
}
|
||||||
|
if h.numRTOs > 0 {
|
||||||
|
return SendRTO
|
||||||
|
}
|
||||||
|
// Only send ACKs if we're congestion limited.
|
||||||
|
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||||
|
}
|
||||||
|
return SendAck
|
||||||
|
}
|
||||||
|
// Send retransmissions first, if there are any.
|
||||||
|
if len(h.retransmissionQueue) > 0 {
|
||||||
|
return SendRetransmission
|
||||||
|
}
|
||||||
|
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
|
||||||
|
if h.logger.Debug() {
|
||||||
|
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||||
|
}
|
||||||
|
return SendAck
|
||||||
|
}
|
||||||
|
return SendAny
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||||
|
return h.nextPacketSendTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
||||||
|
if h.numRTOs > 0 {
|
||||||
|
// RTO probes should not be paced, but must be sent immediately.
|
||||||
|
return h.numRTOs
|
||||||
|
}
|
||||||
|
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||||
|
if delay == 0 || delay > protocol.MinPacingDelay {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// retransmit the oldest two packets
|
||||||
|
func (h *sentPacketHandler) queueRTOs() error {
|
||||||
|
h.largestSentBeforeRTO = h.lastSentPacketNumber
|
||||||
|
// Queue the first two outstanding packets for retransmission.
|
||||||
|
// This does NOT declare this packets as lost:
|
||||||
|
// They are still tracked in the packet history and count towards the bytes in flight.
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||||
|
h.logger.Debugf("Queueing packet %#x for retransmission (RTO)", p.PacketNumber)
|
||||||
|
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||||
|
var handshakePackets []*Packet
|
||||||
|
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||||
|
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||||
|
handshakePackets = append(handshakePackets, p)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
for _, p := range handshakePackets {
|
||||||
|
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
|
||||||
|
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
||||||
|
if !p.canBeRetransmitted {
|
||||||
|
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
|
||||||
|
}
|
||||||
|
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
||||||
|
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||||
|
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
|
||||||
|
// exponential backoff
|
||||||
|
// There's an implicit limit to this set by the handshake timeout.
|
||||||
|
return duration << h.handshakeCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
|
||||||
|
// TODO(#1236): include the max_ack_delay
|
||||||
|
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||||
|
var rto time.Duration
|
||||||
|
rtt := h.rttStats.SmoothedRTT()
|
||||||
|
if rtt == 0 {
|
||||||
|
rto = defaultRTOTimeout
|
||||||
|
} else {
|
||||||
|
rto = rtt + 4*h.rttStats.MeanDeviation()
|
||||||
|
}
|
||||||
|
rto = utils.MaxDuration(rto, minRTOTimeout)
|
||||||
|
// Exponential backoff
|
||||||
|
rto = rto << h.rtoCount
|
||||||
|
return utils.MinDuration(rto, maxRTOTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||||
|
for _, p := range h.skippedPackets {
|
||||||
|
if ackFrame.AcksPacket(p) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||||
|
lowestUnacked := h.lowestUnacked()
|
||||||
|
deleteIndex := 0
|
||||||
|
for i, p := range h.skippedPackets {
|
||||||
|
if p < lowestUnacked {
|
||||||
|
deleteIndex = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
||||||
|
}
|
127
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
127
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package ackhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sentPacketHistory struct {
|
||||||
|
packetList *PacketList
|
||||||
|
packetMap map[protocol.PacketNumber]*PacketElement
|
||||||
|
|
||||||
|
firstOutstanding *PacketElement
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSentPacketHistory() *sentPacketHistory {
|
||||||
|
return &sentPacketHistory{
|
||||||
|
packetList: NewPacketList(),
|
||||||
|
packetMap: make(map[protocol.PacketNumber]*PacketElement),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHistory) SentPacket(p *Packet) {
|
||||||
|
h.sentPacketImpl(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
||||||
|
el := h.packetList.PushBack(*p)
|
||||||
|
h.packetMap[p.PacketNumber] = el
|
||||||
|
if h.firstOutstanding == nil {
|
||||||
|
h.firstOutstanding = el
|
||||||
|
}
|
||||||
|
return el
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||||
|
retransmission, ok := h.packetMap[retransmissionOf]
|
||||||
|
// The retransmitted packet is not present anymore.
|
||||||
|
// This can happen if it was acked in between dequeueing of the retransmission and sending.
|
||||||
|
// Just treat the retransmissions as normal packets.
|
||||||
|
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
|
||||||
|
if !ok {
|
||||||
|
for _, packet := range packets {
|
||||||
|
h.sentPacketImpl(packet)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
|
||||||
|
for i, packet := range packets {
|
||||||
|
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
|
||||||
|
el := h.sentPacketImpl(packet)
|
||||||
|
el.Value.isRetransmission = true
|
||||||
|
el.Value.retransmissionOf = retransmissionOf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
|
||||||
|
if el, ok := h.packetMap[p]; ok {
|
||||||
|
return &el.Value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate iterates through all packets.
|
||||||
|
// The callback must not modify the history.
|
||||||
|
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
|
||||||
|
cont := true
|
||||||
|
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
|
||||||
|
var err error
|
||||||
|
cont, err = cb(&el.Value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstOutStanding returns the first outstanding packet.
|
||||||
|
// It must not be modified (e.g. retransmitted).
|
||||||
|
// Use DequeueFirstPacketForRetransmission() to retransmit it.
|
||||||
|
func (h *sentPacketHistory) FirstOutstanding() *Packet {
|
||||||
|
if h.firstOutstanding == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &h.firstOutstanding.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueuePacketForRetransmission marks a packet for retransmission.
|
||||||
|
// A packet can only be queued once.
|
||||||
|
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
|
||||||
|
el, ok := h.packetMap[pn]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("sent packet history: packet %d not found", pn)
|
||||||
|
}
|
||||||
|
el.Value.canBeRetransmitted = false
|
||||||
|
if el == h.firstOutstanding {
|
||||||
|
h.readjustFirstOutstanding()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
|
||||||
|
// This is necessary every time the first outstanding packet is deleted or retransmitted.
|
||||||
|
func (h *sentPacketHistory) readjustFirstOutstanding() {
|
||||||
|
el := h.firstOutstanding.Next()
|
||||||
|
for el != nil && !el.Value.canBeRetransmitted {
|
||||||
|
el = el.Next()
|
||||||
|
}
|
||||||
|
h.firstOutstanding = el
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHistory) Len() int {
|
||||||
|
return len(h.packetMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
||||||
|
el, ok := h.packetMap[p]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("packet %d not found in sent packet history", p)
|
||||||
|
}
|
||||||
|
if el == h.firstOutstanding {
|
||||||
|
h.readjustFirstOutstanding()
|
||||||
|
}
|
||||||
|
h.packetList.Remove(el)
|
||||||
|
delete(h.packetMap, p)
|
||||||
|
return nil
|
||||||
|
}
|
@ -30,8 +30,9 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||||
if ack.LargestAcked >= s.nextLeastUnacked {
|
largestAcked := ack.LargestAcked()
|
||||||
s.nextLeastUnacked = ack.LargestAcked + 1
|
if largestAcked >= s.nextLeastUnacked {
|
||||||
|
s.nextLeastUnacked = largestAcked + 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -16,11 +16,10 @@ import (
|
|||||||
// allow a 10 shift right to divide.
|
// allow a 10 shift right to divide.
|
||||||
|
|
||||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||||
// where 0.100 is 100 ms which is the scaling
|
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||||
// round trip time.
|
|
||||||
const cubeScale = 40
|
const cubeScale = 40
|
||||||
const cubeCongestionWindowScale = 410
|
const cubeCongestionWindowScale = 410
|
||||||
const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale
|
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
|
||||||
|
|
||||||
const defaultNumConnections = 2
|
const defaultNumConnections = 2
|
||||||
|
|
||||||
@ -32,39 +31,35 @@ const beta float32 = 0.7
|
|||||||
// new concurrent flows and speed up convergence.
|
// new concurrent flows and speed up convergence.
|
||||||
const betaLastMax float32 = 0.85
|
const betaLastMax float32 = 0.85
|
||||||
|
|
||||||
// If true, Cubic's epoch is shifted when the sender is application-limited.
|
|
||||||
const shiftQuicCubicEpochWhenAppLimited = true
|
|
||||||
|
|
||||||
const maxCubicTimeInterval = 30 * time.Millisecond
|
|
||||||
|
|
||||||
// Cubic implements the cubic algorithm from TCP
|
// Cubic implements the cubic algorithm from TCP
|
||||||
type Cubic struct {
|
type Cubic struct {
|
||||||
clock Clock
|
clock Clock
|
||||||
|
|
||||||
// Number of connections to simulate.
|
// Number of connections to simulate.
|
||||||
numConnections int
|
numConnections int
|
||||||
|
|
||||||
// Time when this cycle started, after last loss event.
|
// Time when this cycle started, after last loss event.
|
||||||
epoch time.Time
|
epoch time.Time
|
||||||
// Time when sender went into application-limited period. Zero if not in
|
|
||||||
// application-limited period.
|
// Max congestion window used just before last loss event.
|
||||||
appLimitedStartTime time.Time
|
|
||||||
// Time when we updated last_congestion_window.
|
|
||||||
lastUpdateTime time.Time
|
|
||||||
// Last congestion window (in packets) used.
|
|
||||||
lastCongestionWindow protocol.PacketNumber
|
|
||||||
// Max congestion window (in packets) used just before last loss event.
|
|
||||||
// Note: to improve fairness to other streams an additional back off is
|
// Note: to improve fairness to other streams an additional back off is
|
||||||
// applied to this value if the new value is below our latest value.
|
// applied to this value if the new value is below our latest value.
|
||||||
lastMaxCongestionWindow protocol.PacketNumber
|
lastMaxCongestionWindow protocol.ByteCount
|
||||||
// Number of acked packets since the cycle started (epoch).
|
|
||||||
ackedPacketsCount protocol.PacketNumber
|
// Number of acked bytes since the cycle started (epoch).
|
||||||
|
ackedBytesCount protocol.ByteCount
|
||||||
|
|
||||||
// TCP Reno equivalent congestion window in packets.
|
// TCP Reno equivalent congestion window in packets.
|
||||||
estimatedTCPcongestionWindow protocol.PacketNumber
|
estimatedTCPcongestionWindow protocol.ByteCount
|
||||||
|
|
||||||
// Origin point of cubic function.
|
// Origin point of cubic function.
|
||||||
originPointCongestionWindow protocol.PacketNumber
|
originPointCongestionWindow protocol.ByteCount
|
||||||
|
|
||||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||||
timeToOriginPoint uint32
|
timeToOriginPoint uint32
|
||||||
|
|
||||||
// Last congestion window in packets computed by cubic function.
|
// Last congestion window in packets computed by cubic function.
|
||||||
lastTargetCongestionWindow protocol.PacketNumber
|
lastTargetCongestionWindow protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCubic returns a new Cubic instance
|
// NewCubic returns a new Cubic instance
|
||||||
@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic {
|
|||||||
// Reset is called after a timeout to reset the cubic state
|
// Reset is called after a timeout to reset the cubic state
|
||||||
func (c *Cubic) Reset() {
|
func (c *Cubic) Reset() {
|
||||||
c.epoch = time.Time{}
|
c.epoch = time.Time{}
|
||||||
c.appLimitedStartTime = time.Time{}
|
|
||||||
c.lastUpdateTime = time.Time{}
|
|
||||||
c.lastCongestionWindow = 0
|
|
||||||
c.lastMaxCongestionWindow = 0
|
c.lastMaxCongestionWindow = 0
|
||||||
c.ackedPacketsCount = 0
|
c.ackedBytesCount = 0
|
||||||
c.estimatedTCPcongestionWindow = 0
|
c.estimatedTCPcongestionWindow = 0
|
||||||
c.originPointCongestionWindow = 0
|
c.originPointCongestionWindow = 0
|
||||||
c.timeToOriginPoint = 0
|
c.timeToOriginPoint = 0
|
||||||
@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 {
|
|||||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Cubic) betaLastMax() float32 {
|
||||||
|
// betaLastMax is the additional backoff factor after loss for our
|
||||||
|
// N-connection emulation, which emulates the additional backoff of
|
||||||
|
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||||
|
// effective multiplier is computed as:
|
||||||
|
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||||
|
}
|
||||||
|
|
||||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||||
// the available congestion window. Resets Cubic state during quiescence.
|
// the available congestion window. Resets Cubic state during quiescence.
|
||||||
func (c *Cubic) OnApplicationLimited() {
|
func (c *Cubic) OnApplicationLimited() {
|
||||||
if shiftQuicCubicEpochWhenAppLimited {
|
// When sender is not using the available congestion window, the window does
|
||||||
// When sender is not using the available congestion window, Cubic's epoch
|
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||||
// should not continue growing. Record the time when sender goes into an
|
// using the entire window during the time since the beginning of the current
|
||||||
// app-limited period here, to compensate later when cwnd growth happens.
|
// "epoch" (the end of the last loss recovery period). Since
|
||||||
if c.appLimitedStartTime.IsZero() {
|
// application-limited periods break this assumption, we reset the epoch when
|
||||||
c.appLimitedStartTime = c.clock.Now()
|
// in such a period. This reset effectively freezes congestion window growth
|
||||||
}
|
// through application-limited periods and allows Cubic growth to continue
|
||||||
} else {
|
// when the entire window is being used.
|
||||||
// When sender is not using the available congestion window, Cubic's epoch
|
|
||||||
// should not continue growing. Reset the epoch when in such a period.
|
|
||||||
c.epoch = time.Time{}
|
c.epoch = time.Time{}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||||
// a loss event. Returns the new congestion window in packets. The new
|
// a loss event. Returns the new congestion window in packets. The new
|
||||||
// congestion window is a multiplicative decrease of our current window.
|
// congestion window is a multiplicative decrease of our current window.
|
||||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber {
|
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
|
||||||
if currentCongestionWindow < c.lastMaxCongestionWindow {
|
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
|
||||||
// We never reached the old max, so assume we are competing with another
|
// We never reached the old max, so assume we are competing with another
|
||||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||||
c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow))
|
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||||
} else {
|
} else {
|
||||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||||
}
|
}
|
||||||
c.epoch = time.Time{} // Reset time.
|
c.epoch = time.Time{} // Reset time.
|
||||||
return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta())
|
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||||
// Returns the new congestion window in packets. The new congestion window
|
// Returns the new congestion window in packets. The new congestion window
|
||||||
// follows a cubic function that depends on the time passed since last
|
// follows a cubic function that depends on the time passed since last
|
||||||
// packet loss.
|
// packet loss.
|
||||||
func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber {
|
func (c *Cubic) CongestionWindowAfterAck(
|
||||||
c.ackedPacketsCount++ // Packets acked.
|
ackedBytes protocol.ByteCount,
|
||||||
currentTime := c.clock.Now()
|
currentCongestionWindow protocol.ByteCount,
|
||||||
|
delayMin time.Duration,
|
||||||
// Cubic is "independent" of RTT, the update is limited by the time elapsed.
|
eventTime time.Time,
|
||||||
if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) {
|
) protocol.ByteCount {
|
||||||
return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow)
|
c.ackedBytesCount += ackedBytes
|
||||||
}
|
|
||||||
c.lastCongestionWindow = currentCongestionWindow
|
|
||||||
c.lastUpdateTime = currentTime
|
|
||||||
|
|
||||||
if c.epoch.IsZero() {
|
if c.epoch.IsZero() {
|
||||||
// First ACK after a loss event.
|
// First ACK after a loss event.
|
||||||
c.epoch = currentTime // Start of epoch.
|
c.epoch = eventTime // Start of epoch.
|
||||||
c.ackedPacketsCount = 1 // Reset count.
|
c.ackedBytesCount = ackedBytes // Reset count.
|
||||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||||
@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
|
|||||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// If sender was app-limited, then freeze congestion window growth during
|
|
||||||
// app-limited period. Continue growth now by shifting the epoch-start
|
|
||||||
// through the app-limited period.
|
|
||||||
if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() {
|
|
||||||
shift := currentTime.Sub(c.appLimitedStartTime)
|
|
||||||
c.epoch = c.epoch.Add(shift)
|
|
||||||
c.appLimitedStartTime = time.Time{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||||
// the round trip time in account. This is done to allow us to use shift as a
|
// the round trip time in account. This is done to allow us to use shift as a
|
||||||
// divide operator.
|
// divide operator.
|
||||||
elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000
|
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||||
|
|
||||||
|
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||||
|
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||||
// Right-shifts of negative, signed numbers have
|
|
||||||
// implementation-dependent behavior. Force the offset to be
|
|
||||||
// positive, similar to the kernel implementation.
|
|
||||||
if offset < 0 {
|
if offset < 0 {
|
||||||
offset = -offset
|
offset = -offset
|
||||||
}
|
}
|
||||||
deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale)
|
|
||||||
var targetCongestionWindow protocol.PacketNumber
|
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
|
||||||
|
var targetCongestionWindow protocol.ByteCount
|
||||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||||
} else {
|
} else {
|
||||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||||
}
|
}
|
||||||
// With dynamic beta/alpha based on number of active streams, it is possible
|
// Limit the CWND increase to half the acked bytes.
|
||||||
// for the required_ack_count to become much lower than acked_packets_count_
|
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||||
// suddenly, leading to more than one iteration through the following loop.
|
|
||||||
for {
|
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||||
// Update estimated TCP congestion_window.
|
// time we ack an estimated tcp window of bytes. For small
|
||||||
requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha())
|
// congestion windows (less than 25), the formula below will
|
||||||
if c.ackedPacketsCount < requiredAckCount {
|
// increase slightly slower than linearly per estimated tcp window
|
||||||
break
|
// of bytes.
|
||||||
}
|
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
|
||||||
c.ackedPacketsCount -= requiredAckCount
|
c.ackedBytesCount = 0
|
||||||
c.estimatedTCPcongestionWindow++
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have a new cubic congestion window.
|
// We have a new cubic congestion window.
|
||||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||||
@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
|
|||||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
return targetCongestionWindow
|
return targetCongestionWindow
|
||||||
}
|
}
|
||||||
|
|
@ -9,8 +9,8 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
maxBurstBytes = 3 * protocol.DefaultTCPMSS
|
maxBurstBytes = 3 * protocol.DefaultTCPMSS
|
||||||
defaultMinimumCongestionWindow protocol.PacketNumber = 2
|
|
||||||
renoBeta float32 = 0.7 // Reno backoff factor.
|
renoBeta float32 = 0.7 // Reno backoff factor.
|
||||||
|
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
|
||||||
)
|
)
|
||||||
|
|
||||||
type cubicSender struct {
|
type cubicSender struct {
|
||||||
@ -31,12 +31,6 @@ type cubicSender struct {
|
|||||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||||
largestSentAtLastCutback protocol.PacketNumber
|
largestSentAtLastCutback protocol.PacketNumber
|
||||||
|
|
||||||
// Congestion window in packets.
|
|
||||||
congestionWindow protocol.PacketNumber
|
|
||||||
|
|
||||||
// Slow start congestion window in packets, aka ssthresh.
|
|
||||||
slowstartThreshold protocol.PacketNumber
|
|
||||||
|
|
||||||
// Whether the last loss event caused us to exit slowstart.
|
// Whether the last loss event caused us to exit slowstart.
|
||||||
// Used for stats collection of slowstartPacketsLost
|
// Used for stats collection of slowstartPacketsLost
|
||||||
lastCutbackExitedSlowstart bool
|
lastCutbackExitedSlowstart bool
|
||||||
@ -44,24 +38,35 @@ type cubicSender struct {
|
|||||||
// When true, exit slow start with large cutback of congestion window.
|
// When true, exit slow start with large cutback of congestion window.
|
||||||
slowStartLargeReduction bool
|
slowStartLargeReduction bool
|
||||||
|
|
||||||
// Minimum congestion window in packets.
|
// Congestion window in packets.
|
||||||
minCongestionWindow protocol.PacketNumber
|
congestionWindow protocol.ByteCount
|
||||||
|
|
||||||
// Maximum number of outstanding packets for tcp.
|
// Minimum congestion window in packets.
|
||||||
maxTCPCongestionWindow protocol.PacketNumber
|
minCongestionWindow protocol.ByteCount
|
||||||
|
|
||||||
|
// Maximum congestion window.
|
||||||
|
maxCongestionWindow protocol.ByteCount
|
||||||
|
|
||||||
|
// Slow start congestion window in bytes, aka ssthresh.
|
||||||
|
slowstartThreshold protocol.ByteCount
|
||||||
|
|
||||||
// Number of connections to simulate.
|
// Number of connections to simulate.
|
||||||
numConnections int
|
numConnections int
|
||||||
|
|
||||||
// ACK counter for the Reno implementation.
|
// ACK counter for the Reno implementation.
|
||||||
congestionWindowCount protocol.ByteCount
|
numAckedPackets uint64
|
||||||
|
|
||||||
initialCongestionWindow protocol.PacketNumber
|
initialCongestionWindow protocol.ByteCount
|
||||||
initialMaxCongestionWindow protocol.PacketNumber
|
initialMaxCongestionWindow protocol.ByteCount
|
||||||
|
|
||||||
|
minSlowStartExitWindow protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ SendAlgorithm = &cubicSender{}
|
||||||
|
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
|
||||||
|
|
||||||
// NewCubicSender makes a new cubic sender
|
// NewCubicSender makes a new cubic sender
|
||||||
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo {
|
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
|
||||||
return &cubicSender{
|
return &cubicSender{
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
initialCongestionWindow: initialCongestionWindow,
|
initialCongestionWindow: initialCongestionWindow,
|
||||||
@ -69,28 +74,37 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
|
|||||||
congestionWindow: initialCongestionWindow,
|
congestionWindow: initialCongestionWindow,
|
||||||
minCongestionWindow: defaultMinimumCongestionWindow,
|
minCongestionWindow: defaultMinimumCongestionWindow,
|
||||||
slowstartThreshold: initialMaxCongestionWindow,
|
slowstartThreshold: initialMaxCongestionWindow,
|
||||||
maxTCPCongestionWindow: initialMaxCongestionWindow,
|
maxCongestionWindow: initialMaxCongestionWindow,
|
||||||
numConnections: defaultNumConnections,
|
numConnections: defaultNumConnections,
|
||||||
cubic: NewCubic(clock),
|
cubic: NewCubic(clock),
|
||||||
reno: reno,
|
reno: reno,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
// TimeUntilSend returns when the next packet should be sent.
|
||||||
|
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
|
||||||
if c.InRecovery() {
|
if c.InRecovery() {
|
||||||
// PRR is used when in recovery.
|
// PRR is used when in recovery.
|
||||||
return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold())
|
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
|
||||||
}
|
|
||||||
if c.GetCongestionWindow() > bytesInFlight {
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return utils.InfDuration
|
}
|
||||||
|
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
|
||||||
|
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
|
||||||
|
delay = delay * 8 / 5
|
||||||
|
}
|
||||||
|
return delay
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
func (c *cubicSender) OnPacketSent(
|
||||||
// Only update bytesInFlight for data packets.
|
sentTime time.Time,
|
||||||
|
bytesInFlight protocol.ByteCount,
|
||||||
|
packetNumber protocol.PacketNumber,
|
||||||
|
bytes protocol.ByteCount,
|
||||||
|
isRetransmittable bool,
|
||||||
|
) {
|
||||||
if !isRetransmittable {
|
if !isRetransmittable {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
if c.InRecovery() {
|
if c.InRecovery() {
|
||||||
// PRR is used when in recovery.
|
// PRR is used when in recovery.
|
||||||
@ -98,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By
|
|||||||
}
|
}
|
||||||
c.largestSentPacketNumber = packetNumber
|
c.largestSentPacketNumber = packetNumber
|
||||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) InRecovery() bool {
|
func (c *cubicSender) InRecovery() bool {
|
||||||
@ -110,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
||||||
return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS
|
return c.congestionWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
|
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
|
||||||
return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS
|
return c.slowstartThreshold
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) ExitSlowstart() {
|
func (c *cubicSender) ExitSlowstart() {
|
||||||
c.slowstartThreshold = c.congestionWindow
|
c.slowstartThreshold = c.congestionWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
|
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
|
||||||
return c.slowstartThreshold
|
return c.slowstartThreshold
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
func (c *cubicSender) OnPacketAcked(
|
||||||
|
ackedPacketNumber protocol.PacketNumber,
|
||||||
|
ackedBytes protocol.ByteCount,
|
||||||
|
priorInFlight protocol.ByteCount,
|
||||||
|
eventTime time.Time,
|
||||||
|
) {
|
||||||
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
|
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||||
if c.InRecovery() {
|
if c.InRecovery() {
|
||||||
// PRR is used when in recovery.
|
// PRR is used when in recovery.
|
||||||
c.prr.OnPacketAcked(ackedBytes)
|
c.prr.OnPacketAcked(ackedBytes)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight)
|
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||||
if c.InSlowStart() {
|
if c.InSlowStart() {
|
||||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
func (c *cubicSender) OnPacketLost(
|
||||||
|
packetNumber protocol.PacketNumber,
|
||||||
|
lostBytes protocol.ByteCount,
|
||||||
|
priorInFlight protocol.ByteCount,
|
||||||
|
) {
|
||||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||||
// already sent should be treated as a single loss event, since it's expected.
|
// already sent should be treated as a single loss event, since it's expected.
|
||||||
if packetNumber <= c.largestSentAtLastCutback {
|
if packetNumber <= c.largestSentAtLastCutback {
|
||||||
@ -152,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
|||||||
c.stats.slowstartPacketsLost++
|
c.stats.slowstartPacketsLost++
|
||||||
c.stats.slowstartBytesLost += lostBytes
|
c.stats.slowstartBytesLost += lostBytes
|
||||||
if c.slowStartLargeReduction {
|
if c.slowStartLargeReduction {
|
||||||
if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS {
|
// Reduce congestion window by lost_bytes for every loss.
|
||||||
// Reduce congestion window by 1 for every mss of bytes lost.
|
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
|
||||||
c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow)
|
|
||||||
}
|
|
||||||
c.slowstartThreshold = c.congestionWindow
|
c.slowstartThreshold = c.congestionWindow
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -166,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
|||||||
c.stats.slowstartPacketsLost++
|
c.stats.slowstartPacketsLost++
|
||||||
}
|
}
|
||||||
|
|
||||||
c.prr.OnPacketLost(bytesInFlight)
|
c.prr.OnPacketLost(priorInFlight)
|
||||||
|
|
||||||
// TODO(chromium): Separate out all of slow start into a separate class.
|
// TODO(chromium): Separate out all of slow start into a separate class.
|
||||||
if c.slowStartLargeReduction && c.InSlowStart() {
|
if c.slowStartLargeReduction && c.InSlowStart() {
|
||||||
c.congestionWindow = c.congestionWindow - 1
|
if c.congestionWindow >= 2*c.initialCongestionWindow {
|
||||||
|
c.minSlowStartExitWindow = c.congestionWindow / 2
|
||||||
|
}
|
||||||
|
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
|
||||||
} else if c.reno {
|
} else if c.reno {
|
||||||
c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta())
|
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
|
||||||
} else {
|
} else {
|
||||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||||
}
|
}
|
||||||
// Enforce a minimum congestion window.
|
|
||||||
if c.congestionWindow < c.minCongestionWindow {
|
if c.congestionWindow < c.minCongestionWindow {
|
||||||
c.congestionWindow = c.minCongestionWindow
|
c.congestionWindow = c.minCongestionWindow
|
||||||
}
|
}
|
||||||
@ -184,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
|
|||||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||||
// reset packet count from congestion avoidance mode. We start
|
// reset packet count from congestion avoidance mode. We start
|
||||||
// counting again when we're out of recovery.
|
// counting again when we're out of recovery.
|
||||||
c.congestionWindowCount = 0
|
c.numAckedPackets = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cubicSender) RenoBeta() float32 {
|
func (c *cubicSender) RenoBeta() float32 {
|
||||||
@ -197,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 {
|
|||||||
|
|
||||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||||
// represents, but quic has a separate ack for each packet.
|
// represents, but quic has a separate ack for each packet.
|
||||||
func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
|
func (c *cubicSender) maybeIncreaseCwnd(
|
||||||
|
ackedPacketNumber protocol.PacketNumber,
|
||||||
|
ackedBytes protocol.ByteCount,
|
||||||
|
priorInFlight protocol.ByteCount,
|
||||||
|
eventTime time.Time,
|
||||||
|
) {
|
||||||
// Do not increase the congestion window unless the sender is close to using
|
// Do not increase the congestion window unless the sender is close to using
|
||||||
// the current window.
|
// the current window.
|
||||||
if !c.isCwndLimited(bytesInFlight) {
|
if !c.isCwndLimited(priorInFlight) {
|
||||||
c.cubic.OnApplicationLimited()
|
c.cubic.OnApplicationLimited()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.congestionWindow >= c.maxTCPCongestionWindow {
|
if c.congestionWindow >= c.maxCongestionWindow {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.InSlowStart() {
|
if c.InSlowStart() {
|
||||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||||
c.congestionWindow++
|
c.congestionWindow += protocol.DefaultTCPMSS
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Congestion avoidance
|
||||||
if c.reno {
|
if c.reno {
|
||||||
// Classic Reno congestion avoidance.
|
// Classic Reno congestion avoidance.
|
||||||
c.congestionWindowCount++
|
c.numAckedPackets++
|
||||||
// Divide by num_connections to smoothly increase the CWND at a faster
|
// Divide by num_connections to smoothly increase the CWND at a faster
|
||||||
// rate than conventional Reno.
|
// rate than conventional Reno.
|
||||||
if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow {
|
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
|
||||||
c.congestionWindow++
|
c.congestionWindow += protocol.DefaultTCPMSS
|
||||||
c.congestionWindowCount = 0
|
c.numAckedPackets = 0
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT()))
|
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,21 +306,13 @@ func (c *cubicSender) OnConnectionMigration() {
|
|||||||
c.largestSentAtLastCutback = 0
|
c.largestSentAtLastCutback = 0
|
||||||
c.lastCutbackExitedSlowstart = false
|
c.lastCutbackExitedSlowstart = false
|
||||||
c.cubic.Reset()
|
c.cubic.Reset()
|
||||||
c.congestionWindowCount = 0
|
c.numAckedPackets = 0
|
||||||
c.congestionWindow = c.initialCongestionWindow
|
c.congestionWindow = c.initialCongestionWindow
|
||||||
c.slowstartThreshold = c.initialMaxCongestionWindow
|
c.slowstartThreshold = c.initialMaxCongestionWindow
|
||||||
c.maxTCPCongestionWindow = c.initialMaxCongestionWindow
|
c.maxCongestionWindow = c.initialMaxCongestionWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSlowStartLargeReduction allows enabling the SSLR experiment
|
// SetSlowStartLargeReduction allows enabling the SSLR experiment
|
||||||
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
|
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
|
||||||
c.slowStartLargeReduction = enabled
|
c.slowStartLargeReduction = enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// RetransmissionDelay gives the time to retransmission
|
|
||||||
func (c *cubicSender) RetransmissionDelay() time.Duration {
|
|
||||||
if c.rttStats.SmoothedRTT() == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
|
|
||||||
}
|
|
@ -8,16 +8,15 @@ import (
|
|||||||
|
|
||||||
// A SendAlgorithm performs congestion control and calculates the congestion window
|
// A SendAlgorithm performs congestion control and calculates the congestion window
|
||||||
type SendAlgorithm interface {
|
type SendAlgorithm interface {
|
||||||
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
|
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
|
||||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
|
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
|
||||||
GetCongestionWindow() protocol.ByteCount
|
GetCongestionWindow() protocol.ByteCount
|
||||||
MaybeExitSlowStart()
|
MaybeExitSlowStart()
|
||||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
|
||||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
|
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
|
||||||
SetNumEmulatedConnections(n int)
|
SetNumEmulatedConnections(n int)
|
||||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||||
OnConnectionMigration()
|
OnConnectionMigration()
|
||||||
RetransmissionDelay() time.Duration
|
|
||||||
|
|
||||||
// Experiments
|
// Experiments
|
||||||
SetSlowStartLargeReduction(enabled bool)
|
SetSlowStartLargeReduction(enabled bool)
|
||||||
@ -31,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface {
|
|||||||
// Stuff only used in testing
|
// Stuff only used in testing
|
||||||
|
|
||||||
HybridSlowStart() *HybridSlowStart
|
HybridSlowStart() *HybridSlowStart
|
||||||
SlowstartThreshold() protocol.PacketNumber
|
SlowstartThreshold() protocol.ByteCount
|
||||||
RenoBeta() float32
|
RenoBeta() float32
|
||||||
InRecovery() bool
|
InRecovery() bool
|
||||||
}
|
}
|
@ -1,10 +1,7 @@
|
|||||||
package congestion
|
package congestion
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
||||||
@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
|
|||||||
// OnPacketLost should be called on the first loss that triggers a recovery
|
// OnPacketLost should be called on the first loss that triggers a recovery
|
||||||
// period and all other methods in this class should only be called when in
|
// period and all other methods in this class should only be called when in
|
||||||
// recovery.
|
// recovery.
|
||||||
func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) {
|
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
|
||||||
p.bytesSentSinceLoss = 0
|
p.bytesSentSinceLoss = 0
|
||||||
p.bytesInFlightBeforeLoss = bytesInFlight
|
p.bytesInFlightBeforeLoss = priorInFlight
|
||||||
p.bytesDeliveredSinceLoss = 0
|
p.bytesDeliveredSinceLoss = 0
|
||||||
p.ackCountSinceLoss = 0
|
p.ackCountSinceLoss = 0
|
||||||
}
|
}
|
||||||
@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
|
|||||||
p.ackCountSinceLoss++
|
p.ackCountSinceLoss++
|
||||||
}
|
}
|
||||||
|
|
||||||
// TimeUntilSend calculates the time until a packet can be sent
|
// CanSend returns if packets can be sent
|
||||||
func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration {
|
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
|
||||||
// Return QuicTime::Zero In order to ensure limited transmit always works.
|
// Return QuicTime::Zero In order to ensure limited transmit always works.
|
||||||
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
|
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
|
||||||
return 0
|
return true
|
||||||
}
|
}
|
||||||
if congestionWindow > bytesInFlight {
|
if congestionWindow > bytesInFlight {
|
||||||
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
|
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
|
||||||
// of sending the entire available window. This prevents burst retransmits
|
// of sending the entire available window. This prevents burst retransmits
|
||||||
// when more packets are lost than the CWND reduction.
|
// when more packets are lost than the CWND reduction.
|
||||||
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
|
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
|
||||||
if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss {
|
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
|
||||||
return utils.InfDuration
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
// Implement Proportional Rate Reduction (RFC6937).
|
// Implement Proportional Rate Reduction (RFC6937).
|
||||||
// Checks a simplified version of the PRR formula that doesn't use division:
|
// Checks a simplified version of the PRR formula that doesn't use division:
|
||||||
// AvailableSendWindow =
|
// AvailableSendWindow =
|
||||||
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
|
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
|
||||||
if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss {
|
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return utils.InfDuration
|
|
||||||
}
|
}
|
101
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
Normal file
101
vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go
generated
vendored
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package congestion
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
rttAlpha float32 = 0.125
|
||||||
|
oneMinusAlpha float32 = (1 - rttAlpha)
|
||||||
|
rttBeta float32 = 0.25
|
||||||
|
oneMinusBeta float32 = (1 - rttBeta)
|
||||||
|
// The default RTT used before an RTT sample is taken.
|
||||||
|
defaultInitialRTT = 100 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// RTTStats provides round-trip statistics
|
||||||
|
type RTTStats struct {
|
||||||
|
minRTT time.Duration
|
||||||
|
latestRTT time.Duration
|
||||||
|
smoothedRTT time.Duration
|
||||||
|
meanDeviation time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRTTStats makes a properly initialized RTTStats object
|
||||||
|
func NewRTTStats() *RTTStats {
|
||||||
|
return &RTTStats{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinRTT Returns the minRTT for the entire connection.
|
||||||
|
// May return Zero if no valid updates have occurred.
|
||||||
|
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
|
||||||
|
|
||||||
|
// LatestRTT returns the most recent rtt measurement.
|
||||||
|
// May return Zero if no valid updates have occurred.
|
||||||
|
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
|
||||||
|
|
||||||
|
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
|
||||||
|
// May return Zero if no valid updates have occurred.
|
||||||
|
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
|
||||||
|
|
||||||
|
// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection.
|
||||||
|
// If no valid updates have occurred, it returns the initial RTT.
|
||||||
|
func (r *RTTStats) SmoothedOrInitialRTT() time.Duration {
|
||||||
|
if r.smoothedRTT != 0 {
|
||||||
|
return r.smoothedRTT
|
||||||
|
}
|
||||||
|
return defaultInitialRTT
|
||||||
|
}
|
||||||
|
|
||||||
|
// MeanDeviation gets the mean deviation
|
||||||
|
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
|
||||||
|
|
||||||
|
// UpdateRTT updates the RTT based on a new sample.
|
||||||
|
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
||||||
|
if sendDelta == utils.InfDuration || sendDelta <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
|
||||||
|
// ackDelay but the raw observed sendDelta, since poor clock granularity at
|
||||||
|
// the client may cause a high ackDelay to result in underestimation of the
|
||||||
|
// r.minRTT.
|
||||||
|
if r.minRTT == 0 || r.minRTT > sendDelta {
|
||||||
|
r.minRTT = sendDelta
|
||||||
|
}
|
||||||
|
|
||||||
|
// Correct for ackDelay if information received from the peer results in a
|
||||||
|
// an RTT sample at least as large as minRTT. Otherwise, only use the
|
||||||
|
// sendDelta.
|
||||||
|
sample := sendDelta
|
||||||
|
if sample-r.minRTT >= ackDelay {
|
||||||
|
sample -= ackDelay
|
||||||
|
}
|
||||||
|
r.latestRTT = sample
|
||||||
|
// First time call.
|
||||||
|
if r.smoothedRTT == 0 {
|
||||||
|
r.smoothedRTT = sample
|
||||||
|
r.meanDeviation = sample / 2
|
||||||
|
} else {
|
||||||
|
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
|
||||||
|
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
|
||||||
|
func (r *RTTStats) OnConnectionMigration() {
|
||||||
|
r.latestRTT = 0
|
||||||
|
r.minRTT = 0
|
||||||
|
r.smoothedRTT = 0
|
||||||
|
r.meanDeviation = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
|
||||||
|
// is larger. The mean deviation is increased to the most recent deviation if
|
||||||
|
// it's larger.
|
||||||
|
func (r *RTTStats) ExpireSmoothedMetrics() {
|
||||||
|
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
|
||||||
|
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
|
||||||
|
}
|
22
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go
generated
vendored
22
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go
generated
vendored
@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
|
|||||||
return cert.Certificate[0], nil
|
return cert.Certificate[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||||
c := cc.config
|
conf := c.config
|
||||||
c, err := maybeGetConfigForClient(c, sni)
|
conf, err := maybeGetConfigForClient(conf, sni)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// The rest of this function is mostly copied from crypto/tls.getCertificate
|
// The rest of this function is mostly copied from crypto/tls.getCertificate
|
||||||
|
|
||||||
if c.GetCertificate != nil {
|
if conf.GetCertificate != nil {
|
||||||
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
||||||
if cert != nil || err != nil {
|
if cert != nil || err != nil {
|
||||||
return cert, err
|
return cert, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.Certificates) == 0 {
|
if len(conf.Certificates) == 0 {
|
||||||
return nil, errNoMatchingCertificate
|
return nil, errNoMatchingCertificate
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
|
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
|
||||||
// There's only one choice, so no point doing any work.
|
// There's only one choice, so no point doing any work.
|
||||||
return &c.Certificates[0], nil
|
return &conf.Certificates[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
name := strings.ToLower(sni)
|
name := strings.ToLower(sni)
|
||||||
@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
|||||||
name = name[:len(name)-1]
|
name = name[:len(name)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
if cert, ok := c.NameToCertificate[name]; ok {
|
if cert, ok := conf.NameToCertificate[name]; ok {
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
|||||||
for i := range labels {
|
for i := range labels {
|
||||||
labels[i] = "*"
|
labels[i] = "*"
|
||||||
candidate := strings.Join(labels, ".")
|
candidate := strings.Join(labels, ".")
|
||||||
if cert, ok := c.NameToCertificate[candidate]; ok {
|
if cert, ok := conf.NameToCertificate[candidate]; ok {
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If nothing matches, return the first certificate.
|
// If nothing matches, return the first certificate.
|
||||||
return &c.Certificates[0], nil
|
return &conf.Certificates[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
|
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
|
||||||
|
5
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go
generated
vendored
5
vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go
generated
vendored
@ -18,6 +18,7 @@ type CertManager interface {
|
|||||||
GetLeafCertHash() (uint64, error)
|
GetLeafCertHash() (uint64, error)
|
||||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
||||||
Verify(hostname string) error
|
Verify(hostname string) error
|
||||||
|
GetChain() []*x509.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
type certManager struct {
|
type certManager struct {
|
||||||
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *certManager) GetChain() []*x509.Certificate {
|
||||||
|
return c.chain
|
||||||
|
}
|
||||||
|
|
||||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
func (c *certManager) GetCommonCertificateHashes() []byte {
|
||||||
return getCommonCertificateHashes()
|
return getCommonCertificateHashes()
|
||||||
}
|
}
|
||||||
|
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
@ -1,61 +0,0 @@
|
|||||||
// +build ignore
|
|
||||||
|
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/cipher"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/aead/chacha20"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type aeadChacha20Poly1305 struct {
|
|
||||||
otherIV []byte
|
|
||||||
myIV []byte
|
|
||||||
encrypter cipher.AEAD
|
|
||||||
decrypter cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
|
|
||||||
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
|
||||||
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
|
|
||||||
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
|
|
||||||
}
|
|
||||||
// copy because ChaCha20Poly1305 expects array pointers
|
|
||||||
var MyKey, OtherKey [32]byte
|
|
||||||
copy(MyKey[:], myKey)
|
|
||||||
copy(OtherKey[:], otherKey)
|
|
||||||
|
|
||||||
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &aeadChacha20Poly1305{
|
|
||||||
otherIV: otherIV,
|
|
||||||
myIV: myIV,
|
|
||||||
encrypter: encrypter,
|
|
||||||
decrypter: decrypter,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
|
||||||
res := make([]byte, 12)
|
|
||||||
copy(res[0:4], iv)
|
|
||||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
|
||||||
return res
|
|
||||||
}
|
|
71
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go
generated
vendored
71
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go
generated
vendored
@ -1,71 +0,0 @@
|
|||||||
// +build ignore
|
|
||||||
|
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Chacha20poly1305", func() {
|
|
||||||
var (
|
|
||||||
alice, bob AEAD
|
|
||||||
keyAlice, keyBob, ivAlice, ivBob []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
keyAlice = make([]byte, 32)
|
|
||||||
keyBob = make([]byte, 32)
|
|
||||||
ivAlice = make([]byte, 4)
|
|
||||||
ivBob = make([]byte, 4)
|
|
||||||
rand.Reader.Read(keyAlice)
|
|
||||||
rand.Reader.Read(keyBob)
|
|
||||||
rand.Reader.Read(ivAlice)
|
|
||||||
rand.Reader.Read(ivBob)
|
|
||||||
var err error
|
|
||||||
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := bob.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("seals and opens reverse", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
text, err := alice.Open(nil, b, 42, []byte("aad"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(text).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the proper length", func() {
|
|
||||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
Expect(b).To(HaveLen(6 + 12))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("fails with wrong aad", func() {
|
|
||||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
|
||||||
_, err := bob.Open(nil, b, 42, []byte("aad2"))
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects wrong key and iv sizes", func() {
|
|
||||||
var err error
|
|
||||||
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
|
|
||||||
Expect(err).To(MatchError(e))
|
|
||||||
})
|
|
||||||
})
|
|
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go
generated
vendored
@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) {
|
|||||||
if _, err := rand.Read(c.secret[:]); err != nil {
|
if _, err := rand.Read(c.secret[:]); err != nil {
|
||||||
return nil, errors.New("Curve25519: could not create private key")
|
return nil, errors.New("Curve25519: could not create private key")
|
||||||
}
|
}
|
||||||
// See https://cr.yp.to/ecdh.html
|
|
||||||
c.secret[0] &= 248
|
|
||||||
c.secret[31] &= 127
|
|
||||||
c.secret[31] |= 64
|
|
||||||
curve25519.ScalarBaseMult(&c.public, &c.secret)
|
curve25519.ScalarBaseMult(&c.public, &c.secret)
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
19
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
19
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
@ -1,13 +1,16 @@
|
|||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto"
|
||||||
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
"github.com/bifurcation/mint"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret"
|
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
|
||||||
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
|
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||||
@ -16,6 +19,14 @@ type TLSExporter interface {
|
|||||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func qhkdfExpand(secret []byte, label string, length int) []byte {
|
||||||
|
qlabel := make([]byte, 2+1+5+len(label))
|
||||||
|
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
|
||||||
|
qlabel[2] = uint8(5 + len(label))
|
||||||
|
copy(qlabel[3:], []byte("QUIC "+label))
|
||||||
|
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
|
||||||
|
}
|
||||||
|
|
||||||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
||||||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
||||||
var myLabel, otherLabel string
|
var myLabel, otherLabel string
|
||||||
@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen)
|
key = qhkdfExpand(secret, "key", cs.KeyLen)
|
||||||
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen)
|
iv = qhkdfExpand(secret, "iv", cs.IvLen)
|
||||||
return key, iv, nil
|
return key, iv, nil
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
"golang.org/x/crypto/hkdf"
|
||||||
)
|
)
|
||||||
@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
|
|||||||
} else {
|
} else {
|
||||||
info.Write([]byte("QUIC key expansion\x00"))
|
info.Write([]byte("QUIC key expansion\x00"))
|
||||||
}
|
}
|
||||||
utils.BigEndian.WriteUint64(&info, uint64(connID))
|
info.Write(connID)
|
||||||
info.Write(chlo)
|
info.Write(chlo)
|
||||||
info.Write(scfg)
|
info.Write(scfg)
|
||||||
info.Write(cert)
|
info.Write(cert)
|
||||||
|
17
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
17
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
@ -2,13 +2,12 @@ package crypto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"encoding/binary"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
"github.com/bifurcation/mint"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39}
|
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
|
||||||
|
|
||||||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||||
clientSecret, serverSecret := computeSecrets(connectionID)
|
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||||
@ -28,17 +27,15 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
|
|||||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||||
connID := make([]byte, 8)
|
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
|
||||||
binary.BigEndian.PutUint64(connID, uint64(connectionID))
|
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
|
||||||
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID)
|
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
|
||||||
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
|
|
||||||
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
||||||
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16)
|
key = qhkdfExpand(secret, "key", 16)
|
||||||
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12)
|
iv = qhkdfExpand(secret, "iv", 12)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
36
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go
generated
vendored
36
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go
generated
vendored
@ -1,10 +1,11 @@
|
|||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
|
||||||
"github.com/lucas-clemente/fnv128a"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
|
|||||||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := fnv128a.New()
|
hash := fnv.New128a()
|
||||||
hash.Write(associatedData)
|
hash.Write(associatedData)
|
||||||
hash.Write(src[12:])
|
hash.Write(src[12:])
|
||||||
if n.perspective == protocol.PerspectiveServer {
|
if n.perspective == protocol.PerspectiveServer {
|
||||||
@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
|
|||||||
} else {
|
} else {
|
||||||
hash.Write([]byte("Server"))
|
hash.Write([]byte("Server"))
|
||||||
}
|
}
|
||||||
testHigh, testLow := hash.Sum128()
|
sum := make([]byte, 0, 16)
|
||||||
|
sum = hash.Sum(sum)
|
||||||
|
// The tag is written in little endian, so we need to reverse the slice.
|
||||||
|
reverse(sum)
|
||||||
|
|
||||||
low := binary.LittleEndian.Uint64(src)
|
if !bytes.Equal(sum[:12], src[:12]) {
|
||||||
high := binary.LittleEndian.Uint32(src[8:])
|
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
|
||||||
|
|
||||||
if uint32(testHigh&0xffffffff) != high || testLow != low {
|
|
||||||
return nil, errors.New("NullAEAD: failed to authenticate received data")
|
|
||||||
}
|
}
|
||||||
return src[12:], nil
|
return src[12:], nil
|
||||||
}
|
}
|
||||||
@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
|
|||||||
dst = dst[:12+len(src)]
|
dst = dst[:12+len(src)]
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := fnv128a.New()
|
hash := fnv.New128a()
|
||||||
hash.Write(associatedData)
|
hash.Write(associatedData)
|
||||||
hash.Write(src)
|
hash.Write(src)
|
||||||
|
|
||||||
@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
|
|||||||
} else {
|
} else {
|
||||||
hash.Write([]byte("Client"))
|
hash.Write([]byte("Client"))
|
||||||
}
|
}
|
||||||
|
sum := make([]byte, 0, 16)
|
||||||
high, low := hash.Sum128()
|
sum = hash.Sum(sum)
|
||||||
|
// The tag is written in little endian, so we need to reverse the slice.
|
||||||
|
reverse(sum)
|
||||||
|
|
||||||
copy(dst[12:], src)
|
copy(dst[12:], src)
|
||||||
binary.LittleEndian.PutUint64(dst, low)
|
copy(dst, sum[:12])
|
||||||
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
|
|
||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nullAEADFNV128a) Overhead() int {
|
func (n *nullAEADFNV128a) Overhead() int {
|
||||||
return 12
|
return 12
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverse(a []byte) {
|
||||||
|
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
|
||||||
|
a[left], a[right] = a[right], a[left]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
76
vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go
generated
vendored
76
vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go
generated
vendored
@ -1,76 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StkSource is used to create and verify source address tokens
|
|
||||||
type StkSource interface {
|
|
||||||
// NewToken creates a new token
|
|
||||||
NewToken([]byte) ([]byte, error)
|
|
||||||
// DecodeToken decodes a token
|
|
||||||
DecodeToken([]byte) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type stkSource struct {
|
|
||||||
aead cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
const stkKeySize = 16
|
|
||||||
|
|
||||||
// Chrome currently sets this to 12, but discusses changing it to 16. We start
|
|
||||||
// at 16 :)
|
|
||||||
const stkNonceSize = 16
|
|
||||||
|
|
||||||
// NewStkSource creates a source for source address tokens
|
|
||||||
func NewStkSource() (StkSource, error) {
|
|
||||||
secret := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(secret); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
key, err := deriveKey(secret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &stkSource{aead: aead}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
|
||||||
nonce := make([]byte, stkNonceSize)
|
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.aead.Seal(nonce, nonce, data, nil), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
|
|
||||||
if len(p) < stkNonceSize {
|
|
||||||
return nil, fmt.Errorf("STK too short: %d", len(p))
|
|
||||||
}
|
|
||||||
nonce := p[:stkNonceSize]
|
|
||||||
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deriveKey(secret []byte) ([]byte, error) {
|
|
||||||
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
|
|
||||||
key := make([]byte, stkKeySize)
|
|
||||||
if _, err := io.ReadFull(r, key); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
|
78
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
78
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
@ -4,41 +4,38 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type baseFlowController struct {
|
type baseFlowController struct {
|
||||||
mutex sync.RWMutex
|
// for sending data
|
||||||
|
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
|
|
||||||
bytesSent protocol.ByteCount
|
bytesSent protocol.ByteCount
|
||||||
sendWindow protocol.ByteCount
|
sendWindow protocol.ByteCount
|
||||||
|
|
||||||
lastWindowUpdateTime time.Time
|
// for receiving data
|
||||||
|
mutex sync.RWMutex
|
||||||
bytesRead protocol.ByteCount
|
bytesRead protocol.ByteCount
|
||||||
highestReceived protocol.ByteCount
|
highestReceived protocol.ByteCount
|
||||||
receiveWindow protocol.ByteCount
|
receiveWindow protocol.ByteCount
|
||||||
receiveWindowIncrement protocol.ByteCount
|
receiveWindowSize protocol.ByteCount
|
||||||
maxReceiveWindowIncrement protocol.ByteCount
|
maxReceiveWindowSize protocol.ByteCount
|
||||||
|
|
||||||
|
epochStartTime time.Time
|
||||||
|
epochStartOffset protocol.ByteCount
|
||||||
|
rttStats *congestion.RTTStats
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||||
c.mutex.Lock()
|
|
||||||
defer c.mutex.Unlock()
|
|
||||||
|
|
||||||
c.bytesSent += n
|
c.bytesSent += n
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
||||||
// it returns true if the window was actually updated
|
// it returns true if the window was actually updated
|
||||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||||
c.mutex.Lock()
|
|
||||||
defer c.mutex.Unlock()
|
|
||||||
|
|
||||||
if offset > c.sendWindow {
|
if offset > c.sendWindow {
|
||||||
c.sendWindow = offset
|
c.sendWindow = offset
|
||||||
}
|
}
|
||||||
@ -57,52 +54,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
|
|||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
// pretend we sent a WindowUpdate when reading the first byte
|
// pretend we sent a WindowUpdate when reading the first byte
|
||||||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
// this way auto-tuning of the window size already works for the first WindowUpdate
|
||||||
if c.bytesRead == 0 {
|
if c.bytesRead == 0 {
|
||||||
c.lastWindowUpdateTime = time.Now()
|
c.startNewAutoTuningEpoch()
|
||||||
}
|
}
|
||||||
c.bytesRead += n
|
c.bytesRead += n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||||
|
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||||
|
// update the window when more than the threshold was consumed
|
||||||
|
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
|
||||||
|
}
|
||||||
|
|
||||||
// getWindowUpdate updates the receive window, if necessary
|
// getWindowUpdate updates the receive window, if necessary
|
||||||
// it returns the new offset
|
// it returns the new offset
|
||||||
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
||||||
diff := c.receiveWindow - c.bytesRead
|
if !c.hasWindowUpdate() {
|
||||||
// update the window when more than half of it was already consumed
|
|
||||||
if diff >= (c.receiveWindowIncrement / 2) {
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
c.maybeAdjustWindowIncrement()
|
c.maybeAdjustWindowSize()
|
||||||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
|
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||||
c.lastWindowUpdateTime = time.Now()
|
|
||||||
return c.receiveWindow
|
return c.receiveWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseFlowController) IsBlocked() bool {
|
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||||
c.mutex.RLock()
|
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||||
defer c.mutex.RUnlock()
|
func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||||
|
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||||
return c.sendWindowSize() == 0
|
// don't do anything if less than half the window has been consumed
|
||||||
}
|
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||||
|
|
||||||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
|
||||||
func (c *baseFlowController) maybeAdjustWindowIncrement() {
|
|
||||||
if c.lastWindowUpdateTime.IsZero() {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rtt := c.rttStats.SmoothedRTT()
|
rtt := c.rttStats.SmoothedRTT()
|
||||||
if rtt == 0 {
|
if rtt == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
|
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||||
// interval between the window updates is sufficiently large, no need to increase the increment
|
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||||
if timeSinceLastWindowUpdate >= 2*rtt {
|
// window is consumed too fast, try to increase the window size
|
||||||
return
|
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||||
}
|
}
|
||||||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
|
c.startNewAutoTuningEpoch()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *baseFlowController) startNewAutoTuningEpoch() {
|
||||||
|
c.epochStartTime = time.Now()
|
||||||
|
c.epochStartOffset = c.bytesRead
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *baseFlowController) checkFlowControlViolation() bool {
|
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||||
|
@ -2,16 +2,18 @@ package flowcontrol
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type connectionFlowController struct {
|
type connectionFlowController struct {
|
||||||
|
lastBlockedAt protocol.ByteCount
|
||||||
baseFlowController
|
baseFlowController
|
||||||
|
|
||||||
|
queueWindowUpdate func()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ConnectionFlowController = &connectionFlowController{}
|
var _ ConnectionFlowController = &connectionFlowController{}
|
||||||
@ -21,25 +23,37 @@ var _ ConnectionFlowController = &connectionFlowController{}
|
|||||||
func NewConnectionFlowController(
|
func NewConnectionFlowController(
|
||||||
receiveWindow protocol.ByteCount,
|
receiveWindow protocol.ByteCount,
|
||||||
maxReceiveWindow protocol.ByteCount,
|
maxReceiveWindow protocol.ByteCount,
|
||||||
|
queueWindowUpdate func(),
|
||||||
rttStats *congestion.RTTStats,
|
rttStats *congestion.RTTStats,
|
||||||
|
logger utils.Logger,
|
||||||
) ConnectionFlowController {
|
) ConnectionFlowController {
|
||||||
return &connectionFlowController{
|
return &connectionFlowController{
|
||||||
baseFlowController: baseFlowController{
|
baseFlowController: baseFlowController{
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
receiveWindow: receiveWindow,
|
receiveWindow: receiveWindow,
|
||||||
receiveWindowIncrement: receiveWindow,
|
receiveWindowSize: receiveWindow,
|
||||||
maxReceiveWindowIncrement: maxReceiveWindow,
|
maxReceiveWindowSize: maxReceiveWindow,
|
||||||
|
logger: logger,
|
||||||
},
|
},
|
||||||
|
queueWindowUpdate: queueWindowUpdate,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||||
c.mutex.RLock()
|
|
||||||
defer c.mutex.RUnlock()
|
|
||||||
|
|
||||||
return c.baseFlowController.sendWindowSize()
|
return c.baseFlowController.sendWindowSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||||
|
// For every offset, it only returns true once.
|
||||||
|
// If it is blocked, the offset is returned.
|
||||||
|
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||||
|
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
c.lastBlockedAt = c.sendWindow
|
||||||
|
return true, c.sendWindow
|
||||||
|
}
|
||||||
|
|
||||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
@ -52,26 +66,34 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *connectionFlowController) MaybeQueueWindowUpdate() {
|
||||||
|
c.mutex.Lock()
|
||||||
|
hasWindowUpdate := c.hasWindowUpdate()
|
||||||
|
c.mutex.Unlock()
|
||||||
|
if hasWindowUpdate {
|
||||||
|
c.queueWindowUpdate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
oldWindowSize := c.receiveWindowSize
|
||||||
|
|
||||||
oldWindowIncrement := c.receiveWindowIncrement
|
|
||||||
offset := c.baseFlowController.getWindowUpdate()
|
offset := c.baseFlowController.getWindowUpdate()
|
||||||
if oldWindowIncrement < c.receiveWindowIncrement {
|
if oldWindowSize < c.receiveWindowSize {
|
||||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
|
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||||
}
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
return offset
|
return offset
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
// EnsureMinimumWindowSize sets a minimum window size
|
||||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||||
func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
|
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
if inc > c.receiveWindowSize {
|
||||||
|
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
|
||||||
if inc > c.receiveWindowIncrement {
|
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
|
||||||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
|
c.startNewAutoTuningEpoch()
|
||||||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
|
||||||
}
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
8
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
8
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
@ -5,17 +5,19 @@ import "github.com/lucas-clemente/quic-go/internal/protocol"
|
|||||||
type flowController interface {
|
type flowController interface {
|
||||||
// for sending
|
// for sending
|
||||||
SendWindowSize() protocol.ByteCount
|
SendWindowSize() protocol.ByteCount
|
||||||
IsBlocked() bool
|
|
||||||
UpdateSendWindow(protocol.ByteCount)
|
UpdateSendWindow(protocol.ByteCount)
|
||||||
AddBytesSent(protocol.ByteCount)
|
AddBytesSent(protocol.ByteCount)
|
||||||
// for receiving
|
// for receiving
|
||||||
AddBytesRead(protocol.ByteCount)
|
AddBytesRead(protocol.ByteCount)
|
||||||
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||||
|
MaybeQueueWindowUpdate() // queues a window update, if necessary
|
||||||
}
|
}
|
||||||
|
|
||||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||||
type StreamFlowController interface {
|
type StreamFlowController interface {
|
||||||
flowController
|
flowController
|
||||||
|
// for sending
|
||||||
|
IsBlocked() (bool, protocol.ByteCount)
|
||||||
// for receiving
|
// for receiving
|
||||||
// UpdateHighestReceived should be called when a new highest offset is received
|
// UpdateHighestReceived should be called when a new highest offset is received
|
||||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
||||||
@ -25,13 +27,15 @@ type StreamFlowController interface {
|
|||||||
// The ConnectionFlowController is the flow controller for the connection.
|
// The ConnectionFlowController is the flow controller for the connection.
|
||||||
type ConnectionFlowController interface {
|
type ConnectionFlowController interface {
|
||||||
flowController
|
flowController
|
||||||
|
// for sending
|
||||||
|
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
type connectionFlowControllerI interface {
|
type connectionFlowControllerI interface {
|
||||||
ConnectionFlowController
|
ConnectionFlowController
|
||||||
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||||
// for sending
|
// for sending
|
||||||
EnsureMinimumWindowIncrement(protocol.ByteCount)
|
EnsureMinimumWindowSize(protocol.ByteCount)
|
||||||
// for receiving
|
// for receiving
|
||||||
IncrementHighestReceived(protocol.ByteCount) error
|
IncrementHighestReceived(protocol.ByteCount) error
|
||||||
}
|
}
|
||||||
|
58
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
58
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
@ -3,7 +3,7 @@ package flowcontrol
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
@ -14,6 +14,8 @@ type streamFlowController struct {
|
|||||||
|
|
||||||
streamID protocol.StreamID
|
streamID protocol.StreamID
|
||||||
|
|
||||||
|
queueWindowUpdate func()
|
||||||
|
|
||||||
connection connectionFlowControllerI
|
connection connectionFlowControllerI
|
||||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||||
|
|
||||||
@ -30,18 +32,22 @@ func NewStreamFlowController(
|
|||||||
receiveWindow protocol.ByteCount,
|
receiveWindow protocol.ByteCount,
|
||||||
maxReceiveWindow protocol.ByteCount,
|
maxReceiveWindow protocol.ByteCount,
|
||||||
initialSendWindow protocol.ByteCount,
|
initialSendWindow protocol.ByteCount,
|
||||||
|
queueWindowUpdate func(protocol.StreamID),
|
||||||
rttStats *congestion.RTTStats,
|
rttStats *congestion.RTTStats,
|
||||||
|
logger utils.Logger,
|
||||||
) StreamFlowController {
|
) StreamFlowController {
|
||||||
return &streamFlowController{
|
return &streamFlowController{
|
||||||
streamID: streamID,
|
streamID: streamID,
|
||||||
contributesToConnection: contributesToConnection,
|
contributesToConnection: contributesToConnection,
|
||||||
connection: cfc.(connectionFlowControllerI),
|
connection: cfc.(connectionFlowControllerI),
|
||||||
|
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||||
baseFlowController: baseFlowController{
|
baseFlowController: baseFlowController{
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
receiveWindow: receiveWindow,
|
receiveWindow: receiveWindow,
|
||||||
receiveWindowIncrement: receiveWindow,
|
receiveWindowSize: receiveWindow,
|
||||||
maxReceiveWindowIncrement: maxReceiveWindow,
|
maxReceiveWindowSize: maxReceiveWindow,
|
||||||
sendWindow: initialSendWindow,
|
sendWindow: initialSendWindow,
|
||||||
|
logger: logger,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,9 +108,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||||
c.mutex.Lock()
|
|
||||||
defer c.mutex.Unlock()
|
|
||||||
|
|
||||||
window := c.baseFlowController.sendWindowSize()
|
window := c.baseFlowController.sendWindowSize()
|
||||||
if c.contributesToConnection {
|
if c.contributesToConnection {
|
||||||
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
||||||
@ -112,17 +115,44 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
|||||||
return window
|
return window
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
// IsBlocked says if it is blocked by stream-level flow control.
|
||||||
c.mutex.Lock()
|
// If it is blocked, the offset is returned.
|
||||||
defer c.mutex.Unlock()
|
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
|
||||||
|
if c.sendWindowSize() != 0 {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
return true, c.sendWindow
|
||||||
|
}
|
||||||
|
|
||||||
oldWindowIncrement := c.receiveWindowIncrement
|
func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
||||||
offset := c.baseFlowController.getWindowUpdate()
|
c.mutex.Lock()
|
||||||
if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment
|
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10))
|
c.mutex.Unlock()
|
||||||
|
if hasWindowUpdate {
|
||||||
|
c.queueWindowUpdate()
|
||||||
|
}
|
||||||
if c.contributesToConnection {
|
if c.contributesToConnection {
|
||||||
c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier))
|
c.connection.MaybeQueueWindowUpdate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||||
|
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
|
||||||
|
c.mutex.Lock()
|
||||||
|
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
|
||||||
|
if c.receivedFinalOffset {
|
||||||
|
c.mutex.Unlock()
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
oldWindowSize := c.receiveWindowSize
|
||||||
|
offset := c.baseFlowController.getWindowUpdate()
|
||||||
|
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||||
|
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||||
|
if c.contributesToConnection {
|
||||||
|
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
return offset
|
return offset
|
||||||
}
|
}
|
||||||
|
12
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
12
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
@ -6,7 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
"github.com/bifurcation/mint"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -29,17 +29,17 @@ type token struct {
|
|||||||
|
|
||||||
// A CookieGenerator generates Cookies
|
// A CookieGenerator generates Cookies
|
||||||
type CookieGenerator struct {
|
type CookieGenerator struct {
|
||||||
cookieSource crypto.StkSource
|
cookieProtector mint.CookieProtector
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCookieGenerator initializes a new CookieGenerator
|
// NewCookieGenerator initializes a new CookieGenerator
|
||||||
func NewCookieGenerator() (*CookieGenerator, error) {
|
func NewCookieGenerator() (*CookieGenerator, error) {
|
||||||
stkSource, err := crypto.NewStkSource()
|
cookieProtector, err := mint.NewDefaultCookieProtector()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &CookieGenerator{
|
return &CookieGenerator{
|
||||||
cookieSource: stkSource,
|
cookieProtector: cookieProtector,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return g.cookieSource.NewToken(data)
|
return g.cookieProtector.NewToken(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeToken decodes a Cookie
|
// DecodeToken decodes a Cookie
|
||||||
@ -62,7 +62,7 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := g.cookieSource.DecodeToken(encrypted)
|
data, err := g.cookieProtector.DecodeToken(encrypted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
24
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
@ -7,36 +7,44 @@ import (
|
|||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cookieHandler struct {
|
// A CookieHandler generates and validates cookies.
|
||||||
|
// The cookie is sent in the TLS Retry.
|
||||||
|
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
|
||||||
|
type CookieHandler struct {
|
||||||
callback func(net.Addr, *Cookie) bool
|
callback func(net.Addr, *Cookie) bool
|
||||||
|
|
||||||
cookieGenerator *CookieGenerator
|
cookieGenerator *CookieGenerator
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ mint.CookieHandler = &cookieHandler{}
|
var _ mint.CookieHandler = &CookieHandler{}
|
||||||
|
|
||||||
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) {
|
// NewCookieHandler creates a new CookieHandler.
|
||||||
|
func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
|
||||||
cookieGenerator, err := NewCookieGenerator()
|
cookieGenerator, err := NewCookieGenerator()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &cookieHandler{
|
return &CookieHandler{
|
||||||
callback: callback,
|
callback: callback,
|
||||||
cookieGenerator: cookieGenerator,
|
cookieGenerator: cookieGenerator,
|
||||||
|
logger: logger,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
// Generate a new cookie for a mint connection.
|
||||||
|
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||||
if h.callback(conn.RemoteAddr(), nil) {
|
if h.callback(conn.RemoteAddr(), nil) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
// Validate a cookie.
|
||||||
|
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||||
data, err := h.cookieGenerator.DecodeToken(token)
|
data, err := h.cookieGenerator.DecodeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return h.callback(conn.RemoteAddr(), data)
|
return h.callback(conn.RemoteAddr(), data)
|
||||||
|
88
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
88
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
@ -38,13 +38,12 @@ type cryptoSetupClient struct {
|
|||||||
lastSentCHLO []byte
|
lastSentCHLO []byte
|
||||||
certManager crypto.CertManager
|
certManager crypto.CertManager
|
||||||
|
|
||||||
divNonceChan chan []byte
|
divNonceChan chan struct{}
|
||||||
diversificationNonce []byte
|
diversificationNonce []byte
|
||||||
|
|
||||||
clientHelloCounter int
|
clientHelloCounter int
|
||||||
serverVerified bool // has the certificate chain and the proof already been verified
|
serverVerified bool // has the certificate chain and the proof already been verified
|
||||||
keyDerivation QuicCryptoKeyDerivationFunction
|
keyDerivation QuicCryptoKeyDerivationFunction
|
||||||
keyExchange KeyExchangeFunction
|
|
||||||
|
|
||||||
receivedSecurePacket bool
|
receivedSecurePacket bool
|
||||||
nullAEAD crypto.AEAD
|
nullAEAD crypto.AEAD
|
||||||
@ -52,9 +51,11 @@ type cryptoSetupClient struct {
|
|||||||
forwardSecureAEAD crypto.AEAD
|
forwardSecureAEAD crypto.AEAD
|
||||||
|
|
||||||
paramsChan chan<- TransportParameters
|
paramsChan chan<- TransportParameters
|
||||||
aeadChanged chan<- protocol.EncryptionLevel
|
handshakeEvent chan<- struct{}
|
||||||
|
|
||||||
params *TransportParameters
|
params *TransportParameters
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ CryptoSetup = &cryptoSetupClient{}
|
var _ CryptoSetup = &cryptoSetupClient{}
|
||||||
@ -74,15 +75,17 @@ func NewCryptoSetupClient(
|
|||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
paramsChan chan<- TransportParameters,
|
paramsChan chan<- TransportParameters,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
handshakeEvent chan<- struct{},
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
negotiatedVersions []protocol.VersionNumber,
|
negotiatedVersions []protocol.VersionNumber,
|
||||||
|
logger utils.Logger,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &cryptoSetupClient{
|
divNonceChan := make(chan struct{})
|
||||||
|
cs := &cryptoSetupClient{
|
||||||
cryptoStream: cryptoStream,
|
cryptoStream: cryptoStream,
|
||||||
hostname: hostname,
|
hostname: hostname,
|
||||||
connID: connID,
|
connID: connID,
|
||||||
@ -90,19 +93,20 @@ func NewCryptoSetupClient(
|
|||||||
certManager: crypto.NewCertManager(tlsConfig),
|
certManager: crypto.NewCertManager(tlsConfig),
|
||||||
params: params,
|
params: params,
|
||||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||||
keyExchange: getEphermalKEX,
|
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
aeadChanged: aeadChanged,
|
handshakeEvent: handshakeEvent,
|
||||||
initialVersion: initialVersion,
|
initialVersion: initialVersion,
|
||||||
negotiatedVersions: negotiatedVersions,
|
negotiatedVersions: negotiatedVersions,
|
||||||
divNonceChan: make(chan []byte),
|
divNonceChan: divNonceChan,
|
||||||
}, nil
|
logger: logger,
|
||||||
|
}
|
||||||
|
return cs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||||
messageChan := make(chan HandshakeMessage)
|
messageChan := make(chan HandshakeMessage)
|
||||||
errorChan := make(chan error)
|
errorChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@ -116,37 +120,30 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
err := h.maybeUpgradeCrypto()
|
if err := h.maybeUpgradeCrypto(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
sendCHLO := h.secureAEAD == nil
|
sendCHLO := h.secureAEAD == nil
|
||||||
h.mutex.RUnlock()
|
h.mutex.RUnlock()
|
||||||
|
|
||||||
if sendCHLO {
|
if sendCHLO {
|
||||||
err = h.sendCHLO()
|
if err := h.sendCHLO(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var message HandshakeMessage
|
var message HandshakeMessage
|
||||||
select {
|
select {
|
||||||
case divNonce := <-h.divNonceChan:
|
case <-h.divNonceChan:
|
||||||
if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) {
|
|
||||||
return errConflictingDiversificationNonces
|
|
||||||
}
|
|
||||||
h.diversificationNonce = divNonce
|
|
||||||
// there's no message to process, but we should try upgrading the crypto again
|
// there's no message to process, but we should try upgrading the crypto again
|
||||||
continue
|
continue
|
||||||
case message = <-messageChan:
|
case message = <-messageChan:
|
||||||
case err = <-errorChan:
|
case err := <-errorChan:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.Debugf("Got %s", message)
|
h.logger.Debugf("Got %s", message)
|
||||||
switch message.Tag {
|
switch message.Tag {
|
||||||
case TagREJ:
|
case TagREJ:
|
||||||
if err := h.handleREJMessage(message.Data); err != nil {
|
if err := h.handleREJMessage(message.Data); err != nil {
|
||||||
@ -159,8 +156,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||||||
}
|
}
|
||||||
// blocks until the session has received the parameters
|
// blocks until the session has received the parameters
|
||||||
h.paramsChan <- *params
|
h.paramsChan <- *params
|
||||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
h.handshakeEvent <- struct{}{}
|
||||||
close(h.aeadChanged)
|
close(h.handshakeEvent)
|
||||||
default:
|
default:
|
||||||
return qerr.InvalidCryptoMessageType
|
return qerr.InvalidCryptoMessageType
|
||||||
}
|
}
|
||||||
@ -211,7 +208,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
|||||||
|
|
||||||
err = h.certManager.Verify(h.hostname)
|
err = h.certManager.Verify(h.hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Infof("Certificate validation failed: %s", err.Error())
|
h.logger.Infof("Certificate validation failed: %s", err.Error())
|
||||||
return qerr.ProofInvalid
|
return qerr.ProofInvalid
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,7 +216,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
|||||||
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
|
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
|
||||||
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
|
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
|
||||||
if !validProof {
|
if !validProof {
|
||||||
utils.Infof("Server proof verification failed")
|
h.logger.Infof("Server proof verification failed")
|
||||||
return qerr.ProofInvalid
|
return qerr.ProofInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -277,6 +274,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*Trans
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
|
||||||
|
|
||||||
params, err := readHelloMap(cryptoData)
|
params, err := readHelloMap(cryptoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -322,6 +320,7 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||||||
if h.secureAEAD != nil {
|
if h.secureAEAD != nil {
|
||||||
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||||
h.receivedSecurePacket = true
|
h.receivedSecurePacket = true
|
||||||
return data, protocol.EncryptionSecure, nil
|
return data, protocol.EncryptionSecure, nil
|
||||||
}
|
}
|
||||||
@ -373,16 +372,28 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
|
|||||||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) DiversificationNonce() []byte {
|
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||||
panic("not needed for cryptoSetupClient")
|
h.mutex.Lock()
|
||||||
|
defer h.mutex.Unlock()
|
||||||
|
return ConnectionState{
|
||||||
|
HandshakeComplete: h.forwardSecureAEAD != nil,
|
||||||
|
PeerCertificates: h.certManager.GetChain(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
|
||||||
h.divNonceChan <- data
|
h.mutex.Lock()
|
||||||
}
|
if len(h.diversificationNonce) > 0 {
|
||||||
|
defer h.mutex.Unlock()
|
||||||
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
|
if !bytes.Equal(h.diversificationNonce, divNonce) {
|
||||||
panic("not needed for cryptoSetupServer")
|
return errConflictingDiversificationNonces
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.diversificationNonce = divNonce
|
||||||
|
h.mutex.Unlock()
|
||||||
|
h.divNonceChan <- struct{}{}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) sendCHLO() error {
|
func (h *cryptoSetupClient) sendCHLO() error {
|
||||||
@ -403,7 +414,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
|
|||||||
Data: tags,
|
Data: tags,
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.Debugf("Sending %s", message)
|
h.logger.Debugf("Sending %s", message)
|
||||||
message.Write(b)
|
message.Write(b)
|
||||||
|
|
||||||
_, err = h.cryptoStream.Write(b.Bytes())
|
_, err = h.cryptoStream.Write(b.Bytes())
|
||||||
@ -462,7 +473,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
|
|||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
||||||
}
|
}
|
||||||
paddingSize := protocol.ClientHelloMinimumSize - size
|
paddingSize := protocol.MinClientHelloSize - size
|
||||||
if paddingSize > 0 {
|
if paddingSize > 0 {
|
||||||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
||||||
}
|
}
|
||||||
@ -500,10 +511,9 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||||
h.aeadChanged <- protocol.EncryptionSecure
|
h.handshakeEvent <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
72
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
72
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
@ -19,10 +19,12 @@ import (
|
|||||||
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
|
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
|
||||||
|
|
||||||
// KeyExchangeFunction is used to make a new KEX
|
// KeyExchangeFunction is used to make a new KEX
|
||||||
type KeyExchangeFunction func() crypto.KeyExchange
|
type KeyExchangeFunction func() (crypto.KeyExchange, error)
|
||||||
|
|
||||||
// The CryptoSetupServer handles all things crypto for the Session
|
// The CryptoSetupServer handles all things crypto for the Session
|
||||||
type cryptoSetupServer struct {
|
type cryptoSetupServer struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
|
||||||
connID protocol.ConnectionID
|
connID protocol.ConnectionID
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
scfg *ServerConfig
|
scfg *ServerConfig
|
||||||
@ -42,7 +44,7 @@ type cryptoSetupServer struct {
|
|||||||
|
|
||||||
receivedParams bool
|
receivedParams bool
|
||||||
paramsChan chan<- TransportParameters
|
paramsChan chan<- TransportParameters
|
||||||
aeadChanged chan<- protocol.EncryptionLevel
|
handshakeEvent chan<- struct{}
|
||||||
|
|
||||||
keyDerivation QuicCryptoKeyDerivationFunction
|
keyDerivation QuicCryptoKeyDerivationFunction
|
||||||
keyExchange KeyExchangeFunction
|
keyExchange KeyExchangeFunction
|
||||||
@ -51,7 +53,9 @@ type cryptoSetupServer struct {
|
|||||||
|
|
||||||
params *TransportParameters
|
params *TransportParameters
|
||||||
|
|
||||||
mutex sync.RWMutex
|
sni string // need to fill out the ConnectionState
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ CryptoSetup = &cryptoSetupServer{}
|
var _ CryptoSetup = &cryptoSetupServer{}
|
||||||
@ -71,12 +75,14 @@ func NewCryptoSetup(
|
|||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
remoteAddr net.Addr,
|
remoteAddr net.Addr,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
|
divNonce []byte,
|
||||||
scfg *ServerConfig,
|
scfg *ServerConfig,
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
acceptSTK func(net.Addr, *Cookie) bool,
|
acceptSTK func(net.Addr, *Cookie) bool,
|
||||||
paramsChan chan<- TransportParameters,
|
paramsChan chan<- TransportParameters,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
handshakeEvent chan<- struct{},
|
||||||
|
logger utils.Logger,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -88,6 +94,7 @@ func NewCryptoSetup(
|
|||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
version: version,
|
version: version,
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
|
diversificationNonce: divNonce,
|
||||||
scfg: scfg,
|
scfg: scfg,
|
||||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||||
keyExchange: getEphermalKEX,
|
keyExchange: getEphermalKEX,
|
||||||
@ -96,7 +103,8 @@ func NewCryptoSetup(
|
|||||||
acceptSTKCallback: acceptSTK,
|
acceptSTKCallback: acceptSTK,
|
||||||
sentSHLO: make(chan struct{}),
|
sentSHLO: make(chan struct{}),
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
aeadChanged: aeadChanged,
|
handshakeEvent: handshakeEvent,
|
||||||
|
logger: logger,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
|
|||||||
return qerr.InvalidCryptoMessageType
|
return qerr.InvalidCryptoMessageType
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.Debugf("Got %s", message)
|
h.logger.Debugf("Got %s", message)
|
||||||
done, err := h.handleMessage(chloData.Bytes(), message.Data)
|
done, err := h.handleMessage(chloData.Bytes(), message.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -139,6 +147,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
|||||||
if sni == "" {
|
if sni == "" {
|
||||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
||||||
}
|
}
|
||||||
|
h.sni = sni
|
||||||
|
|
||||||
// prevent version downgrade attacks
|
// prevent version downgrade attacks
|
||||||
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
|
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
|
||||||
@ -182,7 +191,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
|||||||
if _, err := h.cryptoStream.Write(reply); err != nil {
|
if _, err := h.cryptoStream.Write(reply); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
h.handshakeEvent <- struct{}{}
|
||||||
close(h.sentSHLO)
|
close(h.sentSHLO)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@ -205,10 +214,11 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||||||
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
||||||
|
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
|
||||||
h.receivedForwardSecurePacket = true
|
h.receivedForwardSecurePacket = true
|
||||||
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan
|
// wait for the send on the handshakeEvent chan
|
||||||
<-h.sentSHLO
|
<-h.sentSHLO
|
||||||
close(h.aeadChanged)
|
close(h.handshakeEvent)
|
||||||
}
|
}
|
||||||
return res, protocol.EncryptionForwardSecure, nil
|
return res, protocol.EncryptionForwardSecure, nil
|
||||||
}
|
}
|
||||||
@ -219,6 +229,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||||||
if h.secureAEAD != nil {
|
if h.secureAEAD != nil {
|
||||||
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||||
h.receivedSecurePacket = true
|
h.receivedSecurePacket = true
|
||||||
return res, protocol.EncryptionSecure, nil
|
return res, protocol.EncryptionSecure, nil
|
||||||
}
|
}
|
||||||
@ -294,17 +305,13 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
|||||||
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
||||||
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
|
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Debugf("STK invalid: %s", err.Error())
|
h.logger.Debugf("STK invalid: %s", err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return h.acceptSTKCallback(h.remoteAddr, stk)
|
return h.acceptSTKCallback(h.remoteAddr, stk)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||||
if len(chlo) < protocol.ClientHelloMinimumSize {
|
|
||||||
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
|
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -341,7 +348,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
|||||||
|
|
||||||
var serverReply bytes.Buffer
|
var serverReply bytes.Buffer
|
||||||
message.Write(&serverReply)
|
message.Write(&serverReply)
|
||||||
utils.Debugf("Sending %s", message)
|
h.logger.Debugf("Sending %s", message)
|
||||||
return serverReply.Bytes(), nil
|
return serverReply.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -365,11 +372,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
h.diversificationNonce = make([]byte, 32)
|
|
||||||
if _, err = rand.Read(h.diversificationNonce); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientNonce := cryptoData[TagNONC]
|
clientNonce := cryptoData[TagNONC]
|
||||||
err = h.validateClientNonce(clientNonce)
|
err = h.validateClientNonce(clientNonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -400,14 +402,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||||
h.aeadChanged <- protocol.EncryptionSecure
|
h.handshakeEvent <- struct{}{}
|
||||||
|
|
||||||
// Generate a new curve instance to derive the forward secure key
|
// Generate a new curve instance to derive the forward secure key
|
||||||
var fsNonce bytes.Buffer
|
var fsNonce bytes.Buffer
|
||||||
fsNonce.Write(clientNonce)
|
fsNonce.Write(clientNonce)
|
||||||
fsNonce.Write(serverNonce)
|
fsNonce.Write(serverNonce)
|
||||||
ephermalKex := h.keyExchange()
|
ephermalKex, err := h.keyExchange()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
|
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -427,6 +432,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
|
||||||
|
|
||||||
replyMap := h.params.getHelloMap()
|
replyMap := h.params.getHelloMap()
|
||||||
// add crypto parameters
|
// add crypto parameters
|
||||||
@ -445,21 +451,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||||||
}
|
}
|
||||||
var reply bytes.Buffer
|
var reply bytes.Buffer
|
||||||
message.Write(&reply)
|
message.Write(&reply)
|
||||||
utils.Debugf("Sending %s", message)
|
h.logger.Debugf("Sending %s", message)
|
||||||
return reply.Bytes(), nil
|
return reply.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DiversificationNonce returns the diversification nonce
|
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
|
||||||
func (h *cryptoSetupServer) DiversificationNonce() []byte {
|
h.mutex.Lock()
|
||||||
return h.diversificationNonce
|
defer h.mutex.Unlock()
|
||||||
}
|
return ConnectionState{
|
||||||
|
ServerName: h.sni,
|
||||||
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
|
HandshakeComplete: h.receivedForwardSecurePacket,
|
||||||
panic("not needed for cryptoSetupServer")
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
|
|
||||||
panic("not needed for cryptoSetupServer")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||||
|
185
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
185
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
@ -1,10 +1,9 @@
|
|||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
"github.com/bifurcation/mint"
|
||||||
@ -12,6 +11,9 @@ import (
|
|||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
|
||||||
|
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
|
||||||
|
|
||||||
// KeyDerivationFunction is used for key derivation
|
// KeyDerivationFunction is used for key derivation
|
||||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
||||||
|
|
||||||
@ -20,64 +22,33 @@ type cryptoSetupTLS struct {
|
|||||||
|
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
|
|
||||||
tls mintTLS
|
|
||||||
conn *fakeConn
|
|
||||||
|
|
||||||
nextPacketType protocol.PacketType
|
|
||||||
|
|
||||||
keyDerivation KeyDerivationFunction
|
keyDerivation KeyDerivationFunction
|
||||||
nullAEAD crypto.AEAD
|
nullAEAD crypto.AEAD
|
||||||
aead crypto.AEAD
|
aead crypto.AEAD
|
||||||
|
|
||||||
aeadChanged chan<- protocol.EncryptionLevel
|
tls MintTLS
|
||||||
|
cryptoStream *CryptoStreamConn
|
||||||
|
handshakeEvent chan<- struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
||||||
|
|
||||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||||
func NewCryptoSetupTLSServer(
|
func NewCryptoSetupTLSServer(
|
||||||
cryptoStream io.ReadWriter,
|
tls MintTLS,
|
||||||
connID protocol.ConnectionID,
|
cryptoStream *CryptoStreamConn,
|
||||||
tlsConfig *tls.Config,
|
nullAEAD crypto.AEAD,
|
||||||
remoteAddr net.Addr,
|
handshakeEvent chan<- struct{},
|
||||||
params *TransportParameters,
|
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
|
||||||
checkCookie func(net.Addr, *Cookie) bool,
|
|
||||||
supportedVersions []protocol.VersionNumber,
|
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetup, error) {
|
) CryptoSetupTLS {
|
||||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mintConf.RequireCookie = true
|
|
||||||
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn := &fakeConn{
|
|
||||||
stream: cryptoStream,
|
|
||||||
pers: protocol.PerspectiveServer,
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
}
|
|
||||||
mintConn := mint.Server(conn, mintConf)
|
|
||||||
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
|
|
||||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
perspective: protocol.PerspectiveServer,
|
tls: tls,
|
||||||
tls: &mintController{mintConn},
|
cryptoStream: cryptoStream,
|
||||||
conn: conn,
|
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
|
perspective: protocol.PerspectiveServer,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
aeadChanged: aeadChanged,
|
handshakeEvent: handshakeEvent,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||||
@ -85,59 +56,44 @@ func NewCryptoSetupTLSClient(
|
|||||||
cryptoStream io.ReadWriter,
|
cryptoStream io.ReadWriter,
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
hostname string,
|
hostname string,
|
||||||
tlsConfig *tls.Config,
|
handshakeEvent chan<- struct{},
|
||||||
params *TransportParameters,
|
tls MintTLS,
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
|
||||||
initialVersion protocol.VersionNumber,
|
|
||||||
supportedVersions []protocol.VersionNumber,
|
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetupTLS, error) {
|
||||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mintConf.ServerName = hostname
|
|
||||||
conn := &fakeConn{
|
|
||||||
stream: cryptoStream,
|
|
||||||
pers: protocol.PerspectiveClient,
|
|
||||||
}
|
|
||||||
mintConn := mint.Client(conn, mintConf)
|
|
||||||
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
|
|
||||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
conn: conn,
|
|
||||||
perspective: protocol.PerspectiveClient,
|
perspective: protocol.PerspectiveClient,
|
||||||
tls: &mintController{mintConn},
|
tls: tls,
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
aeadChanged: aeadChanged,
|
handshakeEvent: handshakeEvent,
|
||||||
nextPacketType: protocol.PacketTypeInitial,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||||
handshakeLoop:
|
if h.perspective == protocol.PerspectiveServer {
|
||||||
for {
|
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
|
||||||
switch alert := h.tls.Handshake(); alert {
|
// send out that data now
|
||||||
case mint.AlertNoAlert: // handshake complete
|
if _, err := h.cryptoStream.Flush(); err != nil {
|
||||||
break handshakeLoop
|
|
||||||
case mint.AlertWouldBlock:
|
|
||||||
h.determineNextPacketType()
|
|
||||||
if err := h.conn.Continue(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
}
|
||||||
|
|
||||||
|
handshakeLoop:
|
||||||
|
for {
|
||||||
|
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||||
}
|
}
|
||||||
|
switch h.tls.State() {
|
||||||
|
case mint.StateClientStart: // this happens if a stateless retry is performed
|
||||||
|
return ErrCloseSessionForRetry
|
||||||
|
case mint.StateClientConnected, mint.StateServerConnected:
|
||||||
|
break handshakeLoop
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
aead, err := h.keyDerivation(h.tls, h.perspective)
|
||||||
@ -148,28 +104,23 @@ handshakeLoop:
|
|||||||
h.aead = aead
|
h.aead = aead
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
|
|
||||||
// signal to the outside world that the handshake completed
|
h.handshakeEvent <- struct{}{}
|
||||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
close(h.handshakeEvent)
|
||||||
close(h.aeadChanged)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||||
|
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
defer h.mutex.RUnlock()
|
defer h.mutex.RUnlock()
|
||||||
|
|
||||||
if h.aead != nil {
|
if h.aead == nil {
|
||||||
data, err := h.aead.Open(dst, src, packetNumber, associatedData)
|
return nil, errors.New("no 1-RTT sealer")
|
||||||
if err != nil {
|
|
||||||
return nil, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
}
|
||||||
return data, protocol.EncryptionForwardSecure, nil
|
return h.aead.Open(dst, src, packetNumber, associatedData)
|
||||||
}
|
|
||||||
data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
|
||||||
return data, protocol.EncryptionUnencrypted, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||||
@ -204,39 +155,13 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
|
|||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) determineNextPacketType() error {
|
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
defer h.mutex.Unlock()
|
defer h.mutex.Unlock()
|
||||||
state := h.tls.State().HandshakeState
|
mintConnState := h.tls.ConnectionState()
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
return ConnectionState{
|
||||||
switch state {
|
// TODO: set the ServerName, once mint exports it
|
||||||
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
|
HandshakeComplete: h.aead != nil,
|
||||||
h.nextPacketType = protocol.PacketTypeRetry
|
PeerCertificates: mintConnState.PeerCertificates,
|
||||||
case "ServerStateWaitFinished":
|
|
||||||
h.nextPacketType = protocol.PacketTypeHandshake
|
|
||||||
default:
|
|
||||||
// TODO: accept 0-RTT data
|
|
||||||
return fmt.Errorf("Unexpected handshake state: %s", state)
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// client
|
|
||||||
if state != "ClientStateWaitSH" {
|
|
||||||
h.nextPacketType = protocol.PacketTypeHandshake
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.nextPacketType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
|
||||||
panic("diversification nonce not needed for TLS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
|
|
||||||
panic("diversification nonce not needed for TLS")
|
|
||||||
}
|
}
|
||||||
|
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
Normal file
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The CryptoStreamConn is used as the net.Conn passed to mint.
|
||||||
|
// It has two operating modes:
|
||||||
|
// 1. It can read and write to bytes.Buffers.
|
||||||
|
// 2. It can use a quic.Stream for reading and writing.
|
||||||
|
// The buffer-mode is only used by the server, in order to statelessly handle retries.
|
||||||
|
type CryptoStreamConn struct {
|
||||||
|
remoteAddr net.Addr
|
||||||
|
|
||||||
|
// the buffers are used before the session is initialized
|
||||||
|
readBuf bytes.Buffer
|
||||||
|
writeBuf bytes.Buffer
|
||||||
|
|
||||||
|
// stream will be set once the session is initialized
|
||||||
|
stream io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = &CryptoStreamConn{}
|
||||||
|
|
||||||
|
// NewCryptoStreamConn creates a new CryptoStreamConn
|
||||||
|
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
|
||||||
|
return &CryptoStreamConn{remoteAddr: remoteAddr}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
|
||||||
|
if c.stream != nil {
|
||||||
|
return c.stream.Read(b)
|
||||||
|
}
|
||||||
|
return c.readBuf.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDataForReading adds data to the read buffer.
|
||||||
|
// This data will ONLY be read when the stream has not been set.
|
||||||
|
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
|
||||||
|
c.readBuf.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
|
||||||
|
if c.stream != nil {
|
||||||
|
return c.stream.Write(p)
|
||||||
|
}
|
||||||
|
return c.writeBuf.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
|
||||||
|
func (c *CryptoStreamConn) GetDataForWriting() []byte {
|
||||||
|
defer c.writeBuf.Reset()
|
||||||
|
data := make([]byte, c.writeBuf.Len())
|
||||||
|
copy(data, c.writeBuf.Bytes())
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStream sets the stream.
|
||||||
|
// After setting the stream, the read and write buffer won't be used any more.
|
||||||
|
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
|
||||||
|
c.stream = stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush copies the contents of the write buffer to the stream
|
||||||
|
func (c *CryptoStreamConn) Flush() (int, error) {
|
||||||
|
n, err := io.Copy(c.stream, &c.writeBuf)
|
||||||
|
return int(n), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close is not implemented
|
||||||
|
func (c *CryptoStreamConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr is not implemented
|
||||||
|
func (c *CryptoStreamConn) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote address
|
||||||
|
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
|
||||||
|
return c.remoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline is not implemented
|
||||||
|
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline is not implemented
|
||||||
|
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline is not implemented
|
||||||
|
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
16
vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go
generated
vendored
@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -24,27 +23,26 @@ var (
|
|||||||
// used for all connections for 60 seconds is negligible. Thus we can amortise
|
// used for all connections for 60 seconds is negligible. Thus we can amortise
|
||||||
// the Diffie-Hellman key generation at the server over all the connections in a
|
// the Diffie-Hellman key generation at the server over all the connections in a
|
||||||
// small time span.
|
// small time span.
|
||||||
func getEphermalKEX() (res crypto.KeyExchange) {
|
func getEphermalKEX() (crypto.KeyExchange, error) {
|
||||||
kexMutex.RLock()
|
kexMutex.RLock()
|
||||||
res = kexCurrent
|
res := kexCurrent
|
||||||
t := kexCurrentTime
|
t := kexCurrentTime
|
||||||
kexMutex.RUnlock()
|
kexMutex.RUnlock()
|
||||||
if res != nil && time.Since(t) < kexLifetime {
|
if res != nil && time.Since(t) < kexLifetime {
|
||||||
return res
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
kexMutex.Lock()
|
kexMutex.Lock()
|
||||||
defer kexMutex.Unlock()
|
defer kexMutex.Unlock()
|
||||||
// Check if still unfulfilled
|
// Check if still unfulfilled
|
||||||
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
|
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
|
||||||
kex, err := crypto.NewCurve25519KEX()
|
kex, err := crypto.NewCurve25519KEX()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Errorf("could not set KEX: %s", err.Error())
|
return nil, err
|
||||||
return kexCurrent
|
|
||||||
}
|
}
|
||||||
kexCurrent = kex
|
kexCurrent = kex
|
||||||
kexCurrentTime = time.Now()
|
kexCurrentTime = time.Now()
|
||||||
return kexCurrent
|
return kexCurrent, nil
|
||||||
}
|
}
|
||||||
return kexCurrent
|
return kexCurrent, nil
|
||||||
}
|
}
|
||||||
|
5
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
5
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
|
|||||||
|
|
||||||
offset := uint32(0)
|
offset := uint32(0)
|
||||||
for i, t := range h.getTagsSorted() {
|
for i, t := range h.getTagsSorted() {
|
||||||
v := data[Tag(t)]
|
v := data[t]
|
||||||
b.Write(v)
|
b.Write(v)
|
||||||
offset += uint32(len(v))
|
offset += uint32(len(v))
|
||||||
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
|
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
|
||||||
@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag {
|
|||||||
func (h HandshakeMessage) String() string {
|
func (h HandshakeMessage) String() string {
|
||||||
var pad string
|
var pad string
|
||||||
res := tagToString(h.Tag) + ":\n"
|
res := tagToString(h.Tag) + ":\n"
|
||||||
for _, t := range h.getTagsSorted() {
|
for _, tag := range h.getTagsSorted() {
|
||||||
tag := Tag(t)
|
|
||||||
if tag == TagPAD {
|
if tag == TagPAD {
|
||||||
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
|
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
|
||||||
} else {
|
} else {
|
||||||
|
57
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
57
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
generated
vendored
@ -1,6 +1,11 @@
|
|||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -10,16 +15,54 @@ type Sealer interface {
|
|||||||
Overhead() int
|
Overhead() int
|
||||||
}
|
}
|
||||||
|
|
||||||
// CryptoSetup is a crypto setup
|
// A TLSExtensionHandler sends and received the QUIC TLS extension.
|
||||||
type CryptoSetup interface {
|
// It provides the parameters sent by the peer on a channel.
|
||||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
type TLSExtensionHandler interface {
|
||||||
|
Send(mint.HandshakeType, *mint.ExtensionList) error
|
||||||
|
Receive(mint.HandshakeType, *mint.ExtensionList) error
|
||||||
|
GetPeerParams() <-chan TransportParameters
|
||||||
|
}
|
||||||
|
|
||||||
|
// MintTLS combines some methods needed to interact with mint.
|
||||||
|
type MintTLS interface {
|
||||||
|
crypto.TLSExporter
|
||||||
|
|
||||||
|
// additional methods
|
||||||
|
Handshake() mint.Alert
|
||||||
|
State() mint.State
|
||||||
|
ConnectionState() mint.ConnectionState
|
||||||
|
|
||||||
|
SetCryptoStream(io.ReadWriter)
|
||||||
|
}
|
||||||
|
|
||||||
|
type baseCryptoSetup interface {
|
||||||
HandleCryptoStream() error
|
HandleCryptoStream() error
|
||||||
// TODO: clean up this interface
|
ConnectionState() ConnectionState
|
||||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
|
||||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
|
||||||
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
|
|
||||||
|
|
||||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CryptoSetup is the crypto setup used by gQUIC
|
||||||
|
type CryptoSetup interface {
|
||||||
|
baseCryptoSetup
|
||||||
|
|
||||||
|
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CryptoSetupTLS is the crypto setup used by IETF QUIC
|
||||||
|
type CryptoSetupTLS interface {
|
||||||
|
baseCryptoSetup
|
||||||
|
|
||||||
|
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||||
|
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState records basic details about the QUIC connection.
|
||||||
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
|
type ConnectionState struct {
|
||||||
|
HandshakeComplete bool // handshake is complete
|
||||||
|
ServerName string // server name requested by client, if any (server side only)
|
||||||
|
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
|
||||||
|
}
|
||||||
|
127
vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go
generated
vendored
127
vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go
generated
vendored
@ -1,127 +0,0 @@
|
|||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
gocrypto "crypto"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
|
|
||||||
mconf := &mint.Config{
|
|
||||||
NonBlocking: true,
|
|
||||||
CipherSuites: []mint.CipherSuite{
|
|
||||||
mint.TLS_AES_128_GCM_SHA256,
|
|
||||||
mint.TLS_AES_256_GCM_SHA384,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if tlsConf != nil {
|
|
||||||
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
|
|
||||||
for i, certChain := range tlsConf.Certificates {
|
|
||||||
mconf.Certificates[i] = &mint.Certificate{
|
|
||||||
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
|
|
||||||
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
|
|
||||||
}
|
|
||||||
for j, cert := range certChain.Certificate {
|
|
||||||
c, err := x509.ParseCertificate(cert)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mconf.Certificates[i].Chain[j] = c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return mconf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mintTLS interface {
|
|
||||||
// These two methods are the same as the crypto.TLSExporter interface.
|
|
||||||
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
|
|
||||||
GetCipherSuite() mint.CipherSuiteParams
|
|
||||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
|
||||||
// additional methods
|
|
||||||
Handshake() mint.Alert
|
|
||||||
State() mint.ConnectionState
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ crypto.TLSExporter = (mintTLS)(nil)
|
|
||||||
|
|
||||||
type mintController struct {
|
|
||||||
conn *mint.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ mintTLS = &mintController{}
|
|
||||||
|
|
||||||
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
|
|
||||||
return mc.conn.State().CipherSuite
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
|
||||||
return mc.conn.ComputeExporter(label, context, keyLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mintController) Handshake() mint.Alert {
|
|
||||||
return mc.conn.Handshake()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mintController) State() mint.ConnectionState {
|
|
||||||
return mc.conn.State()
|
|
||||||
}
|
|
||||||
|
|
||||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
|
||||||
// so we wrap a stream such that implements a net.Conn
|
|
||||||
type fakeConn struct {
|
|
||||||
stream io.ReadWriter
|
|
||||||
pers protocol.Perspective
|
|
||||||
remoteAddr net.Addr
|
|
||||||
|
|
||||||
blockRead bool
|
|
||||||
writeBuffer bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ net.Conn = &fakeConn{}
|
|
||||||
|
|
||||||
func (c *fakeConn) Read(b []byte) (int, error) {
|
|
||||||
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
c.blockRead = true // block the next Read call
|
|
||||||
return c.stream.Read(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeConn) Write(p []byte) (int, error) {
|
|
||||||
if c.pers == protocol.PerspectiveClient {
|
|
||||||
return c.stream.Write(p)
|
|
||||||
}
|
|
||||||
// Buffer all writes by the server.
|
|
||||||
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
|
|
||||||
return c.writeBuffer.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeConn) Continue() error {
|
|
||||||
c.blockRead = false
|
|
||||||
if c.pers == protocol.PerspectiveClient {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// write all contents of the write buffer to the stream.
|
|
||||||
_, err := c.stream.Write(c.writeBuffer.Bytes())
|
|
||||||
c.writeBuffer.Reset()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeConn) Close() error { return nil }
|
|
||||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
|
||||||
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
|
||||||
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
|
|
||||||
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
|
|
||||||
func (c *fakeConn) SetDeadline(time.Time) error { return nil }
|
|
28
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
28
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
@ -102,32 +102,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
|
|||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||||
}
|
}
|
||||||
|
|
||||||
var pubs_kexs []struct{Length uint32; Value []byte}
|
var pubsKexs []struct {
|
||||||
var last_len uint32
|
Length uint32
|
||||||
|
Value []byte
|
||||||
for i := 0; i < len(pubs)-3; i += int(last_len)+3 {
|
}
|
||||||
|
var lastLen uint32
|
||||||
|
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
|
||||||
// the PUBS value is always prepended by 3 byte little endian length field
|
// the PUBS value is always prepended by 3 byte little endian length field
|
||||||
|
|
||||||
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &last_len);
|
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
|
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
|
||||||
}
|
}
|
||||||
if last_len == 0 {
|
if lastLen == 0 {
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||||
}
|
}
|
||||||
|
|
||||||
if i+3+int(last_len) > len(pubs) {
|
if i+3+int(lastLen) > len(pubs) {
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||||
}
|
}
|
||||||
|
|
||||||
pubs_kexs = append(pubs_kexs, struct{Length uint32; Value []byte}{last_len, pubs[i+3:i+3+int(last_len)]})
|
pubsKexs = append(pubsKexs, struct {
|
||||||
|
Length uint32
|
||||||
|
Value []byte
|
||||||
|
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
|
||||||
}
|
}
|
||||||
|
|
||||||
if c255Foundat >= len(pubs_kexs) {
|
if c255Foundat >= len(pubsKexs) {
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
|
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
|
||||||
}
|
}
|
||||||
|
|
||||||
if pubs_kexs[c255Foundat].Length != 32 {
|
if pubsKexs[c255Foundat].Length != 32 {
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,8 +142,7 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
|
||||||
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs_kexs[c255Foundat].Value)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user