update vendor

This commit is contained in:
Zhixing Wang 2018-05-17 13:14:45 +01:00
parent 030fbd8521
commit 25d223c384
186 changed files with 14653 additions and 5690 deletions

View File

@ -18,6 +18,12 @@ library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns
with earlier TLS versions. However, unnecessary parts will be ruthlessly cut with earlier TLS versions. However, unnecessary parts will be ruthlessly cut
off. off.
## DTLS Support
Mint has partial support for DTLS, but that support is not yet complete
and may still contain serious defects.
## Quickstart ## Quickstart
Installation is the same as for any other Go package: Installation is the same as for any other Go package:

View File

@ -46,6 +46,7 @@ const (
AlertBadCertificateHashValue Alert = 114 AlertBadCertificateHashValue Alert = 114
AlertUnknownPSKIdentity Alert = 115 AlertUnknownPSKIdentity Alert = 115
AlertNoApplicationProtocol Alert = 120 AlertNoApplicationProtocol Alert = 120
AlertStatelessRetry Alert = 253
AlertWouldBlock Alert = 254 AlertWouldBlock Alert = 254
AlertNoAlert Alert = 255 AlertNoAlert Alert = 255
) )
@ -82,6 +83,7 @@ var alertText = map[Alert]string{
AlertUnknownPSKIdentity: "unknown PSK identity", AlertUnknownPSKIdentity: "unknown PSK identity",
AlertNoApplicationProtocol: "no application protocol", AlertNoApplicationProtocol: "no application protocol",
AlertNoRenegotiation: "no renegotiation", AlertNoRenegotiation: "no renegotiation",
AlertStatelessRetry: "stateless retry",
AlertWouldBlock: "would have blocked", AlertWouldBlock: "would have blocked",
AlertNoAlert: "no alert", AlertNoAlert: "no alert",
} }

View File

@ -3,6 +3,7 @@ package mint
import ( import (
"bytes" "bytes"
"crypto" "crypto"
"crypto/x509"
"hash" "hash"
"time" "time"
) )
@ -49,29 +50,31 @@ import (
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut; // WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) // CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ClientStateStart struct { type clientStateStart struct {
Caps Capabilities Config *Config
Opts ConnectionOptions Opts ConnectionOptions
Params ConnectionParameters Params ConnectionParameters
cookie []byte cookie []byte
firstClientHello *HandshakeMessage firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage helloRetryRequest *HandshakeMessage
hsCtx *HandshakeContext
} }
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { var _ HandshakeState = &clientStateStart{}
if hm != nil {
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message") func (state clientStateStart) State() State {
return nil, nil, AlertUnexpectedMessage return StateClientStart
} }
func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
// key_shares // key_shares
offeredDH := map[NamedGroup][]byte{} offeredDH := map[NamedGroup][]byte{}
ks := KeyShareExtension{ ks := KeyShareExtension{
HandshakeType: HandshakeTypeClientHello, HandshakeType: HandshakeTypeClientHello,
Shares: make([]KeyShareEntry, len(state.Caps.Groups)), Shares: make([]KeyShareEntry, len(state.Config.Groups)),
} }
for i, group := range state.Caps.Groups { for i, group := range state.Config.Groups {
pub, priv, err := newKeyShare(group) pub, priv, err := newKeyShare(group)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
@ -86,10 +89,10 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
logf(logTypeHandshake, "opts: %+v", state.Opts) logf(logTypeHandshake, "opts: %+v", state.Opts)
// supported_versions, supported_groups, signature_algorithms, server_name // supported_versions, supported_groups, signature_algorithms, server_name
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}}
sni := ServerNameExtension(state.Opts.ServerName) sni := ServerNameExtension(state.Opts.ServerName)
sg := SupportedGroupsExtension{Groups: state.Caps.Groups} sg := SupportedGroupsExtension{Groups: state.Config.Groups}
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} sa := SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes}
state.Params.ServerName = state.Opts.ServerName state.Params.ServerName = state.Opts.ServerName
@ -101,7 +104,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
// Construct base ClientHello // Construct base ClientHello
ch := &ClientHelloBody{ ch := &ClientHelloBody{
CipherSuites: state.Caps.CipherSuites, LegacyVersion: wireVersion(state.hsCtx.hIn),
CipherSuites: state.Config.CipherSuites,
} }
_, err := prng.Read(ch.Random[:]) _, err := prng.Read(ch.Random[:])
if err != nil { if err != nil {
@ -133,8 +137,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
} }
// Run the external extension handler. // Run the external extension handler.
if state.Caps.ExtensionHandler != nil { if state.Config.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) err := state.Config.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
@ -150,7 +154,7 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
var earlySecret []byte var earlySecret []byte
var clientEarlyTrafficKeys keySet var clientEarlyTrafficKeys keySet
var clientHello *HandshakeMessage var clientHello *HandshakeMessage
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok { if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok {
offeredPSK = key offeredPSK = key
// Narrow ciphersuites to ones that match PSK hash // Narrow ciphersuites to ones that match PSK hash
@ -168,8 +172,10 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
} }
ch.CipherSuites = compatibleSuites ch.CipherSuites = compatibleSuites
// TODO(ekr@rtfm.com): Check that the ticket can be used for early
// data.
// Signal early data if we're going to do it // Signal early data if we're going to do it
if len(state.Opts.EarlyData) > 0 { if state.Config.AllowEarlyData && state.helloRetryRequest == nil {
state.Params.ClientSendingEarlyData = true state.Params.ClientSendingEarlyData = true
ed = &EarlyDataExtension{} ed = &EarlyDataExtension{}
err = ch.Extensions.Add(ed) err = ch.Extensions.Add(ed)
@ -180,11 +186,11 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
} }
// Signal supported PSK key exchange modes // Signal supported PSK key exchange modes
if len(state.Caps.PSKModes) == 0 { if len(state.Config.PSKModes) == 0 {
logf(logTypeHandshake, "PSK selected, but no PSKModes") logf(logTypeHandshake, "PSK selected, but no PSKModes")
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes} kem := &PSKKeyExchangeModesExtension{KEModes: state.Config.PSKModes}
err = ch.Extensions.Add(kem) err = ch.Extensions.Add(kem)
if err != nil { if err != nil {
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
@ -241,7 +247,7 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so // If we got here, the earlier marshal succeeded (in ch.Truncated()), so
// this one should too. // this one should too.
clientHello, _ = HandshakeMessageFromBody(ch) clientHello, _ = state.hsCtx.hOut.HandshakeMessageFromBody(ch)
// Compute early traffic keys // Compute early traffic keys
h := params.Hash.New() h := params.Hash.New()
@ -251,11 +257,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
} else if len(state.Opts.EarlyData) > 0 {
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
return nil, nil, AlertInternalError
} else { } else {
clientHello, err = HandshakeMessageFromBody(ch) clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
@ -263,10 +266,12 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
} }
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
nextState := ClientStateWaitSH{ state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2.
Caps: state.Caps, nextState := clientStateWaitSH{
Config: state.Config,
Opts: state.Opts, Opts: state.Opts,
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
OfferedDH: offeredDH, OfferedDH: offeredDH,
OfferedPSK: offeredPSK, OfferedPSK: offeredPSK,
@ -279,22 +284,23 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand
} }
toSend := []HandshakeAction{ toSend := []HandshakeAction{
SendHandshakeMessage{clientHello}, QueueHandshakeMessage{clientHello},
SendQueuedHandshake{},
} }
if state.Params.ClientSendingEarlyData { if state.Params.ClientSendingEarlyData {
toSend = append(toSend, []HandshakeAction{ toSend = append(toSend, []HandshakeAction{
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys}, RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys},
SendEarlyData{},
}...) }...)
} }
return nextState, toSend, AlertNoAlert return nextState, toSend, AlertNoAlert
} }
type ClientStateWaitSH struct { type clientStateWaitSH struct {
Caps Capabilities Config *Config
Opts ConnectionOptions Opts ConnectionOptions
Params ConnectionParameters Params ConnectionParameters
hsCtx *HandshakeContext
OfferedDH map[NamedGroup][]byte OfferedDH map[NamedGroup][]byte
OfferedPSK PreSharedKey OfferedPSK PreSharedKey
PSK []byte PSK []byte
@ -307,49 +313,73 @@ type ClientStateWaitSH struct {
clientHello *HandshakeMessage clientHello *HandshakeMessage
} }
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { var _ HandshakeState = &clientStateWaitSH{}
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") func (state clientStateWaitSH) State() State {
return StateClientWaitSH
}
func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
hm, alert := hr.ReadMessage()
if alert != AlertNoAlert {
return nil, nil, alert
}
if hm == nil || hm.msgType != HandshakeTypeServerHello {
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
} }
bodyGeneric, err := hm.ToBody() sh := &ServerHelloBody{}
if _, err := sh.Unmarshal(hm.body); err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] unexpected message")
return nil, nil, AlertUnexpectedMessage
}
// Common SH/HRR processing first.
// 1. Check that sh.version is TLS 1.2
if sh.Version != tls12Version {
logf(logTypeHandshake, "[ClientStateWaitSH] illegal legacy version [%v]", sh.Version)
return nil, nil, AlertIllegalParameter
}
// 2. Check that it responded with a valid version.
supportedVersions := SupportedVersionsExtension{HandshakeType: HandshakeTypeServerHello}
foundSupportedVersions, err := sh.Extensions.Find(&supportedVersions)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitSH] invalid supported_versions extension [%v]", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
if !foundSupportedVersions {
switch body := bodyGeneric.(type) { logf(logTypeHandshake, "[ClientStateWaitSH] no supported_versions extension")
case *HelloRetryRequestBody: return nil, nil, AlertMissingExtension
hrr := body
if state.helloRetryRequest != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
return nil, nil, AlertUnexpectedMessage
} }
if supportedVersions.Versions[0] != supportedVersion {
// Check that the version sent by the server is the one we support logf(logTypeHandshake, "[ClientStateWaitSH] unsupported version [%x]", supportedVersions.Versions[0])
if hrr.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
return nil, nil, AlertProtocolVersion return nil, nil, AlertProtocolVersion
} }
// 3. Check that the server provided a supported ciphersuite
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites { for _, suite := range state.Config.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite) supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
} }
if !supportedCipherSuite { if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite) logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure return nil, nil, AlertHandshakeFailure
} }
// Now check for the sentinel.
if sh.Random == hrrRandomSentinel {
// This is actually HRR.
hrr := sh
// Narrow the supported ciphersuites to the server-provided one // Narrow the supported ciphersuites to the server-provided one
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} state.Config.CipherSuites = []CipherSuite{hrr.CipherSuite}
// Handle external extensions. // Handle external extensions.
if state.Caps.ExtensionHandler != nil { if state.Config.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) err := state.Config.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
@ -358,10 +388,14 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
// The only thing we know how to respond to in an HRR is the Cookie // The only thing we know how to respond to in an HRR is the Cookie
// extension, so if there is either no Cookie extension or anything other // extension, so if there is either no Cookie extension or anything other
// than a Cookie extension, we have to fail. // than a Cookie extension and SupportedVersions we have to fail.
serverCookie := new(CookieExtension) serverCookie := new(CookieExtension)
foundCookie := hrr.Extensions.Find(serverCookie) foundCookie, err := hrr.Extensions.Find(serverCookie)
if !foundCookie || len(hrr.Extensions) != 1 { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Invalid server cookie extension [%v]", err)
return nil, nil, AlertDecodeError
}
if !foundCookie || len(hrr.Extensions) != 2 {
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
return nil, nil, AlertIllegalParameter return nil, nil, AlertIllegalParameter
} }
@ -376,37 +410,26 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
body: h.Sum(nil), body: h.Sum(nil),
} }
state.hsCtx.receivedEndOfFlight()
// TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT
// mode. In DTLS, we also need to bump the sequence number.
// This is a pre-existing defect in Mint. Issue #175.
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
return ClientStateStart{ return clientStateStart{
Caps: state.Caps, Config: state.Config,
Opts: state.Opts, Opts: state.Opts,
hsCtx: state.hsCtx,
cookie: serverCookie.Cookie, cookie: serverCookie.Cookie,
firstClientHello: firstClientHello, firstClientHello: firstClientHello,
helloRetryRequest: hm, helloRetryRequest: hm,
}.Next(nil) }, []HandshakeAction{ResetOut{1}}, AlertNoAlert
case *ServerHelloBody:
sh := body
// Check that the version sent by the server is the one we support
if sh.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
} }
// This is SH.
// Handle external extensions. // Handle external extensions.
if state.Caps.ExtensionHandler != nil { if state.Config.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) err := state.Config.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
@ -417,15 +440,22 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
foundPSK := sh.Extensions.Find(&serverPSK) foundExts, err := sh.Extensions.Parse(
foundKeyShare := sh.Extensions.Find(&serverKeyShare) []ExtensionBody{
&serverPSK,
&serverKeyShare,
})
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err)
return nil, nil, AlertDecodeError
}
if foundPSK && (serverPSK.SelectedIdentity == 0) { if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) {
state.Params.UsingPSK = true state.Params.UsingPSK = true
} }
var dhSecret []byte var dhSecret []byte
if foundKeyShare { if foundExts[ExtensionTypeKeyShare] {
sks := serverKeyShare.Shares[0] sks := serverKeyShare.Shares[0]
priv, ok := state.OfferedDH[sks.Group] priv, ok := state.OfferedDH[sks.Group]
if !ok { if !ok {
@ -488,114 +518,152 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := ClientStateWaitEE{ nextState := clientStateWaitEE{
Caps: state.Caps, Config: state.Config,
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: params, cryptoParams: params,
handshakeHash: handshakeHash, handshakeHash: handshakeHash,
certificates: state.Caps.Certificates,
masterSecret: masterSecret, masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
} }
toSend := []HandshakeAction{ toSend := []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys}, RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys},
} }
// We're definitely not going to have to send anything with
// early data.
if !state.Params.ClientSendingEarlyData {
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)})
}
return nextState, toSend, AlertNoAlert return nextState, toSend, AlertNoAlert
} }
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) type clientStateWaitEE struct {
return nil, nil, AlertUnexpectedMessage Config *Config
}
type ClientStateWaitEE struct {
Caps Capabilities
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters Params ConnectionParameters
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams cryptoParams CipherSuiteParams
handshakeHash hash.Hash handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte masterSecret []byte
clientHandshakeTrafficSecret []byte clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte
} }
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { var _ HandshakeState = &clientStateWaitEE{}
func (state clientStateWaitEE) State() State {
return StateClientWaitEE
}
func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
hm, alert := hr.ReadMessage()
if alert != AlertNoAlert {
return nil, nil, alert
}
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
} }
ee := EncryptedExtensionsBody{} ee := EncryptedExtensionsBody{}
_, err := ee.Unmarshal(hm.body) if err := safeUnmarshal(&ee, hm.body); err != nil {
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
// Handle external extensions. // Handle external extensions.
if state.Caps.ExtensionHandler != nil { if state.Config.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) err := state.Config.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
} }
serverALPN := ALPNExtension{} serverALPN := &ALPNExtension{}
serverEarlyData := EarlyDataExtension{} serverEarlyData := &EarlyDataExtension{}
gotALPN := ee.Extensions.Find(&serverALPN) foundExts, err := ee.Extensions.Parse(
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) []ExtensionBody{
serverALPN,
serverEarlyData,
})
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err)
return nil, nil, AlertDecodeError
}
if gotALPN && len(serverALPN.Protocols) > 0 { state.Params.UsingEarlyData = foundExts[ExtensionTypeEarlyData]
if foundExts[ExtensionTypeALPN] && len(serverALPN.Protocols) > 0 {
state.Params.NextProto = serverALPN.Protocols[0] state.Params.NextProto = serverALPN.Protocols[0]
} }
state.handshakeHash.Write(hm.Marshal()) state.handshakeHash.Write(hm.Marshal())
toSend := []HandshakeAction{}
if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData {
// We didn't get 0-RTT, so rekey to handshake.
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
}
if state.Params.UsingPSK { if state.Params.UsingPSK {
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{ nextState := clientStateWaitFinished{
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams, cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash, handshakeHash: state.handshakeHash,
certificates: state.certificates, certificates: state.Config.Certificates,
masterSecret: state.masterSecret, masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
} }
return nextState, nil, AlertNoAlert return nextState, toSend, AlertNoAlert
} }
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
nextState := ClientStateWaitCertCR{ nextState := clientStateWaitCertCR{
AuthCertificate: state.AuthCertificate, Config: state.Config,
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams, cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash, handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret, masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
} }
return nextState, nil, AlertNoAlert return nextState, toSend, AlertNoAlert
} }
type ClientStateWaitCertCR struct { type clientStateWaitCertCR struct {
AuthCertificate func(chain []CertificateEntry) error Config *Config
Params ConnectionParameters Params ConnectionParameters
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams cryptoParams CipherSuiteParams
handshakeHash hash.Hash handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte masterSecret []byte
clientHandshakeTrafficSecret []byte clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte
} }
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { var _ HandshakeState = &clientStateWaitCertCR{}
func (state clientStateWaitCertCR) State() State {
return StateClientWaitCertCR
}
func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
hm, alert := hr.ReadMessage()
if alert != AlertNoAlert {
return nil, nil, alert
}
if hm == nil { if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
@ -612,12 +680,12 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [
switch body := bodyGeneric.(type) { switch body := bodyGeneric.(type) {
case *CertificateBody: case *CertificateBody:
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{ nextState := clientStateWaitCV{
AuthCertificate: state.AuthCertificate, Config: state.Config,
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams, cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash, handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: body, serverCertificate: body,
masterSecret: state.masterSecret, masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
@ -635,12 +703,12 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [
state.Params.UsingClientAuth = true state.Params.UsingClientAuth = true
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
nextState := ClientStateWaitCert{ nextState := clientStateWaitCert{
AuthCertificate: state.AuthCertificate, Config: state.Config,
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams, cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash, handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: body, serverCertificateRequest: body,
masterSecret: state.masterSecret, masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
@ -652,13 +720,13 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
} }
type ClientStateWaitCert struct { type clientStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error Config *Config
Params ConnectionParameters Params ConnectionParameters
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams cryptoParams CipherSuiteParams
handshakeHash hash.Hash handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody serverCertificateRequest *CertificateRequestBody
masterSecret []byte masterSecret []byte
@ -666,15 +734,24 @@ type ClientStateWaitCert struct {
serverHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte
} }
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { var _ HandshakeState = &clientStateWaitCert{}
func (state clientStateWaitCert) State() State {
return StateClientWaitCert
}
func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
hm, alert := hr.ReadMessage()
if alert != AlertNoAlert {
return nil, nil, alert
}
if hm == nil || hm.msgType != HandshakeTypeCertificate { if hm == nil || hm.msgType != HandshakeTypeCertificate {
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
} }
cert := &CertificateBody{} cert := &CertificateBody{}
_, err := cert.Unmarshal(hm.body) if err := safeUnmarshal(cert, hm.body); err != nil {
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -682,12 +759,12 @@ func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H
state.handshakeHash.Write(hm.Marshal()) state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{ nextState := clientStateWaitCV{
AuthCertificate: state.AuthCertificate, Config: state.Config,
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams, cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash, handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: cert, serverCertificate: cert,
serverCertificateRequest: state.serverCertificateRequest, serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret, masterSecret: state.masterSecret,
@ -697,13 +774,13 @@ func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H
return nextState, nil, AlertNoAlert return nextState, nil, AlertNoAlert
} }
type ClientStateWaitCV struct { type clientStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error Config *Config
Params ConnectionParameters Params ConnectionParameters
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams cryptoParams CipherSuiteParams
handshakeHash hash.Hash handshakeHash hash.Hash
certificates []*Certificate
serverCertificate *CertificateBody serverCertificate *CertificateBody
serverCertificateRequest *CertificateRequestBody serverCertificateRequest *CertificateRequestBody
@ -712,15 +789,24 @@ type ClientStateWaitCV struct {
serverHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte
} }
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { var _ HandshakeState = &clientStateWaitCV{}
func (state clientStateWaitCV) State() State {
return StateClientWaitCV
}
func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
hm, alert := hr.ReadMessage()
if alert != AlertNoAlert {
return nil, nil, alert
}
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
} }
certVerify := CertificateVerifyBody{} certVerify := CertificateVerifyBody{}
_, err := certVerify.Unmarshal(hm.body) if err := safeUnmarshal(&certVerify, hm.body); err != nil {
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -734,46 +820,89 @@ func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []Han
return nil, nil, AlertHandshakeFailure return nil, nil, AlertHandshakeFailure
} }
if state.AuthCertificate != nil { certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList))
err := state.AuthCertificate(state.serverCertificate.CertificateList) rawCerts := make([][]byte, len(state.serverCertificate.CertificateList))
for i, certEntry := range state.serverCertificate.CertificateList {
certs[i] = certEntry.CertData
rawCerts[i] = certEntry.CertData.Raw
}
var verifiedChains [][]*x509.Certificate
if !state.Config.InsecureSkipVerify {
opts := x509.VerifyOptions{
Roots: state.Config.RootCAs,
CurrentTime: state.Config.time(),
DNSName: state.Config.ServerName,
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
var err error
verifiedChains, err = certs[0].Verify(opts)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate") logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err)
return nil, nil, AlertBadCertificate
}
}
if state.Config.VerifyPeerCertificate != nil {
if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err)
return nil, nil, AlertBadCertificate return nil, nil, AlertBadCertificate
} }
} else {
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
} }
state.handshakeHash.Write(hm.Marshal()) state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{ nextState := clientStateWaitFinished{
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
cryptoParams: state.cryptoParams, cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash, handshakeHash: state.handshakeHash,
certificates: state.certificates, certificates: state.Config.Certificates,
serverCertificateRequest: state.serverCertificateRequest, serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret, masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
peerCertificates: certs,
verifiedChains: verifiedChains,
} }
return nextState, nil, AlertNoAlert return nextState, nil, AlertNoAlert
} }
type ClientStateWaitFinished struct { type clientStateWaitFinished struct {
Params ConnectionParameters Params ConnectionParameters
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams cryptoParams CipherSuiteParams
handshakeHash hash.Hash handshakeHash hash.Hash
certificates []*Certificate certificates []*Certificate
serverCertificateRequest *CertificateRequestBody serverCertificateRequest *CertificateRequestBody
peerCertificates []*x509.Certificate
verifiedChains [][]*x509.Certificate
masterSecret []byte masterSecret []byte
clientHandshakeTrafficSecret []byte clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte
} }
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { var _ HandshakeState = &clientStateWaitFinished{}
func (state clientStateWaitFinished) State() State {
return StateClientWaitFinished
}
func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
hm, alert := hr.ReadMessage()
if alert != AlertNoAlert {
return nil, nil, alert
}
if hm == nil || hm.msgType != HandshakeTypeFinished { if hm == nil || hm.msgType != HandshakeTypeFinished {
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
@ -788,8 +917,7 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
_, err := fin.Unmarshal(hm.body) if err := safeUnmarshal(fin, hm.body); err != nil {
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
return nil, nil, AlertDecodeError return nil, nil, AlertDecodeError
} }
@ -822,25 +950,32 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
toSend := []HandshakeAction{} toSend := []HandshakeAction{}
if state.Params.UsingEarlyData { if state.Params.UsingEarlyData {
logf(logTypeHandshake, "Sending end of early data")
// Note: We only send EOED if the server is actually going to use the early // Note: We only send EOED if the server is actually going to use the early
// data. Otherwise, it will never see it, and the transcripts will // data. Otherwise, it will never see it, and the transcripts will
// mismatch. // mismatch.
// EOED marshal is infallible // EOED marshal is infallible
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{}) eoedm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(&EndOfEarlyDataBody{})
toSend = append(toSend, SendHandshakeMessage{eoedm}) toSend = append(toSend, QueueHandshakeMessage{eoedm})
state.handshakeHash.Write(eoedm.Marshal()) state.handshakeHash.Write(eoedm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
}
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) // And then rekey to handshake
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys}) toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
}
if state.Params.UsingClientAuth { if state.Params.UsingClientAuth {
// Extract constraints from certicateRequest // Extract constraints from certicateRequest
schemes := SignatureAlgorithmsExtension{} schemes := SignatureAlgorithmsExtension{}
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) gotSchemes, err := state.serverCertificateRequest.Extensions.Find(&schemes)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING invalid signature_schemes extension [%v]", err)
return nil, nil, AlertDecodeError
}
if !gotSchemes { if !gotSchemes {
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found")
return nil, nil, AlertIllegalParameter return nil, nil, AlertIllegalParameter
} }
@ -851,13 +986,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
certificate := &CertificateBody{} certificate := &CertificateBody{}
certm, err := HandshakeMessageFromBody(certificate) certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
toSend = append(toSend, SendHandshakeMessage{certm}) toSend = append(toSend, QueueHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal()) state.handshakeHash.Write(certm.Marshal())
} else { } else {
// Create and send Certificate, CertificateVerify // Create and send Certificate, CertificateVerify
@ -867,13 +1002,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
for i, entry := range cert.Chain { for i, entry := range cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry} certificate.CertificateList[i] = CertificateEntry{CertData: entry}
} }
certm, err := HandshakeMessageFromBody(certificate) certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
toSend = append(toSend, SendHandshakeMessage{certm}) toSend = append(toSend, QueueHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal()) state.handshakeHash.Write(certm.Marshal())
hcv := state.handshakeHash.Sum(nil) hcv := state.handshakeHash.Sum(nil)
@ -887,13 +1022,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
certvm, err := HandshakeMessageFromBody(certificateVerify) certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
} }
toSend = append(toSend, SendHandshakeMessage{certvm}) toSend = append(toSend, QueueHandshakeMessage{certvm})
state.handshakeHash.Write(certvm.Marshal()) state.handshakeHash.Write(certvm.Marshal())
} }
} }
@ -909,7 +1044,7 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
VerifyDataLen: len(clientFinishedData), VerifyDataLen: len(clientFinishedData),
VerifyData: clientFinishedData, VerifyData: clientFinishedData,
} }
finm, err := HandshakeMessageFromBody(fin) finm, err := state.hsCtx.hOut.HandshakeMessageFromBody(fin)
if err != nil { if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
return nil, nil, AlertInternalError return nil, nil, AlertInternalError
@ -923,20 +1058,26 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState,
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
toSend = append(toSend, []HandshakeAction{ toSend = append(toSend, []HandshakeAction{
SendHandshakeMessage{finm}, QueueHandshakeMessage{finm},
RekeyIn{Label: "application", KeySet: serverTrafficKeys}, SendQueuedHandshake{},
RekeyOut{Label: "application", KeySet: clientTrafficKeys}, RekeyIn{epoch: EpochApplicationData, KeySet: serverTrafficKeys},
RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys},
}...) }...)
state.hsCtx.receivedEndOfFlight()
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
nextState := StateConnected{ nextState := stateConnected{
Params: state.Params, Params: state.Params,
hsCtx: state.hsCtx,
isClient: true, isClient: true,
cryptoParams: state.cryptoParams, cryptoParams: state.cryptoParams,
resumptionSecret: resumptionSecret, resumptionSecret: resumptionSecret,
clientTrafficSecret: clientTrafficSecret, clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret, serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret, exporterSecret: exporterSecret,
peerCertificates: state.peerCertificates,
verifiedChains: state.verifiedChains,
} }
return nextState, toSend, AlertNoAlert return nextState, toSend, AlertNoAlert
} }

View File

@ -5,9 +5,14 @@ import (
"strconv" "strconv"
) )
var ( const (
supportedVersion uint16 = 0x7f15 // draft-21 supportedVersion uint16 = 0x7f16 // draft-22
tls12Version uint16 = 0x0303
tls10Version uint16 = 0x0301
dtls12WireVersion uint16 = 0xfefd
)
var (
// Flags for some minor compat issues // Flags for some minor compat issues
allowWrongVersionNumber = true allowWrongVersionNumber = true
allowPKCS1 = true allowPKCS1 = true
@ -20,6 +25,7 @@ const (
RecordTypeAlert RecordType = 21 RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22 RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23 RecordTypeApplicationData RecordType = 23
RecordTypeAck RecordType = 25
) )
// enum {...} HandshakeType; // enum {...} HandshakeType;
@ -42,6 +48,13 @@ const (
HandshakeTypeMessageHash HandshakeType = 254 HandshakeTypeMessageHash HandshakeType = 254
) )
var hrrRandomSentinel = [32]byte{
0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11,
0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e,
0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
}
// uint8 CipherSuite[2]; // uint8 CipherSuite[2];
type CipherSuite uint16 type CipherSuite uint16
@ -150,3 +163,104 @@ const (
KeyUpdateNotRequested KeyUpdateRequest = 0 KeyUpdateNotRequested KeyUpdateRequest = 0
KeyUpdateRequested KeyUpdateRequest = 1 KeyUpdateRequested KeyUpdateRequest = 1
) )
type State uint8
const (
StateInit = 0
// states valid for the client
StateClientStart State = iota
StateClientWaitSH
StateClientWaitEE
StateClientWaitCert
StateClientWaitCV
StateClientWaitFinished
StateClientWaitCertCR
StateClientConnected
// states valid for the server
StateServerStart State = iota
StateServerRecvdCH
StateServerNegotiated
StateServerReadPastEarlyData
StateServerWaitEOED
StateServerWaitFlight2
StateServerWaitCert
StateServerWaitCV
StateServerWaitFinished
StateServerConnected
)
func (s State) String() string {
switch s {
case StateClientStart:
return "Client START"
case StateClientWaitSH:
return "Client WAIT_SH"
case StateClientWaitEE:
return "Client WAIT_EE"
case StateClientWaitCert:
return "Client WAIT_CERT"
case StateClientWaitCV:
return "Client WAIT_CV"
case StateClientWaitFinished:
return "Client WAIT_FINISHED"
case StateClientWaitCertCR:
return "Client WAIT_CERT_CR"
case StateClientConnected:
return "Client CONNECTED"
case StateServerStart:
return "Server START"
case StateServerRecvdCH:
return "Server RECVD_CH"
case StateServerNegotiated:
return "Server NEGOTIATED"
case StateServerReadPastEarlyData:
return "Server READ_PAST_EARLY_DATA"
case StateServerWaitEOED:
return "Server WAIT_EOED"
case StateServerWaitFlight2:
return "Server WAIT_FLIGHT2"
case StateServerWaitCert:
return "Server WAIT_CERT"
case StateServerWaitCV:
return "Server WAIT_CV"
case StateServerWaitFinished:
return "Server WAIT_FINISHED"
case StateServerConnected:
return "Server CONNECTED"
default:
return fmt.Sprintf("unknown state: %d", s)
}
}
// Epochs for DTLS (also used for key phase labelling)
type Epoch uint16
const (
EpochClear Epoch = 0
EpochEarlyData Epoch = 1
EpochHandshakeData Epoch = 2
EpochApplicationData Epoch = 3
EpochUpdate Epoch = 4
)
func (e Epoch) label() string {
switch e {
case EpochClear:
return "clear"
case EpochEarlyData:
return "early data"
case EpochHandshakeData:
return "handshake"
case EpochApplicationData:
return "application data"
}
return "Application data (updated)"
}
func assert(b bool) {
if !b {
panic("Assertion failed")
}
}

View File

@ -4,6 +4,7 @@ import (
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -12,8 +13,6 @@ import (
"time" "time"
) )
var WouldBlock = fmt.Errorf("Would have blocked")
type Certificate struct { type Certificate struct {
Chain []*x509.Certificate Chain []*x509.Certificate
PrivateKey crypto.Signer PrivateKey crypto.Signer
@ -36,16 +35,20 @@ type PreSharedKeyCache interface {
Size() int Size() int
} }
type PSKMapCache map[string]PreSharedKey // A CookieHandler can be used to give the application more fine-grained control over Cookies.
// Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie.
// A CookieHandler does two things: // When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie.
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
// - validates this byte string echoed by the client in the ClientHello
type CookieHandler interface { type CookieHandler interface {
// Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
// If Generate returns nil, mint will not send a HelloRetryRequest.
Generate(*Conn) ([]byte, error) Generate(*Conn) ([]byte, error)
// Validate is called when receiving a ClientHello containing a Cookie.
// If validation failed, the handshake is aborted.
Validate(*Conn, []byte) bool Validate(*Conn, []byte) bool
} }
type PSKMapCache map[string]PreSharedKey
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
psk, ok = cache[key] psk, ok = cache[key]
return return
@ -74,14 +77,49 @@ type Config struct {
AllowEarlyData bool AllowEarlyData bool
// Require the client to echo a cookie. // Require the client to echo a cookie.
RequireCookie bool RequireCookie bool
// If cookies are required and no CookieHandler is set, a default cookie handler is used. // A CookieHandler can be used to set and validate a cookie.
// The default cookie handler uses 32 random bytes as a cookie. // The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector.
// If no CookieHandler is set, mint will always send a cookie.
// The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent.
CookieHandler CookieHandler CookieHandler CookieHandler
// The CookieProtector is used to encrypt / decrypt cookies.
// It should make sure that the Cookie cannot be read and tampered with by the client.
// If non-blocking mode is used, and cookies are required, this field has to be set.
// In blocking mode, a default cookie protector is used, if this is unused.
CookieProtector CookieProtector
// The ExtensionHandler is used to add custom extensions.
ExtensionHandler AppExtensionHandler
RequireClientAuth bool RequireClientAuth bool
// Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses time.Now.
Time func() time.Time
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name.
// If InsecureSkipVerify is true, TLS accepts any certificate
// presented by the server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureSkipVerify bool
// Shared fields // Shared fields
Certificates []*Certificate Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error // VerifyPeerCertificate, if not nil, is called after normal
// certificate verification by either a TLS client or server. It
// receives the raw ASN.1 certificates provided by the peer and also
// any verified chains that normal processing found. If it returns a
// non-nil error, the handshake is aborted and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. If normal verification is disabled by
// setting InsecureSkipVerify then this callback will be considered but
// the verifiedChains argument will always be nil.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
CipherSuites []CipherSuite CipherSuites []CipherSuite
Groups []NamedGroup Groups []NamedGroup
SignatureSchemes []SignatureScheme SignatureSchemes []SignatureScheme
@ -89,6 +127,7 @@ type Config struct {
PSKs PreSharedKeyCache PSKs PreSharedKeyCache
PSKModes []PSKKeyExchangeMode PSKModes []PSKKeyExchangeMode
NonBlocking bool NonBlocking bool
UseDTLS bool
// The same config object can be shared among different connections, so it // The same config object can be shared among different connections, so it
// needs its own mutex // needs its own mutex
@ -110,10 +149,16 @@ func (c *Config) Clone() *Config {
EarlyDataLifetime: c.EarlyDataLifetime, EarlyDataLifetime: c.EarlyDataLifetime,
AllowEarlyData: c.AllowEarlyData, AllowEarlyData: c.AllowEarlyData,
RequireCookie: c.RequireCookie, RequireCookie: c.RequireCookie,
CookieHandler: c.CookieHandler,
CookieProtector: c.CookieProtector,
ExtensionHandler: c.ExtensionHandler,
RequireClientAuth: c.RequireClientAuth, RequireClientAuth: c.RequireClientAuth,
Time: c.Time,
RootCAs: c.RootCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
Certificates: c.Certificates, Certificates: c.Certificates,
AuthCertificate: c.AuthCertificate, VerifyPeerCertificate: c.VerifyPeerCertificate,
CipherSuites: c.CipherSuites, CipherSuites: c.CipherSuites,
Groups: c.Groups, Groups: c.Groups,
SignatureSchemes: c.SignatureSchemes, SignatureSchemes: c.SignatureSchemes,
@ -121,6 +166,7 @@ func (c *Config) Clone() *Config {
PSKs: c.PSKs, PSKs: c.PSKs,
PSKModes: c.PSKModes, PSKModes: c.PSKModes,
NonBlocking: c.NonBlocking, NonBlocking: c.NonBlocking,
UseDTLS: c.UseDTLS,
} }
} }
@ -147,28 +193,6 @@ func (c *Config) Init(isClient bool) error {
if len(c.PSKModes) == 0 { if len(c.PSKModes) == 0 {
c.PSKModes = defaultPSKModes c.PSKModes = defaultPSKModes
} }
// If there is no certificate, generate one
if !isClient && len(c.Certificates) == 0 {
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
priv, err := newSigningKey(RSA_PSS_SHA256)
if err != nil {
return err
}
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
if err != nil {
return err
}
c.Certificates = []*Certificate{
{
Chain: []*x509.Certificate{cert},
PrivateKey: priv,
},
}
}
return nil return nil
} }
@ -183,6 +207,14 @@ func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0 return len(c.ServerName) > 0
} }
func (c *Config) time() time.Time {
t := c.Time
if t == nil {
t = time.Now
}
return t()
}
var ( var (
defaultSupportedCipherSuites = []CipherSuite{ defaultSupportedCipherSuites = []CipherSuite{
TLS_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256,
@ -214,10 +246,13 @@ var (
) )
type ConnectionState struct { type ConnectionState struct {
HandshakeState string // string representation of the handshake state. HandshakeState State
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
NextProto string // Selected ALPN proto NextProto string // Selected ALPN proto
UsingPSK bool // Are we using PSK.
UsingEarlyData bool // Did we negotiate 0-RTT.
} }
// Conn implements the net.Conn interface, as with "crypto/tls" // Conn implements the net.Conn interface, as with "crypto/tls"
@ -228,9 +263,7 @@ type Conn struct {
conn net.Conn conn net.Conn
isClient bool isClient bool
EarlyData []byte state stateConnected
state StateConnected
hState HandshakeState hState HandshakeState
handshakeMutex sync.Mutex handshakeMutex sync.Mutex
handshakeAlert Alert handshakeAlert Alert
@ -238,18 +271,28 @@ type Conn struct {
readBuffer []byte readBuffer []byte
in, out *RecordLayer in, out *RecordLayer
hIn, hOut *HandshakeLayer hsCtx *HandshakeContext
extHandler AppExtensionHandler
} }
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient} c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
c.in = NewRecordLayer(c.conn) if !config.UseDTLS {
c.out = NewRecordLayer(c.conn) c.in = NewRecordLayerTLS(c.conn, directionRead)
c.hIn = NewHandshakeLayer(c.in) c.out = NewRecordLayerTLS(c.conn, directionWrite)
c.hIn.nonblocking = c.config.NonBlocking c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
c.hOut = NewHandshakeLayer(c.out) c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
} else {
c.in = NewRecordLayerDTLS(c.conn, directionRead)
c.out = NewRecordLayerDTLS(c.conn, directionWrite)
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
c.hsCtx.timeoutMS = initialTimeout
c.hsCtx.timers = newTimerSet()
c.hsCtx.waitingNextFlight = true
}
c.in.label = c.label()
c.out.label = c.label()
c.hsCtx.hIn.nonblocking = c.config.NonBlocking
return c return c
} }
@ -267,8 +310,12 @@ func (c *Conn) consumeRecord() error {
// We do not support fragmentation of post-handshake handshake messages. // We do not support fragmentation of post-handshake handshake messages.
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage() // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
start := 0 start := 0
headerLen := handshakeHeaderLenTLS
if c.config.UseDTLS {
headerLen = handshakeHeaderLenDTLS
}
for start < len(pt.fragment) { for start < len(pt.fragment) {
if len(pt.fragment[start:]) < handshakeHeaderLen { if len(pt.fragment[start:]) < headerLen {
return fmt.Errorf("Post-handshake handshake message too short for header") return fmt.Errorf("Post-handshake handshake message too short for header")
} }
@ -276,14 +323,15 @@ func (c *Conn) consumeRecord() error {
hm.msgType = HandshakeType(pt.fragment[start]) hm.msgType = HandshakeType(pt.fragment[start])
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen { if len(pt.fragment[start+headerLen:]) < hmLen {
return fmt.Errorf("Post-handshake handshake message too short for body") return fmt.Errorf("Post-handshake handshake message too short for body")
} }
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen] hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen]
// Advance state machine
state, actions, alert := c.state.Next(hm)
// XXX: If we want to support more advanced cases, e.g., post-handshake
// authentication, we'll need to allow transitions other than
// Connected -> Connected
state, actions, alert := c.state.ProcessMessage(hm)
if alert != AlertNoAlert { if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert) logf(logTypeHandshake, "Error in state transition: %v", alert)
c.sendAlert(alert) c.sendAlert(alert)
@ -299,18 +347,15 @@ func (c *Conn) consumeRecord() error {
} }
} }
// XXX: If we want to support more advanced cases, e.g., post-handshake
// authentication, we'll need to allow transitions other than
// Connected -> Connected
var connected bool var connected bool
c.state, connected = state.(StateConnected) c.state, connected = state.(stateConnected)
if !connected { if !connected {
logf(logTypeHandshake, "Disconnected after state transition: %v", alert) logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
c.sendAlert(alert) c.sendAlert(alert)
return io.EOF return io.EOF
} }
start += handshakeHeaderLen + hmLen start += headerLen + hmLen
} }
case RecordTypeAlert: case RecordTypeAlert:
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
@ -332,17 +377,54 @@ func (c *Conn) consumeRecord() error {
return io.EOF return io.EOF
} }
case RecordTypeAck:
if !c.hsCtx.hIn.datagram {
logf(logTypeHandshake, "Received ACK in TLS mode")
return AlertUnexpectedMessage
}
return c.hsCtx.processAck(pt.fragment)
case RecordTypeApplicationData: case RecordTypeApplicationData:
c.readBuffer = append(c.readBuffer, pt.fragment...) c.readBuffer = append(c.readBuffer, pt.fragment...)
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
} }
return err return err
} }
func readPartial(in *[]byte, buffer []byte) int {
logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in)))
read := copy(buffer, *in)
*in = (*in)[read:]
logf(logTypeVerbose, "Returning %v", string(buffer))
return read
}
// Read application data up to the size of buffer. Handshake and alert records // Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly. // are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) { func (c *Conn) Read(buffer []byte) (int, error) {
if _, connected := c.hState.(stateConnected); !connected {
// Clients can't call Read prior to handshake completion.
if c.isClient {
return 0, errors.New("Read called before the handshake completed")
}
// Neither can servers that don't allow early data.
if !c.config.AllowEarlyData {
return 0, errors.New("Read called before the handshake completed")
}
// If there's no early data, then return WouldBlock
if len(c.hsCtx.earlyData) == 0 {
return 0, AlertWouldBlock
}
return readPartial(&c.hsCtx.earlyData, buffer), nil
}
// The handshake is now connected.
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
if alert := c.Handshake(); alert != AlertNoAlert { if alert := c.Handshake(); alert != AlertNoAlert {
return 0, alert return 0, alert
@ -352,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) {
return 0, nil return 0, nil
} }
// Run our timers.
if c.config.UseDTLS {
if err := c.hsCtx.timers.check(time.Now()); err != nil {
return 0, AlertInternalError
}
}
// Lock the input channel // Lock the input channel
c.in.Lock() c.in.Lock()
defer c.in.Unlock() defer c.in.Unlock()
@ -361,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) {
// err can be nil if consumeRecord processed a non app-data // err can be nil if consumeRecord processed a non app-data
// record. // record.
if err != nil { if err != nil {
if c.config.NonBlocking || err != WouldBlock { if c.config.NonBlocking || err != AlertWouldBlock {
logf(logTypeIO, "conn.Read returns err=%v", err) logf(logTypeIO, "conn.Read returns err=%v", err)
return 0, err return 0, err
} }
} }
} }
var read int return readPartial(&c.readBuffer, buffer), nil
n := len(buffer)
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
if len(c.readBuffer) <= n {
buffer = buffer[:len(c.readBuffer)]
copy(buffer, c.readBuffer)
read = len(c.readBuffer)
c.readBuffer = c.readBuffer[:0]
} else {
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
copy(buffer[:n], c.readBuffer[:n])
c.readBuffer = c.readBuffer[n:]
read = n
}
logf(logTypeVerbose, "Returning %v", string(buffer))
return read, nil
} }
// Write application data // Write application data
@ -393,6 +466,10 @@ func (c *Conn) Write(buffer []byte) (int, error) {
c.out.Lock() c.out.Lock()
defer c.out.Unlock() defer c.out.Unlock()
if !c.Writable() {
return 0, errors.New("Write called before the handshake completed (and early data not in use)")
}
// Send full-size fragments // Send full-size fragments
var start int var start int
sent := 0 sent := 0
@ -495,84 +572,44 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
} }
switch action := actionGeneric.(type) { switch action := actionGeneric.(type) {
case SendHandshakeMessage: case QueueHandshakeMessage:
err := c.hOut.WriteMessage(action.Message) logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType)
err := c.hsCtx.hOut.QueueMessage(action.Message)
if err != nil { if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError return AlertInternalError
} }
case SendQueuedHandshake:
_, err := c.hsCtx.hOut.SendQueuedMessages()
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
if c.config.UseDTLS {
c.hsCtx.timers.start(retransmitTimerLabel,
c.hsCtx.handshakeRetransmit,
c.hsCtx.timeoutMS)
}
case RekeyIn: case RekeyIn:
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet) logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet)
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil { if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
return AlertInternalError return AlertInternalError
} }
case RekeyOut: case RekeyOut:
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet) logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet)
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil { if err != nil {
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
return AlertInternalError return AlertInternalError
} }
case SendEarlyData: case ResetOut:
logf(logTypeHandshake, "%s Sending early data...", label) logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq)
_, err := c.Write(c.EarlyData) c.out.ResetClear(action.seq)
if err != nil {
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
return AlertInternalError
}
case ReadPastEarlyData:
logf(logTypeHandshake, "%s Reading past early data...", label)
// Scan past all records that fail to decrypt
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok := err.(DecryptError)
for ok {
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok = err.(DecryptError)
}
case ReadEarlyData:
logf(logTypeHandshake, "%s Reading early data...", label)
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
for t == RecordTypeApplicationData {
// Read a record into the buffer. Note that this is safe
// in blocking mode because we read the record in in
// PeekRecordType.
pt, err := c.in.ReadRecord()
if err != nil {
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
c.EarlyData = append(c.EarlyData, pt.fragment...)
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
}
logf(logTypeHandshake, "%s Done reading early data", label)
case StorePSK: case StorePSK:
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
@ -585,7 +622,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
} }
default: default:
logf(logTypeHandshake, "%s Unknown actionuction type", label) logf(logTypeHandshake, "%s Unknown action type", label)
assert(false)
return AlertInternalError return AlertInternalError
} }
@ -602,33 +640,13 @@ func (c *Conn) HandshakeSetup() Alert {
return AlertInternalError return AlertInternalError
} }
// Set things up
caps := Capabilities{
CipherSuites: c.config.CipherSuites,
Groups: c.config.Groups,
SignatureSchemes: c.config.SignatureSchemes,
PSKs: c.config.PSKs,
PSKModes: c.config.PSKModes,
AllowEarlyData: c.config.AllowEarlyData,
RequireCookie: c.config.RequireCookie,
CookieHandler: c.config.CookieHandler,
RequireClientAuth: c.config.RequireClientAuth,
NextProtos: c.config.NextProtos,
Certificates: c.config.Certificates,
ExtensionHandler: c.extHandler,
}
opts := ConnectionOptions{ opts := ConnectionOptions{
ServerName: c.config.ServerName, ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos, NextProtos: c.config.NextProtos,
EarlyData: c.EarlyData,
}
if caps.RequireCookie && caps.CookieHandler == nil {
caps.CookieHandler = &defaultCookieHandler{}
} }
if c.isClient { if c.isClient {
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil) state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
if alert != AlertNoAlert { if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert) logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert return alert
@ -642,14 +660,56 @@ func (c *Conn) HandshakeSetup() Alert {
} }
} }
} else { } else {
state = ServerStateStart{Caps: caps, conn: c} if c.config.RequireCookie && c.config.CookieProtector == nil {
logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.")
if c.config.NonBlocking {
logf(logTypeHandshake, "Not possible in non-blocking mode.")
return AlertInternalError
}
var err error
c.config.CookieProtector, err = NewDefaultCookieProtector()
if err != nil {
logf(logTypeHandshake, "Error initializing cookie source: %v", alert)
return AlertInternalError
}
}
state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx}
} }
c.hState = state c.hState = state
return AlertNoAlert return AlertNoAlert
} }
type handshakeMessageReader interface {
ReadMessage() (*HandshakeMessage, Alert)
}
type handshakeMessageReaderImpl struct {
hsCtx *HandshakeContext
}
var _ handshakeMessageReader = &handshakeMessageReaderImpl{}
func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) {
var hm *HandshakeMessage
var err error
for {
hm, err = r.hsCtx.hIn.ReadMessage()
if err == AlertWouldBlock {
return nil, AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "Error reading message: %v", err)
return nil, AlertCloseNotify
}
if hm != nil {
break
}
}
return hm, AlertNoAlert
}
// Handshake causes a TLS handshake on the connection. The `isClient` member // Handshake causes a TLS handshake on the connection. The `isClient` member
// determines whether a client or server handshake is performed. If a // determines whether a client or server handshake is performed. If a
// handshake has already been performed, then its result will be returned. // handshake has already been performed, then its result will be returned.
@ -669,48 +729,48 @@ func (c *Conn) Handshake() Alert {
return AlertNoAlert return AlertNoAlert
} }
var alert Alert
if c.hState == nil { if c.hState == nil {
logf(logTypeHandshake, "%s First time through handshake, setting up", label) logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label)
alert = c.HandshakeSetup() alert := c.HandshakeSetup()
if alert != AlertNoAlert { if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) {
return alert return alert
} }
} else {
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
} }
logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState)
state := c.hState state := c.hState
_, connected := state.(StateConnected) _, connected := state.(stateConnected)
hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx}
for !connected {
var alert Alert
var actions []HandshakeAction var actions []HandshakeAction
for !connected { // Advance the state machine
// Read a handshake message state, actions, alert = state.Next(hmr)
hm, err := c.hIn.ReadMessage() if alert == AlertWouldBlock {
if err == WouldBlock { logf(logTypeHandshake, "%s Would block reading message: %s", label, alert)
logf(logTypeHandshake, "%s Would block reading message: %v", label, err) // If we blocked, then run our timers to see if any have expired.
if c.hsCtx.hIn.datagram {
if err := c.hsCtx.timers.check(time.Now()); err != nil {
return AlertInternalError
}
}
return AlertWouldBlock return AlertWouldBlock
} }
if err != nil { if alert == AlertCloseNotify {
logf(logTypeHandshake, "%s Error reading message: %v", label, err) logf(logTypeHandshake, "%s Error reading message: %s", label, alert)
c.sendAlert(AlertCloseNotify) c.sendAlert(AlertCloseNotify)
return AlertCloseNotify return AlertCloseNotify
} }
logf(logTypeHandshake, "Read message with type: %v", hm.msgType) if alert != AlertNoAlert && alert != AlertStatelessRetry {
// Advance the state machine
state, actions, alert = state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert) logf(logTypeHandshake, "Error in state transition: %v", alert)
return alert return alert
} }
for index, action := range actions { for index, action := range actions {
logf(logTypeHandshake, "%s taking next action (%d)", label, index) logf(logTypeHandshake, "%s taking next action (%d)", label, index)
alert = c.takeAction(action) if alert := c.takeAction(action); alert != AlertNoAlert {
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert) logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert) c.sendAlert(alert)
return alert return alert
@ -719,14 +779,14 @@ func (c *Conn) Handshake() Alert {
c.hState = state c.hState = state
logf(logTypeHandshake, "state is now %s", c.GetHsState()) logf(logTypeHandshake, "state is now %s", c.GetHsState())
_, connected = state.(stateConnected)
if connected {
c.state = state.(stateConnected)
c.handshakeComplete = true
_, connected = state.(StateConnected) if !c.isClient {
} // Send NewSessionTicket if configured to
if c.config.SendSessionTickets {
c.state = state.(StateConnected)
// Send NewSessionTicket if acting as server
if !c.isClient && c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket( actions, alert := c.state.NewSessionTicket(
c.config.TicketLen, c.config.TicketLen,
c.config.TicketLifetime, c.config.TicketLifetime,
@ -742,7 +802,25 @@ func (c *Conn) Handshake() Alert {
} }
} }
c.handshakeComplete = true // If there is early data, move it into the main buffer
if c.hsCtx.earlyData != nil {
c.readBuffer = c.hsCtx.earlyData
c.hsCtx.earlyData = nil
}
} else {
assert(c.hsCtx.earlyData == nil)
}
}
if c.config.NonBlocking {
if alert == AlertStatelessRetry {
return AlertStatelessRetry
}
return AlertNoAlert
}
}
return AlertNoAlert return AlertNoAlert
} }
@ -775,12 +853,15 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
return nil return nil
} }
func (c *Conn) GetHsState() string { func (c *Conn) GetHsState() State {
return reflect.TypeOf(c.hState).Name() if c.hState == nil {
return StateInit
}
return c.hState.State()
} }
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
_, connected := c.hState.(StateConnected) _, connected := c.hState.(stateConnected)
if !connected { if !connected {
return nil, fmt.Errorf("Cannot compute exporter when state is not connected") return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
} }
@ -796,7 +877,7 @@ func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]b
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
} }
func (c *Conn) State() ConnectionState { func (c *Conn) ConnectionState() ConnectionState {
state := ConnectionState{ state := ConnectionState{
HandshakeState: c.GetHsState(), HandshakeState: c.GetHsState(),
} }
@ -804,16 +885,32 @@ func (c *Conn) State() ConnectionState {
if c.handshakeComplete { if c.handshakeComplete {
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
state.NextProto = c.state.Params.NextProto state.NextProto = c.state.Params.NextProto
state.VerifiedChains = c.state.verifiedChains
state.PeerCertificates = c.state.peerCertificates
state.UsingPSK = c.state.Params.UsingPSK
state.UsingEarlyData = c.state.Params.UsingEarlyData
} }
return state return state
} }
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error { func (c *Conn) Writable() bool {
if c.hState != nil { // If we're connected, we're writable.
return fmt.Errorf("Can't set extension handler after setup") if _, connected := c.hState.(stateConnected); connected {
return true
} }
c.extHandler = h // If we're a client in 0-RTT, then we're writable.
return nil if c.isClient && c.out.cipher.epoch == EpochEarlyData {
return true
}
return false
}
func (c *Conn) label() string {
if c.isClient {
return "client"
}
return "server"
} }

86
vendor/github.com/bifurcation/mint/cookie-protector.go generated vendored Normal file
View File

@ -0,0 +1,86 @@
package mint
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// CookieProtector is used to create and verify a cookie
type CookieProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const cookieSecretSize = 32
const cookieNonceSize = 32
// The DefaultCookieProtector is a simple implementation for the CookieProtector.
type DefaultCookieProtector struct {
secret []byte
}
var _ CookieProtector = &DefaultCookieProtector{}
// NewDefaultCookieProtector creates a source for source address tokens
func NewDefaultCookieProtector() (CookieProtector, error) {
secret := make([]byte, cookieSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &DefaultCookieProtector{secret: secret}, nil
}
// NewToken encodes data into a new token.
func (s *DefaultCookieProtector) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, cookieNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
}
// DecodeToken decodes a token.
func (s *DefaultCookieProtector) DecodeToken(p []byte) ([]byte, error) {
if len(p) < cookieNonceSize {
return nil, fmt.Errorf("Token too short: %d", len(p))
}
nonce := p[:cookieNonceSize]
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
}
func (s *DefaultCookieProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.secret, nonce, []byte("mint cookie source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err
}
aeadNonce := make([]byte, 12)
if _, err := io.ReadFull(h, aeadNonce); err != nil {
return nil, nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, nil, err
}
return aead, aeadNonce, nil
}

View File

@ -331,40 +331,6 @@ func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
} }
} }
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}
// XXX(rlb): Copied from crypto/x509 // XXX(rlb): Copied from crypto/x509
type ecdsaSignature struct { type ecdsaSignature struct {
R, S *big.Int R, S *big.Int
@ -652,3 +618,50 @@ func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
} }
} }
func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) {
priv, err := newSigningKey(alg)
if err != nil {
return nil, nil, err
}
cert, err := newSelfSigned(name, alg, priv)
if err != nil {
return nil, nil, err
}
return priv, cert, nil
}
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}

222
vendor/github.com/bifurcation/mint/dtls.go generated vendored Normal file
View File

@ -0,0 +1,222 @@
package mint
import (
"fmt"
"github.com/bifurcation/mint/syntax"
"time"
)
const (
initialMtu = 1200
initialTimeout = 100
)
// labels for timers
const (
retransmitTimerLabel = "handshake retransmit"
ackTimerLabel = "ack timer"
)
type SentHandshakeFragment struct {
seq uint32
offset int
fragLength int
record uint64
acked bool
}
type DtlsAck struct {
RecordNumbers []uint64 `tls:"head=2"`
}
func wireVersion(h *HandshakeLayer) uint16 {
if h.datagram {
return dtls12WireVersion
}
return tls12Version
}
func dtlsConvertVersion(version uint16) uint16 {
if version == tls12Version {
return dtls12WireVersion
}
if version == tls10Version {
return 0xfeff
}
panic(fmt.Sprintf("Internal error, unexpected version=%d", version))
}
// TODO(ekr@rtfm.com): Move these to state-machine.go
func (h *HandshakeContext) handshakeRetransmit() error {
if _, err := h.hOut.SendQueuedMessages(); err != nil {
return err
}
h.timers.start(retransmitTimerLabel,
h.handshakeRetransmit,
h.timeoutMS)
// TODO(ekr@rtfm.com): Back off timer
return nil
}
func (h *HandshakeContext) sendAck() error {
toack := h.hIn.recvdRecords
count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU
if len(toack) > count {
toack = toack[:count]
}
logf(logTypeHandshake, "Sending ACK: [%x]", toack)
ack := &DtlsAck{toack}
body, err := syntax.Marshal(&ack)
if err != nil {
return err
}
err = h.hOut.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAck,
fragment: body,
})
if err != nil {
return err
}
return nil
}
func (h *HandshakeContext) processAck(data []byte) error {
// Cancel the retransmit timer because we will be resending
// and possibly re-arming later.
h.timers.cancel(retransmitTimerLabel)
ack := &DtlsAck{}
read, err := syntax.Unmarshal(data, &ack)
if err != nil {
return err
}
if len(data) != read {
return fmt.Errorf("Invalid encoding: Extra data not consumed")
}
logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers)
for _, r := range ack.RecordNumbers {
for _, m := range h.sentFragments {
if r == m.record {
logf(logTypeHandshake, "Marking %v %v(%v) as acked",
m.seq, m.offset, m.fragLength)
m.acked = true
}
}
}
count, err := h.hOut.SendQueuedMessages()
if err != nil {
return err
}
if count == 0 {
logf(logTypeHandshake, "All messages ACKed")
h.hOut.ClearQueuedMessages()
return nil
}
// Reset the timer
h.timers.start(retransmitTimerLabel,
h.handshakeRetransmit,
h.timeoutMS)
return nil
}
func (c *Conn) GetDTLSTimeout() (bool, time.Duration) {
return c.hsCtx.timers.remaining()
}
func (h *HandshakeContext) receivedHandshakeMessage() {
logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight)
// This just enables tests.
if h.hIn == nil {
return
}
if !h.hIn.datagram {
return
}
if h.waitingNextFlight {
logf(logTypeHandshake, "Received the start of the flight")
// Clear the outgoing DTLS queue and terminate the retransmit timer
h.hOut.ClearQueuedMessages()
h.timers.cancel(retransmitTimerLabel)
// OK, we're not waiting any more.
h.waitingNextFlight = false
}
// Now pre-emptively arm the ACK timer if it's not armed already.
// We'll automatically dis-arm it at the end of the handshake.
if h.timers.getTimer(ackTimerLabel) == nil {
h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4)
}
}
func (h *HandshakeContext) receivedEndOfFlight() {
logf(logTypeHandshake, "%p Received the end of the flight", h)
if !h.hIn.datagram {
return
}
// Empty incoming queue
h.hIn.queued = nil
// Note that we are waiting for the next flight.
h.waitingNextFlight = true
// Clear the ACK queue.
h.hIn.recvdRecords = nil
// Disarm the ACK timer
h.timers.cancel(ackTimerLabel)
}
func (h *HandshakeContext) receivedFinalFlight() {
logf(logTypeHandshake, "%p Received final flight", h)
if !h.hIn.datagram {
return
}
// Disarm the ACK timer
h.timers.cancel(ackTimerLabel)
// But send an ACK immediately.
h.sendAck()
}
func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool {
logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen)
for _, f := range h.sentFragments {
if !f.acked {
continue
}
if f.seq != seq {
continue
}
if f.offset > offset {
continue
}
// At this point, we know that the stored fragment starts
// at or before what we want to send, so check where the end
// is.
if f.offset+f.fragLength < offset+fraglen {
continue
}
return true
}
return false
}

View File

@ -3,7 +3,6 @@ package mint
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/bifurcation/mint/syntax" "github.com/bifurcation/mint/syntax"
) )
@ -77,14 +76,40 @@ func (el *ExtensionList) Add(src ExtensionBody) error {
return nil return nil
} }
func (el ExtensionList) Find(dst ExtensionBody) bool { func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) {
found := make(map[ExtensionType]bool)
for _, dst := range dsts {
for _, ext := range el { for _, ext := range el {
if ext.ExtensionType == dst.Type() { if ext.ExtensionType == dst.Type() {
_, err := dst.Unmarshal(ext.ExtensionData) if found[dst.Type()] {
return err == nil return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type())
}
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return nil, err
}
found[dst.Type()] = true
} }
} }
return false }
return found, nil
}
func (el ExtensionList) Find(dst ExtensionBody) (bool, error) {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return true, err
}
return true, nil
}
}
return false, nil
} }
// struct { // struct {
@ -529,19 +554,56 @@ func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
// ProtocolVersion versions<2..254>; // ProtocolVersion versions<2..254>;
// } SupportedVersions; // } SupportedVersions;
type SupportedVersionsExtension struct { type SupportedVersionsExtension struct {
HandshakeType HandshakeType
Versions []uint16
}
type SupportedVersionsClientHelloInner struct {
Versions []uint16 `tls:"head=1,min=2,max=254"` Versions []uint16 `tls:"head=1,min=2,max=254"`
} }
type SupportedVersionsServerHelloInner struct {
Version uint16
}
func (sv SupportedVersionsExtension) Type() ExtensionType { func (sv SupportedVersionsExtension) Type() ExtensionType {
return ExtensionTypeSupportedVersions return ExtensionTypeSupportedVersions
} }
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sv) switch sv.HandshakeType {
case HandshakeTypeClientHello:
return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions})
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]})
default:
return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
}
} }
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sv) switch sv.HandshakeType {
case HandshakeTypeClientHello:
var inner SupportedVersionsClientHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
sv.Versions = inner.Versions
return read, nil
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
var inner SupportedVersionsServerHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
sv.Versions = []uint16{inner.Version}
return read, nil
default:
return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
}
} }
// struct { // struct {
@ -562,25 +624,3 @@ func (c CookieExtension) Marshal() ([]byte, error) {
func (c *CookieExtension) Unmarshal(data []byte) (int, error) { func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, c) return syntax.Unmarshal(data, c)
} }
// defaultCookieLength is the default length of a cookie
const defaultCookieLength = 32
type defaultCookieHandler struct {
data []byte
}
var _ CookieHandler = &defaultCookieHandler{}
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
h.data = make([]byte, defaultCookieLength)
if _, err := prng.Read(h.data); err != nil {
return nil, err
}
return h.data, nil
}
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
return bytes.Equal(h.data, data)
}

View File

@ -66,8 +66,8 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
f.remainder = f.remainder[copied:] f.remainder = f.remainder[copied:]
f.writeOffset += copied f.writeOffset += copied
if f.writeOffset < len(f.working) { if f.writeOffset < len(f.working) {
logf(logTypeFrameReader, "Read would have blocked 1") logf(logTypeVerbose, "Read would have blocked 1")
return nil, nil, WouldBlock return nil, nil, AlertWouldBlock
} }
// Reset the write offset, because we are now full. // Reset the write offset, because we are now full.
f.writeOffset = 0 f.writeOffset = 0
@ -93,6 +93,6 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
f.state = kFrameReaderBody f.state = kFrameReaderBody
} }
logf(logTypeFrameReader, "Read would have blocked 2") logf(logTypeVerbose, "Read would have blocked 2")
return nil, nil, WouldBlock return nil, nil, AlertWouldBlock
} }

View File

@ -7,7 +7,8 @@ import (
) )
const ( const (
handshakeHeaderLen = 4 // handshake message header length handshakeHeaderLenTLS = 4 // handshake message header length
handshakeHeaderLenDTLS = 12 // handshake message header length
maxHandshakeMessageLen = 1 << 24 // max handshake message length maxHandshakeMessageLen = 1 << 24 // max handshake message length
) )
@ -27,28 +28,42 @@ const (
// opaque msg<0..2^24-1> // opaque msg<0..2^24-1>
// } Handshake; // } Handshake;
// //
// TODO: File a spec bug
type HandshakeMessage struct { type HandshakeMessage struct {
// Omitted: length
msgType HandshakeType msgType HandshakeType
seq uint32
body []byte body []byte
datagram bool
offset uint32 // Used for DTLS
length uint32
cipher *cipherState
} }
// Note: This could be done with the `syntax` module, using the simplified // Note: This could be done with the `syntax` module, using the simplified
// syntax as discussed above. However, since this is so simple, there's not // syntax as discussed above. However, since this is so simple, there's not
// much benefit to doing so. // much benefit to doing so.
// When datagram is set, we marshal this as a whole DTLS record.
func (hm *HandshakeMessage) Marshal() []byte { func (hm *HandshakeMessage) Marshal() []byte {
if hm == nil { if hm == nil {
return []byte{} return []byte{}
} }
msgLen := len(hm.body) fragLen := len(hm.body)
data := make([]byte, 4+len(hm.body)) var data []byte
data[0] = byte(hm.msgType)
data[1] = byte(msgLen >> 16) if hm.datagram {
data[2] = byte(msgLen >> 8) data = make([]byte, handshakeHeaderLenDTLS+fragLen)
data[3] = byte(msgLen) } else {
copy(data[4:], hm.body) data = make([]byte, handshakeHeaderLenTLS+fragLen)
}
tmp := data
tmp = encodeUint(uint64(hm.msgType), 1, tmp)
tmp = encodeUint(uint64(hm.length), 3, tmp)
if hm.datagram {
tmp = encodeUint(uint64(hm.seq), 2, tmp)
tmp = encodeUint(uint64(hm.offset), 3, tmp)
tmp = encodeUint(uint64(fragLen), 3, tmp)
}
copy(tmp, hm.body)
return data return data
} }
@ -61,8 +76,6 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
body = new(ClientHelloBody) body = new(ClientHelloBody)
case HandshakeTypeServerHello: case HandshakeTypeServerHello:
body = new(ServerHelloBody) body = new(ServerHelloBody)
case HandshakeTypeHelloRetryRequest:
body = new(HelloRetryRequestBody)
case HandshakeTypeEncryptedExtensions: case HandshakeTypeEncryptedExtensions:
body = new(EncryptedExtensionsBody) body = new(EncryptedExtensionsBody)
case HandshakeTypeCertificate: case HandshakeTypeCertificate:
@ -83,62 +96,104 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
} }
_, err := body.Unmarshal(hm.body) err := safeUnmarshal(body, hm.body)
return body, err return body, err
} }
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
data, err := body.Marshal() data, err := body.Marshal()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &HandshakeMessage{ m := &HandshakeMessage{
msgType: body.Type(), msgType: body.Type(),
body: data, body: data,
}, nil seq: h.msgSeq,
datagram: h.datagram,
length: uint32(len(data)),
}
h.msgSeq++
return m, nil
} }
type HandshakeLayer struct { type HandshakeLayer struct {
ctx *HandshakeContext // The handshake we are attached to
nonblocking bool // Should we operate in nonblocking mode nonblocking bool // Should we operate in nonblocking mode
conn *RecordLayer // Used for reading/writing records conn *RecordLayer // Used for reading/writing records
frame *frameReader // The buffered frame reader frame *frameReader // The buffered frame reader
datagram bool // Is this DTLS?
msgSeq uint32 // The DTLS message sequence number
queued []*HandshakeMessage // In/out queue
sent []*HandshakeMessage // Sent messages for DTLS
recvdRecords []uint64 // Records we have received.
maxFragmentLen int
} }
type handshakeLayerFrameDetails struct{} type handshakeLayerFrameDetails struct {
datagram bool
}
func (d handshakeLayerFrameDetails) headerLen() int { func (d handshakeLayerFrameDetails) headerLen() int {
return handshakeHeaderLen if d.datagram {
return handshakeHeaderLenDTLS
}
return handshakeHeaderLenTLS
} }
func (d handshakeLayerFrameDetails) defaultReadLen() int { func (d handshakeLayerFrameDetails) defaultReadLen() int {
return handshakeHeaderLen + maxFragmentLen return d.headerLen() + maxFragmentLen
} }
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
logf(logTypeIO, "Header=%x", hdr) logf(logTypeIO, "Header=%x", hdr)
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil // The length of this fragment (as opposed to the message)
// is always the last three bytes for both TLS and DTLS
val, _ := decodeUint(hdr[len(hdr)-3:], 3)
return int(val), nil
} }
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer { func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
h := HandshakeLayer{} h := HandshakeLayer{}
h.ctx = c
h.conn = r h.conn = r
h.frame = newFrameReader(&handshakeLayerFrameDetails{}) h.datagram = false
h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
h.maxFragmentLen = maxFragmentLen
return &h
}
func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.ctx = c
h.conn = r
h.datagram = true
h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
h.maxFragmentLen = initialMtu // Not quite right
return &h return &h
} }
func (h *HandshakeLayer) readRecord() error { func (h *HandshakeLayer) readRecord() error {
logf(logTypeIO, "Trying to read record") logf(logTypeVerbose, "Trying to read record")
pt, err := h.conn.ReadRecord() pt, err := h.conn.readRecordAnyEpoch()
if err != nil { if err != nil {
return err return err
} }
if pt.contentType != RecordTypeHandshake && switch pt.contentType {
pt.contentType != RecordTypeAlert { case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
default:
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
} }
if pt.contentType == RecordTypeAck {
if !h.datagram {
return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
}
logf(logTypeIO, "read ACK")
return h.ctx.processAck(pt.fragment)
}
if pt.contentType == RecordTypeAlert { if pt.contentType == RecordTypeAlert {
logf(logTypeIO, "read alert %v", pt.fragment[1]) logf(logTypeIO, "read alert %v", pt.fragment[1])
if len(pt.fragment) < 2 { if len(pt.fragment) < 2 {
@ -148,7 +203,19 @@ func (h *HandshakeLayer) readRecord() error {
return Alert(pt.fragment[1]) return Alert(pt.fragment[1])
} }
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment)) assert(h.ctx.hIn.conn != nil)
if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
// This is out of order but we're dropping it.
// TODO(ekr@rtfm.com): If server, need to retransmit Finished.
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
return nil
}
// Anything else shouldn't happen.
return AlertIllegalParameter
}
h.recvdRecords = append(h.recvdRecords, pt.seq)
h.frame.addChunk(pt.fragment) h.frame.addChunk(pt.fragment)
return nil return nil
@ -171,83 +238,314 @@ func (h *HandshakeLayer) sendAlert(err Alert) error {
return nil return nil
} }
func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
h.msgSeq = seq + 1
var i int
var m *HandshakeMessage
for i, m = range h.queued {
if m.seq > seq {
break
}
}
h.queued = h.queued[i:]
}
func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
if hm.seq < h.msgSeq {
return nil, nil
}
// TODO(ekr@rtfm.com): Send an ACK immediately if we got something
// out of order.
h.ctx.receivedHandshakeMessage()
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
// TODO(ekr@rtfm.com): Check the length?
// This is complete.
h.noteMessageDelivered(hm.seq)
return hm, nil
}
// Now insert sorted.
var i int
for i = 0; i < len(h.queued); i++ {
f := h.queued[i]
if hm.seq < f.seq {
break
}
if hm.offset < f.offset {
break
}
}
tmp := make([]*HandshakeMessage, 0, len(h.queued)+1)
tmp = append(tmp, h.queued[:i]...)
tmp = append(tmp, hm)
tmp = append(tmp, h.queued[i:]...)
h.queued = tmp
return h.checkMessageAvailable()
}
func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
if len(h.queued) == 0 {
return nil, nil
}
hm := h.queued[0]
if hm.seq != h.msgSeq {
return nil, nil
}
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
// TODO(ekr@rtfm.com): Check the length?
// This is complete.
h.noteMessageDelivered(hm.seq)
return hm, nil
}
// OK, this at least might complete the message.
end := uint32(0)
buf := make([]byte, hm.length)
for _, f := range h.queued {
// Out of fragments
if f.seq > hm.seq {
break
}
if f.length != uint32(len(buf)) {
return nil, fmt.Errorf("Mismatched DTLS length")
}
if f.offset > end {
break
}
if f.offset+uint32(len(f.body)) > end {
// OK, this is adding something we don't know about
copy(buf[f.offset:], f.body)
end = f.offset + uint32(len(f.body))
if end == hm.length {
h2 := *hm
h2.offset = 0
h2.body = buf
h.noteMessageDelivered(hm.seq)
return &h2, nil
}
}
}
return nil, nil
}
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
var hdr, body []byte var hdr, body []byte
var err error var err error
for { hm, err := h.checkMessageAvailable()
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder)) if err != nil {
if h.frame.needed() > 0 {
logf(logTypeHandshake, "Trying to read a new record")
err = h.readRecord()
}
if err != nil && (h.nonblocking || err != WouldBlock) {
return nil, err return nil, err
} }
if hm != nil {
return hm, nil
}
for {
logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
if h.frame.needed() > 0 {
logf(logTypeVerbose, "Trying to read a new record")
err = h.readRecord()
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
return nil, err
}
}
hdr, body, err = h.frame.process() hdr, body, err = h.frame.process()
if err == nil { if err == nil {
break break
} }
if err != nil && (h.nonblocking || err != WouldBlock) { if err != nil && (h.nonblocking || err != AlertWouldBlock) {
return nil, err return nil, err
} }
} }
logf(logTypeHandshake, "read handshake message") logf(logTypeHandshake, "read handshake message")
hm := &HandshakeMessage{} hm = &HandshakeMessage{}
hm.msgType = HandshakeType(hdr[0]) hm.msgType = HandshakeType(hdr[0])
hm.datagram = h.datagram
hm.body = make([]byte, len(body)) hm.body = make([]byte, len(body))
copy(hm.body, body) copy(hm.body, body)
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
if h.datagram {
tmp, hdr := decodeUint(hdr[1:], 3)
hm.length = uint32(tmp)
tmp, hdr = decodeUint(hdr, 2)
hm.seq = uint32(tmp)
tmp, hdr = decodeUint(hdr, 3)
hm.offset = uint32(tmp)
return h.newFragmentReceived(hm)
}
hm.length = uint32(len(body))
return hm, nil return hm, nil
} }
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
return h.WriteMessages([]*HandshakeMessage{hm}) hm.cipher = h.conn.cipher
h.queued = append(h.queued, hm)
return nil
} }
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
logf(logTypeHandshake, "Sending outgoing messages")
count, err := h.WriteMessages(h.queued)
if !h.datagram {
h.ClearQueuedMessages()
}
return count, err
}
func (h *HandshakeLayer) ClearQueuedMessages() {
logf(logTypeHandshake, "Clearing outgoing hs message queue")
h.queued = nil
}
func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
var buf []byte
// Figure out if we're going to want the full header or just
// the body
hdrlen := 0
if hm.datagram {
hdrlen = handshakeHeaderLenDTLS
} else if start == 0 {
hdrlen = handshakeHeaderLenTLS
}
// Compute the amount of body we can fit in
room -= hdrlen
if room == 0 {
// This works because we are doing one record per
// message
panic("Too short max fragment len")
}
bodylen := len(hm.body) - start
if bodylen > room {
bodylen = room
}
body := hm.body[start : start+bodylen]
// Now see if this chunk has been ACKed. This doesn't produce ideal
// retransmission but is simple.
if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
return false, start + bodylen, nil
}
// Encode the data.
if hdrlen > 0 {
hm2 := *hm
hm2.offset = uint32(start)
hm2.body = body
buf = hm2.Marshal()
hm = &hm2
} else {
buf = body
}
if h.datagram {
// Remember that we sent this.
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
hm.seq,
start,
len(body),
h.conn.cipher.combineSeq(true),
false,
})
}
return true, start + bodylen, h.conn.writeRecordWithPadding(
&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buf,
},
hm.cipher, 0)
}
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
start := int(0)
if len(hm.body) > maxHandshakeMessageLen {
return 0, fmt.Errorf("Tried to write a handshake message that's too long")
}
written := 0
wrote := false
// Always make one pass through to allow EOED (which is empty).
for {
var err error
wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
if err != nil {
return 0, err
}
if wrote {
written++
}
if start >= len(hm.body) {
break
}
}
return written, nil
}
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
written := 0
for _, hm := range hms { for _, hm := range hms {
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
wrote, err := h.WriteMessage(hm)
if err != nil {
return 0, err
}
written += wrote
}
return written, nil
} }
// Write out headers and bodies func encodeUint(v uint64, size int, out []byte) []byte {
buffer := []byte{} for i := size - 1; i >= 0; i-- {
for _, msg := range hms { out[i] = byte(v & 0xff)
msgLen := len(msg.body) v >>= 8
if msgLen > maxHandshakeMessageLen { }
return fmt.Errorf("tls.handshakelayer: Message too large to send") return out[size:]
} }
buffer = append(buffer, msg.Marshal()...) func decodeUint(in []byte, size int) (uint64, []byte) {
val := uint64(0)
for i := 0; i < size; i++ {
val <<= 8
val += uint64(in[i])
}
return val, in[size:]
} }
// Send full-size fragments type marshalledPDU interface {
var start int Marshal() ([]byte, error)
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { Unmarshal(data []byte) (int, error)
err := h.conn.WriteRecord(&TLSPlaintext{ }
contentType: RecordTypeHandshake,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return err
}
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buffer[start:],
})
func safeUnmarshal(pdu marshalledPDU, data []byte) error {
read, err := pdu.Unmarshal(data)
if err != nil { if err != nil {
return err return err
} }
if len(data) != read {
return fmt.Errorf("Invalid encoding: Extra data not consumed")
} }
return nil return nil
} }

View File

@ -25,15 +25,14 @@ type HandshakeMessageBody interface {
// Extension extensions<0..2^16-1>; // Extension extensions<0..2^16-1>;
// } ClientHello; // } ClientHello;
type ClientHelloBody struct { type ClientHelloBody struct {
// Omitted: clientVersion LegacyVersion uint16
// Omitted: legacySessionID
// Omitted: legacyCompressionMethods
Random [32]byte Random [32]byte
LegacySessionID []byte
CipherSuites []CipherSuite CipherSuites []CipherSuite
Extensions ExtensionList Extensions ExtensionList
} }
type clientHelloBodyInner struct { type clientHelloBodyInnerTLS struct {
LegacyVersion uint16 LegacyVersion uint16
Random [32]byte Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"` LegacySessionID []byte `tls:"head=1,max=32"`
@ -42,13 +41,33 @@ type clientHelloBodyInner struct {
Extensions []Extension `tls:"head=2"` Extensions []Extension `tls:"head=2"`
} }
type clientHelloBodyInnerDTLS struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
EmptyCookie uint8
CipherSuites []CipherSuite `tls:"head=2,min=2"`
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
Extensions []Extension `tls:"head=2"`
}
func (ch ClientHelloBody) Type() HandshakeType { func (ch ClientHelloBody) Type() HandshakeType {
return HandshakeTypeClientHello return HandshakeTypeClientHello
} }
func (ch ClientHelloBody) Marshal() ([]byte, error) { func (ch ClientHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(clientHelloBodyInner{ if ch.LegacyVersion == tls12Version {
LegacyVersion: 0x0303, return syntax.Marshal(clientHelloBodyInnerTLS{
LegacyVersion: ch.LegacyVersion,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
} else {
return syntax.Marshal(clientHelloBodyInnerDTLS{
LegacyVersion: ch.LegacyVersion,
Random: ch.Random, Random: ch.Random,
LegacySessionID: []byte{}, LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites, CipherSuites: ch.CipherSuites,
@ -57,25 +76,51 @@ func (ch ClientHelloBody) Marshal() ([]byte, error) {
}) })
} }
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
var inner clientHelloBodyInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
} }
// We are strict about these things because we only support 1.3 func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
if inner.LegacyVersion != 0x0303 { var read int
return 0, fmt.Errorf("tls.clienthello: Incorrect version number") var err error
// Note that this might be 0, in which case we do TLS. That
// makes the tests easier.
if ch.LegacyVersion != dtls12WireVersion {
var inner clientHelloBodyInnerTLS
read, err = syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
} }
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method") return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
} }
ch.LegacyVersion = inner.LegacyVersion
ch.Random = inner.Random ch.Random = inner.Random
ch.LegacySessionID = inner.LegacySessionID
ch.CipherSuites = inner.CipherSuites ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions ch.Extensions = inner.Extensions
} else {
var inner clientHelloBodyInnerDTLS
read, err = syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if inner.EmptyCookie != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid cookie")
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.LegacyVersion = inner.LegacyVersion
ch.Random = inner.Random
ch.LegacySessionID = inner.LegacySessionID
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
}
return read, nil return read, nil
} }
@ -90,10 +135,15 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) {
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
} }
chm, err := HandshakeMessageFromBody(&ch) body, err := ch.Marshal()
if err != nil { if err != nil {
return nil, err return nil, err
} }
chm := &HandshakeMessage{
msgType: ch.Type(),
body: body,
length: uint32(len(body)),
}
chData := chm.Marshal() chData := chm.Marshal()
psk := PreSharedKeyExtension{ psk := PreSharedKeyExtension{
@ -116,38 +166,19 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) {
} }
// struct { // struct {
// ProtocolVersion server_version; // ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// CipherSuite cipher_suite;
// Extension extensions<2..2^16-1>;
// } HelloRetryRequest;
type HelloRetryRequestBody struct {
Version uint16
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2,min=2"`
}
func (hrr HelloRetryRequestBody) Type() HandshakeType {
return HandshakeTypeHelloRetryRequest
}
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(hrr)
}
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, hrr)
}
// struct {
// ProtocolVersion version;
// Random random; // Random random;
// opaque legacy_session_id_echo<0..32>;
// CipherSuite cipher_suite; // CipherSuite cipher_suite;
// Extension extensions<0..2^16-1>; // uint8 legacy_compression_method = 0;
// Extension extensions<6..2^16-1>;
// } ServerHello; // } ServerHello;
type ServerHelloBody struct { type ServerHelloBody struct {
Version uint16 Version uint16
Random [32]byte Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
CipherSuite CipherSuite CipherSuite CipherSuite
LegacyCompressionMethod uint8
Extensions ExtensionList `tls:"head=2"` Extensions ExtensionList `tls:"head=2"`
} }

View File

@ -148,7 +148,7 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme
} }
if len(candidatesByName) == 0 { if len(candidatesByName) == 0 {
return nil, 0, fmt.Errorf("No certificates available for server name") return nil, 0, fmt.Errorf("No certificates available for server name: %s", *serverName)
} }
candidates = candidatesByName candidates = candidatesByName
@ -168,10 +168,11 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
} }
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) {
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData using = gotEarlyData && usingPSK && allowEarlyData
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) rejected = gotEarlyData && !using
return usingEarlyData logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected)
return
} }
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {

View File

@ -1,7 +1,6 @@
package mint package mint
import ( import (
"bytes"
"crypto/cipher" "crypto/cipher"
"fmt" "fmt"
"io" "io"
@ -10,7 +9,8 @@ import (
const ( const (
sequenceNumberLen = 8 // sequence number length sequenceNumberLen = 8 // sequence number length
recordHeaderLen = 5 // record header length recordHeaderLenTLS = 5 // record header length (TLS)
recordHeaderLenDTLS = 13 // record header length (DTLS)
maxFragmentLen = 1 << 14 // max number of bytes in a record maxFragmentLen = 1 << 14 // max number of bytes in a record
) )
@ -20,9 +20,16 @@ func (err DecryptError) Error() string {
return string(err) return string(err)
} }
type direction uint8
const (
directionWrite = direction(1)
directionRead = direction(2)
)
// struct { // struct {
// ContentType type; // ContentType type;
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */ // ProtocolVersion record_version [0301 for CH, 0303 for others]
// uint16 length; // uint16 length;
// opaque fragment[TLSPlaintext.length]; // opaque fragment[TLSPlaintext.length];
// } TLSPlaintext; // } TLSPlaintext;
@ -30,87 +37,177 @@ type TLSPlaintext struct {
// Omitted: record_version (static) // Omitted: record_version (static)
// Omitted: length (computed from fragment) // Omitted: length (computed from fragment)
contentType RecordType contentType RecordType
epoch Epoch
seq uint64
fragment []byte fragment []byte
} }
type cipherState struct {
epoch Epoch // DTLS epoch
ivLength int // Length of the seq and nonce fields
seq uint64 // Zero-padded sequence number
iv []byte // Buffer for the IV
cipher cipher.AEAD // AEAD cipher
}
type RecordLayer struct { type RecordLayer struct {
sync.Mutex sync.Mutex
label string
direction direction
version uint16 // The current version number
conn io.ReadWriter // The underlying connection conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader frame *frameReader // The buffered frame reader
nextData []byte // The next record to send nextData []byte // The next record to send
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
cachedError error // Error on the last record read cachedError error // Error on the last record read
ivLength int // Length of the seq and nonce fields cipher *cipherState
seq []byte // Zero-padded sequence number readCiphers map[Epoch]*cipherState
nonce []byte // Buffer for per-record nonces
cipher cipher.AEAD // AEAD cipher datagram bool
} }
type recordLayerFrameDetails struct{} type recordLayerFrameDetails struct {
datagram bool
}
func (d recordLayerFrameDetails) headerLen() int { func (d recordLayerFrameDetails) headerLen() int {
return recordHeaderLen if d.datagram {
return recordHeaderLenDTLS
}
return recordHeaderLenTLS
} }
func (d recordLayerFrameDetails) defaultReadLen() int { func (d recordLayerFrameDetails) defaultReadLen() int {
return recordHeaderLen + maxFragmentLen return d.headerLen() + maxFragmentLen
} }
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return (int(hdr[3]) << 8) | int(hdr[4]), nil return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
} }
func NewRecordLayer(conn io.ReadWriter) *RecordLayer { func newCipherStateNull() *cipherState {
return &cipherState{EpochClear, 0, 0, nil, nil}
}
func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) {
cipher, err := factory(key)
if err != nil {
return nil, err
}
return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
}
func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer {
r := RecordLayer{} r := RecordLayer{}
r.label = ""
r.direction = dir
r.conn = conn r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{}) r.frame = newFrameReader(recordLayerFrameDetails{false})
r.ivLength = 0 r.cipher = newCipherStateNull()
r.version = tls10Version
return &r return &r
} }
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error { func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer {
var err error r := RecordLayer{}
r.cipher, err = cipher(key) r.label = ""
r.direction = dir
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{true})
r.cipher = newCipherStateNull()
r.readCiphers = make(map[Epoch]*cipherState, 0)
r.readCiphers[0] = r.cipher
r.datagram = true
return &r
}
func (r *RecordLayer) SetVersion(v uint16) {
r.version = v
}
func (r *RecordLayer) ResetClear(seq uint64) {
r.cipher = newCipherStateNull()
r.cipher.seq = seq
}
func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
cipher, err := newCipherStateAead(epoch, factory, key, iv)
if err != nil { if err != nil {
return err return err
} }
r.cipher = cipher
r.ivLength = len(iv) if r.datagram && r.direction == directionRead {
r.seq = bytes.Repeat([]byte{0}, r.ivLength) r.readCiphers[epoch] = cipher
r.nonce = make([]byte, r.ivLength) }
copy(r.nonce, iv)
return nil return nil
} }
func (r *RecordLayer) incrementSequenceNumber() { // TODO(ekr@rtfm.com): This is never used, which is a bug.
if r.ivLength == 0 { func (r *RecordLayer) DiscardReadKey(epoch Epoch) {
if !r.datagram {
return return
} }
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- { _, ok := r.readCiphers[epoch]
r.seq[i]++ assert(ok)
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i] delete(r.readCiphers, epoch)
if r.seq[i] != 0 {
return
}
} }
func (c *cipherState) combineSeq(datagram bool) uint64 {
seq := c.seq
if datagram {
seq |= uint64(c.epoch) << 48
}
return seq
}
func (c *cipherState) computeNonce(seq uint64) []byte {
nonce := make([]byte, len(c.iv))
copy(nonce, c.iv)
s := seq
offset := len(c.iv)
for i := 0; i < 8; i++ {
nonce[(offset-i)-1] ^= byte(s & 0xff)
s >>= 8
}
logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
return nonce
}
func (c *cipherState) incrementSequenceNumber() {
if c.seq >= (1<<48 - 1) {
// Not allowed to let sequence number wrap. // Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does. // Instead, must renegotiate before it does.
// Not likely enough to bother. // Not likely enough to bother. This is the
// DTLS limit.
panic("TLS: sequence number wraparound") panic("TLS: sequence number wraparound")
} }
c.seq++
}
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext { func (c *cipherState) overhead() int {
if c.cipher == nil {
return 0
}
return c.cipher.Overhead()
}
func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext {
assert(r.direction == directionWrite)
logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
// Expand the fragment to hold contentType, padding, and overhead // Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment) originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen plaintextLen := originalLen + 1 + padLen
ciphertextLen := plaintextLen + r.cipher.Overhead() ciphertextLen := plaintextLen + cipher.overhead()
// Assemble the revised plaintext // Assemble the revised plaintext
out := &TLSPlaintext{ out := &TLSPlaintext{
contentType: RecordTypeApplicationData, contentType: RecordTypeApplicationData,
fragment: make([]byte, ciphertextLen), fragment: make([]byte, ciphertextLen),
} }
@ -122,25 +219,28 @@ func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
// Encrypt the fragment // Encrypt the fragment
payload := out.fragment[:plaintextLen] payload := out.fragment[:plaintextLen]
r.cipher.Seal(payload[:0], r.nonce, payload, nil) cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil)
return out return out
} }
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) { func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) {
if len(pt.fragment) < r.cipher.Overhead() { assert(r.direction == directionRead)
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead()) logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
if len(pt.fragment) < r.cipher.overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
return nil, 0, DecryptError(msg) return nil, 0, DecryptError(msg)
} }
decryptLen := len(pt.fragment) - r.cipher.Overhead() decryptLen := len(pt.fragment) - r.cipher.overhead()
out := &TLSPlaintext{ out := &TLSPlaintext{
contentType: pt.contentType, contentType: pt.contentType,
fragment: make([]byte, decryptLen), fragment: make([]byte, decryptLen),
} }
// Decrypt // Decrypt
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil) _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil)
if err != nil { if err != nil {
logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
} }
@ -155,6 +255,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
// Truncate the message to remove contentType, padding, overhead // Truncate the message to remove contentType, padding, overhead
out.fragment = out.fragment[:newLen] out.fragment = out.fragment[:newLen]
out.seq = seq
return out, padLen, nil return out, padLen, nil
} }
@ -163,11 +264,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
var err error var err error
for { for {
pt, err = r.nextRecord() pt, err = r.nextRecord(false)
if err == nil { if err == nil {
break break
} }
if !block || err != WouldBlock { if !block || err != AlertWouldBlock {
return 0, err return 0, err
} }
} }
@ -175,7 +276,7 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
} }
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
pt, err := r.nextRecord() pt, err := r.nextRecord(false)
// Consume the cached record if there was one // Consume the cached record if there was one
r.cachedRecord = nil r.cachedRecord = nil
@ -184,9 +285,20 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
return pt, err return pt, err
} }
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) {
pt, err := r.nextRecord(true)
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
cipher := r.cipher
if r.cachedRecord != nil { if r.cachedRecord != nil {
logf(logTypeIO, "Returning cached record") logf(logTypeIO, "%s Returning cached record", r.label)
return r.cachedRecord, r.cachedError return r.cachedRecord, r.cachedError
} }
@ -194,34 +306,35 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
// //
// 1. We get a frame // 1. We get a frame
// 2. We try to read off the socket and get nothing, in which case // 2. We try to read off the socket and get nothing, in which case
// return WouldBlock // returnAlertWouldBlock
// 3. We get an error. // 3. We get an error.
err := WouldBlock var err error
err = AlertWouldBlock
var header, body []byte var header, body []byte
for err != nil { for err != nil {
if r.frame.needed() > 0 { if r.frame.needed() > 0 {
buf := make([]byte, recordHeaderLen+maxFragmentLen) buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
n, err := r.conn.Read(buf) n, err := r.conn.Read(buf)
if err != nil { if err != nil {
logf(logTypeIO, "Error reading, %v", err) logf(logTypeIO, "%s Error reading, %v", r.label, err)
return nil, err return nil, err
} }
if n == 0 { if n == 0 {
return nil, WouldBlock return nil, AlertWouldBlock
} }
logf(logTypeIO, "Read %v bytes", n) logf(logTypeIO, "%s Read %v bytes", r.label, n)
buf = buf[:n] buf = buf[:n]
r.frame.addChunk(buf) r.frame.addChunk(buf)
} }
header, body, err = r.frame.process() header, body, err = r.frame.process()
// Loop around on WouldBlock to see if some // Loop around onAlertWouldBlock to see if some
// data is now available. // data is now available.
if err != nil && err != WouldBlock { if err != nil && err != AlertWouldBlock {
return nil, err return nil, err
} }
} }
@ -231,7 +344,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
switch RecordType(header[0]) { switch RecordType(header[0]) {
default: default:
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
pt.contentType = RecordType(header[0]) pt.contentType = RecordType(header[0])
} }
@ -241,7 +354,8 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
} }
// Validate size < max // Validate size < max
size := (int(header[3]) << 8) + int(header[4]) size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
if size > maxFragmentLen+256 { if size > maxFragmentLen+256 {
return nil, fmt.Errorf("tls.record: Ciphertext size too big") return nil, fmt.Errorf("tls.record: Ciphertext size too big")
} }
@ -249,33 +363,67 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
pt.fragment = make([]byte, size) pt.fragment = make([]byte, size)
copy(pt.fragment, body) copy(pt.fragment, body)
// TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
// Attempt to decrypt fragment // Attempt to decrypt fragment
if r.cipher != nil { seq := cipher.seq
pt, _, err = r.decrypt(pt) if r.datagram {
// TODO(ekr@rtfm.com): Handle duplicates.
seq, _ = decodeUint(header[3:11], 8)
epoch := Epoch(seq >> 48)
// Look up the cipher suite from the epoch
c, ok := r.readCiphers[epoch]
if !ok {
logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
return nil, AlertWouldBlock
}
if epoch != cipher.epoch {
logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
cipher.epoch, allowOldEpoch)
if !allowOldEpoch {
return nil, AlertWouldBlock
}
cipher = c
}
}
if cipher.cipher != nil {
logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
pt, _, err = r.decrypt(pt, seq)
if err != nil { if err != nil {
logf(logTypeIO, "%s Decryption failed", r.label)
return nil, err return nil, err
} }
} }
pt.epoch = cipher.epoch
// Check that plaintext length is not too long // Check that plaintext length is not too long
if len(pt.fragment) > maxFragmentLen { if len(pt.fragment) > maxFragmentLen {
return nil, fmt.Errorf("tls.record: Plaintext size too big") return nil, fmt.Errorf("tls.record: Plaintext size too big")
} }
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
r.cachedRecord = pt r.cachedRecord = pt
r.incrementSequenceNumber() cipher.incrementSequenceNumber()
return pt, nil return pt, nil
} }
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
return r.WriteRecordWithPadding(pt, 0) return r.writeRecordWithPadding(pt, r.cipher, 0)
} }
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
if r.cipher != nil { return r.writeRecordWithPadding(pt, r.cipher, padLen)
pt = r.encrypt(pt, padLen) }
func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
seq := cipher.combineSeq(r.datagram)
if cipher.cipher != nil {
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
pt = r.encrypt(cipher, seq, pt, padLen)
} else if padLen > 0 { } else if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records") return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
} }
@ -285,12 +433,26 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error
} }
length := len(pt.fragment) length := len(pt.fragment)
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} var header []byte
if !r.datagram {
header = []byte{byte(pt.contentType),
byte(r.version >> 8), byte(r.version & 0xff),
byte(length >> 8), byte(length)}
} else {
header = make([]byte, 13)
version := dtlsConvertVersion(r.version)
copy(header, []byte{byte(pt.contentType),
byte(version >> 8), byte(version & 0xff),
})
encodeUint(seq, 8, header[3:])
encodeUint(uint64(length), 2, header[11:])
}
record := append(header, pt.fragment...) record := append(header, pt.fragment...)
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment) logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
r.incrementSequenceNumber() cipher.incrementSequenceNumber()
_, err := r.conn.Write(record) _, err := r.conn.Write(record)
return err return err
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
package mint package mint
import ( import (
"crypto/x509"
"time" "time"
) )
@ -8,32 +9,35 @@ import (
// state transitions. // state transitions.
type HandshakeAction interface{} type HandshakeAction interface{}
type SendHandshakeMessage struct { type QueueHandshakeMessage struct {
Message *HandshakeMessage Message *HandshakeMessage
} }
type SendQueuedHandshake struct{}
type SendEarlyData struct{} type SendEarlyData struct{}
type ReadEarlyData struct{}
type ReadPastEarlyData struct{}
type RekeyIn struct { type RekeyIn struct {
Label string epoch Epoch
KeySet keySet KeySet keySet
} }
type RekeyOut struct { type RekeyOut struct {
Label string epoch Epoch
KeySet keySet KeySet keySet
} }
type ResetOut struct {
seq uint64
}
type StorePSK struct { type StorePSK struct {
PSK PreSharedKey PSK PreSharedKey
} }
type HandshakeState interface { type HandshakeState interface {
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) Next(handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert)
State() State
} }
type AppExtensionHandler interface { type AppExtensionHandler interface {
@ -41,35 +45,11 @@ type AppExtensionHandler interface {
Receive(hs HandshakeType, el *ExtensionList) error Receive(hs HandshakeType, el *ExtensionList) error
} }
// Capabilities objects represent the capabilities of a TLS client or server,
// as an input to TLS negotiation
type Capabilities struct {
// For both client and server
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
PSKs PreSharedKeyCache
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
ExtensionHandler AppExtensionHandler
// For client
PSKModes []PSKKeyExchangeMode
// For server
NextProtos []string
AllowEarlyData bool
RequireCookie bool
CookieHandler CookieHandler
RequireClientAuth bool
}
// ConnectionOptions objects represent per-connection settings for a client // ConnectionOptions objects represent per-connection settings for a client
// initiating a connection // initiating a connection
type ConnectionOptions struct { type ConnectionOptions struct {
ServerName string ServerName string
NextProtos []string NextProtos []string
EarlyData []byte
} }
// ConnectionParameters objects represent the parameters negotiated for a // ConnectionParameters objects represent the parameters negotiated for a
@ -79,6 +59,7 @@ type ConnectionParameters struct {
UsingDH bool UsingDH bool
ClientSendingEarlyData bool ClientSendingEarlyData bool
UsingEarlyData bool UsingEarlyData bool
RejectedEarlyData bool
UsingClientAuth bool UsingClientAuth bool
CipherSuite CipherSuite CipherSuite CipherSuite
@ -86,18 +67,50 @@ type ConnectionParameters struct {
NextProto string NextProto string
} }
// StateConnected is symmetric between client and server // Working state for the handshake.
type StateConnected struct { type HandshakeContext struct {
timeoutMS uint32
timers *timerSet
recvdRecords []uint64
sentFragments []*SentHandshakeFragment
hIn, hOut *HandshakeLayer
waitingNextFlight bool
earlyData []byte
}
func (hc *HandshakeContext) SetVersion(version uint16) {
if hc.hIn.conn != nil {
hc.hIn.conn.SetVersion(version)
}
if hc.hOut.conn != nil {
hc.hOut.conn.SetVersion(version)
}
}
// stateConnected is symmetric between client and server
type stateConnected struct {
Params ConnectionParameters Params ConnectionParameters
hsCtx *HandshakeContext
isClient bool isClient bool
cryptoParams CipherSuiteParams cryptoParams CipherSuiteParams
resumptionSecret []byte resumptionSecret []byte
clientTrafficSecret []byte clientTrafficSecret []byte
serverTrafficSecret []byte serverTrafficSecret []byte
exporterSecret []byte exporterSecret []byte
peerCertificates []*x509.Certificate
verifiedChains [][]*x509.Certificate
} }
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { var _ HandshakeState = &stateConnected{}
func (state stateConnected) State() State {
if state.isClient {
return StateClientConnected
}
return StateServerConnected
}
func (state *stateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
var trafficKeys keySet var trafficKeys keySet
if state.isClient { if state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
@ -109,20 +122,21 @@ func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAct
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
} }
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) kum, err := state.hsCtx.hOut.HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
if err != nil { if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
return nil, AlertInternalError return nil, AlertInternalError
} }
toSend := []HandshakeAction{ toSend := []HandshakeAction{
SendHandshakeMessage{kum}, QueueHandshakeMessage{kum},
RekeyOut{Label: "update", KeySet: trafficKeys}, SendQueuedHandshake{},
RekeyOut{epoch: EpochUpdate, KeySet: trafficKeys},
} }
return toSend, AlertNoAlert return toSend, AlertNoAlert
} }
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { func (state *stateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
tkt, err := NewSessionTicket(length, lifetime) tkt, err := NewSessionTicket(length, lifetime)
if err != nil { if err != nil {
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
@ -149,7 +163,7 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif
TicketAgeAdd: tkt.TicketAgeAdd, TicketAgeAdd: tkt.TicketAgeAdd,
} }
tktm, err := HandshakeMessageFromBody(tkt) tktm, err := state.hsCtx.hOut.HandshakeMessageFromBody(tkt)
if err != nil { if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
return nil, AlertInternalError return nil, AlertInternalError
@ -157,12 +171,18 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif
toSend := []HandshakeAction{ toSend := []HandshakeAction{
StorePSK{newPSK}, StorePSK{newPSK},
SendHandshakeMessage{tktm}, QueueHandshakeMessage{tktm},
SendQueuedHandshake{},
} }
return toSend, AlertNoAlert return toSend, AlertNoAlert
} }
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { // Next does nothing for this state.
func (state stateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
return state, nil, AlertNoAlert
}
func (state stateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil { if hm == nil {
logf(logTypeHandshake, "[StateConnected] Unexpected message") logf(logTypeHandshake, "[StateConnected] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
@ -187,20 +207,18 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
} }
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}} toSend := []HandshakeAction{RekeyIn{epoch: EpochUpdate, KeySet: trafficKeys}}
// If requested, roll outbound keys and send a KeyUpdate // If requested, roll outbound keys and send a KeyUpdate
if body.KeyUpdateRequest == KeyUpdateRequested { if body.KeyUpdateRequest == KeyUpdateRequested {
logf(logTypeHandshake, "Received key update, update requested", body.KeyUpdateRequest)
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
if alert != AlertNoAlert { if alert != AlertNoAlert {
return nil, nil, alert return nil, nil, alert
} }
toSend = append(toSend, moreToSend...) toSend = append(toSend, moreToSend...)
} }
return state, toSend, AlertNoAlert return state, toSend, AlertNoAlert
case *NewSessionTicketBody: case *NewSessionTicketBody:
// XXX: Allow NewSessionTicket in both directions? // XXX: Allow NewSessionTicket in both directions?
if !state.isClient { if !state.isClient {
@ -209,7 +227,6 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
psk := PreSharedKey{ psk := PreSharedKey{
CipherSuite: state.cryptoParams.Suite, CipherSuite: state.cryptoParams.Suite,
IsResumption: true, IsResumption: true,

View File

@ -72,3 +72,13 @@ The available annotations right now are all related to vectors:
fragment[TLSPlaintext.length]`. Note, however, that in cases where the length fragment[TLSPlaintext.length]`. Note, however, that in cases where the length
immediately preceds the array, these can be reframed as vectors with immediately preceds the array, these can be reframed as vectors with
appropriate sizes. appropriate sizes.
QUIC Extensions Syntax
======================
syntax also supports some minor extensions to allow implementing QUIC.
* The `varint` annotation describes a QUIC-style varint
* `head=none` means no header, i.e., the bytes are encoded directly on the wire.
On reading, the decoder will consume all available data.
* `head=varint` means to encode the header as a varint

View File

@ -16,12 +16,22 @@ func Unmarshal(data []byte, v interface{}) (int, error) {
return d.unmarshal(v) return d.unmarshal(v)
} }
// Unmarshaler is the interface implemented by types that can
// unmarshal a TLS description of themselves. Note that unlike the
// JSON unmarshaler interface, it is not known a priori how much of
// the input data will be consumed. So the Unmarshaler must state
// how much of the input data it consumed.
type Unmarshaler interface {
UnmarshalTLS([]byte) (int, error)
}
// These are the options that can be specified in the struct tag. Right now, // These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else // all of them apply to variable-length vectors and nothing else
type decOpts struct { type decOpts struct {
head uint // length of length in bytes head uint // length of length in bytes
min uint // minimum size in bytes min uint // minimum size in bytes
max uint // maximum size in bytes max uint // maximum size in bytes
varint bool // whether to decode as a varint
} }
type decodeState struct { type decodeState struct {
@ -65,8 +75,14 @@ func typeDecoder(t reflect.Type) decoderFunc {
return newTypeDecoder(t) return newTypeDecoder(t)
} }
var (
unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
)
func newTypeDecoder(t reflect.Type) decoderFunc { func newTypeDecoder(t reflect.Type) decoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument if t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(unmarshalerType) {
return unmarshalerDecoder
}
switch t.Kind() { switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@ -77,6 +93,8 @@ func newTypeDecoder(t reflect.Type) decoderFunc {
return newSliceDecoder(t) return newSliceDecoder(t)
case reflect.Struct: case reflect.Struct:
return newStructDecoder(t) return newStructDecoder(t)
case reflect.Ptr:
return newPointerDecoder(t)
default: default:
panic(fmt.Errorf("Unsupported type (%s)", t)) panic(fmt.Errorf("Unsupported type (%s)", t))
} }
@ -84,35 +102,87 @@ func newTypeDecoder(t reflect.Type) decoderFunc {
///// Specific decoders below ///// Specific decoders below
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { func unmarshalerDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
var uintLen int um, ok := v.Interface().(Unmarshaler)
switch v.Elem().Kind() { if !ok {
case reflect.Uint8: panic(fmt.Errorf("Non-Unmarshaler passed to unmarshalerEncoder"))
uintLen = 1
case reflect.Uint16:
uintLen = 2
case reflect.Uint32:
uintLen = 4
case reflect.Uint64:
uintLen = 8
} }
buf := make([]byte, uintLen) read, err := um.UnmarshalTLS(d.Bytes())
n, err := d.Read(buf)
if err != nil { if err != nil {
panic(err) panic(err)
} }
if n != uintLen {
if read > d.Len() {
panic(fmt.Errorf("Invalid return value from UnmarshalTLS"))
}
d.Next(read)
return read
}
//////////
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
if opts.varint {
return varintDecoder(d, v, opts)
}
uintLen := int(v.Elem().Type().Size())
buf := d.Next(uintLen)
if len(buf) != uintLen {
panic(fmt.Errorf("Insufficient data to read uint")) panic(fmt.Errorf("Insufficient data to read uint"))
} }
return setUintFromBuffer(v, buf)
}
func varintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
l, val := readVarint(d)
uintLen := int(v.Elem().Type().Size())
if uintLen < l {
panic(fmt.Errorf("Uint too small to fit varint: %d < %d", uintLen, l))
}
v.Elem().SetUint(val)
return l
}
func readVarint(d *decodeState) (int, uint64) {
// Read the first octet and decide the size of the presented varint
first := d.Next(1)
if len(first) != 1 {
panic(fmt.Errorf("Insufficient data to read varint length"))
}
twoBits := uint(first[0] >> 6)
varintLen := 1 << twoBits
rest := d.Next(varintLen - 1)
if len(rest) != varintLen-1 {
panic(fmt.Errorf("Insufficient data to read varint"))
}
buf := append(first, rest...)
buf[0] &= 0x3f
return len(buf), decodeUintFromBuffer(buf)
}
func decodeUintFromBuffer(buf []byte) uint64 {
val := uint64(0) val := uint64(0)
for _, b := range buf { for _, b := range buf {
val = (val << 8) + uint64(b) val = (val << 8) + uint64(b)
} }
v.Elem().SetUint(val) return val
return uintLen }
func setUintFromBuffer(v reflect.Value, buf []byte) int {
v.Elem().SetUint(decodeUintFromBuffer(buf))
return len(buf)
} }
////////// //////////
@ -143,44 +213,57 @@ type sliceDecoder struct {
} }
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
var length uint64
var read int
var data []byte
if opts.head == 0 { if opts.head == 0 {
panic(fmt.Errorf("Cannot decode a slice without a header length")) panic(fmt.Errorf("Cannot decode a slice without a header length"))
} }
lengthBytes := make([]byte, opts.head) // If the caller indicated there is no header, then read everything from the buffer
n, err := d.Read(lengthBytes) if opts.head == headValueNoHead {
if err != nil { for {
panic(err) chunk := d.Next(1024)
data = append(data, chunk...)
if len(chunk) != 1024 {
break
} }
if uint(n) != opts.head {
panic(fmt.Errorf("Not enough data to read header"))
} }
length = uint64(len(data))
length := uint(0) if opts.max > 0 && length > uint64(opts.max) {
for _, b := range lengthBytes {
length = (length << 8) + uint(b)
}
if opts.max > 0 && length > opts.max {
panic(fmt.Errorf("Length of vector exceeds declared max")) panic(fmt.Errorf("Length of vector exceeds declared max"))
} }
if length < opts.min { if length < uint64(opts.min) {
panic(fmt.Errorf("Length of vector below declared min"))
}
} else {
if opts.head != headValueVarint {
lengthBytes := d.Next(int(opts.head))
if len(lengthBytes) != int(opts.head) {
panic(fmt.Errorf("Not enough data to read header"))
}
read = len(lengthBytes)
length = decodeUintFromBuffer(lengthBytes)
} else {
read, length = readVarint(d)
}
if opts.max > 0 && length > uint64(opts.max) {
panic(fmt.Errorf("Length of vector exceeds declared max"))
}
if length < uint64(opts.min) {
panic(fmt.Errorf("Length of vector below declared min")) panic(fmt.Errorf("Length of vector below declared min"))
} }
data := make([]byte, length) data = d.Next(int(length))
n, err = d.Read(data) if len(data) != int(length) {
if err != nil { panic(fmt.Errorf("Available data less than declared length [%d < %d]", len(data), length))
panic(err)
} }
if uint(n) != length {
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
} }
elemBuf := &decodeState{} elemBuf := &decodeState{}
elemBuf.Write(data) elemBuf.Write(data)
elems := []reflect.Value{} elems := []reflect.Value{}
read := int(opts.head)
for elemBuf.Len() > 0 { for elemBuf.Len() > 0 {
elem := reflect.New(sd.elementType) elem := reflect.New(sd.elementType)
read += sd.elementDec(elemBuf, elem, opts) read += sd.elementDec(elemBuf, elem, opts)
@ -234,6 +317,7 @@ func newStructDecoder(t reflect.Type) decoderFunc {
head: tagOpts["head"], head: tagOpts["head"],
max: tagOpts["max"], max: tagOpts["max"],
min: tagOpts["min"], min: tagOpts["min"],
varint: tagOpts[varintOption] > 0,
} }
sd.fieldDecs[i] = typeDecoder(f.Type) sd.fieldDecs[i] = typeDecoder(f.Type)
@ -241,3 +325,20 @@ func newStructDecoder(t reflect.Type) decoderFunc {
return sd.decode return sd.decode
} }
//////////
type pointerDecoder struct {
base decoderFunc
}
func (pd *pointerDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
v.Elem().Set(reflect.New(v.Elem().Type().Elem()))
return pd.base(d, v.Elem(), opts)
}
func newPointerDecoder(t reflect.Type) decoderFunc {
baseDecoder := typeDecoder(t.Elem())
pd := pointerDecoder{base: baseDecoder}
return pd.decode
}

View File

@ -16,12 +16,19 @@ func Marshal(v interface{}) ([]byte, error) {
return e.Bytes(), nil return e.Bytes(), nil
} }
// Marshaler is the interface implemented by types that
// have a defined TLS encoding.
type Marshaler interface {
MarshalTLS() ([]byte, error)
}
// These are the options that can be specified in the struct tag. Right now, // These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else // all of them apply to variable-length vectors and nothing else
type encOpts struct { type encOpts struct {
head uint // length of length in bytes head uint // length of length in bytes
min uint // minimum size in bytes min uint // minimum size in bytes
max uint // maximum size in bytes max uint // maximum size in bytes
varint bool // whether to encode as a varint
} }
type encodeState struct { type encodeState struct {
@ -62,8 +69,14 @@ func typeEncoder(t reflect.Type) encoderFunc {
return newTypeEncoder(t) return newTypeEncoder(t)
} }
var (
marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
)
func newTypeEncoder(t reflect.Type) encoderFunc { func newTypeEncoder(t reflect.Type) encoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument if t.Implements(marshalerType) {
return marshalerEncoder
}
switch t.Kind() { switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@ -74,6 +87,8 @@ func newTypeEncoder(t reflect.Type) encoderFunc {
return newSliceEncoder(t) return newSliceEncoder(t)
case reflect.Struct: case reflect.Struct:
return newStructEncoder(t) return newStructEncoder(t)
case reflect.Ptr:
return newPointerEncoder(t)
default: default:
panic(fmt.Errorf("Unsupported type (%s)", t)) panic(fmt.Errorf("Unsupported type (%s)", t))
} }
@ -81,19 +96,65 @@ func newTypeEncoder(t reflect.Type) encoderFunc {
///// Specific encoders below ///// Specific encoders below
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
u := v.Uint() if v.Kind() == reflect.Ptr && v.IsNil() {
switch v.Type().Kind() { panic(fmt.Errorf("Cannot encode nil pointer"))
case reflect.Uint8:
e.WriteByte(byte(u))
case reflect.Uint16:
e.Write([]byte{byte(u >> 8), byte(u)})
case reflect.Uint32:
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
case reflect.Uint64:
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
} }
m, ok := v.Interface().(Marshaler)
if !ok {
panic(fmt.Errorf("Non-Marshaler passed to marshalerEncoder"))
}
b, err := m.MarshalTLS()
if err == nil {
_, err = e.Write(b)
}
if err != nil {
panic(err)
}
}
//////////
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if opts.varint {
varintEncoder(e, v, opts)
return
}
writeUint(e, v.Uint(), int(v.Type().Size()))
}
func varintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
writeVarint(e, v.Uint())
}
func writeVarint(e *encodeState, u uint64) {
if (u >> 62) > 0 {
panic(fmt.Errorf("uint value is too big for varint"))
}
var varintLen int
for _, len := range []uint{1, 2, 4, 8} {
if u < (uint64(1) << (8*len - 2)) {
varintLen = int(len)
break
}
}
twoBits := map[int]uint64{1: 0x00, 2: 0x01, 4: 0x02, 8: 0x03}[varintLen]
shift := uint(8*varintLen - 2)
writeUint(e, u|(twoBits<<shift), varintLen)
}
func writeUint(e *encodeState, u uint64, len int) {
data := make([]byte, len)
for i := 0; i < len; i += 1 {
data[i] = byte(u >> uint(8*(len-i-1)))
}
e.Write(data)
} }
////////// //////////
@ -121,27 +182,34 @@ type sliceEncoder struct {
} }
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if opts.head == 0 {
panic(fmt.Errorf("Cannot encode a slice without a header length"))
}
arrayState := &encodeState{} arrayState := &encodeState{}
se.ae.encode(arrayState, v, opts) se.ae.encode(arrayState, v, opts)
n := uint(arrayState.Len()) n := uint(arrayState.Len())
if opts.head == 0 {
panic(fmt.Errorf("Cannot encode a slice without a header length"))
}
if opts.max > 0 && n > opts.max { if opts.max > 0 && n > opts.max {
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max)) panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
} }
if n>>(8*opts.head) > 0 {
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
}
if n < opts.min { if n < opts.min {
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min)) panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
} }
for i := int(opts.head - 1); i >= 0; i -= 1 { switch opts.head {
e.WriteByte(byte(n >> (8 * uint(i)))) case headValueNoHead:
// None.
case headValueVarint:
writeVarint(e, uint64(n))
default:
if n>>(8*opts.head) > 0 {
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
} }
writeUint(e, uint64(n), int(opts.head))
}
e.Write(arrayState.Bytes()) e.Write(arrayState.Bytes())
} }
@ -179,9 +247,30 @@ func newStructEncoder(t reflect.Type) encoderFunc {
head: tagOpts["head"], head: tagOpts["head"],
max: tagOpts["max"], max: tagOpts["max"],
min: tagOpts["min"], min: tagOpts["min"],
varint: tagOpts[varintOption] > 0,
} }
se.fieldEncs[i] = typeEncoder(f.Type) se.fieldEncs[i] = typeEncoder(f.Type)
} }
return se.encode return se.encode
} }
//////////
type pointerEncoder struct {
base encoderFunc
}
func (pe pointerEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
panic(fmt.Errorf("Cannot marshal a struct containing a nil pointer"))
}
pe.base(e, v.Elem(), opts)
}
func newPointerEncoder(t reflect.Type) encoderFunc {
baseEncoder := typeEncoder(t.Elem())
pe := pointerEncoder{base: baseEncoder}
return pe.encode
}

View File

@ -5,16 +5,27 @@ import (
"strings" "strings"
) )
// `tls:"head=2,min=2,max=255"` // `tls:"head=2,min=2,max=255,varint"`
type tagOptions map[string]uint type tagOptions map[string]uint
var (
varintOption = "varint"
headOptionNone = "none"
headOptionVarint = "varint"
headValueNoHead = uint(255)
headValueVarint = uint(254)
)
// parseTag parses a struct field's "tls" tag as a comma-separated list of // parseTag parses a struct field's "tls" tag as a comma-separated list of
// name=value pairs, where the values MUST be unsigned integers // name=value pairs, where the values MUST be unsigned integers, or in
// the special case of head, "none" or "varint"
func parseTag(tag string) tagOptions { func parseTag(tag string) tagOptions {
opts := tagOptions{} opts := tagOptions{}
for _, token := range strings.Split(tag, ",") { for _, token := range strings.Split(tag, ",") {
if strings.Index(token, "=") == -1 { if token == varintOption {
opts[varintOption] = 1
continue continue
} }
@ -22,7 +33,16 @@ func parseTag(tag string) tagOptions {
if len(parts[0]) == 0 { if len(parts[0]) == 0 {
continue continue
} }
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
if len(parts) == 1 {
continue
}
if parts[0] == "head" && parts[1] == headOptionNone {
opts[parts[0]] = headValueNoHead
} else if parts[0] == "head" && parts[1] == headOptionVarint {
opts[parts[0]] = headValueVarint
} else if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
opts[parts[0]] = uint(val) opts[parts[0]] = uint(val)
} }
} }

122
vendor/github.com/bifurcation/mint/timer.go generated vendored Normal file
View File

@ -0,0 +1,122 @@
package mint
import (
"time"
)
// This is a simple timer implementation. Timers are stored in a sorted
// list.
// TODO(ekr@rtfm.com): Add a way to uncouple these from the system
// clock.
type timerCb func() error
type timer struct {
label string
cb timerCb
deadline time.Time
duration uint32
}
type timerSet struct {
ts []*timer
}
func newTimerSet() *timerSet {
return &timerSet{}
}
func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer {
now := time.Now()
t := timer{
label,
cb,
now.Add(time.Millisecond * time.Duration(delayMs)),
delayMs,
}
logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline)
var i int
ntimers := len(ts.ts)
for i = 0; i < ntimers; i++ {
if t.deadline.Before(ts.ts[i].deadline) {
break
}
}
tmp := make([]*timer, 0, ntimers+1)
tmp = append(tmp, ts.ts[:i]...)
tmp = append(tmp, &t)
tmp = append(tmp, ts.ts[i:]...)
ts.ts = tmp
return &t
}
// TODO(ekr@rtfm.com): optimize this now that the list is sorted.
// We should be able to do just one list manipulation, as long
// as we're careful about how we handle inserts during callbacks.
func (ts *timerSet) check(now time.Time) error {
for i, t := range ts.ts {
if now.After(t.deadline) {
ts.ts = append(ts.ts[:i], ts.ts[:i+1]...)
if t.cb != nil {
logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline)
cb := t.cb
t.cb = nil
err := cb()
if err != nil {
return err
}
}
} else {
break
}
}
return nil
}
// Returns the next time any of the timers would fire.
func (ts *timerSet) remaining() (bool, time.Duration) {
for _, t := range ts.ts {
if t.cb != nil {
return true, time.Until(t.deadline)
}
}
return false, time.Duration(0)
}
func (ts *timerSet) cancel(label string) {
for _, t := range ts.ts {
if t.label == label {
t.cancel()
}
}
}
func (ts *timerSet) getTimer(label string) *timer {
for _, t := range ts.ts {
if t.label == label && t.cb != nil {
return t
}
}
return nil
}
func (ts *timerSet) getAllTimers() []string {
var ret []string
for _, t := range ts.ts {
if t.cb != nil {
ret = append(ret, t.label)
}
}
return ret
}
func (t *timer) cancel() {
logf(logTypeHandshake, "Timer %s cancelled", t.label)
t.cb = nil
t.label = ""
}

View File

@ -51,11 +51,14 @@ func (l *Listener) Accept() (c net.Conn, err error) {
// Listener and wraps each connection with Server. // Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include // The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate. // at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) net.Listener { func NewListener(inner net.Listener, config *Config) (net.Listener, error) {
if config != nil && config.NonBlocking {
return nil, errors.New("listening not possible in non-blocking mode")
}
l := new(Listener) l := new(Listener)
l.Listener = inner l.Listener = inner
l.config = config l.config = config
return l return l, nil
} }
// Listen creates a TLS listener accepting connections on the // Listen creates a TLS listener accepting connections on the
@ -70,7 +73,7 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewListener(l, config), nil return NewListener(l, config)
} }
type TimeoutError struct{} type TimeoutError struct{}
@ -87,6 +90,10 @@ func (TimeoutError) Temporary() bool { return true }
// DialWithDialer interprets a nil configuration as equivalent to the zero // DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults. // configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
if config != nil && config.NonBlocking {
return nil, errors.New("dialing not possible in non-blocking mode")
}
// We want the Timeout and Deadline values from dialer to cover the // We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we // whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now. // also need to start our own timers now.
@ -121,16 +128,20 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
if config == nil { if config == nil {
config = &Config{} config = &Config{}
} else {
config = config.Clone()
} }
// If no ServerName is set, infer the ServerName // If no ServerName is set, infer the ServerName
// from the hostname we're connecting to. // from the hostname we're connecting to.
if config.ServerName == "" { if config.ServerName == "" {
// Make a copy to avoid polluting argument or default. config.ServerName = hostname
c := config.Clone()
c.ServerName = hostname
config = c
} }
// Set up DTLS as needed.
config.UseDTLS = (network == "udp")
conn := Client(rawConn, config) conn := Client(rawConn, config)
if timeout == 0 { if timeout == 0 {

22
vendor/github.com/cheekybits/genny/LICENSE generated vendored Normal file
View File

@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2014 cheekybits
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

2
vendor/github.com/cheekybits/genny/generic/doc.go generated vendored Normal file
View File

@ -0,0 +1,2 @@
// Package generic contains the generic marker types.
package generic

13
vendor/github.com/cheekybits/genny/generic/generic.go generated vendored Normal file
View File

@ -0,0 +1,13 @@
package generic
// Type is the placeholder type that indicates a generic value.
// When genny is executed, variables of this type will be replaced with
// references to the specific types.
// var GenericType generic.Type
type Type interface{}
// Number is the placehoder type that indiccates a generic numerical value.
// When genny is executed, variables of this type will be replaced with
// references to the specific types.
// var GenericType generic.Number
type Number float64

View File

@ -30,9 +30,9 @@ type TwoQueueCache struct {
size int size int
recentSize int recentSize int
recent *simplelru.LRU recent simplelru.LRUCache
frequent *simplelru.LRU frequent simplelru.LRUCache
recentEvict *simplelru.LRU recentEvict simplelru.LRUCache
lock sync.RWMutex lock sync.RWMutex
} }
@ -84,7 +84,8 @@ func New2QParams(size int, recentRatio float64, ghostRatio float64) (*TwoQueueCa
return c, nil return c, nil
} }
func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) { // Get looks up a key's value from the cache.
func (c *TwoQueueCache) Get(key interface{}) (value interface{}, ok bool) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
@ -105,6 +106,7 @@ func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) {
return nil, false return nil, false
} }
// Add adds a value to the cache.
func (c *TwoQueueCache) Add(key, value interface{}) { func (c *TwoQueueCache) Add(key, value interface{}) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
@ -160,12 +162,15 @@ func (c *TwoQueueCache) ensureSpace(recentEvict bool) {
c.frequent.RemoveOldest() c.frequent.RemoveOldest()
} }
// Len returns the number of items in the cache.
func (c *TwoQueueCache) Len() int { func (c *TwoQueueCache) Len() int {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
return c.recent.Len() + c.frequent.Len() return c.recent.Len() + c.frequent.Len()
} }
// Keys returns a slice of the keys in the cache.
// The frequently used keys are first in the returned slice.
func (c *TwoQueueCache) Keys() []interface{} { func (c *TwoQueueCache) Keys() []interface{} {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
@ -174,6 +179,7 @@ func (c *TwoQueueCache) Keys() []interface{} {
return append(k1, k2...) return append(k1, k2...)
} }
// Remove removes the provided key from the cache.
func (c *TwoQueueCache) Remove(key interface{}) { func (c *TwoQueueCache) Remove(key interface{}) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
@ -188,6 +194,7 @@ func (c *TwoQueueCache) Remove(key interface{}) {
} }
} }
// Purge is used to completely clear the cache.
func (c *TwoQueueCache) Purge() { func (c *TwoQueueCache) Purge() {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
@ -196,13 +203,17 @@ func (c *TwoQueueCache) Purge() {
c.recentEvict.Purge() c.recentEvict.Purge()
} }
// Contains is used to check if the cache contains a key
// without updating recency or frequency.
func (c *TwoQueueCache) Contains(key interface{}) bool { func (c *TwoQueueCache) Contains(key interface{}) bool {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
return c.frequent.Contains(key) || c.recent.Contains(key) return c.frequent.Contains(key) || c.recent.Contains(key)
} }
func (c *TwoQueueCache) Peek(key interface{}) (interface{}, bool) { // Peek is used to inspect the cache value of a key
// without updating recency or frequency.
func (c *TwoQueueCache) Peek(key interface{}) (value interface{}, ok bool) {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
if val, ok := c.frequent.Peek(key); ok { if val, ok := c.frequent.Peek(key); ok {

View File

@ -18,11 +18,11 @@ type ARCCache struct {
size int // Size is the total capacity of the cache size int // Size is the total capacity of the cache
p int // P is the dynamic preference towards T1 or T2 p int // P is the dynamic preference towards T1 or T2
t1 *simplelru.LRU // T1 is the LRU for recently accessed items t1 simplelru.LRUCache // T1 is the LRU for recently accessed items
b1 *simplelru.LRU // B1 is the LRU for evictions from t1 b1 simplelru.LRUCache // B1 is the LRU for evictions from t1
t2 *simplelru.LRU // T2 is the LRU for frequently accessed items t2 simplelru.LRUCache // T2 is the LRU for frequently accessed items
b2 *simplelru.LRU // B2 is the LRU for evictions from t2 b2 simplelru.LRUCache // B2 is the LRU for evictions from t2
lock sync.RWMutex lock sync.RWMutex
} }
@ -60,11 +60,11 @@ func NewARC(size int) (*ARCCache, error) {
} }
// Get looks up a key's value from the cache. // Get looks up a key's value from the cache.
func (c *ARCCache) Get(key interface{}) (interface{}, bool) { func (c *ARCCache) Get(key interface{}) (value interface{}, ok bool) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
// Ff the value is contained in T1 (recent), then // If the value is contained in T1 (recent), then
// promote it to T2 (frequent) // promote it to T2 (frequent)
if val, ok := c.t1.Peek(key); ok { if val, ok := c.t1.Peek(key); ok {
c.t1.Remove(key) c.t1.Remove(key)
@ -153,7 +153,7 @@ func (c *ARCCache) Add(key, value interface{}) {
// Remove from B2 // Remove from B2
c.b2.Remove(key) c.b2.Remove(key)
// Add the key to the frequntly used list // Add the key to the frequently used list
c.t2.Add(key, value) c.t2.Add(key, value)
return return
} }
@ -247,7 +247,7 @@ func (c *ARCCache) Contains(key interface{}) bool {
// Peek is used to inspect the cache value of a key // Peek is used to inspect the cache value of a key
// without updating recency or frequency. // without updating recency or frequency.
func (c *ARCCache) Peek(key interface{}) (interface{}, bool) { func (c *ARCCache) Peek(key interface{}) (value interface{}, ok bool) {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
if val, ok := c.t1.Peek(key); ok { if val, ok := c.t1.Peek(key); ok {

21
vendor/github.com/hashicorp/golang-lru/doc.go generated vendored Normal file
View File

@ -0,0 +1,21 @@
// Package lru provides three different LRU caches of varying sophistication.
//
// Cache is a simple LRU cache. It is based on the
// LRU implementation in groupcache:
// https://github.com/golang/groupcache/tree/master/lru
//
// TwoQueueCache tracks frequently used and recently used entries separately.
// This avoids a burst of accesses from taking out frequently used entries,
// at the cost of about 2x computational overhead and some extra bookkeeping.
//
// ARCCache is an adaptive replacement cache. It tracks recent evictions as
// well as recent usage in both the frequent and recent caches. Its
// computational overhead is comparable to TwoQueueCache, but the memory
// overhead is linear with the size of the cache.
//
// ARC has been patented by IBM, so do not use it if that is problematic for
// your program.
//
// All caches in this package take locks while operating, and are therefore
// thread-safe for consumers.
package lru

View File

@ -1,6 +1,3 @@
// This package provides a simple LRU cache. It is based on the
// LRU implementation in groupcache:
// https://github.com/golang/groupcache/tree/master/lru
package lru package lru
import ( import (
@ -11,11 +8,11 @@ import (
// Cache is a thread-safe fixed size LRU cache. // Cache is a thread-safe fixed size LRU cache.
type Cache struct { type Cache struct {
lru *simplelru.LRU lru simplelru.LRUCache
lock sync.RWMutex lock sync.RWMutex
} }
// New creates an LRU of the given size // New creates an LRU of the given size.
func New(size int) (*Cache, error) { func New(size int) (*Cache, error) {
return NewWithEvict(size, nil) return NewWithEvict(size, nil)
} }
@ -33,7 +30,7 @@ func NewWithEvict(size int, onEvicted func(key interface{}, value interface{}))
return c, nil return c, nil
} }
// Purge is used to completely clear the cache // Purge is used to completely clear the cache.
func (c *Cache) Purge() { func (c *Cache) Purge() {
c.lock.Lock() c.lock.Lock()
c.lru.Purge() c.lru.Purge()
@ -41,30 +38,30 @@ func (c *Cache) Purge() {
} }
// Add adds a value to the cache. Returns true if an eviction occurred. // Add adds a value to the cache. Returns true if an eviction occurred.
func (c *Cache) Add(key, value interface{}) bool { func (c *Cache) Add(key, value interface{}) (evicted bool) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
return c.lru.Add(key, value) return c.lru.Add(key, value)
} }
// Get looks up a key's value from the cache. // Get looks up a key's value from the cache.
func (c *Cache) Get(key interface{}) (interface{}, bool) { func (c *Cache) Get(key interface{}) (value interface{}, ok bool) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
return c.lru.Get(key) return c.lru.Get(key)
} }
// Check if a key is in the cache, without updating the recent-ness // Contains checks if a key is in the cache, without updating the
// or deleting it for being stale. // recent-ness or deleting it for being stale.
func (c *Cache) Contains(key interface{}) bool { func (c *Cache) Contains(key interface{}) bool {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
return c.lru.Contains(key) return c.lru.Contains(key)
} }
// Returns the key value (or undefined if not found) without updating // Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key. // the "recently used"-ness of the key.
func (c *Cache) Peek(key interface{}) (interface{}, bool) { func (c *Cache) Peek(key interface{}) (value interface{}, ok bool) {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
return c.lru.Peek(key) return c.lru.Peek(key)
@ -73,16 +70,15 @@ func (c *Cache) Peek(key interface{}) (interface{}, bool) {
// ContainsOrAdd checks if a key is in the cache without updating the // ContainsOrAdd checks if a key is in the cache without updating the
// recent-ness or deleting it for being stale, and if not, adds the value. // recent-ness or deleting it for being stale, and if not, adds the value.
// Returns whether found and whether an eviction occurred. // Returns whether found and whether an eviction occurred.
func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evict bool) { func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evicted bool) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if c.lru.Contains(key) { if c.lru.Contains(key) {
return true, false return true, false
} else {
evict := c.lru.Add(key, value)
return false, evict
} }
evicted = c.lru.Add(key, value)
return false, evicted
} }
// Remove removes the provided key from the cache. // Remove removes the provided key from the cache.

View File

@ -36,7 +36,7 @@ func NewLRU(size int, onEvict EvictCallback) (*LRU, error) {
return c, nil return c, nil
} }
// Purge is used to completely clear the cache // Purge is used to completely clear the cache.
func (c *LRU) Purge() { func (c *LRU) Purge() {
for k, v := range c.items { for k, v := range c.items {
if c.onEvict != nil { if c.onEvict != nil {
@ -48,7 +48,7 @@ func (c *LRU) Purge() {
} }
// Add adds a value to the cache. Returns true if an eviction occurred. // Add adds a value to the cache. Returns true if an eviction occurred.
func (c *LRU) Add(key, value interface{}) bool { func (c *LRU) Add(key, value interface{}) (evicted bool) {
// Check for existing item // Check for existing item
if ent, ok := c.items[key]; ok { if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent) c.evictList.MoveToFront(ent)
@ -78,17 +78,18 @@ func (c *LRU) Get(key interface{}) (value interface{}, ok bool) {
return return
} }
// Check if a key is in the cache, without updating the recent-ness // Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale. // or deleting it for being stale.
func (c *LRU) Contains(key interface{}) (ok bool) { func (c *LRU) Contains(key interface{}) (ok bool) {
_, ok = c.items[key] _, ok = c.items[key]
return ok return ok
} }
// Returns the key value (or undefined if not found) without updating // Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key. // the "recently used"-ness of the key.
func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) { func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
if ent, ok := c.items[key]; ok { var ent *list.Element
if ent, ok = c.items[key]; ok {
return ent.Value.(*entry).value, true return ent.Value.(*entry).value, true
} }
return nil, ok return nil, ok
@ -96,7 +97,7 @@ func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
// Remove removes the provided key from the cache, returning if the // Remove removes the provided key from the cache, returning if the
// key was contained. // key was contained.
func (c *LRU) Remove(key interface{}) bool { func (c *LRU) Remove(key interface{}) (present bool) {
if ent, ok := c.items[key]; ok { if ent, ok := c.items[key]; ok {
c.removeElement(ent) c.removeElement(ent)
return true return true
@ -105,7 +106,7 @@ func (c *LRU) Remove(key interface{}) bool {
} }
// RemoveOldest removes the oldest item from the cache. // RemoveOldest removes the oldest item from the cache.
func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) { func (c *LRU) RemoveOldest() (key interface{}, value interface{}, ok bool) {
ent := c.evictList.Back() ent := c.evictList.Back()
if ent != nil { if ent != nil {
c.removeElement(ent) c.removeElement(ent)
@ -116,7 +117,7 @@ func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) {
} }
// GetOldest returns the oldest entry // GetOldest returns the oldest entry
func (c *LRU) GetOldest() (interface{}, interface{}, bool) { func (c *LRU) GetOldest() (key interface{}, value interface{}, ok bool) {
ent := c.evictList.Back() ent := c.evictList.Back()
if ent != nil { if ent != nil {
kv := ent.Value.(*entry) kv := ent.Value.(*entry)

View File

@ -0,0 +1,37 @@
package simplelru
// LRUCache is the interface for simple LRU cache.
type LRUCache interface {
// Adds a value to the cache, returns true if an eviction occurred and
// updates the "recently used"-ness of the key.
Add(key, value interface{}) bool
// Returns key's value from the cache and
// updates the "recently used"-ness of the key. #value, isFound
Get(key interface{}) (value interface{}, ok bool)
// Check if a key exsists in cache without updating the recent-ness.
Contains(key interface{}) (ok bool)
// Returns key's value without updating the "recently used"-ness of the key.
Peek(key interface{}) (value interface{}, ok bool)
// Removes a key from the cache.
Remove(key interface{}) bool
// Removes the oldest entry from cache.
RemoveOldest() (interface{}, interface{}, bool)
// Returns the oldest entry from the cache. #key, value, isFound
GetOldest() (interface{}, interface{}, bool)
// Returns a slice of the keys in the cache, from oldest to newest.
Keys() []interface{}
// Returns the number of items in the cache.
Len() int
// Clear all cache entries
Purge()
}

21
vendor/github.com/isofew/go-stun/LICENSE generated vendored Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Vasily Vasilyev
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

284
vendor/github.com/isofew/go-stun/stun/agent.go generated vendored Normal file
View File

@ -0,0 +1,284 @@
package stun
import (
"errors"
"math/rand"
"net"
"sync"
"time"
)
var DefaultConfig = &Config{
RetransmissionTimeout: 500 * time.Millisecond,
TransactionTimeout: 39500 * time.Millisecond,
Software: "pixelbender/go-stun",
}
type Handler interface {
ServeSTUN(msg *Message, tr Transport)
}
type HandlerFunc func(msg *Message, tr Transport)
func (h HandlerFunc) ServeSTUN(msg *Message, tr Transport) {
h(msg, tr)
}
type Config struct {
// AuthMethod returns a key for MESSAGE-INTEGRITY attribute
AuthMethod AuthMethod
// Retransmission timeout, default is 500 milliseconds
RetransmissionTimeout time.Duration
// Transaction timeout, default is 39.5 seconds
TransactionTimeout time.Duration
// Fingerprint, if true all outgoing messages contain FINGERPRINT attribute
Fingerprint bool
// Software is a SOFTWARE attribute value for outgoing messages, if not empty
Software string
// Logf, if set all sent and received messages printed using Logf
Logf func(format string, args ...interface{})
}
func (c *Config) attrs() []Attr {
if c == nil {
return nil
}
var a []Attr
if c.Software != "" {
a = append(a, String(AttrSoftware, c.Software))
}
if c.Fingerprint {
a = append(a, Fingerprint)
}
return a
}
func (c *Config) Clone() *Config {
r := *c
return &r
}
type Agent struct {
config *Config
Handler Handler
m mux
}
func NewAgent(config *Config) *Agent {
if config == nil {
config = DefaultConfig
}
return &Agent{
config: config,
}
}
func (a *Agent) Send(msg *Message, tr Transport) (err error) {
msg = &Message{
msg.Type,
msg.Transaction,
append(a.config.attrs(), msg.Attributes...),
}
if log := a.config.Logf; log != nil {
log("%v → %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
}
b := msg.Marshal(getBuffer()[:0])
_, err = tr.Write(b)
putBuffer(b)
return
}
func (a *Agent) ServeConn(c net.Conn, stop chan struct{}) error {
if c, ok := c.(net.PacketConn); ok {
return a.ServePacket(c, stop)
}
var (
b = getBuffer()
p int
)
defer putBuffer(b)
for {
select {
case <-stop:
return nil
default:
}
if p >= len(b) {
return errBufferOverflow
}
n, err := c.Read(b[p:])
if err != nil {
return err
}
p += n
n = 0
for n < p {
r, err := a.ServeTransport(b[n:p], c)
if err != nil {
return err
}
n += r
}
if n > 0 {
if n < p {
p = copy(b, b[n:p])
} else {
p = 0
}
}
}
}
func (a *Agent) ServePacket(c net.PacketConn, stop chan struct{}) error {
b := getBuffer()
defer putBuffer(b)
// don't close the connection since we're going to reuse it
// defer c.Close()
for {
select {
case <-stop:
return nil
default:
}
n, addr, err := c.ReadFrom(b)
if err != nil {
return err
}
if n > 0 {
a.ServeTransport(b[:n], &packetConn{c, addr})
}
}
}
func (a *Agent) ServeTransport(b []byte, tr Transport) (n int, err error) {
msg := &Message{}
n, err = msg.Unmarshal(b)
if err != nil {
return
}
a.ServeSTUN(msg, tr)
return
}
func (a *Agent) ServeSTUN(msg *Message, tr Transport) {
if log := a.config.Logf; log != nil {
log("%v ← %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg)
}
if a.m.serve(msg, tr) {
return
}
if h := a.Handler; h != nil {
go h.ServeSTUN(msg, tr)
}
}
func (a *Agent) RoundTrip(req *Message, to Transport) (res *Message, from Transport, err error) {
var (
start = time.Now()
rto = a.config.RetransmissionTimeout
udp = to.LocalAddr().Network() == "udp"
tx = a.m.newTx()
)
defer a.m.closeTx(tx)
req = &Message{req.Type, tx.id, req.Attributes}
if err = a.Send(req, to); err != nil {
return
}
for {
d := a.config.TransactionTimeout - time.Since(start)
if d < 0 {
err = errTimeout
return
}
if udp && d > rto {
d = rto
}
res, from, err = tx.Receive(d)
if udp && err == errTimeout && d == rto {
rto <<= 1
a.Send(req, to)
continue
}
return
}
}
type mux struct {
sync.RWMutex
t map[string]*transaction
}
func (m *mux) serve(msg *Message, tr Transport) bool {
m.RLock()
tx, ok := m.t[string(msg.Transaction)]
m.RUnlock()
if ok {
tx.msg, tx.from = msg, tr
tx.Done()
return true
}
return false
}
func (m *mux) newTx() *transaction {
tx := &transaction{id: NewTransaction()}
m.Lock()
if m.t == nil {
m.t = make(map[string]*transaction)
} else {
for m.t[string(tx.id)] != nil {
rand.Read(tx.id[4:])
}
}
m.t[string(tx.id)] = tx
m.Unlock()
return tx
}
func (m *mux) closeTx(tx *transaction) {
m.Lock()
delete(m.t, string(tx.id))
m.Unlock()
}
func (m *mux) Close() {
m.Lock()
defer m.Unlock()
for _, it := range m.t {
it.Close()
}
m.t = nil
}
type transaction struct {
sync.WaitGroup
id []byte
from Transport
msg *Message
err error
}
func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) {
tx.Add(1)
t := time.AfterFunc(d, tx.timeout)
tx.Wait()
t.Stop()
if err = tx.err; err != nil {
return
}
return tx.msg, tx.from, nil
}
func (tx *transaction) timeout() {
tx.err = errTimeout
tx.Done()
}
func (tx *transaction) Close() {
tx.err = errCanceled
tx.Done()
}
var errCanceled = errors.New("stun: transaction canceled")
var errTimeout = errors.New("stun: transaction timeout")

438
vendor/github.com/isofew/go-stun/stun/attribute.go generated vendored Normal file
View File

@ -0,0 +1,438 @@
package stun
import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"fmt"
"hash/crc32"
"net"
"strconv"
)
// Attribute represents a STUN attribute.
type Attr interface {
Type() uint16
Marshal(p []byte) []byte
Unmarshal(b []byte) error
}
// IP address family
const (
IPv4 = 0x01
IPv6 = 0x02
)
// IP address family
const (
ChangeIP uint64 = 0x04
ChangePort = 0x02
)
func newAttr(typ uint16) Attr {
switch typ {
case AttrMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress,
AttrXorMappedAddress, AttrAlternateServer, AttrResponseOrigin, AttrOtherAddress,
AttrResponseAddress, AttrSourceAddress, AttrChangedAddress, AttrReflectedFrom:
return &addr{typ: typ}
case AttrRequestedAddressFamily, AttrRequestedTransport:
return &number{typ: typ, size: 4, pad: 24}
case AttrChannelNumber, AttrResponsePort:
return &number{typ: typ, size: 4, pad: 16}
case AttrLifetime, AttrConnectionID, AttrCacheTimeout,
AttrBandwidth, AttrTimerVal,
AttrTransactionTransmitCounter,
AttrEcnCheck, AttrChangeRequest, AttrPriority:
return &number{typ: typ, size: 4}
case AttrIceControlled, AttrIceControlling:
return &number{typ: typ, size: 8}
case AttrUsername, AttrRealm, AttrNonce, AttrSoftware, AttrPassword, AttrThirdPartyAuthorization,
AttrData, AttrAccessToken, AttrReservationToken, AttrMobilityTicket, AttrPadding, AttrUnknownAttributes:
return &raw{typ: typ}
case AttrMessageIntegrity:
return &integrity{}
case AttrErrorCode:
return &Error{}
case AttrEvenPort:
return &number{typ: typ, size: 1}
case AttrDontFragment, AttrUseCandidate:
return flag(typ)
case AttrFingerprint:
return &fingerprint{}
}
return nil
}
func AttrName(typ uint16) string {
if r, ok := attrNames[typ]; ok {
return r
}
return "0x" + strconv.FormatUint(uint64(typ), 16)
}
func Int(typ uint16, v uint64) Attr {
switch typ {
case AttrRequestedAddressFamily, AttrRequestedTransport:
return &number{typ, 4, 24, v}
case AttrChannelNumber, AttrResponsePort:
return &number{typ, 4, 16, v}
case AttrIceControlled, AttrIceControlling:
return &number{typ, 8, 0, v}
case AttrEvenPort:
return &number{typ, 1, 0, v}
default:
return &number{typ, 4, 0, v}
}
}
type number struct {
typ uint16
size, pad uint8
v uint64
}
func (a *number) Type() uint16 { return a.typ }
func (a *number) Marshal(p []byte) []byte {
r, b := grow(p, int(a.size))
switch a.size {
case 1:
b[0] = byte(a.v)
case 4:
be.PutUint32(b, uint32(a.v<<a.pad))
case 8:
be.PutUint64(b, a.v<<a.pad)
}
return r
}
func (a *number) Unmarshal(b []byte) error {
if len(b) < int(a.size) {
return errFormat
}
switch a.size {
case 1:
a.v = uint64(b[0])
case 4:
a.v = uint64(be.Uint32(b) >> a.pad)
case 8:
a.v = be.Uint64(b) >> a.pad
}
return nil
}
func (a *number) String() string {
return "0x" + strconv.FormatUint(a.v, 16)
}
func Flag(v uint16) Attr {
return flag(v)
}
type flag uint16
func (attr flag) Type() uint16 { return uint16(attr) }
func (flag) Marshal(p []byte) []byte { return p }
func (flag) Unmarshal(b []byte) error { return nil }
// Error represents the ERROR-CODE attribute.
type Error struct {
Code int
Reason string
}
func NewError(code int) *Error {
return &Error{code, ErrorText(code)}
}
func (*Error) Type() uint16 { return AttrErrorCode }
func (e *Error) Marshal(p []byte) []byte {
r, b := grow(p, 4+len(e.Reason))
b[0] = 0
b[1] = 0
b[2] = byte(e.Code / 100)
b[3] = byte(e.Code % 100)
copy(b[4:], e.Reason)
return r
}
func (e *Error) Unmarshal(b []byte) error {
if len(b) < 4 {
return errFormat
}
e.Code = int(b[2])*100 + int(b[3])
e.Reason = getString(b[4:])
return nil
}
func (e *Error) Error() string { return e.String() }
func (e *Error) String() string { return fmt.Sprintf("%d %s", e.Code, e.Reason) }
// ErrorText returns a text for the STUN error code. It returns the empty string if the code is unknown.
func ErrorText(code int) string { return errorText[code] }
func getString(b []byte) string {
for i := len(b); i >= 0; i-- {
if b[i-1] > 0 {
return string(b[:i])
}
}
return ""
}
func Addr(typ uint16, v net.Addr) Attr {
ip, port := SockAddr(v)
return &addr{typ, ip, port}
}
func SockAddr(v net.Addr) (net.IP, int) {
switch a := v.(type) {
case *net.UDPAddr:
return a.IP, a.Port
case *net.TCPAddr:
return a.IP, a.Port
case *net.IPAddr:
return a.IP, 0
default:
return net.IPv4zero, 0
}
}
func sameAddr(a, b net.Addr) bool {
aip, aport := SockAddr(a)
bip, bport := SockAddr(b)
return aip.Equal(bip) && aport == bport
}
func NewAddr(network string, ip net.IP, port int) net.Addr {
switch network {
case "udp", "udp4", "udp6":
return &net.UDPAddr{IP: ip, Port: port}
case "tcp", "tcp4", "tcp6":
return &net.TCPAddr{IP: ip, Port: port}
}
return &net.IPAddr{IP: ip}
}
func IP(typ uint16, ip net.IP) Attr { return &addr{typ, ip, 0} }
type addr struct {
typ uint16
IP net.IP
Port int
}
func (addr *addr) Type() uint16 { return addr.typ }
func (addr *addr) Addr(network string) net.Addr {
return NewAddr(network, addr.IP, addr.Port)
}
func (addr *addr) Xored() bool {
switch addr.typ {
case AttrXorMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress:
return true
default:
return false
}
}
func (addr *addr) Marshal(p []byte) []byte {
return addr.MarshalAddr(p, nil)
}
func (addr *addr) MarshalAddr(p, tx []byte) []byte {
fam, ip := IPv4, addr.IP.To4()
if ip == nil {
fam, ip = IPv6, addr.IP
}
r, b := grow(p, 4+len(ip))
b[0] = 0
b[1] = byte(fam)
if addr.Xored() && tx != nil {
be.PutUint16(b[2:], uint16(addr.Port)^0x2112)
b = b[4:]
for i, it := range ip {
b[i] = it ^ tx[i]
}
} else {
be.PutUint16(b[2:], uint16(addr.Port))
copy(b[4:], ip)
}
return r
}
func (addr *addr) Unmarshal(b []byte) error {
return addr.UnmarshalAddr(b, nil)
}
func (addr *addr) UnmarshalAddr(b, tx []byte) error {
if len(b) < 4 {
return errFormat
}
n, port := net.IPv4len, int(be.Uint16(b[2:]))
if b[1] == IPv6 {
n = net.IPv6len
}
if b = b[4:]; len(b) < n {
return errFormat
}
addr.IP = make(net.IP, n)
if addr.Xored() && tx != nil {
for i, it := range b {
addr.IP[i] = it ^ tx[i]
}
addr.Port = port ^ 0x2112
} else {
copy(addr.IP, b)
addr.Port = port
}
return nil
}
func (addr *addr) Equal(a *addr) bool {
return addr == a || (addr != nil && a != nil && addr.IP.Equal(a.IP) && addr.Port == a.Port)
}
func (addr *addr) String() string {
if addr.Port == 0 {
return addr.IP.String()
}
return net.JoinHostPort(addr.IP.String(), strconv.Itoa(addr.Port))
}
func Bytes(typ uint16, v []byte) Attr { return &raw{typ, v} }
type raw struct {
typ uint16
data []byte
}
func (attr *raw) Type() uint16 { return attr.typ }
func (attr *raw) Marshal(p []byte) []byte { return append(p, attr.data...) }
func (attr *raw) Unmarshal(p []byte) error {
attr.data = p
return nil
}
func (attr *raw) String() string { return string(attr.data) }
func String(typ uint16, v string) Attr {
return &str{typ, v}
}
type str struct {
typ uint16
data string
}
func (attr *str) Type() uint16 { return attr.typ }
func (attr *str) Marshal(p []byte) []byte { return append(p, attr.data...) }
func (attr *str) Unmarshal(p []byte) error {
attr.data = string(p)
return nil
}
func (attr *str) String() string { return attr.data }
func MessageIntegrity(key []byte) Attr {
return &integrity{key: key}
}
type integrity struct {
key, sum, raw []byte
}
func (*integrity) Type() uint16 {
return AttrMessageIntegrity
}
func (attr *integrity) Marshal(p []byte) []byte {
return append(p, attr.sum...)
}
func (attr *integrity) Unmarshal(b []byte) error {
if len(b) < 20 {
return errFormat
}
attr.sum = b
return nil
}
func (attr *integrity) MarshalSum(p, raw []byte) []byte {
n := len(raw) - 4
be.PutUint16(raw[2:], uint16(n+4))
return attr.Sum(attr.key, raw[:n], p)
}
func (attr *integrity) UnmarshalSum(p, raw []byte) error {
attr.raw = raw
return attr.Unmarshal(p)
}
func (attr *integrity) Sum(key, data, p []byte) []byte {
h := hmac.New(sha1.New, key)
h.Write(data)
return h.Sum(p)
}
func (attr *integrity) Check(key []byte) bool {
r := attr.raw
if len(r) < 44 {
return r == nil
}
be.PutUint16(r[2:], uint16(len(r)-20))
h := attr.Sum(key, r[:len(r)-24], nil)
return bytes.Equal(h, attr.sum)
}
var Fingerprint Attr = &fingerprint{}
type fingerprint struct {
sum uint32
raw []byte
}
func (*fingerprint) Type() uint16 {
return AttrFingerprint
}
func (attr *fingerprint) Marshal(p []byte) []byte {
r, b := grow(p, 4)
be.PutUint32(b, attr.sum)
return r
}
func (attr *fingerprint) Unmarshal(b []byte) error {
if len(b) < 4 {
return errFormat
}
attr.sum = be.Uint32(b)
return nil
}
func (attr *fingerprint) MarshalSum(p, raw []byte) []byte {
n := len(raw) - 4
be.PutUint16(raw[2:], uint16(n-12))
v := attr.Sum(raw[:n])
r, b := grow(p, 4)
be.PutUint32(b, v)
return r
}
func (attr *fingerprint) UnmarshalSum(p, raw []byte) error {
attr.raw = raw
return attr.Unmarshal(p)
}
func (attr *fingerprint) Sum(p []byte) uint32 {
return crc32.ChecksumIEEE(p) ^ 0x5354554e
}
func (attr *fingerprint) Check() bool {
r := attr.raw
if len(r) < 28 {
return r == nil
}
be.PutUint16(r[2:], uint16(len(r)-20))
return attr.Sum(r[:len(r)-8]) == attr.sum
}

110
vendor/github.com/isofew/go-stun/stun/conn.go generated vendored Normal file
View File

@ -0,0 +1,110 @@
package stun
import (
"github.com/pkg/errors"
"net"
)
type Conn struct {
net.Conn
agent *Agent
sess *Session
}
func NewConn(conn net.Conn, config *Config, stop chan struct{}) *Conn {
a := NewAgent(config)
go a.ServeConn(conn, stop)
return &Conn{conn, a, nil}
}
func (c *Conn) Network() string {
return c.LocalAddr().Network()
}
func (c *Conn) Discover() (net.Addr, error) {
res, err := c.Request(&Message{Type: MethodBinding})
if err != nil {
return nil, err
}
mapped := res.GetAddr(c.Network(), AttrXorMappedAddress, AttrMappedAddress)
if mapped != nil {
return mapped, nil
}
return nil, errors.New("stun: bad response, no mapped address")
}
func (c *Conn) Request(req *Message) (res *Message, err error) {
res, _, err = c.RequestTransport(req, c.Conn)
return
}
func (c *Conn) RequestTransport(req *Message, to Transport) (res *Message, from Transport, err error) {
sess := c.sess
auth := c.agent.config.AuthMethod
if to == nil {
to = c.Conn
}
for {
msg := &Message{
req.Type,
NewTransaction(),
append(sess.attrs(), req.Attributes...),
}
res, from, err = c.agent.RoundTrip(msg, to)
if err != nil {
return
}
code := res.GetError()
if code == nil {
// FIXME: authorize response...
if sess != nil {
c.sess = sess
}
return
}
err = code
switch code.Code {
case CodeUnauthorized, CodeStaleNonce:
if auth == nil {
return
}
sess = &Session{
Realm: res.GetString(AttrRealm),
Nonce: res.GetString(AttrNonce),
}
if err = auth(sess); err != nil {
return
}
auth = nil
default:
return
}
}
}
type Session struct {
Realm string
Nonce string
Username string
Key []byte
}
func (s *Session) attrs() []Attr {
if s == nil {
return nil
}
var a []Attr
if s.Realm != "" {
a = append(a, String(AttrRealm, s.Realm))
}
if s.Nonce != "" {
a = append(a, String(AttrNonce, s.Nonce))
}
if s.Username != "" {
a = append(a, String(AttrUsername, s.Username))
}
if s.Key != nil {
a = append(a, MessageIntegrity(s.Key))
}
return a
}

203
vendor/github.com/isofew/go-stun/stun/gen.go generated vendored Normal file
View File

@ -0,0 +1,203 @@
//+build ignore
//go:generate go run gen.go
// This program generates STUN parameters: methods, attributes and error codes by reading IANA registry.
package main
import (
"bytes"
"encoding/xml"
"fmt"
"go/format"
"io/ioutil"
"net/http"
"os"
"regexp"
"strings"
"unicode"
"unicode/utf8"
)
func main() {
err := generate("http://www.iana.org/assignments/stun-parameters/stun-parameters.xml", "registry.go")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func loadRegistry(url string) (*Registry, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
r := &Registry{}
err = xml.Unmarshal(b, r)
if err != nil {
return nil, err
}
name := regexp.MustCompile("^([\\w-]+)(.*was\\s+([\\w-]+))?")
for _, reg := range r.Registry {
for _, r := range reg.Records {
m := name.FindStringSubmatch(r.Description)
if m != nil {
r.Value = strings.ToLower(r.Value)
r.Name = m[1]
if r.Name == "Reserved" && m[3] != "" {
r.Name = m[3]
r.Deprecated = true
}
r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "rfc")
r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "RFC-")
}
}
}
return r, nil
}
type Record struct {
Value string `xml:"value"`
Description string `xml:"description"`
Name string
Deprecated bool
Ref struct {
Type string `xml:"type,attr"`
Data string `xml:"data,attr"`
} `xml:"xref"`
}
func (r *Record) IsValid() bool {
return r.Name != "Reserved" && r.Name != "Unassigned" && (r.Ref.Type == "rfc" || r.Ref.Type == "draft")
}
type Registry struct {
Title string `xml:"title"`
Updated string `xml:"updated"`
Registry []struct {
Id string `xml:"id,attr"`
Records []*Record `xml:"record"`
} `xml:"registry"`
}
func (reg *Registry) GetRecords(id string) []*Record {
for _, it := range reg.Registry {
if it.Id == id {
return it.Records
}
}
return nil
}
func generate(url, file string) error {
reg, err := loadRegistry(url)
if err != nil {
return err
}
b := &bytes.Buffer{}
fmt.Fprintf(b, "package stun\n\n")
fmt.Fprintf(b, "// Do not edit. This file is generated by 'go generate gen.go'\n")
fmt.Fprintf(b, "// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA).\n")
fmt.Fprintf(b, "// %s, Updated: %s.\n\n", reg.Title, reg.Updated)
genMethods(reg.GetRecords("stun-parameters-2"), b)
genAttributes(reg.GetRecords("stun-parameters-4"), b)
genErrors(reg.GetRecords("stun-parameters-6"), b)
src, err := format.Source(b.Bytes())
if err != nil {
return err
}
if err = ioutil.WriteFile(file, src, 0644); err != nil {
return err
}
return nil
}
func genMethods(records []*Record, b *bytes.Buffer) error {
c, m := &bytes.Buffer{}, &bytes.Buffer{}
ref := ""
for _, it := range records {
if it.IsValid() {
if ref == it.Ref.Data {
fmt.Fprintf(c, "Method%v uint16 = %s\n", it.Name, it.Value)
} else {
ref = it.Ref.Data
fmt.Fprintf(c, "Method%v uint16 = %s // RFC %s\n", it.Name, it.Value, ref)
}
fmt.Fprintf(m, "Method%v: \"%s\",\n", it.Name, it.Name)
}
}
fmt.Fprintf(b, "// STUN methods.\n")
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
fmt.Fprintf(b, "// STUN method names.\n")
fmt.Fprintf(b, "var methodNames = map[uint16]string{\n%s}\n", m.Bytes())
return nil
}
func genAttributes(records []*Record, b *bytes.Buffer) error {
c, d, m := &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}
ref := ""
for _, it := range records {
if it.IsValid() {
a := c
if it.Deprecated {
a = d
}
v := strings.Replace(it.Name, "_", "-", -1)
n := strings.Replace(v, "-", " ", -1)
parts := strings.Fields(n)
for i, s := range parts {
switch s {
case "", "ID":
default:
r, n := utf8.DecodeRuneInString(s)
s = string(unicode.ToUpper(r)) + strings.ToLower(s[n:])
}
parts[i] = s
}
n = strings.Join(parts, "")
if ref == it.Ref.Data || it.Deprecated {
fmt.Fprintf(a, "Attr%v uint16 = %s\n", n, it.Value)
} else {
ref = it.Ref.Data
fmt.Fprintf(a, "Attr%v uint16 = %s // RFC %s\n", n, it.Value, ref)
}
fmt.Fprintf(m, "Attr%v: \"%s\",\n", n, v)
}
}
fmt.Fprintf(b, "// STUN attributes.\n")
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
fmt.Fprintf(b, "// Deprecated: For backwards compatibility only.\n")
fmt.Fprintf(b, "const (\n%s)\n", d.Bytes())
fmt.Fprintf(b, "// STUN attribute names.\n")
fmt.Fprintf(b, "var attrNames = map[uint16]string{\n%s}\n", m.Bytes())
return nil
}
func genErrors(records []*Record, b *bytes.Buffer) error {
c, m := &bytes.Buffer{}, &bytes.Buffer{}
ref := ""
for _, it := range records {
if it.IsValid() {
n := strings.Replace(strings.Title(it.Description), " ", "", -1)
if ref == it.Ref.Data {
fmt.Fprintf(c, "Code%v int = %s\n", n, it.Value)
} else {
ref = it.Ref.Data
fmt.Fprintf(c, "Code%v int = %s // RFC %s\n", n, it.Value, ref)
}
fmt.Fprintf(m, "Code%v: \"%s\",\n", n, it.Description)
}
}
fmt.Fprintf(b, "// STUN error codes.\n")
fmt.Fprintf(b, "const (\n%s)\n", c.Bytes())
fmt.Fprintf(b, "// STUN error texts.\n")
fmt.Fprintf(b, "var errorText = map[int]string{\n%s}\n", m.Bytes())
return nil
}

351
vendor/github.com/isofew/go-stun/stun/message.go generated vendored Normal file
View File

@ -0,0 +1,351 @@
package stun
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net"
"sort"
"strconv"
)
const (
KindRequest uint16 = 0x0000
KindIndication uint16 = 0x0010
KindResponse uint16 = 0x0100
KindError uint16 = 0x0110
)
// Message represents a STUN message.
type Message struct {
Type uint16
Transaction []byte
Attributes []Attr
}
func (m *Message) Marshal(p []byte) []byte {
pos := len(p)
r, b := grow(p, 20)
be.PutUint16(b, m.Type)
if m.Transaction != nil {
copy(b[4:], m.Transaction)
} else {
copy(b[4:], magicCookie)
rand.Read(b[8:20])
}
sort.Sort(byPosition(m.Attributes))
for _, attr := range m.Attributes {
r = m.marshalAttr(r, attr, pos)
}
be.PutUint16(r[pos+2:], uint16(len(r)-pos-20))
return r
}
func (m *Message) marshalAttr(p []byte, attr Attr, pos int) []byte {
h := len(p)
r, b := grow(p, 4)
be.PutUint16(b, attr.Type())
switch v := attr.(type) {
case *addr:
r = v.MarshalAddr(r, r[pos+4:])
case *integrity:
r = v.MarshalSum(r, r[pos:])
case *fingerprint:
r = v.MarshalSum(r, r[pos:])
default:
r = v.Marshal(r)
}
n := len(r) - h - 4
be.PutUint16(r[h+2:], uint16(n))
if pad := n & 3; pad != 0 {
r, b = grow(r, 4-pad)
for i := range b {
b[i] = 0
}
}
return r
}
func (m *Message) Unmarshal(b []byte) (n int, err error) {
if len(b) < 20 {
err = io.EOF
return
}
l := int(be.Uint16(b[2:])) + 20
if len(b) < l {
err = io.EOF
return
}
pos, p := 20, make([]byte, l)
copy(p, b[:l])
m.Type = be.Uint16(p)
m.Transaction = p[4:20]
for pos < len(p) {
s, attr, err := m.unmarshalAttr(p, pos)
if err != nil {
return 0, err
}
pos += s
if attr != nil {
m.Attributes = append(m.Attributes, attr)
}
}
return l, nil
}
func (m *Message) unmarshalAttr(p []byte, pos int) (n int, attr Attr, err error) {
b := p[pos:]
if len(b) < 4 {
err = errFormat
return
}
typ := be.Uint16(b)
attr, n = newAttr(typ), int(be.Uint16(b[2:]))+4
if len(b) < n {
err = errFormat
return
}
b = b[4:n]
if attr != nil {
switch v := attr.(type) {
case *addr:
err = v.UnmarshalAddr(b, m.Transaction)
case *integrity:
err = v.UnmarshalSum(b, p[:pos+n])
case *fingerprint:
err = v.UnmarshalSum(b, p[:pos+n])
default:
err = attr.Unmarshal(b)
}
} else if typ < 0x8000 {
err = errFormat
}
if err != nil {
err = &errAttribute{err, typ}
return
}
if pad := n & 3; pad != 0 {
n += 4 - pad
if len(p) < pos+n {
err = errFormat
}
}
return
}
func (m *Message) Kind() uint16 {
return m.Type & 0x110
}
func (m *Message) Method() uint16 {
return m.Type &^ 0x110
}
func (m *Message) Add(attr Attr) {
m.Attributes = append(m.Attributes, attr)
}
func (m *Message) Set(attr Attr) {
m.Del(attr.Type())
m.Add(attr)
}
func (m *Message) Del(typ uint16) {
n := 0
for _, a := range m.Attributes {
if a.Type() != typ {
m.Attributes[n] = a
n++
}
}
m.Attributes = m.Attributes[:n]
}
func (m *Message) Get(typ uint16) (attr Attr) {
for _, attr = range m.Attributes {
if attr.Type() == typ {
return
}
}
return nil
}
func (m *Message) Has(typ uint16) bool {
for _, attr := range m.Attributes {
if attr.Type() == typ {
return true
}
}
return false
}
func (m *Message) GetString(typ uint16) string {
if str, ok := m.Get(typ).(fmt.Stringer); ok {
return str.String()
}
return ""
}
func (m *Message) GetAddr(network string, typ ...uint16) net.Addr {
for _, t := range typ {
if addr, ok := m.Get(t).(*addr); ok {
return addr.Addr(network)
}
}
return nil
}
func (m *Message) GetInt(typ uint16) (v uint64, ok bool) {
attr := m.Get(typ)
if r, ok := attr.(*number); ok {
return r.v, true
}
return
}
func (m *Message) GetBytes(typ uint16) []byte {
if attr, ok := m.Get(typ).(*raw); ok {
return attr.data
}
return nil
}
func (m *Message) GetError() *Error {
if err, ok := m.Get(AttrErrorCode).(*Error); ok {
return err
}
return nil
}
func (m *Message) CheckIntegrity(key []byte) bool {
if attr, ok := m.Get(AttrMessageIntegrity).(*integrity); ok {
return attr.Check(key)
}
return false
}
func (m *Message) CheckFingerprint() bool {
if attr, ok := m.Get(AttrFingerprint).(*fingerprint); ok {
return attr.Check()
}
return false
}
func (m *Message) String() string {
sort.Sort(byPosition(m.Attributes))
// TODO: use sprintf
b := &bytes.Buffer{}
b.WriteString(MethodName(m.Type))
b.WriteByte('{')
tx := m.Transaction
if tx == nil {
b.WriteString("nil")
} else if bytes.Equal(magicCookie, tx[:4]) {
b.WriteString(hex.EncodeToString(tx[4:]))
} else {
b.WriteString(hex.EncodeToString(tx))
}
for _, attr := range m.Attributes {
b.WriteString(", ")
b.WriteString(AttrName(attr.Type()))
switch v := attr.(type) {
case *raw:
b.WriteString(": \"")
b.Write(v.data)
b.WriteByte('"')
case *str:
b.WriteString(": \"")
b.WriteString(v.data)
b.WriteByte('"')
case flag, *integrity, *fingerprint:
default:
b.WriteString(fmt.Sprintf(": %v", attr))
}
}
b.WriteByte('}')
return b.String()
}
func MethodName(typ uint16) string {
if r, ok := methodNames[typ&^0x110]; ok {
switch typ & 0x110 {
case KindRequest:
return r + "Request"
case KindIndication:
return r + "Indication"
case KindResponse:
return r + "Response"
case KindError:
return r + "Error"
}
}
return "0x" + strconv.FormatUint(uint64(typ), 16)
}
func UnmarshalMessage(b []byte) (*Message, error) {
m := &Message{}
if _, err := m.Unmarshal(b); err != nil {
return nil, err
}
return m, nil
}
var magicCookie = []byte{0x21, 0x12, 0xa4, 0x42}
var alphanum = dict("01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
type dict []byte
func (d dict) rand(n int) string {
m, b := len(d), make([]byte, n)
for i := range b {
b[i] = d[rand.Intn(m)]
}
return string(b)
}
func NewTransaction() []byte {
id := make([]byte, 16)
copy(id, magicCookie)
rand.Read(id[4:]) // TODO: configure random source
return id
}
type byPosition []Attr
func (s byPosition) Len() int { return len(s) }
func (s byPosition) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byPosition) Less(i, j int) bool {
a, b := s[i].Type(), s[j].Type()
switch b {
case a:
return i < j
case AttrMessageIntegrity:
return a != AttrFingerprint
case AttrFingerprint:
return true
default:
return i < j
}
}
type errAttribute struct {
error
typ uint16
}
func (err errAttribute) Error() string {
return "attribute " + AttrName(err.typ) + ": " + err.error.Error()
}

172
vendor/github.com/isofew/go-stun/stun/nat.go generated vendored Normal file
View File

@ -0,0 +1,172 @@
package stun
import (
"errors"
"net"
)
const (
EndpointIndependent = "endpoint-independent"
AddressDependent = "address-dependent"
AddressPortDependent = "address-port-dependent"
)
type Detector struct {
*Conn
}
func NewDetector(c *Conn) *Detector {
d := &Detector{c}
d.agent.Handler = &Server{agent: c.agent}
return d
}
func (d *Detector) Hairpinning() error {
mapped, err := d.Discover()
if err != nil {
return err
}
conn, err := net.Dial(d.Network(), mapped.String())
if err != nil {
return err
}
// not using stop channel
c := NewConn(conn, d.agent.config, make(chan struct{}))
defer c.Close()
_, err = c.Discover()
return err
}
func (d *Detector) DiscoverChange(change uint64) error {
req := &Message{Type: MethodBinding, Attributes: []Attr{Int(AttrChangeRequest, change)}}
_, from, err := d.RequestTransport(req, d)
if err != nil {
return err
}
ip, port := SockAddr(d.RemoteAddr())
chip, chport := SockAddr(from.RemoteAddr())
if change&ChangeIP != 0 {
if ip.Equal(chip) {
return errors.New("stun: bad response, ip address is not changed")
}
} else if change&ChangePort != 0 {
if port == chport {
return errors.New("stun: bad response, port is not changed")
}
}
return nil
}
func (d *Detector) Filtering() (string, error) {
n := d.Network()
if n != "udp" {
return "", errors.New("stun: filtering test is not applicable to " + n)
}
_, err := d.Request(&Message{Type: MethodBinding})
if err != nil {
return "", err
}
err = d.DiscoverChange(ChangeIP | ChangePort)
switch err {
case nil:
return EndpointIndependent, nil
case errTimeout:
err = d.DiscoverChange(ChangePort)
switch err {
case nil:
return AddressDependent, nil
case errTimeout:
return AddressPortDependent, nil
}
}
return "", err
}
func (d *Detector) DiscoverOther(addr net.Addr) (net.Addr, error) {
n := addr.Network()
conn, err := net.Dial(n, addr.String())
if err != nil {
return nil, err
}
defer conn.Close()
// not using stop channel
go d.agent.ServeConn(conn, make(chan struct{}))
res, _, err := d.RequestTransport(&Message{Type: MethodBinding}, conn)
if err != nil {
return nil, err
}
mapped := res.GetAddr(n, AttrXorMappedAddress, AttrMappedAddress)
if mapped != nil {
return mapped, nil
}
return nil, errors.New("stun: bad response, no mapped address")
}
func (d *Detector) Mapping() (string, error) {
n := d.Network()
msg, err := d.Request(&Message{Type: MethodBinding})
if err != nil {
return "", err
}
mapped, other := msg.GetAddr(n, AttrXorMappedAddress), msg.GetAddr(n, AttrOtherAddress)
if mapped == nil {
return "", errors.New("stun: bad response, no mapped address")
}
if other == nil {
return "", errors.New("stun: bad response, no other address")
}
ip, _ := SockAddr(mapped)
if ip.IsLoopback() {
return EndpointIndependent, nil
}
for _, it := range local {
if it.IP.Equal(ip) {
return EndpointIndependent, nil
}
}
ip, _ = SockAddr(other)
_, port := SockAddr(d.RemoteAddr())
a, err := d.DiscoverOther(NewAddr(n, ip, port))
if err != nil {
return "", err
}
if sameAddr(a, mapped) {
return EndpointIndependent, nil
}
b, err := d.DiscoverOther(other)
if err != nil {
return "", err
}
if sameAddr(b, a) {
return AddressDependent, nil
}
return AddressPortDependent, nil
}
func LocalAddrs() []*net.IPAddr {
return local
}
var local []*net.IPAddr
func init() {
ifaces, err := net.Interfaces()
if err != nil {
return
}
for _, iface := range ifaces {
addrs, _ := iface.Addrs()
for _, it := range addrs {
var ip net.IP
switch it := it.(type) {
case *net.IPNet:
ip = it.IP
case *net.IPAddr:
ip = it.IP
}
if ip != nil && ip.IsGlobalUnicast() {
local = append(local, &net.IPAddr{ip, iface.Name})
}
}
}
}

179
vendor/github.com/isofew/go-stun/stun/registry.go generated vendored Normal file
View File

@ -0,0 +1,179 @@
package stun
// Do not edit. This file is generated by 'go generate gen.go'
// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA).
// Session Traversal Utilities for NAT (STUN) Parameters, Updated: 2016-11-14.
// STUN methods.
const (
MethodBinding uint16 = 0x001 // RFC 5389
MethodSharedSecret uint16 = 0x002
MethodAllocate uint16 = 0x003 // RFC 5766
MethodRefresh uint16 = 0x004
MethodSend uint16 = 0x006
MethodData uint16 = 0x007
MethodCreatePermission uint16 = 0x008
MethodChannelBind uint16 = 0x009
MethodConnect uint16 = 0x00a // RFC 6062
MethodConnectionBind uint16 = 0x00b
MethodConnectionAttempt uint16 = 0x00c
)
// STUN method names.
var methodNames = map[uint16]string{
MethodBinding: "Binding",
MethodSharedSecret: "SharedSecret",
MethodAllocate: "Allocate",
MethodRefresh: "Refresh",
MethodSend: "Send",
MethodData: "Data",
MethodCreatePermission: "CreatePermission",
MethodChannelBind: "ChannelBind",
MethodConnect: "Connect",
MethodConnectionBind: "ConnectionBind",
MethodConnectionAttempt: "ConnectionAttempt",
}
// STUN attributes.
const (
AttrMappedAddress uint16 = 0x0001 // RFC 5389
AttrChangeRequest uint16 = 0x0003 // RFC 5780
AttrUsername uint16 = 0x0006 // RFC 5389
AttrMessageIntegrity uint16 = 0x0008
AttrErrorCode uint16 = 0x0009
AttrUnknownAttributes uint16 = 0x000a
AttrChannelNumber uint16 = 0x000c // RFC 5766
AttrLifetime uint16 = 0x000d
AttrXorPeerAddress uint16 = 0x0012
AttrData uint16 = 0x0013
AttrRealm uint16 = 0x0014 // RFC 5389
AttrNonce uint16 = 0x0015
AttrXorRelayedAddress uint16 = 0x0016 // RFC 5766
AttrRequestedAddressFamily uint16 = 0x0017 // RFC 6156
AttrEvenPort uint16 = 0x0018 // RFC 5766
AttrRequestedTransport uint16 = 0x0019
AttrDontFragment uint16 = 0x001a
AttrAccessToken uint16 = 0x001b // RFC 7635
AttrXorMappedAddress uint16 = 0x0020 // RFC 5389
AttrReservationToken uint16 = 0x0022 // RFC 5766
AttrPriority uint16 = 0x0024 // RFC 5245
AttrUseCandidate uint16 = 0x0025
AttrPadding uint16 = 0x0026 // RFC 5780
AttrResponsePort uint16 = 0x0027
AttrConnectionID uint16 = 0x002a // RFC 6062
AttrSoftware uint16 = 0x8022 // RFC 5389
AttrAlternateServer uint16 = 0x8023
AttrTransactionTransmitCounter uint16 = 0x8025 // RFC 7982
AttrCacheTimeout uint16 = 0x8027 // RFC 5780
AttrFingerprint uint16 = 0x8028 // RFC 5389
AttrIceControlled uint16 = 0x8029 // RFC 5245
AttrIceControlling uint16 = 0x802a
AttrResponseOrigin uint16 = 0x802b // RFC 5780
AttrOtherAddress uint16 = 0x802c
AttrEcnCheck uint16 = 0x802d // RFC 6679
AttrThirdPartyAuthorization uint16 = 0x802e // RFC 7635
AttrMobilityTicket uint16 = 0x8030 // RFC 8016
)
// Deprecated: For backwards compatibility only.
const (
AttrResponseAddress uint16 = 0x0002
AttrSourceAddress uint16 = 0x0004
AttrChangedAddress uint16 = 0x0005
AttrPassword uint16 = 0x0007
AttrReflectedFrom uint16 = 0x000b
AttrBandwidth uint16 = 0x0010
AttrTimerVal uint16 = 0x0021
)
// STUN attribute names.
var attrNames = map[uint16]string{
AttrMappedAddress: "MAPPED-ADDRESS",
AttrResponseAddress: "RESPONSE-ADDRESS",
AttrChangeRequest: "CHANGE-REQUEST",
AttrSourceAddress: "SOURCE-ADDRESS",
AttrChangedAddress: "CHANGED-ADDRESS",
AttrUsername: "USERNAME",
AttrPassword: "PASSWORD",
AttrMessageIntegrity: "MESSAGE-INTEGRITY",
AttrErrorCode: "ERROR-CODE",
AttrUnknownAttributes: "UNKNOWN-ATTRIBUTES",
AttrReflectedFrom: "REFLECTED-FROM",
AttrChannelNumber: "CHANNEL-NUMBER",
AttrLifetime: "LIFETIME",
AttrBandwidth: "BANDWIDTH",
AttrXorPeerAddress: "XOR-PEER-ADDRESS",
AttrData: "DATA",
AttrRealm: "REALM",
AttrNonce: "NONCE",
AttrXorRelayedAddress: "XOR-RELAYED-ADDRESS",
AttrRequestedAddressFamily: "REQUESTED-ADDRESS-FAMILY",
AttrEvenPort: "EVEN-PORT",
AttrRequestedTransport: "REQUESTED-TRANSPORT",
AttrDontFragment: "DONT-FRAGMENT",
AttrAccessToken: "ACCESS-TOKEN",
AttrXorMappedAddress: "XOR-MAPPED-ADDRESS",
AttrTimerVal: "TIMER-VAL",
AttrReservationToken: "RESERVATION-TOKEN",
AttrPriority: "PRIORITY",
AttrUseCandidate: "USE-CANDIDATE",
AttrPadding: "PADDING",
AttrResponsePort: "RESPONSE-PORT",
AttrConnectionID: "CONNECTION-ID",
AttrSoftware: "SOFTWARE",
AttrAlternateServer: "ALTERNATE-SERVER",
AttrTransactionTransmitCounter: "TRANSACTION-TRANSMIT-COUNTER",
AttrCacheTimeout: "CACHE-TIMEOUT",
AttrFingerprint: "FINGERPRINT",
AttrIceControlled: "ICE-CONTROLLED",
AttrIceControlling: "ICE-CONTROLLING",
AttrResponseOrigin: "RESPONSE-ORIGIN",
AttrOtherAddress: "OTHER-ADDRESS",
AttrEcnCheck: "ECN-CHECK",
AttrThirdPartyAuthorization: "THIRD-PARTY-AUTHORIZATION",
AttrMobilityTicket: "MOBILITY-TICKET",
}
// STUN error codes.
const (
CodeTryAlternate int = 300 // RFC 5389
CodeBadRequest int = 400
CodeUnauthorized int = 401
CodeForbidden int = 403 // RFC 5766
CodeMobilityForbidden int = 405 // RFC 8016
CodeUnknownAttribute int = 420 // RFC 5389
CodeAllocationMismatch int = 437 // RFC 5766
CodeStaleNonce int = 438 // RFC 5389
CodeAddressFamilyNotSupported int = 440 // RFC 6156
CodeWrongCredentials int = 441 // RFC 5766
CodeUnsupportedTransportProtocol int = 442
CodePeerAddressFamilyMismatch int = 443 // RFC 6156
CodeConnectionAlreadyExists int = 446 // RFC 6062
CodeConnectionTimeoutOrFailure int = 447
CodeAllocationQuotaReached int = 486 // RFC 5766
CodeRoleConflict int = 487 // RFC 5245
CodeServerError int = 500 // RFC 5389
CodeInsufficientCapacity int = 508 // RFC 5766
)
// STUN error texts.
var errorText = map[int]string{
CodeTryAlternate: "Try Alternate",
CodeBadRequest: "Bad Request",
CodeUnauthorized: "Unauthorized",
CodeForbidden: "Forbidden",
CodeMobilityForbidden: "Mobility Forbidden",
CodeUnknownAttribute: "Unknown Attribute",
CodeAllocationMismatch: "Allocation Mismatch",
CodeStaleNonce: "Stale Nonce",
CodeAddressFamilyNotSupported: "Address Family not Supported",
CodeWrongCredentials: "Wrong Credentials",
CodeUnsupportedTransportProtocol: "Unsupported Transport Protocol",
CodePeerAddressFamilyMismatch: "Peer Address Family Mismatch",
CodeConnectionAlreadyExists: "Connection Already Exists",
CodeConnectionTimeoutOrFailure: "Connection Timeout or Failure",
CodeAllocationQuotaReached: "Allocation Quota Reached",
CodeRoleConflict: "Role Conflict",
CodeServerError: "Server Error",
CodeInsufficientCapacity: "Insufficient Capacity",
}

126
vendor/github.com/isofew/go-stun/stun/server.go generated vendored Normal file
View File

@ -0,0 +1,126 @@
package stun
import (
"net"
"sync"
)
func ListenAndServe(network, laddr string, config *Config) error {
srv := NewServer(config)
return srv.ListenAndServe(network, laddr)
}
type Server struct {
agent *Agent
mu sync.RWMutex
conns []net.PacketConn
}
func NewServer(config *Config) *Server {
srv := &Server{agent: NewAgent(config)}
srv.agent.Handler = srv
return srv
}
func (srv *Server) ListenAndServe(network, laddr string) error {
c, err := net.ListenPacket(network, laddr)
if err != nil {
return err
}
srv.addConn(c)
defer srv.removeConn(c)
// not using stop channel
return srv.agent.ServePacket(c, make(chan struct{}))
}
func (srv *Server) ServeSTUN(msg *Message, from Transport) {
if msg.Type == MethodBinding {
to := from
mapped := from.RemoteAddr()
ip, port := SockAddr(from.LocalAddr())
res := &Message{
Type: MethodBinding | KindResponse,
Transaction: msg.Transaction,
Attributes: []Attr{
Addr(AttrXorMappedAddress, mapped),
Addr(AttrMappedAddress, mapped),
},
}
srv.mu.RLock()
defer srv.mu.RUnlock()
if ch, ok := msg.GetInt(AttrChangeRequest); ok && ch != 0 {
for _, c := range srv.conns {
chip, chport := SockAddr(c.LocalAddr())
if chip.IsUnspecified() {
continue
}
if ch&ChangeIP != 0 {
if !ip.Equal(chip) {
to = &packetConn{c, mapped}
break
}
} else if ch&ChangePort != 0 {
if ip.Equal(chip) && port != chport {
to = &packetConn{c, mapped}
break
}
}
}
}
if len(srv.conns) < 2 {
srv.agent.Send(res, to)
return
}
other:
for _, a := range srv.conns {
aip, aport := SockAddr(a.LocalAddr())
if aip.IsUnspecified() || !ip.Equal(aip) || port == aport {
continue
}
for _, b := range srv.conns {
bip, bport := SockAddr(b.LocalAddr())
if bip.IsUnspecified() || bip.Equal(ip) || aport != bport {
continue
}
res.Set(Addr(AttrOtherAddress, b.LocalAddr()))
break other
}
}
srv.agent.Send(res, to)
}
}
func (srv *Server) addConn(c net.PacketConn) {
srv.mu.Lock()
srv.conns = append(srv.conns, c)
srv.mu.Unlock()
}
func (srv *Server) removeConn(c net.PacketConn) {
srv.mu.Lock()
l := srv.conns
for i, it := range l {
if it == c {
srv.conns = append(l[:i], l[i+1:]...)
break
}
}
srv.mu.Unlock()
}
func (srv *Server) Close() error {
srv.mu.RLock()
defer srv.mu.RUnlock()
for _, it := range srv.conns {
it.Close()
}
srv.conns = nil
return nil
}

129
vendor/github.com/isofew/go-stun/stun/stun.go generated vendored Normal file
View File

@ -0,0 +1,129 @@
package stun
import (
"crypto/md5"
"crypto/tls"
"errors"
"net"
"net/url"
"strings"
)
func Discover(uri string) (net.PacketConn, net.Addr, error) {
stop := make(chan struct{})
conn, err := Dial(uri, nil, stop)
if err != nil {
return nil, nil, err
}
addr, err := conn.Discover()
if err != nil {
conn.Close()
return nil, nil, err
}
// TODO: hijack
// stop reading conn before returning it
close(stop)
// note. serveconn/packet func is blocked by read/readfrom at the time
// we send the signal, which means it will still consume one more
// packet and we can only read starting from the second packet.
// (not too much of a problem, since we'll punch a few packets anyway)
return conn.Conn.(net.PacketConn), addr, nil
}
type AuthMethod func(sess *Session) error
// LongTermAuthMethod returns AuthMethod for long-term credentials.
// Key = MD5(username ":" realm ":" SASLprep(password)).
// SASLprep is defined in RFC 4013.
func LongTermAuthMethod(username, password string) AuthMethod {
return func(sess *Session) error {
h := md5.New()
h.Write([]byte(username + ":" + sess.Realm + ":" + password))
sess.Username = username
sess.Key = h.Sum(nil)
return nil
}
}
// ShotTermAuthMethod returns AuthMethod for short-term credentials.
// Key = SASLprep(password).
// SASLprep is defined in RFC 4013.
func ShortTermAuthMethod(password string) AuthMethod {
key := []byte(password)
return func(sess *Session) error {
sess.Key = key
return nil
}
}
func Dial(uri string, config *Config, stop chan struct{}) (*Conn, error) {
secure, network, addr, auth, err := parseURI(uri)
if err != nil {
return nil, err
}
var conn net.Conn
if secure {
conn, err = tls.Dial(network, addr, nil)
} else {
if strings.HasPrefix(network, "udp") {
conn, err = dialUDP(network, addr)
} else {
conn, err = dialTCP(network, addr)
}
}
if err != nil {
return nil, err
}
if auth != nil {
config = config.Clone()
config.AuthMethod = auth
}
return NewConn(conn, config, stop), nil
}
func parseURI(uri string) (secure bool, network, addr string, auth AuthMethod, err error) {
var u *url.URL
if u, err = url.Parse(uri); err != nil {
return
}
host, port, e := net.SplitHostPort(u.Opaque)
if e != nil {
host = u.Opaque
}
if a := u.User; a != nil {
if password, ok := a.Password(); ok {
auth = LongTermAuthMethod(a.Username(), password)
} else {
auth = ShortTermAuthMethod(a.Username())
}
}
network = u.Query().Get("transport")
if network == "" {
network = "udp"
}
switch u.Scheme {
case "stun", "turn":
if port == "" {
port = "3478"
}
switch network {
case "udp", "udp4", "udp6", "tcp", "tcp4", "tcp6":
default:
err = errors.New("stun: unsupported transport: " + network)
}
case "stuns", "turns":
if port == "" {
port = "5478"
}
secure = true
switch network {
case "tcp", "tcp4", "tcp6":
default:
err = errors.New("stun: unsupported transport: " + network)
}
default:
err = errors.New("stun: unsupported scheme " + u.Scheme)
}
addr = net.JoinHostPort(host, port)
return
}

96
vendor/github.com/isofew/go-stun/stun/transport.go generated vendored Normal file
View File

@ -0,0 +1,96 @@
package stun
import (
"encoding/binary"
"errors"
"net"
)
type Listener interface {
Addr() net.Addr
Close() error
}
type Transport interface {
LocalAddr() net.Addr
RemoteAddr() net.Addr
Write(p []byte) (int, error)
Close() error
}
type Marshaler interface {
Marshal(b []byte) []byte
}
type TransportHandler interface {
ServeTransport(b []byte, tr Transport) (int, error)
}
func dialUDP(network, raddr string) (net.Conn, error) {
addr, err := net.ResolveUDPAddr(network, raddr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP(network, nil)
if err != nil {
return nil, err
}
return &packetConn{conn, addr}, nil
}
func dialTCP(network, raddr string) (net.Conn, error) {
addr, err := net.ResolveTCPAddr(network, raddr)
if err != nil {
return nil, err
}
return net.DialTCP(network, nil, addr)
}
type packetConn struct {
net.PacketConn
addr net.Addr
}
func (t *packetConn) Read(p []byte) (n int, err error) {
n, _, err = t.ReadFrom(p)
return
}
func (t *packetConn) Write(p []byte) (int, error) {
return t.WriteTo(p, t.addr)
}
func (t *packetConn) RemoteAddr() net.Addr {
return t.addr
}
var (
errBufferOverflow = errors.New("stun: buffer overflow")
errFormat = errors.New("stun: format error")
)
func getBuffer() []byte {
return make([]byte, 2048)
}
func putBuffer(b []byte) {
if cap(b) >= 2048 {
}
}
func grow(p []byte, n int) (b, a []byte) {
l := len(p)
r := l + n
if r > cap(p) {
b = make([]byte, (1+((r-1)>>10))<<10)[:r]
a = b[l:r]
if l > 0 {
copy(b, p[:l])
}
} else {
return p[:r], p[l:r]
}
return
}
var be = binary.BigEndian

View File

@ -1,6 +1,18 @@
# Changelog # Changelog
## v0.6.0 (unreleased) ## v0.8.0 (unreleased)
- Add support for unidirectional streams (for IETF QUIC).
- Add a `quic.Config` option for the maximum number of incoming streams.
## v0.7.0 (2018-02-03)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
- Implement packet pacing.
## v0.6.0 (2017-12-12)
- Add support for QUIC 39, drop support for QUIC 35 - 37 - Add support for QUIC 39, drop support for QUIC 35 - 37
- Added `quic.Config` options for maximal flow control windows - Added `quic.Config` options for maximal flow control windows

View File

@ -1,7 +0,0 @@
package main
import (
_ "github.com/clipperhouse/linkedlist"
_ "github.com/clipperhouse/slice"
_ "github.com/clipperhouse/stringer"
)

View File

@ -1,34 +0,0 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
SendingAllowed() bool
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
ShouldSendRetransmittablePacket() bool
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm()
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
SetLowerLimit(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

View File

@ -1,34 +0,0 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
// +gen linkedlist
type Packet struct {
PacketNumber protocol.PacketNumber
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
}
// GetFramesForRetransmission gets all the frames for retransmission
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
var fs []wire.Frame
for _, frame := range p.Frames {
switch frame.(type) {
case *wire.AckFrame:
continue
case *wire.StopWaitingFrame:
continue
}
fs = append(fs, frame)
}
return fs
}

View File

@ -1,141 +0,0 @@
package ackhandler
import (
"errors"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
lowerLimit protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
version protocol.VersionNumber
}
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: protocol.AckSendDelay,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
if packetNumber == 0 {
return errInvalidPacketNumber
}
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
}
if packetNumber <= h.lowerLimit {
return nil
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil
}
// SetLowerLimit sets a lower limit for acking packets.
// Packets with packet numbers smaller or equal than p will not be acked.
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
h.lowerLimit = p
h.packetHistory.DeleteUpTo(p)
}
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
}
// always ack the first packet
if h.lastAck == nil {
h.ackQueued = true
}
if h.version < protocol.Version39 {
// Always send an ack every 20 packets in order to allow the peer to discard
// information from the SentPacketManager and provide an RTT measurement.
// From QUIC 39, this is not needed anymore, since the peer will regularly send a retransmittable packet.
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
h.ackQueued = true
}
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
h.ackQueued = true
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
return nil
}
ackRanges := h.packetHistory.GetAckRanges()
ack := &wire.AckFrame{
LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].First,
PacketReceivedTime: h.largestObservedReceivedTime,
}
if len(ackRanges) > 1 {
ack.AckRanges = ackRanges
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -1,455 +0,0 @@
package ackhandler
import (
"errors"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// The default RTT used before an RTT sample is taken.
// Note: This constant is also defined in the congestion package.
defaultInitialRTT = 100 * time.Millisecond
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
)
var (
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
// ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets
ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets")
// ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped
ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
)
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
skippedPackets []protocol.PacketNumber
numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet
LargestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
packetHistory *PacketList
stopWaitingManager stopWaitingManager
retransmissionQueue []*Packet
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
false, /* don't use reno since chromium doesn't (why?) */
protocol.InitialCongestionWindow,
protocol.DefaultMaxCongestionWindow,
)
return &sentPacketHandler{
packetHistory: NewPacketList(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
}
}
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
if f := h.packetHistory.Front(); f != nil {
return f.Value.PacketNumber - 1
}
return h.LargestAcked
}
func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool {
return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
if packet.PacketNumber <= h.lastSentPacketNumber {
return errPacketNumberNotIncreasing
}
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
return ErrTooManyTrackedSentPackets
}
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p)
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
h.skippedPackets = h.skippedPackets[1:]
}
}
h.lastSentPacketNumber = packet.PacketNumber
now := time.Now()
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet)
h.numNonRetransmittablePackets = 0
} else {
h.numNonRetransmittablePackets++
}
h.congestion.OnPacketSent(
now,
h.bytesInFlight,
packet.PacketNumber,
packet.Length,
isRetransmittable,
)
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
if ackFrame.LargestAcked > h.lastSentPacketNumber {
return errAckForUnsentPacket
}
// duplicate or out-of-order ACK
if withPacketNumber <= h.largestReceivedPacketWithAck {
return ErrDuplicateOrOutOfOrderAck
}
h.largestReceivedPacketWithAck = withPacketNumber
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
if ackFrame.LargestAcked <= h.largestInOrderAcked() {
return nil
}
h.LargestAcked = ackFrame.LargestAcked
if h.skippedPacketsAcked(ackFrame) {
return ErrAckForSkippedPacket
}
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
if rttUpdated {
h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
if len(ackedPackets) > 0 {
for _, p := range ackedPackets {
if encLevel < p.Value.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
}
h.onPacketAcked(p)
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
h.detectLostPackets()
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
var ackedPackets []*PacketElement
ackRangeIndex := 0
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
packetNumber := packet.PacketNumber
// Ignore packets below the LowestAcked
if packetNumber < ackFrame.LowestAcked {
continue
}
// Break after LargestAcked is reached
if packetNumber > ackFrame.LargestAcked {
break
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if packetNumber >= ackRange.First { // packet i contained in ACK range
if packetNumber > ackRange.Last {
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
}
ackedPackets = append(ackedPackets, el)
}
} else {
ackedPackets = append(ackedPackets, el)
}
}
return ackedPackets, nil
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber == largestAcked {
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
return true
}
// Packets are sorted by number, so we can stop searching
if packet.PacketNumber > largestAcked {
break
}
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
h.alarm = time.Time{}
return
}
// TODO(#497): TLP
if !h.handshakeComplete {
h.alarm = time.Now().Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO
h.alarm = time.Now().Add(h.computeRTOTimeout())
}
}
func (h *sentPacketHandler) detectLostPackets() {
h.lossTime = time.Time{}
now := time.Now()
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
var lostPackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber > h.LargestAcked {
break
}
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, el)
} else if h.lossTime.IsZero() {
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
}
if len(lostPackets) > 0 {
for _, p := range lostPackets {
h.queuePacketForRetransmission(p)
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
}
func (h *sentPacketHandler) OnAlarm() {
// TODO(#497): TLP
if !h.handshakeComplete {
h.queueHandshakePacketsForRetransmission()
h.handshakeCount++
} else if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
h.detectLostPackets()
} else {
// RTO
h.retransmitOldestTwoPackets()
h.rtoCount++
}
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
h.bytesInFlight -= packetElement.Value.Length
h.rtoCount = 0
h.handshakeCount = 0
// TODO(#497): h.tlpCount = 0
h.packetHistory.Remove(packetElement)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 {
return nil
}
packet := h.retransmissionQueue[0]
// Shift the slice and don't retain anything that isn't needed.
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet
}
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
return h.largestInOrderAcked() + 1
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendingAllowed() bool {
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
if congestionLimited {
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
h.bytesInFlight,
h.congestion.GetCongestionWindow())
}
// Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
// to RTOs, but we currently don't have a nice way of distinguishing them.
haveRetransmissions := len(h.retransmissionQueue) > 0
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
}
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)
}
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)
}
}
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
packet := &el.Value
utils.Debugf(
"\tQueueing packet 0x%x for retransmission (RTO), %d outstanding",
packet.PacketNumber,
h.packetHistory.Len(),
)
h.queuePacketForRetransmission(el)
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
var handshakePackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, el)
}
}
for _, el := range handshakePackets {
h.queuePacketForRetransmission(el)
}
}
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
packet := &packetElement.Value
h.bytesInFlight -= packet.Length
h.retransmissionQueue = append(h.retransmissionQueue, packet)
h.packetHistory.Remove(packetElement)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := 2 * h.rttStats.SmoothedRTT()
if duration == 0 {
duration = 2 * defaultInitialRTT
}
duration = utils.MaxDuration(duration, minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
rto := h.congestion.RetransmissionDelay()
if rto == 0 {
rto = defaultRTOTimeout
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff
rto = rto << h.rtoCount
return utils.MinDuration(rto, maxRTOTimeout)
}
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return true
}
}
return false
}
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
lioa := h.largestInOrderAcked()
deleteIndex := 0
for i, p := range h.skippedPackets {
if p <= lioa {
deleteIndex = i + 1
}
}
h.skippedPackets = h.skippedPackets[deleteIndex:]
}

View File

@ -14,8 +14,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install: install:
- rmdir c:\go /s /q - rmdir c:\go /s /q
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.9.2.windows-amd64.zip - appveyor DownloadFile https://storage.googleapis.com/golang/go1.10.2.windows-amd64.zip
- 7z x go1.9.2.windows-amd64.zip -y -oC:\ > NUL - 7z x go1.10.2.windows-amd64.zip -y -oC:\ > NUL
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin - set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH% - echo %PATH%
- echo %GOPATH% - echo %GOPATH%

View File

@ -8,19 +8,20 @@ import (
var bufferPool sync.Pool var bufferPool sync.Pool
func getPacketBuffer() []byte { func getPacketBuffer() *[]byte {
return bufferPool.Get().([]byte) return bufferPool.Get().(*[]byte)
} }
func putPacketBuffer(buf []byte) { func putPacketBuffer(buf *[]byte) {
if cap(buf) != int(protocol.MaxReceivePacketSize) { if cap(*buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!") panic("putPacketBuffer called with packet of wrong size!")
} }
bufferPool.Put(buf[:0]) bufferPool.Put(buf)
} }
func init() { func init() {
bufferPool.New = func() interface{} { bufferPool.New = func() interface{} {
return make([]byte, 0, protocol.MaxReceivePacketSize) b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
} }
} }

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -22,24 +23,29 @@ type client struct {
conn connection conn connection
hostname string hostname string
handshakeChan <-chan handshakeEvent
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has version negotiation completed yet versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config tlsConf *tls.Config
config *Config config *Config
tls handshake.MintTLS // only used when using TLS
connectionID protocol.ConnectionID srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialVersion protocol.VersionNumber
version protocol.VersionNumber version protocol.VersionNumber
session packetHandler session packetHandler
logger utils.Logger
} }
var ( var (
// make it possible to mock connection ID generation in the tests // make it possible to mock connection ID generation in the tests
generateConnectionID = utils.GenerateConnectionID generateConnectionID = protocol.GenerateConnectionID
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
) )
@ -57,69 +63,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
return Dial(udpConn, udpAddr, addr, tlsConf, config) return Dial(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := generateConnectionID()
if err != nil {
return nil, err
}
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}),
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.establishSecureConnection(); err != nil {
return nil, err
}
return c.session.(NonFWSession), nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func Dial( func Dial(
@ -129,14 +72,57 @@ func Dial(
tlsConf *tls.Config, tlsConf *tls.Config,
config *Config, config *Config,
) (Session, error) { ) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config) clientConfig := populateClientConfig(config)
version := clientConfig.Versions[0]
srcConnID, err := generateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := sess.WaitUntilHandshakeComplete(); err != nil { destConnID := srcConnID
if version.UsesTLS() {
destConnID, err = generateConnectionID()
if err != nil {
return nil, err return nil, err
} }
return sess, nil }
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID,
destConnID: destConnID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: version,
versionNegotiationChan: make(chan struct{}),
logger: utils.DefaultLogger,
}
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
if err := c.dial(); err != nil {
return nil, err
}
return c.session, nil
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set // populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -167,6 +153,18 @@ func populateClientConfig(config *Config) *Config {
if maxReceiveConnectionFlowControlWindow == 0 { if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
} }
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
return &Config{ return &Config{
Versions: versions, Versions: versions,
@ -175,29 +173,87 @@ func populateClientConfig(config *Config) *Config {
RequestConnectionIDOmission: config.RequestConnectionIDOmission, RequestConnectionIDOmission: config.RequestConnectionIDOmission,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
KeepAlive: config.KeepAlive, KeepAlive: config.KeepAlive,
} }
} }
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) func (c *client) dial() error {
func (c *client) establishSecureConnection() error { var err error
if err := c.createNewSession(c.version, nil); err != nil { if c.version.UsesTLS() {
err = c.dialTLS()
} else {
err = c.dialGQUIC()
}
if err == errCloseSessionForNewVersion {
return c.dial()
}
return err
}
func (c *client) dialGQUIC() error {
if err := c.createNewGQUICSession(); err != nil {
return err return err
} }
go c.listen() go c.listen()
return c.establishSecureConnection()
}
func (c *client) dialTLS() error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
}
csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ExtensionHandler = extHandler
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
go c.listen()
if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry {
return err
}
c.logger.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
return err
}
}
return nil
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error {
var runErr error var runErr error
errorChan := make(chan struct{}) errorChan := make(chan struct{})
go func() { go func() {
// session.run() returns as soon as the session is closed runErr = c.session.run() // returns as soon as the session is closed
runErr = c.session.run()
if runErr == errCloseSessionForNewVersion {
// run the new session
runErr = c.session.run()
}
close(errorChan) close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID) c.logger.Infof("Connection %s closed.", c.srcConnID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close() c.conn.Close()
}
}() }()
// wait until the server accepts the QUIC version (or an error occurs) // wait until the server accepts the QUIC version (or an error occurs)
@ -210,96 +266,95 @@ func (c *client) establishSecureConnection() error {
select { select {
case <-errorChan: case <-errorChan:
return runErr return runErr
case ev := <-c.handshakeChan: case err := <-c.session.handshakeStatus():
if ev.err != nil { return err
return ev.err
}
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
} }
} }
// Listen listens // Listen listens on the underlying connection and passes packets on for handling.
// It returns when the connection is closed.
func (c *client) listen() { func (c *client) listen() {
var err error var err error
for { for {
var n int var n int
var addr net.Addr var addr net.Addr
data := getPacketBuffer() data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize] data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes // The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err = c.conn.Read(data) n, addr, err = c.conn.Read(data)
if err != nil { if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") { if !strings.HasSuffix(err.Error(), "use of closed network connection") {
c.mutex.Lock()
if c.session != nil {
c.session.Close(err) c.session.Close(err)
} }
c.mutex.Unlock()
}
break break
} }
data = data[:n] if err := c.handlePacket(addr, data[:n]); err != nil {
c.logger.Errorf("error handling packet: %s", err.Error())
c.handlePacket(addr, data) }
} }
} }
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
rcvTime := time.Now() rcvTime := time.Now()
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := wire.ParseHeaderSentByServer(r, c.version) hdr, err := wire.ParseHeaderSentByServer(r, c.version)
// drop the packet if we can't parse the header
if err != nil { if err != nil {
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the header
return
} }
// reject packets with truncated connection id if we didn't request truncation // reject packets with truncated connection id if we didn't request truncation
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := wire.ParsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return
}
isVersionNegotiationPacket := hdr.VersionFlag /* gQUIC Version Negotiation Packet */ || hdr.Type == protocol.PacketTypeVersionNegotiation /* IETF draft style Version Negotiation Packet */
// handle Version Negotiation Packets // handle Version Negotiation Packets
if isVersionNegotiationPacket { if hdr.IsVersionNegotiation {
// ignore delayed / duplicated version negotiation packets // ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated { if c.receivedVersionNegotiationPacket || c.versionNegotiated {
return return errors.New("received a delayed Version Negotiation Packet")
} }
// version negotiation packets have no payload // version negotiation packets have no payload
if err := c.handleVersionNegotiationPacket(hdr); err != nil { if err := c.handleVersionNegotiationPacket(hdr); err != nil {
c.session.Close(err) c.session.Close(err)
} }
return return nil
}
if hdr.IsPublicHeader {
return c.handleGQUICPacket(hdr, r, packetData, remoteAddr, rcvTime)
}
return c.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime)
}
func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
// TODO(#1003): add support for server-chosen connection IDs
// reject packets with the wrong connection ID
if !hdr.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
}
if hdr.IsLongHeader {
if hdr.Type != protocol.PacketTypeRetry && hdr.Type != protocol.PacketTypeHandshake {
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
}
c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen)
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
} }
// this is the first packet we are receiving // this is the first packet we are receiving
@ -312,9 +367,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
header: hdr, header: hdr,
data: packet[len(packet)-r.Len():], data: packetData,
rcvTime: rcvTime, rcvTime: rcvTime,
}) })
return nil
}
func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
}
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(r)
if err != nil {
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
}
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
return nil
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
close(c.versionNegotiationChan)
}
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
return nil
} }
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
@ -327,42 +421,66 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
} }
} }
c.receivedVersionNegotiationPacket = true c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok { if !ok {
return qerr.InvalidVersion return qerr.InvalidVersion
} }
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version // switch to negotiated version
initialVersion := c.version c.initialVersion = c.version
c.version = newVersion c.version = newVersion
var err error var err error
c.connectionID, err = utils.GenerateConnectionID() c.destConnID, err = generateConnectionID()
if err != nil { if err != nil {
return err return err
} }
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) // in gQUIC, there's only one connection ID
if !c.version.UsesTLS() {
// create a new session and close the old one c.srcConnID = c.destConnID
// the new session must be created first to update client member variables }
oldSession := c.session c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
defer oldSession.Close(errCloseSessionForNewVersion) c.session.Close(errCloseSessionForNewVersion)
return c.createNewSession(initialVersion, hdr.SupportedVersions) return nil
} }
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error { func (c *client) createNewGQUICSession() (err error) {
var err error c.mutex.Lock()
utils.Debugf("createNewSession with initial version %s", initialVersion) defer c.mutex.Unlock()
c.session, c.handshakeChan, err = newClientSession( c.session, err = newClientSession(
c.conn, c.conn,
c.hostname, c.hostname,
c.version, c.version,
c.connectionID, c.destConnID,
c.tlsConf, c.tlsConf,
c.config, c.config,
initialVersion, c.initialVersion,
negotiatedVersions, c.negotiatedVersions,
c.logger,
)
return err
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newTLSClientSession(
c.conn,
c.hostname,
c.version,
c.destConnID,
c.srcConnID,
c.config,
c.tls,
paramsChan,
1,
c.logger,
) )
return err return err
} }

View File

@ -1,11 +1,16 @@
coverage: coverage:
round: nearest round: nearest
ignore: ignore:
- ackhandler/packet_linkedlist.go - streams_map_incoming_bidi.go
- streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go
- h2quic/gzipreader.go - h2quic/gzipreader.go
- h2quic/response.go - h2quic/response.go
- internal/ackhandler/packet_linkedlist.go
- internal/utils/byteinterval_linkedlist.go - internal/utils/byteinterval_linkedlist.go
- internal/utils/packetinterval_linkedlist.go - internal/utils/packetinterval_linkedlist.go
- internal/utils/linkedlist/linkedlist.go
status: status:
project: project:
default: default:

View File

@ -1,183 +0,0 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const (
// Note: This constant is also defined in the ackhandler package.
initialRTTus = 100 * 1000
rttAlpha float32 = 0.125
oneMinusAlpha float32 = (1 - rttAlpha)
rttBeta float32 = 0.25
oneMinusBeta float32 = (1 - rttBeta)
halfWindow float32 = 0.5
quarterWindow float32 = 0.25
)
type rttSample struct {
rtt time.Duration
time time.Time
}
// RTTStats provides round-trip statistics
type RTTStats struct {
initialRTTus int64
recentMinRTTwindow time.Duration
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
numMinRTTsamplesRemaining uint32
newMinRTT rttSample
recentMinRTT rttSample
halfWindowRTT rttSample
quarterWindowRTT rttSample
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{
initialRTTus: initialRTTus,
recentMinRTTwindow: utils.InfDuration,
}
}
// InitialRTTus is the initial RTT in us
func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus }
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// LatestRTT returns the most recent rtt measurement.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the
// minRTT for the entire connection if SampleNewMinRtt was never called.
func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt }
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// GetQuarterWindowRTT gets the quarter window RTT
func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt }
// GetHalfWindowRTT gets the half window RTT
func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt }
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// SetRecentMinRTTwindow sets how old a recent min rtt sample can be.
func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
r.recentMinRTTwindow = recentMinRTTwindow
}
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 {
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
return
}
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
// ackDelay but the raw observed sendDelta, since poor clock granularity at
// the client may cause a high ackDelay to result in underestimation of the
// r.minRTT.
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
r.updateRecentMinRTT(sendDelta, now)
// Correct for ackDelay if information received from the peer results in a
// positive RTT sample. Otherwise, we use the sendDelta as a reasonable
// measure for smoothedRTT.
sample := sendDelta
if sample > ackDelay {
sample -= ackDelay
}
r.latestRTT = sample
// First time call.
if r.smoothedRTT == 0 {
r.smoothedRTT = sample
r.meanDeviation = sample / 2
} else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
}
}
func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update.
if r.numMinRTTsamplesRemaining > 0 {
r.numMinRTTsamplesRemaining--
if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt {
r.newMinRTT = rttSample{rtt: sample, time: now}
}
if r.numMinRTTsamplesRemaining == 0 {
r.recentMinRTT = r.newMinRTT
r.halfWindowRTT = r.newMinRTT
r.quarterWindowRTT = r.newMinRTT
}
}
// Update the three recent rtt samples.
if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt {
r.recentMinRTT = rttSample{rtt: sample, time: now}
r.halfWindowRTT = r.recentMinRTT
r.quarterWindowRTT = r.recentMinRTT
} else if sample <= r.halfWindowRTT.rtt {
r.halfWindowRTT = rttSample{rtt: sample, time: now}
r.quarterWindowRTT = r.halfWindowRTT
} else if sample <= r.quarterWindowRTT.rtt {
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
}
// Expire old min rtt samples.
if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) {
r.recentMinRTT = r.halfWindowRTT
r.halfWindowRTT = r.quarterWindowRTT
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
} else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) {
r.halfWindowRTT = r.quarterWindowRTT
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
} else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) {
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
}
}
// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next
// |numSamples| UpdateRTT calls.
func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) {
r.numMinRTTsamplesRemaining = numSamples
r.newMinRTT = rttSample{}
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
r.initialRTTus = initialRTTus
r.numMinRTTsamplesRemaining = 0
r.recentMinRTTwindow = utils.InfDuration
r.recentMinRTT = rttSample{}
r.halfWindowRTT = rttSample{}
r.quarterWindowRTT = rttSample{}
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
}

View File

@ -0,0 +1,41 @@
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStreamI interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream struct {
*stream
}
var _ cryptoStreamI = &cryptoStream{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStream{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset
}

View File

@ -16,23 +16,48 @@ type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number. // A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber type VersionNumber = protocol.VersionNumber
// VersionGQUIC39 is gQUIC version 39.
const VersionGQUIC39 = protocol.Version39
// A Cookie can be used to verify the ownership of the client address. // A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie type Cookie = handshake.Cookie
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams // Stream is the interface implemented by QUIC streams
type Stream interface { type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream. // Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true // Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline. // after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader io.Reader
// Write writes data to the stream. // Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true // Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline. // after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer io.Closer
StreamID() StreamID // CancelWrite aborts sending on this stream.
// Reset closes the stream with an error. // It must not be called after Close.
Reset(error) // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed. // The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely). // This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
@ -53,18 +78,63 @@ type Stream interface {
SetDeadline(t time.Time) error SetDeadline(t time.Time) error
} }
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
type Session interface { type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available. // AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error) AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached. // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
// New streams always have the smallest possible stream ID. AcceptUniStream() (ReceiveStream, error)
// TODO: Enable testing for the special error // OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error) OpenStream() (Stream, error)
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened. // OpenStreamSync opens a new bidirectional QUIC stream.
// It always picks the smallest possible stream ID. // It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenStreamSync() (Stream, error) OpenStreamSync() (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenUniStreamSync() (SendStream, error)
// LocalAddr returns the local address. // LocalAddr returns the local address.
LocalAddr() net.Addr LocalAddr() net.Addr
// RemoteAddr returns the address of the peer. // RemoteAddr returns the address of the peer.
@ -74,13 +144,9 @@ type Session interface {
// The context is cancelled when the session is closed. // The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
Context() context.Context Context() context.Context
} // ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
// A NonFWSession is a QUIC connection between two peers half-way through the handshake. ConnectionState() ConnectionState
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
} }
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.
@ -113,6 +179,17 @@ type Config struct {
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data. // MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client. // If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow uint64 MaxReceiveConnectionFlowControlWindow uint64
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive. // KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool KeepAlive bool
} }

View File

@ -0,0 +1,3 @@
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet

View File

@ -0,0 +1,46 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet)
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
// The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
// It always returns a number greater or equal than 1.
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
// Note that the number of packets is only calculated based on the pacing algorithm.
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() (packet *Packet)
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
GetAlarmTimeout() time.Time
OnAlarm() error
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

View File

@ -0,0 +1,29 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
type Packet struct {
PacketNumber protocol.PacketNumber
PacketType protocol.PacketType
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
// There are two reasons why a packet cannot be retransmitted:
// * it was already retransmitted
// * this packet is a retransmission, and we already received an ACK for the original packet
canBeRetransmitted bool
includedInBytesInFlight bool
retransmittedAs []protocol.PacketNumber
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
retransmissionOf protocol.PacketNumber
}

View File

@ -1,13 +1,10 @@
// Generated by: main // This file was automatically generated by genny.
// TypeWriter: linkedlist // Any changes will be lost if this file is regenerated.
// Directive: +gen on Packet // see https://github.com/cheekybits/genny
package ackhandler package ackhandler
// List is a modification of http://golang.org/pkg/container/list/ // Linked list implementation from the Go standard library.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// PacketElement is an element of a linked list. // PacketElement is an element of a linked list.
type PacketElement struct { type PacketElement struct {
@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
return nil return nil
} }
// PacketList represents a doubly linked list. // PacketList is a linked list of Packets.
// The zero value for PacketList is an empty list ready to use.
type PacketList struct { type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element len int // current list length excluding (this) sentinel element
@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
// The complexity is O(1). // The complexity is O(1).
func (l *PacketList) Len() int { return l.len } func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil. // Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement { func (l *PacketList) Front() *PacketElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
return l.root.next return l.root.next
} }
// Back returns the last element of list l or nil. // Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement { func (l *PacketList) Back() *PacketElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
return l.root.prev return l.root.prev
} }
// lazyInit lazily initializes a zero PacketList value. // lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() { func (l *PacketList) lazyInit() {
if l.root.next == nil { if l.root.next == nil {
l.Init() l.Init()
@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
return e return e
} }
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at). // insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at) return l.insert(&PacketElement{Value: v}, at)
} }
@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
// Remove removes e from l if e is an element of list l. // Remove removes e from l if e is an element of list l.
// It returns the element value e.Value. // It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet { func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l { if e.list == l {
// if e.list == l, l must have been initialized when e was inserted // if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero PacketElement) and l.remove will crash // in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e) l.remove(e)
} }
return e.Value return e.Value
@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
// InsertBefore inserts a new element e with value v immediately before mark and returns e. // InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev) return l.insertValue(v, mark.prev)
} }
// InsertAfter inserts a new element e with value v immediately after mark and returns e. // InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark) return l.insertValue(v, mark)
} }
// MoveToFront moves element e to the front of list l. // MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) { func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e { if e.list != l || l.root.next == e {
return return
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root) l.insert(l.remove(e), &l.root)
} }
// MoveToBack moves element e to the back of list l. // MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) { func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e { if e.list != l || l.root.prev == e {
return return
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev) l.insert(l.remove(e), l.root.prev)
} }
// MoveBefore moves element e to its new position before mark. // MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) { func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
} }
// MoveAfter moves element e to its new position after mark. // MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) { func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
} }
// PushBackList inserts a copy of an other list at the back of list l. // PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) { func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
} }
// PushFrontList inserts a copy of an other list at the front of list l. // PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) { func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {

View File

@ -0,0 +1,215 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
// initial maximum number of retransmittable packets received before sending an ack.
initialRetransmittablePacketsBeforeAck = 2
// number of retransmittable that an ACK is sent for
retransmittablePacketsBeforeAck = 10
// 1/5 RTT delay when doing ack decimation
ackDecimationDelay = 1.0 / 4
// 1/8 RTT delay when doing ack decimation
shortAckDecimationDelay = 1.0 / 8
// Minimum number of packets received before ack decimation is enabled.
// This intends to avoid the beginning of slow start, when CWNDs may be
// rapidly increasing.
minReceivedBeforeAckDecimation = 100
// Maximum number of packets to ack immediately after a missing packet for
// fast retransmission to kick in at the sender. This limit is created to
// reduce the number of acks sent that have no benefit for fast retransmission.
// Set to the number of nacks needed for fast retransmit plus one for protection
// against an ack loss
maxPacketsAfterNewMissing = 4
)
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(
rttStats *congestion.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: ackSendDelay,
rttStats: rttStats,
logger: logger,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil
}
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
// in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.ackQueued = true
return
}
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
}
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
}
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
} else {
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
}
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
ackTime := rcvTime.Add(ackDelay)
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
now := time.Now()
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
DelayTime: now.Sub(h.largestObservedReceivedTime),
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -74,17 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
return nil return nil
} }
// DeleteUpTo deletes all entries up to (and including) p // DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) { func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1) if p <= h.lowestInReceivedPacketNumbers {
return
}
h.lowestInReceivedPacketNumbers = p
nextEl := h.ranges.Front() nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl { for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next() nextEl = el.Next()
if p >= el.Value.Start && p < el.Value.End { if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p + 1 el.Value.Start = p
} else if el.Value.End <= p { // delete a whole range } else if el.Value.End < p { // delete a whole range
h.ranges.Remove(el) h.ranges.Remove(el)
} else { // no ranges affected. Nothing to do } else { // no ranges affected. Nothing to do
return return
@ -101,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
ackRanges := make([]wire.AckRange, h.ranges.Len()) ackRanges := make([]wire.AckRange, h.ranges.Len())
i := 0 i := 0
for el := h.ranges.Back(); el != nil; el = el.Prev() { for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End} ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
i++ i++
} }
return ackRanges return ackRanges
@ -111,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{} ackRange := wire.AckRange{}
if h.ranges.Len() > 0 { if h.ranges.Len() > 0 {
r := h.ranges.Back().Value r := h.ranges.Back().Value
ackRange.First = r.Start ackRange.Smallest = r.Start
ackRange.Last = r.End ackRange.Largest = r.End
} }
return ackRange return ackRange
} }

View File

@ -0,0 +1,40 @@
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendTLP means that a TLP probe packet should be sent
SendTLP
// SendAny means that any packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendRTO:
return "rto"
case SendTLP:
return "tlp"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}

View File

@ -0,0 +1,649 @@
package ackhandler
import (
"fmt"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Maximum number of tail loss probes before an RTO fires.
maxTLPs = 2
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
)
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
nextPacketSendTime time.Time
skippedPackets []protocol.PacketNumber
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
largestSentBeforeRTO protocol.PacketNumber
packetHistory *sentPacketHistory
stopWaitingManager stopWaitingManager
retransmissionQueue []*Packet
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times a TLP has been sent without receiving an ack.
tlpCount uint32
allowTLP bool
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The number of RTO probe packets that should be sent.
numRTOs int
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
logger utils.Logger
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
false, /* don't use reno since chromium doesn't (why?) */
protocol.InitialCongestionWindow,
protocol.DefaultMaxCongestionWindow,
)
return &sentPacketHandler{
packetHistory: newSentPacketHistory(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
logger: logger,
}
}
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
if p := h.packetHistory.FirstOutstanding(); p != nil {
return p.PacketNumber
}
return h.largestAcked + 1
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.packetHistory.Remove(p.PacketNumber)
}
h.retransmissionQueue = queue
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
h.packetHistory.SentPacket(packet)
h.updateLossDetectionAlarm()
}
}
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
var p []*Packet
for _, packet := range packets {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
p = append(p, packet)
}
}
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p)
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
h.skippedPackets = h.skippedPackets[1:]
}
}
h.lastSentPacketNumber = packet.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
packet.largestAcked = ackFrame.LargestAcked()
}
}
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
h.lastSentHandshakePacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
packet.canBeRetransmitted = true
if h.numRTOs > 0 {
h.numRTOs--
}
h.allowTLP = false
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
return isRetransmittable
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
largestAcked := ackFrame.LargestAcked()
if largestAcked > h.lastSentPacketNumber {
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
}
// duplicate or out of order ACK
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
return nil
}
h.largestReceivedPacketWithAck = withPacketNumber
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
if h.skippedPacketsAcked(ackFrame) {
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
priorInFlight := h.bytesInFlight
for _, p := range ackedPackets {
if encLevel < p.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
}
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
}
if err := h.onPacketAcked(p, rcvTime); err != nil {
return err
}
if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
}
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
return err
}
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
var ackedPackets []*Packet
ackRangeIndex := 0
lowestAcked := ackFrame.LowestAcked()
largestAcked := ackFrame.LargestAcked()
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
ackedPackets = append(ackedPackets, p)
}
} else {
ackedPackets = append(ackedPackets, p)
}
return true, nil
})
if h.logger.Debug() && len(ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(ackedPackets))
for i, p := range ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
}
return ackedPackets, err
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
return true
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
h.alarm = time.Time{}
return
}
if !h.handshakeComplete {
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO or TLP alarm
alarmDuration := h.computeRTOTimeout()
if h.tlpCount < maxTLPs {
tlpAlarm := h.computeTLPTimeout()
// if the RTO duration is shorter than the TLP duration, use the RTO duration
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
}
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
h.lossTime = time.Time{}
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
var lostPackets []*Packet
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
if packet.PacketNumber > h.largestAcked {
return false, nil
}
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, packet)
} else if h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
}
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
return true, nil
})
if h.logger.Debug() && len(lostPackets) > 0 {
pns := make([]protocol.PacketNumber, len(lostPackets))
for i, p := range lostPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
}
for _, p := range lostPackets {
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
if p.canBeRetransmitted {
// queue the packet for retransmission, and report the loss to the congestion controller
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
h.packetHistory.Remove(p.PacketNumber)
}
return nil
}
func (h *sentPacketHandler) OnAlarm() error {
now := time.Now()
var err error
if !h.handshakeComplete {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in handshake mode")
}
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
} else if !h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode")
}
// Early retransmit or time loss detection
err = h.detectLostPackets(now, h.bytesInFlight)
} else if h.tlpCount < maxTLPs {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in TLP mode")
}
h.allowTLP = true
h.tlpCount++
} else {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in RTO mode")
}
// RTO
h.rtoCount++
h.numRTOs += 2
err = h.queueRTOs()
}
if err != nil {
return err
}
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
// This happens if a packet and its retransmissions is acked in the same ACK.
// As soon as we process the first one, this will remove all the retransmissions,
// so we won't find the retransmitted packet number later.
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
return nil
}
// only report the acking of this packet to the congestion controller if:
// * it is a retransmittable packet
// * this packet wasn't retransmitted yet
if p.isRetransmission {
// that the parent doesn't exist is expected to happen every time the original packet was already acked
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
if len(parent.retransmittedAs) == 1 {
parent.retransmittedAs = nil
} else {
// remove this packet from the slice of retransmission
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
for _, pn := range parent.retransmittedAs {
if pn != p.PacketNumber {
retransmittedAs = append(retransmittedAs, pn)
}
}
parent.retransmittedAs = retransmittedAs
}
}
}
// this also applies to packets that have been retransmitted as probe packets
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if h.rtoCount > 0 {
h.verifyRTO(p.PacketNumber)
}
if err := h.stopRetransmissionsFor(p); err != nil {
return err
}
h.rtoCount = 0
h.tlpCount = 0
h.handshakeCount = 0
return h.packetHistory.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
for _, r := range p.retransmittedAs {
packet := h.packetHistory.GetPacket(r)
if packet == nil {
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
}
h.stopRetransmissionsFor(packet)
}
return nil
}
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
if pn <= h.largestSentBeforeRTO {
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
// Replace SRTT with latest_rtt and increase the variance to prevent
// a spurious RTO from happening again.
h.rttStats.ExpireSmoothedMetrics()
return
}
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 {
return nil
}
packet := h.retransmissionQueue[0]
// Shift the slice and don't retain anything that isn't needed.
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet
}
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.allowTLP {
return SendTLP
}
if h.numRTOs > 0 {
return SendRTO
}
// Only send ACKs if we're congestion limited.
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
}
return SendAck
}
// Send retransmissions first, if there are any.
if len(h.retransmissionQueue) > 0 {
return SendRetransmission
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.nextPacketSendTime
}
func (h *sentPacketHandler) ShouldSendNumPackets() int {
if h.numRTOs > 0 {
// RTO probes should not be paced, but must be sent immediately.
return h.numRTOs
}
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
if delay == 0 || delay > protocol.MinPacingDelay {
return 1
}
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
// retransmit the oldest two packets
func (h *sentPacketHandler) queueRTOs() error {
h.largestSentBeforeRTO = h.lastSentPacketNumber
// Queue the first two outstanding packets for retransmission.
// This does NOT declare this packets as lost:
// They are still tracked in the packet history and count towards the bytes in flight.
for i := 0; i < 2; i++ {
if p := h.packetHistory.FirstOutstanding(); p != nil {
h.logger.Debugf("Queueing packet %#x for retransmission (RTO)", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
}
return nil
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
return nil
}
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
if !p.canBeRetransmitted {
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
}
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
return nil
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
// TODO(#1236): include the max_ack_delay
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
var rto time.Duration
rtt := h.rttStats.SmoothedRTT()
if rtt == 0 {
rto = defaultRTOTimeout
} else {
rto = rtt + 4*h.rttStats.MeanDeviation()
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff
rto = rto << h.rtoCount
return utils.MinDuration(rto, maxRTOTimeout)
}
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return true
}
}
return false
}
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
lowestUnacked := h.lowestUnacked()
deleteIndex := 0
for i, p := range h.skippedPackets {
if p < lowestUnacked {
deleteIndex = i + 1
}
}
h.skippedPackets = h.skippedPackets[deleteIndex:]
}

View File

@ -0,0 +1,127 @@
package ackhandler
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
}
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
return el
}
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
retransmission, ok := h.packetMap[retransmissionOf]
// The retransmitted packet is not present anymore.
// This can happen if it was acked in between dequeueing of the retransmission and sending.
// Just treat the retransmissions as normal packets.
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
if !ok {
for _, packet := range packets {
h.sentPacketImpl(packet)
}
return
}
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
for i, packet := range packets {
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
el := h.sentPacketImpl(packet)
el.Value.isRetransmission = true
el.Value.retransmissionOf = retransmissionOf
}
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
if el, ok := h.packetMap[p]; ok {
return &el.Value
}
return nil
}
// Iterate iterates through all packets.
// The callback must not modify the history.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
var err error
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn]
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && !el.Value.canBeRetransmitted {
el = el.Next()
}
h.firstOutstanding = el
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}

View File

@ -30,8 +30,9 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFr
} }
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) { func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
if ack.LargestAcked >= s.nextLeastUnacked { largestAcked := ack.LargestAcked()
s.nextLeastUnacked = ack.LargestAcked + 1 if largestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = largestAcked + 1
} }
} }

View File

@ -16,11 +16,10 @@ import (
// allow a 10 shift right to divide. // allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3) // 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling // where 0.100 is 100 ms which is the scaling round trip time.
// round trip time.
const cubeScale = 40 const cubeScale = 40
const cubeCongestionWindowScale = 410 const cubeCongestionWindowScale = 410
const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
const defaultNumConnections = 2 const defaultNumConnections = 2
@ -32,39 +31,35 @@ const beta float32 = 0.7
// new concurrent flows and speed up convergence. // new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85 const betaLastMax float32 = 0.85
// If true, Cubic's epoch is shifted when the sender is application-limited.
const shiftQuicCubicEpochWhenAppLimited = true
const maxCubicTimeInterval = 30 * time.Millisecond
// Cubic implements the cubic algorithm from TCP // Cubic implements the cubic algorithm from TCP
type Cubic struct { type Cubic struct {
clock Clock clock Clock
// Number of connections to simulate. // Number of connections to simulate.
numConnections int numConnections int
// Time when this cycle started, after last loss event. // Time when this cycle started, after last loss event.
epoch time.Time epoch time.Time
// Time when sender went into application-limited period. Zero if not in
// application-limited period. // Max congestion window used just before last loss event.
appLimitedStartTime time.Time
// Time when we updated last_congestion_window.
lastUpdateTime time.Time
// Last congestion window (in packets) used.
lastCongestionWindow protocol.PacketNumber
// Max congestion window (in packets) used just before last loss event.
// Note: to improve fairness to other streams an additional back off is // Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value. // applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.PacketNumber lastMaxCongestionWindow protocol.ByteCount
// Number of acked packets since the cycle started (epoch).
ackedPacketsCount protocol.PacketNumber // Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets. // TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.PacketNumber estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function. // Origin point of cubic function.
originPointCongestionWindow protocol.PacketNumber originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second. // Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32 timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function. // Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.PacketNumber lastTargetCongestionWindow protocol.ByteCount
} }
// NewCubic returns a new Cubic instance // NewCubic returns a new Cubic instance
@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic {
// Reset is called after a timeout to reset the cubic state // Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() { func (c *Cubic) Reset() {
c.epoch = time.Time{} c.epoch = time.Time{}
c.appLimitedStartTime = time.Time{}
c.lastUpdateTime = time.Time{}
c.lastCongestionWindow = 0
c.lastMaxCongestionWindow = 0 c.lastMaxCongestionWindow = 0
c.ackedPacketsCount = 0 c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0 c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0 c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0 c.timeToOriginPoint = 0
@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 {
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
} }
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use // OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence. // the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() { func (c *Cubic) OnApplicationLimited() {
if shiftQuicCubicEpochWhenAppLimited { // When sender is not using the available congestion window, the window does
// When sender is not using the available congestion window, Cubic's epoch // not grow. But to be RTT-independent, Cubic assumes that the sender has been
// should not continue growing. Record the time when sender goes into an // using the entire window during the time since the beginning of the current
// app-limited period here, to compensate later when cwnd growth happens. // "epoch" (the end of the last loss recovery period). Since
if c.appLimitedStartTime.IsZero() { // application-limited periods break this assumption, we reset the epoch when
c.appLimitedStartTime = c.clock.Now() // in such a period. This reset effectively freezes congestion window growth
} // through application-limited periods and allows Cubic growth to continue
} else { // when the entire window is being used.
// When sender is not using the available congestion window, Cubic's epoch
// should not continue growing. Reset the epoch when in such a period.
c.epoch = time.Time{} c.epoch = time.Time{}
} }
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after // CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new // a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window. // congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber { func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow < c.lastMaxCongestionWindow { if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another // We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up. // flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow)) c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else { } else {
c.lastMaxCongestionWindow = currentCongestionWindow c.lastMaxCongestionWindow = currentCongestionWindow
} }
c.epoch = time.Time{} // Reset time. c.epoch = time.Time{} // Reset time.
return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta()) return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
} }
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. // CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window // Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last // follows a cubic function that depends on the time passed since last
// packet loss. // packet loss.
func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber { func (c *Cubic) CongestionWindowAfterAck(
c.ackedPacketsCount++ // Packets acked. ackedBytes protocol.ByteCount,
currentTime := c.clock.Now() currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
// Cubic is "independent" of RTT, the update is limited by the time elapsed. eventTime time.Time,
if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) { ) protocol.ByteCount {
return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow) c.ackedBytesCount += ackedBytes
}
c.lastCongestionWindow = currentCongestionWindow
c.lastUpdateTime = currentTime
if c.epoch.IsZero() { if c.epoch.IsZero() {
// First ACK after a loss event. // First ACK after a loss event.
c.epoch = currentTime // Start of epoch. c.epoch = eventTime // Start of epoch.
c.ackedPacketsCount = 1 // Reset count. c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic. // Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow { if c.lastMaxCongestionWindow <= currentCongestionWindow {
@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow c.originPointCongestionWindow = c.lastMaxCongestionWindow
} }
} else {
// If sender was app-limited, then freeze congestion window growth during
// app-limited period. Continue growth now by shifting the epoch-start
// through the app-limited period.
if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() {
shift := currentTime.Sub(c.appLimitedStartTime)
c.epoch = c.epoch.Add(shift)
c.appLimitedStartTime = time.Time{}
}
} }
// Change the time unit from microseconds to 2^10 fractions per second. Take // Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a // the round trip time in account. This is done to allow us to use shift as a
// divide operator. // divide operator.
elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000 elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime offset := int64(c.timeToOriginPoint) - elapsedTime
// Right-shifts of negative, signed numbers have
// implementation-dependent behavior. Force the offset to be
// positive, similar to the kernel implementation.
if offset < 0 { if offset < 0 {
offset = -offset offset = -offset
} }
deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale)
var targetCongestionWindow protocol.PacketNumber deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) { if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else { } else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
} }
// With dynamic beta/alpha based on number of active streams, it is possible // Limit the CWND increase to half the acked bytes.
// for the required_ack_count to become much lower than acked_packets_count_ targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// suddenly, leading to more than one iteration through the following loop.
for { // Increase the window by approximately Alpha * 1 MSS of bytes every
// Update estimated TCP congestion_window. // time we ack an estimated tcp window of bytes. For small
requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha()) // congestion windows (less than 25), the formula below will
if c.ackedPacketsCount < requiredAckCount { // increase slightly slower than linearly per estimated tcp window
break // of bytes.
} c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
c.ackedPacketsCount -= requiredAckCount c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow++
}
// We have a new cubic congestion window. // We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow c.lastTargetCongestionWindow = targetCongestionWindow
@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
if targetCongestionWindow < c.estimatedTCPcongestionWindow { if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow targetCongestionWindow = c.estimatedTCPcongestionWindow
} }
return targetCongestionWindow return targetCongestionWindow
} }

View File

@ -9,8 +9,8 @@ import (
const ( const (
maxBurstBytes = 3 * protocol.DefaultTCPMSS maxBurstBytes = 3 * protocol.DefaultTCPMSS
defaultMinimumCongestionWindow protocol.PacketNumber = 2
renoBeta float32 = 0.7 // Reno backoff factor. renoBeta float32 = 0.7 // Reno backoff factor.
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
) )
type cubicSender struct { type cubicSender struct {
@ -31,12 +31,6 @@ type cubicSender struct {
// Track the largest packet number outstanding when a CWND cutback occurs. // Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber largestSentAtLastCutback protocol.PacketNumber
// Congestion window in packets.
congestionWindow protocol.PacketNumber
// Slow start congestion window in packets, aka ssthresh.
slowstartThreshold protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart. // Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost // Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool lastCutbackExitedSlowstart bool
@ -44,24 +38,35 @@ type cubicSender struct {
// When true, exit slow start with large cutback of congestion window. // When true, exit slow start with large cutback of congestion window.
slowStartLargeReduction bool slowStartLargeReduction bool
// Minimum congestion window in packets. // Congestion window in packets.
minCongestionWindow protocol.PacketNumber congestionWindow protocol.ByteCount
// Maximum number of outstanding packets for tcp. // Minimum congestion window in packets.
maxTCPCongestionWindow protocol.PacketNumber minCongestionWindow protocol.ByteCount
// Maximum congestion window.
maxCongestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowstartThreshold protocol.ByteCount
// Number of connections to simulate. // Number of connections to simulate.
numConnections int numConnections int
// ACK counter for the Reno implementation. // ACK counter for the Reno implementation.
congestionWindowCount protocol.ByteCount numAckedPackets uint64
initialCongestionWindow protocol.PacketNumber initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.PacketNumber initialMaxCongestionWindow protocol.ByteCount
minSlowStartExitWindow protocol.ByteCount
} }
var _ SendAlgorithm = &cubicSender{}
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
// NewCubicSender makes a new cubic sender // NewCubicSender makes a new cubic sender
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo { func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
return &cubicSender{ return &cubicSender{
rttStats: rttStats, rttStats: rttStats,
initialCongestionWindow: initialCongestionWindow, initialCongestionWindow: initialCongestionWindow,
@ -69,28 +74,37 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
congestionWindow: initialCongestionWindow, congestionWindow: initialCongestionWindow,
minCongestionWindow: defaultMinimumCongestionWindow, minCongestionWindow: defaultMinimumCongestionWindow,
slowstartThreshold: initialMaxCongestionWindow, slowstartThreshold: initialMaxCongestionWindow,
maxTCPCongestionWindow: initialMaxCongestionWindow, maxCongestionWindow: initialMaxCongestionWindow,
numConnections: defaultNumConnections, numConnections: defaultNumConnections,
cubic: NewCubic(clock), cubic: NewCubic(clock),
reno: reno, reno: reno,
} }
} }
func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { // TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
if c.InRecovery() { if c.InRecovery() {
// PRR is used when in recovery. // PRR is used when in recovery.
return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
}
if c.GetCongestionWindow() > bytesInFlight {
return 0 return 0
} }
return utils.InfDuration }
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
delay = delay * 8 / 5
}
return delay
} }
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { func (c *cubicSender) OnPacketSent(
// Only update bytesInFlight for data packets. sentTime time.Time,
bytesInFlight protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
if !isRetransmittable { if !isRetransmittable {
return false return
} }
if c.InRecovery() { if c.InRecovery() {
// PRR is used when in recovery. // PRR is used when in recovery.
@ -98,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By
} }
c.largestSentPacketNumber = packetNumber c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber) c.hybridSlowStart.OnPacketSent(packetNumber)
return true
} }
func (c *cubicSender) InRecovery() bool { func (c *cubicSender) InRecovery() bool {
@ -110,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool {
} }
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS return c.congestionWindow
} }
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount { func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS return c.slowstartThreshold
} }
func (c *cubicSender) ExitSlowstart() { func (c *cubicSender) ExitSlowstart() {
c.slowstartThreshold = c.congestionWindow c.slowstartThreshold = c.congestionWindow
} }
func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber { func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
return c.slowstartThreshold return c.slowstartThreshold
} }
@ -131,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() {
} }
} }
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() { if c.InRecovery() {
// PRR is used when in recovery. // PRR is used when in recovery.
c.prr.OnPacketAcked(ackedBytes) c.prr.OnPacketAcked(ackedBytes)
return return
} }
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight) c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() { if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
} }
} }
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { func (c *cubicSender) OnPacketLost(
packetNumber protocol.PacketNumber,
lostBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected. // already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback { if packetNumber <= c.largestSentAtLastCutback {
@ -152,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.stats.slowstartPacketsLost++ c.stats.slowstartPacketsLost++
c.stats.slowstartBytesLost += lostBytes c.stats.slowstartBytesLost += lostBytes
if c.slowStartLargeReduction { if c.slowStartLargeReduction {
if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS { // Reduce congestion window by lost_bytes for every loss.
// Reduce congestion window by 1 for every mss of bytes lost. c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow)
}
c.slowstartThreshold = c.congestionWindow c.slowstartThreshold = c.congestionWindow
} }
} }
@ -166,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.stats.slowstartPacketsLost++ c.stats.slowstartPacketsLost++
} }
c.prr.OnPacketLost(bytesInFlight) c.prr.OnPacketLost(priorInFlight)
// TODO(chromium): Separate out all of slow start into a separate class. // TODO(chromium): Separate out all of slow start into a separate class.
if c.slowStartLargeReduction && c.InSlowStart() { if c.slowStartLargeReduction && c.InSlowStart() {
c.congestionWindow = c.congestionWindow - 1 if c.congestionWindow >= 2*c.initialCongestionWindow {
c.minSlowStartExitWindow = c.congestionWindow / 2
}
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
} else if c.reno { } else if c.reno {
c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta()) c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
} else { } else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
} }
// Enforce a minimum congestion window.
if c.congestionWindow < c.minCongestionWindow { if c.congestionWindow < c.minCongestionWindow {
c.congestionWindow = c.minCongestionWindow c.congestionWindow = c.minCongestionWindow
} }
@ -184,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.largestSentAtLastCutback = c.largestSentPacketNumber c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start // reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery. // counting again when we're out of recovery.
c.congestionWindowCount = 0 c.numAckedPackets = 0
} }
func (c *cubicSender) RenoBeta() float32 { func (c *cubicSender) RenoBeta() float32 {
@ -197,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 {
// Called when we receive an ack. Normal TCP tracks how many packets one ack // Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet. // represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { func (c *cubicSender) maybeIncreaseCwnd(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using // Do not increase the congestion window unless the sender is close to using
// the current window. // the current window.
if !c.isCwndLimited(bytesInFlight) { if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited() c.cubic.OnApplicationLimited()
return return
} }
if c.congestionWindow >= c.maxTCPCongestionWindow { if c.congestionWindow >= c.maxCongestionWindow {
return return
} }
if c.InSlowStart() { if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK. // TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow++ c.congestionWindow += protocol.DefaultTCPMSS
return return
} }
// Congestion avoidance
if c.reno { if c.reno {
// Classic Reno congestion avoidance. // Classic Reno congestion avoidance.
c.congestionWindowCount++ c.numAckedPackets++
// Divide by num_connections to smoothly increase the CWND at a faster // Divide by num_connections to smoothly increase the CWND at a faster
// rate than conventional Reno. // rate than conventional Reno.
if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow { if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
c.congestionWindow++ c.congestionWindow += protocol.DefaultTCPMSS
c.congestionWindowCount = 0 c.numAckedPackets = 0
} }
} else { } else {
c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT())) c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
} }
} }
@ -278,21 +306,13 @@ func (c *cubicSender) OnConnectionMigration() {
c.largestSentAtLastCutback = 0 c.largestSentAtLastCutback = 0
c.lastCutbackExitedSlowstart = false c.lastCutbackExitedSlowstart = false
c.cubic.Reset() c.cubic.Reset()
c.congestionWindowCount = 0 c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow c.congestionWindow = c.initialCongestionWindow
c.slowstartThreshold = c.initialMaxCongestionWindow c.slowstartThreshold = c.initialMaxCongestionWindow
c.maxTCPCongestionWindow = c.initialMaxCongestionWindow c.maxCongestionWindow = c.initialMaxCongestionWindow
} }
// SetSlowStartLargeReduction allows enabling the SSLR experiment // SetSlowStartLargeReduction allows enabling the SSLR experiment
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) { func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
c.slowStartLargeReduction = enabled c.slowStartLargeReduction = enabled
} }
// RetransmissionDelay gives the time to retransmission
func (c *cubicSender) RetransmissionDelay() time.Duration {
if c.rttStats.SmoothedRTT() == 0 {
return 0
}
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
}

View File

@ -8,16 +8,15 @@ import (
// A SendAlgorithm performs congestion control and calculates the congestion window // A SendAlgorithm performs congestion control and calculates the congestion window
type SendAlgorithm interface { type SendAlgorithm interface {
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
GetCongestionWindow() protocol.ByteCount GetCongestionWindow() protocol.ByteCount
MaybeExitSlowStart() MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
SetNumEmulatedConnections(n int) SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool) OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration() OnConnectionMigration()
RetransmissionDelay() time.Duration
// Experiments // Experiments
SetSlowStartLargeReduction(enabled bool) SetSlowStartLargeReduction(enabled bool)
@ -31,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface {
// Stuff only used in testing // Stuff only used in testing
HybridSlowStart() *HybridSlowStart HybridSlowStart() *HybridSlowStart
SlowstartThreshold() protocol.PacketNumber SlowstartThreshold() protocol.ByteCount
RenoBeta() float32 RenoBeta() float32
InRecovery() bool InRecovery() bool
} }

View File

@ -1,10 +1,7 @@
package congestion package congestion
import ( import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
// OnPacketLost should be called on the first loss that triggers a recovery // OnPacketLost should be called on the first loss that triggers a recovery
// period and all other methods in this class should only be called when in // period and all other methods in this class should only be called when in
// recovery. // recovery.
func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) { func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
p.bytesSentSinceLoss = 0 p.bytesSentSinceLoss = 0
p.bytesInFlightBeforeLoss = bytesInFlight p.bytesInFlightBeforeLoss = priorInFlight
p.bytesDeliveredSinceLoss = 0 p.bytesDeliveredSinceLoss = 0
p.ackCountSinceLoss = 0 p.ackCountSinceLoss = 0
} }
@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
p.ackCountSinceLoss++ p.ackCountSinceLoss++
} }
// TimeUntilSend calculates the time until a packet can be sent // CanSend returns if packets can be sent
func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration { func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
// Return QuicTime::Zero In order to ensure limited transmit always works. // Return QuicTime::Zero In order to ensure limited transmit always works.
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS { if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
return 0 return true
} }
if congestionWindow > bytesInFlight { if congestionWindow > bytesInFlight {
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead // During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
// of sending the entire available window. This prevents burst retransmits // of sending the entire available window. This prevents burst retransmits
// when more packets are lost than the CWND reduction. // when more packets are lost than the CWND reduction.
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS // limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss { return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
return utils.InfDuration
}
return 0
} }
// Implement Proportional Rate Reduction (RFC6937). // Implement Proportional Rate Reduction (RFC6937).
// Checks a simplified version of the PRR formula that doesn't use division: // Checks a simplified version of the PRR formula that doesn't use division:
// AvailableSendWindow = // AvailableSendWindow =
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent // CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss { return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
return 0
}
return utils.InfDuration
} }

View File

@ -0,0 +1,101 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const (
rttAlpha float32 = 0.125
oneMinusAlpha float32 = (1 - rttAlpha)
rttBeta float32 = 0.25
oneMinusBeta float32 = (1 - rttBeta)
// The default RTT used before an RTT sample is taken.
defaultInitialRTT = 100 * time.Millisecond
)
// RTTStats provides round-trip statistics
type RTTStats struct {
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{}
}
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// LatestRTT returns the most recent rtt measurement.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection.
// If no valid updates have occurred, it returns the initial RTT.
func (r *RTTStats) SmoothedOrInitialRTT() time.Duration {
if r.smoothedRTT != 0 {
return r.smoothedRTT
}
return defaultInitialRTT
}
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 {
return
}
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
// ackDelay but the raw observed sendDelta, since poor clock granularity at
// the client may cause a high ackDelay to result in underestimation of the
// r.minRTT.
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
// Correct for ackDelay if information received from the peer results in a
// an RTT sample at least as large as minRTT. Otherwise, only use the
// sendDelta.
sample := sendDelta
if sample-r.minRTT >= ackDelay {
sample -= ackDelay
}
r.latestRTT = sample
// First time call.
if r.smoothedRTT == 0 {
r.smoothedRTT = sample
r.meanDeviation = sample / 2
} else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
}
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
}

View File

@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
return cert.Certificate[0], nil return cert.Certificate[0], nil
} }
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
c := cc.config conf := c.config
c, err := maybeGetConfigForClient(c, sni) conf, err := maybeGetConfigForClient(conf, sni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// The rest of this function is mostly copied from crypto/tls.getCertificate // The rest of this function is mostly copied from crypto/tls.getCertificate
if c.GetCertificate != nil { if conf.GetCertificate != nil {
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil { if cert != nil || err != nil {
return cert, err return cert, err
} }
} }
if len(c.Certificates) == 0 { if len(conf.Certificates) == 0 {
return nil, errNoMatchingCertificate return nil, errNoMatchingCertificate
} }
if len(c.Certificates) == 1 || c.NameToCertificate == nil { if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
// There's only one choice, so no point doing any work. // There's only one choice, so no point doing any work.
return &c.Certificates[0], nil return &conf.Certificates[0], nil
} }
name := strings.ToLower(sni) name := strings.ToLower(sni)
@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
name = name[:len(name)-1] name = name[:len(name)-1]
} }
if cert, ok := c.NameToCertificate[name]; ok { if cert, ok := conf.NameToCertificate[name]; ok {
return cert, nil return cert, nil
} }
@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
for i := range labels { for i := range labels {
labels[i] = "*" labels[i] = "*"
candidate := strings.Join(labels, ".") candidate := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[candidate]; ok { if cert, ok := conf.NameToCertificate[candidate]; ok {
return cert, nil return cert, nil
} }
} }
// If nothing matches, return the first certificate. // If nothing matches, return the first certificate.
return &c.Certificates[0], nil return &conf.Certificates[0], nil
} }
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) { func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {

View File

@ -18,6 +18,7 @@ type CertManager interface {
GetLeafCertHash() (uint64, error) GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error Verify(hostname string) error
GetChain() []*x509.Certificate
} }
type certManager struct { type certManager struct {
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
return nil return nil
} }
func (c *certManager) GetChain() []*x509.Certificate {
return c.chain
}
func (c *certManager) GetCommonCertificateHashes() []byte { func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes() return getCommonCertificateHashes()
} }

View File

@ -1,61 +0,0 @@
// +build ignore
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/aead/chacha20"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadChacha20Poly1305 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
}
// copy because ChaCha20Poly1305 expects array pointers
var MyKey, OtherKey [32]byte
copy(MyKey[:], myKey)
copy(OtherKey[:], otherKey)
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
if err != nil {
return nil, err
}
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
if err != nil {
return nil, err
}
return &aeadChacha20Poly1305{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View File

@ -1,71 +0,0 @@
// +build ignore
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Chacha20poly1305", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 32)
keyBob = make([]byte, 32)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + 12))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})

View File

@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) {
if _, err := rand.Read(c.secret[:]); err != nil { if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key") return nil, errors.New("Curve25519: could not create private key")
} }
// See https://cr.yp.to/ecdh.html
c.secret[0] &= 248
c.secret[31] &= 127
c.secret[31] |= 64
curve25519.ScalarBaseMult(&c.public, &c.secret) curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil return c, nil
} }

View File

@ -1,13 +1,16 @@
package crypto package crypto
import ( import (
"crypto"
"encoding/binary"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
const ( const (
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret" clientExporterLabel = "EXPORTER-QUIC client 1rtt"
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" serverExporterLabel = "EXPORTER-QUIC server 1rtt"
) )
// A TLSExporter gets the negotiated ciphersuite and computes exporter // A TLSExporter gets the negotiated ciphersuite and computes exporter
@ -16,6 +19,14 @@ type TLSExporter interface {
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
} }
func qhkdfExpand(secret []byte, label string, length int) []byte {
qlabel := make([]byte, 2+1+5+len(label))
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(5 + len(label))
copy(qlabel[3:], []byte("QUIC "+label))
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
}
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance // DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
var myLabel, otherLabel string var myLabel, otherLabel string
@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen) key = qhkdfExpand(secret, "key", cs.KeyLen)
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen) iv = qhkdfExpand(secret, "iv", cs.IvLen)
return key, iv, nil return key, iv, nil
} }

View File

@ -6,7 +6,6 @@ import (
"io" "io"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )
@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
} else { } else {
info.Write([]byte("QUIC key expansion\x00")) info.Write([]byte("QUIC key expansion\x00"))
} }
utils.BigEndian.WriteUint64(&info, uint64(connID)) info.Write(connID)
info.Write(chlo) info.Write(chlo)
info.Write(scfg) info.Write(scfg)
info.Write(cert) info.Write(cert)

View File

@ -2,13 +2,12 @@ package crypto
import ( import (
"crypto" "crypto"
"encoding/binary"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39} var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) { func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID) clientSecret, serverSecret := computeSecrets(connectionID)
@ -28,17 +27,15 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
} }
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
connID := make([]byte, 8) handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
binary.BigEndian.PutUint64(connID, uint64(connectionID)) clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID) serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
return return
} }
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16) key = qhkdfExpand(secret, "key", 16)
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12) iv = qhkdfExpand(secret, "iv", 12)
return return
} }

View File

@ -1,10 +1,11 @@
package crypto package crypto
import ( import (
"encoding/binary" "bytes"
"errors" "errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/fnv128a"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long") return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
} }
hash := fnv128a.New() hash := fnv.New128a()
hash.Write(associatedData) hash.Write(associatedData)
hash.Write(src[12:]) hash.Write(src[12:])
if n.perspective == protocol.PerspectiveServer { if n.perspective == protocol.PerspectiveServer {
@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
} else { } else {
hash.Write([]byte("Server")) hash.Write([]byte("Server"))
} }
testHigh, testLow := hash.Sum128() sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
low := binary.LittleEndian.Uint64(src) if !bytes.Equal(sum[:12], src[:12]) {
high := binary.LittleEndian.Uint32(src[8:]) return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
if uint32(testHigh&0xffffffff) != high || testLow != low {
return nil, errors.New("NullAEAD: failed to authenticate received data")
} }
return src[12:], nil return src[12:], nil
} }
@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
dst = dst[:12+len(src)] dst = dst[:12+len(src)]
} }
hash := fnv128a.New() hash := fnv.New128a()
hash.Write(associatedData) hash.Write(associatedData)
hash.Write(src) hash.Write(src)
@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
} else { } else {
hash.Write([]byte("Client")) hash.Write([]byte("Client"))
} }
sum := make([]byte, 0, 16)
high, low := hash.Sum128() sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
copy(dst[12:], src) copy(dst[12:], src)
binary.LittleEndian.PutUint64(dst, low) copy(dst, sum[:12])
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
return dst return dst
} }
func (n *nullAEADFNV128a) Overhead() int { func (n *nullAEADFNV128a) Overhead() int {
return 12 return 12
} }
func reverse(a []byte) {
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
a[left], a[right] = a[right], a[left]
}
}

View File

@ -1,76 +0,0 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// StkSource is used to create and verify source address tokens
type StkSource interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
type stkSource struct {
aead cipher.AEAD
}
const stkKeySize = 16
// Chrome currently sets this to 12, but discusses changing it to 16. We start
// at 16 :)
const stkNonceSize = 16
// NewStkSource creates a source for source address tokens
func NewStkSource() (StkSource, error) {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
key, err := deriveKey(secret)
if err != nil {
return nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
if err != nil {
return nil, err
}
return &stkSource{aead: aead}, nil
}
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, stkNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
return s.aead.Seal(nonce, nonce, data, nil), nil
}
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
if len(p) < stkNonceSize {
return nil, fmt.Errorf("STK too short: %d", len(p))
}
nonce := p[:stkNonceSize]
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
}
func deriveKey(secret []byte) ([]byte, error) {
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
key := make([]byte, stkKeySize)
if _, err := io.ReadFull(r, key); err != nil {
return nil, err
}
return key, nil
}

View File

@ -4,41 +4,38 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type baseFlowController struct { type baseFlowController struct {
mutex sync.RWMutex // for sending data
rttStats *congestion.RTTStats
bytesSent protocol.ByteCount bytesSent protocol.ByteCount
sendWindow protocol.ByteCount sendWindow protocol.ByteCount
lastWindowUpdateTime time.Time // for receiving data
mutex sync.RWMutex
bytesRead protocol.ByteCount bytesRead protocol.ByteCount
highestReceived protocol.ByteCount highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount receiveWindow protocol.ByteCount
receiveWindowIncrement protocol.ByteCount receiveWindowSize protocol.ByteCount
maxReceiveWindowIncrement protocol.ByteCount maxReceiveWindowSize protocol.ByteCount
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats
logger utils.Logger
} }
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.bytesSent += n c.bytesSent += n
} }
// UpdateSendWindow should be called after receiving a WindowUpdateFrame // UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated // it returns true if the window was actually updated
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
if offset > c.sendWindow { if offset > c.sendWindow {
c.sendWindow = offset c.sendWindow = offset
} }
@ -57,52 +54,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
defer c.mutex.Unlock() defer c.mutex.Unlock()
// pretend we sent a WindowUpdate when reading the first byte // pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window increment already works for the first WindowUpdate // this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 { if c.bytesRead == 0 {
c.lastWindowUpdateTime = time.Now() c.startNewAutoTuningEpoch()
} }
c.bytesRead += n c.bytesRead += n
} }
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
}
// getWindowUpdate updates the receive window, if necessary // getWindowUpdate updates the receive window, if necessary
// it returns the new offset // it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
diff := c.receiveWindow - c.bytesRead if !c.hasWindowUpdate() {
// update the window when more than half of it was already consumed
if diff >= (c.receiveWindowIncrement / 2) {
return 0 return 0
} }
c.maybeAdjustWindowIncrement() c.maybeAdjustWindowSize()
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement c.receiveWindow = c.bytesRead + c.receiveWindowSize
c.lastWindowUpdateTime = time.Now()
return c.receiveWindow return c.receiveWindow
} }
func (c *baseFlowController) IsBlocked() bool { // maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
c.mutex.RLock() // For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
defer c.mutex.RUnlock() func (c *baseFlowController) maybeAdjustWindowSize() {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
return c.sendWindowSize() == 0 // don't do anything if less than half the window has been consumed
} if bytesReadInEpoch <= c.receiveWindowSize/2 {
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
func (c *baseFlowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() {
return return
} }
rtt := c.rttStats.SmoothedRTT() rtt := c.rttStats.SmoothedRTT()
if rtt == 0 { if rtt == 0 {
return return
} }
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
// interval between the window updates is sufficiently large, no need to increase the increment if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
if timeSinceLastWindowUpdate >= 2*rtt { // window is consumed too fast, try to increase the window size
return c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
} }
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) c.startNewAutoTuningEpoch()
}
func (c *baseFlowController) startNewAutoTuningEpoch() {
c.epochStartTime = time.Now()
c.epochStartOffset = c.bytesRead
} }
func (c *baseFlowController) checkFlowControlViolation() bool { func (c *baseFlowController) checkFlowControlViolation() bool {

View File

@ -2,16 +2,18 @@ package flowcontrol
import ( import (
"fmt" "fmt"
"time"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
type connectionFlowController struct { type connectionFlowController struct {
lastBlockedAt protocol.ByteCount
baseFlowController baseFlowController
queueWindowUpdate func()
} }
var _ ConnectionFlowController = &connectionFlowController{} var _ ConnectionFlowController = &connectionFlowController{}
@ -21,25 +23,37 @@ var _ ConnectionFlowController = &connectionFlowController{}
func NewConnectionFlowController( func NewConnectionFlowController(
receiveWindow protocol.ByteCount, receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger,
) ConnectionFlowController { ) ConnectionFlowController {
return &connectionFlowController{ return &connectionFlowController{
baseFlowController: baseFlowController{ baseFlowController: baseFlowController{
rttStats: rttStats, rttStats: rttStats,
receiveWindow: receiveWindow, receiveWindow: receiveWindow,
receiveWindowIncrement: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowIncrement: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
logger: logger,
}, },
queueWindowUpdate: queueWindowUpdate,
} }
} }
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.baseFlowController.sendWindowSize() return c.baseFlowController.sendWindowSize()
} }
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
// IncrementHighestReceived adds an increment to the highestReceived value // IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock() c.mutex.Lock()
@ -52,26 +66,34 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
return nil return nil
} }
func (c *connectionFlowController) MaybeQueueWindowUpdate() {
c.mutex.Lock()
hasWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if hasWindowUpdate {
c.queueWindowUpdate()
}
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() oldWindowSize := c.receiveWindowSize
oldWindowIncrement := c.receiveWindowIncrement
offset := c.baseFlowController.getWindowUpdate() offset := c.baseFlowController.getWindowUpdate()
if oldWindowIncrement < c.receiveWindowIncrement { if oldWindowSize < c.receiveWindowSize {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
} }
c.mutex.Unlock()
return offset return offset
} }
// EnsureMinimumWindowIncrement sets a minimum window increment // EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows // it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() if inc > c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
if inc > c.receiveWindowIncrement { c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) c.startNewAutoTuningEpoch()
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
} }
c.mutex.Unlock()
} }

View File

@ -5,17 +5,19 @@ import "github.com/lucas-clemente/quic-go/internal/protocol"
type flowController interface { type flowController interface {
// for sending // for sending
SendWindowSize() protocol.ByteCount SendWindowSize() protocol.ByteCount
IsBlocked() bool
UpdateSendWindow(protocol.ByteCount) UpdateSendWindow(protocol.ByteCount)
AddBytesSent(protocol.ByteCount) AddBytesSent(protocol.ByteCount)
// for receiving // for receiving
AddBytesRead(protocol.ByteCount) AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
MaybeQueueWindowUpdate() // queues a window update, if necessary
} }
// A StreamFlowController is a flow controller for a QUIC stream. // A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface { type StreamFlowController interface {
flowController flowController
// for sending
IsBlocked() (bool, protocol.ByteCount)
// for receiving // for receiving
// UpdateHighestReceived should be called when a new highest offset is received // UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
@ -25,13 +27,15 @@ type StreamFlowController interface {
// The ConnectionFlowController is the flow controller for the connection. // The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface { type ConnectionFlowController interface {
flowController flowController
// for sending
IsNewlyBlocked() (bool, protocol.ByteCount)
} }
type connectionFlowControllerI interface { type connectionFlowControllerI interface {
ConnectionFlowController ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally // The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending // for sending
EnsureMinimumWindowIncrement(protocol.ByteCount) EnsureMinimumWindowSize(protocol.ByteCount)
// for receiving // for receiving
IncrementHighestReceived(protocol.ByteCount) error IncrementHighestReceived(protocol.ByteCount) error
} }

View File

@ -3,7 +3,7 @@ package flowcontrol
import ( import (
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
@ -14,6 +14,8 @@ type streamFlowController struct {
streamID protocol.StreamID streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI connection connectionFlowControllerI
contributesToConnection bool // does the stream contribute to connection level flow control contributesToConnection bool // does the stream contribute to connection level flow control
@ -30,18 +32,22 @@ func NewStreamFlowController(
receiveWindow protocol.ByteCount, receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount, initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger,
) StreamFlowController { ) StreamFlowController {
return &streamFlowController{ return &streamFlowController{
streamID: streamID, streamID: streamID,
contributesToConnection: contributesToConnection, contributesToConnection: contributesToConnection,
connection: cfc.(connectionFlowControllerI), connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
baseFlowController: baseFlowController{ baseFlowController: baseFlowController{
rttStats: rttStats, rttStats: rttStats,
receiveWindow: receiveWindow, receiveWindow: receiveWindow,
receiveWindowIncrement: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowIncrement: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow, sendWindow: initialSendWindow,
logger: logger,
}, },
} }
} }
@ -102,9 +108,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
} }
func (c *streamFlowController) SendWindowSize() protocol.ByteCount { func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
c.mutex.Lock()
defer c.mutex.Unlock()
window := c.baseFlowController.sendWindowSize() window := c.baseFlowController.sendWindowSize()
if c.contributesToConnection { if c.contributesToConnection {
window = utils.MinByteCount(window, c.connection.SendWindowSize()) window = utils.MinByteCount(window, c.connection.SendWindowSize())
@ -112,17 +115,44 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return window return window
} }
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { // IsBlocked says if it is blocked by stream-level flow control.
c.mutex.Lock() // If it is blocked, the offset is returned.
defer c.mutex.Unlock() func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 {
return false, 0
}
return true, c.sendWindow
}
oldWindowIncrement := c.receiveWindowIncrement func (c *streamFlowController) MaybeQueueWindowUpdate() {
offset := c.baseFlowController.getWindowUpdate() c.mutex.Lock()
if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) c.mutex.Unlock()
if hasWindowUpdate {
c.queueWindowUpdate()
}
if c.contributesToConnection { if c.contributesToConnection {
c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)) c.connection.MaybeQueueWindowUpdate()
} }
} }
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
if c.receivedFinalOffset {
c.mutex.Unlock()
return 0
}
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
}
c.mutex.Unlock()
return offset return offset
} }

View File

@ -6,7 +6,7 @@ import (
"net" "net"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/bifurcation/mint"
) )
const ( const (
@ -29,17 +29,17 @@ type token struct {
// A CookieGenerator generates Cookies // A CookieGenerator generates Cookies
type CookieGenerator struct { type CookieGenerator struct {
cookieSource crypto.StkSource cookieProtector mint.CookieProtector
} }
// NewCookieGenerator initializes a new CookieGenerator // NewCookieGenerator initializes a new CookieGenerator
func NewCookieGenerator() (*CookieGenerator, error) { func NewCookieGenerator() (*CookieGenerator, error) {
stkSource, err := crypto.NewStkSource() cookieProtector, err := mint.NewDefaultCookieProtector()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &CookieGenerator{ return &CookieGenerator{
cookieSource: stkSource, cookieProtector: cookieProtector,
}, nil }, nil
} }
@ -52,7 +52,7 @@ func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return g.cookieSource.NewToken(data) return g.cookieProtector.NewToken(data)
} }
// DecodeToken decodes a Cookie // DecodeToken decodes a Cookie
@ -62,7 +62,7 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
return nil, nil return nil, nil
} }
data, err := g.cookieSource.DecodeToken(encrypted) data, err := g.cookieProtector.DecodeToken(encrypted)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,36 +7,44 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type cookieHandler struct { // A CookieHandler generates and validates cookies.
// The cookie is sent in the TLS Retry.
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
type CookieHandler struct {
callback func(net.Addr, *Cookie) bool callback func(net.Addr, *Cookie) bool
cookieGenerator *CookieGenerator cookieGenerator *CookieGenerator
logger utils.Logger
} }
var _ mint.CookieHandler = &cookieHandler{} var _ mint.CookieHandler = &CookieHandler{}
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) { // NewCookieHandler creates a new CookieHandler.
func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator() cookieGenerator, err := NewCookieGenerator()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cookieHandler{ return &CookieHandler{
callback: callback, callback: callback,
cookieGenerator: cookieGenerator, cookieGenerator: cookieGenerator,
logger: logger,
}, nil }, nil
} }
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) { // Generate a new cookie for a mint connection.
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
if h.callback(conn.RemoteAddr(), nil) { if h.callback(conn.RemoteAddr(), nil) {
return nil, nil return nil, nil
} }
return h.cookieGenerator.NewToken(conn.RemoteAddr()) return h.cookieGenerator.NewToken(conn.RemoteAddr())
} }
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool { // Validate a cookie.
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token) data, err := h.cookieGenerator.DecodeToken(token)
if err != nil { if err != nil {
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
return false return false
} }
return h.callback(conn.RemoteAddr(), data) return h.callback(conn.RemoteAddr(), data)

View File

@ -38,13 +38,12 @@ type cryptoSetupClient struct {
lastSentCHLO []byte lastSentCHLO []byte
certManager crypto.CertManager certManager crypto.CertManager
divNonceChan chan []byte divNonceChan chan struct{}
diversificationNonce []byte diversificationNonce []byte
clientHelloCounter int clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation QuicCryptoKeyDerivationFunction keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
receivedSecurePacket bool receivedSecurePacket bool
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
@ -52,9 +51,11 @@ type cryptoSetupClient struct {
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
params *TransportParameters params *TransportParameters
logger utils.Logger
} }
var _ CryptoSetup = &cryptoSetupClient{} var _ CryptoSetup = &cryptoSetupClient{}
@ -74,15 +75,17 @@ func NewCryptoSetupClient(
tlsConfig *tls.Config, tlsConfig *tls.Config,
params *TransportParameters, params *TransportParameters,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
logger utils.Logger,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cryptoSetupClient{ divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
hostname: hostname, hostname: hostname,
connID: connID, connID: connID,
@ -90,19 +93,20 @@ func NewCryptoSetupClient(
certManager: crypto.NewCertManager(tlsConfig), certManager: crypto.NewCertManager(tlsConfig),
params: params, params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
paramsChan: paramsChan, paramsChan: paramsChan,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
initialVersion: initialVersion, initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions, negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte), divNonceChan: divNonceChan,
}, nil logger: logger,
}
return cs, nil
} }
func (h *cryptoSetupClient) HandleCryptoStream() error { func (h *cryptoSetupClient) HandleCryptoStream() error {
messageChan := make(chan HandshakeMessage) messageChan := make(chan HandshakeMessage)
errorChan := make(chan error) errorChan := make(chan error, 1)
go func() { go func() {
for { for {
@ -116,37 +120,30 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
}() }()
for { for {
err := h.maybeUpgradeCrypto() if err := h.maybeUpgradeCrypto(); err != nil {
if err != nil {
return err return err
} }
h.mutex.RLock() h.mutex.RLock()
sendCHLO := h.secureAEAD == nil sendCHLO := h.secureAEAD == nil
h.mutex.RUnlock() h.mutex.RUnlock()
if sendCHLO { if sendCHLO {
err = h.sendCHLO() if err := h.sendCHLO(); err != nil {
if err != nil {
return err return err
} }
} }
var message HandshakeMessage var message HandshakeMessage
select { select {
case divNonce := <-h.divNonceChan: case <-h.divNonceChan:
if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
h.diversificationNonce = divNonce
// there's no message to process, but we should try upgrading the crypto again // there's no message to process, but we should try upgrading the crypto again
continue continue
case message = <-messageChan: case message = <-messageChan:
case err = <-errorChan: case err := <-errorChan:
return err return err
} }
utils.Debugf("Got %s", message) h.logger.Debugf("Got %s", message)
switch message.Tag { switch message.Tag {
case TagREJ: case TagREJ:
if err := h.handleREJMessage(message.Data); err != nil { if err := h.handleREJMessage(message.Data); err != nil {
@ -159,8 +156,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
} }
// blocks until the session has received the parameters // blocks until the session has received the parameters
h.paramsChan <- *params h.paramsChan <- *params
h.aeadChanged <- protocol.EncryptionForwardSecure h.handshakeEvent <- struct{}{}
close(h.aeadChanged) close(h.handshakeEvent)
default: default:
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }
@ -211,7 +208,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
err = h.certManager.Verify(h.hostname) err = h.certManager.Verify(h.hostname)
if err != nil { if err != nil {
utils.Infof("Certificate validation failed: %s", err.Error()) h.logger.Infof("Certificate validation failed: %s", err.Error())
return qerr.ProofInvalid return qerr.ProofInvalid
} }
} }
@ -219,7 +216,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil { if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get()) validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
if !validProof { if !validProof {
utils.Infof("Server proof verification failed") h.logger.Infof("Server proof verification failed")
return qerr.ProofInvalid return qerr.ProofInvalid
} }
@ -277,6 +274,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*Trans
if err != nil { if err != nil {
return nil, err return nil, err
} }
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
params, err := readHelloMap(cryptoData) params, err := readHelloMap(cryptoData)
if err != nil { if err != nil {
@ -322,6 +320,7 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
if h.secureAEAD != nil { if h.secureAEAD != nil {
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil { if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true h.receivedSecurePacket = true
return data, protocol.EncryptionSecure, nil return data, protocol.EncryptionSecure, nil
} }
@ -373,16 +372,28 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
return nil, errors.New("CryptoSetupClient: no encryption level specified") return nil, errors.New("CryptoSetupClient: no encryption level specified")
} }
func (h *cryptoSetupClient) DiversificationNonce() []byte { func (h *cryptoSetupClient) ConnectionState() ConnectionState {
panic("not needed for cryptoSetupClient") h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
} }
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) { func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
h.divNonceChan <- data h.mutex.Lock()
if len(h.diversificationNonce) > 0 {
defer h.mutex.Unlock()
if !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
} }
return nil
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType { }
panic("not needed for cryptoSetupServer") h.diversificationNonce = divNonce
h.mutex.Unlock()
h.divNonceChan <- struct{}{}
return nil
} }
func (h *cryptoSetupClient) sendCHLO() error { func (h *cryptoSetupClient) sendCHLO() error {
@ -403,7 +414,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
Data: tags, Data: tags,
} }
utils.Debugf("Sending %s", message) h.logger.Debugf("Sending %s", message)
message.Write(b) message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes()) _, err = h.cryptoStream.Write(b.Bytes())
@ -462,7 +473,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
for _, tag := range tags { for _, tag := range tags {
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
} }
paddingSize := protocol.ClientHelloMinimumSize - size paddingSize := protocol.MinClientHelloSize - size
if paddingSize > 0 { if paddingSize > 0 {
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize) tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
} }
@ -500,10 +511,9 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if err != nil { if err != nil {
return err return err
} }
h.logger.Debugf("Creating AEAD for secure encryption.")
h.aeadChanged <- protocol.EncryptionSecure h.handshakeEvent <- struct{}{}
} }
return nil return nil
} }

View File

@ -19,10 +19,12 @@ import (
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX // KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() crypto.KeyExchange type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session // The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct { type cryptoSetupServer struct {
mutex sync.RWMutex
connID protocol.ConnectionID connID protocol.ConnectionID
remoteAddr net.Addr remoteAddr net.Addr
scfg *ServerConfig scfg *ServerConfig
@ -42,7 +44,7 @@ type cryptoSetupServer struct {
receivedParams bool receivedParams bool
paramsChan chan<- TransportParameters paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel handshakeEvent chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction keyExchange KeyExchangeFunction
@ -51,7 +53,9 @@ type cryptoSetupServer struct {
params *TransportParameters params *TransportParameters
mutex sync.RWMutex sni string // need to fill out the ConnectionState
logger utils.Logger
} }
var _ CryptoSetup = &cryptoSetupServer{} var _ CryptoSetup = &cryptoSetupServer{}
@ -71,12 +75,14 @@ func NewCryptoSetup(
connID protocol.ConnectionID, connID protocol.ConnectionID,
remoteAddr net.Addr, remoteAddr net.Addr,
version protocol.VersionNumber, version protocol.VersionNumber,
divNonce []byte,
scfg *ServerConfig, scfg *ServerConfig,
params *TransportParameters, params *TransportParameters,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool, acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, handshakeEvent chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil { if err != nil {
@ -88,6 +94,7 @@ func NewCryptoSetup(
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
version: version, version: version,
supportedVersions: supportedVersions, supportedVersions: supportedVersions,
diversificationNonce: divNonce,
scfg: scfg, scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX, keyExchange: getEphermalKEX,
@ -96,7 +103,8 @@ func NewCryptoSetup(
acceptSTKCallback: acceptSTK, acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}), sentSHLO: make(chan struct{}),
paramsChan: paramsChan, paramsChan: paramsChan,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
logger: logger,
}, nil }, nil
} }
@ -112,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }
utils.Debugf("Got %s", message) h.logger.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data) done, err := h.handleMessage(chloData.Bytes(), message.Data)
if err != nil { if err != nil {
return err return err
@ -139,6 +147,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if sni == "" { if sni == "" {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
} }
h.sni = sni
// prevent version downgrade attacks // prevent version downgrade attacks
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples // see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
@ -182,7 +191,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if _, err := h.cryptoStream.Write(reply); err != nil { if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err return false, err
} }
h.aeadChanged <- protocol.EncryptionForwardSecure h.handshakeEvent <- struct{}{}
close(h.sentSHLO) close(h.sentSHLO)
return true, nil return true, nil
} }
@ -205,10 +214,11 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil { if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
h.receivedForwardSecurePacket = true h.receivedForwardSecurePacket = true
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan // wait for the send on the handshakeEvent chan
<-h.sentSHLO <-h.sentSHLO
close(h.aeadChanged) close(h.handshakeEvent)
} }
return res, protocol.EncryptionForwardSecure, nil return res, protocol.EncryptionForwardSecure, nil
} }
@ -219,6 +229,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
if h.secureAEAD != nil { if h.secureAEAD != nil {
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil { if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true h.receivedSecurePacket = true
return res, protocol.EncryptionSecure, nil return res, protocol.EncryptionSecure, nil
} }
@ -294,17 +305,13 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
func (h *cryptoSetupServer) acceptSTK(token []byte) bool { func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.scfg.cookieGenerator.DecodeToken(token) stk, err := h.scfg.cookieGenerator.DecodeToken(token)
if err != nil { if err != nil {
utils.Debugf("STK invalid: %s", err.Error()) h.logger.Debugf("STK invalid: %s", err.Error())
return false return false
} }
return h.acceptSTKCallback(h.remoteAddr, stk) return h.acceptSTKCallback(h.remoteAddr, stk)
} }
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
if len(chlo) < protocol.ClientHelloMinimumSize {
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
}
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr) token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -341,7 +348,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
var serverReply bytes.Buffer var serverReply bytes.Buffer
message.Write(&serverReply) message.Write(&serverReply)
utils.Debugf("Sending %s", message) h.logger.Debugf("Sending %s", message)
return serverReply.Bytes(), nil return serverReply.Bytes(), nil
} }
@ -365,11 +372,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
return nil, err return nil, err
} }
h.diversificationNonce = make([]byte, 32)
if _, err = rand.Read(h.diversificationNonce); err != nil {
return nil, err
}
clientNonce := cryptoData[TagNONC] clientNonce := cryptoData[TagNONC]
err = h.validateClientNonce(clientNonce) err = h.validateClientNonce(clientNonce)
if err != nil { if err != nil {
@ -400,14 +402,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
if err != nil { if err != nil {
return nil, err return nil, err
} }
h.logger.Debugf("Creating AEAD for secure encryption.")
h.aeadChanged <- protocol.EncryptionSecure h.handshakeEvent <- struct{}{}
// Generate a new curve instance to derive the forward secure key // Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer var fsNonce bytes.Buffer
fsNonce.Write(clientNonce) fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce) fsNonce.Write(serverNonce)
ephermalKex := h.keyExchange() ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil { if err != nil {
return nil, err return nil, err
@ -427,6 +432,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
if err != nil { if err != nil {
return nil, err return nil, err
} }
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
replyMap := h.params.getHelloMap() replyMap := h.params.getHelloMap()
// add crypto parameters // add crypto parameters
@ -445,21 +451,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
} }
var reply bytes.Buffer var reply bytes.Buffer
message.Write(&reply) message.Write(&reply)
utils.Debugf("Sending %s", message) h.logger.Debugf("Sending %s", message)
return reply.Bytes(), nil return reply.Bytes(), nil
} }
// DiversificationNonce returns the diversification nonce func (h *cryptoSetupServer) ConnectionState() ConnectionState {
func (h *cryptoSetupServer) DiversificationNonce() []byte { h.mutex.Lock()
return h.diversificationNonce defer h.mutex.Unlock()
return ConnectionState{
ServerName: h.sni,
HandshakeComplete: h.receivedForwardSecurePacket,
} }
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
panic("not needed for cryptoSetupServer")
} }
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {

View File

@ -1,10 +1,9 @@
package handshake package handshake
import ( import (
"crypto/tls" "errors"
"fmt" "fmt"
"io" "io"
"net"
"sync" "sync"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
@ -12,6 +11,9 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
// KeyDerivationFunction is used for key derivation // KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
@ -20,64 +22,33 @@ type cryptoSetupTLS struct {
perspective protocol.Perspective perspective protocol.Perspective
tls mintTLS
conn *fakeConn
nextPacketType protocol.PacketType
keyDerivation KeyDerivationFunction keyDerivation KeyDerivationFunction
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
aead crypto.AEAD aead crypto.AEAD
aeadChanged chan<- protocol.EncryptionLevel tls MintTLS
cryptoStream *CryptoStreamConn
handshakeEvent chan<- struct{}
} }
var _ CryptoSetupTLS = &cryptoSetupTLS{}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer( func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter, tls MintTLS,
connID protocol.ConnectionID, cryptoStream *CryptoStreamConn,
tlsConfig *tls.Config, nullAEAD crypto.AEAD,
remoteAddr net.Addr, handshakeEvent chan<- struct{},
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
checkCookie func(net.Addr, *Cookie) bool,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetup, error) { ) CryptoSetupTLS {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
if err != nil {
return nil, err
}
mintConf.RequireCookie = true
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
if err != nil {
return nil, err
}
conn := &fakeConn{
stream: cryptoStream,
pers: protocol.PerspectiveServer,
remoteAddr: remoteAddr,
}
mintConn := mint.Server(conn, mintConf)
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupTLS{ return &cryptoSetupTLS{
perspective: protocol.PerspectiveServer, tls: tls,
tls: &mintController{mintConn}, cryptoStream: cryptoStream,
conn: conn,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
}, nil }
} }
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
@ -85,59 +56,44 @@ func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter, cryptoStream io.ReadWriter,
connID protocol.ConnectionID, connID protocol.ConnectionID,
hostname string, hostname string,
tlsConfig *tls.Config, handshakeEvent chan<- struct{},
params *TransportParameters, tls MintTLS,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetupTLS, error) {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
if err != nil {
return nil, err
}
mintConf.ServerName = hostname
conn := &fakeConn{
stream: cryptoStream,
pers: protocol.PerspectiveClient,
}
mintConn := mint.Client(conn, mintConf)
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cryptoSetupTLS{ return &cryptoSetupTLS{
conn: conn,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
tls: &mintController{mintConn}, tls: tls,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, handshakeEvent: handshakeEvent,
nextPacketType: protocol.PacketTypeInitial,
}, nil }, nil
} }
func (h *cryptoSetupTLS) HandleCryptoStream() error { func (h *cryptoSetupTLS) HandleCryptoStream() error {
handshakeLoop: if h.perspective == protocol.PerspectiveServer {
for { // mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
switch alert := h.tls.Handshake(); alert { // send out that data now
case mint.AlertNoAlert: // handshake complete if _, err := h.cryptoStream.Flush(); err != nil {
break handshakeLoop
case mint.AlertWouldBlock:
h.determineNextPacketType()
if err := h.conn.Continue(); err != nil {
return err return err
} }
default: }
handshakeLoop:
for {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
} }
switch h.tls.State() {
case mint.StateClientStart: // this happens if a stateless retry is performed
return ErrCloseSessionForRetry
case mint.StateClientConnected, mint.StateServerConnected:
break handshakeLoop
}
} }
aead, err := h.keyDerivation(h.tls, h.perspective) aead, err := h.keyDerivation(h.tls, h.perspective)
@ -148,28 +104,23 @@ handshakeLoop:
h.aead = aead h.aead = aead
h.mutex.Unlock() h.mutex.Unlock()
// signal to the outside world that the handshake completed h.handshakeEvent <- struct{}{}
h.aeadChanged <- protocol.EncryptionForwardSecure close(h.handshakeEvent)
close(h.aeadChanged)
return nil return nil
} }
func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.aead != nil { if h.aead == nil {
data, err := h.aead.Open(dst, src, packetNumber, associatedData) return nil, errors.New("no 1-RTT sealer")
if err != nil {
return nil, protocol.EncryptionUnspecified, err
} }
return data, protocol.EncryptionForwardSecure, nil return h.aead.Open(dst, src, packetNumber, associatedData)
}
data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return data, protocol.EncryptionUnencrypted, nil
} }
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
@ -204,39 +155,13 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
return protocol.EncryptionUnencrypted, h.nullAEAD return protocol.EncryptionUnencrypted, h.nullAEAD
} }
func (h *cryptoSetupTLS) determineNextPacketType() error { func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
state := h.tls.State().HandshakeState mintConnState := h.tls.ConnectionState()
if h.perspective == protocol.PerspectiveServer { return ConnectionState{
switch state { // TODO: set the ServerName, once mint exports it
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest HandshakeComplete: h.aead != nil,
h.nextPacketType = protocol.PacketTypeRetry PeerCertificates: mintConnState.PeerCertificates,
case "ServerStateWaitFinished":
h.nextPacketType = protocol.PacketTypeHandshake
default:
// TODO: accept 0-RTT data
return fmt.Errorf("Unexpected handshake state: %s", state)
} }
return nil
}
// client
if state != "ClientStateWaitSH" {
h.nextPacketType = protocol.PacketTypeHandshake
}
return nil
}
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.nextPacketType
}
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
panic("diversification nonce not needed for TLS")
}
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
panic("diversification nonce not needed for TLS")
} }

View File

@ -0,0 +1,101 @@
package handshake
import (
"bytes"
"io"
"net"
"time"
)
// The CryptoStreamConn is used as the net.Conn passed to mint.
// It has two operating modes:
// 1. It can read and write to bytes.Buffers.
// 2. It can use a quic.Stream for reading and writing.
// The buffer-mode is only used by the server, in order to statelessly handle retries.
type CryptoStreamConn struct {
remoteAddr net.Addr
// the buffers are used before the session is initialized
readBuf bytes.Buffer
writeBuf bytes.Buffer
// stream will be set once the session is initialized
stream io.ReadWriter
}
var _ net.Conn = &CryptoStreamConn{}
// NewCryptoStreamConn creates a new CryptoStreamConn
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
return &CryptoStreamConn{remoteAddr: remoteAddr}
}
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
if c.stream != nil {
return c.stream.Read(b)
}
return c.readBuf.Read(b)
}
// AddDataForReading adds data to the read buffer.
// This data will ONLY be read when the stream has not been set.
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
c.readBuf.Write(data)
}
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
if c.stream != nil {
return c.stream.Write(p)
}
return c.writeBuf.Write(p)
}
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
func (c *CryptoStreamConn) GetDataForWriting() []byte {
defer c.writeBuf.Reset()
data := make([]byte, c.writeBuf.Len())
copy(data, c.writeBuf.Bytes())
return data
}
// SetStream sets the stream.
// After setting the stream, the read and write buffer won't be used any more.
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
c.stream = stream
}
// Flush copies the contents of the write buffer to the stream
func (c *CryptoStreamConn) Flush() (int, error) {
n, err := io.Copy(c.stream, &c.writeBuf)
return int(n), err
}
// Close is not implemented
func (c *CryptoStreamConn) Close() error {
return nil
}
// LocalAddr is not implemented
func (c *CryptoStreamConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr returns the remote address
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
// SetReadDeadline is not implemented
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is not implemented
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// SetDeadline is not implemented
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
return nil
}

View File

@ -6,7 +6,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
var ( var (
@ -24,27 +23,26 @@ var (
// used for all connections for 60 seconds is negligible. Thus we can amortise // used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a // the Diffie-Hellman key generation at the server over all the connections in a
// small time span. // small time span.
func getEphermalKEX() (res crypto.KeyExchange) { func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock() kexMutex.RLock()
res = kexCurrent res := kexCurrent
t := kexCurrentTime t := kexCurrentTime
kexMutex.RUnlock() kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime { if res != nil && time.Since(t) < kexLifetime {
return res return res, nil
} }
kexMutex.Lock() kexMutex.Lock()
defer kexMutex.Unlock() defer kexMutex.Unlock()
// Check if still unfulfilled // Check if still unfulfilled
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime { if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
if err != nil { if err != nil {
utils.Errorf("could not set KEX: %s", err.Error()) return nil, err
return kexCurrent
} }
kexCurrent = kex kexCurrent = kex
kexCurrentTime = time.Now() kexCurrentTime = time.Now()
return kexCurrent return kexCurrent, nil
} }
return kexCurrent return kexCurrent, nil
} }

View File

@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
offset := uint32(0) offset := uint32(0)
for i, t := range h.getTagsSorted() { for i, t := range h.getTagsSorted() {
v := data[Tag(t)] v := data[t]
b.Write(v) b.Write(v)
offset += uint32(len(v)) offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t)) binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag {
func (h HandshakeMessage) String() string { func (h HandshakeMessage) String() string {
var pad string var pad string
res := tagToString(h.Tag) + ":\n" res := tagToString(h.Tag) + ":\n"
for _, t := range h.getTagsSorted() { for _, tag := range h.getTagsSorted() {
tag := Tag(t)
if tag == TagPAD { if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag])) pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else { } else {

View File

@ -1,6 +1,11 @@
package handshake package handshake
import ( import (
"crypto/x509"
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
@ -10,16 +15,54 @@ type Sealer interface {
Overhead() int Overhead() int
} }
// CryptoSetup is a crypto setup // A TLSExtensionHandler sends and received the QUIC TLS extension.
type CryptoSetup interface { // It provides the parameters sent by the peer on a channel.
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) type TLSExtensionHandler interface {
Send(mint.HandshakeType, *mint.ExtensionList) error
Receive(mint.HandshakeType, *mint.ExtensionList) error
GetPeerParams() <-chan TransportParameters
}
// MintTLS combines some methods needed to interact with mint.
type MintTLS interface {
crypto.TLSExporter
// additional methods
Handshake() mint.Alert
State() mint.State
ConnectionState() mint.ConnectionState
SetCryptoStream(io.ReadWriter)
}
type baseCryptoSetup interface {
HandleCryptoStream() error HandleCryptoStream() error
// TODO: clean up this interface ConnectionState() ConnectionState
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
} }
// CryptoSetup is the crypto setup used by gQUIC
type CryptoSetup interface {
baseCryptoSetup
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
}
// CryptoSetupTLS is the crypto setup used by IETF QUIC
type CryptoSetupTLS interface {
baseCryptoSetup
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ConnectionState records basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
type ConnectionState struct {
HandshakeComplete bool // handshake is complete
ServerName string // server name requested by client, if any (server side only)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
}

View File

@ -1,127 +0,0 @@
package handshake
import (
"bytes"
gocrypto "crypto"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
mconf := &mint.Config{
NonBlocking: true,
CipherSuites: []mint.CipherSuite{
mint.TLS_AES_128_GCM_SHA256,
mint.TLS_AES_256_GCM_SHA384,
},
}
if tlsConf != nil {
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
}
for j, cert := range certChain.Certificate {
c, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
mconf.Certificates[i].Chain[j] = c
}
}
}
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
return nil, err
}
return mconf, nil
}
type mintTLS interface {
// These two methods are the same as the crypto.TLSExporter interface.
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
GetCipherSuite() mint.CipherSuiteParams
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
// additional methods
Handshake() mint.Alert
State() mint.ConnectionState
}
var _ crypto.TLSExporter = (mintTLS)(nil)
type mintController struct {
conn *mint.Conn
}
var _ mintTLS = &mintController{}
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
return mc.conn.State().CipherSuite
}
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
return mc.conn.ComputeExporter(label, context, keyLength)
}
func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
func (mc *mintController) State() mint.ConnectionState {
return mc.conn.State()
}
// mint expects a net.Conn, but we're doing the handshake on a stream
// so we wrap a stream such that implements a net.Conn
type fakeConn struct {
stream io.ReadWriter
pers protocol.Perspective
remoteAddr net.Addr
blockRead bool
writeBuffer bytes.Buffer
}
var _ net.Conn = &fakeConn{}
func (c *fakeConn) Read(b []byte) (int, error) {
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
return 0, nil
}
c.blockRead = true // block the next Read call
return c.stream.Read(b)
}
func (c *fakeConn) Write(p []byte) (int, error) {
if c.pers == protocol.PerspectiveClient {
return c.stream.Write(p)
}
// Buffer all writes by the server.
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
return c.writeBuffer.Write(p)
}
func (c *fakeConn) Continue() error {
c.blockRead = false
if c.pers == protocol.PerspectiveClient {
return nil
}
// write all contents of the write buffer to the stream.
_, err := c.stream.Write(c.writeBuffer.Bytes())
c.writeBuffer.Reset()
return err
}
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
func (c *fakeConn) SetDeadline(time.Time) error { return nil }

View File

@ -102,32 +102,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
} }
var pubs_kexs []struct{Length uint32; Value []byte} var pubsKexs []struct {
var last_len uint32 Length uint32
Value []byte
for i := 0; i < len(pubs)-3; i += int(last_len)+3 { }
var lastLen uint32
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
// the PUBS value is always prepended by 3 byte little endian length field // the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &last_len); err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
if err != nil { if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
} }
if last_len == 0 { if lastLen == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
} }
if i+3+int(last_len) > len(pubs) { if i+3+int(lastLen) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
} }
pubs_kexs = append(pubs_kexs, struct{Length uint32; Value []byte}{last_len, pubs[i+3:i+3+int(last_len)]}) pubsKexs = append(pubsKexs, struct {
Length uint32
Value []byte
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
} }
if c255Foundat >= len(pubs_kexs) { if c255Foundat >= len(pubsKexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS") return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
} }
if pubs_kexs[c255Foundat].Length != 32 { if pubsKexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
} }
@ -137,8 +142,7 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
return err return err
} }
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs_kexs[c255Foundat].Value)
if err != nil { if err != nil {
return err return err
} }

Some files were not shown because too many files have changed in this diff Show More