diff --git a/vendor/github.com/bifurcation/mint/README.md b/vendor/github.com/bifurcation/mint/README.md index 0ac41e0..9fa05dd 100644 --- a/vendor/github.com/bifurcation/mint/README.md +++ b/vendor/github.com/bifurcation/mint/README.md @@ -18,6 +18,12 @@ library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns with earlier TLS versions. However, unnecessary parts will be ruthlessly cut off. +## DTLS Support + +Mint has partial support for DTLS, but that support is not yet complete +and may still contain serious defects. + + ## Quickstart Installation is the same as for any other Go package: diff --git a/vendor/github.com/bifurcation/mint/alert.go b/vendor/github.com/bifurcation/mint/alert.go index 5e31035..430e455 100644 --- a/vendor/github.com/bifurcation/mint/alert.go +++ b/vendor/github.com/bifurcation/mint/alert.go @@ -46,6 +46,7 @@ const ( AlertBadCertificateHashValue Alert = 114 AlertUnknownPSKIdentity Alert = 115 AlertNoApplicationProtocol Alert = 120 + AlertStatelessRetry Alert = 253 AlertWouldBlock Alert = 254 AlertNoAlert Alert = 255 ) @@ -82,6 +83,7 @@ var alertText = map[Alert]string{ AlertUnknownPSKIdentity: "unknown PSK identity", AlertNoApplicationProtocol: "no application protocol", AlertNoRenegotiation: "no renegotiation", + AlertStatelessRetry: "stateless retry", AlertWouldBlock: "would have blocked", AlertNoAlert: "no alert", } diff --git a/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/bifurcation/mint/client-state-machine.go index 290a930..07e7f53 100644 --- a/vendor/github.com/bifurcation/mint/client-state-machine.go +++ b/vendor/github.com/bifurcation/mint/client-state-machine.go @@ -3,6 +3,7 @@ package mint import ( "bytes" "crypto" + "crypto/x509" "hash" "time" ) @@ -49,29 +50,31 @@ import ( // WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut; // CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) -type ClientStateStart struct { - Caps Capabilities +type clientStateStart struct { + Config *Config Opts ConnectionOptions Params ConnectionParameters cookie []byte firstClientHello *HandshakeMessage helloRetryRequest *HandshakeMessage + hsCtx *HandshakeContext } -func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm != nil { - logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message") - return nil, nil, AlertUnexpectedMessage - } +var _ HandshakeState = &clientStateStart{} +func (state clientStateStart) State() State { + return StateClientStart +} + +func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { // key_shares offeredDH := map[NamedGroup][]byte{} ks := KeyShareExtension{ HandshakeType: HandshakeTypeClientHello, - Shares: make([]KeyShareEntry, len(state.Caps.Groups)), + Shares: make([]KeyShareEntry, len(state.Config.Groups)), } - for i, group := range state.Caps.Groups { + for i, group := range state.Config.Groups { pub, priv, err := newKeyShare(group) if err != nil { logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) @@ -86,10 +89,10 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand logf(logTypeHandshake, "opts: %+v", state.Opts) // supported_versions, supported_groups, signature_algorithms, server_name - sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} + sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}} sni := ServerNameExtension(state.Opts.ServerName) - sg := SupportedGroupsExtension{Groups: state.Caps.Groups} - sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} + sg := SupportedGroupsExtension{Groups: state.Config.Groups} + sa := SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes} state.Params.ServerName = state.Opts.ServerName @@ -101,7 +104,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand // Construct base ClientHello ch := &ClientHelloBody{ - CipherSuites: state.Caps.CipherSuites, + LegacyVersion: wireVersion(state.hsCtx.hIn), + CipherSuites: state.Config.CipherSuites, } _, err := prng.Read(ch.Random[:]) if err != nil { @@ -133,8 +137,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand } // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) if err != nil { logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) return nil, nil, AlertInternalError @@ -150,7 +154,7 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand var earlySecret []byte var clientEarlyTrafficKeys keySet var clientHello *HandshakeMessage - if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok { + if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok { offeredPSK = key // Narrow ciphersuites to ones that match PSK hash @@ -168,8 +172,10 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand } ch.CipherSuites = compatibleSuites + // TODO(ekr@rtfm.com): Check that the ticket can be used for early + // data. // Signal early data if we're going to do it - if len(state.Opts.EarlyData) > 0 { + if state.Config.AllowEarlyData && state.helloRetryRequest == nil { state.Params.ClientSendingEarlyData = true ed = &EarlyDataExtension{} err = ch.Extensions.Add(ed) @@ -180,11 +186,11 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand } // Signal supported PSK key exchange modes - if len(state.Caps.PSKModes) == 0 { + if len(state.Config.PSKModes) == 0 { logf(logTypeHandshake, "PSK selected, but no PSKModes") return nil, nil, AlertInternalError } - kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes} + kem := &PSKKeyExchangeModesExtension{KEModes: state.Config.PSKModes} err = ch.Extensions.Add(kem) if err != nil { logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) @@ -241,7 +247,7 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand // If we got here, the earlier marshal succeeded (in ch.Truncated()), so // this one should too. - clientHello, _ = HandshakeMessageFromBody(ch) + clientHello, _ = state.hsCtx.hOut.HandshakeMessageFromBody(ch) // Compute early traffic keys h := params.Hash.New() @@ -251,11 +257,8 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) - } else if len(state.Opts.EarlyData) > 0 { - logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") - return nil, nil, AlertInternalError } else { - clientHello, err = HandshakeMessageFromBody(ch) + clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch) if err != nil { logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) return nil, nil, AlertInternalError @@ -263,10 +266,12 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand } logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") - nextState := ClientStateWaitSH{ - Caps: state.Caps, + state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. + nextState := clientStateWaitSH{ + Config: state.Config, Opts: state.Opts, Params: state.Params, + hsCtx: state.hsCtx, OfferedDH: offeredDH, OfferedPSK: offeredPSK, @@ -279,22 +284,23 @@ func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand } toSend := []HandshakeAction{ - SendHandshakeMessage{clientHello}, + QueueHandshakeMessage{clientHello}, + SendQueuedHandshake{}, } if state.Params.ClientSendingEarlyData { toSend = append(toSend, []HandshakeAction{ - RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys}, - SendEarlyData{}, + RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, }...) } return nextState, toSend, AlertNoAlert } -type ClientStateWaitSH struct { - Caps Capabilities +type clientStateWaitSH struct { + Config *Config Opts ConnectionOptions Params ConnectionParameters + hsCtx *HandshakeContext OfferedDH map[NamedGroup][]byte OfferedPSK PreSharedKey PSK []byte @@ -307,49 +313,73 @@ type ClientStateWaitSH struct { clientHello *HandshakeMessage } -func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") +var _ HandshakeState = &clientStateWaitSH{} + +func (state clientStateWaitSH) State() State { + return StateClientWaitSH +} + +func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } + + if hm == nil || hm.msgType != HandshakeTypeServerHello { + logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message") return nil, nil, AlertUnexpectedMessage } - bodyGeneric, err := hm.ToBody() - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) - return nil, nil, AlertDecodeError + sh := &ServerHelloBody{} + if _, err := sh.Unmarshal(hm.body); err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] unexpected message") + return nil, nil, AlertUnexpectedMessage } - switch body := bodyGeneric.(type) { - case *HelloRetryRequestBody: - hrr := body + // Common SH/HRR processing first. + // 1. Check that sh.version is TLS 1.2 + if sh.Version != tls12Version { + logf(logTypeHandshake, "[ClientStateWaitSH] illegal legacy version [%v]", sh.Version) + return nil, nil, AlertIllegalParameter + } - if state.helloRetryRequest != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest") - return nil, nil, AlertUnexpectedMessage - } + // 2. Check that it responded with a valid version. + supportedVersions := SupportedVersionsExtension{HandshakeType: HandshakeTypeServerHello} + foundSupportedVersions, err := sh.Extensions.Find(&supportedVersions) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] invalid supported_versions extension [%v]", err) + return nil, nil, AlertDecodeError + } + if !foundSupportedVersions { + logf(logTypeHandshake, "[ClientStateWaitSH] no supported_versions extension") + return nil, nil, AlertMissingExtension + } + if supportedVersions.Versions[0] != supportedVersion { + logf(logTypeHandshake, "[ClientStateWaitSH] unsupported version [%x]", supportedVersions.Versions[0]) + return nil, nil, AlertProtocolVersion + } + // 3. Check that the server provided a supported ciphersuite + supportedCipherSuite := false + for _, suite := range state.Config.CipherSuites { + supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) + } + if !supportedCipherSuite { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } - // Check that the version sent by the server is the one we support - if hrr.Version != supportedVersion { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version) - return nil, nil, AlertProtocolVersion - } + // Now check for the sentinel. - // Check that the server provided a supported ciphersuite - supportedCipherSuite := false - for _, suite := range state.Caps.CipherSuites { - supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite) - } - if !supportedCipherSuite { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite) - return nil, nil, AlertHandshakeFailure - } + if sh.Random == hrrRandomSentinel { + // This is actually HRR. + hrr := sh // Narrow the supported ciphersuites to the server-provided one - state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} + state.Config.CipherSuites = []CipherSuite{hrr.CipherSuite} // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) if err != nil { logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) return nil, nil, AlertInternalError @@ -358,10 +388,14 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han // The only thing we know how to respond to in an HRR is the Cookie // extension, so if there is either no Cookie extension or anything other - // than a Cookie extension, we have to fail. + // than a Cookie extension and SupportedVersions we have to fail. serverCookie := new(CookieExtension) - foundCookie := hrr.Extensions.Find(serverCookie) - if !foundCookie || len(hrr.Extensions) != 1 { + foundCookie, err := hrr.Extensions.Find(serverCookie) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Invalid server cookie extension [%v]", err) + return nil, nil, AlertDecodeError + } + if !foundCookie || len(hrr.Extensions) != 2 { logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) return nil, nil, AlertIllegalParameter } @@ -376,226 +410,260 @@ func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []Han body: h.Sum(nil), } + state.hsCtx.receivedEndOfFlight() + + // TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT + // mode. In DTLS, we also need to bump the sequence number. + // This is a pre-existing defect in Mint. Issue #175. logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") - return ClientStateStart{ - Caps: state.Caps, + return clientStateStart{ + Config: state.Config, Opts: state.Opts, + hsCtx: state.hsCtx, cookie: serverCookie.Cookie, firstClientHello: firstClientHello, helloRetryRequest: hm, - }.Next(nil) - - case *ServerHelloBody: - sh := body - - // Check that the version sent by the server is the one we support - if sh.Version != supportedVersion { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version) - return nil, nil, AlertProtocolVersion - } - - // Check that the server provided a supported ciphersuite - supportedCipherSuite := false - for _, suite := range state.Caps.CipherSuites { - supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) - } - if !supportedCipherSuite { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) - return nil, nil, AlertHandshakeFailure - } - - // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Do PSK or key agreement depending on extensions - serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} - serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} - - foundPSK := sh.Extensions.Find(&serverPSK) - foundKeyShare := sh.Extensions.Find(&serverKeyShare) - - if foundPSK && (serverPSK.SelectedIdentity == 0) { - state.Params.UsingPSK = true - } - - var dhSecret []byte - if foundKeyShare { - sks := serverKeyShare.Shares[0] - priv, ok := state.OfferedDH[sks.Group] - if !ok { - logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") - return nil, nil, AlertIllegalParameter - } - - state.Params.UsingDH = true - dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) - } - - suite := sh.CipherSuite - state.Params.CipherSuite = suite - - params, ok := cipherSuiteMap[suite] - if !ok { - logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) - return nil, nil, AlertHandshakeFailure - } - - // Start up the handshake hash - handshakeHash := params.Hash.New() - handshakeHash.Write(state.firstClientHello.Marshal()) - handshakeHash.Write(state.helloRetryRequest.Marshal()) - handshakeHash.Write(state.clientHello.Marshal()) - handshakeHash.Write(hm.Marshal()) - - // Compute handshake secrets - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - var earlySecret []byte - if state.Params.UsingPSK { - if params.Hash != state.earlyHash { - logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", - state.earlyHash, suite, params.Hash) - } - - earlySecret = state.earlySecret - } else { - earlySecret = HkdfExtract(params.Hash, zero, zero) - } - - if dhSecret == nil { - dhSecret = zero - } - - h0 := params.Hash.New().Sum(nil) - h2 := handshakeHash.Sum(nil) - preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) - handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) - clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) - serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) - preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) - masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) - - logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) - logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) - logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) - logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) - logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) - - serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") - nextState := ClientStateWaitEE{ - Caps: state.Caps, - Params: state.Params, - cryptoParams: params, - handshakeHash: handshakeHash, - certificates: state.Caps.Certificates, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, - } - toSend := []HandshakeAction{ - RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys}, - } - return nextState, toSend, AlertNoAlert + }, []HandshakeAction{ResetOut{1}}, AlertNoAlert } - logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) - return nil, nil, AlertUnexpectedMessage + // This is SH. + // Handle external extensions. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Do PSK or key agreement depending on extensions + serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} + serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} + + foundExts, err := sh.Extensions.Parse( + []ExtensionBody{ + &serverPSK, + &serverKeyShare, + }) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err) + return nil, nil, AlertDecodeError + } + + if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) { + state.Params.UsingPSK = true + } + + var dhSecret []byte + if foundExts[ExtensionTypeKeyShare] { + sks := serverKeyShare.Shares[0] + priv, ok := state.OfferedDH[sks.Group] + if !ok { + logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingDH = true + dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) + } + + suite := sh.CipherSuite + state.Params.CipherSuite = suite + + params, ok := cipherSuiteMap[suite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(hm.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + if params.Hash != state.earlyHash { + logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", + state.earlyHash, suite, params.Hash) + } + + earlySecret = state.earlySecret + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if dhSecret == nil { + dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") + nextState := clientStateWaitEE{ + Config: state.Config, + Params: state.Params, + hsCtx: state.hsCtx, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, + } + toSend := []HandshakeAction{ + RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, + } + // We're definitely not going to have to send anything with + // early data. + if !state.Params.ClientSendingEarlyData { + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)}) + } + + return nextState, toSend, AlertNoAlert } -type ClientStateWaitEE struct { - Caps Capabilities - AuthCertificate func(chain []CertificateEntry) error +type clientStateWaitEE struct { + Config *Config Params ConnectionParameters + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash - certificates []*Certificate masterSecret []byte clientHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte } -func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &clientStateWaitEE{} + +func (state clientStateWaitEE) State() State { + return StateClientWaitEE +} + +func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") return nil, nil, AlertUnexpectedMessage } ee := EncryptedExtensionsBody{} - _, err := ee.Unmarshal(hm.body) - if err != nil { + if err := safeUnmarshal(&ee, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) return nil, nil, AlertDecodeError } // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) if err != nil { logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) return nil, nil, AlertInternalError } } - serverALPN := ALPNExtension{} - serverEarlyData := EarlyDataExtension{} + serverALPN := &ALPNExtension{} + serverEarlyData := &EarlyDataExtension{} - gotALPN := ee.Extensions.Find(&serverALPN) - state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) + foundExts, err := ee.Extensions.Parse( + []ExtensionBody{ + serverALPN, + serverEarlyData, + }) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err) + return nil, nil, AlertDecodeError + } - if gotALPN && len(serverALPN.Protocols) > 0 { + state.Params.UsingEarlyData = foundExts[ExtensionTypeEarlyData] + + if foundExts[ExtensionTypeALPN] && len(serverALPN.Protocols) > 0 { state.Params.NextProto = serverALPN.Protocols[0] } state.handshakeHash.Write(hm.Marshal()) + toSend := []HandshakeAction{} + + if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData { + // We didn't get 0-RTT, so rekey to handshake. + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } + if state.Params.UsingPSK { logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") - nextState := ClientStateWaitFinished{ + nextState := clientStateWaitFinished{ Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, - certificates: state.certificates, + certificates: state.Config.Certificates, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") - nextState := ClientStateWaitCertCR{ - AuthCertificate: state.AuthCertificate, + nextState := clientStateWaitCertCR{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, - certificates: state.certificates, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } -type ClientStateWaitCertCR struct { - AuthCertificate func(chain []CertificateEntry) error +type clientStateWaitCertCR struct { + Config *Config Params ConnectionParameters + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash - certificates []*Certificate masterSecret []byte clientHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte } -func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &clientStateWaitCertCR{} + +func (state clientStateWaitCertCR) State() State { + return StateClientWaitCertCR +} + +func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil { logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") return nil, nil, AlertUnexpectedMessage @@ -612,12 +680,12 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [ switch body := bodyGeneric.(type) { case *CertificateBody: logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") - nextState := ClientStateWaitCV{ - AuthCertificate: state.AuthCertificate, + nextState := clientStateWaitCV{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, - certificates: state.certificates, serverCertificate: body, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, @@ -635,12 +703,12 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [ state.Params.UsingClientAuth = true logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") - nextState := ClientStateWaitCert{ - AuthCertificate: state.AuthCertificate, + nextState := clientStateWaitCert{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, - certificates: state.certificates, serverCertificateRequest: body, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, @@ -652,13 +720,13 @@ func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, [ return nil, nil, AlertUnexpectedMessage } -type ClientStateWaitCert struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - handshakeHash hash.Hash +type clientStateWaitCert struct { + Config *Config + Params ConnectionParameters + hsCtx *HandshakeContext + cryptoParams CipherSuiteParams + handshakeHash hash.Hash - certificates []*Certificate serverCertificateRequest *CertificateRequestBody masterSecret []byte @@ -666,15 +734,24 @@ type ClientStateWaitCert struct { serverHandshakeTrafficSecret []byte } -func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &clientStateWaitCert{} + +func (state clientStateWaitCert) State() State { + return StateClientWaitCert +} + +func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeCertificate { logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") return nil, nil, AlertUnexpectedMessage } cert := &CertificateBody{} - _, err := cert.Unmarshal(hm.body) - if err != nil { + if err := safeUnmarshal(cert, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -682,12 +759,12 @@ func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H state.handshakeHash.Write(hm.Marshal()) logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") - nextState := ClientStateWaitCV{ - AuthCertificate: state.AuthCertificate, + nextState := clientStateWaitCV{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, - certificates: state.certificates, serverCertificate: cert, serverCertificateRequest: state.serverCertificateRequest, masterSecret: state.masterSecret, @@ -697,13 +774,13 @@ func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H return nextState, nil, AlertNoAlert } -type ClientStateWaitCV struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams - handshakeHash hash.Hash +type clientStateWaitCV struct { + Config *Config + Params ConnectionParameters + hsCtx *HandshakeContext + cryptoParams CipherSuiteParams + handshakeHash hash.Hash - certificates []*Certificate serverCertificate *CertificateBody serverCertificateRequest *CertificateRequestBody @@ -712,15 +789,24 @@ type ClientStateWaitCV struct { serverHandshakeTrafficSecret []byte } -func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &clientStateWaitCV{} + +func (state clientStateWaitCV) State() State { + return StateClientWaitCV +} + +func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") return nil, nil, AlertUnexpectedMessage } certVerify := CertificateVerifyBody{} - _, err := certVerify.Unmarshal(hm.body) - if err != nil { + if err := safeUnmarshal(&certVerify, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -734,46 +820,89 @@ func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []Han return nil, nil, AlertHandshakeFailure } - if state.AuthCertificate != nil { - err := state.AuthCertificate(state.serverCertificate.CertificateList) + certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList)) + rawCerts := make([][]byte, len(state.serverCertificate.CertificateList)) + for i, certEntry := range state.serverCertificate.CertificateList { + certs[i] = certEntry.CertData + rawCerts[i] = certEntry.CertData.Raw + } + + var verifiedChains [][]*x509.Certificate + if !state.Config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: state.Config.RootCAs, + CurrentTime: state.Config.time(), + DNSName: state.Config.ServerName, + Intermediates: x509.NewCertPool(), + } + + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + var err error + verifiedChains, err = certs[0].Verify(opts) if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate") + logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err) + return nil, nil, AlertBadCertificate + } + } + + if state.Config.VerifyPeerCertificate != nil { + if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err) return nil, nil, AlertBadCertificate } - } else { - logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate") } state.handshakeHash.Write(hm.Marshal()) logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") - nextState := ClientStateWaitFinished{ + nextState := clientStateWaitFinished{ Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, - certificates: state.certificates, + certificates: state.Config.Certificates, serverCertificateRequest: state.serverCertificateRequest, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + peerCertificates: certs, + verifiedChains: verifiedChains, } return nextState, nil, AlertNoAlert } -type ClientStateWaitFinished struct { +type clientStateWaitFinished struct { Params ConnectionParameters + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash certificates []*Certificate serverCertificateRequest *CertificateRequestBody + peerCertificates []*x509.Certificate + verifiedChains [][]*x509.Certificate masterSecret []byte clientHandshakeTrafficSecret []byte serverHandshakeTrafficSecret []byte } -func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &clientStateWaitFinished{} + +func (state clientStateWaitFinished) State() State { + return StateClientWaitFinished +} + +func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeFinished { logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") return nil, nil, AlertUnexpectedMessage @@ -788,8 +917,7 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} - _, err := fin.Unmarshal(hm.body) - if err != nil { + if err := safeUnmarshal(fin, hm.body); err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) return nil, nil, AlertDecodeError } @@ -822,25 +950,32 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, toSend := []HandshakeAction{} if state.Params.UsingEarlyData { + logf(logTypeHandshake, "Sending end of early data") // Note: We only send EOED if the server is actually going to use the early // data. Otherwise, it will never see it, and the transcripts will // mismatch. // EOED marshal is infallible - eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{}) - toSend = append(toSend, SendHandshakeMessage{eoedm}) + eoedm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(&EndOfEarlyDataBody{}) + toSend = append(toSend, QueueHandshakeMessage{eoedm}) + state.handshakeHash.Write(eoedm.Marshal()) logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) - } - clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) - toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys}) + // And then rekey to handshake + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } if state.Params.UsingClientAuth { // Extract constraints from certicateRequest schemes := SignatureAlgorithmsExtension{} - gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) + gotSchemes, err := state.serverCertificateRequest.Extensions.Find(&schemes) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING invalid signature_schemes extension [%v]", err) + return nil, nil, AlertDecodeError + } if !gotSchemes { - logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found") return nil, nil, AlertIllegalParameter } @@ -851,13 +986,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) certificate := &CertificateBody{} - certm, err := HandshakeMessageFromBody(certificate) + certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) if err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) return nil, nil, AlertInternalError } - toSend = append(toSend, SendHandshakeMessage{certm}) + toSend = append(toSend, QueueHandshakeMessage{certm}) state.handshakeHash.Write(certm.Marshal()) } else { // Create and send Certificate, CertificateVerify @@ -867,13 +1002,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, for i, entry := range cert.Chain { certificate.CertificateList[i] = CertificateEntry{CertData: entry} } - certm, err := HandshakeMessageFromBody(certificate) + certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) if err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) return nil, nil, AlertInternalError } - toSend = append(toSend, SendHandshakeMessage{certm}) + toSend = append(toSend, QueueHandshakeMessage{certm}) state.handshakeHash.Write(certm.Marshal()) hcv := state.handshakeHash.Sum(nil) @@ -887,13 +1022,13 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) return nil, nil, AlertInternalError } - certvm, err := HandshakeMessageFromBody(certificateVerify) + certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify) if err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) return nil, nil, AlertInternalError } - toSend = append(toSend, SendHandshakeMessage{certvm}) + toSend = append(toSend, QueueHandshakeMessage{certvm}) state.handshakeHash.Write(certvm.Marshal()) } } @@ -909,7 +1044,7 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, VerifyDataLen: len(clientFinishedData), VerifyData: clientFinishedData, } - finm, err := HandshakeMessageFromBody(fin) + finm, err := state.hsCtx.hOut.HandshakeMessageFromBody(fin) if err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) return nil, nil, AlertInternalError @@ -923,20 +1058,26 @@ func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) toSend = append(toSend, []HandshakeAction{ - SendHandshakeMessage{finm}, - RekeyIn{Label: "application", KeySet: serverTrafficKeys}, - RekeyOut{Label: "application", KeySet: clientTrafficKeys}, + QueueHandshakeMessage{finm}, + SendQueuedHandshake{}, + RekeyIn{epoch: EpochApplicationData, KeySet: serverTrafficKeys}, + RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, }...) + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") - nextState := StateConnected{ + nextState := stateConnected{ Params: state.Params, + hsCtx: state.hsCtx, isClient: true, cryptoParams: state.cryptoParams, resumptionSecret: resumptionSecret, clientTrafficSecret: clientTrafficSecret, serverTrafficSecret: serverTrafficSecret, exporterSecret: exporterSecret, + peerCertificates: state.peerCertificates, + verifiedChains: state.verifiedChains, } return nextState, toSend, AlertNoAlert } diff --git a/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/bifurcation/mint/common.go index dfda7c3..05af3e9 100644 --- a/vendor/github.com/bifurcation/mint/common.go +++ b/vendor/github.com/bifurcation/mint/common.go @@ -5,9 +5,14 @@ import ( "strconv" ) -var ( - supportedVersion uint16 = 0x7f15 // draft-21 +const ( + supportedVersion uint16 = 0x7f16 // draft-22 + tls12Version uint16 = 0x0303 + tls10Version uint16 = 0x0301 + dtls12WireVersion uint16 = 0xfefd +) +var ( // Flags for some minor compat issues allowWrongVersionNumber = true allowPKCS1 = true @@ -20,6 +25,7 @@ const ( RecordTypeAlert RecordType = 21 RecordTypeHandshake RecordType = 22 RecordTypeApplicationData RecordType = 23 + RecordTypeAck RecordType = 25 ) // enum {...} HandshakeType; @@ -42,6 +48,13 @@ const ( HandshakeTypeMessageHash HandshakeType = 254 ) +var hrrRandomSentinel = [32]byte{ + 0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, + 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91, + 0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, + 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c, +} + // uint8 CipherSuite[2]; type CipherSuite uint16 @@ -150,3 +163,104 @@ const ( KeyUpdateNotRequested KeyUpdateRequest = 0 KeyUpdateRequested KeyUpdateRequest = 1 ) + +type State uint8 + +const ( + StateInit = 0 + + // states valid for the client + StateClientStart State = iota + StateClientWaitSH + StateClientWaitEE + StateClientWaitCert + StateClientWaitCV + StateClientWaitFinished + StateClientWaitCertCR + StateClientConnected + // states valid for the server + StateServerStart State = iota + StateServerRecvdCH + StateServerNegotiated + StateServerReadPastEarlyData + StateServerWaitEOED + StateServerWaitFlight2 + StateServerWaitCert + StateServerWaitCV + StateServerWaitFinished + StateServerConnected +) + +func (s State) String() string { + switch s { + case StateClientStart: + return "Client START" + case StateClientWaitSH: + return "Client WAIT_SH" + case StateClientWaitEE: + return "Client WAIT_EE" + case StateClientWaitCert: + return "Client WAIT_CERT" + case StateClientWaitCV: + return "Client WAIT_CV" + case StateClientWaitFinished: + return "Client WAIT_FINISHED" + case StateClientWaitCertCR: + return "Client WAIT_CERT_CR" + case StateClientConnected: + return "Client CONNECTED" + case StateServerStart: + return "Server START" + case StateServerRecvdCH: + return "Server RECVD_CH" + case StateServerNegotiated: + return "Server NEGOTIATED" + case StateServerReadPastEarlyData: + return "Server READ_PAST_EARLY_DATA" + case StateServerWaitEOED: + return "Server WAIT_EOED" + case StateServerWaitFlight2: + return "Server WAIT_FLIGHT2" + case StateServerWaitCert: + return "Server WAIT_CERT" + case StateServerWaitCV: + return "Server WAIT_CV" + case StateServerWaitFinished: + return "Server WAIT_FINISHED" + case StateServerConnected: + return "Server CONNECTED" + default: + return fmt.Sprintf("unknown state: %d", s) + } +} + +// Epochs for DTLS (also used for key phase labelling) +type Epoch uint16 + +const ( + EpochClear Epoch = 0 + EpochEarlyData Epoch = 1 + EpochHandshakeData Epoch = 2 + EpochApplicationData Epoch = 3 + EpochUpdate Epoch = 4 +) + +func (e Epoch) label() string { + switch e { + case EpochClear: + return "clear" + case EpochEarlyData: + return "early data" + case EpochHandshakeData: + return "handshake" + case EpochApplicationData: + return "application data" + } + return "Application data (updated)" +} + +func assert(b bool) { + if !b { + panic("Assertion failed") + } +} diff --git a/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/bifurcation/mint/conn.go index 08eb58d..6dcabc0 100644 --- a/vendor/github.com/bifurcation/mint/conn.go +++ b/vendor/github.com/bifurcation/mint/conn.go @@ -4,6 +4,7 @@ import ( "crypto" "crypto/x509" "encoding/hex" + "errors" "fmt" "io" "net" @@ -12,8 +13,6 @@ import ( "time" ) -var WouldBlock = fmt.Errorf("Would have blocked") - type Certificate struct { Chain []*x509.Certificate PrivateKey crypto.Signer @@ -36,16 +35,20 @@ type PreSharedKeyCache interface { Size() int } -type PSKMapCache map[string]PreSharedKey - -// A CookieHandler does two things: -// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest -// - validates this byte string echoed by the client in the ClientHello +// A CookieHandler can be used to give the application more fine-grained control over Cookies. +// Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie. +// When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie. type CookieHandler interface { + // Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest + // If Generate returns nil, mint will not send a HelloRetryRequest. Generate(*Conn) ([]byte, error) + // Validate is called when receiving a ClientHello containing a Cookie. + // If validation failed, the handshake is aborted. Validate(*Conn, []byte) bool } +type PSKMapCache map[string]PreSharedKey + func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { psk, ok = cache[key] return @@ -74,14 +77,49 @@ type Config struct { AllowEarlyData bool // Require the client to echo a cookie. RequireCookie bool - // If cookies are required and no CookieHandler is set, a default cookie handler is used. - // The default cookie handler uses 32 random bytes as a cookie. - CookieHandler CookieHandler + // A CookieHandler can be used to set and validate a cookie. + // The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector. + // If no CookieHandler is set, mint will always send a cookie. + // The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent. + CookieHandler CookieHandler + // The CookieProtector is used to encrypt / decrypt cookies. + // It should make sure that the Cookie cannot be read and tampered with by the client. + // If non-blocking mode is used, and cookies are required, this field has to be set. + // In blocking mode, a default cookie protector is used, if this is unused. + CookieProtector CookieProtector + // The ExtensionHandler is used to add custom extensions. + ExtensionHandler AppExtensionHandler RequireClientAuth bool + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + // Shared fields - Certificates []*Certificate - AuthCertificate func(chain []CertificateEntry) error + Certificates []*Certificate + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify then this callback will be considered but + // the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + CipherSuites []CipherSuite Groups []NamedGroup SignatureSchemes []SignatureScheme @@ -89,6 +127,7 @@ type Config struct { PSKs PreSharedKeyCache PSKModes []PSKKeyExchangeMode NonBlocking bool + UseDTLS bool // The same config object can be shared among different connections, so it // needs its own mutex @@ -110,17 +149,24 @@ func (c *Config) Clone() *Config { EarlyDataLifetime: c.EarlyDataLifetime, AllowEarlyData: c.AllowEarlyData, RequireCookie: c.RequireCookie, + CookieHandler: c.CookieHandler, + CookieProtector: c.CookieProtector, + ExtensionHandler: c.ExtensionHandler, RequireClientAuth: c.RequireClientAuth, + Time: c.Time, + RootCAs: c.RootCAs, + InsecureSkipVerify: c.InsecureSkipVerify, - Certificates: c.Certificates, - AuthCertificate: c.AuthCertificate, - CipherSuites: c.CipherSuites, - Groups: c.Groups, - SignatureSchemes: c.SignatureSchemes, - NextProtos: c.NextProtos, - PSKs: c.PSKs, - PSKModes: c.PSKModes, - NonBlocking: c.NonBlocking, + Certificates: c.Certificates, + VerifyPeerCertificate: c.VerifyPeerCertificate, + CipherSuites: c.CipherSuites, + Groups: c.Groups, + SignatureSchemes: c.SignatureSchemes, + NextProtos: c.NextProtos, + PSKs: c.PSKs, + PSKModes: c.PSKModes, + NonBlocking: c.NonBlocking, + UseDTLS: c.UseDTLS, } } @@ -147,28 +193,6 @@ func (c *Config) Init(isClient bool) error { if len(c.PSKModes) == 0 { c.PSKModes = defaultPSKModes } - - // If there is no certificate, generate one - if !isClient && len(c.Certificates) == 0 { - logf(logTypeHandshake, "Generating key name=%v", c.ServerName) - priv, err := newSigningKey(RSA_PSS_SHA256) - if err != nil { - return err - } - - cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv) - if err != nil { - return err - } - - c.Certificates = []*Certificate{ - { - Chain: []*x509.Certificate{cert}, - PrivateKey: priv, - }, - } - } - return nil } @@ -183,6 +207,14 @@ func (c *Config) ValidForClient() bool { return len(c.ServerName) > 0 } +func (c *Config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + var ( defaultSupportedCipherSuites = []CipherSuite{ TLS_AES_128_GCM_SHA256, @@ -214,10 +246,13 @@ var ( ) type ConnectionState struct { - HandshakeState string // string representation of the handshake state. - CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) - PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement - NextProto string // Selected ALPN proto + HandshakeState State + CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates + NextProto string // Selected ALPN proto + UsingPSK bool // Are we using PSK. + UsingEarlyData bool // Did we negotiate 0-RTT. } // Conn implements the net.Conn interface, as with "crypto/tls" @@ -228,9 +263,7 @@ type Conn struct { conn net.Conn isClient bool - EarlyData []byte - - state StateConnected + state stateConnected hState HandshakeState handshakeMutex sync.Mutex handshakeAlert Alert @@ -238,18 +271,28 @@ type Conn struct { readBuffer []byte in, out *RecordLayer - hIn, hOut *HandshakeLayer - - extHandler AppExtensionHandler + hsCtx *HandshakeContext } func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { - c := &Conn{conn: conn, config: config, isClient: isClient} - c.in = NewRecordLayer(c.conn) - c.out = NewRecordLayer(c.conn) - c.hIn = NewHandshakeLayer(c.in) - c.hIn.nonblocking = c.config.NonBlocking - c.hOut = NewHandshakeLayer(c.out) + c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}} + if !config.UseDTLS { + c.in = NewRecordLayerTLS(c.conn, directionRead) + c.out = NewRecordLayerTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out) + } else { + c.in = NewRecordLayerDTLS(c.conn, directionRead) + c.out = NewRecordLayerDTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out) + c.hsCtx.timeoutMS = initialTimeout + c.hsCtx.timers = newTimerSet() + c.hsCtx.waitingNextFlight = true + } + c.in.label = c.label() + c.out.label = c.label() + c.hsCtx.hIn.nonblocking = c.config.NonBlocking return c } @@ -267,8 +310,12 @@ func (c *Conn) consumeRecord() error { // We do not support fragmentation of post-handshake handshake messages. // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage() start := 0 + headerLen := handshakeHeaderLenTLS + if c.config.UseDTLS { + headerLen = handshakeHeaderLenDTLS + } for start < len(pt.fragment) { - if len(pt.fragment[start:]) < handshakeHeaderLen { + if len(pt.fragment[start:]) < headerLen { return fmt.Errorf("Post-handshake handshake message too short for header") } @@ -276,14 +323,15 @@ func (c *Conn) consumeRecord() error { hm.msgType = HandshakeType(pt.fragment[start]) hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) - if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen { + if len(pt.fragment[start+headerLen:]) < hmLen { return fmt.Errorf("Post-handshake handshake message too short for body") } - hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen] - - // Advance state machine - state, actions, alert := c.state.Next(hm) + hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen] + // XXX: If we want to support more advanced cases, e.g., post-handshake + // authentication, we'll need to allow transitions other than + // Connected -> Connected + state, actions, alert := c.state.ProcessMessage(hm) if alert != AlertNoAlert { logf(logTypeHandshake, "Error in state transition: %v", alert) c.sendAlert(alert) @@ -299,18 +347,15 @@ func (c *Conn) consumeRecord() error { } } - // XXX: If we want to support more advanced cases, e.g., post-handshake - // authentication, we'll need to allow transitions other than - // Connected -> Connected var connected bool - c.state, connected = state.(StateConnected) + c.state, connected = state.(stateConnected) if !connected { logf(logTypeHandshake, "Disconnected after state transition: %v", alert) c.sendAlert(alert) return io.EOF } - start += handshakeHeaderLen + hmLen + start += headerLen + hmLen } case RecordTypeAlert: logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) @@ -332,17 +377,54 @@ func (c *Conn) consumeRecord() error { return io.EOF } + case RecordTypeAck: + if !c.hsCtx.hIn.datagram { + logf(logTypeHandshake, "Received ACK in TLS mode") + return AlertUnexpectedMessage + } + return c.hsCtx.processAck(pt.fragment) + case RecordTypeApplicationData: c.readBuffer = append(c.readBuffer, pt.fragment...) logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) + } return err } +func readPartial(in *[]byte, buffer []byte) int { + logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in))) + read := copy(buffer, *in) + *in = (*in)[read:] + + logf(logTypeVerbose, "Returning %v", string(buffer)) + return read +} + // Read application data up to the size of buffer. Handshake and alert records // are consumed by the Conn object directly. func (c *Conn) Read(buffer []byte) (int, error) { + if _, connected := c.hState.(stateConnected); !connected { + // Clients can't call Read prior to handshake completion. + if c.isClient { + return 0, errors.New("Read called before the handshake completed") + } + + // Neither can servers that don't allow early data. + if !c.config.AllowEarlyData { + return 0, errors.New("Read called before the handshake completed") + } + + // If there's no early data, then return WouldBlock + if len(c.hsCtx.earlyData) == 0 { + return 0, AlertWouldBlock + } + + return readPartial(&c.hsCtx.earlyData, buffer), nil + } + + // The handshake is now connected. logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) if alert := c.Handshake(); alert != AlertNoAlert { return 0, alert @@ -352,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) { return 0, nil } + // Run our timers. + if c.config.UseDTLS { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return 0, AlertInternalError + } + } + // Lock the input channel c.in.Lock() defer c.in.Unlock() @@ -361,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) { // err can be nil if consumeRecord processed a non app-data // record. if err != nil { - if c.config.NonBlocking || err != WouldBlock { + if c.config.NonBlocking || err != AlertWouldBlock { logf(logTypeIO, "conn.Read returns err=%v", err) return 0, err } } } - var read int - n := len(buffer) - logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) - if len(c.readBuffer) <= n { - buffer = buffer[:len(c.readBuffer)] - copy(buffer, c.readBuffer) - read = len(c.readBuffer) - c.readBuffer = c.readBuffer[:0] - } else { - logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) - copy(buffer[:n], c.readBuffer[:n]) - c.readBuffer = c.readBuffer[n:] - read = n - } - - logf(logTypeVerbose, "Returning %v", string(buffer)) - return read, nil + return readPartial(&c.readBuffer, buffer), nil } // Write application data @@ -393,6 +466,10 @@ func (c *Conn) Write(buffer []byte) (int, error) { c.out.Lock() defer c.out.Unlock() + if !c.Writable() { + return 0, errors.New("Write called before the handshake completed (and early data not in use)") + } + // Send full-size fragments var start int sent := 0 @@ -495,84 +572,44 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } switch action := actionGeneric.(type) { - case SendHandshakeMessage: - err := c.hOut.WriteMessage(action.Message) + case QueueHandshakeMessage: + logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType) + err := c.hsCtx.hOut.QueueMessage(action.Message) if err != nil { logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) return AlertInternalError } + case SendQueuedHandshake: + _, err := c.hsCtx.hOut.SendQueuedMessages() + if err != nil { + logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) + return AlertInternalError + } + if c.config.UseDTLS { + c.hsCtx.timers.start(retransmitTimerLabel, + c.hsCtx.handshakeRetransmit, + c.hsCtx.timeoutMS) + } case RekeyIn: - logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet) - err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet) + err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) if err != nil { logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) return AlertInternalError } case RekeyOut: - logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet) - err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet) + err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) if err != nil { logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) return AlertInternalError } - case SendEarlyData: - logf(logTypeHandshake, "%s Sending early data...", label) - _, err := c.Write(c.EarlyData) - if err != nil { - logf(logTypeHandshake, "%s Error writing early data: %v", label, err) - return AlertInternalError - } - - case ReadPastEarlyData: - logf(logTypeHandshake, "%s Reading past early data...", label) - // Scan past all records that fail to decrypt - _, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok := err.(DecryptError) - - for ok { - _, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok = err.(DecryptError) - } - - case ReadEarlyData: - logf(logTypeHandshake, "%s Reading early data...", label) - t, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type(1): %v", label, t) - - for t == RecordTypeApplicationData { - // Read a record into the buffer. Note that this is safe - // in blocking mode because we read the record in in - // PeekRecordType. - pt, err := c.in.ReadRecord() - if err != nil { - logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) - return AlertInternalError - } - - logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) - c.EarlyData = append(c.EarlyData, pt.fragment...) - - t, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type (2): %v", label, t) - } - logf(logTypeHandshake, "%s Done reading early data", label) + case ResetOut: + logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq) + c.out.ResetClear(action.seq) case StorePSK: logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) @@ -585,7 +622,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } default: - logf(logTypeHandshake, "%s Unknown actionuction type", label) + logf(logTypeHandshake, "%s Unknown action type", label) + assert(false) return AlertInternalError } @@ -602,33 +640,13 @@ func (c *Conn) HandshakeSetup() Alert { return AlertInternalError } - // Set things up - caps := Capabilities{ - CipherSuites: c.config.CipherSuites, - Groups: c.config.Groups, - SignatureSchemes: c.config.SignatureSchemes, - PSKs: c.config.PSKs, - PSKModes: c.config.PSKModes, - AllowEarlyData: c.config.AllowEarlyData, - RequireCookie: c.config.RequireCookie, - CookieHandler: c.config.CookieHandler, - RequireClientAuth: c.config.RequireClientAuth, - NextProtos: c.config.NextProtos, - Certificates: c.config.Certificates, - ExtensionHandler: c.extHandler, - } opts := ConnectionOptions{ ServerName: c.config.ServerName, NextProtos: c.config.NextProtos, - EarlyData: c.EarlyData, - } - - if caps.RequireCookie && caps.CookieHandler == nil { - caps.CookieHandler = &defaultCookieHandler{} } if c.isClient { - state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil) + state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil) if alert != AlertNoAlert { logf(logTypeHandshake, "Error initializing client state: %v", alert) return alert @@ -642,14 +660,56 @@ func (c *Conn) HandshakeSetup() Alert { } } } else { - state = ServerStateStart{Caps: caps, conn: c} + if c.config.RequireCookie && c.config.CookieProtector == nil { + logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.") + if c.config.NonBlocking { + logf(logTypeHandshake, "Not possible in non-blocking mode.") + return AlertInternalError + } + var err error + c.config.CookieProtector, err = NewDefaultCookieProtector() + if err != nil { + logf(logTypeHandshake, "Error initializing cookie source: %v", alert) + return AlertInternalError + } + } + state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx} } c.hState = state - return AlertNoAlert } +type handshakeMessageReader interface { + ReadMessage() (*HandshakeMessage, Alert) +} + +type handshakeMessageReaderImpl struct { + hsCtx *HandshakeContext +} + +var _ handshakeMessageReader = &handshakeMessageReaderImpl{} + +func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) { + var hm *HandshakeMessage + var err error + for { + hm, err = r.hsCtx.hIn.ReadMessage() + if err == AlertWouldBlock { + return nil, AlertWouldBlock + } + if err != nil { + logf(logTypeHandshake, "Error reading message: %v", err) + return nil, AlertCloseNotify + } + if hm != nil { + break + } + } + + return hm, AlertNoAlert +} + // Handshake causes a TLS handshake on the connection. The `isClient` member // determines whether a client or server handshake is performed. If a // handshake has already been performed, then its result will be returned. @@ -669,48 +729,48 @@ func (c *Conn) Handshake() Alert { return AlertNoAlert } - var alert Alert if c.hState == nil { - logf(logTypeHandshake, "%s First time through handshake, setting up", label) - alert = c.HandshakeSetup() - if alert != AlertNoAlert { + logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label) + alert := c.HandshakeSetup() + if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) { return alert } - } else { - logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState) } + logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState) state := c.hState - _, connected := state.(StateConnected) - - var actions []HandshakeAction + _, connected := state.(stateConnected) + hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx} for !connected { - // Read a handshake message - hm, err := c.hIn.ReadMessage() - if err == WouldBlock { - logf(logTypeHandshake, "%s Would block reading message: %v", label, err) + var alert Alert + var actions []HandshakeAction + + // Advance the state machine + state, actions, alert = state.Next(hmr) + if alert == AlertWouldBlock { + logf(logTypeHandshake, "%s Would block reading message: %s", label, alert) + // If we blocked, then run our timers to see if any have expired. + if c.hsCtx.hIn.datagram { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return AlertInternalError + } + } return AlertWouldBlock } - if err != nil { - logf(logTypeHandshake, "%s Error reading message: %v", label, err) + if alert == AlertCloseNotify { + logf(logTypeHandshake, "%s Error reading message: %s", label, alert) c.sendAlert(AlertCloseNotify) return AlertCloseNotify } - logf(logTypeHandshake, "Read message with type: %v", hm.msgType) - - // Advance the state machine - state, actions, alert = state.Next(hm) - - if alert != AlertNoAlert { + if alert != AlertNoAlert && alert != AlertStatelessRetry { logf(logTypeHandshake, "Error in state transition: %v", alert) return alert } for index, action := range actions { logf(logTypeHandshake, "%s taking next action (%d)", label, index) - alert = c.takeAction(action) - if alert != AlertNoAlert { + if alert := c.takeAction(action); alert != AlertNoAlert { logf(logTypeHandshake, "Error during handshake actions: %v", alert) c.sendAlert(alert) return alert @@ -719,30 +779,48 @@ func (c *Conn) Handshake() Alert { c.hState = state logf(logTypeHandshake, "state is now %s", c.GetHsState()) + _, connected = state.(stateConnected) + if connected { + c.state = state.(stateConnected) + c.handshakeComplete = true - _, connected = state.(StateConnected) - } + if !c.isClient { + // Send NewSessionTicket if configured to + if c.config.SendSessionTickets { + actions, alert := c.state.NewSessionTicket( + c.config.TicketLen, + c.config.TicketLifetime, + c.config.EarlyDataLifetime) - c.state = state.(StateConnected) + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + } - // Send NewSessionTicket if acting as server - if !c.isClient && c.config.SendSessionTickets { - actions, alert := c.state.NewSessionTicket( - c.config.TicketLen, - c.config.TicketLifetime, - c.config.EarlyDataLifetime) + // If there is early data, move it into the main buffer + if c.hsCtx.earlyData != nil { + c.readBuffer = c.hsCtx.earlyData + c.hsCtx.earlyData = nil + } - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert + } else { + assert(c.hsCtx.earlyData == nil) } } + + if c.config.NonBlocking { + if alert == AlertStatelessRetry { + return AlertStatelessRetry + } + return AlertNoAlert + } } - c.handshakeComplete = true return AlertNoAlert } @@ -775,12 +853,15 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error { return nil } -func (c *Conn) GetHsState() string { - return reflect.TypeOf(c.hState).Name() +func (c *Conn) GetHsState() State { + if c.hState == nil { + return StateInit + } + return c.hState.State() } func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - _, connected := c.hState.(StateConnected) + _, connected := c.hState.(stateConnected) if !connected { return nil, fmt.Errorf("Cannot compute exporter when state is not connected") } @@ -796,7 +877,7 @@ func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]b return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil } -func (c *Conn) State() ConnectionState { +func (c *Conn) ConnectionState() ConnectionState { state := ConnectionState{ HandshakeState: c.GetHsState(), } @@ -804,16 +885,32 @@ func (c *Conn) State() ConnectionState { if c.handshakeComplete { state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] state.NextProto = c.state.Params.NextProto + state.VerifiedChains = c.state.verifiedChains + state.PeerCertificates = c.state.peerCertificates + state.UsingPSK = c.state.Params.UsingPSK + state.UsingEarlyData = c.state.Params.UsingEarlyData } return state } -func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error { - if c.hState != nil { - return fmt.Errorf("Can't set extension handler after setup") +func (c *Conn) Writable() bool { + // If we're connected, we're writable. + if _, connected := c.hState.(stateConnected); connected { + return true } - c.extHandler = h - return nil + // If we're a client in 0-RTT, then we're writable. + if c.isClient && c.out.cipher.epoch == EpochEarlyData { + return true + } + + return false +} + +func (c *Conn) label() string { + if c.isClient { + return "client" + } + return "server" } diff --git a/vendor/github.com/bifurcation/mint/cookie-protector.go b/vendor/github.com/bifurcation/mint/cookie-protector.go new file mode 100644 index 0000000..73dd80b --- /dev/null +++ b/vendor/github.com/bifurcation/mint/cookie-protector.go @@ -0,0 +1,86 @@ +package mint + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +// CookieProtector is used to create and verify a cookie +type CookieProtector interface { + // NewToken creates a new token + NewToken([]byte) ([]byte, error) + // DecodeToken decodes a token + DecodeToken([]byte) ([]byte, error) +} + +const cookieSecretSize = 32 +const cookieNonceSize = 32 + +// The DefaultCookieProtector is a simple implementation for the CookieProtector. +type DefaultCookieProtector struct { + secret []byte +} + +var _ CookieProtector = &DefaultCookieProtector{} + +// NewDefaultCookieProtector creates a source for source address tokens +func NewDefaultCookieProtector() (CookieProtector, error) { + secret := make([]byte, cookieSecretSize) + if _, err := rand.Read(secret); err != nil { + return nil, err + } + return &DefaultCookieProtector{secret: secret}, nil +} + +// NewToken encodes data into a new token. +func (s *DefaultCookieProtector) NewToken(data []byte) ([]byte, error) { + nonce := make([]byte, cookieNonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil +} + +// DecodeToken decodes a token. +func (s *DefaultCookieProtector) DecodeToken(p []byte) ([]byte, error) { + if len(p) < cookieNonceSize { + return nil, fmt.Errorf("Token too short: %d", len(p)) + } + nonce := p[:cookieNonceSize] + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil) +} + +func (s *DefaultCookieProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { + h := hkdf.New(sha256.New, s.secret, nonce, []byte("mint cookie source")) + key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 + if _, err := io.ReadFull(h, key); err != nil { + return nil, nil, err + } + aeadNonce := make([]byte, 12) + if _, err := io.ReadFull(h, aeadNonce); err != nil { + return nil, nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, nil, err + } + return aead, aeadNonce, nil +} diff --git a/vendor/github.com/bifurcation/mint/crypto.go b/vendor/github.com/bifurcation/mint/crypto.go index 60d3437..ef7397d 100644 --- a/vendor/github.com/bifurcation/mint/crypto.go +++ b/vendor/github.com/bifurcation/mint/crypto.go @@ -331,40 +331,6 @@ func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { } } -func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) { - sigAlg, ok := x509AlgMap[alg] - if !ok { - return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg) - } - if len(name) == 0 { - return nil, fmt.Errorf("tls.selfsigned: No name provided") - } - - serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0)) - if err != nil { - return nil, err - } - - template := &x509.Certificate{ - SerialNumber: serial, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - SignatureAlgorithm: sigAlg, - Subject: pkix.Name{CommonName: name}, - DNSNames: []string{name}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - } - der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv) - if err != nil { - return nil, err - } - - // It is safe to ignore the error here because we're parsing known-good data - cert, _ := x509.ParseCertificate(der) - return cert, nil -} - // XXX(rlb): Copied from crypto/x509 type ecdsaSignature struct { R, S *big.Int @@ -652,3 +618,50 @@ func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), } } + +func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) { + priv, err := newSigningKey(alg) + if err != nil { + return nil, nil, err + } + + cert, err := newSelfSigned(name, alg, priv) + if err != nil { + return nil, nil, err + } + return priv, cert, nil +} + +func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) { + sigAlg, ok := x509AlgMap[alg] + if !ok { + return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg) + } + if len(name) == 0 { + return nil, fmt.Errorf("tls.selfsigned: No name provided") + } + + serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0)) + if err != nil { + return nil, err + } + + template := &x509.Certificate{ + SerialNumber: serial, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, 1), + SignatureAlgorithm: sigAlg, + Subject: pkix.Name{CommonName: name}, + DNSNames: []string{name}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv) + if err != nil { + return nil, err + } + + // It is safe to ignore the error here because we're parsing known-good data + cert, _ := x509.ParseCertificate(der) + return cert, nil +} diff --git a/vendor/github.com/bifurcation/mint/dtls.go b/vendor/github.com/bifurcation/mint/dtls.go new file mode 100644 index 0000000..aa914e3 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/dtls.go @@ -0,0 +1,222 @@ +package mint + +import ( + "fmt" + "github.com/bifurcation/mint/syntax" + "time" +) + +const ( + initialMtu = 1200 + initialTimeout = 100 +) + +// labels for timers +const ( + retransmitTimerLabel = "handshake retransmit" + ackTimerLabel = "ack timer" +) + +type SentHandshakeFragment struct { + seq uint32 + offset int + fragLength int + record uint64 + acked bool +} + +type DtlsAck struct { + RecordNumbers []uint64 `tls:"head=2"` +} + +func wireVersion(h *HandshakeLayer) uint16 { + if h.datagram { + return dtls12WireVersion + } + return tls12Version +} + +func dtlsConvertVersion(version uint16) uint16 { + if version == tls12Version { + return dtls12WireVersion + } + if version == tls10Version { + return 0xfeff + } + panic(fmt.Sprintf("Internal error, unexpected version=%d", version)) +} + +// TODO(ekr@rtfm.com): Move these to state-machine.go +func (h *HandshakeContext) handshakeRetransmit() error { + if _, err := h.hOut.SendQueuedMessages(); err != nil { + return err + } + + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + // TODO(ekr@rtfm.com): Back off timer + return nil +} + +func (h *HandshakeContext) sendAck() error { + toack := h.hIn.recvdRecords + + count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU + if len(toack) > count { + toack = toack[:count] + } + logf(logTypeHandshake, "Sending ACK: [%x]", toack) + + ack := &DtlsAck{toack} + body, err := syntax.Marshal(&ack) + if err != nil { + return err + } + err = h.hOut.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAck, + fragment: body, + }) + if err != nil { + return err + } + return nil +} + +func (h *HandshakeContext) processAck(data []byte) error { + // Cancel the retransmit timer because we will be resending + // and possibly re-arming later. + h.timers.cancel(retransmitTimerLabel) + + ack := &DtlsAck{} + read, err := syntax.Unmarshal(data, &ack) + if err != nil { + return err + } + if len(data) != read { + return fmt.Errorf("Invalid encoding: Extra data not consumed") + } + logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers) + + for _, r := range ack.RecordNumbers { + for _, m := range h.sentFragments { + if r == m.record { + logf(logTypeHandshake, "Marking %v %v(%v) as acked", + m.seq, m.offset, m.fragLength) + m.acked = true + } + } + } + + count, err := h.hOut.SendQueuedMessages() + if err != nil { + return err + } + + if count == 0 { + logf(logTypeHandshake, "All messages ACKed") + h.hOut.ClearQueuedMessages() + return nil + } + + // Reset the timer + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + return nil +} + +func (c *Conn) GetDTLSTimeout() (bool, time.Duration) { + return c.hsCtx.timers.remaining() +} + +func (h *HandshakeContext) receivedHandshakeMessage() { + logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight) + // This just enables tests. + if h.hIn == nil { + return + } + + if !h.hIn.datagram { + return + } + + if h.waitingNextFlight { + logf(logTypeHandshake, "Received the start of the flight") + + // Clear the outgoing DTLS queue and terminate the retransmit timer + h.hOut.ClearQueuedMessages() + h.timers.cancel(retransmitTimerLabel) + + // OK, we're not waiting any more. + h.waitingNextFlight = false + } + + // Now pre-emptively arm the ACK timer if it's not armed already. + // We'll automatically dis-arm it at the end of the handshake. + if h.timers.getTimer(ackTimerLabel) == nil { + h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4) + } +} + +func (h *HandshakeContext) receivedEndOfFlight() { + logf(logTypeHandshake, "%p Received the end of the flight", h) + if !h.hIn.datagram { + return + } + + // Empty incoming queue + h.hIn.queued = nil + + // Note that we are waiting for the next flight. + h.waitingNextFlight = true + + // Clear the ACK queue. + h.hIn.recvdRecords = nil + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) +} + +func (h *HandshakeContext) receivedFinalFlight() { + logf(logTypeHandshake, "%p Received final flight", h) + if !h.hIn.datagram { + return + } + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) + + // But send an ACK immediately. + h.sendAck() +} + +func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool { + logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen) + for _, f := range h.sentFragments { + if !f.acked { + continue + } + + if f.seq != seq { + continue + } + + if f.offset > offset { + continue + } + + // At this point, we know that the stored fragment starts + // at or before what we want to send, so check where the end + // is. + if f.offset+f.fragLength < offset+fraglen { + continue + } + + return true + } + + return false +} diff --git a/vendor/github.com/bifurcation/mint/extensions.go b/vendor/github.com/bifurcation/mint/extensions.go index 1dbe7bd..07cb16c 100644 --- a/vendor/github.com/bifurcation/mint/extensions.go +++ b/vendor/github.com/bifurcation/mint/extensions.go @@ -3,7 +3,6 @@ package mint import ( "bytes" "fmt" - "github.com/bifurcation/mint/syntax" ) @@ -77,14 +76,40 @@ func (el *ExtensionList) Add(src ExtensionBody) error { return nil } -func (el ExtensionList) Find(dst ExtensionBody) bool { - for _, ext := range el { - if ext.ExtensionType == dst.Type() { - _, err := dst.Unmarshal(ext.ExtensionData) - return err == nil +func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) { + found := make(map[ExtensionType]bool) + + for _, dst := range dsts { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + if found[dst.Type()] { + return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type()) + } + + err := safeUnmarshal(dst, ext.ExtensionData) + if err != nil { + return nil, err + } + + found[dst.Type()] = true + } } } - return false + + return found, nil +} + +func (el ExtensionList) Find(dst ExtensionBody) (bool, error) { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + err := safeUnmarshal(dst, ext.ExtensionData) + if err != nil { + return true, err + } + return true, nil + } + } + return false, nil } // struct { @@ -529,19 +554,56 @@ func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { // ProtocolVersion versions<2..254>; // } SupportedVersions; type SupportedVersionsExtension struct { + HandshakeType HandshakeType + Versions []uint16 +} + +type SupportedVersionsClientHelloInner struct { Versions []uint16 `tls:"head=1,min=2,max=254"` } +type SupportedVersionsServerHelloInner struct { + Version uint16 +} + func (sv SupportedVersionsExtension) Type() ExtensionType { return ExtensionTypeSupportedVersions } func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { - return syntax.Marshal(sv) + switch sv.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions}) + case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: + return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]}) + default: + return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed") + } } func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sv) + switch sv.HandshakeType { + case HandshakeTypeClientHello: + var inner SupportedVersionsClientHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + sv.Versions = inner.Versions + return read, nil + + case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: + var inner SupportedVersionsServerHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + sv.Versions = []uint16{inner.Version} + return read, nil + + default: + return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed") + } } // struct { @@ -562,25 +624,3 @@ func (c CookieExtension) Marshal() ([]byte, error) { func (c *CookieExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, c) } - -// defaultCookieLength is the default length of a cookie -const defaultCookieLength = 32 - -type defaultCookieHandler struct { - data []byte -} - -var _ CookieHandler = &defaultCookieHandler{} - -// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data -func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) { - h.data = make([]byte, defaultCookieLength) - if _, err := prng.Read(h.data); err != nil { - return nil, err - } - return h.data, nil -} - -func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool { - return bytes.Equal(h.data, data) -} diff --git a/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/bifurcation/mint/frame-reader.go index 99ea470..4ccfc23 100644 --- a/vendor/github.com/bifurcation/mint/frame-reader.go +++ b/vendor/github.com/bifurcation/mint/frame-reader.go @@ -66,8 +66,8 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { f.remainder = f.remainder[copied:] f.writeOffset += copied if f.writeOffset < len(f.working) { - logf(logTypeFrameReader, "Read would have blocked 1") - return nil, nil, WouldBlock + logf(logTypeVerbose, "Read would have blocked 1") + return nil, nil, AlertWouldBlock } // Reset the write offset, because we are now full. f.writeOffset = 0 @@ -93,6 +93,6 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { f.state = kFrameReaderBody } - logf(logTypeFrameReader, "Read would have blocked 2") - return nil, nil, WouldBlock + logf(logTypeVerbose, "Read would have blocked 2") + return nil, nil, AlertWouldBlock } diff --git a/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/bifurcation/mint/handshake-layer.go index 2b04ac5..de17b30 100644 --- a/vendor/github.com/bifurcation/mint/handshake-layer.go +++ b/vendor/github.com/bifurcation/mint/handshake-layer.go @@ -7,7 +7,8 @@ import ( ) const ( - handshakeHeaderLen = 4 // handshake message header length + handshakeHeaderLenTLS = 4 // handshake message header length + handshakeHeaderLenDTLS = 12 // handshake message header length maxHandshakeMessageLen = 1 << 24 // max handshake message length ) @@ -27,28 +28,42 @@ const ( // opaque msg<0..2^24-1> // } Handshake; // -// TODO: File a spec bug type HandshakeMessage struct { - // Omitted: length - msgType HandshakeType - body []byte + msgType HandshakeType + seq uint32 + body []byte + datagram bool + offset uint32 // Used for DTLS + length uint32 + cipher *cipherState } // Note: This could be done with the `syntax` module, using the simplified // syntax as discussed above. However, since this is so simple, there's not // much benefit to doing so. +// When datagram is set, we marshal this as a whole DTLS record. func (hm *HandshakeMessage) Marshal() []byte { if hm == nil { return []byte{} } - msgLen := len(hm.body) - data := make([]byte, 4+len(hm.body)) - data[0] = byte(hm.msgType) - data[1] = byte(msgLen >> 16) - data[2] = byte(msgLen >> 8) - data[3] = byte(msgLen) - copy(data[4:], hm.body) + fragLen := len(hm.body) + var data []byte + + if hm.datagram { + data = make([]byte, handshakeHeaderLenDTLS+fragLen) + } else { + data = make([]byte, handshakeHeaderLenTLS+fragLen) + } + tmp := data + tmp = encodeUint(uint64(hm.msgType), 1, tmp) + tmp = encodeUint(uint64(hm.length), 3, tmp) + if hm.datagram { + tmp = encodeUint(uint64(hm.seq), 2, tmp) + tmp = encodeUint(uint64(hm.offset), 3, tmp) + tmp = encodeUint(uint64(fragLen), 3, tmp) + } + copy(tmp, hm.body) return data } @@ -61,8 +76,6 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { body = new(ClientHelloBody) case HandshakeTypeServerHello: body = new(ServerHelloBody) - case HandshakeTypeHelloRetryRequest: - body = new(HelloRetryRequestBody) case HandshakeTypeEncryptedExtensions: body = new(EncryptedExtensionsBody) case HandshakeTypeCertificate: @@ -83,62 +96,104 @@ func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") } - _, err := body.Unmarshal(hm.body) + err := safeUnmarshal(body, hm.body) return body, err } -func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { +func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { data, err := body.Marshal() if err != nil { return nil, err } - return &HandshakeMessage{ - msgType: body.Type(), - body: data, - }, nil + m := &HandshakeMessage{ + msgType: body.Type(), + body: data, + seq: h.msgSeq, + datagram: h.datagram, + length: uint32(len(data)), + } + h.msgSeq++ + return m, nil } type HandshakeLayer struct { - nonblocking bool // Should we operate in nonblocking mode - conn *RecordLayer // Used for reading/writing records - frame *frameReader // The buffered frame reader + ctx *HandshakeContext // The handshake we are attached to + nonblocking bool // Should we operate in nonblocking mode + conn *RecordLayer // Used for reading/writing records + frame *frameReader // The buffered frame reader + datagram bool // Is this DTLS? + msgSeq uint32 // The DTLS message sequence number + queued []*HandshakeMessage // In/out queue + sent []*HandshakeMessage // Sent messages for DTLS + recvdRecords []uint64 // Records we have received. + maxFragmentLen int } -type handshakeLayerFrameDetails struct{} +type handshakeLayerFrameDetails struct { + datagram bool +} func (d handshakeLayerFrameDetails) headerLen() int { - return handshakeHeaderLen + if d.datagram { + return handshakeHeaderLenDTLS + } + return handshakeHeaderLenTLS } func (d handshakeLayerFrameDetails) defaultReadLen() int { - return handshakeHeaderLen + maxFragmentLen + return d.headerLen() + maxFragmentLen } func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { logf(logTypeIO, "Header=%x", hdr) - return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil + // The length of this fragment (as opposed to the message) + // is always the last three bytes for both TLS and DTLS + val, _ := decodeUint(hdr[len(hdr)-3:], 3) + return int(val), nil } -func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer { +func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} + h.ctx = c h.conn = r - h.frame = newFrameReader(&handshakeLayerFrameDetails{}) + h.datagram = false + h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) + h.maxFragmentLen = maxFragmentLen + return &h +} + +func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { + h := HandshakeLayer{} + h.ctx = c + h.conn = r + h.datagram = true + h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) + h.maxFragmentLen = initialMtu // Not quite right return &h } func (h *HandshakeLayer) readRecord() error { - logf(logTypeIO, "Trying to read record") - pt, err := h.conn.ReadRecord() + logf(logTypeVerbose, "Trying to read record") + pt, err := h.conn.readRecordAnyEpoch() if err != nil { return err } - if pt.contentType != RecordTypeHandshake && - pt.contentType != RecordTypeAlert { + switch pt.contentType { + case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck: + default: return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) } + if pt.contentType == RecordTypeAck { + if !h.datagram { + return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS") + } + logf(logTypeIO, "read ACK") + return h.ctx.processAck(pt.fragment) + } + if pt.contentType == RecordTypeAlert { logf(logTypeIO, "read alert %v", pt.fragment[1]) if len(pt.fragment) < 2 { @@ -148,7 +203,19 @@ func (h *HandshakeLayer) readRecord() error { return Alert(pt.fragment[1]) } - logf(logTypeIO, "read handshake record of len %v", len(pt.fragment)) + assert(h.ctx.hIn.conn != nil) + if pt.epoch != h.ctx.hIn.conn.cipher.epoch { + // This is out of order but we're dropping it. + // TODO(ekr@rtfm.com): If server, need to retransmit Finished. + if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData { + return nil + } + + // Anything else shouldn't happen. + return AlertIllegalParameter + } + + h.recvdRecords = append(h.recvdRecords, pt.seq) h.frame.addChunk(pt.fragment) return nil @@ -171,83 +238,314 @@ func (h *HandshakeLayer) sendAlert(err Alert) error { return nil } +func (h *HandshakeLayer) noteMessageDelivered(seq uint32) { + h.msgSeq = seq + 1 + var i int + var m *HandshakeMessage + for i, m = range h.queued { + if m.seq > seq { + break + } + } + h.queued = h.queued[i:] +} + +func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) { + if hm.seq < h.msgSeq { + return nil, nil + } + + // TODO(ekr@rtfm.com): Send an ACK immediately if we got something + // out of order. + h.ctx.receivedHandshakeMessage() + + if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { + // TODO(ekr@rtfm.com): Check the length? + // This is complete. + h.noteMessageDelivered(hm.seq) + return hm, nil + } + + // Now insert sorted. + var i int + for i = 0; i < len(h.queued); i++ { + f := h.queued[i] + if hm.seq < f.seq { + break + } + if hm.offset < f.offset { + break + } + } + tmp := make([]*HandshakeMessage, 0, len(h.queued)+1) + tmp = append(tmp, h.queued[:i]...) + tmp = append(tmp, hm) + tmp = append(tmp, h.queued[i:]...) + h.queued = tmp + + return h.checkMessageAvailable() +} + +func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { + if len(h.queued) == 0 { + return nil, nil + } + + hm := h.queued[0] + if hm.seq != h.msgSeq { + return nil, nil + } + + if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { + // TODO(ekr@rtfm.com): Check the length? + // This is complete. + h.noteMessageDelivered(hm.seq) + return hm, nil + } + + // OK, this at least might complete the message. + end := uint32(0) + buf := make([]byte, hm.length) + + for _, f := range h.queued { + // Out of fragments + if f.seq > hm.seq { + break + } + + if f.length != uint32(len(buf)) { + return nil, fmt.Errorf("Mismatched DTLS length") + } + + if f.offset > end { + break + } + + if f.offset+uint32(len(f.body)) > end { + // OK, this is adding something we don't know about + copy(buf[f.offset:], f.body) + end = f.offset + uint32(len(f.body)) + if end == hm.length { + h2 := *hm + h2.offset = 0 + h2.body = buf + h.noteMessageDelivered(hm.seq) + return &h2, nil + } + } + + } + + return nil, nil +} + func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { var hdr, body []byte var err error + hm, err := h.checkMessageAvailable() + if err != nil { + return nil, err + } + if hm != nil { + return hm, nil + } for { - logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder)) + logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) if h.frame.needed() > 0 { - logf(logTypeHandshake, "Trying to read a new record") + logf(logTypeVerbose, "Trying to read a new record") err = h.readRecord() - } - if err != nil && (h.nonblocking || err != WouldBlock) { - return nil, err + + if err != nil && (h.nonblocking || err != AlertWouldBlock) { + return nil, err + } } hdr, body, err = h.frame.process() if err == nil { break } - if err != nil && (h.nonblocking || err != WouldBlock) { + if err != nil && (h.nonblocking || err != AlertWouldBlock) { return nil, err } } logf(logTypeHandshake, "read handshake message") - hm := &HandshakeMessage{} + hm = &HandshakeMessage{} hm.msgType = HandshakeType(hdr[0]) - + hm.datagram = h.datagram hm.body = make([]byte, len(body)) copy(hm.body, body) + logf(logTypeHandshake, "Read message with type: %v", hm.msgType) + if h.datagram { + tmp, hdr := decodeUint(hdr[1:], 3) + hm.length = uint32(tmp) + tmp, hdr = decodeUint(hdr, 2) + hm.seq = uint32(tmp) + tmp, hdr = decodeUint(hdr, 3) + hm.offset = uint32(tmp) + return h.newFragmentReceived(hm) + } + + hm.length = uint32(len(body)) return hm, nil } -func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { - return h.WriteMessages([]*HandshakeMessage{hm}) +func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error { + hm.cipher = h.conn.cipher + h.queued = append(h.queued, hm) + return nil } -func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { +func (h *HandshakeLayer) SendQueuedMessages() (int, error) { + logf(logTypeHandshake, "Sending outgoing messages") + count, err := h.WriteMessages(h.queued) + if !h.datagram { + h.ClearQueuedMessages() + } + return count, err +} + +func (h *HandshakeLayer) ClearQueuedMessages() { + logf(logTypeHandshake, "Clearing outgoing hs message queue") + h.queued = nil +} + +func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) { + var buf []byte + + // Figure out if we're going to want the full header or just + // the body + hdrlen := 0 + if hm.datagram { + hdrlen = handshakeHeaderLenDTLS + } else if start == 0 { + hdrlen = handshakeHeaderLenTLS + } + + // Compute the amount of body we can fit in + room -= hdrlen + if room == 0 { + // This works because we are doing one record per + // message + panic("Too short max fragment len") + } + bodylen := len(hm.body) - start + if bodylen > room { + bodylen = room + } + body := hm.body[start : start+bodylen] + + // Now see if this chunk has been ACKed. This doesn't produce ideal + // retransmission but is simple. + if h.ctx.fragmentAcked(hm.seq, start, bodylen) { + logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen) + return false, start + bodylen, nil + } + + // Encode the data. + if hdrlen > 0 { + hm2 := *hm + hm2.offset = uint32(start) + hm2.body = body + buf = hm2.Marshal() + hm = &hm2 + } else { + buf = body + } + + if h.datagram { + // Remember that we sent this. + h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{ + hm.seq, + start, + len(body), + h.conn.cipher.combineSeq(true), + false, + }) + } + return true, start + bodylen, h.conn.writeRecordWithPadding( + &TLSPlaintext{ + contentType: RecordTypeHandshake, + fragment: buf, + }, + hm.cipher, 0) +} + +func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) { + start := int(0) + + if len(hm.body) > maxHandshakeMessageLen { + return 0, fmt.Errorf("Tried to write a handshake message that's too long") + } + + written := 0 + wrote := false + + // Always make one pass through to allow EOED (which is empty). + for { + var err error + wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen) + if err != nil { + return 0, err + } + if wrote { + written++ + } + if start >= len(hm.body) { + break + } + } + + return written, nil +} + +func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) { + written := 0 for _, hm := range hms { logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) - } - - // Write out headers and bodies - buffer := []byte{} - for _, msg := range hms { - msgLen := len(msg.body) - if msgLen > maxHandshakeMessageLen { - return fmt.Errorf("tls.handshakelayer: Message too large to send") - } - - buffer = append(buffer, msg.Marshal()...) - } - - // Send full-size fragments - var start int - for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { - err := h.conn.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeHandshake, - fragment: buffer[start : start+maxFragmentLen], - }) + wrote, err := h.WriteMessage(hm) if err != nil { - return err + return 0, err } + written += wrote } + return written, nil +} - // Send a final partial fragment if necessary - if start < len(buffer) { - err := h.conn.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeHandshake, - fragment: buffer[start:], - }) +func encodeUint(v uint64, size int, out []byte) []byte { + for i := size - 1; i >= 0; i-- { + out[i] = byte(v & 0xff) + v >>= 8 + } + return out[size:] +} - if err != nil { - return err - } +func decodeUint(in []byte, size int) (uint64, []byte) { + val := uint64(0) + + for i := 0; i < size; i++ { + val <<= 8 + val += uint64(in[i]) + } + return val, in[size:] +} + +type marshalledPDU interface { + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +func safeUnmarshal(pdu marshalledPDU, data []byte) error { + read, err := pdu.Unmarshal(data) + if err != nil { + return err + } + if len(data) != read { + return fmt.Errorf("Invalid encoding: Extra data not consumed") } return nil } diff --git a/vendor/github.com/bifurcation/mint/handshake-messages.go b/vendor/github.com/bifurcation/mint/handshake-messages.go index 339bbcd..5a229f1 100644 --- a/vendor/github.com/bifurcation/mint/handshake-messages.go +++ b/vendor/github.com/bifurcation/mint/handshake-messages.go @@ -25,15 +25,14 @@ type HandshakeMessageBody interface { // Extension extensions<0..2^16-1>; // } ClientHello; type ClientHelloBody struct { - // Omitted: clientVersion - // Omitted: legacySessionID - // Omitted: legacyCompressionMethods - Random [32]byte - CipherSuites []CipherSuite - Extensions ExtensionList + LegacyVersion uint16 + Random [32]byte + LegacySessionID []byte + CipherSuites []CipherSuite + Extensions ExtensionList } -type clientHelloBodyInner struct { +type clientHelloBodyInnerTLS struct { LegacyVersion uint16 Random [32]byte LegacySessionID []byte `tls:"head=1,max=32"` @@ -42,40 +41,86 @@ type clientHelloBodyInner struct { Extensions []Extension `tls:"head=2"` } +type clientHelloBodyInnerDTLS struct { + LegacyVersion uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + EmptyCookie uint8 + CipherSuites []CipherSuite `tls:"head=2,min=2"` + LegacyCompressionMethods []byte `tls:"head=1,min=1"` + Extensions []Extension `tls:"head=2"` +} + func (ch ClientHelloBody) Type() HandshakeType { return HandshakeTypeClientHello } func (ch ClientHelloBody) Marshal() ([]byte, error) { - return syntax.Marshal(clientHelloBodyInner{ - LegacyVersion: 0x0303, - Random: ch.Random, - LegacySessionID: []byte{}, - CipherSuites: ch.CipherSuites, - LegacyCompressionMethods: []byte{0}, - Extensions: ch.Extensions, - }) + if ch.LegacyVersion == tls12Version { + return syntax.Marshal(clientHelloBodyInnerTLS{ + LegacyVersion: ch.LegacyVersion, + Random: ch.Random, + LegacySessionID: []byte{}, + CipherSuites: ch.CipherSuites, + LegacyCompressionMethods: []byte{0}, + Extensions: ch.Extensions, + }) + } else { + return syntax.Marshal(clientHelloBodyInnerDTLS{ + LegacyVersion: ch.LegacyVersion, + Random: ch.Random, + LegacySessionID: []byte{}, + CipherSuites: ch.CipherSuites, + LegacyCompressionMethods: []byte{0}, + Extensions: ch.Extensions, + }) + } + } func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { - var inner clientHelloBodyInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } + var read int + var err error - // We are strict about these things because we only support 1.3 - if inner.LegacyVersion != 0x0303 { - return 0, fmt.Errorf("tls.clienthello: Incorrect version number") - } + // Note that this might be 0, in which case we do TLS. That + // makes the tests easier. + if ch.LegacyVersion != dtls12WireVersion { + var inner clientHelloBodyInnerTLS + read, err = syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } - if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { - return 0, fmt.Errorf("tls.clienthello: Invalid compression method") - } + if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid compression method") + } - ch.Random = inner.Random - ch.CipherSuites = inner.CipherSuites - ch.Extensions = inner.Extensions + ch.LegacyVersion = inner.LegacyVersion + ch.Random = inner.Random + ch.LegacySessionID = inner.LegacySessionID + ch.CipherSuites = inner.CipherSuites + ch.Extensions = inner.Extensions + } else { + var inner clientHelloBodyInnerDTLS + read, err = syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if inner.EmptyCookie != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid cookie") + } + + if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid compression method") + } + + ch.LegacyVersion = inner.LegacyVersion + ch.Random = inner.Random + ch.LegacySessionID = inner.LegacySessionID + ch.CipherSuites = inner.CipherSuites + ch.Extensions = inner.Extensions + } return read, nil } @@ -90,10 +135,15 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) { return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") } - chm, err := HandshakeMessageFromBody(&ch) + body, err := ch.Marshal() if err != nil { return nil, err } + chm := &HandshakeMessage{ + msgType: ch.Type(), + body: body, + length: uint32(len(body)), + } chData := chm.Marshal() psk := PreSharedKeyExtension{ @@ -116,39 +166,20 @@ func (ch ClientHelloBody) Truncated() ([]byte, error) { } // struct { -// ProtocolVersion server_version; -// CipherSuite cipher_suite; -// Extension extensions<2..2^16-1>; -// } HelloRetryRequest; -type HelloRetryRequestBody struct { - Version uint16 - CipherSuite CipherSuite - Extensions ExtensionList `tls:"head=2,min=2"` -} - -func (hrr HelloRetryRequestBody) Type() HandshakeType { - return HandshakeTypeHelloRetryRequest -} - -func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) { - return syntax.Marshal(hrr) -} - -func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, hrr) -} - -// struct { -// ProtocolVersion version; +// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ // Random random; +// opaque legacy_session_id_echo<0..32>; // CipherSuite cipher_suite; -// Extension extensions<0..2^16-1>; +// uint8 legacy_compression_method = 0; +// Extension extensions<6..2^16-1>; // } ServerHello; type ServerHelloBody struct { - Version uint16 - Random [32]byte - CipherSuite CipherSuite - Extensions ExtensionList `tls:"head=2"` + Version uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + CipherSuite CipherSuite + LegacyCompressionMethod uint8 + Extensions ExtensionList `tls:"head=2"` } func (sh ServerHelloBody) Type() HandshakeType { diff --git a/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/bifurcation/mint/negotiation.go index f4ead72..2c80b8d 100644 --- a/vendor/github.com/bifurcation/mint/negotiation.go +++ b/vendor/github.com/bifurcation/mint/negotiation.go @@ -148,7 +148,7 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme } if len(candidatesByName) == 0 { - return nil, 0, fmt.Errorf("No certificates available for server name") + return nil, 0, fmt.Errorf("No certificates available for server name: %s", *serverName) } candidates = candidatesByName @@ -168,10 +168,11 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") } -func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { - usingEarlyData := gotEarlyData && usingPSK && allowEarlyData - logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) - return usingEarlyData +func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) { + using = gotEarlyData && usingPSK && allowEarlyData + rejected = gotEarlyData && !using + logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected) + return } func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { diff --git a/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/bifurcation/mint/record-layer.go index bcef613..5cf8ae2 100644 --- a/vendor/github.com/bifurcation/mint/record-layer.go +++ b/vendor/github.com/bifurcation/mint/record-layer.go @@ -1,7 +1,6 @@ package mint import ( - "bytes" "crypto/cipher" "fmt" "io" @@ -9,9 +8,10 @@ import ( ) const ( - sequenceNumberLen = 8 // sequence number length - recordHeaderLen = 5 // record header length - maxFragmentLen = 1 << 14 // max number of bytes in a record + sequenceNumberLen = 8 // sequence number length + recordHeaderLenTLS = 5 // record header length (TLS) + recordHeaderLenDTLS = 13 // record header length (DTLS) + maxFragmentLen = 1 << 14 // max number of bytes in a record ) type DecryptError string @@ -20,9 +20,16 @@ func (err DecryptError) Error() string { return string(err) } +type direction uint8 + +const ( + directionWrite = direction(1) + directionRead = direction(2) +) + // struct { // ContentType type; -// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */ +// ProtocolVersion record_version [0301 for CH, 0303 for others] // uint16 length; // opaque fragment[TLSPlaintext.length]; // } TLSPlaintext; @@ -30,87 +37,177 @@ type TLSPlaintext struct { // Omitted: record_version (static) // Omitted: length (computed from fragment) contentType RecordType + epoch Epoch + seq uint64 fragment []byte } +type cipherState struct { + epoch Epoch // DTLS epoch + ivLength int // Length of the seq and nonce fields + seq uint64 // Zero-padded sequence number + iv []byte // Buffer for the IV + cipher cipher.AEAD // AEAD cipher +} + type RecordLayer struct { sync.Mutex - + label string + direction direction + version uint16 // The current version number conn io.ReadWriter // The underlying connection frame *frameReader // The buffered frame reader nextData []byte // The next record to send cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" cachedError error // Error on the last record read - ivLength int // Length of the seq and nonce fields - seq []byte // Zero-padded sequence number - nonce []byte // Buffer for per-record nonces - cipher cipher.AEAD // AEAD cipher + cipher *cipherState + readCiphers map[Epoch]*cipherState + + datagram bool } -type recordLayerFrameDetails struct{} +type recordLayerFrameDetails struct { + datagram bool +} func (d recordLayerFrameDetails) headerLen() int { - return recordHeaderLen + if d.datagram { + return recordHeaderLenDTLS + } + return recordHeaderLenTLS } func (d recordLayerFrameDetails) defaultReadLen() int { - return recordHeaderLen + maxFragmentLen + return d.headerLen() + maxFragmentLen } func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { - return (int(hdr[3]) << 8) | int(hdr[4]), nil + return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil } -func NewRecordLayer(conn io.ReadWriter) *RecordLayer { +func newCipherStateNull() *cipherState { + return &cipherState{EpochClear, 0, 0, nil, nil} +} + +func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) { + cipher, err := factory(key) + if err != nil { + return nil, err + } + + return &cipherState{epoch, len(iv), 0, iv, cipher}, nil +} + +func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer { r := RecordLayer{} + r.label = "" + r.direction = dir r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{}) - r.ivLength = 0 + r.frame = newFrameReader(recordLayerFrameDetails{false}) + r.cipher = newCipherStateNull() + r.version = tls10Version return &r } -func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error { - var err error - r.cipher, err = cipher(key) +func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer { + r := RecordLayer{} + r.label = "" + r.direction = dir + r.conn = conn + r.frame = newFrameReader(recordLayerFrameDetails{true}) + r.cipher = newCipherStateNull() + r.readCiphers = make(map[Epoch]*cipherState, 0) + r.readCiphers[0] = r.cipher + r.datagram = true + return &r +} + +func (r *RecordLayer) SetVersion(v uint16) { + r.version = v +} + +func (r *RecordLayer) ResetClear(seq uint64) { + r.cipher = newCipherStateNull() + r.cipher.seq = seq +} + +func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { + cipher, err := newCipherStateAead(epoch, factory, key, iv) if err != nil { return err } - - r.ivLength = len(iv) - r.seq = bytes.Repeat([]byte{0}, r.ivLength) - r.nonce = make([]byte, r.ivLength) - copy(r.nonce, iv) + r.cipher = cipher + if r.datagram && r.direction == directionRead { + r.readCiphers[epoch] = cipher + } return nil } -func (r *RecordLayer) incrementSequenceNumber() { - if r.ivLength == 0 { +// TODO(ekr@rtfm.com): This is never used, which is a bug. +func (r *RecordLayer) DiscardReadKey(epoch Epoch) { + if !r.datagram { return } - for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- { - r.seq[i]++ - r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i] - if r.seq[i] != 0 { - return - } - } - - // Not allowed to let sequence number wrap. - // Instead, must renegotiate before it does. - // Not likely enough to bother. - panic("TLS: sequence number wraparound") + _, ok := r.readCiphers[epoch] + assert(ok) + delete(r.readCiphers, epoch) } -func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext { +func (c *cipherState) combineSeq(datagram bool) uint64 { + seq := c.seq + if datagram { + seq |= uint64(c.epoch) << 48 + } + return seq +} + +func (c *cipherState) computeNonce(seq uint64) []byte { + nonce := make([]byte, len(c.iv)) + copy(nonce, c.iv) + + s := seq + + offset := len(c.iv) + for i := 0; i < 8; i++ { + nonce[(offset-i)-1] ^= byte(s & 0xff) + s >>= 8 + } + logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce) + + return nonce +} + +func (c *cipherState) incrementSequenceNumber() { + if c.seq >= (1<<48 - 1) { + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. This is the + // DTLS limit. + panic("TLS: sequence number wraparound") + } + c.seq++ +} + +func (c *cipherState) overhead() int { + if c.cipher == nil { + return 0 + } + return c.cipher.Overhead() +} + +func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext { + assert(r.direction == directionWrite) + logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq) // Expand the fragment to hold contentType, padding, and overhead originalLen := len(pt.fragment) plaintextLen := originalLen + 1 + padLen - ciphertextLen := plaintextLen + r.cipher.Overhead() + ciphertextLen := plaintextLen + cipher.overhead() // Assemble the revised plaintext out := &TLSPlaintext{ + contentType: RecordTypeApplicationData, fragment: make([]byte, ciphertextLen), } @@ -122,25 +219,28 @@ func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext { // Encrypt the fragment payload := out.fragment[:plaintextLen] - r.cipher.Seal(payload[:0], r.nonce, payload, nil) + cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil) return out } -func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) { - if len(pt.fragment) < r.cipher.Overhead() { - msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead()) +func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) { + assert(r.direction == directionRead) + logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq) + if len(pt.fragment) < r.cipher.overhead() { + msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead()) return nil, 0, DecryptError(msg) } - decryptLen := len(pt.fragment) - r.cipher.Overhead() + decryptLen := len(pt.fragment) - r.cipher.overhead() out := &TLSPlaintext{ contentType: pt.contentType, fragment: make([]byte, decryptLen), } // Decrypt - _, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil) + _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil) if err != nil { + logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt) return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") } @@ -155,6 +255,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) { // Truncate the message to remove contentType, padding, overhead out.fragment = out.fragment[:newLen] + out.seq = seq return out, padLen, nil } @@ -163,11 +264,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { var err error for { - pt, err = r.nextRecord() + pt, err = r.nextRecord(false) if err == nil { break } - if !block || err != WouldBlock { + if !block || err != AlertWouldBlock { return 0, err } } @@ -175,7 +276,7 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { } func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { - pt, err := r.nextRecord() + pt, err := r.nextRecord(false) // Consume the cached record if there was one r.cachedRecord = nil @@ -184,9 +285,20 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { return pt, err } -func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { +func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) { + pt, err := r.nextRecord(true) + + // Consume the cached record if there was one + r.cachedRecord = nil + r.cachedError = nil + + return pt, err +} + +func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) { + cipher := r.cipher if r.cachedRecord != nil { - logf(logTypeIO, "Returning cached record") + logf(logTypeIO, "%s Returning cached record", r.label) return r.cachedRecord, r.cachedError } @@ -194,34 +306,35 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { // // 1. We get a frame // 2. We try to read off the socket and get nothing, in which case - // return WouldBlock + // returnAlertWouldBlock // 3. We get an error. - err := WouldBlock + var err error + err = AlertWouldBlock var header, body []byte for err != nil { if r.frame.needed() > 0 { - buf := make([]byte, recordHeaderLen+maxFragmentLen) + buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) n, err := r.conn.Read(buf) if err != nil { - logf(logTypeIO, "Error reading, %v", err) + logf(logTypeIO, "%s Error reading, %v", r.label, err) return nil, err } if n == 0 { - return nil, WouldBlock + return nil, AlertWouldBlock } - logf(logTypeIO, "Read %v bytes", n) + logf(logTypeIO, "%s Read %v bytes", r.label, n) buf = buf[:n] r.frame.addChunk(buf) } header, body, err = r.frame.process() - // Loop around on WouldBlock to see if some + // Loop around onAlertWouldBlock to see if some // data is now available. - if err != nil && err != WouldBlock { + if err != nil && err != AlertWouldBlock { return nil, err } } @@ -231,7 +344,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { switch RecordType(header[0]) { default: return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) - case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: + case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck: pt.contentType = RecordType(header[0]) } @@ -241,7 +354,8 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { } // Validate size < max - size := (int(header[3]) << 8) + int(header[4]) + size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1]) + if size > maxFragmentLen+256 { return nil, fmt.Errorf("tls.record: Ciphertext size too big") } @@ -249,33 +363,67 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { pt.fragment = make([]byte, size) copy(pt.fragment, body) + // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data. + // Attempt to decrypt fragment - if r.cipher != nil { - pt, _, err = r.decrypt(pt) + seq := cipher.seq + if r.datagram { + // TODO(ekr@rtfm.com): Handle duplicates. + seq, _ = decodeUint(header[3:11], 8) + epoch := Epoch(seq >> 48) + + // Look up the cipher suite from the epoch + c, ok := r.readCiphers[epoch] + if !ok { + logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch) + return nil, AlertWouldBlock + } + + if epoch != cipher.epoch { + logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch, + cipher.epoch, allowOldEpoch) + if !allowOldEpoch { + return nil, AlertWouldBlock + } + cipher = c + } + } + + if cipher.cipher != nil { + logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment) + pt, _, err = r.decrypt(pt, seq) if err != nil { + logf(logTypeIO, "%s Decryption failed", r.label) return nil, err } } + pt.epoch = cipher.epoch // Check that plaintext length is not too long if len(pt.fragment) > maxFragmentLen { return nil, fmt.Errorf("tls.record: Plaintext size too big") } - logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment) r.cachedRecord = pt - r.incrementSequenceNumber() + cipher.incrementSequenceNumber() return pt, nil } func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { - return r.WriteRecordWithPadding(pt, 0) + return r.writeRecordWithPadding(pt, r.cipher, 0) } func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { - if r.cipher != nil { - pt = r.encrypt(pt, padLen) + return r.writeRecordWithPadding(pt, r.cipher, padLen) +} + +func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error { + seq := cipher.combineSeq(r.datagram) + if cipher.cipher != nil { + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + pt = r.encrypt(cipher, seq, pt, padLen) } else if padLen > 0 { return fmt.Errorf("tls.record: Padding can only be done on encrypted records") } @@ -285,12 +433,26 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error } length := len(pt.fragment) - header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} + var header []byte + + if !r.datagram { + header = []byte{byte(pt.contentType), + byte(r.version >> 8), byte(r.version & 0xff), + byte(length >> 8), byte(length)} + } else { + header = make([]byte, 13) + version := dtlsConvertVersion(r.version) + copy(header, []byte{byte(pt.contentType), + byte(version >> 8), byte(version & 0xff), + }) + encodeUint(seq, 8, header[3:]) + encodeUint(uint64(length), 2, header[11:]) + } record := append(header, pt.fragment...) - logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) - r.incrementSequenceNumber() + cipher.incrementSequenceNumber() _, err := r.conn.Write(record) return err } diff --git a/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/bifurcation/mint/server-state-machine.go index 60df9b6..f91b22e 100644 --- a/vendor/github.com/bifurcation/mint/server-state-machine.go +++ b/vendor/github.com/bifurcation/mint/server-state-machine.go @@ -2,8 +2,12 @@ package mint import ( "bytes" + "crypto/x509" + "fmt" "hash" "reflect" + + "github.com/bifurcation/mint/syntax" ) // Server State Machine @@ -20,14 +24,17 @@ import ( // | [Send CertificateRequest] // Can send | [Send Certificate + CertificateVerify] // app data --> | Send Finished -// after +--------+--------+ -// here No 0-RTT | | 0-RTT -// | v -// | WAIT_EOED <---+ -// | Recv | | | Recv -// | EndOfEarlyData | | | early data -// | | +-----+ -// +> WAIT_FLIGHT2 <-+ +// after here | +// +-----------+--------+ +// | | | +// Rejected 0-RTT | No | | 0-RTT +// | 0-RTT | | +// | | v +// +---->READ_PAST | WAIT_EOED <---+ +// Decrypt | | | Decrypt | Recv | | | Recv +// error | | | OK + HS | EOED | | | early data +// +-----+ | V | +-----+ +// +---> WAIT_FLIGHT2 <-+ // | // +--------+--------+ // No auth | | Client auth @@ -46,43 +53,66 @@ import ( // // NB: Not using state RECVD_CH // -// State Instructions -// START {} -// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] -// WAIT_EOED RekeyIn; -// WAIT_FLIGHT2 {} -// WAIT_CERT_CR {} -// WAIT_CERT {} -// WAIT_CV {} -// WAIT_FINISHED RekeyIn; RekeyOut; -// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) +// State Instructions +// START {} +// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] +// WAIT_EOED RekeyIn; +// READ_PAST {} +// WAIT_FLIGHT2 {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) -type ServerStateStart struct { - Caps Capabilities - conn *Conn +// A cookie can be sent to the client in a HRR. +type cookie struct { + // The CipherSuite that was selected when the client sent the first ClientHello + CipherSuite CipherSuite + ClientHelloHash []byte `tls:"head=2"` - cookieSent bool - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage + // The ApplicationCookie can be provided by the application (by setting a Config.CookieHandler) + ApplicationCookie []byte `tls:"head=2"` } -func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +type serverStateStart struct { + Config *Config + conn *Conn + hsCtx *HandshakeContext +} + +var _ HandshakeState = &serverStateStart{} + +func (state serverStateStart) State() State { + return StateServerStart +} + +func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeClientHello { logf(logTypeHandshake, "[ServerStateStart] unexpected message") return nil, nil, AlertUnexpectedMessage } - ch := &ClientHelloBody{} - _, err := ch.Unmarshal(hm.body) - if err != nil { + ch := &ClientHelloBody{LegacyVersion: wireVersion(state.hsCtx.hIn)} + if err := safeUnmarshal(ch, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) return nil, nil, AlertDecodeError } + // We are strict about these things because we only support 1.3 + if ch.LegacyVersion != wireVersion(state.hsCtx.hIn) { + logf(logTypeHandshake, "[ServerStateStart] Invalid version number: %v", ch.LegacyVersion) + return nil, nil, AlertDecodeError + } + clientHello := hm connParams := ConnectionParameters{} - supportedVersions := new(SupportedVersionsExtension) + supportedVersions := &SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello} serverName := new(ServerNameExtension) supportedGroups := new(SupportedGroupsExtension) signatureAlgorithms := new(SignatureAlgorithmsExtension) @@ -94,32 +124,42 @@ func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand clientCookie := new(CookieExtension) // Handle external extensions. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) if err != nil { logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err) return nil, nil, AlertInternalError } } - gotSupportedVersions := ch.Extensions.Find(supportedVersions) - gotServerName := ch.Extensions.Find(serverName) - gotSupportedGroups := ch.Extensions.Find(supportedGroups) - gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms) - gotEarlyData := ch.Extensions.Find(clientEarlyData) - ch.Extensions.Find(clientKeyShares) - ch.Extensions.Find(clientPSK) - ch.Extensions.Find(clientALPN) - ch.Extensions.Find(clientPSKModes) - ch.Extensions.Find(clientCookie) + foundExts, err := ch.Extensions.Parse( + []ExtensionBody{ + supportedVersions, + serverName, + supportedGroups, + signatureAlgorithms, + clientEarlyData, + clientKeyShares, + clientPSK, + clientALPN, + clientPSKModes, + clientCookie, + }) - if gotServerName { + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error parsing extensions [%v]", err) + return nil, nil, AlertDecodeError + } + + clientSentCookie := len(clientCookie.Cookie) > 0 + + if foundExts[ExtensionTypeServerName] { connParams.ServerName = string(*serverName) } // If the client didn't send supportedVersions or doesn't support 1.3, // then we're done here. - if !gotSupportedVersions { + if !foundExts[ExtensionTypeSupportedVersions] { logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") return nil, nil, AlertProtocolVersion } @@ -129,36 +169,72 @@ func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand return nil, nil, AlertProtocolVersion } - if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) { - logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") - return nil, nil, AlertAccessDenied + // The client sent a cookie. So this is probably the second ClientHello (sent as a response to a HRR) + var firstClientHello *HandshakeMessage + var initialCipherSuite CipherSuiteParams // the cipher suite that was negotiated when sending the HelloRetryRequest + if clientSentCookie { + plainCookie, err := state.Config.CookieProtector.DecodeToken(clientCookie.Cookie) + if err != nil { + logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error decoding token [%v]", err)) + return nil, nil, AlertDecryptError + } + cookie := &cookie{} + if rb, err := syntax.Unmarshal(plainCookie, cookie); err != nil && rb != len(plainCookie) { // this should never happen + logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error unmarshaling cookie [%v]", err)) + return nil, nil, AlertInternalError + } + // restore the hash of initial ClientHello from the cookie + firstClientHello = &HandshakeMessage{ + msgType: HandshakeTypeMessageHash, + body: cookie.ClientHelloHash, + } + // have the application validate its part of the cookie + if state.Config.CookieHandler != nil && !state.Config.CookieHandler.Validate(state.conn, cookie.ApplicationCookie) { + logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") + return nil, nil, AlertAccessDenied + } + var ok bool + initialCipherSuite, ok = cipherSuiteMap[cookie.CipherSuite] + if !ok { + logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Cookie contained invalid cipher suite: %#x", cookie.CipherSuite)) + return nil, nil, AlertInternalError + } + } + + if len(ch.LegacySessionID) != 0 && len(ch.LegacySessionID) != 32 { + logf(logTypeHandshake, "[ServerStateStart] invalid session ID") + return nil, nil, AlertIllegalParameter } // Figure out if we can do DH - canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups) + canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Config.Groups) // Figure out if we can do PSK - canDoPSK := false + var canDoPSK bool var selectedPSK int - var psk *PreSharedKey var params CipherSuiteParams + var psk *PreSharedKey if len(clientPSK.Identities) > 0 { contextBase := []byte{} - if state.helloRetryRequest != nil { - chBytes := state.firstClientHello.Marshal() - hrrBytes := state.helloRetryRequest.Marshal() - contextBase = append(chBytes, hrrBytes...) + if clientSentCookie { + contextBase = append(contextBase, firstClientHello.Marshal()...) + // fill in the cookie sent by the client. Needed to calculate the correct hash + cookieExt := &CookieExtension{Cookie: clientCookie.Cookie} + hrr, err := state.generateHRR(params.Suite, + ch.LegacySessionID, cookieExt) + if err != nil { + return nil, nil, AlertInternalError + } + contextBase = append(contextBase, hrr.Marshal()...) } - chTrunc, err := ch.Truncated() if err != nil { logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err) return nil, nil, AlertDecodeError } - context := append(contextBase, chTrunc...) - canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs) + canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Config.PSKs) if err != nil { logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) return nil, nil, AlertInternalError @@ -169,67 +245,81 @@ func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) // Select a ciphersuite - connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites) + connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Config.CipherSuites) if err != nil { logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) return nil, nil, AlertHandshakeFailure } - - // Send a cookie if required - // NB: Need to do this here because it's after ciphersuite selection, which - // has to be after PSK selection. - // XXX: Doing this statefully for now, could be stateless - var cookieData []byte - if state.Caps.RequireCookie && !state.cookieSent { - var err error - cookieData, err = state.Caps.CookieHandler.Generate(state.conn) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) - return nil, nil, AlertInternalError - } + if clientSentCookie && initialCipherSuite.Suite != connParams.CipherSuite { + logf(logTypeHandshake, "[ServerStateStart] Would have selected a different CipherSuite after receiving the client's Cookie") + return nil, nil, AlertInternalError } - if cookieData != nil { + + var helloRetryRequest *HandshakeMessage + if state.Config.RequireCookie { + // Send a cookie if required + // NB: Need to do this here because it's after ciphersuite selection, which + // has to be after PSK selection. + var shouldSendHRR bool + var cookieExt *CookieExtension + if !clientSentCookie { // this is the first ClientHello that we receive + var appCookie []byte + if state.Config.CookieHandler == nil { // if Config.RequireCookie is set, but no CookieHandler was provided, we definitely need to send a cookie + shouldSendHRR = true + } else { // if the CookieHandler was set, we just send a cookie when the application provides one + var err error + appCookie, err = state.Config.CookieHandler.Generate(state.conn) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) + return nil, nil, AlertInternalError + } + shouldSendHRR = appCookie != nil + } + if shouldSendHRR { + params := cipherSuiteMap[connParams.CipherSuite] + h := params.Hash.New() + h.Write(clientHello.Marshal()) + plainCookie, err := syntax.Marshal(cookie{ + CipherSuite: connParams.CipherSuite, + ClientHelloHash: h.Sum(nil), + ApplicationCookie: appCookie, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error marshalling cookie [%v]", err) + return nil, nil, AlertInternalError + } + cookieData, err := state.Config.CookieProtector.NewToken(plainCookie) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error encoding cookie [%v]", err) + return nil, nil, AlertInternalError + } + cookieExt = &CookieExtension{Cookie: cookieData} + } + } else { + cookieExt = &CookieExtension{Cookie: clientCookie.Cookie} + } + + // Generate a HRR. We will need it in both of the two cases: + // 1. We need to send a Cookie. Then this HRR will be sent on the wire + // 2. We need to validate a cookie. Then we need its hash // Ignoring errors because everything here is newly constructed, so there // shouldn't be marshal errors - hrr := &HelloRetryRequestBody{ - Version: supportedVersion, - CipherSuite: connParams.CipherSuite, - } - hrr.Extensions.Add(&CookieExtension{Cookie: cookieData}) - - // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if shouldSendHRR || clientSentCookie { + helloRetryRequest, err = state.generateHRR(connParams.CipherSuite, + ch.LegacySessionID, cookieExt) if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) return nil, nil, AlertInternalError } } - helloRetryRequest, err := HandshakeMessageFromBody(hrr) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) - return nil, nil, AlertInternalError + if shouldSendHRR { + toSend := []HandshakeAction{ + QueueHandshakeMessage{helloRetryRequest}, + SendQueuedHandshake{}, + } + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") + return state, toSend, AlertStatelessRetry } - - params := cipherSuiteMap[connParams.CipherSuite] - h := params.Hash.New() - h.Write(clientHello.Marshal()) - firstClientHello := &HandshakeMessage{ - msgType: HandshakeTypeMessageHash, - body: h.Sum(nil), - } - - nextState := ServerStateStart{ - Caps: state.Caps, - conn: state.conn, - cookieSent: true, - firstClientHello: firstClientHello, - helloRetryRequest: helloRetryRequest, - } - toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}} - logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") - return nextState, toSend, AlertNoAlert } // If we've got no entropy to make keys from, fail @@ -247,16 +337,17 @@ func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand psk = nil // If we're not using a PSK mode, then we need to have certain extensions - if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms { - logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)", - gotServerName, gotSupportedGroups, gotSignatureAlgorithms) + if !(foundExts[ExtensionTypeServerName] && + foundExts[ExtensionTypeSupportedGroups] && + foundExts[ExtensionTypeSignatureAlgorithms]) { + logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v)", foundExts) return nil, nil, AlertMissingExtension } // Select a certificate name := string(*serverName) var err error - cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates) + cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Config.Certificates) if err != nil { logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err) return nil, nil, AlertAccessDenied @@ -269,10 +360,9 @@ func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand // Figure out if we're going to do early data var clientEarlyTrafficSecret []byte - connParams.ClientSendingEarlyData = gotEarlyData - connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData) + connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData] + connParams.UsingEarlyData, connParams.RejectedEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) if connParams.UsingEarlyData { - h := params.Hash.New() h.Write(clientHello.Marshal()) chHash := h.Sum(nil) @@ -283,17 +373,20 @@ func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand } // Select a next protocol - connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos) + connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Config.NextProtos) if err != nil { logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err) return nil, nil, AlertNoApplicationProtocol } - logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") - return ServerStateNegotiated{ - Caps: state.Caps, - Params: connParams, + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") + state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. + return serverStateNegotiated{ + Config: state.Config, + Params: connParams, + hsCtx: state.hsCtx, dhGroup: dhGroup, dhPublic: dhPublic, dhSecret: dhSecret, @@ -301,18 +394,60 @@ func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []Hand selectedPSK: selectedPSK, cert: cert, certScheme: certScheme, + legacySessionId: ch.LegacySessionID, clientEarlyTrafficSecret: clientEarlyTrafficSecret, - firstClientHello: state.firstClientHello, - helloRetryRequest: state.helloRetryRequest, + firstClientHello: firstClientHello, + helloRetryRequest: helloRetryRequest, clientHello: clientHello, - }.Next(nil) + }, nil, AlertNoAlert } -type ServerStateNegotiated struct { - Caps Capabilities - Params ConnectionParameters +func (state *serverStateStart) generateHRR(cs CipherSuite, legacySessionId []byte, + cookieExt *CookieExtension) (*HandshakeMessage, error) { + var helloRetryRequest *HandshakeMessage + hrr := &ServerHelloBody{ + Version: tls12Version, + Random: hrrRandomSentinel, + CipherSuite: cs, + LegacySessionID: legacySessionId, + LegacyCompressionMethod: 0, + } + sv := &SupportedVersionsExtension{ + HandshakeType: HandshakeTypeServerHello, + Versions: []uint16{supportedVersion}, + } + + if err := hrr.Extensions.Add(sv); err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error adding SupportedVersion [%v]", err) + return nil, err + } + + if err := hrr.Extensions.Add(cookieExt); err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error adding CookieExtension [%v]", err) + return nil, err + } + // Run the external extension handler. + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) + return nil, err + } + } + helloRetryRequest, err := state.hsCtx.hOut.HandshakeMessageFromBody(hrr) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) + return nil, err + } + return helloRetryRequest, nil +} + +type serverStateNegotiated struct { + Config *Config + Params ConnectionParameters + hsCtx *HandshakeContext dhGroup NamedGroup dhPublic []byte dhSecret []byte @@ -321,31 +456,42 @@ type ServerStateNegotiated struct { selectedPSK int cert *Certificate certScheme SignatureScheme - - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage - clientHello *HandshakeMessage + legacySessionId []byte + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage } -func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } +var _ HandshakeState = &serverStateNegotiated{} +func (state serverStateNegotiated) State() State { + return StateServerNegotiated +} + +func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { // Create the ServerHello sh := &ServerHelloBody{ - Version: supportedVersion, - CipherSuite: state.Params.CipherSuite, + Version: tls12Version, + CipherSuite: state.Params.CipherSuite, + LegacySessionID: state.legacySessionId, + LegacyCompressionMethod: 0, } - _, err := prng.Read(sh.Random[:]) - if err != nil { + if _, err := prng.Read(sh.Random[:]); err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) return nil, nil, AlertInternalError } + + err := sh.Extensions.Add(&SupportedVersionsExtension{ + HandshakeType: HandshakeTypeServerHello, + Versions: []uint16{supportedVersion}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported_versions extension [%v]", err) + return nil, nil, AlertInternalError + } if state.Params.UsingDH { logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") - err = sh.Extensions.Add(&KeyShareExtension{ + err := sh.Extensions.Add(&KeyShareExtension{ HandshakeType: HandshakeTypeServerHello, Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}}, }) @@ -356,7 +502,7 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ } if state.Params.UsingPSK { logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension") - err = sh.Extensions.Add(&PreSharedKeyExtension{ + err := sh.Extensions.Add(&PreSharedKeyExtension{ HandshakeType: HandshakeTypeServerHello, SelectedIdentity: uint16(state.selectedPSK), }) @@ -367,15 +513,15 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ } // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) return nil, nil, AlertInternalError } } - serverHello, err := HandshakeMessageFromBody(sh) + serverHello, err := state.hsCtx.hOut.HandshakeMessageFromBody(sh) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err) return nil, nil, AlertInternalError @@ -448,15 +594,15 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ ee := &EncryptedExtensionsBody{eeList} // Run the external extension handler. - if state.Caps.ExtensionHandler != nil { - err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if state.Config.ExtensionHandler != nil { + err := state.Config.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) return nil, nil, AlertInternalError } } - eem, err := HandshakeMessageFromBody(ee) + eem, err := state.hsCtx.hOut.HandshakeMessageFromBody(ee) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err) return nil, nil, AlertInternalError @@ -465,35 +611,35 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ handshakeHash.Write(eem.Marshal()) toSend := []HandshakeAction{ - SendHandshakeMessage{serverHello}, - RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys}, - SendHandshakeMessage{eem}, + QueueHandshakeMessage{serverHello}, + RekeyOut{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, + QueueHandshakeMessage{eem}, } // Authenticate with a certificate if required if !state.Params.UsingPSK { // Send a CertificateRequest message if we want client auth - if state.Caps.RequireClientAuth { + if state.Config.RequireClientAuth { state.Params.UsingClientAuth = true // XXX: We don't support sending any constraints besides a list of // supported signature algorithms cr := &CertificateRequestBody{} - schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} + schemes := &SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes} err := cr.Extensions.Add(schemes) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err) return nil, nil, AlertInternalError } - crm, err := HandshakeMessageFromBody(cr) + crm, err := state.hsCtx.hOut.HandshakeMessageFromBody(cr) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err) return nil, nil, AlertInternalError } //TODO state.state.serverCertificateRequest = cr - toSend = append(toSend, SendHandshakeMessage{crm}) + toSend = append(toSend, QueueHandshakeMessage{crm}) handshakeHash.Write(crm.Marshal()) } @@ -504,13 +650,13 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ for i, entry := range state.cert.Chain { certificate.CertificateList[i] = CertificateEntry{CertData: entry} } - certm, err := HandshakeMessageFromBody(certificate) + certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) return nil, nil, AlertInternalError } - toSend = append(toSend, SendHandshakeMessage{certm}) + toSend = append(toSend, QueueHandshakeMessage{certm}) handshakeHash.Write(certm.Marshal()) certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} @@ -524,13 +670,13 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err) return nil, nil, AlertInternalError } - certvm, err := HandshakeMessageFromBody(certificateVerify) + certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err) return nil, nil, AlertInternalError } - toSend = append(toSend, SendHandshakeMessage{certvm}) + toSend = append(toSend, QueueHandshakeMessage{certvm}) handshakeHash.Write(certvm.Marshal()) } @@ -547,10 +693,11 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ VerifyDataLen: len(serverFinishedData), VerifyData: serverFinishedData, } - finm, _ := HandshakeMessageFromBody(fin) + finm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(fin) - toSend = append(toSend, SendHandshakeMessage{finm}) + toSend = append(toSend, QueueHandshakeMessage{finm}) handshakeHash.Write(finm.Marshal()) + toSend = append(toSend, SendQueuedHandshake{}) // Compute traffic secrets h4 := handshakeHash.Sum(nil) @@ -563,7 +710,7 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret) - toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys}) + toSend = append(toSend, RekeyOut{epoch: EpochApplicationData, KeySet: serverTrafficKeys}) exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4) logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret) @@ -572,9 +719,10 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret) logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]") - nextState := ServerStateWaitEOED{ - AuthCertificate: state.Caps.AuthCertificate, + nextState := serverStateWaitEOED{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: params, handshakeHash: handshakeHash, masterSecret: masterSecret, @@ -584,20 +732,20 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ exporterSecret: exporterSecret, } toSend = append(toSend, []HandshakeAction{ - RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys}, - ReadEarlyData{}, + RekeyIn{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, }...) return nextState, toSend, AlertNoAlert } logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") toSend = append(toSend, []HandshakeAction{ - RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, - ReadPastEarlyData{}, + RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, }...) - waitFlight2 := ServerStateWaitFlight2{ - AuthCertificate: state.Caps.AuthCertificate, + var nextState HandshakeState + nextState = serverStateWaitFlight2{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: params, handshakeHash: handshakeHash, masterSecret: masterSecret, @@ -606,14 +754,19 @@ func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, [ serverTrafficSecret: serverTrafficSecret, exporterSecret: exporterSecret, } - nextState, moreToSend, alert := waitFlight2.Next(nil) - toSend = append(toSend, moreToSend...) - return nextState, toSend, alert + if state.Params.RejectedEarlyData { + nextState = serverStateReadPastEarlyData{ + hsCtx: state.hsCtx, + next: &nextState, + } + } + return nextState, toSend, AlertNoAlert } -type ServerStateWaitEOED struct { - AuthCertificate func(chain []CertificateEntry) error +type serverStateWaitEOED struct { + Config *Config Params ConnectionParameters + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -623,7 +776,49 @@ type ServerStateWaitEOED struct { exporterSecret []byte } -func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &serverStateWaitEOED{} + +func (state serverStateWaitEOED) State() State { + return StateServerWaitEOED +} + +func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading early data...") + assert(state.hsCtx.hIn.conn.cipher.epoch == EpochEarlyData) + t, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + if err != nil { + logf(logTypeHandshake, "Server Error reading record type (1): %v", err) + return nil, nil, AlertBadRecordMAC + } + + logf(logTypeHandshake, "Server got record type(1): %v", t) + + if t != RecordTypeApplicationData { + break + } + + // Read a record into the buffer. Note that this is safe + // in blocking mode because we read the record in + // PeekRecordType. + pt, err := state.hsCtx.hIn.conn.ReadRecord() + if err != nil { + logf(logTypeHandshake, "Server error reading early data record: %v", err) + return nil, nil, AlertInternalError + } + + logf(logTypeHandshake, "Server read early data: %x", pt.fragment) + state.hsCtx.earlyData = append(state.hsCtx.earlyData, pt.fragment...) + } + + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData { logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message") return nil, nil, AlertUnexpectedMessage @@ -640,11 +835,12 @@ func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []H logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]") toSend := []HandshakeAction{ - RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, + RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, } - waitFlight2 := ServerStateWaitFlight2{ - AuthCertificate: state.AuthCertificate, + waitFlight2 := serverStateWaitFlight2{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, masterSecret: state.masterSecret, @@ -653,14 +849,47 @@ func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []H serverTrafficSecret: state.serverTrafficSecret, exporterSecret: state.exporterSecret, } - nextState, moreToSend, alert := waitFlight2.Next(nil) - toSend = append(toSend, moreToSend...) - return nextState, toSend, alert + return waitFlight2, toSend, AlertNoAlert } -type ServerStateWaitFlight2 struct { - AuthCertificate func(chain []CertificateEntry) error +var _ HandshakeState = &serverStateReadPastEarlyData{} + +type serverStateReadPastEarlyData struct { + hsCtx *HandshakeContext + next *HandshakeState +} + +func (state serverStateReadPastEarlyData) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading past early data...") + // Scan past all records that fail to decrypt + _, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == nil { + break + } + + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + // Continue on DecryptError + _, ok := err.(DecryptError) + if !ok { + return nil, nil, AlertInternalError // Really need something else. + } + } + + return *state.next, nil, AlertNoAlert +} + +func (state serverStateReadPastEarlyData) State() State { + return StateServerReadPastEarlyData +} + +type serverStateWaitFlight2 struct { + Config *Config Params ConnectionParameters + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -670,17 +899,19 @@ type ServerStateWaitFlight2 struct { exporterSecret []byte } -func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm != nil { - logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } +var _ HandshakeState = &serverStateWaitFlight2{} +func (state serverStateWaitFlight2) State() State { + return StateServerWaitFlight2 +} + +func (state serverStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { if state.Params.UsingClientAuth { logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]") - nextState := ServerStateWaitCert{ - AuthCertificate: state.AuthCertificate, + nextState := serverStateWaitCert{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, handshakeHash: state.handshakeHash, masterSecret: state.masterSecret, @@ -693,8 +924,9 @@ func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, } logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]") - nextState := ServerStateWaitFinished{ + nextState := serverStateWaitFinished{ Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, @@ -706,9 +938,10 @@ func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, return nextState, nil, AlertNoAlert } -type ServerStateWaitCert struct { - AuthCertificate func(chain []CertificateEntry) error +type serverStateWaitCert struct { + Config *Config Params ConnectionParameters + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -718,15 +951,24 @@ type ServerStateWaitCert struct { exporterSecret []byte } -func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &serverStateWaitCert{} + +func (state serverStateWaitCert) State() State { + return StateServerWaitCert +} + +func (state serverStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeCertificate { logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") return nil, nil, AlertUnexpectedMessage } cert := &CertificateBody{} - _, err := cert.Unmarshal(hm.body) - if err != nil { + if err := safeUnmarshal(cert, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") return nil, nil, AlertDecodeError } @@ -737,8 +979,9 @@ func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate") logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]") - nextState := ServerStateWaitFinished{ + nextState := serverStateWaitFinished{ Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, @@ -751,9 +994,10 @@ func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H } logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]") - nextState := ServerStateWaitCV{ - AuthCertificate: state.AuthCertificate, + nextState := serverStateWaitCV{ + Config: state.Config, Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, @@ -766,10 +1010,11 @@ func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []H return nextState, nil, AlertNoAlert } -type ServerStateWaitCV struct { - AuthCertificate func(chain []CertificateEntry) error - Params ConnectionParameters - cryptoParams CipherSuiteParams +type serverStateWaitCV struct { + Config *Config + Params ConnectionParameters + hsCtx *HandshakeContext + cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -782,19 +1027,35 @@ type ServerStateWaitCV struct { clientCertificate *CertificateBody } -func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &serverStateWaitCV{} + +func (state serverStateWaitCV) State() State { + return StateServerWaitCV +} + +func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm)) return nil, nil, AlertUnexpectedMessage } certVerify := &CertificateVerifyBody{} - _, err := certVerify.Unmarshal(hm.body) - if err != nil { + if err := safeUnmarshal(certVerify, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) return nil, nil, AlertDecodeError } + rawCerts := make([][]byte, len(state.clientCertificate.CertificateList)) + certs := make([]*x509.Certificate, len(state.clientCertificate.CertificateList)) + for i, certEntry := range state.clientCertificate.CertificateList { + certs[i] = certEntry.CertData + rawCerts[i] = certEntry.CertData.Raw + } + // Verify client signature over handshake hash hcv := state.handshakeHash.Sum(nil) logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) @@ -805,22 +1066,21 @@ func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []Han return nil, nil, AlertHandshakeFailure } - if state.AuthCertificate != nil { - err := state.AuthCertificate(state.clientCertificate.CertificateList) - if err != nil { - logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate") + if state.Config.VerifyPeerCertificate != nil { + // TODO(#171): pass in the verified chains, once we support different client auth types + if err := state.Config.VerifyPeerCertificate(rawCerts, nil); err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate: %s", err) return nil, nil, AlertBadCertificate } - } else { - logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate") } // If it passes, record the certificateVerify in the transcript hash state.handshakeHash.Write(hm.Marshal()) logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]") - nextState := ServerStateWaitFinished{ + nextState := serverStateWaitFinished{ Params: state.Params, + hsCtx: state.hsCtx, cryptoParams: state.cryptoParams, masterSecret: state.masterSecret, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, @@ -828,16 +1088,21 @@ func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []Han clientTrafficSecret: state.clientTrafficSecret, serverTrafficSecret: state.serverTrafficSecret, exporterSecret: state.exporterSecret, + peerCertificates: certs, + verifiedChains: nil, // TODO(#171): set this value } return nextState, nil, AlertNoAlert } -type ServerStateWaitFinished struct { +type serverStateWaitFinished struct { Params ConnectionParameters + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte + peerCertificates []*x509.Certificate + verifiedChains [][]*x509.Certificate handshakeHash hash.Hash clientTrafficSecret []byte @@ -845,15 +1110,24 @@ type ServerStateWaitFinished struct { exporterSecret []byte } -func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +var _ HandshakeState = &serverStateWaitFinished{} + +func (state serverStateWaitFinished) State() State { + return StateServerWaitFinished +} + +func (state serverStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + hm, alert := hr.ReadMessage() + if alert != AlertNoAlert { + return nil, nil, alert + } if hm == nil || hm.msgType != HandshakeTypeFinished { logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message") return nil, nil, AlertUnexpectedMessage } fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} - _, err := fin.Unmarshal(hm.body) - if err != nil { + if err := safeUnmarshal(fin, hm.body); err != nil { logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) return nil, nil, AlertDecodeError } @@ -881,18 +1155,23 @@ func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, // Compute client traffic keys clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + state.hsCtx.receivedFinalFlight() + logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") - nextState := StateConnected{ + nextState := stateConnected{ Params: state.Params, + hsCtx: state.hsCtx, isClient: false, cryptoParams: state.cryptoParams, resumptionSecret: resumptionSecret, clientTrafficSecret: state.clientTrafficSecret, serverTrafficSecret: state.serverTrafficSecret, exporterSecret: state.exporterSecret, + peerCertificates: state.peerCertificates, + verifiedChains: state.verifiedChains, } toSend := []HandshakeAction{ - RekeyIn{Label: "application", KeySet: clientTrafficKeys}, + RekeyIn{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, } return nextState, toSend, AlertNoAlert } diff --git a/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/bifurcation/mint/state-machine.go index 4eb468c..558b76c 100644 --- a/vendor/github.com/bifurcation/mint/state-machine.go +++ b/vendor/github.com/bifurcation/mint/state-machine.go @@ -1,6 +1,7 @@ package mint import ( + "crypto/x509" "time" ) @@ -8,32 +9,35 @@ import ( // state transitions. type HandshakeAction interface{} -type SendHandshakeMessage struct { +type QueueHandshakeMessage struct { Message *HandshakeMessage } +type SendQueuedHandshake struct{} + type SendEarlyData struct{} -type ReadEarlyData struct{} - -type ReadPastEarlyData struct{} - type RekeyIn struct { - Label string + epoch Epoch KeySet keySet } type RekeyOut struct { - Label string + epoch Epoch KeySet keySet } +type ResetOut struct { + seq uint64 +} + type StorePSK struct { PSK PreSharedKey } type HandshakeState interface { - Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) + Next(handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) + State() State } type AppExtensionHandler interface { @@ -41,35 +45,11 @@ type AppExtensionHandler interface { Receive(hs HandshakeType, el *ExtensionList) error } -// Capabilities objects represent the capabilities of a TLS client or server, -// as an input to TLS negotiation -type Capabilities struct { - // For both client and server - CipherSuites []CipherSuite - Groups []NamedGroup - SignatureSchemes []SignatureScheme - PSKs PreSharedKeyCache - Certificates []*Certificate - AuthCertificate func(chain []CertificateEntry) error - ExtensionHandler AppExtensionHandler - - // For client - PSKModes []PSKKeyExchangeMode - - // For server - NextProtos []string - AllowEarlyData bool - RequireCookie bool - CookieHandler CookieHandler - RequireClientAuth bool -} - // ConnectionOptions objects represent per-connection settings for a client // initiating a connection type ConnectionOptions struct { ServerName string NextProtos []string - EarlyData []byte } // ConnectionParameters objects represent the parameters negotiated for a @@ -79,6 +59,7 @@ type ConnectionParameters struct { UsingDH bool ClientSendingEarlyData bool UsingEarlyData bool + RejectedEarlyData bool UsingClientAuth bool CipherSuite CipherSuite @@ -86,18 +67,50 @@ type ConnectionParameters struct { NextProto string } -// StateConnected is symmetric between client and server -type StateConnected struct { +// Working state for the handshake. +type HandshakeContext struct { + timeoutMS uint32 + timers *timerSet + recvdRecords []uint64 + sentFragments []*SentHandshakeFragment + hIn, hOut *HandshakeLayer + waitingNextFlight bool + earlyData []byte +} + +func (hc *HandshakeContext) SetVersion(version uint16) { + if hc.hIn.conn != nil { + hc.hIn.conn.SetVersion(version) + } + if hc.hOut.conn != nil { + hc.hOut.conn.SetVersion(version) + } +} + +// stateConnected is symmetric between client and server +type stateConnected struct { Params ConnectionParameters + hsCtx *HandshakeContext isClient bool cryptoParams CipherSuiteParams resumptionSecret []byte clientTrafficSecret []byte serverTrafficSecret []byte exporterSecret []byte + peerCertificates []*x509.Certificate + verifiedChains [][]*x509.Certificate } -func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { +var _ HandshakeState = &stateConnected{} + +func (state stateConnected) State() State { + if state.isClient { + return StateClientConnected + } + return StateServerConnected +} + +func (state *stateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { var trafficKeys keySet if state.isClient { state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, @@ -109,20 +122,21 @@ func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAct trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) } - kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) + kum, err := state.hsCtx.hOut.HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) if err != nil { logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) return nil, AlertInternalError } toSend := []HandshakeAction{ - SendHandshakeMessage{kum}, - RekeyOut{Label: "update", KeySet: trafficKeys}, + QueueHandshakeMessage{kum}, + SendQueuedHandshake{}, + RekeyOut{epoch: EpochUpdate, KeySet: trafficKeys}, } return toSend, AlertNoAlert } -func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { +func (state *stateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { tkt, err := NewSessionTicket(length, lifetime) if err != nil { logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) @@ -149,7 +163,7 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif TicketAgeAdd: tkt.TicketAgeAdd, } - tktm, err := HandshakeMessageFromBody(tkt) + tktm, err := state.hsCtx.hOut.HandshakeMessageFromBody(tkt) if err != nil { logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) return nil, AlertInternalError @@ -157,12 +171,18 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif toSend := []HandshakeAction{ StorePSK{newPSK}, - SendHandshakeMessage{tktm}, + QueueHandshakeMessage{tktm}, + SendQueuedHandshake{}, } return toSend, AlertNoAlert } -func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { +// Next does nothing for this state. +func (state stateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + return state, nil, AlertNoAlert +} + +func (state stateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { if hm == nil { logf(logTypeHandshake, "[StateConnected] Unexpected message") return nil, nil, AlertUnexpectedMessage @@ -187,20 +207,18 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) } - toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}} + toSend := []HandshakeAction{RekeyIn{epoch: EpochUpdate, KeySet: trafficKeys}} // If requested, roll outbound keys and send a KeyUpdate if body.KeyUpdateRequest == KeyUpdateRequested { + logf(logTypeHandshake, "Received key update, update requested", body.KeyUpdateRequest) moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) if alert != AlertNoAlert { return nil, nil, alert } - toSend = append(toSend, moreToSend...) } - return state, toSend, AlertNoAlert - case *NewSessionTicketBody: // XXX: Allow NewSessionTicket in both directions? if !state.isClient { @@ -209,7 +227,6 @@ func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []Handsh resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) - psk := PreSharedKey{ CipherSuite: state.cryptoParams.Suite, IsResumption: true, diff --git a/vendor/github.com/bifurcation/mint/syntax/README.md b/vendor/github.com/bifurcation/mint/syntax/README.md index dbf4ec2..537b9b4 100644 --- a/vendor/github.com/bifurcation/mint/syntax/README.md +++ b/vendor/github.com/bifurcation/mint/syntax/README.md @@ -72,3 +72,13 @@ The available annotations right now are all related to vectors: fragment[TLSPlaintext.length]`. Note, however, that in cases where the length immediately preceds the array, these can be reframed as vectors with appropriate sizes. + + +QUIC Extensions Syntax +====================== +syntax also supports some minor extensions to allow implementing QUIC. + +* The `varint` annotation describes a QUIC-style varint +* `head=none` means no header, i.e., the bytes are encoded directly on the wire. + On reading, the decoder will consume all available data. +* `head=varint` means to encode the header as a varint diff --git a/vendor/github.com/bifurcation/mint/syntax/decode.go b/vendor/github.com/bifurcation/mint/syntax/decode.go index cd5aada..1735840 100644 --- a/vendor/github.com/bifurcation/mint/syntax/decode.go +++ b/vendor/github.com/bifurcation/mint/syntax/decode.go @@ -16,12 +16,22 @@ func Unmarshal(data []byte, v interface{}) (int, error) { return d.unmarshal(v) } +// Unmarshaler is the interface implemented by types that can +// unmarshal a TLS description of themselves. Note that unlike the +// JSON unmarshaler interface, it is not known a priori how much of +// the input data will be consumed. So the Unmarshaler must state +// how much of the input data it consumed. +type Unmarshaler interface { + UnmarshalTLS([]byte) (int, error) +} + // These are the options that can be specified in the struct tag. Right now, // all of them apply to variable-length vectors and nothing else type decOpts struct { - head uint // length of length in bytes - min uint // minimum size in bytes - max uint // maximum size in bytes + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes + varint bool // whether to decode as a varint } type decodeState struct { @@ -65,8 +75,14 @@ func typeDecoder(t reflect.Type) decoderFunc { return newTypeDecoder(t) } +var ( + unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem() +) + func newTypeDecoder(t reflect.Type) decoderFunc { - // Note: Does not support Marshaler, so don't need the allowAddr argument + if t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(unmarshalerType) { + return unmarshalerDecoder + } switch t.Kind() { case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: @@ -77,6 +93,8 @@ func newTypeDecoder(t reflect.Type) decoderFunc { return newSliceDecoder(t) case reflect.Struct: return newStructDecoder(t) + case reflect.Ptr: + return newPointerDecoder(t) default: panic(fmt.Errorf("Unsupported type (%s)", t)) } @@ -84,35 +102,87 @@ func newTypeDecoder(t reflect.Type) decoderFunc { ///// Specific decoders below -func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { - var uintLen int - switch v.Elem().Kind() { - case reflect.Uint8: - uintLen = 1 - case reflect.Uint16: - uintLen = 2 - case reflect.Uint32: - uintLen = 4 - case reflect.Uint64: - uintLen = 8 +func unmarshalerDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + um, ok := v.Interface().(Unmarshaler) + if !ok { + panic(fmt.Errorf("Non-Unmarshaler passed to unmarshalerEncoder")) } - buf := make([]byte, uintLen) - n, err := d.Read(buf) + read, err := um.UnmarshalTLS(d.Bytes()) if err != nil { panic(err) } - if n != uintLen { + + if read > d.Len() { + panic(fmt.Errorf("Invalid return value from UnmarshalTLS")) + } + + d.Next(read) + return read +} + +////////// + +func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + if opts.varint { + return varintDecoder(d, v, opts) + } + + uintLen := int(v.Elem().Type().Size()) + buf := d.Next(uintLen) + if len(buf) != uintLen { panic(fmt.Errorf("Insufficient data to read uint")) } + return setUintFromBuffer(v, buf) +} + +func varintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + l, val := readVarint(d) + + uintLen := int(v.Elem().Type().Size()) + if uintLen < l { + panic(fmt.Errorf("Uint too small to fit varint: %d < %d", uintLen, l)) + } + + v.Elem().SetUint(val) + + return l +} + +func readVarint(d *decodeState) (int, uint64) { + // Read the first octet and decide the size of the presented varint + first := d.Next(1) + if len(first) != 1 { + panic(fmt.Errorf("Insufficient data to read varint length")) + } + + twoBits := uint(first[0] >> 6) + varintLen := 1 << twoBits + + rest := d.Next(varintLen - 1) + if len(rest) != varintLen-1 { + panic(fmt.Errorf("Insufficient data to read varint")) + } + + buf := append(first, rest...) + buf[0] &= 0x3f + + return len(buf), decodeUintFromBuffer(buf) +} + +func decodeUintFromBuffer(buf []byte) uint64 { val := uint64(0) for _, b := range buf { val = (val << 8) + uint64(b) } - v.Elem().SetUint(val) - return uintLen + return val +} + +func setUintFromBuffer(v reflect.Value, buf []byte) int { + v.Elem().SetUint(decodeUintFromBuffer(buf)) + return len(buf) } ////////// @@ -143,44 +213,57 @@ type sliceDecoder struct { } func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + var length uint64 + var read int + var data []byte + if opts.head == 0 { panic(fmt.Errorf("Cannot decode a slice without a header length")) } - lengthBytes := make([]byte, opts.head) - n, err := d.Read(lengthBytes) - if err != nil { - panic(err) - } - if uint(n) != opts.head { - panic(fmt.Errorf("Not enough data to read header")) - } + // If the caller indicated there is no header, then read everything from the buffer + if opts.head == headValueNoHead { + for { + chunk := d.Next(1024) + data = append(data, chunk...) + if len(chunk) != 1024 { + break + } + } + length = uint64(len(data)) + if opts.max > 0 && length > uint64(opts.max) { + panic(fmt.Errorf("Length of vector exceeds declared max")) + } + if length < uint64(opts.min) { + panic(fmt.Errorf("Length of vector below declared min")) + } + } else { + if opts.head != headValueVarint { + lengthBytes := d.Next(int(opts.head)) + if len(lengthBytes) != int(opts.head) { + panic(fmt.Errorf("Not enough data to read header")) + } + read = len(lengthBytes) + length = decodeUintFromBuffer(lengthBytes) + } else { + read, length = readVarint(d) + } + if opts.max > 0 && length > uint64(opts.max) { + panic(fmt.Errorf("Length of vector exceeds declared max")) + } + if length < uint64(opts.min) { + panic(fmt.Errorf("Length of vector below declared min")) + } - length := uint(0) - for _, b := range lengthBytes { - length = (length << 8) + uint(b) - } - - if opts.max > 0 && length > opts.max { - panic(fmt.Errorf("Length of vector exceeds declared max")) - } - if length < opts.min { - panic(fmt.Errorf("Length of vector below declared min")) - } - - data := make([]byte, length) - n, err = d.Read(data) - if err != nil { - panic(err) - } - if uint(n) != length { - panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length)) + data = d.Next(int(length)) + if len(data) != int(length) { + panic(fmt.Errorf("Available data less than declared length [%d < %d]", len(data), length)) + } } elemBuf := &decodeState{} elemBuf.Write(data) elems := []reflect.Value{} - read := int(opts.head) for elemBuf.Len() > 0 { elem := reflect.New(sd.elementType) read += sd.elementDec(elemBuf, elem, opts) @@ -231,9 +314,10 @@ func newStructDecoder(t reflect.Type) decoderFunc { tagOpts := parseTag(tag) sd.fieldOpts[i] = decOpts{ - head: tagOpts["head"], - max: tagOpts["max"], - min: tagOpts["min"], + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + varint: tagOpts[varintOption] > 0, } sd.fieldDecs[i] = typeDecoder(f.Type) @@ -241,3 +325,20 @@ func newStructDecoder(t reflect.Type) decoderFunc { return sd.decode } + +////////// + +type pointerDecoder struct { + base decoderFunc +} + +func (pd *pointerDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + v.Elem().Set(reflect.New(v.Elem().Type().Elem())) + return pd.base(d, v.Elem(), opts) +} + +func newPointerDecoder(t reflect.Type) decoderFunc { + baseDecoder := typeDecoder(t.Elem()) + pd := pointerDecoder{base: baseDecoder} + return pd.decode +} diff --git a/vendor/github.com/bifurcation/mint/syntax/encode.go b/vendor/github.com/bifurcation/mint/syntax/encode.go index 2874f40..54fea59 100644 --- a/vendor/github.com/bifurcation/mint/syntax/encode.go +++ b/vendor/github.com/bifurcation/mint/syntax/encode.go @@ -16,12 +16,19 @@ func Marshal(v interface{}) ([]byte, error) { return e.Bytes(), nil } +// Marshaler is the interface implemented by types that +// have a defined TLS encoding. +type Marshaler interface { + MarshalTLS() ([]byte, error) +} + // These are the options that can be specified in the struct tag. Right now, // all of them apply to variable-length vectors and nothing else type encOpts struct { - head uint // length of length in bytes - min uint // minimum size in bytes - max uint // maximum size in bytes + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes + varint bool // whether to encode as a varint } type encodeState struct { @@ -62,8 +69,14 @@ func typeEncoder(t reflect.Type) encoderFunc { return newTypeEncoder(t) } +var ( + marshalerType = reflect.TypeOf(new(Marshaler)).Elem() +) + func newTypeEncoder(t reflect.Type) encoderFunc { - // Note: Does not support Marshaler, so don't need the allowAddr argument + if t.Implements(marshalerType) { + return marshalerEncoder + } switch t.Kind() { case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: @@ -74,6 +87,8 @@ func newTypeEncoder(t reflect.Type) encoderFunc { return newSliceEncoder(t) case reflect.Struct: return newStructEncoder(t) + case reflect.Ptr: + return newPointerEncoder(t) default: panic(fmt.Errorf("Unsupported type (%s)", t)) } @@ -81,19 +96,65 @@ func newTypeEncoder(t reflect.Type) encoderFunc { ///// Specific encoders below -func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { - u := v.Uint() - switch v.Type().Kind() { - case reflect.Uint8: - e.WriteByte(byte(u)) - case reflect.Uint16: - e.Write([]byte{byte(u >> 8), byte(u)}) - case reflect.Uint32: - e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) - case reflect.Uint64: - e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32), - byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) +func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if v.Kind() == reflect.Ptr && v.IsNil() { + panic(fmt.Errorf("Cannot encode nil pointer")) } + + m, ok := v.Interface().(Marshaler) + if !ok { + panic(fmt.Errorf("Non-Marshaler passed to marshalerEncoder")) + } + + b, err := m.MarshalTLS() + if err == nil { + _, err = e.Write(b) + } + + if err != nil { + panic(err) + } +} + +////////// + +func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if opts.varint { + varintEncoder(e, v, opts) + return + } + + writeUint(e, v.Uint(), int(v.Type().Size())) +} + +func varintEncoder(e *encodeState, v reflect.Value, opts encOpts) { + writeVarint(e, v.Uint()) +} + +func writeVarint(e *encodeState, u uint64) { + if (u >> 62) > 0 { + panic(fmt.Errorf("uint value is too big for varint")) + } + + var varintLen int + for _, len := range []uint{1, 2, 4, 8} { + if u < (uint64(1) << (8*len - 2)) { + varintLen = int(len) + break + } + } + + twoBits := map[int]uint64{1: 0x00, 2: 0x01, 4: 0x02, 8: 0x03}[varintLen] + shift := uint(8*varintLen - 2) + writeUint(e, u|(twoBits<> uint(8*(len-i-1))) + } + e.Write(data) } ////////// @@ -121,27 +182,34 @@ type sliceEncoder struct { } func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { - if opts.head == 0 { - panic(fmt.Errorf("Cannot encode a slice without a header length")) - } - arrayState := &encodeState{} se.ae.encode(arrayState, v, opts) n := uint(arrayState.Len()) + if opts.head == 0 { + panic(fmt.Errorf("Cannot encode a slice without a header length")) + } + if opts.max > 0 && n > opts.max { panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max)) } - if n>>(8*opts.head) > 0 { - panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head)) - } if n < opts.min { panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min)) } - for i := int(opts.head - 1); i >= 0; i -= 1 { - e.WriteByte(byte(n >> (8 * uint(i)))) + switch opts.head { + case headValueNoHead: + // None. + case headValueVarint: + writeVarint(e, uint64(n)) + default: + if n>>(8*opts.head) > 0 { + panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head)) + } + + writeUint(e, uint64(n), int(opts.head)) } + e.Write(arrayState.Bytes()) } @@ -176,12 +244,33 @@ func newStructEncoder(t reflect.Type) encoderFunc { tagOpts := parseTag(tag) se.fieldOpts[i] = encOpts{ - head: tagOpts["head"], - max: tagOpts["max"], - min: tagOpts["min"], + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + varint: tagOpts[varintOption] > 0, } se.fieldEncs[i] = typeEncoder(f.Type) } return se.encode } + +////////// + +type pointerEncoder struct { + base encoderFunc +} + +func (pe pointerEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + if v.IsNil() { + panic(fmt.Errorf("Cannot marshal a struct containing a nil pointer")) + } + + pe.base(e, v.Elem(), opts) +} + +func newPointerEncoder(t reflect.Type) encoderFunc { + baseEncoder := typeEncoder(t.Elem()) + pe := pointerEncoder{base: baseEncoder} + return pe.encode +} diff --git a/vendor/github.com/bifurcation/mint/syntax/tags.go b/vendor/github.com/bifurcation/mint/syntax/tags.go index a6c9c88..e06f8ec 100644 --- a/vendor/github.com/bifurcation/mint/syntax/tags.go +++ b/vendor/github.com/bifurcation/mint/syntax/tags.go @@ -5,16 +5,27 @@ import ( "strings" ) -// `tls:"head=2,min=2,max=255"` +// `tls:"head=2,min=2,max=255,varint"` type tagOptions map[string]uint +var ( + varintOption = "varint" + + headOptionNone = "none" + headOptionVarint = "varint" + headValueNoHead = uint(255) + headValueVarint = uint(254) +) + // parseTag parses a struct field's "tls" tag as a comma-separated list of -// name=value pairs, where the values MUST be unsigned integers +// name=value pairs, where the values MUST be unsigned integers, or in +// the special case of head, "none" or "varint" func parseTag(tag string) tagOptions { opts := tagOptions{} for _, token := range strings.Split(tag, ",") { - if strings.Index(token, "=") == -1 { + if token == varintOption { + opts[varintOption] = 1 continue } @@ -22,7 +33,16 @@ func parseTag(tag string) tagOptions { if len(parts[0]) == 0 { continue } - if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 { + + if len(parts) == 1 { + continue + } + + if parts[0] == "head" && parts[1] == headOptionNone { + opts[parts[0]] = headValueNoHead + } else if parts[0] == "head" && parts[1] == headOptionVarint { + opts[parts[0]] = headValueVarint + } else if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 { opts[parts[0]] = uint(val) } } diff --git a/vendor/github.com/bifurcation/mint/timer.go b/vendor/github.com/bifurcation/mint/timer.go new file mode 100644 index 0000000..0b7f7af --- /dev/null +++ b/vendor/github.com/bifurcation/mint/timer.go @@ -0,0 +1,122 @@ +package mint + +import ( + "time" +) + +// This is a simple timer implementation. Timers are stored in a sorted +// list. +// TODO(ekr@rtfm.com): Add a way to uncouple these from the system +// clock. +type timerCb func() error + +type timer struct { + label string + cb timerCb + deadline time.Time + duration uint32 +} + +type timerSet struct { + ts []*timer +} + +func newTimerSet() *timerSet { + return &timerSet{} +} + +func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer { + now := time.Now() + t := timer{ + label, + cb, + now.Add(time.Millisecond * time.Duration(delayMs)), + delayMs, + } + logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline) + + var i int + ntimers := len(ts.ts) + for i = 0; i < ntimers; i++ { + if t.deadline.Before(ts.ts[i].deadline) { + break + } + } + + tmp := make([]*timer, 0, ntimers+1) + tmp = append(tmp, ts.ts[:i]...) + tmp = append(tmp, &t) + tmp = append(tmp, ts.ts[i:]...) + ts.ts = tmp + + return &t +} + +// TODO(ekr@rtfm.com): optimize this now that the list is sorted. +// We should be able to do just one list manipulation, as long +// as we're careful about how we handle inserts during callbacks. +func (ts *timerSet) check(now time.Time) error { + for i, t := range ts.ts { + if now.After(t.deadline) { + ts.ts = append(ts.ts[:i], ts.ts[:i+1]...) + if t.cb != nil { + logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline) + cb := t.cb + t.cb = nil + err := cb() + if err != nil { + return err + } + } + } else { + break + } + } + return nil +} + +// Returns the next time any of the timers would fire. +func (ts *timerSet) remaining() (bool, time.Duration) { + for _, t := range ts.ts { + if t.cb != nil { + return true, time.Until(t.deadline) + } + } + + return false, time.Duration(0) +} + +func (ts *timerSet) cancel(label string) { + for _, t := range ts.ts { + if t.label == label { + t.cancel() + } + } +} + +func (ts *timerSet) getTimer(label string) *timer { + for _, t := range ts.ts { + if t.label == label && t.cb != nil { + return t + } + } + return nil +} + +func (ts *timerSet) getAllTimers() []string { + var ret []string + + for _, t := range ts.ts { + if t.cb != nil { + ret = append(ret, t.label) + } + } + + return ret +} + +func (t *timer) cancel() { + logf(logTypeHandshake, "Timer %s cancelled", t.label) + t.cb = nil + t.label = "" +} diff --git a/vendor/github.com/bifurcation/mint/tls.go b/vendor/github.com/bifurcation/mint/tls.go index 0c57aba..4d22869 100644 --- a/vendor/github.com/bifurcation/mint/tls.go +++ b/vendor/github.com/bifurcation/mint/tls.go @@ -51,11 +51,14 @@ func (l *Listener) Accept() (c net.Conn, err error) { // Listener and wraps each connection with Server. // The configuration config must be non-nil and must include // at least one certificate or else set GetCertificate. -func NewListener(inner net.Listener, config *Config) net.Listener { +func NewListener(inner net.Listener, config *Config) (net.Listener, error) { + if config != nil && config.NonBlocking { + return nil, errors.New("listening not possible in non-blocking mode") + } l := new(Listener) l.Listener = inner l.config = config - return l + return l, nil } // Listen creates a TLS listener accepting connections on the @@ -70,7 +73,7 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) { if err != nil { return nil, err } - return NewListener(l, config), nil + return NewListener(l, config) } type TimeoutError struct{} @@ -87,6 +90,10 @@ func (TimeoutError) Temporary() bool { return true } // DialWithDialer interprets a nil configuration as equivalent to the zero // configuration; see the documentation of Config for the defaults. func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + if config != nil && config.NonBlocking { + return nil, errors.New("dialing not possible in non-blocking mode") + } + // We want the Timeout and Deadline values from dialer to cover the // whole process: TCP connection and TLS handshake. This means that we // also need to start our own timers now. @@ -121,16 +128,20 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (* if config == nil { config = &Config{} + } else { + config = config.Clone() } + // If no ServerName is set, infer the ServerName // from the hostname we're connecting to. if config.ServerName == "" { - // Make a copy to avoid polluting argument or default. - c := config.Clone() - c.ServerName = hostname - config = c + config.ServerName = hostname + } + // Set up DTLS as needed. + config.UseDTLS = (network == "udp") + conn := Client(rawConn, config) if timeout == 0 { diff --git a/vendor/github.com/cheekybits/genny/LICENSE b/vendor/github.com/cheekybits/genny/LICENSE new file mode 100644 index 0000000..519d7f2 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2014 cheekybits + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/vendor/github.com/cheekybits/genny/generic/doc.go b/vendor/github.com/cheekybits/genny/generic/doc.go new file mode 100644 index 0000000..3bd6c86 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/generic/doc.go @@ -0,0 +1,2 @@ +// Package generic contains the generic marker types. +package generic diff --git a/vendor/github.com/cheekybits/genny/generic/generic.go b/vendor/github.com/cheekybits/genny/generic/generic.go new file mode 100644 index 0000000..04a2306 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/generic/generic.go @@ -0,0 +1,13 @@ +package generic + +// Type is the placeholder type that indicates a generic value. +// When genny is executed, variables of this type will be replaced with +// references to the specific types. +// var GenericType generic.Type +type Type interface{} + +// Number is the placehoder type that indiccates a generic numerical value. +// When genny is executed, variables of this type will be replaced with +// references to the specific types. +// var GenericType generic.Number +type Number float64 diff --git a/vendor/github.com/hashicorp/golang-lru/2q.go b/vendor/github.com/hashicorp/golang-lru/2q.go index 337d963..e474cd0 100644 --- a/vendor/github.com/hashicorp/golang-lru/2q.go +++ b/vendor/github.com/hashicorp/golang-lru/2q.go @@ -30,9 +30,9 @@ type TwoQueueCache struct { size int recentSize int - recent *simplelru.LRU - frequent *simplelru.LRU - recentEvict *simplelru.LRU + recent simplelru.LRUCache + frequent simplelru.LRUCache + recentEvict simplelru.LRUCache lock sync.RWMutex } @@ -84,7 +84,8 @@ func New2QParams(size int, recentRatio float64, ghostRatio float64) (*TwoQueueCa return c, nil } -func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) { +// Get looks up a key's value from the cache. +func (c *TwoQueueCache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() @@ -105,6 +106,7 @@ func (c *TwoQueueCache) Get(key interface{}) (interface{}, bool) { return nil, false } +// Add adds a value to the cache. func (c *TwoQueueCache) Add(key, value interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -160,12 +162,15 @@ func (c *TwoQueueCache) ensureSpace(recentEvict bool) { c.frequent.RemoveOldest() } +// Len returns the number of items in the cache. func (c *TwoQueueCache) Len() int { c.lock.RLock() defer c.lock.RUnlock() return c.recent.Len() + c.frequent.Len() } +// Keys returns a slice of the keys in the cache. +// The frequently used keys are first in the returned slice. func (c *TwoQueueCache) Keys() []interface{} { c.lock.RLock() defer c.lock.RUnlock() @@ -174,6 +179,7 @@ func (c *TwoQueueCache) Keys() []interface{} { return append(k1, k2...) } +// Remove removes the provided key from the cache. func (c *TwoQueueCache) Remove(key interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -188,6 +194,7 @@ func (c *TwoQueueCache) Remove(key interface{}) { } } +// Purge is used to completely clear the cache. func (c *TwoQueueCache) Purge() { c.lock.Lock() defer c.lock.Unlock() @@ -196,13 +203,17 @@ func (c *TwoQueueCache) Purge() { c.recentEvict.Purge() } +// Contains is used to check if the cache contains a key +// without updating recency or frequency. func (c *TwoQueueCache) Contains(key interface{}) bool { c.lock.RLock() defer c.lock.RUnlock() return c.frequent.Contains(key) || c.recent.Contains(key) } -func (c *TwoQueueCache) Peek(key interface{}) (interface{}, bool) { +// Peek is used to inspect the cache value of a key +// without updating recency or frequency. +func (c *TwoQueueCache) Peek(key interface{}) (value interface{}, ok bool) { c.lock.RLock() defer c.lock.RUnlock() if val, ok := c.frequent.Peek(key); ok { diff --git a/vendor/github.com/hashicorp/golang-lru/arc.go b/vendor/github.com/hashicorp/golang-lru/arc.go index a2a2528..555225a 100644 --- a/vendor/github.com/hashicorp/golang-lru/arc.go +++ b/vendor/github.com/hashicorp/golang-lru/arc.go @@ -18,11 +18,11 @@ type ARCCache struct { size int // Size is the total capacity of the cache p int // P is the dynamic preference towards T1 or T2 - t1 *simplelru.LRU // T1 is the LRU for recently accessed items - b1 *simplelru.LRU // B1 is the LRU for evictions from t1 + t1 simplelru.LRUCache // T1 is the LRU for recently accessed items + b1 simplelru.LRUCache // B1 is the LRU for evictions from t1 - t2 *simplelru.LRU // T2 is the LRU for frequently accessed items - b2 *simplelru.LRU // B2 is the LRU for evictions from t2 + t2 simplelru.LRUCache // T2 is the LRU for frequently accessed items + b2 simplelru.LRUCache // B2 is the LRU for evictions from t2 lock sync.RWMutex } @@ -60,11 +60,11 @@ func NewARC(size int) (*ARCCache, error) { } // Get looks up a key's value from the cache. -func (c *ARCCache) Get(key interface{}) (interface{}, bool) { +func (c *ARCCache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() - // Ff the value is contained in T1 (recent), then + // If the value is contained in T1 (recent), then // promote it to T2 (frequent) if val, ok := c.t1.Peek(key); ok { c.t1.Remove(key) @@ -153,7 +153,7 @@ func (c *ARCCache) Add(key, value interface{}) { // Remove from B2 c.b2.Remove(key) - // Add the key to the frequntly used list + // Add the key to the frequently used list c.t2.Add(key, value) return } @@ -247,7 +247,7 @@ func (c *ARCCache) Contains(key interface{}) bool { // Peek is used to inspect the cache value of a key // without updating recency or frequency. -func (c *ARCCache) Peek(key interface{}) (interface{}, bool) { +func (c *ARCCache) Peek(key interface{}) (value interface{}, ok bool) { c.lock.RLock() defer c.lock.RUnlock() if val, ok := c.t1.Peek(key); ok { diff --git a/vendor/github.com/hashicorp/golang-lru/doc.go b/vendor/github.com/hashicorp/golang-lru/doc.go new file mode 100644 index 0000000..2547df9 --- /dev/null +++ b/vendor/github.com/hashicorp/golang-lru/doc.go @@ -0,0 +1,21 @@ +// Package lru provides three different LRU caches of varying sophistication. +// +// Cache is a simple LRU cache. It is based on the +// LRU implementation in groupcache: +// https://github.com/golang/groupcache/tree/master/lru +// +// TwoQueueCache tracks frequently used and recently used entries separately. +// This avoids a burst of accesses from taking out frequently used entries, +// at the cost of about 2x computational overhead and some extra bookkeeping. +// +// ARCCache is an adaptive replacement cache. It tracks recent evictions as +// well as recent usage in both the frequent and recent caches. Its +// computational overhead is comparable to TwoQueueCache, but the memory +// overhead is linear with the size of the cache. +// +// ARC has been patented by IBM, so do not use it if that is problematic for +// your program. +// +// All caches in this package take locks while operating, and are therefore +// thread-safe for consumers. +package lru diff --git a/vendor/github.com/hashicorp/golang-lru/lru.go b/vendor/github.com/hashicorp/golang-lru/lru.go index a6285f9..c8d9b0a 100644 --- a/vendor/github.com/hashicorp/golang-lru/lru.go +++ b/vendor/github.com/hashicorp/golang-lru/lru.go @@ -1,6 +1,3 @@ -// This package provides a simple LRU cache. It is based on the -// LRU implementation in groupcache: -// https://github.com/golang/groupcache/tree/master/lru package lru import ( @@ -11,11 +8,11 @@ import ( // Cache is a thread-safe fixed size LRU cache. type Cache struct { - lru *simplelru.LRU + lru simplelru.LRUCache lock sync.RWMutex } -// New creates an LRU of the given size +// New creates an LRU of the given size. func New(size int) (*Cache, error) { return NewWithEvict(size, nil) } @@ -33,7 +30,7 @@ func NewWithEvict(size int, onEvicted func(key interface{}, value interface{})) return c, nil } -// Purge is used to completely clear the cache +// Purge is used to completely clear the cache. func (c *Cache) Purge() { c.lock.Lock() c.lru.Purge() @@ -41,30 +38,30 @@ func (c *Cache) Purge() { } // Add adds a value to the cache. Returns true if an eviction occurred. -func (c *Cache) Add(key, value interface{}) bool { +func (c *Cache) Add(key, value interface{}) (evicted bool) { c.lock.Lock() defer c.lock.Unlock() return c.lru.Add(key, value) } // Get looks up a key's value from the cache. -func (c *Cache) Get(key interface{}) (interface{}, bool) { +func (c *Cache) Get(key interface{}) (value interface{}, ok bool) { c.lock.Lock() defer c.lock.Unlock() return c.lru.Get(key) } -// Check if a key is in the cache, without updating the recent-ness -// or deleting it for being stale. +// Contains checks if a key is in the cache, without updating the +// recent-ness or deleting it for being stale. func (c *Cache) Contains(key interface{}) bool { c.lock.RLock() defer c.lock.RUnlock() return c.lru.Contains(key) } -// Returns the key value (or undefined if not found) without updating +// Peek returns the key value (or undefined if not found) without updating // the "recently used"-ness of the key. -func (c *Cache) Peek(key interface{}) (interface{}, bool) { +func (c *Cache) Peek(key interface{}) (value interface{}, ok bool) { c.lock.RLock() defer c.lock.RUnlock() return c.lru.Peek(key) @@ -73,16 +70,15 @@ func (c *Cache) Peek(key interface{}) (interface{}, bool) { // ContainsOrAdd checks if a key is in the cache without updating the // recent-ness or deleting it for being stale, and if not, adds the value. // Returns whether found and whether an eviction occurred. -func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evict bool) { +func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evicted bool) { c.lock.Lock() defer c.lock.Unlock() if c.lru.Contains(key) { return true, false - } else { - evict := c.lru.Add(key, value) - return false, evict } + evicted = c.lru.Add(key, value) + return false, evicted } // Remove removes the provided key from the cache. diff --git a/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go b/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go index cb416b3..5673773 100644 --- a/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go +++ b/vendor/github.com/hashicorp/golang-lru/simplelru/lru.go @@ -36,7 +36,7 @@ func NewLRU(size int, onEvict EvictCallback) (*LRU, error) { return c, nil } -// Purge is used to completely clear the cache +// Purge is used to completely clear the cache. func (c *LRU) Purge() { for k, v := range c.items { if c.onEvict != nil { @@ -48,7 +48,7 @@ func (c *LRU) Purge() { } // Add adds a value to the cache. Returns true if an eviction occurred. -func (c *LRU) Add(key, value interface{}) bool { +func (c *LRU) Add(key, value interface{}) (evicted bool) { // Check for existing item if ent, ok := c.items[key]; ok { c.evictList.MoveToFront(ent) @@ -78,17 +78,18 @@ func (c *LRU) Get(key interface{}) (value interface{}, ok bool) { return } -// Check if a key is in the cache, without updating the recent-ness +// Contains checks if a key is in the cache, without updating the recent-ness // or deleting it for being stale. func (c *LRU) Contains(key interface{}) (ok bool) { _, ok = c.items[key] return ok } -// Returns the key value (or undefined if not found) without updating +// Peek returns the key value (or undefined if not found) without updating // the "recently used"-ness of the key. func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) { - if ent, ok := c.items[key]; ok { + var ent *list.Element + if ent, ok = c.items[key]; ok { return ent.Value.(*entry).value, true } return nil, ok @@ -96,7 +97,7 @@ func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) { // Remove removes the provided key from the cache, returning if the // key was contained. -func (c *LRU) Remove(key interface{}) bool { +func (c *LRU) Remove(key interface{}) (present bool) { if ent, ok := c.items[key]; ok { c.removeElement(ent) return true @@ -105,7 +106,7 @@ func (c *LRU) Remove(key interface{}) bool { } // RemoveOldest removes the oldest item from the cache. -func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) { +func (c *LRU) RemoveOldest() (key interface{}, value interface{}, ok bool) { ent := c.evictList.Back() if ent != nil { c.removeElement(ent) @@ -116,7 +117,7 @@ func (c *LRU) RemoveOldest() (interface{}, interface{}, bool) { } // GetOldest returns the oldest entry -func (c *LRU) GetOldest() (interface{}, interface{}, bool) { +func (c *LRU) GetOldest() (key interface{}, value interface{}, ok bool) { ent := c.evictList.Back() if ent != nil { kv := ent.Value.(*entry) diff --git a/vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go b/vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go new file mode 100644 index 0000000..744cac0 --- /dev/null +++ b/vendor/github.com/hashicorp/golang-lru/simplelru/lru_interface.go @@ -0,0 +1,37 @@ +package simplelru + + +// LRUCache is the interface for simple LRU cache. +type LRUCache interface { + // Adds a value to the cache, returns true if an eviction occurred and + // updates the "recently used"-ness of the key. + Add(key, value interface{}) bool + + // Returns key's value from the cache and + // updates the "recently used"-ness of the key. #value, isFound + Get(key interface{}) (value interface{}, ok bool) + + // Check if a key exsists in cache without updating the recent-ness. + Contains(key interface{}) (ok bool) + + // Returns key's value without updating the "recently used"-ness of the key. + Peek(key interface{}) (value interface{}, ok bool) + + // Removes a key from the cache. + Remove(key interface{}) bool + + // Removes the oldest entry from cache. + RemoveOldest() (interface{}, interface{}, bool) + + // Returns the oldest entry from the cache. #key, value, isFound + GetOldest() (interface{}, interface{}, bool) + + // Returns a slice of the keys in the cache, from oldest to newest. + Keys() []interface{} + + // Returns the number of items in the cache. + Len() int + + // Clear all cache entries + Purge() +} diff --git a/vendor/github.com/isofew/go-stun/LICENSE b/vendor/github.com/isofew/go-stun/LICENSE new file mode 100644 index 0000000..12558c0 --- /dev/null +++ b/vendor/github.com/isofew/go-stun/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Vasily Vasilyev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/isofew/go-stun/stun/agent.go b/vendor/github.com/isofew/go-stun/stun/agent.go new file mode 100644 index 0000000..65fd9be --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/agent.go @@ -0,0 +1,284 @@ +package stun + +import ( + "errors" + "math/rand" + "net" + "sync" + "time" +) + +var DefaultConfig = &Config{ + RetransmissionTimeout: 500 * time.Millisecond, + TransactionTimeout: 39500 * time.Millisecond, + Software: "pixelbender/go-stun", +} + +type Handler interface { + ServeSTUN(msg *Message, tr Transport) +} + +type HandlerFunc func(msg *Message, tr Transport) + +func (h HandlerFunc) ServeSTUN(msg *Message, tr Transport) { + h(msg, tr) +} + +type Config struct { + // AuthMethod returns a key for MESSAGE-INTEGRITY attribute + AuthMethod AuthMethod + // Retransmission timeout, default is 500 milliseconds + RetransmissionTimeout time.Duration + // Transaction timeout, default is 39.5 seconds + TransactionTimeout time.Duration + // Fingerprint, if true all outgoing messages contain FINGERPRINT attribute + Fingerprint bool + // Software is a SOFTWARE attribute value for outgoing messages, if not empty + Software string + // Logf, if set all sent and received messages printed using Logf + Logf func(format string, args ...interface{}) +} + +func (c *Config) attrs() []Attr { + if c == nil { + return nil + } + var a []Attr + if c.Software != "" { + a = append(a, String(AttrSoftware, c.Software)) + } + if c.Fingerprint { + a = append(a, Fingerprint) + } + return a +} + +func (c *Config) Clone() *Config { + r := *c + return &r +} + +type Agent struct { + config *Config + Handler Handler + m mux +} + +func NewAgent(config *Config) *Agent { + if config == nil { + config = DefaultConfig + } + return &Agent{ + config: config, + } +} + +func (a *Agent) Send(msg *Message, tr Transport) (err error) { + msg = &Message{ + msg.Type, + msg.Transaction, + append(a.config.attrs(), msg.Attributes...), + } + if log := a.config.Logf; log != nil { + log("%v → %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg) + } + b := msg.Marshal(getBuffer()[:0]) + _, err = tr.Write(b) + putBuffer(b) + return +} + +func (a *Agent) ServeConn(c net.Conn, stop chan struct{}) error { + if c, ok := c.(net.PacketConn); ok { + return a.ServePacket(c, stop) + } + var ( + b = getBuffer() + p int + ) + defer putBuffer(b) + for { + select { + case <-stop: + return nil + default: + } + if p >= len(b) { + return errBufferOverflow + } + n, err := c.Read(b[p:]) + if err != nil { + return err + } + p += n + n = 0 + for n < p { + r, err := a.ServeTransport(b[n:p], c) + if err != nil { + return err + } + n += r + } + if n > 0 { + if n < p { + p = copy(b, b[n:p]) + } else { + p = 0 + } + } + } +} + +func (a *Agent) ServePacket(c net.PacketConn, stop chan struct{}) error { + b := getBuffer() + defer putBuffer(b) + // don't close the connection since we're going to reuse it + // defer c.Close() + + for { + select { + case <-stop: + return nil + default: + } + n, addr, err := c.ReadFrom(b) + if err != nil { + return err + } + if n > 0 { + a.ServeTransport(b[:n], &packetConn{c, addr}) + } + } +} + +func (a *Agent) ServeTransport(b []byte, tr Transport) (n int, err error) { + msg := &Message{} + n, err = msg.Unmarshal(b) + if err != nil { + return + } + a.ServeSTUN(msg, tr) + return +} + +func (a *Agent) ServeSTUN(msg *Message, tr Transport) { + if log := a.config.Logf; log != nil { + log("%v ← %v %v", tr.LocalAddr(), tr.RemoteAddr(), msg) + } + if a.m.serve(msg, tr) { + return + } + if h := a.Handler; h != nil { + go h.ServeSTUN(msg, tr) + } +} + +func (a *Agent) RoundTrip(req *Message, to Transport) (res *Message, from Transport, err error) { + var ( + start = time.Now() + rto = a.config.RetransmissionTimeout + udp = to.LocalAddr().Network() == "udp" + tx = a.m.newTx() + ) + defer a.m.closeTx(tx) + req = &Message{req.Type, tx.id, req.Attributes} + if err = a.Send(req, to); err != nil { + return + } + for { + d := a.config.TransactionTimeout - time.Since(start) + if d < 0 { + err = errTimeout + return + } + if udp && d > rto { + d = rto + } + res, from, err = tx.Receive(d) + if udp && err == errTimeout && d == rto { + rto <<= 1 + a.Send(req, to) + continue + } + return + } +} + +type mux struct { + sync.RWMutex + t map[string]*transaction +} + +func (m *mux) serve(msg *Message, tr Transport) bool { + m.RLock() + tx, ok := m.t[string(msg.Transaction)] + m.RUnlock() + if ok { + tx.msg, tx.from = msg, tr + tx.Done() + return true + } + return false +} + +func (m *mux) newTx() *transaction { + tx := &transaction{id: NewTransaction()} + m.Lock() + if m.t == nil { + m.t = make(map[string]*transaction) + } else { + for m.t[string(tx.id)] != nil { + rand.Read(tx.id[4:]) + } + } + m.t[string(tx.id)] = tx + m.Unlock() + return tx +} + +func (m *mux) closeTx(tx *transaction) { + m.Lock() + delete(m.t, string(tx.id)) + m.Unlock() +} + +func (m *mux) Close() { + m.Lock() + defer m.Unlock() + for _, it := range m.t { + it.Close() + } + m.t = nil +} + +type transaction struct { + sync.WaitGroup + id []byte + from Transport + msg *Message + err error +} + +func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) { + tx.Add(1) + t := time.AfterFunc(d, tx.timeout) + tx.Wait() + t.Stop() + if err = tx.err; err != nil { + return + } + return tx.msg, tx.from, nil +} + +func (tx *transaction) timeout() { + tx.err = errTimeout + tx.Done() +} + +func (tx *transaction) Close() { + tx.err = errCanceled + tx.Done() +} + +var errCanceled = errors.New("stun: transaction canceled") +var errTimeout = errors.New("stun: transaction timeout") diff --git a/vendor/github.com/isofew/go-stun/stun/attribute.go b/vendor/github.com/isofew/go-stun/stun/attribute.go new file mode 100644 index 0000000..12b058f --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/attribute.go @@ -0,0 +1,438 @@ +package stun + +import ( + "bytes" + "crypto/hmac" + "crypto/sha1" + "fmt" + "hash/crc32" + "net" + "strconv" +) + +// Attribute represents a STUN attribute. +type Attr interface { + Type() uint16 + Marshal(p []byte) []byte + Unmarshal(b []byte) error +} + +// IP address family +const ( + IPv4 = 0x01 + IPv6 = 0x02 +) + +// IP address family +const ( + ChangeIP uint64 = 0x04 + ChangePort = 0x02 +) + +func newAttr(typ uint16) Attr { + switch typ { + case AttrMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress, + AttrXorMappedAddress, AttrAlternateServer, AttrResponseOrigin, AttrOtherAddress, + AttrResponseAddress, AttrSourceAddress, AttrChangedAddress, AttrReflectedFrom: + return &addr{typ: typ} + case AttrRequestedAddressFamily, AttrRequestedTransport: + return &number{typ: typ, size: 4, pad: 24} + case AttrChannelNumber, AttrResponsePort: + return &number{typ: typ, size: 4, pad: 16} + case AttrLifetime, AttrConnectionID, AttrCacheTimeout, + AttrBandwidth, AttrTimerVal, + AttrTransactionTransmitCounter, + AttrEcnCheck, AttrChangeRequest, AttrPriority: + return &number{typ: typ, size: 4} + case AttrIceControlled, AttrIceControlling: + return &number{typ: typ, size: 8} + case AttrUsername, AttrRealm, AttrNonce, AttrSoftware, AttrPassword, AttrThirdPartyAuthorization, + AttrData, AttrAccessToken, AttrReservationToken, AttrMobilityTicket, AttrPadding, AttrUnknownAttributes: + return &raw{typ: typ} + case AttrMessageIntegrity: + return &integrity{} + case AttrErrorCode: + return &Error{} + case AttrEvenPort: + return &number{typ: typ, size: 1} + case AttrDontFragment, AttrUseCandidate: + return flag(typ) + case AttrFingerprint: + return &fingerprint{} + } + return nil +} + +func AttrName(typ uint16) string { + if r, ok := attrNames[typ]; ok { + return r + } + return "0x" + strconv.FormatUint(uint64(typ), 16) +} + +func Int(typ uint16, v uint64) Attr { + switch typ { + case AttrRequestedAddressFamily, AttrRequestedTransport: + return &number{typ, 4, 24, v} + case AttrChannelNumber, AttrResponsePort: + return &number{typ, 4, 16, v} + case AttrIceControlled, AttrIceControlling: + return &number{typ, 8, 0, v} + case AttrEvenPort: + return &number{typ, 1, 0, v} + default: + return &number{typ, 4, 0, v} + } +} + +type number struct { + typ uint16 + size, pad uint8 + v uint64 +} + +func (a *number) Type() uint16 { return a.typ } + +func (a *number) Marshal(p []byte) []byte { + r, b := grow(p, int(a.size)) + switch a.size { + case 1: + b[0] = byte(a.v) + case 4: + be.PutUint32(b, uint32(a.v<> a.pad) + case 8: + a.v = be.Uint64(b) >> a.pad + } + return nil +} + +func (a *number) String() string { + return "0x" + strconv.FormatUint(a.v, 16) +} + +func Flag(v uint16) Attr { + return flag(v) +} + +type flag uint16 + +func (attr flag) Type() uint16 { return uint16(attr) } +func (flag) Marshal(p []byte) []byte { return p } +func (flag) Unmarshal(b []byte) error { return nil } + +// Error represents the ERROR-CODE attribute. +type Error struct { + Code int + Reason string +} + +func NewError(code int) *Error { + return &Error{code, ErrorText(code)} +} + +func (*Error) Type() uint16 { return AttrErrorCode } + +func (e *Error) Marshal(p []byte) []byte { + r, b := grow(p, 4+len(e.Reason)) + b[0] = 0 + b[1] = 0 + b[2] = byte(e.Code / 100) + b[3] = byte(e.Code % 100) + copy(b[4:], e.Reason) + return r +} + +func (e *Error) Unmarshal(b []byte) error { + if len(b) < 4 { + return errFormat + } + e.Code = int(b[2])*100 + int(b[3]) + e.Reason = getString(b[4:]) + return nil +} + +func (e *Error) Error() string { return e.String() } +func (e *Error) String() string { return fmt.Sprintf("%d %s", e.Code, e.Reason) } + +// ErrorText returns a text for the STUN error code. It returns the empty string if the code is unknown. +func ErrorText(code int) string { return errorText[code] } + +func getString(b []byte) string { + for i := len(b); i >= 0; i-- { + if b[i-1] > 0 { + return string(b[:i]) + } + } + return "" +} + +func Addr(typ uint16, v net.Addr) Attr { + ip, port := SockAddr(v) + return &addr{typ, ip, port} +} + +func SockAddr(v net.Addr) (net.IP, int) { + switch a := v.(type) { + case *net.UDPAddr: + return a.IP, a.Port + case *net.TCPAddr: + return a.IP, a.Port + case *net.IPAddr: + return a.IP, 0 + default: + return net.IPv4zero, 0 + } +} + +func sameAddr(a, b net.Addr) bool { + aip, aport := SockAddr(a) + bip, bport := SockAddr(b) + return aip.Equal(bip) && aport == bport +} + +func NewAddr(network string, ip net.IP, port int) net.Addr { + switch network { + case "udp", "udp4", "udp6": + return &net.UDPAddr{IP: ip, Port: port} + case "tcp", "tcp4", "tcp6": + return &net.TCPAddr{IP: ip, Port: port} + } + return &net.IPAddr{IP: ip} +} + +func IP(typ uint16, ip net.IP) Attr { return &addr{typ, ip, 0} } + +type addr struct { + typ uint16 + IP net.IP + Port int +} + +func (addr *addr) Type() uint16 { return addr.typ } + +func (addr *addr) Addr(network string) net.Addr { + return NewAddr(network, addr.IP, addr.Port) +} + +func (addr *addr) Xored() bool { + switch addr.typ { + case AttrXorMappedAddress, AttrXorPeerAddress, AttrXorRelayedAddress: + return true + default: + return false + } +} + +func (addr *addr) Marshal(p []byte) []byte { + return addr.MarshalAddr(p, nil) +} + +func (addr *addr) MarshalAddr(p, tx []byte) []byte { + fam, ip := IPv4, addr.IP.To4() + if ip == nil { + fam, ip = IPv6, addr.IP + } + r, b := grow(p, 4+len(ip)) + b[0] = 0 + b[1] = byte(fam) + if addr.Xored() && tx != nil { + be.PutUint16(b[2:], uint16(addr.Port)^0x2112) + b = b[4:] + for i, it := range ip { + b[i] = it ^ tx[i] + } + } else { + be.PutUint16(b[2:], uint16(addr.Port)) + copy(b[4:], ip) + } + return r +} + +func (addr *addr) Unmarshal(b []byte) error { + return addr.UnmarshalAddr(b, nil) +} + +func (addr *addr) UnmarshalAddr(b, tx []byte) error { + if len(b) < 4 { + return errFormat + } + n, port := net.IPv4len, int(be.Uint16(b[2:])) + if b[1] == IPv6 { + n = net.IPv6len + } + if b = b[4:]; len(b) < n { + return errFormat + } + addr.IP = make(net.IP, n) + if addr.Xored() && tx != nil { + for i, it := range b { + addr.IP[i] = it ^ tx[i] + } + addr.Port = port ^ 0x2112 + } else { + copy(addr.IP, b) + addr.Port = port + } + return nil +} + +func (addr *addr) Equal(a *addr) bool { + return addr == a || (addr != nil && a != nil && addr.IP.Equal(a.IP) && addr.Port == a.Port) +} + +func (addr *addr) String() string { + if addr.Port == 0 { + return addr.IP.String() + } + return net.JoinHostPort(addr.IP.String(), strconv.Itoa(addr.Port)) +} + +func Bytes(typ uint16, v []byte) Attr { return &raw{typ, v} } + +type raw struct { + typ uint16 + data []byte +} + +func (attr *raw) Type() uint16 { return attr.typ } +func (attr *raw) Marshal(p []byte) []byte { return append(p, attr.data...) } +func (attr *raw) Unmarshal(p []byte) error { + attr.data = p + return nil +} +func (attr *raw) String() string { return string(attr.data) } + +func String(typ uint16, v string) Attr { + return &str{typ, v} +} + +type str struct { + typ uint16 + data string +} + +func (attr *str) Type() uint16 { return attr.typ } +func (attr *str) Marshal(p []byte) []byte { return append(p, attr.data...) } +func (attr *str) Unmarshal(p []byte) error { + attr.data = string(p) + return nil +} +func (attr *str) String() string { return attr.data } + +func MessageIntegrity(key []byte) Attr { + return &integrity{key: key} +} + +type integrity struct { + key, sum, raw []byte +} + +func (*integrity) Type() uint16 { + return AttrMessageIntegrity +} + +func (attr *integrity) Marshal(p []byte) []byte { + return append(p, attr.sum...) +} + +func (attr *integrity) Unmarshal(b []byte) error { + if len(b) < 20 { + return errFormat + } + attr.sum = b + return nil +} + +func (attr *integrity) MarshalSum(p, raw []byte) []byte { + n := len(raw) - 4 + be.PutUint16(raw[2:], uint16(n+4)) + return attr.Sum(attr.key, raw[:n], p) +} + +func (attr *integrity) UnmarshalSum(p, raw []byte) error { + attr.raw = raw + return attr.Unmarshal(p) +} + +func (attr *integrity) Sum(key, data, p []byte) []byte { + h := hmac.New(sha1.New, key) + h.Write(data) + return h.Sum(p) +} + +func (attr *integrity) Check(key []byte) bool { + r := attr.raw + if len(r) < 44 { + return r == nil + } + be.PutUint16(r[2:], uint16(len(r)-20)) + h := attr.Sum(key, r[:len(r)-24], nil) + return bytes.Equal(h, attr.sum) +} + +var Fingerprint Attr = &fingerprint{} + +type fingerprint struct { + sum uint32 + raw []byte +} + +func (*fingerprint) Type() uint16 { + return AttrFingerprint +} + +func (attr *fingerprint) Marshal(p []byte) []byte { + r, b := grow(p, 4) + be.PutUint32(b, attr.sum) + return r +} + +func (attr *fingerprint) Unmarshal(b []byte) error { + if len(b) < 4 { + return errFormat + } + attr.sum = be.Uint32(b) + return nil +} + +func (attr *fingerprint) MarshalSum(p, raw []byte) []byte { + n := len(raw) - 4 + be.PutUint16(raw[2:], uint16(n-12)) + v := attr.Sum(raw[:n]) + r, b := grow(p, 4) + be.PutUint32(b, v) + return r +} + +func (attr *fingerprint) UnmarshalSum(p, raw []byte) error { + attr.raw = raw + return attr.Unmarshal(p) +} + +func (attr *fingerprint) Sum(p []byte) uint32 { + return crc32.ChecksumIEEE(p) ^ 0x5354554e +} + +func (attr *fingerprint) Check() bool { + r := attr.raw + if len(r) < 28 { + return r == nil + } + be.PutUint16(r[2:], uint16(len(r)-20)) + return attr.Sum(r[:len(r)-8]) == attr.sum +} diff --git a/vendor/github.com/isofew/go-stun/stun/conn.go b/vendor/github.com/isofew/go-stun/stun/conn.go new file mode 100644 index 0000000..9c562de --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/conn.go @@ -0,0 +1,110 @@ +package stun + +import ( + "github.com/pkg/errors" + "net" +) + +type Conn struct { + net.Conn + agent *Agent + sess *Session +} + +func NewConn(conn net.Conn, config *Config, stop chan struct{}) *Conn { + a := NewAgent(config) + go a.ServeConn(conn, stop) + return &Conn{conn, a, nil} +} + +func (c *Conn) Network() string { + return c.LocalAddr().Network() +} + +func (c *Conn) Discover() (net.Addr, error) { + res, err := c.Request(&Message{Type: MethodBinding}) + if err != nil { + return nil, err + } + mapped := res.GetAddr(c.Network(), AttrXorMappedAddress, AttrMappedAddress) + if mapped != nil { + return mapped, nil + } + return nil, errors.New("stun: bad response, no mapped address") +} + +func (c *Conn) Request(req *Message) (res *Message, err error) { + res, _, err = c.RequestTransport(req, c.Conn) + return +} + +func (c *Conn) RequestTransport(req *Message, to Transport) (res *Message, from Transport, err error) { + sess := c.sess + auth := c.agent.config.AuthMethod + if to == nil { + to = c.Conn + } + for { + msg := &Message{ + req.Type, + NewTransaction(), + append(sess.attrs(), req.Attributes...), + } + res, from, err = c.agent.RoundTrip(msg, to) + if err != nil { + return + } + code := res.GetError() + if code == nil { + // FIXME: authorize response... + if sess != nil { + c.sess = sess + } + return + } + err = code + switch code.Code { + case CodeUnauthorized, CodeStaleNonce: + if auth == nil { + return + } + sess = &Session{ + Realm: res.GetString(AttrRealm), + Nonce: res.GetString(AttrNonce), + } + if err = auth(sess); err != nil { + return + } + auth = nil + default: + return + } + } +} + +type Session struct { + Realm string + Nonce string + Username string + Key []byte +} + +func (s *Session) attrs() []Attr { + if s == nil { + return nil + } + var a []Attr + if s.Realm != "" { + a = append(a, String(AttrRealm, s.Realm)) + } + if s.Nonce != "" { + a = append(a, String(AttrNonce, s.Nonce)) + } + if s.Username != "" { + a = append(a, String(AttrUsername, s.Username)) + } + if s.Key != nil { + a = append(a, MessageIntegrity(s.Key)) + } + return a +} diff --git a/vendor/github.com/isofew/go-stun/stun/gen.go b/vendor/github.com/isofew/go-stun/stun/gen.go new file mode 100644 index 0000000..ace32bf --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/gen.go @@ -0,0 +1,203 @@ +//+build ignore + +//go:generate go run gen.go +// This program generates STUN parameters: methods, attributes and error codes by reading IANA registry. +package main + +import ( + "bytes" + "encoding/xml" + "fmt" + "go/format" + "io/ioutil" + "net/http" + "os" + "regexp" + "strings" + "unicode" + "unicode/utf8" +) + +func main() { + err := generate("http://www.iana.org/assignments/stun-parameters/stun-parameters.xml", "registry.go") + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func loadRegistry(url string) (*Registry, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + r := &Registry{} + err = xml.Unmarshal(b, r) + if err != nil { + return nil, err + } + name := regexp.MustCompile("^([\\w-]+)(.*was\\s+([\\w-]+))?") + for _, reg := range r.Registry { + for _, r := range reg.Records { + m := name.FindStringSubmatch(r.Description) + if m != nil { + r.Value = strings.ToLower(r.Value) + r.Name = m[1] + if r.Name == "Reserved" && m[3] != "" { + r.Name = m[3] + r.Deprecated = true + } + r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "rfc") + r.Ref.Data = strings.TrimPrefix(r.Ref.Data, "RFC-") + } + } + } + return r, nil +} + +type Record struct { + Value string `xml:"value"` + Description string `xml:"description"` + Name string + Deprecated bool + Ref struct { + Type string `xml:"type,attr"` + Data string `xml:"data,attr"` + } `xml:"xref"` +} + +func (r *Record) IsValid() bool { + return r.Name != "Reserved" && r.Name != "Unassigned" && (r.Ref.Type == "rfc" || r.Ref.Type == "draft") +} + +type Registry struct { + Title string `xml:"title"` + Updated string `xml:"updated"` + Registry []struct { + Id string `xml:"id,attr"` + Records []*Record `xml:"record"` + } `xml:"registry"` +} + +func (reg *Registry) GetRecords(id string) []*Record { + for _, it := range reg.Registry { + if it.Id == id { + return it.Records + } + } + return nil +} + +func generate(url, file string) error { + reg, err := loadRegistry(url) + if err != nil { + return err + } + + b := &bytes.Buffer{} + fmt.Fprintf(b, "package stun\n\n") + fmt.Fprintf(b, "// Do not edit. This file is generated by 'go generate gen.go'\n") + fmt.Fprintf(b, "// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA).\n") + fmt.Fprintf(b, "// %s, Updated: %s.\n\n", reg.Title, reg.Updated) + + genMethods(reg.GetRecords("stun-parameters-2"), b) + genAttributes(reg.GetRecords("stun-parameters-4"), b) + genErrors(reg.GetRecords("stun-parameters-6"), b) + + src, err := format.Source(b.Bytes()) + if err != nil { + return err + } + if err = ioutil.WriteFile(file, src, 0644); err != nil { + return err + } + return nil +} + +func genMethods(records []*Record, b *bytes.Buffer) error { + c, m := &bytes.Buffer{}, &bytes.Buffer{} + ref := "" + for _, it := range records { + if it.IsValid() { + if ref == it.Ref.Data { + fmt.Fprintf(c, "Method%v uint16 = %s\n", it.Name, it.Value) + } else { + ref = it.Ref.Data + fmt.Fprintf(c, "Method%v uint16 = %s // RFC %s\n", it.Name, it.Value, ref) + } + fmt.Fprintf(m, "Method%v: \"%s\",\n", it.Name, it.Name) + } + } + fmt.Fprintf(b, "// STUN methods.\n") + fmt.Fprintf(b, "const (\n%s)\n", c.Bytes()) + fmt.Fprintf(b, "// STUN method names.\n") + fmt.Fprintf(b, "var methodNames = map[uint16]string{\n%s}\n", m.Bytes()) + return nil +} + +func genAttributes(records []*Record, b *bytes.Buffer) error { + c, d, m := &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{} + ref := "" + for _, it := range records { + if it.IsValid() { + a := c + if it.Deprecated { + a = d + } + v := strings.Replace(it.Name, "_", "-", -1) + n := strings.Replace(v, "-", " ", -1) + parts := strings.Fields(n) + for i, s := range parts { + switch s { + case "", "ID": + default: + r, n := utf8.DecodeRuneInString(s) + s = string(unicode.ToUpper(r)) + strings.ToLower(s[n:]) + } + parts[i] = s + } + n = strings.Join(parts, "") + if ref == it.Ref.Data || it.Deprecated { + fmt.Fprintf(a, "Attr%v uint16 = %s\n", n, it.Value) + } else { + ref = it.Ref.Data + fmt.Fprintf(a, "Attr%v uint16 = %s // RFC %s\n", n, it.Value, ref) + } + fmt.Fprintf(m, "Attr%v: \"%s\",\n", n, v) + } + } + fmt.Fprintf(b, "// STUN attributes.\n") + fmt.Fprintf(b, "const (\n%s)\n", c.Bytes()) + fmt.Fprintf(b, "// Deprecated: For backwards compatibility only.\n") + fmt.Fprintf(b, "const (\n%s)\n", d.Bytes()) + fmt.Fprintf(b, "// STUN attribute names.\n") + fmt.Fprintf(b, "var attrNames = map[uint16]string{\n%s}\n", m.Bytes()) + return nil +} + +func genErrors(records []*Record, b *bytes.Buffer) error { + c, m := &bytes.Buffer{}, &bytes.Buffer{} + ref := "" + for _, it := range records { + if it.IsValid() { + n := strings.Replace(strings.Title(it.Description), " ", "", -1) + if ref == it.Ref.Data { + fmt.Fprintf(c, "Code%v int = %s\n", n, it.Value) + } else { + ref = it.Ref.Data + fmt.Fprintf(c, "Code%v int = %s // RFC %s\n", n, it.Value, ref) + } + fmt.Fprintf(m, "Code%v: \"%s\",\n", n, it.Description) + } + } + fmt.Fprintf(b, "// STUN error codes.\n") + fmt.Fprintf(b, "const (\n%s)\n", c.Bytes()) + fmt.Fprintf(b, "// STUN error texts.\n") + fmt.Fprintf(b, "var errorText = map[int]string{\n%s}\n", m.Bytes()) + return nil +} diff --git a/vendor/github.com/isofew/go-stun/stun/message.go b/vendor/github.com/isofew/go-stun/stun/message.go new file mode 100644 index 0000000..63d1b69 --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/message.go @@ -0,0 +1,351 @@ +package stun + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "math/rand" + "net" + "sort" + "strconv" +) + +const ( + KindRequest uint16 = 0x0000 + KindIndication uint16 = 0x0010 + KindResponse uint16 = 0x0100 + KindError uint16 = 0x0110 +) + +// Message represents a STUN message. +type Message struct { + Type uint16 + Transaction []byte + Attributes []Attr +} + +func (m *Message) Marshal(p []byte) []byte { + pos := len(p) + r, b := grow(p, 20) + be.PutUint16(b, m.Type) + + if m.Transaction != nil { + copy(b[4:], m.Transaction) + } else { + copy(b[4:], magicCookie) + rand.Read(b[8:20]) + } + + sort.Sort(byPosition(m.Attributes)) + for _, attr := range m.Attributes { + r = m.marshalAttr(r, attr, pos) + } + + be.PutUint16(r[pos+2:], uint16(len(r)-pos-20)) + return r +} + +func (m *Message) marshalAttr(p []byte, attr Attr, pos int) []byte { + h := len(p) + r, b := grow(p, 4) + be.PutUint16(b, attr.Type()) + + switch v := attr.(type) { + case *addr: + r = v.MarshalAddr(r, r[pos+4:]) + case *integrity: + r = v.MarshalSum(r, r[pos:]) + case *fingerprint: + r = v.MarshalSum(r, r[pos:]) + default: + r = v.Marshal(r) + } + n := len(r) - h - 4 + be.PutUint16(r[h+2:], uint16(n)) + + if pad := n & 3; pad != 0 { + r, b = grow(r, 4-pad) + for i := range b { + b[i] = 0 + } + } + return r +} + +func (m *Message) Unmarshal(b []byte) (n int, err error) { + if len(b) < 20 { + err = io.EOF + return + } + l := int(be.Uint16(b[2:])) + 20 + if len(b) < l { + err = io.EOF + return + } + pos, p := 20, make([]byte, l) + copy(p, b[:l]) + + m.Type = be.Uint16(p) + m.Transaction = p[4:20] + + for pos < len(p) { + s, attr, err := m.unmarshalAttr(p, pos) + if err != nil { + return 0, err + } + pos += s + if attr != nil { + m.Attributes = append(m.Attributes, attr) + } + } + + return l, nil +} + +func (m *Message) unmarshalAttr(p []byte, pos int) (n int, attr Attr, err error) { + b := p[pos:] + if len(b) < 4 { + err = errFormat + return + } + typ := be.Uint16(b) + attr, n = newAttr(typ), int(be.Uint16(b[2:]))+4 + if len(b) < n { + err = errFormat + return + } + + b = b[4:n] + if attr != nil { + switch v := attr.(type) { + case *addr: + err = v.UnmarshalAddr(b, m.Transaction) + case *integrity: + err = v.UnmarshalSum(b, p[:pos+n]) + case *fingerprint: + err = v.UnmarshalSum(b, p[:pos+n]) + default: + err = attr.Unmarshal(b) + } + } else if typ < 0x8000 { + err = errFormat + } + if err != nil { + err = &errAttribute{err, typ} + return + } + if pad := n & 3; pad != 0 { + n += 4 - pad + if len(p) < pos+n { + err = errFormat + } + } + return +} + +func (m *Message) Kind() uint16 { + return m.Type & 0x110 +} + +func (m *Message) Method() uint16 { + return m.Type &^ 0x110 +} + +func (m *Message) Add(attr Attr) { + m.Attributes = append(m.Attributes, attr) +} + +func (m *Message) Set(attr Attr) { + m.Del(attr.Type()) + m.Add(attr) +} + +func (m *Message) Del(typ uint16) { + n := 0 + for _, a := range m.Attributes { + if a.Type() != typ { + m.Attributes[n] = a + n++ + } + } + m.Attributes = m.Attributes[:n] +} + +func (m *Message) Get(typ uint16) (attr Attr) { + for _, attr = range m.Attributes { + if attr.Type() == typ { + return + } + } + return nil +} + +func (m *Message) Has(typ uint16) bool { + for _, attr := range m.Attributes { + if attr.Type() == typ { + return true + } + } + return false +} + +func (m *Message) GetString(typ uint16) string { + if str, ok := m.Get(typ).(fmt.Stringer); ok { + return str.String() + } + return "" +} + +func (m *Message) GetAddr(network string, typ ...uint16) net.Addr { + for _, t := range typ { + if addr, ok := m.Get(t).(*addr); ok { + return addr.Addr(network) + } + } + return nil +} + +func (m *Message) GetInt(typ uint16) (v uint64, ok bool) { + attr := m.Get(typ) + if r, ok := attr.(*number); ok { + return r.v, true + } + return +} + +func (m *Message) GetBytes(typ uint16) []byte { + if attr, ok := m.Get(typ).(*raw); ok { + return attr.data + } + return nil +} + +func (m *Message) GetError() *Error { + if err, ok := m.Get(AttrErrorCode).(*Error); ok { + return err + } + return nil +} + +func (m *Message) CheckIntegrity(key []byte) bool { + if attr, ok := m.Get(AttrMessageIntegrity).(*integrity); ok { + return attr.Check(key) + } + return false +} + +func (m *Message) CheckFingerprint() bool { + if attr, ok := m.Get(AttrFingerprint).(*fingerprint); ok { + return attr.Check() + } + return false +} + +func (m *Message) String() string { + sort.Sort(byPosition(m.Attributes)) + + // TODO: use sprintf + + b := &bytes.Buffer{} + b.WriteString(MethodName(m.Type)) + b.WriteByte('{') + tx := m.Transaction + if tx == nil { + b.WriteString("nil") + } else if bytes.Equal(magicCookie, tx[:4]) { + b.WriteString(hex.EncodeToString(tx[4:])) + } else { + b.WriteString(hex.EncodeToString(tx)) + } + for _, attr := range m.Attributes { + b.WriteString(", ") + b.WriteString(AttrName(attr.Type())) + switch v := attr.(type) { + case *raw: + b.WriteString(": \"") + b.Write(v.data) + b.WriteByte('"') + case *str: + b.WriteString(": \"") + b.WriteString(v.data) + b.WriteByte('"') + case flag, *integrity, *fingerprint: + default: + b.WriteString(fmt.Sprintf(": %v", attr)) + } + } + b.WriteByte('}') + return b.String() +} + +func MethodName(typ uint16) string { + if r, ok := methodNames[typ&^0x110]; ok { + switch typ & 0x110 { + case KindRequest: + return r + "Request" + case KindIndication: + return r + "Indication" + case KindResponse: + return r + "Response" + case KindError: + return r + "Error" + } + } + return "0x" + strconv.FormatUint(uint64(typ), 16) +} + +func UnmarshalMessage(b []byte) (*Message, error) { + m := &Message{} + if _, err := m.Unmarshal(b); err != nil { + return nil, err + } + return m, nil +} + +var magicCookie = []byte{0x21, 0x12, 0xa4, 0x42} +var alphanum = dict("01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +type dict []byte + +func (d dict) rand(n int) string { + m, b := len(d), make([]byte, n) + for i := range b { + b[i] = d[rand.Intn(m)] + } + return string(b) +} + +func NewTransaction() []byte { + id := make([]byte, 16) + copy(id, magicCookie) + rand.Read(id[4:]) // TODO: configure random source + return id +} + +type byPosition []Attr + +func (s byPosition) Len() int { return len(s) } +func (s byPosition) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byPosition) Less(i, j int) bool { + a, b := s[i].Type(), s[j].Type() + switch b { + case a: + return i < j + case AttrMessageIntegrity: + return a != AttrFingerprint + case AttrFingerprint: + return true + default: + return i < j + } +} + +type errAttribute struct { + error + typ uint16 +} + +func (err errAttribute) Error() string { + return "attribute " + AttrName(err.typ) + ": " + err.error.Error() +} diff --git a/vendor/github.com/isofew/go-stun/stun/nat.go b/vendor/github.com/isofew/go-stun/stun/nat.go new file mode 100644 index 0000000..b4f40d0 --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/nat.go @@ -0,0 +1,172 @@ +package stun + +import ( + "errors" + "net" +) + +const ( + EndpointIndependent = "endpoint-independent" + AddressDependent = "address-dependent" + AddressPortDependent = "address-port-dependent" +) + +type Detector struct { + *Conn +} + +func NewDetector(c *Conn) *Detector { + d := &Detector{c} + d.agent.Handler = &Server{agent: c.agent} + return d +} + +func (d *Detector) Hairpinning() error { + mapped, err := d.Discover() + if err != nil { + return err + } + conn, err := net.Dial(d.Network(), mapped.String()) + if err != nil { + return err + } + // not using stop channel + c := NewConn(conn, d.agent.config, make(chan struct{})) + defer c.Close() + _, err = c.Discover() + return err +} + +func (d *Detector) DiscoverChange(change uint64) error { + req := &Message{Type: MethodBinding, Attributes: []Attr{Int(AttrChangeRequest, change)}} + _, from, err := d.RequestTransport(req, d) + if err != nil { + return err + } + ip, port := SockAddr(d.RemoteAddr()) + chip, chport := SockAddr(from.RemoteAddr()) + if change&ChangeIP != 0 { + if ip.Equal(chip) { + return errors.New("stun: bad response, ip address is not changed") + } + } else if change&ChangePort != 0 { + if port == chport { + return errors.New("stun: bad response, port is not changed") + } + } + return nil +} + +func (d *Detector) Filtering() (string, error) { + n := d.Network() + if n != "udp" { + return "", errors.New("stun: filtering test is not applicable to " + n) + } + _, err := d.Request(&Message{Type: MethodBinding}) + if err != nil { + return "", err + } + err = d.DiscoverChange(ChangeIP | ChangePort) + switch err { + case nil: + return EndpointIndependent, nil + case errTimeout: + err = d.DiscoverChange(ChangePort) + switch err { + case nil: + return AddressDependent, nil + case errTimeout: + return AddressPortDependent, nil + } + } + return "", err +} + +func (d *Detector) DiscoverOther(addr net.Addr) (net.Addr, error) { + n := addr.Network() + conn, err := net.Dial(n, addr.String()) + if err != nil { + return nil, err + } + defer conn.Close() + // not using stop channel + go d.agent.ServeConn(conn, make(chan struct{})) + res, _, err := d.RequestTransport(&Message{Type: MethodBinding}, conn) + if err != nil { + return nil, err + } + mapped := res.GetAddr(n, AttrXorMappedAddress, AttrMappedAddress) + if mapped != nil { + return mapped, nil + } + return nil, errors.New("stun: bad response, no mapped address") +} + +func (d *Detector) Mapping() (string, error) { + n := d.Network() + msg, err := d.Request(&Message{Type: MethodBinding}) + if err != nil { + return "", err + } + mapped, other := msg.GetAddr(n, AttrXorMappedAddress), msg.GetAddr(n, AttrOtherAddress) + if mapped == nil { + return "", errors.New("stun: bad response, no mapped address") + } + if other == nil { + return "", errors.New("stun: bad response, no other address") + } + ip, _ := SockAddr(mapped) + if ip.IsLoopback() { + return EndpointIndependent, nil + } + for _, it := range local { + if it.IP.Equal(ip) { + return EndpointIndependent, nil + } + } + ip, _ = SockAddr(other) + _, port := SockAddr(d.RemoteAddr()) + a, err := d.DiscoverOther(NewAddr(n, ip, port)) + if err != nil { + return "", err + } + if sameAddr(a, mapped) { + return EndpointIndependent, nil + } + b, err := d.DiscoverOther(other) + if err != nil { + return "", err + } + if sameAddr(b, a) { + return AddressDependent, nil + } + return AddressPortDependent, nil +} + +func LocalAddrs() []*net.IPAddr { + return local +} + +var local []*net.IPAddr + +func init() { + ifaces, err := net.Interfaces() + if err != nil { + return + } + for _, iface := range ifaces { + addrs, _ := iface.Addrs() + for _, it := range addrs { + var ip net.IP + switch it := it.(type) { + case *net.IPNet: + ip = it.IP + case *net.IPAddr: + ip = it.IP + } + if ip != nil && ip.IsGlobalUnicast() { + local = append(local, &net.IPAddr{ip, iface.Name}) + } + } + } +} diff --git a/vendor/github.com/isofew/go-stun/stun/registry.go b/vendor/github.com/isofew/go-stun/stun/registry.go new file mode 100644 index 0000000..2072ab8 --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/registry.go @@ -0,0 +1,179 @@ +package stun + +// Do not edit. This file is generated by 'go generate gen.go' +// This file provides STUN parameters managed by the Internet Assigned Numbers Authority (IANA). +// Session Traversal Utilities for NAT (STUN) Parameters, Updated: 2016-11-14. + +// STUN methods. +const ( + MethodBinding uint16 = 0x001 // RFC 5389 + MethodSharedSecret uint16 = 0x002 + MethodAllocate uint16 = 0x003 // RFC 5766 + MethodRefresh uint16 = 0x004 + MethodSend uint16 = 0x006 + MethodData uint16 = 0x007 + MethodCreatePermission uint16 = 0x008 + MethodChannelBind uint16 = 0x009 + MethodConnect uint16 = 0x00a // RFC 6062 + MethodConnectionBind uint16 = 0x00b + MethodConnectionAttempt uint16 = 0x00c +) + +// STUN method names. +var methodNames = map[uint16]string{ + MethodBinding: "Binding", + MethodSharedSecret: "SharedSecret", + MethodAllocate: "Allocate", + MethodRefresh: "Refresh", + MethodSend: "Send", + MethodData: "Data", + MethodCreatePermission: "CreatePermission", + MethodChannelBind: "ChannelBind", + MethodConnect: "Connect", + MethodConnectionBind: "ConnectionBind", + MethodConnectionAttempt: "ConnectionAttempt", +} + +// STUN attributes. +const ( + AttrMappedAddress uint16 = 0x0001 // RFC 5389 + AttrChangeRequest uint16 = 0x0003 // RFC 5780 + AttrUsername uint16 = 0x0006 // RFC 5389 + AttrMessageIntegrity uint16 = 0x0008 + AttrErrorCode uint16 = 0x0009 + AttrUnknownAttributes uint16 = 0x000a + AttrChannelNumber uint16 = 0x000c // RFC 5766 + AttrLifetime uint16 = 0x000d + AttrXorPeerAddress uint16 = 0x0012 + AttrData uint16 = 0x0013 + AttrRealm uint16 = 0x0014 // RFC 5389 + AttrNonce uint16 = 0x0015 + AttrXorRelayedAddress uint16 = 0x0016 // RFC 5766 + AttrRequestedAddressFamily uint16 = 0x0017 // RFC 6156 + AttrEvenPort uint16 = 0x0018 // RFC 5766 + AttrRequestedTransport uint16 = 0x0019 + AttrDontFragment uint16 = 0x001a + AttrAccessToken uint16 = 0x001b // RFC 7635 + AttrXorMappedAddress uint16 = 0x0020 // RFC 5389 + AttrReservationToken uint16 = 0x0022 // RFC 5766 + AttrPriority uint16 = 0x0024 // RFC 5245 + AttrUseCandidate uint16 = 0x0025 + AttrPadding uint16 = 0x0026 // RFC 5780 + AttrResponsePort uint16 = 0x0027 + AttrConnectionID uint16 = 0x002a // RFC 6062 + AttrSoftware uint16 = 0x8022 // RFC 5389 + AttrAlternateServer uint16 = 0x8023 + AttrTransactionTransmitCounter uint16 = 0x8025 // RFC 7982 + AttrCacheTimeout uint16 = 0x8027 // RFC 5780 + AttrFingerprint uint16 = 0x8028 // RFC 5389 + AttrIceControlled uint16 = 0x8029 // RFC 5245 + AttrIceControlling uint16 = 0x802a + AttrResponseOrigin uint16 = 0x802b // RFC 5780 + AttrOtherAddress uint16 = 0x802c + AttrEcnCheck uint16 = 0x802d // RFC 6679 + AttrThirdPartyAuthorization uint16 = 0x802e // RFC 7635 + AttrMobilityTicket uint16 = 0x8030 // RFC 8016 +) + +// Deprecated: For backwards compatibility only. +const ( + AttrResponseAddress uint16 = 0x0002 + AttrSourceAddress uint16 = 0x0004 + AttrChangedAddress uint16 = 0x0005 + AttrPassword uint16 = 0x0007 + AttrReflectedFrom uint16 = 0x000b + AttrBandwidth uint16 = 0x0010 + AttrTimerVal uint16 = 0x0021 +) + +// STUN attribute names. +var attrNames = map[uint16]string{ + AttrMappedAddress: "MAPPED-ADDRESS", + AttrResponseAddress: "RESPONSE-ADDRESS", + AttrChangeRequest: "CHANGE-REQUEST", + AttrSourceAddress: "SOURCE-ADDRESS", + AttrChangedAddress: "CHANGED-ADDRESS", + AttrUsername: "USERNAME", + AttrPassword: "PASSWORD", + AttrMessageIntegrity: "MESSAGE-INTEGRITY", + AttrErrorCode: "ERROR-CODE", + AttrUnknownAttributes: "UNKNOWN-ATTRIBUTES", + AttrReflectedFrom: "REFLECTED-FROM", + AttrChannelNumber: "CHANNEL-NUMBER", + AttrLifetime: "LIFETIME", + AttrBandwidth: "BANDWIDTH", + AttrXorPeerAddress: "XOR-PEER-ADDRESS", + AttrData: "DATA", + AttrRealm: "REALM", + AttrNonce: "NONCE", + AttrXorRelayedAddress: "XOR-RELAYED-ADDRESS", + AttrRequestedAddressFamily: "REQUESTED-ADDRESS-FAMILY", + AttrEvenPort: "EVEN-PORT", + AttrRequestedTransport: "REQUESTED-TRANSPORT", + AttrDontFragment: "DONT-FRAGMENT", + AttrAccessToken: "ACCESS-TOKEN", + AttrXorMappedAddress: "XOR-MAPPED-ADDRESS", + AttrTimerVal: "TIMER-VAL", + AttrReservationToken: "RESERVATION-TOKEN", + AttrPriority: "PRIORITY", + AttrUseCandidate: "USE-CANDIDATE", + AttrPadding: "PADDING", + AttrResponsePort: "RESPONSE-PORT", + AttrConnectionID: "CONNECTION-ID", + AttrSoftware: "SOFTWARE", + AttrAlternateServer: "ALTERNATE-SERVER", + AttrTransactionTransmitCounter: "TRANSACTION-TRANSMIT-COUNTER", + AttrCacheTimeout: "CACHE-TIMEOUT", + AttrFingerprint: "FINGERPRINT", + AttrIceControlled: "ICE-CONTROLLED", + AttrIceControlling: "ICE-CONTROLLING", + AttrResponseOrigin: "RESPONSE-ORIGIN", + AttrOtherAddress: "OTHER-ADDRESS", + AttrEcnCheck: "ECN-CHECK", + AttrThirdPartyAuthorization: "THIRD-PARTY-AUTHORIZATION", + AttrMobilityTicket: "MOBILITY-TICKET", +} + +// STUN error codes. +const ( + CodeTryAlternate int = 300 // RFC 5389 + CodeBadRequest int = 400 + CodeUnauthorized int = 401 + CodeForbidden int = 403 // RFC 5766 + CodeMobilityForbidden int = 405 // RFC 8016 + CodeUnknownAttribute int = 420 // RFC 5389 + CodeAllocationMismatch int = 437 // RFC 5766 + CodeStaleNonce int = 438 // RFC 5389 + CodeAddressFamilyNotSupported int = 440 // RFC 6156 + CodeWrongCredentials int = 441 // RFC 5766 + CodeUnsupportedTransportProtocol int = 442 + CodePeerAddressFamilyMismatch int = 443 // RFC 6156 + CodeConnectionAlreadyExists int = 446 // RFC 6062 + CodeConnectionTimeoutOrFailure int = 447 + CodeAllocationQuotaReached int = 486 // RFC 5766 + CodeRoleConflict int = 487 // RFC 5245 + CodeServerError int = 500 // RFC 5389 + CodeInsufficientCapacity int = 508 // RFC 5766 +) + +// STUN error texts. +var errorText = map[int]string{ + CodeTryAlternate: "Try Alternate", + CodeBadRequest: "Bad Request", + CodeUnauthorized: "Unauthorized", + CodeForbidden: "Forbidden", + CodeMobilityForbidden: "Mobility Forbidden", + CodeUnknownAttribute: "Unknown Attribute", + CodeAllocationMismatch: "Allocation Mismatch", + CodeStaleNonce: "Stale Nonce", + CodeAddressFamilyNotSupported: "Address Family not Supported", + CodeWrongCredentials: "Wrong Credentials", + CodeUnsupportedTransportProtocol: "Unsupported Transport Protocol", + CodePeerAddressFamilyMismatch: "Peer Address Family Mismatch", + CodeConnectionAlreadyExists: "Connection Already Exists", + CodeConnectionTimeoutOrFailure: "Connection Timeout or Failure", + CodeAllocationQuotaReached: "Allocation Quota Reached", + CodeRoleConflict: "Role Conflict", + CodeServerError: "Server Error", + CodeInsufficientCapacity: "Insufficient Capacity", +} diff --git a/vendor/github.com/isofew/go-stun/stun/server.go b/vendor/github.com/isofew/go-stun/stun/server.go new file mode 100644 index 0000000..3331d85 --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/server.go @@ -0,0 +1,126 @@ +package stun + +import ( + "net" + "sync" +) + +func ListenAndServe(network, laddr string, config *Config) error { + srv := NewServer(config) + return srv.ListenAndServe(network, laddr) +} + +type Server struct { + agent *Agent + + mu sync.RWMutex + conns []net.PacketConn +} + +func NewServer(config *Config) *Server { + srv := &Server{agent: NewAgent(config)} + srv.agent.Handler = srv + return srv +} + +func (srv *Server) ListenAndServe(network, laddr string) error { + c, err := net.ListenPacket(network, laddr) + if err != nil { + return err + } + srv.addConn(c) + defer srv.removeConn(c) + // not using stop channel + return srv.agent.ServePacket(c, make(chan struct{})) +} + +func (srv *Server) ServeSTUN(msg *Message, from Transport) { + if msg.Type == MethodBinding { + to := from + mapped := from.RemoteAddr() + ip, port := SockAddr(from.LocalAddr()) + + res := &Message{ + Type: MethodBinding | KindResponse, + Transaction: msg.Transaction, + Attributes: []Attr{ + Addr(AttrXorMappedAddress, mapped), + Addr(AttrMappedAddress, mapped), + }, + } + + srv.mu.RLock() + defer srv.mu.RUnlock() + + if ch, ok := msg.GetInt(AttrChangeRequest); ok && ch != 0 { + for _, c := range srv.conns { + chip, chport := SockAddr(c.LocalAddr()) + if chip.IsUnspecified() { + continue + } + if ch&ChangeIP != 0 { + if !ip.Equal(chip) { + to = &packetConn{c, mapped} + break + } + } else if ch&ChangePort != 0 { + if ip.Equal(chip) && port != chport { + to = &packetConn{c, mapped} + break + } + } + } + } + + if len(srv.conns) < 2 { + srv.agent.Send(res, to) + return + } + + other: + for _, a := range srv.conns { + aip, aport := SockAddr(a.LocalAddr()) + if aip.IsUnspecified() || !ip.Equal(aip) || port == aport { + continue + } + for _, b := range srv.conns { + bip, bport := SockAddr(b.LocalAddr()) + if bip.IsUnspecified() || bip.Equal(ip) || aport != bport { + continue + } + res.Set(Addr(AttrOtherAddress, b.LocalAddr())) + break other + } + } + + srv.agent.Send(res, to) + } +} + +func (srv *Server) addConn(c net.PacketConn) { + srv.mu.Lock() + srv.conns = append(srv.conns, c) + srv.mu.Unlock() +} + +func (srv *Server) removeConn(c net.PacketConn) { + srv.mu.Lock() + l := srv.conns + for i, it := range l { + if it == c { + srv.conns = append(l[:i], l[i+1:]...) + break + } + } + srv.mu.Unlock() +} + +func (srv *Server) Close() error { + srv.mu.RLock() + defer srv.mu.RUnlock() + for _, it := range srv.conns { + it.Close() + } + srv.conns = nil + return nil +} diff --git a/vendor/github.com/isofew/go-stun/stun/stun.go b/vendor/github.com/isofew/go-stun/stun/stun.go new file mode 100644 index 0000000..5901d75 --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/stun.go @@ -0,0 +1,129 @@ +package stun + +import ( + "crypto/md5" + "crypto/tls" + "errors" + "net" + "net/url" + "strings" +) + +func Discover(uri string) (net.PacketConn, net.Addr, error) { + stop := make(chan struct{}) + conn, err := Dial(uri, nil, stop) + if err != nil { + return nil, nil, err + } + addr, err := conn.Discover() + if err != nil { + conn.Close() + return nil, nil, err + } + // TODO: hijack + // stop reading conn before returning it + close(stop) + // note. serveconn/packet func is blocked by read/readfrom at the time + // we send the signal, which means it will still consume one more + // packet and we can only read starting from the second packet. + // (not too much of a problem, since we'll punch a few packets anyway) + return conn.Conn.(net.PacketConn), addr, nil +} + +type AuthMethod func(sess *Session) error + +// LongTermAuthMethod returns AuthMethod for long-term credentials. +// Key = MD5(username ":" realm ":" SASLprep(password)). +// SASLprep is defined in RFC 4013. +func LongTermAuthMethod(username, password string) AuthMethod { + return func(sess *Session) error { + h := md5.New() + h.Write([]byte(username + ":" + sess.Realm + ":" + password)) + sess.Username = username + sess.Key = h.Sum(nil) + return nil + } +} + +// ShotTermAuthMethod returns AuthMethod for short-term credentials. +// Key = SASLprep(password). +// SASLprep is defined in RFC 4013. +func ShortTermAuthMethod(password string) AuthMethod { + key := []byte(password) + return func(sess *Session) error { + sess.Key = key + return nil + } +} + +func Dial(uri string, config *Config, stop chan struct{}) (*Conn, error) { + secure, network, addr, auth, err := parseURI(uri) + if err != nil { + return nil, err + } + var conn net.Conn + if secure { + conn, err = tls.Dial(network, addr, nil) + } else { + if strings.HasPrefix(network, "udp") { + conn, err = dialUDP(network, addr) + } else { + conn, err = dialTCP(network, addr) + } + } + if err != nil { + return nil, err + } + if auth != nil { + config = config.Clone() + config.AuthMethod = auth + } + return NewConn(conn, config, stop), nil +} + +func parseURI(uri string) (secure bool, network, addr string, auth AuthMethod, err error) { + var u *url.URL + if u, err = url.Parse(uri); err != nil { + return + } + host, port, e := net.SplitHostPort(u.Opaque) + if e != nil { + host = u.Opaque + } + if a := u.User; a != nil { + if password, ok := a.Password(); ok { + auth = LongTermAuthMethod(a.Username(), password) + } else { + auth = ShortTermAuthMethod(a.Username()) + } + } + network = u.Query().Get("transport") + if network == "" { + network = "udp" + } + switch u.Scheme { + case "stun", "turn": + if port == "" { + port = "3478" + } + switch network { + case "udp", "udp4", "udp6", "tcp", "tcp4", "tcp6": + default: + err = errors.New("stun: unsupported transport: " + network) + } + case "stuns", "turns": + if port == "" { + port = "5478" + } + secure = true + switch network { + case "tcp", "tcp4", "tcp6": + default: + err = errors.New("stun: unsupported transport: " + network) + } + default: + err = errors.New("stun: unsupported scheme " + u.Scheme) + } + addr = net.JoinHostPort(host, port) + return +} diff --git a/vendor/github.com/isofew/go-stun/stun/transport.go b/vendor/github.com/isofew/go-stun/stun/transport.go new file mode 100644 index 0000000..dff0a65 --- /dev/null +++ b/vendor/github.com/isofew/go-stun/stun/transport.go @@ -0,0 +1,96 @@ +package stun + +import ( + "encoding/binary" + "errors" + "net" +) + +type Listener interface { + Addr() net.Addr + Close() error +} + +type Transport interface { + LocalAddr() net.Addr + RemoteAddr() net.Addr + Write(p []byte) (int, error) + Close() error +} + +type Marshaler interface { + Marshal(b []byte) []byte +} + +type TransportHandler interface { + ServeTransport(b []byte, tr Transport) (int, error) +} + +func dialUDP(network, raddr string) (net.Conn, error) { + addr, err := net.ResolveUDPAddr(network, raddr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP(network, nil) + if err != nil { + return nil, err + } + return &packetConn{conn, addr}, nil +} + +func dialTCP(network, raddr string) (net.Conn, error) { + addr, err := net.ResolveTCPAddr(network, raddr) + if err != nil { + return nil, err + } + return net.DialTCP(network, nil, addr) +} + +type packetConn struct { + net.PacketConn + addr net.Addr +} + +func (t *packetConn) Read(p []byte) (n int, err error) { + n, _, err = t.ReadFrom(p) + return +} + +func (t *packetConn) Write(p []byte) (int, error) { + return t.WriteTo(p, t.addr) +} + +func (t *packetConn) RemoteAddr() net.Addr { + return t.addr +} + +var ( + errBufferOverflow = errors.New("stun: buffer overflow") + errFormat = errors.New("stun: format error") +) + +func getBuffer() []byte { + return make([]byte, 2048) +} + +func putBuffer(b []byte) { + if cap(b) >= 2048 { + } +} + +func grow(p []byte, n int) (b, a []byte) { + l := len(p) + r := l + n + if r > cap(p) { + b = make([]byte, (1+((r-1)>>10))<<10)[:r] + a = b[l:r] + if l > 0 { + copy(b, p[:l]) + } + } else { + return p[:r], p[l:r] + } + return +} + +var be = binary.BigEndian diff --git a/vendor/github.com/lucas-clemente/quic-go/Changelog.md b/vendor/github.com/lucas-clemente/quic-go/Changelog.md index 4725779..5a6a392 100644 --- a/vendor/github.com/lucas-clemente/quic-go/Changelog.md +++ b/vendor/github.com/lucas-clemente/quic-go/Changelog.md @@ -1,6 +1,18 @@ # Changelog -## v0.6.0 (unreleased) +## v0.8.0 (unreleased) + +- Add support for unidirectional streams (for IETF QUIC). +- Add a `quic.Config` option for the maximum number of incoming streams. + +## v0.7.0 (2018-02-03) + +- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored. +- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`. +- Expose the `ConnectionState` in the `Session` (experimental API). +- Implement packet pacing. + +## v0.6.0 (2017-12-12) - Add support for QUIC 39, drop support for QUIC 35 - 37 - Added `quic.Config` options for maximal flow control windows diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/_gen.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/_gen.go deleted file mode 100644 index 154515b..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/_gen.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import ( - _ "github.com/clipperhouse/linkedlist" - _ "github.com/clipperhouse/slice" - _ "github.com/clipperhouse/stringer" -) diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go deleted file mode 100644 index 8492fd4..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go +++ /dev/null @@ -1,34 +0,0 @@ -package ackhandler - -import ( - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -// SentPacketHandler handles ACKs received for outgoing packets -type SentPacketHandler interface { - // SentPacket may modify the packet - SentPacket(packet *Packet) error - ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error - SetHandshakeComplete() - - SendingAllowed() bool - GetStopWaitingFrame(force bool) *wire.StopWaitingFrame - ShouldSendRetransmittablePacket() bool - DequeuePacketForRetransmission() (packet *Packet) - GetLeastUnacked() protocol.PacketNumber - - GetAlarmTimeout() time.Time - OnAlarm() -} - -// ReceivedPacketHandler handles ACKs needed to send for incoming packets -type ReceivedPacketHandler interface { - ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error - SetLowerLimit(protocol.PacketNumber) - - GetAlarmTimeout() time.Time - GetAckFrame() *wire.AckFrame -} diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go deleted file mode 100644 index 9c4ee30..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go +++ /dev/null @@ -1,34 +0,0 @@ -package ackhandler - -import ( - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -// A Packet is a packet -// +gen linkedlist -type Packet struct { - PacketNumber protocol.PacketNumber - Frames []wire.Frame - Length protocol.ByteCount - EncryptionLevel protocol.EncryptionLevel - - SendTime time.Time -} - -// GetFramesForRetransmission gets all the frames for retransmission -func (p *Packet) GetFramesForRetransmission() []wire.Frame { - var fs []wire.Frame - for _, frame := range p.Frames { - switch frame.(type) { - case *wire.AckFrame: - continue - case *wire.StopWaitingFrame: - continue - } - fs = append(fs, frame) - } - return fs -} diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go deleted file mode 100644 index d0cf78d..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go +++ /dev/null @@ -1,141 +0,0 @@ -package ackhandler - -import ( - "errors" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") - -type receivedPacketHandler struct { - largestObserved protocol.PacketNumber - lowerLimit protocol.PacketNumber - largestObservedReceivedTime time.Time - - packetHistory *receivedPacketHistory - - ackSendDelay time.Duration - - packetsReceivedSinceLastAck int - retransmittablePacketsReceivedSinceLastAck int - ackQueued bool - ackAlarm time.Time - lastAck *wire.AckFrame - - version protocol.VersionNumber -} - -// NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler { - return &receivedPacketHandler{ - packetHistory: newReceivedPacketHistory(), - ackSendDelay: protocol.AckSendDelay, - version: version, - } -} - -func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { - if packetNumber == 0 { - return errInvalidPacketNumber - } - - if packetNumber > h.largestObserved { - h.largestObserved = packetNumber - h.largestObservedReceivedTime = time.Now() - } - - if packetNumber <= h.lowerLimit { - return nil - } - - if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { - return err - } - h.maybeQueueAck(packetNumber, shouldInstigateAck) - return nil -} - -// SetLowerLimit sets a lower limit for acking packets. -// Packets with packet numbers smaller or equal than p will not be acked. -func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) { - h.lowerLimit = p - h.packetHistory.DeleteUpTo(p) -} - -func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { - h.packetsReceivedSinceLastAck++ - - if shouldInstigateAck { - h.retransmittablePacketsReceivedSinceLastAck++ - } - - // always ack the first packet - if h.lastAck == nil { - h.ackQueued = true - } - - if h.version < protocol.Version39 { - // Always send an ack every 20 packets in order to allow the peer to discard - // information from the SentPacketManager and provide an RTT measurement. - // From QUIC 39, this is not needed anymore, since the peer will regularly send a retransmittable packet. - if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend { - h.ackQueued = true - } - } - - // if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK - // note that it cannot be a duplicate because they're already filtered out by ReceivedPacket() - if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked { - h.ackQueued = true - } - - // check if a new missing range above the previously was created - if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked { - h.ackQueued = true - } - - if !h.ackQueued && shouldInstigateAck { - if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck { - h.ackQueued = true - } else { - if h.ackAlarm.IsZero() { - h.ackAlarm = time.Now().Add(h.ackSendDelay) - } - } - } - - if h.ackQueued { - // cancel the ack alarm - h.ackAlarm = time.Time{} - } -} - -func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { - if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) { - return nil - } - - ackRanges := h.packetHistory.GetAckRanges() - ack := &wire.AckFrame{ - LargestAcked: h.largestObserved, - LowestAcked: ackRanges[len(ackRanges)-1].First, - PacketReceivedTime: h.largestObservedReceivedTime, - } - - if len(ackRanges) > 1 { - ack.AckRanges = ackRanges - } - - h.lastAck = ack - h.ackAlarm = time.Time{} - h.ackQueued = false - h.packetsReceivedSinceLastAck = 0 - h.retransmittablePacketsReceivedSinceLastAck = 0 - - return ack -} - -func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go deleted file mode 100644 index 68267aa..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go +++ /dev/null @@ -1,455 +0,0 @@ -package ackhandler - -import ( - "errors" - "fmt" - "time" - - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/qerr" -) - -const ( - // Maximum reordering in time space before time based loss detection considers a packet lost. - // In fraction of an RTT. - timeReorderingFraction = 1.0 / 8 - // The default RTT used before an RTT sample is taken. - // Note: This constant is also defined in the congestion package. - defaultInitialRTT = 100 * time.Millisecond - // defaultRTOTimeout is the RTO time on new connections - defaultRTOTimeout = 500 * time.Millisecond - // Minimum time in the future a tail loss probe alarm may be set for. - minTPLTimeout = 10 * time.Millisecond - // Minimum time in the future an RTO alarm may be set for. - minRTOTimeout = 200 * time.Millisecond - // maxRTOTimeout is the maximum RTO time - maxRTOTimeout = 60 * time.Second -) - -var ( - // ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received - ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK") - // ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets - ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets") - // ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped - ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number") - errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package") -) - -var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number") - -type sentPacketHandler struct { - lastSentPacketNumber protocol.PacketNumber - skippedPackets []protocol.PacketNumber - - numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet - - LargestAcked protocol.PacketNumber - - largestReceivedPacketWithAck protocol.PacketNumber - - packetHistory *PacketList - stopWaitingManager stopWaitingManager - - retransmissionQueue []*Packet - - bytesInFlight protocol.ByteCount - - congestion congestion.SendAlgorithm - rttStats *congestion.RTTStats - - handshakeComplete bool - // The number of times the handshake packets have been retransmitted without receiving an ack. - handshakeCount uint32 - - // The number of times an RTO has been sent without receiving an ack. - rtoCount uint32 - - // The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time. - lossTime time.Time - - // The alarm timeout - alarm time.Time -} - -// NewSentPacketHandler creates a new sentPacketHandler -func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { - congestion := congestion.NewCubicSender( - congestion.DefaultClock{}, - rttStats, - false, /* don't use reno since chromium doesn't (why?) */ - protocol.InitialCongestionWindow, - protocol.DefaultMaxCongestionWindow, - ) - - return &sentPacketHandler{ - packetHistory: NewPacketList(), - stopWaitingManager: stopWaitingManager{}, - rttStats: rttStats, - congestion: congestion, - } -} - -func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber { - if f := h.packetHistory.Front(); f != nil { - return f.Value.PacketNumber - 1 - } - return h.LargestAcked -} - -func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool { - return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets -} - -func (h *sentPacketHandler) SetHandshakeComplete() { - h.handshakeComplete = true -} - -func (h *sentPacketHandler) SentPacket(packet *Packet) error { - if packet.PacketNumber <= h.lastSentPacketNumber { - return errPacketNumberNotIncreasing - } - - if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets { - return ErrTooManyTrackedSentPackets - } - - for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { - h.skippedPackets = append(h.skippedPackets, p) - - if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets { - h.skippedPackets = h.skippedPackets[1:] - } - } - - h.lastSentPacketNumber = packet.PacketNumber - now := time.Now() - - packet.Frames = stripNonRetransmittableFrames(packet.Frames) - isRetransmittable := len(packet.Frames) != 0 - - if isRetransmittable { - packet.SendTime = now - h.bytesInFlight += packet.Length - h.packetHistory.PushBack(*packet) - h.numNonRetransmittablePackets = 0 - } else { - h.numNonRetransmittablePackets++ - } - - h.congestion.OnPacketSent( - now, - h.bytesInFlight, - packet.PacketNumber, - packet.Length, - isRetransmittable, - ) - - h.updateLossDetectionAlarm() - return nil -} - -func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { - if ackFrame.LargestAcked > h.lastSentPacketNumber { - return errAckForUnsentPacket - } - - // duplicate or out-of-order ACK - if withPacketNumber <= h.largestReceivedPacketWithAck { - return ErrDuplicateOrOutOfOrderAck - } - h.largestReceivedPacketWithAck = withPacketNumber - - // ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK) - if ackFrame.LargestAcked <= h.largestInOrderAcked() { - return nil - } - h.LargestAcked = ackFrame.LargestAcked - - if h.skippedPacketsAcked(ackFrame) { - return ErrAckForSkippedPacket - } - - rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime) - - if rttUpdated { - h.congestion.MaybeExitSlowStart() - } - - ackedPackets, err := h.determineNewlyAckedPackets(ackFrame) - if err != nil { - return err - } - - if len(ackedPackets) > 0 { - for _, p := range ackedPackets { - if encLevel < p.Value.EncryptionLevel { - return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel) - } - h.onPacketAcked(p) - h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) - } - } - - h.detectLostPackets() - h.updateLossDetectionAlarm() - - h.garbageCollectSkippedPackets() - h.stopWaitingManager.ReceivedAck(ackFrame) - - return nil -} - -func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) { - var ackedPackets []*PacketElement - ackRangeIndex := 0 - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - packet := el.Value - packetNumber := packet.PacketNumber - - // Ignore packets below the LowestAcked - if packetNumber < ackFrame.LowestAcked { - continue - } - // Break after LargestAcked is reached - if packetNumber > ackFrame.LargestAcked { - break - } - - if ackFrame.HasMissingRanges() { - ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] - - for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 { - ackRangeIndex++ - ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] - } - - if packetNumber >= ackRange.First { // packet i contained in ACK range - if packetNumber > ackRange.Last { - return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last) - } - ackedPackets = append(ackedPackets, el) - } - } else { - ackedPackets = append(ackedPackets, el) - } - } - - return ackedPackets, nil -} - -func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool { - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - packet := el.Value - if packet.PacketNumber == largestAcked { - h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now()) - return true - } - // Packets are sorted by number, so we can stop searching - if packet.PacketNumber > largestAcked { - break - } - } - return false -} - -func (h *sentPacketHandler) updateLossDetectionAlarm() { - // Cancel the alarm if no packets are outstanding - if h.packetHistory.Len() == 0 { - h.alarm = time.Time{} - return - } - - // TODO(#497): TLP - if !h.handshakeComplete { - h.alarm = time.Now().Add(h.computeHandshakeTimeout()) - } else if !h.lossTime.IsZero() { - // Early retransmit timer or time loss detection. - h.alarm = h.lossTime - } else { - // RTO - h.alarm = time.Now().Add(h.computeRTOTimeout()) - } -} - -func (h *sentPacketHandler) detectLostPackets() { - h.lossTime = time.Time{} - now := time.Now() - - maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) - delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT) - - var lostPackets []*PacketElement - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - packet := el.Value - - if packet.PacketNumber > h.LargestAcked { - break - } - - timeSinceSent := now.Sub(packet.SendTime) - if timeSinceSent > delayUntilLost { - lostPackets = append(lostPackets, el) - } else if h.lossTime.IsZero() { - // Note: This conditional is only entered once per call - h.lossTime = now.Add(delayUntilLost - timeSinceSent) - } - } - - if len(lostPackets) > 0 { - for _, p := range lostPackets { - h.queuePacketForRetransmission(p) - h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) - } - } -} - -func (h *sentPacketHandler) OnAlarm() { - // TODO(#497): TLP - if !h.handshakeComplete { - h.queueHandshakePacketsForRetransmission() - h.handshakeCount++ - } else if !h.lossTime.IsZero() { - // Early retransmit or time loss detection - h.detectLostPackets() - } else { - // RTO - h.retransmitOldestTwoPackets() - h.rtoCount++ - } - - h.updateLossDetectionAlarm() -} - -func (h *sentPacketHandler) GetAlarmTimeout() time.Time { - return h.alarm -} - -func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) { - h.bytesInFlight -= packetElement.Value.Length - h.rtoCount = 0 - h.handshakeCount = 0 - // TODO(#497): h.tlpCount = 0 - h.packetHistory.Remove(packetElement) -} - -func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { - if len(h.retransmissionQueue) == 0 { - return nil - } - packet := h.retransmissionQueue[0] - // Shift the slice and don't retain anything that isn't needed. - copy(h.retransmissionQueue, h.retransmissionQueue[1:]) - h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil - h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1] - return packet -} - -func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber { - return h.largestInOrderAcked() + 1 -} - -func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { - return h.stopWaitingManager.GetStopWaitingFrame(force) -} - -func (h *sentPacketHandler) SendingAllowed() bool { - congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow() - maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets - if congestionLimited { - utils.Debugf("Congestion limited: bytes in flight %d, window %d", - h.bytesInFlight, - h.congestion.GetCongestionWindow()) - } - // Workaround for #555: - // Always allow sending of retransmissions. This should probably be limited - // to RTOs, but we currently don't have a nice way of distinguishing them. - haveRetransmissions := len(h.retransmissionQueue) > 0 - return !maxTrackedLimited && (!congestionLimited || haveRetransmissions) -} - -func (h *sentPacketHandler) retransmitOldestTwoPackets() { - if p := h.packetHistory.Front(); p != nil { - h.queueRTO(p) - } - if p := h.packetHistory.Front(); p != nil { - h.queueRTO(p) - } -} - -func (h *sentPacketHandler) queueRTO(el *PacketElement) { - packet := &el.Value - utils.Debugf( - "\tQueueing packet 0x%x for retransmission (RTO), %d outstanding", - packet.PacketNumber, - h.packetHistory.Len(), - ) - h.queuePacketForRetransmission(el) - h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight) - h.congestion.OnRetransmissionTimeout(true) -} - -func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() { - var handshakePackets []*PacketElement - for el := h.packetHistory.Front(); el != nil; el = el.Next() { - if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { - handshakePackets = append(handshakePackets, el) - } - } - for _, el := range handshakePackets { - h.queuePacketForRetransmission(el) - } -} - -func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) { - packet := &packetElement.Value - h.bytesInFlight -= packet.Length - h.retransmissionQueue = append(h.retransmissionQueue, packet) - h.packetHistory.Remove(packetElement) - h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber) -} - -func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { - duration := 2 * h.rttStats.SmoothedRTT() - if duration == 0 { - duration = 2 * defaultInitialRTT - } - duration = utils.MaxDuration(duration, minTPLTimeout) - // exponential backoff - // There's an implicit limit to this set by the handshake timeout. - return duration << h.handshakeCount -} - -func (h *sentPacketHandler) computeRTOTimeout() time.Duration { - rto := h.congestion.RetransmissionDelay() - if rto == 0 { - rto = defaultRTOTimeout - } - rto = utils.MaxDuration(rto, minRTOTimeout) - // Exponential backoff - rto = rto << h.rtoCount - return utils.MinDuration(rto, maxRTOTimeout) -} - -func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool { - for _, p := range h.skippedPackets { - if ackFrame.AcksPacket(p) { - return true - } - } - return false -} - -func (h *sentPacketHandler) garbageCollectSkippedPackets() { - lioa := h.largestInOrderAcked() - deleteIndex := 0 - for i, p := range h.skippedPackets { - if p <= lioa { - deleteIndex = i + 1 - } - } - h.skippedPackets = h.skippedPackets[deleteIndex:] -} diff --git a/vendor/github.com/lucas-clemente/quic-go/appveyor.yml b/vendor/github.com/lucas-clemente/quic-go/appveyor.yml index bcd3ac5..b7bb030 100644 --- a/vendor/github.com/lucas-clemente/quic-go/appveyor.yml +++ b/vendor/github.com/lucas-clemente/quic-go/appveyor.yml @@ -14,8 +14,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go install: - rmdir c:\go /s /q - - appveyor DownloadFile https://storage.googleapis.com/golang/go1.9.2.windows-amd64.zip - - 7z x go1.9.2.windows-amd64.zip -y -oC:\ > NUL + - appveyor DownloadFile https://storage.googleapis.com/golang/go1.10.2.windows-amd64.zip + - 7z x go1.10.2.windows-amd64.zip -y -oC:\ > NUL - set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin - echo %PATH% - echo %GOPATH% diff --git a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go index 5032ca7..6b23369 100644 --- a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go +++ b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go @@ -8,19 +8,20 @@ import ( var bufferPool sync.Pool -func getPacketBuffer() []byte { - return bufferPool.Get().([]byte) +func getPacketBuffer() *[]byte { + return bufferPool.Get().(*[]byte) } -func putPacketBuffer(buf []byte) { - if cap(buf) != int(protocol.MaxReceivePacketSize) { +func putPacketBuffer(buf *[]byte) { + if cap(*buf) != int(protocol.MaxReceivePacketSize) { panic("putPacketBuffer called with packet of wrong size!") } - bufferPool.Put(buf[:0]) + bufferPool.Put(buf) } func init() { bufferPool.New = func() interface{} { - return make([]byte, 0, protocol.MaxReceivePacketSize) + b := make([]byte, 0, protocol.MaxReceivePacketSize) + return &b } } diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go index d13dd81..692ca3c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/client.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -22,24 +23,29 @@ type client struct { conn connection hostname string - handshakeChan <-chan handshakeEvent - versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version - versionNegotiated bool // has version negotiation completed yet + versionNegotiated bool // has the server accepted our version receivedVersionNegotiationPacket bool + negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet tlsConf *tls.Config config *Config + tls handshake.MintTLS // only used when using TLS - connectionID protocol.ConnectionID - version protocol.VersionNumber + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID + + initialVersion protocol.VersionNumber + version protocol.VersionNumber session packetHandler + + logger utils.Logger } var ( // make it possible to mock connection ID generation in the tests - generateConnectionID = utils.GenerateConnectionID + generateConnectionID = protocol.GenerateConnectionID errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") ) @@ -57,69 +63,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) return Dial(udpConn, udpAddr, addr, tlsConf, config) } -// DialAddrNonFWSecure establishes a new QUIC connection to a server. -// The hostname for SNI is taken from the given address. -func DialAddrNonFWSecure( - addr string, - tlsConf *tls.Config, - config *Config, -) (NonFWSession, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config) -} - -// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. -// The host parameter is used for SNI. -func DialNonFWSecure( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (NonFWSession, error) { - connID, err := generateConnectionID() - if err != nil { - return nil, err - } - - var hostname string - if tlsConf != nil { - hostname = tlsConf.ServerName - } - - if hostname == "" { - hostname, _, err = net.SplitHostPort(host) - if err != nil { - return nil, err - } - } - - clientConfig := populateClientConfig(config) - c := &client{ - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - connectionID: connID, - hostname: hostname, - tlsConf: tlsConf, - config: clientConfig, - version: clientConfig.Versions[0], - versionNegotiationChan: make(chan struct{}), - } - - utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) - - if err := c.establishSecureConnection(); err != nil { - return nil, err - } - return c.session.(NonFWSession), nil -} - // Dial establishes a new QUIC connection to a server using a net.PacketConn. // The host parameter is used for SNI. func Dial( @@ -129,14 +72,57 @@ func Dial( tlsConf *tls.Config, config *Config, ) (Session, error) { - sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config) + clientConfig := populateClientConfig(config) + version := clientConfig.Versions[0] + srcConnID, err := generateConnectionID() if err != nil { return nil, err } - if err := sess.WaitUntilHandshakeComplete(); err != nil { + destConnID := srcConnID + if version.UsesTLS() { + destConnID, err = generateConnectionID() + if err != nil { + return nil, err + } + } + + var hostname string + if tlsConf != nil { + hostname = tlsConf.ServerName + } + if hostname == "" { + hostname, _, err = net.SplitHostPort(host) + if err != nil { + return nil, err + } + } + + // check that all versions are actually supported + if config != nil { + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + } + } + c := &client{ + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + srcConnID: srcConnID, + destConnID: destConnID, + hostname: hostname, + tlsConf: tlsConf, + config: clientConfig, + version: version, + versionNegotiationChan: make(chan struct{}), + logger: utils.DefaultLogger, + } + + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + + if err := c.dial(); err != nil { return nil, err } - return sess, nil + return c.session, nil } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -167,6 +153,18 @@ func populateClientConfig(config *Config) *Config { if maxReceiveConnectionFlowControlWindow == 0 { maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient } + maxIncomingStreams := config.MaxIncomingStreams + if maxIncomingStreams == 0 { + maxIncomingStreams = protocol.DefaultMaxIncomingStreams + } else if maxIncomingStreams < 0 { + maxIncomingStreams = 0 + } + maxIncomingUniStreams := config.MaxIncomingUniStreams + if maxIncomingUniStreams == 0 { + maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams + } else if maxIncomingUniStreams < 0 { + maxIncomingUniStreams = 0 + } return &Config{ Versions: versions, @@ -175,29 +173,87 @@ func populateClientConfig(config *Config) *Config { RequestConnectionIDOmission: config.RequestConnectionIDOmission, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, - KeepAlive: config.KeepAlive, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + KeepAlive: config.KeepAlive, } } -// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) -func (c *client) establishSecureConnection() error { - if err := c.createNewSession(c.version, nil); err != nil { +func (c *client) dial() error { + var err error + if c.version.UsesTLS() { + err = c.dialTLS() + } else { + err = c.dialGQUIC() + } + if err == errCloseSessionForNewVersion { + return c.dial() + } + return err +} + +func (c *client) dialGQUIC() error { + if err := c.createNewGQUICSession(); err != nil { return err } go c.listen() + return c.establishSecureConnection() +} +func (c *client) dialTLS() error { + params := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + IdleTimeout: c.config.IdleTimeout, + OmitConnectionID: c.config.RequestConnectionIDOmission, + MaxBidiStreams: uint16(c.config.MaxIncomingStreams), + MaxUniStreams: uint16(c.config.MaxIncomingUniStreams), + } + csc := handshake.NewCryptoStreamConn(nil) + extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger) + mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) + if err != nil { + return err + } + mintConf.ExtensionHandler = extHandler + mintConf.ServerName = c.hostname + c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient) + + if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { + return err + } + go c.listen() + if err := c.establishSecureConnection(); err != nil { + if err != handshake.ErrCloseSessionForRetry { + return err + } + c.logger.Infof("Received a Retry packet. Recreating session.") + if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { + return err + } + if err := c.establishSecureConnection(); err != nil { + return err + } + } + return nil +} + +// establishSecureConnection runs the session, and tries to establish a secure connection +// It returns: +// - errCloseSessionForNewVersion when the server sends a version negotiation packet +// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC) +// - any other error that might occur +// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) +func (c *client) establishSecureConnection() error { var runErr error errorChan := make(chan struct{}) go func() { - // session.run() returns as soon as the session is closed - runErr = c.session.run() - if runErr == errCloseSessionForNewVersion { - // run the new session - runErr = c.session.run() - } + runErr = c.session.run() // returns as soon as the session is closed close(errorChan) - utils.Infof("Connection %x closed.", c.connectionID) - c.conn.Close() + c.logger.Infof("Connection %s closed.", c.srcConnID) + if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { + c.conn.Close() + } }() // wait until the server accepts the QUIC version (or an error occurs) @@ -210,96 +266,95 @@ func (c *client) establishSecureConnection() error { select { case <-errorChan: return runErr - case ev := <-c.handshakeChan: - if ev.err != nil { - return ev.err - } - if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure { - return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel) - } - return nil + case err := <-c.session.handshakeStatus(): + return err } } -// Listen listens +// Listen listens on the underlying connection and passes packets on for handling. +// It returns when the connection is closed. func (c *client) listen() { var err error for { var n int var addr net.Addr - data := getPacketBuffer() + data := *getPacketBuffer() data = data[:protocol.MaxReceivePacketSize] // The packet size should not exceed protocol.MaxReceivePacketSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable n, addr, err = c.conn.Read(data) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { - c.session.Close(err) + c.mutex.Lock() + if c.session != nil { + c.session.Close(err) + } + c.mutex.Unlock() } break } - data = data[:n] - - c.handlePacket(addr, data) + if err := c.handlePacket(addr, data[:n]); err != nil { + c.logger.Errorf("error handling packet: %s", err.Error()) + } } } -func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { +func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { rcvTime := time.Now() r := bytes.NewReader(packet) hdr, err := wire.ParseHeaderSentByServer(r, c.version) + // drop the packet if we can't parse the header if err != nil { - utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) - // drop this packet if we can't parse the header - return + return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) } // reject packets with truncated connection id if we didn't request truncation if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { - return - } - // reject packets with the wrong connection ID - if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID { - return + return errors.New("received packet with truncated connection ID, but didn't request truncation") } hdr.Raw = packet[:len(packet)-r.Len()] + packetData := packet[len(packet)-r.Len():] c.mutex.Lock() defer c.mutex.Unlock() - if hdr.ResetFlag { - cr := c.conn.RemoteAddr() - // check if the remote address and the connection ID match - // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection - if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { - utils.Infof("Received a spoofed Public Reset. Ignoring.") - return - } - pr, err := wire.ParsePublicReset(r) - if err != nil { - utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) - return - } - utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) - c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) - return - } - - isVersionNegotiationPacket := hdr.VersionFlag /* gQUIC Version Negotiation Packet */ || hdr.Type == protocol.PacketTypeVersionNegotiation /* IETF draft style Version Negotiation Packet */ - // handle Version Negotiation Packets - if isVersionNegotiationPacket { + if hdr.IsVersionNegotiation { // ignore delayed / duplicated version negotiation packets if c.receivedVersionNegotiationPacket || c.versionNegotiated { - return + return errors.New("received a delayed Version Negotiation Packet") } // version negotiation packets have no payload if err := c.handleVersionNegotiationPacket(hdr); err != nil { c.session.Close(err) } - return + return nil + } + + if hdr.IsPublicHeader { + return c.handleGQUICPacket(hdr, r, packetData, remoteAddr, rcvTime) + } + return c.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime) +} + +func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { + // TODO(#1003): add support for server-chosen connection IDs + // reject packets with the wrong connection ID + if !hdr.DestConnectionID.Equal(c.srcConnID) { + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID) + } + if hdr.IsLongHeader { + if hdr.Type != protocol.PacketTypeRetry && hdr.Type != protocol.PacketTypeHandshake { + return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) + } + c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen) + if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { + return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) + } + packetData = packetData[:int(hdr.PayloadLen)] + // TODO(#1312): implement parsing of compound packets } // this is the first packet we are receiving @@ -312,9 +367,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { c.session.handlePacket(&receivedPacket{ remoteAddr: remoteAddr, header: hdr, - data: packet[len(packet)-r.Len():], + data: packetData, rcvTime: rcvTime, }) + return nil +} + +func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { + // reject packets with the wrong connection ID + if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.srcConnID) { + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID) + } + + if hdr.ResetFlag { + cr := c.conn.RemoteAddr() + // check if the remote address and the connection ID match + // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection + if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.srcConnID) { + return errors.New("Received a spoofed Public Reset") + } + pr, err := wire.ParsePublicReset(r) + if err != nil { + return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err) + } + c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) + c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber) + return nil + } + + // this is the first packet we are receiving + // since it is not a Version Negotiation Packet, this means the server supports the suggested version + if !c.versionNegotiated { + c.versionNegotiated = true + close(c.versionNegotiationChan) + } + + c.session.handlePacket(&receivedPacket{ + remoteAddr: remoteAddr, + header: hdr, + data: packetData, + rcvTime: rcvTime, + }) + return nil } func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { @@ -327,42 +421,66 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { } } - c.receivedVersionNegotiationPacket = true + c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { return qerr.InvalidVersion } + c.receivedVersionNegotiationPacket = true + c.negotiatedVersions = hdr.SupportedVersions // switch to negotiated version - initialVersion := c.version + c.initialVersion = c.version c.version = newVersion var err error - c.connectionID, err = utils.GenerateConnectionID() + c.destConnID, err = generateConnectionID() if err != nil { return err } - utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) - - // create a new session and close the old one - // the new session must be created first to update client member variables - oldSession := c.session - defer oldSession.Close(errCloseSessionForNewVersion) - return c.createNewSession(initialVersion, hdr.SupportedVersions) + // in gQUIC, there's only one connection ID + if !c.version.UsesTLS() { + c.srcConnID = c.destConnID + } + c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) + c.session.Close(errCloseSessionForNewVersion) + return nil } -func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error { - var err error - utils.Debugf("createNewSession with initial version %s", initialVersion) - c.session, c.handshakeChan, err = newClientSession( +func (c *client) createNewGQUICSession() (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.session, err = newClientSession( c.conn, c.hostname, c.version, - c.connectionID, + c.destConnID, c.tlsConf, c.config, - initialVersion, - negotiatedVersions, + c.initialVersion, + c.negotiatedVersions, + c.logger, + ) + return err +} + +func (c *client) createNewTLSSession( + paramsChan <-chan handshake.TransportParameters, + version protocol.VersionNumber, +) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.session, err = newTLSClientSession( + c.conn, + c.hostname, + c.version, + c.destConnID, + c.srcConnID, + c.config, + c.tls, + paramsChan, + 1, + c.logger, ) return err } diff --git a/vendor/github.com/lucas-clemente/quic-go/codecov.yml b/vendor/github.com/lucas-clemente/quic-go/codecov.yml index d85e781..f077c1a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/codecov.yml +++ b/vendor/github.com/lucas-clemente/quic-go/codecov.yml @@ -1,11 +1,16 @@ coverage: round: nearest ignore: - - ackhandler/packet_linkedlist.go + - streams_map_incoming_bidi.go + - streams_map_incoming_uni.go + - streams_map_outgoing_bidi.go + - streams_map_outgoing_uni.go - h2quic/gzipreader.go - h2quic/response.go + - internal/ackhandler/packet_linkedlist.go - internal/utils/byteinterval_linkedlist.go - internal/utils/packetinterval_linkedlist.go + - internal/utils/linkedlist/linkedlist.go status: project: default: diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go b/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go deleted file mode 100644 index 624957c..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go +++ /dev/null @@ -1,183 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/lucas-clemente/quic-go/internal/utils" -) - -const ( - // Note: This constant is also defined in the ackhandler package. - initialRTTus = 100 * 1000 - rttAlpha float32 = 0.125 - oneMinusAlpha float32 = (1 - rttAlpha) - rttBeta float32 = 0.25 - oneMinusBeta float32 = (1 - rttBeta) - halfWindow float32 = 0.5 - quarterWindow float32 = 0.25 -) - -type rttSample struct { - rtt time.Duration - time time.Time -} - -// RTTStats provides round-trip statistics -type RTTStats struct { - initialRTTus int64 - - recentMinRTTwindow time.Duration - minRTT time.Duration - latestRTT time.Duration - smoothedRTT time.Duration - meanDeviation time.Duration - - numMinRTTsamplesRemaining uint32 - - newMinRTT rttSample - recentMinRTT rttSample - halfWindowRTT rttSample - quarterWindowRTT rttSample -} - -// NewRTTStats makes a properly initialized RTTStats object -func NewRTTStats() *RTTStats { - return &RTTStats{ - initialRTTus: initialRTTus, - recentMinRTTwindow: utils.InfDuration, - } -} - -// InitialRTTus is the initial RTT in us -func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus } - -// MinRTT Returns the minRTT for the entire connection. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } - -// LatestRTT returns the most recent rtt measurement. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } - -// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the -// minRTT for the entire connection if SampleNewMinRtt was never called. -func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt } - -// SmoothedRTT returns the EWMA smoothed RTT for the connection. -// May return Zero if no valid updates have occurred. -func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } - -// GetQuarterWindowRTT gets the quarter window RTT -func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt } - -// GetHalfWindowRTT gets the half window RTT -func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt } - -// MeanDeviation gets the mean deviation -func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } - -// SetRecentMinRTTwindow sets how old a recent min rtt sample can be. -func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) { - r.recentMinRTTwindow = recentMinRTTwindow -} - -// UpdateRTT updates the RTT based on a new sample. -func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { - if sendDelta == utils.InfDuration || sendDelta <= 0 { - utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond) - return - } - - // Update r.minRTT first. r.minRTT does not use an rttSample corrected for - // ackDelay but the raw observed sendDelta, since poor clock granularity at - // the client may cause a high ackDelay to result in underestimation of the - // r.minRTT. - if r.minRTT == 0 || r.minRTT > sendDelta { - r.minRTT = sendDelta - } - r.updateRecentMinRTT(sendDelta, now) - - // Correct for ackDelay if information received from the peer results in a - // positive RTT sample. Otherwise, we use the sendDelta as a reasonable - // measure for smoothedRTT. - sample := sendDelta - if sample > ackDelay { - sample -= ackDelay - } - r.latestRTT = sample - // First time call. - if r.smoothedRTT == 0 { - r.smoothedRTT = sample - r.meanDeviation = sample / 2 - } else { - r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond - r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond - } -} - -func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update. - if r.numMinRTTsamplesRemaining > 0 { - r.numMinRTTsamplesRemaining-- - if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt { - r.newMinRTT = rttSample{rtt: sample, time: now} - } - if r.numMinRTTsamplesRemaining == 0 { - r.recentMinRTT = r.newMinRTT - r.halfWindowRTT = r.newMinRTT - r.quarterWindowRTT = r.newMinRTT - } - } - - // Update the three recent rtt samples. - if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt { - r.recentMinRTT = rttSample{rtt: sample, time: now} - r.halfWindowRTT = r.recentMinRTT - r.quarterWindowRTT = r.recentMinRTT - } else if sample <= r.halfWindowRTT.rtt { - r.halfWindowRTT = rttSample{rtt: sample, time: now} - r.quarterWindowRTT = r.halfWindowRTT - } else if sample <= r.quarterWindowRTT.rtt { - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } - - // Expire old min rtt samples. - if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) { - r.recentMinRTT = r.halfWindowRTT - r.halfWindowRTT = r.quarterWindowRTT - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) { - r.halfWindowRTT = r.quarterWindowRTT - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) { - r.quarterWindowRTT = rttSample{rtt: sample, time: now} - } -} - -// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next -// |numSamples| UpdateRTT calls. -func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) { - r.numMinRTTsamplesRemaining = numSamples - r.newMinRTT = rttSample{} -} - -// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. -func (r *RTTStats) OnConnectionMigration() { - r.latestRTT = 0 - r.minRTT = 0 - r.smoothedRTT = 0 - r.meanDeviation = 0 - r.initialRTTus = initialRTTus - r.numMinRTTsamplesRemaining = 0 - r.recentMinRTTwindow = utils.InfDuration - r.recentMinRTT = rttSample{} - r.halfWindowRTT = rttSample{} - r.quarterWindowRTT = rttSample{} -} - -// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt -// is larger. The mean deviation is increased to the most recent deviation if -// it's larger. -func (r *RTTStats) ExpireSmoothedMetrics() { - r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT)) - r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go new file mode 100644 index 0000000..8e96ec1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go @@ -0,0 +1,41 @@ +package quic + +import ( + "io" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type cryptoStreamI interface { + StreamID() protocol.StreamID + io.Reader + io.Writer + handleStreamFrame(*wire.StreamFrame) error + popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) + closeForShutdown(error) + setReadOffset(protocol.ByteCount) + // methods needed for flow control + getWindowUpdate() protocol.ByteCount + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) +} + +type cryptoStream struct { + *stream +} + +var _ cryptoStreamI = &cryptoStream{} + +func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { + str := newStream(version.CryptoStreamID(), sender, flowController, version) + return &cryptoStream{str} +} + +// SetReadOffset sets the read offset. +// It is only needed for the crypto stream. +// It must not be called concurrently with any other stream methods, especially Read and Write. +func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) { + s.receiveStream.readOffset = offset + s.receiveStream.frameQueue.readPosition = offset +} diff --git a/vendor/github.com/lucas-clemente/quic-go/interface.go b/vendor/github.com/lucas-clemente/quic-go/interface.go index 87bf9ea..3ab64af 100644 --- a/vendor/github.com/lucas-clemente/quic-go/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/interface.go @@ -16,23 +16,48 @@ type StreamID = protocol.StreamID // A VersionNumber is a QUIC version number. type VersionNumber = protocol.VersionNumber +// VersionGQUIC39 is gQUIC version 39. +const VersionGQUIC39 = protocol.Version39 + // A Cookie can be used to verify the ownership of the client address. type Cookie = handshake.Cookie +// ConnectionState records basic details about the QUIC connection. +type ConnectionState = handshake.ConnectionState + +// An ErrorCode is an application-defined error code. +type ErrorCode = protocol.ApplicationErrorCode + // Stream is the interface implemented by QUIC streams type Stream interface { + // StreamID returns the stream ID. + StreamID() StreamID // Read reads data from the stream. // Read can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. io.Reader // Write writes data to the stream. // Write can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. io.Writer + // Close closes the write-direction of the stream. + // Future calls to Write are not permitted after calling Close. + // It must not be called concurrently with Write. + // It must not be called after calling CancelWrite. io.Closer - StreamID() StreamID - // Reset closes the stream with an error. - Reset(error) + // CancelWrite aborts sending on this stream. + // It must not be called after Close. + // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. + // Write will unblock immediately, and future calls to Write will fail. + CancelWrite(ErrorCode) error + // CancelRead aborts receiving on this stream. + // It will ask the peer to stop transmitting stream data. + // Read will unblock immediately, and future Read calls will fail. + CancelRead(ErrorCode) error // The context is canceled as soon as the write-side of the stream is closed. // This happens when Close() is called, or when the stream is reset (either locally or remotely). // Warning: This API should not be considered stable and might change soon. @@ -53,18 +78,63 @@ type Stream interface { SetDeadline(t time.Time) error } +// A ReceiveStream is a unidirectional Receive Stream. +type ReceiveStream interface { + // see Stream.StreamID + StreamID() StreamID + // see Stream.Read + io.Reader + // see Stream.CancelRead + CancelRead(ErrorCode) error + // see Stream.SetReadDealine + SetReadDeadline(t time.Time) error +} + +// A SendStream is a unidirectional Send Stream. +type SendStream interface { + // see Stream.StreamID + StreamID() StreamID + // see Stream.Write + io.Writer + // see Stream.Close + io.Closer + // see Stream.CancelWrite + CancelWrite(ErrorCode) error + // see Stream.Context + Context() context.Context + // see Stream.SetWriteDeadline + SetWriteDeadline(t time.Time) error +} + +// StreamError is returned by Read and Write when the peer cancels the stream. +type StreamError interface { + error + Canceled() bool + ErrorCode() ErrorCode +} + // A Session is a QUIC connection between two peers. type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. - // Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server). AcceptStream() (Stream, error) - // OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached. - // New streams always have the smallest possible stream ID. - // TODO: Enable testing for the special error + // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. + AcceptUniStream() (ReceiveStream, error) + // OpenStream opens a new bidirectional QUIC stream. + // It returns a special error when the peer's concurrent stream limit is reached. + // There is no signaling to the peer about new streams: + // The peer can only accept the stream after data has been sent on the stream. + // TODO(#1152): Enable testing for the special error OpenStream() (Stream, error) - // OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened. - // It always picks the smallest possible stream ID. + // OpenStreamSync opens a new bidirectional QUIC stream. + // It blocks until the peer's concurrent stream limit allows a new stream to be opened. OpenStreamSync() (Stream, error) + // OpenUniStream opens a new outgoing unidirectional QUIC stream. + // It returns a special error when the peer's concurrent stream limit is reached. + // TODO(#1152): Enable testing for the special error + OpenUniStream() (SendStream, error) + // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. + // It blocks until the peer's concurrent stream limit allows a new stream to be opened. + OpenUniStreamSync() (SendStream, error) // LocalAddr returns the local address. LocalAddr() net.Addr // RemoteAddr returns the address of the peer. @@ -74,13 +144,9 @@ type Session interface { // The context is cancelled when the session is closed. // Warning: This API should not be considered stable and might change soon. Context() context.Context -} - -// A NonFWSession is a QUIC connection between two peers half-way through the handshake. -// The communication is encrypted, but not yet forward secure. -type NonFWSession interface { - Session - WaitUntilHandshakeComplete() error + // ConnectionState returns basic details about the QUIC connection. + // Warning: This API should not be considered stable and might change soon. + ConnectionState() ConnectionState } // Config contains all configuration data needed for a QUIC server or client. @@ -113,6 +179,17 @@ type Config struct { // MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data. // If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client. MaxReceiveConnectionFlowControlWindow uint64 + // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. + // If not set, it will default to 100. + // If set to a negative value, it doesn't allow any bidirectional streams. + // Values larger than 65535 (math.MaxUint16) are invalid. + MaxIncomingStreams int + // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. + // This value doesn't have any effect in Google QUIC. + // If not set, it will default to 100. + // If set to a negative value, it doesn't allow any unidirectional streams. + // Values larger than 65535 (math.MaxUint16) are invalid. + MaxIncomingUniStreams int // KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive. KeepAlive bool } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go new file mode 100644 index 0000000..32235f8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go @@ -0,0 +1,3 @@ +package ackhandler + +//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go new file mode 100644 index 0000000..43027dc --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go @@ -0,0 +1,46 @@ +package ackhandler + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// SentPacketHandler handles ACKs received for outgoing packets +type SentPacketHandler interface { + // SentPacket may modify the packet + SentPacket(packet *Packet) + SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) + ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error + SetHandshakeComplete() + + // The SendMode determines if and what kind of packets can be sent. + SendMode() SendMode + // TimeUntilSend is the time when the next packet should be sent. + // It is used for pacing packets. + TimeUntilSend() time.Time + // ShouldSendNumPackets returns the number of packets that should be sent immediately. + // It always returns a number greater or equal than 1. + // A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay. + // Note that the number of packets is only calculated based on the pacing algorithm. + // Before sending any packet, SendingAllowed() must be called to learn if we can actually send it. + ShouldSendNumPackets() int + + GetStopWaitingFrame(force bool) *wire.StopWaitingFrame + GetLowestPacketNotConfirmedAcked() protocol.PacketNumber + DequeuePacketForRetransmission() (packet *Packet) + GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen + + GetAlarmTimeout() time.Time + OnAlarm() error +} + +// ReceivedPacketHandler handles ACKs needed to send for incoming packets +type ReceivedPacketHandler interface { + ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error + IgnoreBelow(protocol.PacketNumber) + + GetAlarmTimeout() time.Time + GetAckFrame() *wire.AckFrame +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go new file mode 100644 index 0000000..9673a85 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet.go @@ -0,0 +1,29 @@ +package ackhandler + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// A Packet is a packet +type Packet struct { + PacketNumber protocol.PacketNumber + PacketType protocol.PacketType + Frames []wire.Frame + Length protocol.ByteCount + EncryptionLevel protocol.EncryptionLevel + SendTime time.Time + + largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK + + // There are two reasons why a packet cannot be retransmitted: + // * it was already retransmitted + // * this packet is a retransmission, and we already received an ACK for the original packet + canBeRetransmitted bool + includedInBytesInFlight bool + retransmittedAs []protocol.PacketNumber + isRetransmission bool // we need a separate bool here because 0 is a valid packet number + retransmissionOf protocol.PacketNumber +} diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go similarity index 80% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/packet_linkedlist.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go index a827b21..bb74f4e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet_linkedlist.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go @@ -1,13 +1,10 @@ -// Generated by: main -// TypeWriter: linkedlist -// Directive: +gen on Packet +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny package ackhandler -// List is a modification of http://golang.org/pkg/container/list/ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. +// Linked list implementation from the Go standard library. // PacketElement is an element of a linked list. type PacketElement struct { @@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement { return nil } -// PacketList represents a doubly linked list. -// The zero value for PacketList is an empty list ready to use. +// PacketList is a linked list of Packets. type PacketList struct { root PacketElement // sentinel list element, only &root, root.prev, and root.next are used len int // current list length excluding (this) sentinel element @@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() } // The complexity is O(1). func (l *PacketList) Len() int { return l.len } -// Front returns the first element of list l or nil. +// Front returns the first element of list l or nil if the list is empty. func (l *PacketList) Front() *PacketElement { if l.len == 0 { return nil @@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement { return l.root.next } -// Back returns the last element of list l or nil. +// Back returns the last element of list l or nil if the list is empty. func (l *PacketList) Back() *PacketElement { if l.len == 0 { return nil @@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement { return l.root.prev } -// lazyInit lazily initializes a zero PacketList value. +// lazyInit lazily initializes a zero List value. func (l *PacketList) lazyInit() { if l.root.next == nil { l.Init() @@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement { return e } -// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at). +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { return l.insert(&PacketElement{Value: v}, at) } @@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement { // Remove removes e from l if e is an element of list l. // It returns the element value e.Value. +// The element must not be nil. func (l *PacketList) Remove(e *PacketElement) Packet { if e.list == l { // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero PacketElement) and l.remove will crash + // in l or l == nil (e is a zero Element) and l.remove will crash l.remove(e) } return e.Value @@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement { // InsertBefore inserts a new element e with value v immediately before mark and returns e. // If mark is not an element of l, the list is not modified. +// The mark must not be nil. func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { if mark.list != l { return nil } - // see comment in PacketList.Remove about initialization of l + // see comment in List.Remove about initialization of l return l.insertValue(v, mark.prev) } // InsertAfter inserts a new element e with value v immediately after mark and returns e. // If mark is not an element of l, the list is not modified. +// The mark must not be nil. func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { if mark.list != l { return nil } - // see comment in PacketList.Remove about initialization of l + // see comment in List.Remove about initialization of l return l.insertValue(v, mark) } // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. +// The element must not be nil. func (l *PacketList) MoveToFront(e *PacketElement) { if e.list != l || l.root.next == e { return } - // see comment in PacketList.Remove about initialization of l + // see comment in List.Remove about initialization of l l.insert(l.remove(e), &l.root) } // MoveToBack moves element e to the back of list l. // If e is not an element of l, the list is not modified. +// The element must not be nil. func (l *PacketList) MoveToBack(e *PacketElement) { if e.list != l || l.root.prev == e { return } - // see comment in PacketList.Remove about initialization of l + // see comment in List.Remove about initialization of l l.insert(l.remove(e), l.root.prev) } // MoveBefore moves element e to its new position before mark. // If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. func (l *PacketList) MoveBefore(e, mark *PacketElement) { if e.list != l || e == mark || mark.list != l { return @@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) { } // MoveAfter moves element e to its new position after mark. -// If e is not an element of l, or e == mark, the list is not modified. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. func (l *PacketList) MoveAfter(e, mark *PacketElement) { if e.list != l || e == mark || mark.list != l { return @@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) { } // PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. +// The lists l and other may be the same. They must not be nil. func (l *PacketList) PushBackList(other *PacketList) { l.lazyInit() for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { @@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) { } // PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. +// The lists l and other may be the same. They must not be nil. func (l *PacketList) PushFrontList(other *PacketList) { l.lazyInit() for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go new file mode 100644 index 0000000..8af2132 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go @@ -0,0 +1,215 @@ +package ackhandler + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type receivedPacketHandler struct { + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber + largestObservedReceivedTime time.Time + + packetHistory *receivedPacketHistory + + ackSendDelay time.Duration + rttStats *congestion.RTTStats + + packetsReceivedSinceLastAck int + retransmittablePacketsReceivedSinceLastAck int + ackQueued bool + ackAlarm time.Time + lastAck *wire.AckFrame + + logger utils.Logger + + version protocol.VersionNumber +} + +const ( + // maximum delay that can be applied to an ACK for a retransmittable packet + ackSendDelay = 25 * time.Millisecond + // initial maximum number of retransmittable packets received before sending an ack. + initialRetransmittablePacketsBeforeAck = 2 + // number of retransmittable that an ACK is sent for + retransmittablePacketsBeforeAck = 10 + // 1/5 RTT delay when doing ack decimation + ackDecimationDelay = 1.0 / 4 + // 1/8 RTT delay when doing ack decimation + shortAckDecimationDelay = 1.0 / 8 + // Minimum number of packets received before ack decimation is enabled. + // This intends to avoid the beginning of slow start, when CWNDs may be + // rapidly increasing. + minReceivedBeforeAckDecimation = 100 + // Maximum number of packets to ack immediately after a missing packet for + // fast retransmission to kick in at the sender. This limit is created to + // reduce the number of acks sent that have no benefit for fast retransmission. + // Set to the number of nacks needed for fast retransmit plus one for protection + // against an ack loss + maxPacketsAfterNewMissing = 4 +) + +// NewReceivedPacketHandler creates a new receivedPacketHandler +func NewReceivedPacketHandler( + rttStats *congestion.RTTStats, + logger utils.Logger, + version protocol.VersionNumber, +) ReceivedPacketHandler { + return &receivedPacketHandler{ + packetHistory: newReceivedPacketHistory(), + ackSendDelay: ackSendDelay, + rttStats: rttStats, + logger: logger, + version: version, + } +} + +func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error { + if packetNumber < h.ignoreBelow { + return nil + } + + isMissing := h.isMissing(packetNumber) + if packetNumber > h.largestObserved { + h.largestObserved = packetNumber + h.largestObservedReceivedTime = rcvTime + } + + if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { + return err + } + h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing) + return nil +} + +// IgnoreBelow sets a lower limit for acking packets. +// Packets with packet numbers smaller than p will not be acked. +func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) { + if p <= h.ignoreBelow { + return + } + h.ignoreBelow = p + h.packetHistory.DeleteBelow(p) + if h.logger.Debug() { + h.logger.Debugf("\tIgnoring all packets below %#x.", p) + } +} + +// isMissing says if a packet was reported missing in the last ACK. +func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool { + if h.lastAck == nil || p < h.ignoreBelow { + return false + } + return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) +} + +func (h *receivedPacketHandler) hasNewMissingPackets() bool { + if h.lastAck == nil { + return false + } + highestRange := h.packetHistory.GetHighestAckRange() + return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing +} + +// maybeQueueAck queues an ACK, if necessary. +// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck() +// in ACK_DECIMATION_WITH_REORDERING mode. +func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) { + h.packetsReceivedSinceLastAck++ + + // always ack the first packet + if h.lastAck == nil { + h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") + h.ackQueued = true + return + } + + // Send an ACK if this packet was reported missing in an ACK sent before. + // Ack decimation with reordering relies on the timer to send an ACK, but if + // missing packets we reported in the previous ack, send an ACK immediately. + if wasMissing { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber) + } + h.ackQueued = true + } + + if !h.ackQueued && shouldInstigateAck { + h.retransmittablePacketsReceivedSinceLastAck++ + + if packetNumber > minReceivedBeforeAckDecimation { + // ack up to 10 packets at once + if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck { + h.ackQueued = true + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck) + } + } else if h.ackAlarm.IsZero() { + // wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack + ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay))) + h.ackAlarm = rcvTime.Add(ackDelay) + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm)) + } + } + } else { + // send an ACK every 2 retransmittable packets + if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck) + } + h.ackQueued = true + } else if h.ackAlarm.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay) + } + h.ackAlarm = rcvTime.Add(ackSendDelay) + } + } + // If there are new missing packets to report, set a short timer to send an ACK. + if h.hasNewMissingPackets() { + // wait the minimum of 1/8 min RTT and the existing ack time + ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)) + ackTime := rcvTime.Add(ackDelay) + if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) { + h.ackAlarm = ackTime + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm)) + } + } + } + } + + if h.ackQueued { + // cancel the ack alarm + h.ackAlarm = time.Time{} + } +} + +func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { + now := time.Now() + if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { + return nil + } + if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + h.logger.Debugf("Sending ACK because the ACK timer expired.") + } + + ack := &wire.AckFrame{ + AckRanges: h.packetHistory.GetAckRanges(), + DelayTime: now.Sub(h.largestObservedReceivedTime), + } + + h.lastAck = ack + h.ackAlarm = time.Time{} + h.ackQueued = false + h.packetsReceivedSinceLastAck = 0 + h.retransmittablePacketsReceivedSinceLastAck = 0 + return ack +} + +func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go similarity index 86% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go index 14bdfd5..758286d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go @@ -74,17 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error { return nil } -// DeleteUpTo deletes all entries up to (and including) p -func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) { - h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1) +// DeleteBelow deletes all entries below (but not including) p +func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) { + if p <= h.lowestInReceivedPacketNumbers { + return + } + h.lowestInReceivedPacketNumbers = p nextEl := h.ranges.Front() for el := h.ranges.Front(); nextEl != nil; el = nextEl { nextEl = el.Next() - if p >= el.Value.Start && p < el.Value.End { - el.Value.Start = p + 1 - } else if el.Value.End <= p { // delete a whole range + if p > el.Value.Start && p <= el.Value.End { + el.Value.Start = p + } else if el.Value.End < p { // delete a whole range h.ranges.Remove(el) } else { // no ranges affected. Nothing to do return @@ -101,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { ackRanges := make([]wire.AckRange, h.ranges.Len()) i := 0 for el := h.ranges.Back(); el != nil; el = el.Prev() { - ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End} + ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End} i++ } return ackRanges @@ -111,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange { ackRange := wire.AckRange{} if h.ranges.Len() > 0 { r := h.ranges.Back().Value - ackRange.First = r.Start - ackRange.Last = r.End + ackRange.Smallest = r.Start + ackRange.Largest = r.End } return ackRange } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go new file mode 100644 index 0000000..76c833c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go @@ -0,0 +1,40 @@ +package ackhandler + +import "fmt" + +// The SendMode says what kind of packets can be sent. +type SendMode uint8 + +const ( + // SendNone means that no packets should be sent + SendNone SendMode = iota + // SendAck means an ACK-only packet should be sent + SendAck + // SendRetransmission means that retransmissions should be sent + SendRetransmission + // SendRTO means that an RTO probe packet should be sent + SendRTO + // SendTLP means that a TLP probe packet should be sent + SendTLP + // SendAny means that any packet should be sent + SendAny +) + +func (s SendMode) String() string { + switch s { + case SendNone: + return "none" + case SendAck: + return "ack" + case SendRetransmission: + return "retransmission" + case SendRTO: + return "rto" + case SendTLP: + return "tlp" + case SendAny: + return "any" + default: + return fmt.Sprintf("invalid send mode: %d", s) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go new file mode 100644 index 0000000..274607b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go @@ -0,0 +1,649 @@ +package ackhandler + +import ( + "fmt" + "math" + "time" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +const ( + // Maximum reordering in time space before time based loss detection considers a packet lost. + // In fraction of an RTT. + timeReorderingFraction = 1.0 / 8 + // defaultRTOTimeout is the RTO time on new connections + defaultRTOTimeout = 500 * time.Millisecond + // Minimum time in the future a tail loss probe alarm may be set for. + minTPLTimeout = 10 * time.Millisecond + // Maximum number of tail loss probes before an RTO fires. + maxTLPs = 2 + // Minimum time in the future an RTO alarm may be set for. + minRTOTimeout = 200 * time.Millisecond + // maxRTOTimeout is the maximum RTO time + maxRTOTimeout = 60 * time.Second +) + +type sentPacketHandler struct { + lastSentPacketNumber protocol.PacketNumber + lastSentRetransmittablePacketTime time.Time + lastSentHandshakePacketTime time.Time + + nextPacketSendTime time.Time + skippedPackets []protocol.PacketNumber + + largestAcked protocol.PacketNumber + largestReceivedPacketWithAck protocol.PacketNumber + // lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived + // example: we send an ACK for packets 90-100 with packet number 20 + // once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101 + lowestPacketNotConfirmedAcked protocol.PacketNumber + largestSentBeforeRTO protocol.PacketNumber + + packetHistory *sentPacketHistory + stopWaitingManager stopWaitingManager + + retransmissionQueue []*Packet + + bytesInFlight protocol.ByteCount + + congestion congestion.SendAlgorithm + rttStats *congestion.RTTStats + + handshakeComplete bool + // The number of times the handshake packets have been retransmitted without receiving an ack. + handshakeCount uint32 + + // The number of times a TLP has been sent without receiving an ack. + tlpCount uint32 + allowTLP bool + + // The number of times an RTO has been sent without receiving an ack. + rtoCount uint32 + // The number of RTO probe packets that should be sent. + numRTOs int + + // The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time. + lossTime time.Time + + // The alarm timeout + alarm time.Time + + logger utils.Logger +} + +// NewSentPacketHandler creates a new sentPacketHandler +func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler { + congestion := congestion.NewCubicSender( + congestion.DefaultClock{}, + rttStats, + false, /* don't use reno since chromium doesn't (why?) */ + protocol.InitialCongestionWindow, + protocol.DefaultMaxCongestionWindow, + ) + + return &sentPacketHandler{ + packetHistory: newSentPacketHistory(), + stopWaitingManager: stopWaitingManager{}, + rttStats: rttStats, + congestion: congestion, + logger: logger, + } +} + +func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber { + if p := h.packetHistory.FirstOutstanding(); p != nil { + return p.PacketNumber + } + return h.largestAcked + 1 +} + +func (h *sentPacketHandler) SetHandshakeComplete() { + h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.") + var queue []*Packet + for _, packet := range h.retransmissionQueue { + if packet.EncryptionLevel == protocol.EncryptionForwardSecure { + queue = append(queue, packet) + } + } + var handshakePackets []*Packet + h.packetHistory.Iterate(func(p *Packet) (bool, error) { + if p.EncryptionLevel != protocol.EncryptionForwardSecure { + handshakePackets = append(handshakePackets, p) + } + return true, nil + }) + for _, p := range handshakePackets { + h.packetHistory.Remove(p.PacketNumber) + } + h.retransmissionQueue = queue + h.handshakeComplete = true +} + +func (h *sentPacketHandler) SentPacket(packet *Packet) { + if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable { + h.packetHistory.SentPacket(packet) + h.updateLossDetectionAlarm() + } +} + +func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) { + var p []*Packet + for _, packet := range packets { + if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable { + p = append(p, packet) + } + } + h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf) + h.updateLossDetectionAlarm() +} + +func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ { + for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { + h.skippedPackets = append(h.skippedPackets, p) + if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets { + h.skippedPackets = h.skippedPackets[1:] + } + } + + h.lastSentPacketNumber = packet.PacketNumber + + if len(packet.Frames) > 0 { + if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok { + packet.largestAcked = ackFrame.LargestAcked() + } + } + + packet.Frames = stripNonRetransmittableFrames(packet.Frames) + isRetransmittable := len(packet.Frames) != 0 + + if isRetransmittable { + if packet.EncryptionLevel < protocol.EncryptionForwardSecure { + h.lastSentHandshakePacketTime = packet.SendTime + } + h.lastSentRetransmittablePacketTime = packet.SendTime + packet.includedInBytesInFlight = true + h.bytesInFlight += packet.Length + packet.canBeRetransmitted = true + if h.numRTOs > 0 { + h.numRTOs-- + } + h.allowTLP = false + } + h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable) + + h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight)) + return isRetransmittable +} + +func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { + largestAcked := ackFrame.LargestAcked() + if largestAcked > h.lastSentPacketNumber { + return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package") + } + + // duplicate or out of order ACK + if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck { + h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).") + return nil + } + h.largestReceivedPacketWithAck = withPacketNumber + h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked) + + if h.skippedPacketsAcked(ackFrame) { + return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number") + } + + if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated { + h.congestion.MaybeExitSlowStart() + } + + ackedPackets, err := h.determineNewlyAckedPackets(ackFrame) + if err != nil { + return err + } + + priorInFlight := h.bytesInFlight + for _, p := range ackedPackets { + if encLevel < p.EncryptionLevel { + return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel) + } + // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 + // It is safe to ignore the corner case of packets that just acked packet 0, because + // the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send. + if p.largestAcked != 0 { + h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1) + } + if err := h.onPacketAcked(p, rcvTime); err != nil { + return err + } + if p.includedInBytesInFlight { + h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) + } + } + + if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil { + return err + } + h.updateLossDetectionAlarm() + + h.garbageCollectSkippedPackets() + h.stopWaitingManager.ReceivedAck(ackFrame) + + return nil +} + +func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + return h.lowestPacketNotConfirmedAcked +} + +func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) { + var ackedPackets []*Packet + ackRangeIndex := 0 + lowestAcked := ackFrame.LowestAcked() + largestAcked := ackFrame.LargestAcked() + err := h.packetHistory.Iterate(func(p *Packet) (bool, error) { + // Ignore packets below the lowest acked + if p.PacketNumber < lowestAcked { + return true, nil + } + // Break after largest acked is reached + if p.PacketNumber > largestAcked { + return false, nil + } + + if ackFrame.HasMissingRanges() { + ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] + + for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 { + ackRangeIndex++ + ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] + } + + if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range + if p.PacketNumber > ackRange.Largest { + return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest) + } + ackedPackets = append(ackedPackets, p) + } + } else { + ackedPackets = append(ackedPackets, p) + } + return true, nil + }) + if h.logger.Debug() && len(ackedPackets) > 0 { + pns := make([]protocol.PacketNumber, len(ackedPackets)) + for i, p := range ackedPackets { + pns[i] = p.PacketNumber + } + h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns) + } + return ackedPackets, err +} + +func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool { + if p := h.packetHistory.GetPacket(largestAcked); p != nil { + h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) + if h.logger.Debug() { + h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) + } + return true + } + return false +} + +func (h *sentPacketHandler) updateLossDetectionAlarm() { + // Cancel the alarm if no packets are outstanding + if h.packetHistory.Len() == 0 { + h.alarm = time.Time{} + return + } + + if !h.handshakeComplete { + h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout()) + } else if !h.lossTime.IsZero() { + // Early retransmit timer or time loss detection. + h.alarm = h.lossTime + } else { + // RTO or TLP alarm + alarmDuration := h.computeRTOTimeout() + if h.tlpCount < maxTLPs { + tlpAlarm := h.computeTLPTimeout() + // if the RTO duration is shorter than the TLP duration, use the RTO duration + alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm) + } + h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration) + } +} + +func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error { + h.lossTime = time.Time{} + + maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) + delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT) + + var lostPackets []*Packet + h.packetHistory.Iterate(func(packet *Packet) (bool, error) { + if packet.PacketNumber > h.largestAcked { + return false, nil + } + + timeSinceSent := now.Sub(packet.SendTime) + if timeSinceSent > delayUntilLost { + lostPackets = append(lostPackets, packet) + } else if h.lossTime.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent) + } + // Note: This conditional is only entered once per call + h.lossTime = now.Add(delayUntilLost - timeSinceSent) + } + return true, nil + }) + if h.logger.Debug() && len(lostPackets) > 0 { + pns := make([]protocol.PacketNumber, len(lostPackets)) + for i, p := range lostPackets { + pns[i] = p.PacketNumber + } + h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns) + } + + for _, p := range lostPackets { + // the bytes in flight need to be reduced no matter if this packet will be retransmitted + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } + if p.canBeRetransmitted { + // queue the packet for retransmission, and report the loss to the congestion controller + if err := h.queuePacketForRetransmission(p); err != nil { + return err + } + } + h.packetHistory.Remove(p.PacketNumber) + } + return nil +} + +func (h *sentPacketHandler) OnAlarm() error { + now := time.Now() + + var err error + if !h.handshakeComplete { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in handshake mode") + } + h.handshakeCount++ + err = h.queueHandshakePacketsForRetransmission() + } else if !h.lossTime.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in loss timer mode") + } + // Early retransmit or time loss detection + err = h.detectLostPackets(now, h.bytesInFlight) + } else if h.tlpCount < maxTLPs { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in TLP mode") + } + h.allowTLP = true + h.tlpCount++ + } else { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in RTO mode") + } + // RTO + h.rtoCount++ + h.numRTOs += 2 + err = h.queueRTOs() + } + if err != nil { + return err + } + h.updateLossDetectionAlarm() + return nil +} + +func (h *sentPacketHandler) GetAlarmTimeout() time.Time { + return h.alarm +} + +func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error { + // This happens if a packet and its retransmissions is acked in the same ACK. + // As soon as we process the first one, this will remove all the retransmissions, + // so we won't find the retransmitted packet number later. + if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil { + return nil + } + + // only report the acking of this packet to the congestion controller if: + // * it is a retransmittable packet + // * this packet wasn't retransmitted yet + if p.isRetransmission { + // that the parent doesn't exist is expected to happen every time the original packet was already acked + if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil { + if len(parent.retransmittedAs) == 1 { + parent.retransmittedAs = nil + } else { + // remove this packet from the slice of retransmission + retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1) + for _, pn := range parent.retransmittedAs { + if pn != p.PacketNumber { + retransmittedAs = append(retransmittedAs, pn) + } + } + parent.retransmittedAs = retransmittedAs + } + } + } + // this also applies to packets that have been retransmitted as probe packets + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + } + if h.rtoCount > 0 { + h.verifyRTO(p.PacketNumber) + } + if err := h.stopRetransmissionsFor(p); err != nil { + return err + } + h.rtoCount = 0 + h.tlpCount = 0 + h.handshakeCount = 0 + return h.packetHistory.Remove(p.PacketNumber) +} + +func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error { + if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil { + return err + } + for _, r := range p.retransmittedAs { + packet := h.packetHistory.GetPacket(r) + if packet == nil { + return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber) + } + h.stopRetransmissionsFor(packet) + } + return nil +} + +func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) { + if pn <= h.largestSentBeforeRTO { + h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO) + // Replace SRTT with latest_rtt and increase the variance to prevent + // a spurious RTO from happening again. + h.rttStats.ExpireSmoothedMetrics() + return + } + h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO) + h.congestion.OnRetransmissionTimeout(true) +} + +func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { + if len(h.retransmissionQueue) == 0 { + return nil + } + packet := h.retransmissionQueue[0] + // Shift the slice and don't retain anything that isn't needed. + copy(h.retransmissionQueue, h.retransmissionQueue[1:]) + h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil + h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1] + return packet +} + +func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen { + return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked()) +} + +func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { + return h.stopWaitingManager.GetStopWaitingFrame(force) +} + +func (h *sentPacketHandler) SendMode() SendMode { + numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len() + + // Don't send any packets if we're keeping track of the maximum number of packets. + // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, + // we will stop sending out new data when reaching MaxOutstandingSentPackets, + // but still allow sending of retransmissions and ACKs. + if numTrackedPackets >= protocol.MaxTrackedSentPackets { + if h.logger.Debug() { + h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) + } + return SendNone + } + if h.allowTLP { + return SendTLP + } + if h.numRTOs > 0 { + return SendRTO + } + // Only send ACKs if we're congestion limited. + if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd { + if h.logger.Debug() { + h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd) + } + return SendAck + } + // Send retransmissions first, if there are any. + if len(h.retransmissionQueue) > 0 { + return SendRetransmission + } + if numTrackedPackets >= protocol.MaxOutstandingSentPackets { + if h.logger.Debug() { + h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) + } + return SendAck + } + return SendAny +} + +func (h *sentPacketHandler) TimeUntilSend() time.Time { + return h.nextPacketSendTime +} + +func (h *sentPacketHandler) ShouldSendNumPackets() int { + if h.numRTOs > 0 { + // RTO probes should not be paced, but must be sent immediately. + return h.numRTOs + } + delay := h.congestion.TimeUntilSend(h.bytesInFlight) + if delay == 0 || delay > protocol.MinPacingDelay { + return 1 + } + return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) +} + +// retransmit the oldest two packets +func (h *sentPacketHandler) queueRTOs() error { + h.largestSentBeforeRTO = h.lastSentPacketNumber + // Queue the first two outstanding packets for retransmission. + // This does NOT declare this packets as lost: + // They are still tracked in the packet history and count towards the bytes in flight. + for i := 0; i < 2; i++ { + if p := h.packetHistory.FirstOutstanding(); p != nil { + h.logger.Debugf("Queueing packet %#x for retransmission (RTO)", p.PacketNumber) + if err := h.queuePacketForRetransmission(p); err != nil { + return err + } + } + } + return nil +} + +func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error { + var handshakePackets []*Packet + h.packetHistory.Iterate(func(p *Packet) (bool, error) { + if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure { + handshakePackets = append(handshakePackets, p) + } + return true, nil + }) + for _, p := range handshakePackets { + h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber) + if err := h.queuePacketForRetransmission(p); err != nil { + return err + } + } + return nil +} + +func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error { + if !p.canBeRetransmitted { + return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber) + } + if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil { + return err + } + h.retransmissionQueue = append(h.retransmissionQueue, p) + h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber) + return nil +} + +func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { + duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout) + // exponential backoff + // There's an implicit limit to this set by the handshake timeout. + return duration << h.handshakeCount +} + +func (h *sentPacketHandler) computeTLPTimeout() time.Duration { + // TODO(#1236): include the max_ack_delay + return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout) +} + +func (h *sentPacketHandler) computeRTOTimeout() time.Duration { + var rto time.Duration + rtt := h.rttStats.SmoothedRTT() + if rtt == 0 { + rto = defaultRTOTimeout + } else { + rto = rtt + 4*h.rttStats.MeanDeviation() + } + rto = utils.MaxDuration(rto, minRTOTimeout) + // Exponential backoff + rto = rto << h.rtoCount + return utils.MinDuration(rto, maxRTOTimeout) +} + +func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool { + for _, p := range h.skippedPackets { + if ackFrame.AcksPacket(p) { + return true + } + } + return false +} + +func (h *sentPacketHandler) garbageCollectSkippedPackets() { + lowestUnacked := h.lowestUnacked() + deleteIndex := 0 + for i, p := range h.skippedPackets { + if p < lowestUnacked { + deleteIndex = i + 1 + } + } + h.skippedPackets = h.skippedPackets[deleteIndex:] +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go new file mode 100644 index 0000000..38a2a0e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go @@ -0,0 +1,127 @@ +package ackhandler + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type sentPacketHistory struct { + packetList *PacketList + packetMap map[protocol.PacketNumber]*PacketElement + + firstOutstanding *PacketElement +} + +func newSentPacketHistory() *sentPacketHistory { + return &sentPacketHistory{ + packetList: NewPacketList(), + packetMap: make(map[protocol.PacketNumber]*PacketElement), + } +} + +func (h *sentPacketHistory) SentPacket(p *Packet) { + h.sentPacketImpl(p) +} + +func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement { + el := h.packetList.PushBack(*p) + h.packetMap[p.PacketNumber] = el + if h.firstOutstanding == nil { + h.firstOutstanding = el + } + return el +} + +func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) { + retransmission, ok := h.packetMap[retransmissionOf] + // The retransmitted packet is not present anymore. + // This can happen if it was acked in between dequeueing of the retransmission and sending. + // Just treat the retransmissions as normal packets. + // TODO: This won't happen if we clear packets queued for retransmission on new ACKs. + if !ok { + for _, packet := range packets { + h.sentPacketImpl(packet) + } + return + } + retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets)) + for i, packet := range packets { + retransmission.Value.retransmittedAs[i] = packet.PacketNumber + el := h.sentPacketImpl(packet) + el.Value.isRetransmission = true + el.Value.retransmissionOf = retransmissionOf + } +} + +func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet { + if el, ok := h.packetMap[p]; ok { + return &el.Value + } + return nil +} + +// Iterate iterates through all packets. +// The callback must not modify the history. +func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { + cont := true + for el := h.packetList.Front(); cont && el != nil; el = el.Next() { + var err error + cont, err = cb(&el.Value) + if err != nil { + return err + } + } + return nil +} + +// FirstOutStanding returns the first outstanding packet. +// It must not be modified (e.g. retransmitted). +// Use DequeueFirstPacketForRetransmission() to retransmit it. +func (h *sentPacketHistory) FirstOutstanding() *Packet { + if h.firstOutstanding == nil { + return nil + } + return &h.firstOutstanding.Value +} + +// QueuePacketForRetransmission marks a packet for retransmission. +// A packet can only be queued once. +func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error { + el, ok := h.packetMap[pn] + if !ok { + return fmt.Errorf("sent packet history: packet %d not found", pn) + } + el.Value.canBeRetransmitted = false + if el == h.firstOutstanding { + h.readjustFirstOutstanding() + } + return nil +} + +// readjustFirstOutstanding readjusts the pointer to the first outstanding packet. +// This is necessary every time the first outstanding packet is deleted or retransmitted. +func (h *sentPacketHistory) readjustFirstOutstanding() { + el := h.firstOutstanding.Next() + for el != nil && !el.Value.canBeRetransmitted { + el = el.Next() + } + h.firstOutstanding = el +} + +func (h *sentPacketHistory) Len() int { + return len(h.packetMap) +} + +func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { + el, ok := h.packetMap[p] + if !ok { + return fmt.Errorf("packet %d not found in sent packet history", p) + } + if el == h.firstOutstanding { + h.readjustFirstOutstanding() + } + h.packetList.Remove(el) + delete(h.packetMap, p) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go similarity index 90% rename from vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go rename to vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go index 04cb61f..40ad88c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go @@ -30,8 +30,9 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFr } func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) { - if ack.LargestAcked >= s.nextLeastUnacked { - s.nextLeastUnacked = ack.LargestAcked + 1 + largestAcked := ack.LargestAcked() + if largestAcked >= s.nextLeastUnacked { + s.nextLeastUnacked = largestAcked + 1 } } diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/clock.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/congestion/clock.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go similarity index 56% rename from vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go index 3922f47..dcf91fc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go @@ -16,11 +16,10 @@ import ( // allow a 10 shift right to divide. // 1024*1024^3 (first 1024 is from 0.100^3) -// where 0.100 is 100 ms which is the scaling -// round trip time. +// where 0.100 is 100 ms which is the scaling round trip time. const cubeScale = 40 const cubeCongestionWindowScale = 410 -const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale +const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS const defaultNumConnections = 2 @@ -32,39 +31,35 @@ const beta float32 = 0.7 // new concurrent flows and speed up convergence. const betaLastMax float32 = 0.85 -// If true, Cubic's epoch is shifted when the sender is application-limited. -const shiftQuicCubicEpochWhenAppLimited = true - -const maxCubicTimeInterval = 30 * time.Millisecond - // Cubic implements the cubic algorithm from TCP type Cubic struct { clock Clock + // Number of connections to simulate. numConnections int + // Time when this cycle started, after last loss event. epoch time.Time - // Time when sender went into application-limited period. Zero if not in - // application-limited period. - appLimitedStartTime time.Time - // Time when we updated last_congestion_window. - lastUpdateTime time.Time - // Last congestion window (in packets) used. - lastCongestionWindow protocol.PacketNumber - // Max congestion window (in packets) used just before last loss event. + + // Max congestion window used just before last loss event. // Note: to improve fairness to other streams an additional back off is // applied to this value if the new value is below our latest value. - lastMaxCongestionWindow protocol.PacketNumber - // Number of acked packets since the cycle started (epoch). - ackedPacketsCount protocol.PacketNumber + lastMaxCongestionWindow protocol.ByteCount + + // Number of acked bytes since the cycle started (epoch). + ackedBytesCount protocol.ByteCount + // TCP Reno equivalent congestion window in packets. - estimatedTCPcongestionWindow protocol.PacketNumber + estimatedTCPcongestionWindow protocol.ByteCount + // Origin point of cubic function. - originPointCongestionWindow protocol.PacketNumber + originPointCongestionWindow protocol.ByteCount + // Time to origin point of cubic function in 2^10 fractions of a second. timeToOriginPoint uint32 + // Last congestion window in packets computed by cubic function. - lastTargetCongestionWindow protocol.PacketNumber + lastTargetCongestionWindow protocol.ByteCount } // NewCubic returns a new Cubic instance @@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic { // Reset is called after a timeout to reset the cubic state func (c *Cubic) Reset() { c.epoch = time.Time{} - c.appLimitedStartTime = time.Time{} - c.lastUpdateTime = time.Time{} - c.lastCongestionWindow = 0 c.lastMaxCongestionWindow = 0 - c.ackedPacketsCount = 0 + c.ackedBytesCount = 0 c.estimatedTCPcongestionWindow = 0 c.originPointCongestionWindow = 0 c.timeToOriginPoint = 0 @@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 { return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) } +func (c *Cubic) betaLastMax() float32 { + // betaLastMax is the additional backoff factor after loss for our + // N-connection emulation, which emulates the additional backoff of + // an ensemble of N TCP-Reno connections on a single loss event. The + // effective multiplier is computed as: + return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) +} + // OnApplicationLimited is called on ack arrival when sender is unable to use // the available congestion window. Resets Cubic state during quiescence. func (c *Cubic) OnApplicationLimited() { - if shiftQuicCubicEpochWhenAppLimited { - // When sender is not using the available congestion window, Cubic's epoch - // should not continue growing. Record the time when sender goes into an - // app-limited period here, to compensate later when cwnd growth happens. - if c.appLimitedStartTime.IsZero() { - c.appLimitedStartTime = c.clock.Now() - } - } else { - // When sender is not using the available congestion window, Cubic's epoch - // should not continue growing. Reset the epoch when in such a period. - c.epoch = time.Time{} - } + // When sender is not using the available congestion window, the window does + // not grow. But to be RTT-independent, Cubic assumes that the sender has been + // using the entire window during the time since the beginning of the current + // "epoch" (the end of the last loss recovery period). Since + // application-limited periods break this assumption, we reset the epoch when + // in such a period. This reset effectively freezes congestion window growth + // through application-limited periods and allows Cubic growth to continue + // when the entire window is being used. + c.epoch = time.Time{} } // CongestionWindowAfterPacketLoss computes a new congestion window to use after // a loss event. Returns the new congestion window in packets. The new // congestion window is a multiplicative decrease of our current window. -func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber { - if currentCongestionWindow < c.lastMaxCongestionWindow { +func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount { + if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow { // We never reached the old max, so assume we are competing with another // flow. Use our extra back off factor to allow the other flow to go up. - c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow)) + c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) } else { c.lastMaxCongestionWindow = currentCongestionWindow } c.epoch = time.Time{} // Reset time. - return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta()) + return protocol.ByteCount(float32(currentCongestionWindow) * c.beta()) } // CongestionWindowAfterAck computes a new congestion window to use after a received ACK. // Returns the new congestion window in packets. The new congestion window // follows a cubic function that depends on the time passed since last // packet loss. -func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber { - c.ackedPacketsCount++ // Packets acked. - currentTime := c.clock.Now() - - // Cubic is "independent" of RTT, the update is limited by the time elapsed. - if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) { - return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow) - } - c.lastCongestionWindow = currentCongestionWindow - c.lastUpdateTime = currentTime +func (c *Cubic) CongestionWindowAfterAck( + ackedBytes protocol.ByteCount, + currentCongestionWindow protocol.ByteCount, + delayMin time.Duration, + eventTime time.Time, +) protocol.ByteCount { + c.ackedBytesCount += ackedBytes if c.epoch.IsZero() { // First ACK after a loss event. - c.epoch = currentTime // Start of epoch. - c.ackedPacketsCount = 1 // Reset count. + c.epoch = eventTime // Start of epoch. + c.ackedBytesCount = ackedBytes // Reset count. // Reset estimated_tcp_congestion_window_ to be in sync with cubic. c.estimatedTCPcongestionWindow = currentCongestionWindow if c.lastMaxCongestionWindow <= currentCongestionWindow { @@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) c.originPointCongestionWindow = c.lastMaxCongestionWindow } - } else { - // If sender was app-limited, then freeze congestion window growth during - // app-limited period. Continue growth now by shifting the epoch-start - // through the app-limited period. - if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() { - shift := currentTime.Sub(c.appLimitedStartTime) - c.epoch = c.epoch.Add(shift) - c.appLimitedStartTime = time.Time{} - } } // Change the time unit from microseconds to 2^10 fractions per second. Take // the round trip time in account. This is done to allow us to use shift as a // divide operator. - elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000 + elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) + // Right-shifts of negative, signed numbers have implementation-dependent + // behavior, so force the offset to be positive, as is done in the kernel. offset := int64(c.timeToOriginPoint) - elapsedTime - // Right-shifts of negative, signed numbers have - // implementation-dependent behavior. Force the offset to be - // positive, similar to the kernel implementation. if offset < 0 { offset = -offset } - deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale) - var targetCongestionWindow protocol.PacketNumber + + deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale + var targetCongestionWindow protocol.ByteCount if elapsedTime > int64(c.timeToOriginPoint) { targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow } else { targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow } - // With dynamic beta/alpha based on number of active streams, it is possible - // for the required_ack_count to become much lower than acked_packets_count_ - // suddenly, leading to more than one iteration through the following loop. - for { - // Update estimated TCP congestion_window. - requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha()) - if c.ackedPacketsCount < requiredAckCount { - break - } - c.ackedPacketsCount -= requiredAckCount - c.estimatedTCPcongestionWindow++ - } + // Limit the CWND increase to half the acked bytes. + targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + + // Increase the window by approximately Alpha * 1 MSS of bytes every + // time we ack an estimated tcp window of bytes. For small + // congestion windows (less than 25), the formula below will + // increase slightly slower than linearly per estimated tcp window + // of bytes. + c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow)) + c.ackedBytesCount = 0 // We have a new cubic congestion window. c.lastTargetCongestionWindow = targetCongestionWindow @@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet if targetCongestionWindow < c.estimatedTCPcongestionWindow { targetCongestionWindow = c.estimatedTCPcongestionWindow } - return targetCongestionWindow } diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go similarity index 67% rename from vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go index f2c8c2d..b9f67e6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go @@ -8,9 +8,9 @@ import ( ) const ( - maxBurstBytes = 3 * protocol.DefaultTCPMSS - defaultMinimumCongestionWindow protocol.PacketNumber = 2 - renoBeta float32 = 0.7 // Reno backoff factor. + maxBurstBytes = 3 * protocol.DefaultTCPMSS + renoBeta float32 = 0.7 // Reno backoff factor. + defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS ) type cubicSender struct { @@ -31,12 +31,6 @@ type cubicSender struct { // Track the largest packet number outstanding when a CWND cutback occurs. largestSentAtLastCutback protocol.PacketNumber - // Congestion window in packets. - congestionWindow protocol.PacketNumber - - // Slow start congestion window in packets, aka ssthresh. - slowstartThreshold protocol.PacketNumber - // Whether the last loss event caused us to exit slowstart. // Used for stats collection of slowstartPacketsLost lastCutbackExitedSlowstart bool @@ -44,24 +38,35 @@ type cubicSender struct { // When true, exit slow start with large cutback of congestion window. slowStartLargeReduction bool - // Minimum congestion window in packets. - minCongestionWindow protocol.PacketNumber + // Congestion window in packets. + congestionWindow protocol.ByteCount - // Maximum number of outstanding packets for tcp. - maxTCPCongestionWindow protocol.PacketNumber + // Minimum congestion window in packets. + minCongestionWindow protocol.ByteCount + + // Maximum congestion window. + maxCongestionWindow protocol.ByteCount + + // Slow start congestion window in bytes, aka ssthresh. + slowstartThreshold protocol.ByteCount // Number of connections to simulate. numConnections int // ACK counter for the Reno implementation. - congestionWindowCount protocol.ByteCount + numAckedPackets uint64 - initialCongestionWindow protocol.PacketNumber - initialMaxCongestionWindow protocol.PacketNumber + initialCongestionWindow protocol.ByteCount + initialMaxCongestionWindow protocol.ByteCount + + minSlowStartExitWindow protocol.ByteCount } +var _ SendAlgorithm = &cubicSender{} +var _ SendAlgorithmWithDebugInfo = &cubicSender{} + // NewCubicSender makes a new cubic sender -func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo { +func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo { return &cubicSender{ rttStats: rttStats, initialCongestionWindow: initialCongestionWindow, @@ -69,28 +74,37 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio congestionWindow: initialCongestionWindow, minCongestionWindow: defaultMinimumCongestionWindow, slowstartThreshold: initialMaxCongestionWindow, - maxTCPCongestionWindow: initialMaxCongestionWindow, + maxCongestionWindow: initialMaxCongestionWindow, numConnections: defaultNumConnections, cubic: NewCubic(clock), reno: reno, } } -func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration { +// TimeUntilSend returns when the next packet should be sent. +func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration { if c.InRecovery() { // PRR is used when in recovery. - return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) + if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) { + return 0 + } } - if c.GetCongestionWindow() > bytesInFlight { - return 0 + delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()) + if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt + delay = delay * 8 / 5 } - return utils.InfDuration + return delay } -func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool { - // Only update bytesInFlight for data packets. +func (c *cubicSender) OnPacketSent( + sentTime time.Time, + bytesInFlight protocol.ByteCount, + packetNumber protocol.PacketNumber, + bytes protocol.ByteCount, + isRetransmittable bool, +) { if !isRetransmittable { - return false + return } if c.InRecovery() { // PRR is used when in recovery. @@ -98,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By } c.largestSentPacketNumber = packetNumber c.hybridSlowStart.OnPacketSent(packetNumber) - return true } func (c *cubicSender) InRecovery() bool { @@ -110,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool { } func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { - return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS + return c.congestionWindow } func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount { - return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS + return c.slowstartThreshold } func (c *cubicSender) ExitSlowstart() { c.slowstartThreshold = c.congestionWindow } -func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber { +func (c *cubicSender) SlowstartThreshold() protocol.ByteCount { return c.slowstartThreshold } @@ -131,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() { } } -func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) OnPacketAcked( + ackedPacketNumber protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) if c.InRecovery() { // PRR is used when in recovery. c.prr.OnPacketAcked(ackedBytes) return } - c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight) + c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) if c.InSlowStart() { c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) } } -func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) OnPacketLost( + packetNumber protocol.PacketNumber, + lostBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, +) { // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { @@ -152,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes c.stats.slowstartPacketsLost++ c.stats.slowstartBytesLost += lostBytes if c.slowStartLargeReduction { - if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS { - // Reduce congestion window by 1 for every mss of bytes lost. - c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow) - } + // Reduce congestion window by lost_bytes for every loss. + c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow) c.slowstartThreshold = c.congestionWindow } } @@ -166,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes c.stats.slowstartPacketsLost++ } - c.prr.OnPacketLost(bytesInFlight) + c.prr.OnPacketLost(priorInFlight) // TODO(chromium): Separate out all of slow start into a separate class. if c.slowStartLargeReduction && c.InSlowStart() { - c.congestionWindow = c.congestionWindow - 1 + if c.congestionWindow >= 2*c.initialCongestionWindow { + c.minSlowStartExitWindow = c.congestionWindow / 2 + } + c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS } else if c.reno { - c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta()) + c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta()) } else { c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) } - // Enforce a minimum congestion window. if c.congestionWindow < c.minCongestionWindow { c.congestionWindow = c.minCongestionWindow } @@ -184,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes c.largestSentAtLastCutback = c.largestSentPacketNumber // reset packet count from congestion avoidance mode. We start // counting again when we're out of recovery. - c.congestionWindowCount = 0 + c.numAckedPackets = 0 } func (c *cubicSender) RenoBeta() float32 { @@ -197,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 { // Called when we receive an ack. Normal TCP tracks how many packets one ack // represents, but quic has a separate ack for each packet. -func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) { +func (c *cubicSender) maybeIncreaseCwnd( + ackedPacketNumber protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { // Do not increase the congestion window unless the sender is close to using // the current window. - if !c.isCwndLimited(bytesInFlight) { + if !c.isCwndLimited(priorInFlight) { c.cubic.OnApplicationLimited() return } - if c.congestionWindow >= c.maxTCPCongestionWindow { + if c.congestionWindow >= c.maxCongestionWindow { return } if c.InSlowStart() { // TCP slow start, exponential growth, increase by one for each ACK. - c.congestionWindow++ + c.congestionWindow += protocol.DefaultTCPMSS return } + // Congestion avoidance if c.reno { // Classic Reno congestion avoidance. - c.congestionWindowCount++ + c.numAckedPackets++ // Divide by num_connections to smoothly increase the CWND at a faster // rate than conventional Reno. - if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow { - c.congestionWindow++ - c.congestionWindowCount = 0 + if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) { + c.congestionWindow += protocol.DefaultTCPMSS + c.numAckedPackets = 0 } } else { - c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT())) + c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) } } @@ -278,21 +306,13 @@ func (c *cubicSender) OnConnectionMigration() { c.largestSentAtLastCutback = 0 c.lastCutbackExitedSlowstart = false c.cubic.Reset() - c.congestionWindowCount = 0 + c.numAckedPackets = 0 c.congestionWindow = c.initialCongestionWindow c.slowstartThreshold = c.initialMaxCongestionWindow - c.maxTCPCongestionWindow = c.initialMaxCongestionWindow + c.maxCongestionWindow = c.initialMaxCongestionWindow } // SetSlowStartLargeReduction allows enabling the SSLR experiment func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) { c.slowStartLargeReduction = enabled } - -// RetransmissionDelay gives the time to retransmission -func (c *cubicSender) RetransmissionDelay() time.Duration { - if c.rttStats.SmoothedRTT() == 0 { - return 0 - } - return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4 -} diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go similarity index 77% rename from vendor/github.com/lucas-clemente/quic-go/congestion/interface.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go index 411a5f2..7c27da6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go @@ -8,16 +8,15 @@ import ( // A SendAlgorithm performs congestion control and calculates the congestion window type SendAlgorithm interface { - TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration - OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool + TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration + OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) GetCongestionWindow() protocol.ByteCount MaybeExitSlowStart() - OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) - OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) + OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) + OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) SetNumEmulatedConnections(n int) OnRetransmissionTimeout(packetsRetransmitted bool) OnConnectionMigration() - RetransmissionDelay() time.Duration // Experiments SetSlowStartLargeReduction(enabled bool) @@ -31,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface { // Stuff only used in testing HybridSlowStart() *HybridSlowStart - SlowstartThreshold() protocol.PacketNumber + SlowstartThreshold() protocol.ByteCount RenoBeta() float32 InRecovery() bool } diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go similarity index 72% rename from vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go index 18a3736..5c807d1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/prr_sender.go @@ -1,10 +1,7 @@ package congestion import ( - "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" ) // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 @@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) { // OnPacketLost should be called on the first loss that triggers a recovery // period and all other methods in this class should only be called when in // recovery. -func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) { +func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) { p.bytesSentSinceLoss = 0 - p.bytesInFlightBeforeLoss = bytesInFlight + p.bytesInFlightBeforeLoss = priorInFlight p.bytesDeliveredSinceLoss = 0 p.ackCountSinceLoss = 0 } @@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) { p.ackCountSinceLoss++ } -// TimeUntilSend calculates the time until a packet can be sent -func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration { +// CanSend returns if packets can be sent +func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool { // Return QuicTime::Zero In order to ensure limited transmit always works. if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS { - return 0 + return true } if congestionWindow > bytesInFlight { // During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead // of sending the entire available window. This prevents burst retransmits // when more packets are lost than the CWND reduction. // limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS - if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss { - return utils.InfDuration - } - return 0 + return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss } // Implement Proportional Rate Reduction (RFC6937). // Checks a simplified version of the PRR formula that doesn't use division: // AvailableSendWindow = // CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent - if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss { - return 0 - } - return utils.InfDuration + return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go new file mode 100644 index 0000000..f0ebbb2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/rtt_stats.go @@ -0,0 +1,101 @@ +package congestion + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/utils" +) + +const ( + rttAlpha float32 = 0.125 + oneMinusAlpha float32 = (1 - rttAlpha) + rttBeta float32 = 0.25 + oneMinusBeta float32 = (1 - rttBeta) + // The default RTT used before an RTT sample is taken. + defaultInitialRTT = 100 * time.Millisecond +) + +// RTTStats provides round-trip statistics +type RTTStats struct { + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + meanDeviation time.Duration +} + +// NewRTTStats makes a properly initialized RTTStats object +func NewRTTStats() *RTTStats { + return &RTTStats{} +} + +// MinRTT Returns the minRTT for the entire connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } + +// LatestRTT returns the most recent rtt measurement. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } + +// SmoothedRTT returns the EWMA smoothed RTT for the connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } + +// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection. +// If no valid updates have occurred, it returns the initial RTT. +func (r *RTTStats) SmoothedOrInitialRTT() time.Duration { + if r.smoothedRTT != 0 { + return r.smoothedRTT + } + return defaultInitialRTT +} + +// MeanDeviation gets the mean deviation +func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } + +// UpdateRTT updates the RTT based on a new sample. +func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { + if sendDelta == utils.InfDuration || sendDelta <= 0 { + return + } + + // Update r.minRTT first. r.minRTT does not use an rttSample corrected for + // ackDelay but the raw observed sendDelta, since poor clock granularity at + // the client may cause a high ackDelay to result in underestimation of the + // r.minRTT. + if r.minRTT == 0 || r.minRTT > sendDelta { + r.minRTT = sendDelta + } + + // Correct for ackDelay if information received from the peer results in a + // an RTT sample at least as large as minRTT. Otherwise, only use the + // sendDelta. + sample := sendDelta + if sample-r.minRTT >= ackDelay { + sample -= ackDelay + } + r.latestRTT = sample + // First time call. + if r.smoothedRTT == 0 { + r.smoothedRTT = sample + r.meanDeviation = sample / 2 + } else { + r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond + r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond + } +} + +// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. +func (r *RTTStats) OnConnectionMigration() { + r.latestRTT = 0 + r.minRTT = 0 + r.smoothedRTT = 0 + r.meanDeviation = 0 +} + +// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt +// is larger. The mean deviation is increased to the most recent deviation if +// it's larger. +func (r *RTTStats) ExpireSmoothedMetrics() { + r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT)) + r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/stats.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/stats.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/congestion/stats.go rename to vendor/github.com/lucas-clemente/quic-go/internal/congestion/stats.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go index f3bc9fb..0c728fd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go @@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) { return cert.Certificate[0], nil } -func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { - c := cc.config - c, err := maybeGetConfigForClient(c, sni) +func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { + conf := c.config + conf, err := maybeGetConfigForClient(conf, sni) if err != nil { return nil, err } // The rest of this function is mostly copied from crypto/tls.getCertificate - if c.GetCertificate != nil { - cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) + if conf.GetCertificate != nil { + cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) if cert != nil || err != nil { return cert, err } } - if len(c.Certificates) == 0 { + if len(conf.Certificates) == 0 { return nil, errNoMatchingCertificate } - if len(c.Certificates) == 1 || c.NameToCertificate == nil { + if len(conf.Certificates) == 1 || conf.NameToCertificate == nil { // There's only one choice, so no point doing any work. - return &c.Certificates[0], nil + return &conf.Certificates[0], nil } name := strings.ToLower(sni) @@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { name = name[:len(name)-1] } - if cert, ok := c.NameToCertificate[name]; ok { + if cert, ok := conf.NameToCertificate[name]; ok { return cert, nil } @@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { for i := range labels { labels[i] = "*" candidate := strings.Join(labels, ".") - if cert, ok := c.NameToCertificate[candidate]; ok { + if cert, ok := conf.NameToCertificate[candidate]; ok { return cert, nil } } // If nothing matches, return the first certificate. - return &c.Certificates[0], nil + return &conf.Certificates[0], nil } func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go index 5aaa187..8b8c9fa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go @@ -18,6 +18,7 @@ type CertManager interface { GetLeafCertHash() (uint64, error) VerifyServerProof(proof, chlo, serverConfigData []byte) bool Verify(hostname string) error + GetChain() []*x509.Certificate } type certManager struct { @@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error { return nil } +func (c *certManager) GetChain() []*x509.Certificate { + return c.chain +} + func (c *certManager) GetCommonCertificateHashes() []byte { return getCommonCertificateHashes() } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go deleted file mode 100644 index 5d2e36f..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go +++ /dev/null @@ -1,61 +0,0 @@ -// +build ignore - -package crypto - -import ( - "crypto/cipher" - "encoding/binary" - "errors" - - "github.com/aead/chacha20" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -type aeadChacha20Poly1305 struct { - otherIV []byte - myIV []byte - encrypter cipher.AEAD - decrypter cipher.AEAD -} - -// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305 -func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { - if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 { - return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs") - } - // copy because ChaCha20Poly1305 expects array pointers - var MyKey, OtherKey [32]byte - copy(MyKey[:], myKey) - copy(OtherKey[:], otherKey) - - encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12) - if err != nil { - return nil, err - } - decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12) - if err != nil { - return nil, err - } - return &aeadChacha20Poly1305{ - otherIV: otherIV, - myIV: myIV, - encrypter: encrypter, - decrypter: decrypter, - }, nil -} - -func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) -} - -func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) -} - -func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { - res := make([]byte, 12) - copy(res[0:4], iv) - binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) - return res -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go deleted file mode 100644 index 9d5197b..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// +build ignore - -package crypto - -import ( - "crypto/rand" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Chacha20poly1305", func() { - var ( - alice, bob AEAD - keyAlice, keyBob, ivAlice, ivBob []byte - ) - - BeforeEach(func() { - keyAlice = make([]byte, 32) - keyBob = make([]byte, 32) - ivAlice = make([]byte, 4) - ivBob = make([]byte, 4) - rand.Reader.Read(keyAlice) - rand.Reader.Read(keyBob) - rand.Reader.Read(ivAlice) - rand.Reader.Read(ivBob) - var err error - alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice) - Expect(err).ToNot(HaveOccurred()) - bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob) - Expect(err).ToNot(HaveOccurred()) - }) - - It("seals and opens", func() { - b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) - text, err := bob.Open(nil, b, 42, []byte("aad")) - Expect(err).ToNot(HaveOccurred()) - Expect(text).To(Equal([]byte("foobar"))) - }) - - It("seals and opens reverse", func() { - b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) - text, err := alice.Open(nil, b, 42, []byte("aad")) - Expect(err).ToNot(HaveOccurred()) - Expect(text).To(Equal([]byte("foobar"))) - }) - - It("has the proper length", func() { - b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) - Expect(b).To(HaveLen(6 + 12)) - }) - - It("fails with wrong aad", func() { - b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) - _, err := bob.Open(nil, b, 42, []byte("aad2")) - Expect(err).To(HaveOccurred()) - }) - - It("rejects wrong key and iv sizes", func() { - var err error - e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs" - _, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice) - Expect(err).To(MatchError(e)) - _, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice) - Expect(err).To(MatchError(e)) - _, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice) - Expect(err).To(MatchError(e)) - _, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:]) - Expect(err).To(MatchError(e)) - }) -}) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go index a570d6b..fd25b00 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go @@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) { if _, err := rand.Read(c.secret[:]); err != nil { return nil, errors.New("Curve25519: could not create private key") } - // See https://cr.yp.to/ecdh.html - c.secret[0] &= 248 - c.secret[31] &= 127 - c.secret[31] |= 64 curve25519.ScalarBaseMult(&c.public, &c.secret) return c, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go index 316bd1b..89b9e1f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go @@ -1,13 +1,16 @@ package crypto import ( + "crypto" + "encoding/binary" + "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" ) const ( - clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret" - serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" + clientExporterLabel = "EXPORTER-QUIC client 1rtt" + serverExporterLabel = "EXPORTER-QUIC server 1rtt" ) // A TLSExporter gets the negotiated ciphersuite and computes exporter @@ -16,6 +19,14 @@ type TLSExporter interface { ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) } +func qhkdfExpand(secret []byte, label string, length int) []byte { + qlabel := make([]byte, 2+1+5+len(label)) + binary.BigEndian.PutUint16(qlabel[0:2], uint16(length)) + qlabel[2] = uint8(5 + len(label)) + copy(qlabel[3:], []byte("QUIC "+label)) + return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length) +} + // DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { var myLabel, otherLabel string @@ -43,7 +54,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) if err != nil { return nil, nil, err } - key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen) - iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen) + key = qhkdfExpand(secret, "key", cs.KeyLen) + iv = qhkdfExpand(secret, "iv", cs.IvLen) return key, iv, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go index 28f6c2c..6c29417 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go @@ -6,7 +6,6 @@ import ( "io" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "golang.org/x/crypto/hkdf" ) @@ -42,7 +41,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol } else { info.Write([]byte("QUIC key expansion\x00")) } - utils.BigEndian.WriteUint64(&info, uint64(connID)) + info.Write(connID) info.Write(chlo) info.Write(scfg) info.Write(cert) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go index a647ad7..4abc622 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go @@ -2,13 +2,12 @@ package crypto import ( "crypto" - "encoding/binary" "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" ) -var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39} +var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38} func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) { clientSecret, serverSecret := computeSecrets(connectionID) @@ -28,17 +27,15 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) } -func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { - connID := make([]byte, 8) - binary.BigEndian.PutUint64(connID, uint64(connectionID)) - cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID) - clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size()) - serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size()) +func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { + handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID) + clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size()) + serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size()) return } func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { - key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16) - iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12) + key = qhkdfExpand(secret, "key", 16) + iv = qhkdfExpand(secret, "iv", 12) return } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go index ecc4010..6c50ab9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go @@ -1,10 +1,11 @@ package crypto import ( - "encoding/binary" + "bytes" "errors" + "fmt" + "hash/fnv" - "github.com/lucas-clemente/fnv128a" "github.com/lucas-clemente/quic-go/internal/protocol" ) @@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long") } - hash := fnv128a.New() + hash := fnv.New128a() hash.Write(associatedData) hash.Write(src[12:]) if n.perspective == protocol.PerspectiveServer { @@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb } else { hash.Write([]byte("Server")) } - testHigh, testLow := hash.Sum128() + sum := make([]byte, 0, 16) + sum = hash.Sum(sum) + // The tag is written in little endian, so we need to reverse the slice. + reverse(sum) - low := binary.LittleEndian.Uint64(src) - high := binary.LittleEndian.Uint32(src[8:]) - - if uint32(testHigh&0xffffffff) != high || testLow != low { - return nil, errors.New("NullAEAD: failed to authenticate received data") + if !bytes.Equal(sum[:12], src[:12]) { + return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12]) } return src[12:], nil } @@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb dst = dst[:12+len(src)] } - hash := fnv128a.New() + hash := fnv.New128a() hash.Write(associatedData) hash.Write(src) @@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb } else { hash.Write([]byte("Client")) } - - high, low := hash.Sum128() + sum := make([]byte, 0, 16) + sum = hash.Sum(sum) + // The tag is written in little endian, so we need to reverse the slice. + reverse(sum) copy(dst[12:], src) - binary.LittleEndian.PutUint64(dst, low) - binary.LittleEndian.PutUint32(dst[8:], uint32(high)) + copy(dst, sum[:12]) return dst } func (n *nullAEADFNV128a) Overhead() int { return 12 } + +func reverse(a []byte) { + for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 { + a[left], a[right] = a[right], a[left] + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go deleted file mode 100644 index 3dcb26a..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go +++ /dev/null @@ -1,76 +0,0 @@ -package crypto - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha256" - "fmt" - "io" - - "golang.org/x/crypto/hkdf" -) - -// StkSource is used to create and verify source address tokens -type StkSource interface { - // NewToken creates a new token - NewToken([]byte) ([]byte, error) - // DecodeToken decodes a token - DecodeToken([]byte) ([]byte, error) -} - -type stkSource struct { - aead cipher.AEAD -} - -const stkKeySize = 16 - -// Chrome currently sets this to 12, but discusses changing it to 16. We start -// at 16 :) -const stkNonceSize = 16 - -// NewStkSource creates a source for source address tokens -func NewStkSource() (StkSource, error) { - secret := make([]byte, 32) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - key, err := deriveKey(secret) - if err != nil { - return nil, err - } - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize) - if err != nil { - return nil, err - } - return &stkSource{aead: aead}, nil -} - -func (s *stkSource) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, stkNonceSize) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } - return s.aead.Seal(nonce, nonce, data, nil), nil -} - -func (s *stkSource) DecodeToken(p []byte) ([]byte, error) { - if len(p) < stkNonceSize { - return nil, fmt.Errorf("STK too short: %d", len(p)) - } - nonce := p[:stkNonceSize] - return s.aead.Open(nil, nonce, p[stkNonceSize:], nil) -} - -func deriveKey(secret []byte) ([]byte, error) { - r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key")) - key := make([]byte, stkKeySize) - if _, err := io.ReadFull(r, key); err != nil { - return nil, err - } - return key, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go index e74c1d1..fb92f08 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go @@ -4,41 +4,38 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) type baseFlowController struct { - mutex sync.RWMutex - - rttStats *congestion.RTTStats - + // for sending data bytesSent protocol.ByteCount sendWindow protocol.ByteCount - lastWindowUpdateTime time.Time + // for receiving data + mutex sync.RWMutex + bytesRead protocol.ByteCount + highestReceived protocol.ByteCount + receiveWindow protocol.ByteCount + receiveWindowSize protocol.ByteCount + maxReceiveWindowSize protocol.ByteCount - bytesRead protocol.ByteCount - highestReceived protocol.ByteCount - receiveWindow protocol.ByteCount - receiveWindowIncrement protocol.ByteCount - maxReceiveWindowIncrement protocol.ByteCount + epochStartTime time.Time + epochStartOffset protocol.ByteCount + rttStats *congestion.RTTStats + + logger utils.Logger } func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { - c.mutex.Lock() - defer c.mutex.Unlock() - c.bytesSent += n } // UpdateSendWindow should be called after receiving a WindowUpdateFrame // it returns true if the window was actually updated func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { - c.mutex.Lock() - defer c.mutex.Unlock() - if offset > c.sendWindow { c.sendWindow = offset } @@ -57,52 +54,55 @@ func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) { defer c.mutex.Unlock() // pretend we sent a WindowUpdate when reading the first byte - // this way auto-tuning of the window increment already works for the first WindowUpdate + // this way auto-tuning of the window size already works for the first WindowUpdate if c.bytesRead == 0 { - c.lastWindowUpdateTime = time.Now() + c.startNewAutoTuningEpoch() } c.bytesRead += n } +func (c *baseFlowController) hasWindowUpdate() bool { + bytesRemaining := c.receiveWindow - c.bytesRead + // update the window when more than the threshold was consumed + return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold)))) +} + // getWindowUpdate updates the receive window, if necessary // it returns the new offset func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { - diff := c.receiveWindow - c.bytesRead - // update the window when more than half of it was already consumed - if diff >= (c.receiveWindowIncrement / 2) { + if !c.hasWindowUpdate() { return 0 } - c.maybeAdjustWindowIncrement() - c.receiveWindow = c.bytesRead + c.receiveWindowIncrement - c.lastWindowUpdateTime = time.Now() + c.maybeAdjustWindowSize() + c.receiveWindow = c.bytesRead + c.receiveWindowSize return c.receiveWindow } -func (c *baseFlowController) IsBlocked() bool { - c.mutex.RLock() - defer c.mutex.RUnlock() - - return c.sendWindowSize() == 0 -} - -// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often -func (c *baseFlowController) maybeAdjustWindowIncrement() { - if c.lastWindowUpdateTime.IsZero() { +// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. +// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. +func (c *baseFlowController) maybeAdjustWindowSize() { + bytesReadInEpoch := c.bytesRead - c.epochStartOffset + // don't do anything if less than half the window has been consumed + if bytesReadInEpoch <= c.receiveWindowSize/2 { return } - rtt := c.rttStats.SmoothedRTT() if rtt == 0 { return } - timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) - // interval between the window updates is sufficiently large, no need to increase the increment - if timeSinceLastWindowUpdate >= 2*rtt { - return + fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) + if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { + // window is consumed too fast, try to increase the window size + c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize) } - c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) + c.startNewAutoTuningEpoch() +} + +func (c *baseFlowController) startNewAutoTuningEpoch() { + c.epochStartTime = time.Now() + c.epochStartOffset = c.bytesRead } func (c *baseFlowController) checkFlowControlViolation() bool { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go index 934d646..ab565d2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -2,16 +2,18 @@ package flowcontrol import ( "fmt" - "time" - "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" ) type connectionFlowController struct { + lastBlockedAt protocol.ByteCount baseFlowController + + queueWindowUpdate func() } var _ ConnectionFlowController = &connectionFlowController{} @@ -21,25 +23,37 @@ var _ ConnectionFlowController = &connectionFlowController{} func NewConnectionFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, + queueWindowUpdate func(), rttStats *congestion.RTTStats, + logger utils.Logger, ) ConnectionFlowController { return &connectionFlowController{ baseFlowController: baseFlowController{ - rttStats: rttStats, - receiveWindow: receiveWindow, - receiveWindowIncrement: receiveWindow, - maxReceiveWindowIncrement: maxReceiveWindow, + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + logger: logger, }, + queueWindowUpdate: queueWindowUpdate, } } func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { - c.mutex.RLock() - defer c.mutex.RUnlock() - return c.baseFlowController.sendWindowSize() } +// IsNewlyBlocked says if it is newly blocked by flow control. +// For every offset, it only returns true once. +// If it is blocked, the offset is returned. +func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { + return false, 0 + } + c.lastBlockedAt = c.sendWindow + return true, c.sendWindow +} + // IncrementHighestReceived adds an increment to the highestReceived value func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { c.mutex.Lock() @@ -52,26 +66,34 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B return nil } +func (c *connectionFlowController) MaybeQueueWindowUpdate() { + c.mutex.Lock() + hasWindowUpdate := c.hasWindowUpdate() + c.mutex.Unlock() + if hasWindowUpdate { + c.queueWindowUpdate() + } +} + func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { c.mutex.Lock() - defer c.mutex.Unlock() - - oldWindowIncrement := c.receiveWindowIncrement + oldWindowSize := c.receiveWindowSize offset := c.baseFlowController.getWindowUpdate() - if oldWindowIncrement < c.receiveWindowIncrement { - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + if oldWindowSize < c.receiveWindowSize { + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) } + c.mutex.Unlock() return offset } -// EnsureMinimumWindowIncrement sets a minimum window increment +// EnsureMinimumWindowSize sets a minimum window size // it should make sure that the connection-level window is increased when a stream-level window grows -func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { +func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { c.mutex.Lock() - defer c.mutex.Unlock() - - if inc > c.receiveWindowIncrement { - c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) - c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update + if inc > c.receiveWindowSize { + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) + c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize) + c.startNewAutoTuningEpoch() } + c.mutex.Unlock() } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go index 75ec6fa..450d06a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go @@ -5,17 +5,19 @@ import "github.com/lucas-clemente/quic-go/internal/protocol" type flowController interface { // for sending SendWindowSize() protocol.ByteCount - IsBlocked() bool UpdateSendWindow(protocol.ByteCount) AddBytesSent(protocol.ByteCount) // for receiving AddBytesRead(protocol.ByteCount) GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary + MaybeQueueWindowUpdate() // queues a window update, if necessary } // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController + // for sending + IsBlocked() (bool, protocol.ByteCount) // for receiving // UpdateHighestReceived should be called when a new highest offset is received // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame @@ -25,13 +27,15 @@ type StreamFlowController interface { // The ConnectionFlowController is the flow controller for the connection. type ConnectionFlowController interface { flowController + // for sending + IsNewlyBlocked() (bool, protocol.ByteCount) } type connectionFlowControllerI interface { ConnectionFlowController // The following two methods are not supposed to be called from outside this packet, but are needed internally // for sending - EnsureMinimumWindowIncrement(protocol.ByteCount) + EnsureMinimumWindowSize(protocol.ByteCount) // for receiving IncrementHighestReceived(protocol.ByteCount) error } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go index 96e13dc..a394de0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -3,7 +3,7 @@ package flowcontrol import ( "fmt" - "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" @@ -14,6 +14,8 @@ type streamFlowController struct { streamID protocol.StreamID + queueWindowUpdate func() + connection connectionFlowControllerI contributesToConnection bool // does the stream contribute to connection level flow control @@ -30,18 +32,22 @@ func NewStreamFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, initialSendWindow protocol.ByteCount, + queueWindowUpdate func(protocol.StreamID), rttStats *congestion.RTTStats, + logger utils.Logger, ) StreamFlowController { return &streamFlowController{ streamID: streamID, contributesToConnection: contributesToConnection, connection: cfc.(connectionFlowControllerI), + queueWindowUpdate: func() { queueWindowUpdate(streamID) }, baseFlowController: baseFlowController{ - rttStats: rttStats, - receiveWindow: receiveWindow, - receiveWindowIncrement: receiveWindow, - maxReceiveWindowIncrement: maxReceiveWindow, - sendWindow: initialSendWindow, + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + sendWindow: initialSendWindow, + logger: logger, }, } } @@ -102,9 +108,6 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { } func (c *streamFlowController) SendWindowSize() protocol.ByteCount { - c.mutex.Lock() - defer c.mutex.Unlock() - window := c.baseFlowController.sendWindowSize() if c.contributesToConnection { window = utils.MinByteCount(window, c.connection.SendWindowSize()) @@ -112,17 +115,44 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount { return window } -func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { - c.mutex.Lock() - defer c.mutex.Unlock() +// IsBlocked says if it is blocked by stream-level flow control. +// If it is blocked, the offset is returned. +func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 { + return false, 0 + } + return true, c.sendWindow +} - oldWindowIncrement := c.receiveWindowIncrement +func (c *streamFlowController) MaybeQueueWindowUpdate() { + c.mutex.Lock() + hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate() + c.mutex.Unlock() + if hasWindowUpdate { + c.queueWindowUpdate() + } + if c.contributesToConnection { + c.connection.MaybeQueueWindowUpdate() + } +} + +func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { + // don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler + c.mutex.Lock() + // if we already received the final offset for this stream, the peer won't need any additional flow control credit + if c.receivedFinalOffset { + c.mutex.Unlock() + return 0 + } + + oldWindowSize := c.receiveWindowSize offset := c.baseFlowController.getWindowUpdate() - if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size + c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10)) if c.contributesToConnection { - c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)) + c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) } } + c.mutex.Unlock() return offset } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go index 10281fa..97accb7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go @@ -6,7 +6,7 @@ import ( "net" "time" - "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/bifurcation/mint" ) const ( @@ -29,17 +29,17 @@ type token struct { // A CookieGenerator generates Cookies type CookieGenerator struct { - cookieSource crypto.StkSource + cookieProtector mint.CookieProtector } // NewCookieGenerator initializes a new CookieGenerator func NewCookieGenerator() (*CookieGenerator, error) { - stkSource, err := crypto.NewStkSource() + cookieProtector, err := mint.NewDefaultCookieProtector() if err != nil { return nil, err } return &CookieGenerator{ - cookieSource: stkSource, + cookieProtector: cookieProtector, }, nil } @@ -52,7 +52,7 @@ func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) { if err != nil { return nil, err } - return g.cookieSource.NewToken(data) + return g.cookieProtector.NewToken(data) } // DecodeToken decodes a Cookie @@ -62,7 +62,7 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) { return nil, nil } - data, err := g.cookieSource.DecodeToken(encrypted) + data, err := g.cookieProtector.DecodeToken(encrypted) if err != nil { return nil, err } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go index 317f6e5..bc2bd8e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go @@ -7,36 +7,44 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -type cookieHandler struct { - callback func(net.Addr, *Cookie) bool - +// A CookieHandler generates and validates cookies. +// The cookie is sent in the TLS Retry. +// By including the cookie in its ClientHello, a client can proof ownership of its source address. +type CookieHandler struct { + callback func(net.Addr, *Cookie) bool cookieGenerator *CookieGenerator + + logger utils.Logger } -var _ mint.CookieHandler = &cookieHandler{} +var _ mint.CookieHandler = &CookieHandler{} -func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) { +// NewCookieHandler creates a new CookieHandler. +func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) { cookieGenerator, err := NewCookieGenerator() if err != nil { return nil, err } - return &cookieHandler{ + return &CookieHandler{ callback: callback, cookieGenerator: cookieGenerator, + logger: logger, }, nil } -func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) { +// Generate a new cookie for a mint connection. +func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { if h.callback(conn.RemoteAddr(), nil) { return nil, nil } return h.cookieGenerator.NewToken(conn.RemoteAddr()) } -func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool { +// Validate a cookie. +func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool { data, err := h.cookieGenerator.DecodeToken(token) if err != nil { - utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) + h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) return false } return h.callback(conn.RemoteAddr(), data) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go index c923bbc..1324d50 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go @@ -38,23 +38,24 @@ type cryptoSetupClient struct { lastSentCHLO []byte certManager crypto.CertManager - divNonceChan chan []byte + divNonceChan chan struct{} diversificationNonce []byte clientHelloCounter int serverVerified bool // has the certificate chain and the proof already been verified keyDerivation QuicCryptoKeyDerivationFunction - keyExchange KeyExchangeFunction receivedSecurePacket bool nullAEAD crypto.AEAD secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - paramsChan chan<- TransportParameters - aeadChanged chan<- protocol.EncryptionLevel + paramsChan chan<- TransportParameters + handshakeEvent chan<- struct{} params *TransportParameters + + logger utils.Logger } var _ CryptoSetup = &cryptoSetupClient{} @@ -74,15 +75,17 @@ func NewCryptoSetupClient( tlsConfig *tls.Config, params *TransportParameters, paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, + logger utils.Logger, ) (CryptoSetup, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) if err != nil { return nil, err } - return &cryptoSetupClient{ + divNonceChan := make(chan struct{}) + cs := &cryptoSetupClient{ cryptoStream: cryptoStream, hostname: hostname, connID: connID, @@ -90,19 +93,20 @@ func NewCryptoSetupClient( certManager: crypto.NewCertManager(tlsConfig), params: params, keyDerivation: crypto.DeriveQuicCryptoAESKeys, - keyExchange: getEphermalKEX, nullAEAD: nullAEAD, paramsChan: paramsChan, - aeadChanged: aeadChanged, + handshakeEvent: handshakeEvent, initialVersion: initialVersion, negotiatedVersions: negotiatedVersions, - divNonceChan: make(chan []byte), - }, nil + divNonceChan: divNonceChan, + logger: logger, + } + return cs, nil } func (h *cryptoSetupClient) HandleCryptoStream() error { messageChan := make(chan HandshakeMessage) - errorChan := make(chan error) + errorChan := make(chan error, 1) go func() { for { @@ -116,37 +120,30 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { }() for { - err := h.maybeUpgradeCrypto() - if err != nil { + if err := h.maybeUpgradeCrypto(); err != nil { return err } h.mutex.RLock() sendCHLO := h.secureAEAD == nil h.mutex.RUnlock() - if sendCHLO { - err = h.sendCHLO() - if err != nil { + if err := h.sendCHLO(); err != nil { return err } } var message HandshakeMessage select { - case divNonce := <-h.divNonceChan: - if len(h.diversificationNonce) != 0 && !bytes.Equal(h.diversificationNonce, divNonce) { - return errConflictingDiversificationNonces - } - h.diversificationNonce = divNonce + case <-h.divNonceChan: // there's no message to process, but we should try upgrading the crypto again continue case message = <-messageChan: - case err = <-errorChan: + case err := <-errorChan: return err } - utils.Debugf("Got %s", message) + h.logger.Debugf("Got %s", message) switch message.Tag { case TagREJ: if err := h.handleREJMessage(message.Data); err != nil { @@ -159,8 +156,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { } // blocks until the session has received the parameters h.paramsChan <- *params - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) default: return qerr.InvalidCryptoMessageType } @@ -211,7 +208,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { err = h.certManager.Verify(h.hostname) if err != nil { - utils.Infof("Certificate validation failed: %s", err.Error()) + h.logger.Infof("Certificate validation failed: %s", err.Error()) return qerr.ProofInvalid } } @@ -219,7 +216,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil { validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get()) if !validProof { - utils.Infof("Server proof verification failed") + h.logger.Infof("Server proof verification failed") return qerr.ProofInvalid } @@ -277,6 +274,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*Trans if err != nil { return nil, err } + h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.") params, err := readHelloMap(cryptoData) if err != nil { @@ -322,6 +320,7 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu if h.secureAEAD != nil { data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { + h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.") h.receivedSecurePacket = true return data, protocol.EncryptionSecure, nil } @@ -373,16 +372,28 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry return nil, errors.New("CryptoSetupClient: no encryption level specified") } -func (h *cryptoSetupClient) DiversificationNonce() []byte { - panic("not needed for cryptoSetupClient") +func (h *cryptoSetupClient) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + return ConnectionState{ + HandshakeComplete: h.forwardSecureAEAD != nil, + PeerCertificates: h.certManager.GetChain(), + } } -func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) { - h.divNonceChan <- data -} - -func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType { - panic("not needed for cryptoSetupServer") +func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error { + h.mutex.Lock() + if len(h.diversificationNonce) > 0 { + defer h.mutex.Unlock() + if !bytes.Equal(h.diversificationNonce, divNonce) { + return errConflictingDiversificationNonces + } + return nil + } + h.diversificationNonce = divNonce + h.mutex.Unlock() + h.divNonceChan <- struct{}{} + return nil } func (h *cryptoSetupClient) sendCHLO() error { @@ -403,7 +414,7 @@ func (h *cryptoSetupClient) sendCHLO() error { Data: tags, } - utils.Debugf("Sending %s", message) + h.logger.Debugf("Sending %s", message) message.Write(b) _, err = h.cryptoStream.Write(b.Bytes()) @@ -462,7 +473,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) { for _, tag := range tags { size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data } - paddingSize := protocol.ClientHelloMinimumSize - size + paddingSize := protocol.MinClientHelloSize - size if paddingSize > 0 { tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize) } @@ -500,10 +511,9 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error { if err != nil { return err } - - h.aeadChanged <- protocol.EncryptionSecure + h.logger.Debugf("Creating AEAD for secure encryption.") + h.handshakeEvent <- struct{}{} } - return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go index 50e2618..952237e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go @@ -19,10 +19,12 @@ import ( type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) // KeyExchangeFunction is used to make a new KEX -type KeyExchangeFunction func() crypto.KeyExchange +type KeyExchangeFunction func() (crypto.KeyExchange, error) // The CryptoSetupServer handles all things crypto for the Session type cryptoSetupServer struct { + mutex sync.RWMutex + connID protocol.ConnectionID remoteAddr net.Addr scfg *ServerConfig @@ -42,7 +44,7 @@ type cryptoSetupServer struct { receivedParams bool paramsChan chan<- TransportParameters - aeadChanged chan<- protocol.EncryptionLevel + handshakeEvent chan<- struct{} keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction @@ -51,7 +53,9 @@ type cryptoSetupServer struct { params *TransportParameters - mutex sync.RWMutex + sni string // need to fill out the ConnectionState + + logger utils.Logger } var _ CryptoSetup = &cryptoSetupServer{} @@ -71,32 +75,36 @@ func NewCryptoSetup( connID protocol.ConnectionID, remoteAddr net.Addr, version protocol.VersionNumber, + divNonce []byte, scfg *ServerConfig, params *TransportParameters, supportedVersions []protocol.VersionNumber, acceptSTK func(net.Addr, *Cookie) bool, paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, + handshakeEvent chan<- struct{}, + logger utils.Logger, ) (CryptoSetup, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) if err != nil { return nil, err } return &cryptoSetupServer{ - cryptoStream: cryptoStream, - connID: connID, - remoteAddr: remoteAddr, - version: version, - supportedVersions: supportedVersions, - scfg: scfg, - keyDerivation: crypto.DeriveQuicCryptoAESKeys, - keyExchange: getEphermalKEX, - nullAEAD: nullAEAD, - params: params, - acceptSTKCallback: acceptSTK, - sentSHLO: make(chan struct{}), - paramsChan: paramsChan, - aeadChanged: aeadChanged, + cryptoStream: cryptoStream, + connID: connID, + remoteAddr: remoteAddr, + version: version, + supportedVersions: supportedVersions, + diversificationNonce: divNonce, + scfg: scfg, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: nullAEAD, + params: params, + acceptSTKCallback: acceptSTK, + sentSHLO: make(chan struct{}), + paramsChan: paramsChan, + handshakeEvent: handshakeEvent, + logger: logger, }, nil } @@ -112,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error { return qerr.InvalidCryptoMessageType } - utils.Debugf("Got %s", message) + h.logger.Debugf("Got %s", message) done, err := h.handleMessage(chloData.Bytes(), message.Data) if err != nil { return err @@ -139,6 +147,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if sni == "" { return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") } + h.sni = sni // prevent version downgrade attacks // see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples @@ -182,7 +191,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if _, err := h.cryptoStream.Write(reply); err != nil { return false, err } - h.aeadChanged <- protocol.EncryptionForwardSecure + h.handshakeEvent <- struct{}{} close(h.sentSHLO) return true, nil } @@ -205,10 +214,11 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client + h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.") h.receivedForwardSecurePacket = true - // wait until protocol.EncryptionForwardSecure was sent on the aeadChan + // wait for the send on the handshakeEvent chan <-h.sentSHLO - close(h.aeadChanged) + close(h.handshakeEvent) } return res, protocol.EncryptionForwardSecure, nil } @@ -219,6 +229,7 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if h.secureAEAD != nil { res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { + h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.") h.receivedSecurePacket = true return res, protocol.EncryptionSecure, nil } @@ -294,17 +305,13 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt func (h *cryptoSetupServer) acceptSTK(token []byte) bool { stk, err := h.scfg.cookieGenerator.DecodeToken(token) if err != nil { - utils.Debugf("STK invalid: %s", err.Error()) + h.logger.Debugf("STK invalid: %s", err.Error()) return false } return h.acceptSTKCallback(h.remoteAddr, stk) } func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { - if len(chlo) < protocol.ClientHelloMinimumSize { - return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small") - } - token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr) if err != nil { return nil, err @@ -341,7 +348,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa var serverReply bytes.Buffer message.Write(&serverReply) - utils.Debugf("Sending %s", message) + h.logger.Debugf("Sending %s", message) return serverReply.Bytes(), nil } @@ -365,11 +372,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T return nil, err } - h.diversificationNonce = make([]byte, 32) - if _, err = rand.Read(h.diversificationNonce); err != nil { - return nil, err - } - clientNonce := cryptoData[TagNONC] err = h.validateClientNonce(clientNonce) if err != nil { @@ -400,14 +402,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T if err != nil { return nil, err } - - h.aeadChanged <- protocol.EncryptionSecure + h.logger.Debugf("Creating AEAD for secure encryption.") + h.handshakeEvent <- struct{}{} // Generate a new curve instance to derive the forward secure key var fsNonce bytes.Buffer fsNonce.Write(clientNonce) fsNonce.Write(serverNonce) - ephermalKex := h.keyExchange() + ephermalKex, err := h.keyExchange() + if err != nil { + return nil, err + } ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) if err != nil { return nil, err @@ -427,6 +432,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T if err != nil { return nil, err } + h.logger.Debugf("Creating AEAD for forward-secure encryption.") replyMap := h.params.getHelloMap() // add crypto parameters @@ -445,21 +451,17 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T } var reply bytes.Buffer message.Write(&reply) - utils.Debugf("Sending %s", message) + h.logger.Debugf("Sending %s", message) return reply.Bytes(), nil } -// DiversificationNonce returns the diversification nonce -func (h *cryptoSetupServer) DiversificationNonce() []byte { - return h.diversificationNonce -} - -func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) { - panic("not needed for cryptoSetupServer") -} - -func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType { - panic("not needed for cryptoSetupServer") +func (h *cryptoSetupServer) ConnectionState() ConnectionState { + h.mutex.Lock() + defer h.mutex.Unlock() + return ConnectionState{ + ServerName: h.sni, + HandshakeComplete: h.receivedForwardSecurePacket, + } } func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go index e14e7ad..43f8540 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go @@ -1,10 +1,9 @@ package handshake import ( - "crypto/tls" + "errors" "fmt" "io" - "net" "sync" "github.com/bifurcation/mint" @@ -12,6 +11,9 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) +// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry +var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry") + // KeyDerivationFunction is used for key derivation type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) @@ -20,64 +22,33 @@ type cryptoSetupTLS struct { perspective protocol.Perspective - tls mintTLS - conn *fakeConn - - nextPacketType protocol.PacketType - keyDerivation KeyDerivationFunction nullAEAD crypto.AEAD aead crypto.AEAD - aeadChanged chan<- protocol.EncryptionLevel + tls MintTLS + cryptoStream *CryptoStreamConn + handshakeEvent chan<- struct{} } +var _ CryptoSetupTLS = &cryptoSetupTLS{} + // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server func NewCryptoSetupTLSServer( - cryptoStream io.ReadWriter, - connID protocol.ConnectionID, - tlsConfig *tls.Config, - remoteAddr net.Addr, - params *TransportParameters, - paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, - checkCookie func(net.Addr, *Cookie) bool, - supportedVersions []protocol.VersionNumber, + tls MintTLS, + cryptoStream *CryptoStreamConn, + nullAEAD crypto.AEAD, + handshakeEvent chan<- struct{}, version protocol.VersionNumber, -) (CryptoSetup, error) { - mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer) - if err != nil { - return nil, err - } - mintConf.RequireCookie = true - mintConf.CookieHandler, err = newCookieHandler(checkCookie) - if err != nil { - return nil, err - } - conn := &fakeConn{ - stream: cryptoStream, - pers: protocol.PerspectiveServer, - remoteAddr: remoteAddr, - } - mintConn := mint.Server(conn, mintConf) - eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version) - if err := mintConn.SetExtensionHandler(eh); err != nil { - return nil, err - } - - nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) - if err != nil { - return nil, err - } - +) CryptoSetupTLS { return &cryptoSetupTLS{ - perspective: protocol.PerspectiveServer, - tls: &mintController{mintConn}, - conn: conn, - nullAEAD: nullAEAD, - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, - }, nil + tls: tls, + cryptoStream: cryptoStream, + nullAEAD: nullAEAD, + perspective: protocol.PerspectiveServer, + keyDerivation: crypto.DeriveAESKeys, + handshakeEvent: handshakeEvent, + } } // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client @@ -85,59 +56,44 @@ func NewCryptoSetupTLSClient( cryptoStream io.ReadWriter, connID protocol.ConnectionID, hostname string, - tlsConfig *tls.Config, - params *TransportParameters, - paramsChan chan<- TransportParameters, - aeadChanged chan<- protocol.EncryptionLevel, - initialVersion protocol.VersionNumber, - supportedVersions []protocol.VersionNumber, + handshakeEvent chan<- struct{}, + tls MintTLS, version protocol.VersionNumber, -) (CryptoSetup, error) { - mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient) - if err != nil { - return nil, err - } - mintConf.ServerName = hostname - conn := &fakeConn{ - stream: cryptoStream, - pers: protocol.PerspectiveClient, - } - mintConn := mint.Client(conn, mintConf) - eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version) - if err := mintConn.SetExtensionHandler(eh); err != nil { - return nil, err - } - +) (CryptoSetupTLS, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) if err != nil { return nil, err } return &cryptoSetupTLS{ - conn: conn, perspective: protocol.PerspectiveClient, - tls: &mintController{mintConn}, + tls: tls, nullAEAD: nullAEAD, keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, - nextPacketType: protocol.PacketTypeInitial, + handshakeEvent: handshakeEvent, }, nil } func (h *cryptoSetupTLS) HandleCryptoStream() error { + if h.perspective == protocol.PerspectiveServer { + // mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer + // send out that data now + if _, err := h.cryptoStream.Flush(); err != nil { + return err + } + } + handshakeLoop: for { - switch alert := h.tls.Handshake(); alert { - case mint.AlertNoAlert: // handshake complete - break handshakeLoop - case mint.AlertWouldBlock: - h.determineNextPacketType() - if err := h.conn.Continue(); err != nil { - return err - } - default: + if alert := h.tls.Handshake(); alert != mint.AlertNoAlert { return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) } + switch h.tls.State() { + case mint.StateClientStart: // this happens if a stateless retry is performed + return ErrCloseSessionForRetry + case mint.StateClientConnected, mint.StateServerConnected: + break handshakeLoop + } } aead, err := h.keyDerivation(h.tls, h.perspective) @@ -148,28 +104,23 @@ handshakeLoop: h.aead = aead h.mutex.Unlock() - // signal to the outside world that the handshake completed - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) + h.handshakeEvent <- struct{}{} + close(h.handshakeEvent) return nil } -func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { +func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + return h.nullAEAD.Open(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { h.mutex.RLock() defer h.mutex.RUnlock() - if h.aead != nil { - data, err := h.aead.Open(dst, src, packetNumber, associatedData) - if err != nil { - return nil, protocol.EncryptionUnspecified, err - } - return data, protocol.EncryptionForwardSecure, nil + if h.aead == nil { + return nil, errors.New("no 1-RTT sealer") } - data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData) - if err != nil { - return nil, protocol.EncryptionUnspecified, err - } - return data, protocol.EncryptionUnencrypted, nil + return h.aead.Open(dst, src, packetNumber, associatedData) } func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) { @@ -204,39 +155,13 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S return protocol.EncryptionUnencrypted, h.nullAEAD } -func (h *cryptoSetupTLS) determineNextPacketType() error { +func (h *cryptoSetupTLS) ConnectionState() ConnectionState { h.mutex.Lock() defer h.mutex.Unlock() - state := h.tls.State().HandshakeState - if h.perspective == protocol.PerspectiveServer { - switch state { - case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest - h.nextPacketType = protocol.PacketTypeRetry - case "ServerStateWaitFinished": - h.nextPacketType = protocol.PacketTypeHandshake - default: - // TODO: accept 0-RTT data - return fmt.Errorf("Unexpected handshake state: %s", state) - } - return nil + mintConnState := h.tls.ConnectionState() + return ConnectionState{ + // TODO: set the ServerName, once mint exports it + HandshakeComplete: h.aead != nil, + PeerCertificates: mintConnState.PeerCertificates, } - // client - if state != "ClientStateWaitSH" { - h.nextPacketType = protocol.PacketTypeHandshake - } - return nil -} - -func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.nextPacketType -} - -func (h *cryptoSetupTLS) DiversificationNonce() []byte { - panic("diversification nonce not needed for TLS") -} - -func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) { - panic("diversification nonce not needed for TLS") } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go new file mode 100644 index 0000000..03825c4 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go @@ -0,0 +1,101 @@ +package handshake + +import ( + "bytes" + "io" + "net" + "time" +) + +// The CryptoStreamConn is used as the net.Conn passed to mint. +// It has two operating modes: +// 1. It can read and write to bytes.Buffers. +// 2. It can use a quic.Stream for reading and writing. +// The buffer-mode is only used by the server, in order to statelessly handle retries. +type CryptoStreamConn struct { + remoteAddr net.Addr + + // the buffers are used before the session is initialized + readBuf bytes.Buffer + writeBuf bytes.Buffer + + // stream will be set once the session is initialized + stream io.ReadWriter +} + +var _ net.Conn = &CryptoStreamConn{} + +// NewCryptoStreamConn creates a new CryptoStreamConn +func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn { + return &CryptoStreamConn{remoteAddr: remoteAddr} +} + +func (c *CryptoStreamConn) Read(b []byte) (int, error) { + if c.stream != nil { + return c.stream.Read(b) + } + return c.readBuf.Read(b) +} + +// AddDataForReading adds data to the read buffer. +// This data will ONLY be read when the stream has not been set. +func (c *CryptoStreamConn) AddDataForReading(data []byte) { + c.readBuf.Write(data) +} + +func (c *CryptoStreamConn) Write(p []byte) (int, error) { + if c.stream != nil { + return c.stream.Write(p) + } + return c.writeBuf.Write(p) +} + +// GetDataForWriting returns all data currently in the write buffer, and resets this buffer. +func (c *CryptoStreamConn) GetDataForWriting() []byte { + defer c.writeBuf.Reset() + data := make([]byte, c.writeBuf.Len()) + copy(data, c.writeBuf.Bytes()) + return data +} + +// SetStream sets the stream. +// After setting the stream, the read and write buffer won't be used any more. +func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) { + c.stream = stream +} + +// Flush copies the contents of the write buffer to the stream +func (c *CryptoStreamConn) Flush() (int, error) { + n, err := io.Copy(c.stream, &c.writeBuf) + return int(n), err +} + +// Close is not implemented +func (c *CryptoStreamConn) Close() error { + return nil +} + +// LocalAddr is not implemented +func (c *CryptoStreamConn) LocalAddr() net.Addr { + return nil +} + +// RemoteAddr returns the remote address +func (c *CryptoStreamConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +// SetReadDeadline is not implemented +func (c *CryptoStreamConn) SetReadDeadline(time.Time) error { + return nil +} + +// SetWriteDeadline is not implemented +func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error { + return nil +} + +// SetDeadline is not implemented +func (c *CryptoStreamConn) SetDeadline(time.Time) error { + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go index 3bccbef..eb1824d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go @@ -6,7 +6,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" ) var ( @@ -24,27 +23,26 @@ var ( // used for all connections for 60 seconds is negligible. Thus we can amortise // the Diffie-Hellman key generation at the server over all the connections in a // small time span. -func getEphermalKEX() (res crypto.KeyExchange) { +func getEphermalKEX() (crypto.KeyExchange, error) { kexMutex.RLock() - res = kexCurrent + res := kexCurrent t := kexCurrentTime kexMutex.RUnlock() if res != nil && time.Since(t) < kexLifetime { - return res + return res, nil } kexMutex.Lock() defer kexMutex.Unlock() // Check if still unfulfilled - if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime { + if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime { kex, err := crypto.NewCurve25519KEX() if err != nil { - utils.Errorf("could not set KEX: %s", err.Error()) - return kexCurrent + return nil, err } kexCurrent = kex kexCurrentTime = time.Now() - return kexCurrent + return kexCurrent, nil } - return kexCurrent + return kexCurrent, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go index c09db26..cfbd219 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go @@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) { offset := uint32(0) for i, t := range h.getTagsSorted() { - v := data[Tag(t)] + v := data[t] b.Write(v) offset += uint32(len(v)) binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t)) @@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag { func (h HandshakeMessage) String() string { var pad string res := tagToString(h.Tag) + ":\n" - for _, t := range h.getTagsSorted() { - tag := Tag(t) + for _, tag := range h.getTagsSorted() { if tag == TagPAD { pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag])) } else { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go index c34c8f1..8d8fd54 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go @@ -1,6 +1,11 @@ package handshake import ( + "crypto/x509" + "io" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" ) @@ -10,16 +15,54 @@ type Sealer interface { Overhead() int } -// CryptoSetup is a crypto setup -type CryptoSetup interface { - Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) +// A TLSExtensionHandler sends and received the QUIC TLS extension. +// It provides the parameters sent by the peer on a channel. +type TLSExtensionHandler interface { + Send(mint.HandshakeType, *mint.ExtensionList) error + Receive(mint.HandshakeType, *mint.ExtensionList) error + GetPeerParams() <-chan TransportParameters +} + +// MintTLS combines some methods needed to interact with mint. +type MintTLS interface { + crypto.TLSExporter + + // additional methods + Handshake() mint.Alert + State() mint.State + ConnectionState() mint.ConnectionState + + SetCryptoStream(io.ReadWriter) +} + +type baseCryptoSetup interface { HandleCryptoStream() error - // TODO: clean up this interface - DiversificationNonce() []byte // only needed for cryptoSetupServer - SetDiversificationNonce([]byte) // only needed for cryptoSetupClient - GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer + ConnectionState() ConnectionState GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) } + +// CryptoSetup is the crypto setup used by gQUIC +type CryptoSetup interface { + baseCryptoSetup + + Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) +} + +// CryptoSetupTLS is the crypto setup used by IETF QUIC +type CryptoSetupTLS interface { + baseCryptoSetup + + OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) + Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) +} + +// ConnectionState records basic details about the QUIC connection. +// Warning: This API should not be considered stable and might change soon. +type ConnectionState struct { + HandshakeComplete bool // handshake is complete + ServerName string // server name requested by client, if any (server side only) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go deleted file mode 100644 index 8c3a83b..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go +++ /dev/null @@ -1,127 +0,0 @@ -package handshake - -import ( - "bytes" - gocrypto "crypto" - "crypto/tls" - "crypto/x509" - "io" - "net" - "time" - - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/crypto" - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { - mconf := &mint.Config{ - NonBlocking: true, - CipherSuites: []mint.CipherSuite{ - mint.TLS_AES_128_GCM_SHA256, - mint.TLS_AES_256_GCM_SHA384, - }, - } - if tlsConf != nil { - mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) - for i, certChain := range tlsConf.Certificates { - mconf.Certificates[i] = &mint.Certificate{ - Chain: make([]*x509.Certificate, len(certChain.Certificate)), - PrivateKey: certChain.PrivateKey.(gocrypto.Signer), - } - for j, cert := range certChain.Certificate { - c, err := x509.ParseCertificate(cert) - if err != nil { - return nil, err - } - mconf.Certificates[i].Chain[j] = c - } - } - } - if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { - return nil, err - } - return mconf, nil -} - -type mintTLS interface { - // These two methods are the same as the crypto.TLSExporter interface. - // Cannot use embedding here, because mockgen source mode refuses to generate mocks then. - GetCipherSuite() mint.CipherSuiteParams - ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) - // additional methods - Handshake() mint.Alert - State() mint.ConnectionState -} - -var _ crypto.TLSExporter = (mintTLS)(nil) - -type mintController struct { - conn *mint.Conn -} - -var _ mintTLS = &mintController{} - -func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { - return mc.conn.State().CipherSuite -} - -func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - return mc.conn.ComputeExporter(label, context, keyLength) -} - -func (mc *mintController) Handshake() mint.Alert { - return mc.conn.Handshake() -} - -func (mc *mintController) State() mint.ConnectionState { - return mc.conn.State() -} - -// mint expects a net.Conn, but we're doing the handshake on a stream -// so we wrap a stream such that implements a net.Conn -type fakeConn struct { - stream io.ReadWriter - pers protocol.Perspective - remoteAddr net.Addr - - blockRead bool - writeBuffer bytes.Buffer -} - -var _ net.Conn = &fakeConn{} - -func (c *fakeConn) Read(b []byte) (int, error) { - if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock - return 0, nil - } - c.blockRead = true // block the next Read call - return c.stream.Read(b) -} - -func (c *fakeConn) Write(p []byte) (int, error) { - if c.pers == protocol.PerspectiveClient { - return c.stream.Write(p) - } - // Buffer all writes by the server. - // Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data. - return c.writeBuffer.Write(p) -} - -func (c *fakeConn) Continue() error { - c.blockRead = false - if c.pers == protocol.PerspectiveClient { - return nil - } - // write all contents of the write buffer to the stream. - _, err := c.stream.Write(c.writeBuffer.Bytes()) - c.writeBuffer.Reset() - return err -} - -func (c *fakeConn) Close() error { return nil } -func (c *fakeConn) LocalAddr() net.Addr { return nil } -func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *fakeConn) SetReadDeadline(time.Time) error { return nil } -func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil } -func (c *fakeConn) SetDeadline(time.Time) error { return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go index 2b7fba6..b015750 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go @@ -9,10 +9,10 @@ import ( // ServerConfig is a server config type ServerConfig struct { - kex crypto.KeyExchange - certChain crypto.CertChain - ID []byte - obit []byte + kex crypto.KeyExchange + certChain crypto.CertChain + ID []byte + obit []byte cookieGenerator *CookieGenerator } @@ -36,10 +36,10 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve } return &ServerConfig{ - kex: kex, - certChain: certChain, - ID: id, - obit: obit, + kex: kex, + certChain: certChain, + ID: id, + obit: obit, cookieGenerator: cookieGenerator, }, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go index eb042f6..0d6521a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go @@ -102,32 +102,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error { return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") } - var pubs_kexs []struct{Length uint32; Value []byte} - var last_len uint32 - - for i := 0; i < len(pubs)-3; i += int(last_len)+3 { + var pubsKexs []struct { + Length uint32 + Value []byte + } + var lastLen uint32 + for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 { // the PUBS value is always prepended by 3 byte little endian length field - err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &last_len); + err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen) if err != nil { return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable") } - if last_len == 0 { + if lastLen == 0 { return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") } - if i+3+int(last_len) > len(pubs) { + if i+3+int(lastLen) > len(pubs) { return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") } - pubs_kexs = append(pubs_kexs, struct{Length uint32; Value []byte}{last_len, pubs[i+3:i+3+int(last_len)]}) + pubsKexs = append(pubsKexs, struct { + Length uint32 + Value []byte + }{lastLen, pubs[i+3 : i+3+int(lastLen)]}) } - if c255Foundat >= len(pubs_kexs) { + if c255Foundat >= len(pubsKexs) { return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS") } - if pubs_kexs[c255Foundat].Length != 32 { + if pubsKexs[c255Foundat].Length != 32 { return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") } @@ -137,8 +142,7 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error { return err } - - s.sharedSecret, err = s.kex.CalculateSharedKey(pubs_kexs[c255Foundat].Value) + s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value) if err != nil { return err } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go index 7e56e92..e40ec3d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go @@ -9,13 +9,13 @@ type transportParameterID uint16 const quicTLSExtensionType = 26 const ( - initialMaxStreamDataParameterID transportParameterID = iota - initialMaxDataParameterID - initialMaxStreamIDParameterID - idleTimeoutParameterID - omitConnectionIDParameterID - maxPacketSizeParameterID - statelessResetTokenParameterID + initialMaxStreamDataParameterID transportParameterID = 0x0 + initialMaxDataParameterID transportParameterID = 0x1 + initialMaxBidiStreamsParameterID transportParameterID = 0x2 + idleTimeoutParameterID transportParameterID = 0x3 + maxPacketSizeParameterID transportParameterID = 0x5 + statelessResetTokenParameterID transportParameterID = 0x6 + initialMaxUniStreamsParameterID transportParameterID = 0x8 ) type transportParameter struct { @@ -24,12 +24,12 @@ type transportParameter struct { } type clientHelloTransportParameters struct { - NegotiatedVersion uint32 // actually a protocol.VersionNumber - InitialVersion uint32 // actually a protocol.VersionNumber - Parameters []transportParameter `tls:"head=2"` + InitialVersion uint32 // actually a protocol.VersionNumber + Parameters []transportParameter `tls:"head=2"` } type encryptedExtensionsTransportParameters struct { + NegotiatedVersion uint32 // actually a protocol.VersionNumber SupportedVersions []uint32 `tls:"head=1"` // actually a protocol.VersionNumber Parameters []transportParameter `tls:"head=2"` } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go index 4187804..8e711be 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go @@ -3,39 +3,48 @@ package handshake import ( "errors" "fmt" - "math" "github.com/lucas-clemente/quic-go/qerr" "github.com/bifurcation/mint" "github.com/bifurcation/mint/syntax" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) type extensionHandlerClient struct { - params *TransportParameters - paramsChan chan<- TransportParameters + ourParams *TransportParameters + paramsChan chan TransportParameters initialVersion protocol.VersionNumber supportedVersions []protocol.VersionNumber version protocol.VersionNumber + + logger utils.Logger } var _ mint.AppExtensionHandler = &extensionHandlerClient{} +var _ TLSExtensionHandler = &extensionHandlerClient{} -func newExtensionHandlerClient( +// NewExtensionHandlerClient creates a new extension handler for the client. +func NewExtensionHandlerClient( params *TransportParameters, - paramsChan chan<- TransportParameters, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, -) *extensionHandlerClient { + logger utils.Logger, +) TLSExtensionHandler { + // The client reads the transport parameters from the Encrypted Extensions message. + // The paramsChan is used in the session's run loop's select statement. + // We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately. + paramsChan := make(chan TransportParameters) return &extensionHandlerClient{ - params: params, + ourParams: params, paramsChan: paramsChan, initialVersion: initialVersion, supportedVersions: supportedVersions, version: version, + logger: logger, } } @@ -44,10 +53,10 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi return nil } + h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) data, err := syntax.Marshal(clientHelloTransportParameters{ - NegotiatedVersion: uint32(h.version), - InitialVersion: uint32(h.initialVersion), - Parameters: h.params.getTransportParameters(), + InitialVersion: uint32(h.initialVersion), + Parameters: h.ourParams.getTransportParameters(), }) if err != nil { return err @@ -57,19 +66,17 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + if err != nil { + return err + } - if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { + if hType != mint.HandshakeTypeEncryptedExtensions { if found { return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) } return nil } - if hType == mint.HandshakeTypeNewSessionTicket { - // the extension it's optional in the NewSessionTicket message - // TODO: handle this - return nil - } // hType == mint.HandshakeTypeEncryptedExtensions if !found { @@ -84,6 +91,10 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte for i, v := range eetp.SupportedVersions { serverSupportedVersions[i] = protocol.VersionNumber(v) } + // check that the negotiated_version is the current version + if protocol.VersionNumber(eetp.NegotiatedVersion) != h.version { + return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version") + } // check that the current version is included in the supported versions if !protocol.IsSupportedVersion(serverSupportedVersions, h.version) { return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions") @@ -111,12 +122,15 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte // TODO: return the right error here return errors.New("server didn't sent stateless_reset_token") } - params, err := readTransportParamters(eetp.Parameters) + params, err := readTransportParameters(eetp.Parameters) if err != nil { return err } - // TODO(#878): remove this when implementing the MAX_STREAM_ID frame - params.MaxStreams = math.MaxUint32 + h.logger.Debugf("Received Transport Parameters: %s", params) h.paramsChan <- *params return nil } + +func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters { + return h.paramsChan +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go index 49830d8..138fc21 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go @@ -4,36 +4,44 @@ import ( "bytes" "errors" "fmt" - "math" "github.com/lucas-clemente/quic-go/qerr" "github.com/bifurcation/mint" "github.com/bifurcation/mint/syntax" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) type extensionHandlerServer struct { - params *TransportParameters - paramsChan chan<- TransportParameters + ourParams *TransportParameters + paramsChan chan TransportParameters version protocol.VersionNumber supportedVersions []protocol.VersionNumber + + logger utils.Logger } var _ mint.AppExtensionHandler = &extensionHandlerServer{} +var _ TLSExtensionHandler = &extensionHandlerServer{} -func newExtensionHandlerServer( +// NewExtensionHandlerServer creates a new extension handler for the server +func NewExtensionHandlerServer( params *TransportParameters, - paramsChan chan<- TransportParameters, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, -) *extensionHandlerServer { + logger utils.Logger, +) TLSExtensionHandler { + // Processing the ClientHello is performed statelessly (and from a single go-routine). + // Therefore, we have to use a buffered chan to pass the transport parameters to that go routine. + paramsChan := make(chan TransportParameters, 1) return &extensionHandlerServer{ - params: params, + ourParams: params, paramsChan: paramsChan, - version: version, supportedVersions: supportedVersions, + version: version, + logger: logger, } } @@ -43,16 +51,19 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi } transportParams := append( - h.params.getTransportParameters(), + h.ourParams.getTransportParameters(), // TODO(#855): generate a real token transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, ) - supportedVersions := make([]uint32, len(h.supportedVersions)) - for i, v := range h.supportedVersions { - supportedVersions[i] = uint32(v) + supportedVersions := protocol.GetGreasedVersions(h.supportedVersions) + versions := make([]uint32, len(supportedVersions)) + for i, v := range supportedVersions { + versions[i] = uint32(v) } + h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ - SupportedVersions: supportedVersions, + NegotiatedVersion: uint32(h.version), + SupportedVersions: versions, Parameters: transportParams, }) if err != nil { @@ -63,7 +74,10 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { ext := &tlsExtensionBody{} - found := el.Find(ext) + found, err := el.Find(ext) + if err != nil { + return err + } if hType != mint.HandshakeTypeClientHello { if found { @@ -80,15 +94,11 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte return err } initialVersion := protocol.VersionNumber(chtp.InitialVersion) - negotiatedVersion := protocol.VersionNumber(chtp.NegotiatedVersion) - // check that the negotiated version is the version we're currently using - if negotiatedVersion != h.version { - return qerr.Error(qerr.VersionNegotiationMismatch, "Inconsistent negotiated version") - } + // perform the stateless version negotiation validation: // make sure that we would have sent a Version Negotiation Packet if the client offered the initial version - // this is the case when the initial version is not contained in the supported versions - if initialVersion != negotiatedVersion && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) { + // this is the case if and only if the initial version is not contained in the supported versions + if initialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) { return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version") } @@ -98,12 +108,15 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte return errors.New("client sent a stateless reset token") } } - params, err := readTransportParamters(chtp.Parameters) + params, err := readTransportParameters(chtp.Parameters) if err != nil { return err } - // TODO(#878): remove this when implementing the MAX_STREAM_ID frame - params.MaxStreams = math.MaxUint32 + h.logger.Debugf("Received Transport Parameters: %s", params) h.paramsChan <- *params return nil } + +func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters { + return h.paramsChan +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go index bda12c2..7cfd52e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "errors" "fmt" - "math" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -21,9 +20,13 @@ type TransportParameters struct { StreamFlowControlWindow protocol.ByteCount ConnectionFlowControlWindow protocol.ByteCount - MaxStreams uint32 + MaxPacketSize protocol.ByteCount - OmitConnectionID bool + MaxUniStreams uint16 // only used for IETF QUIC + MaxBidiStreams uint16 // only used for IETF QUIC + MaxStreams uint32 // only used for gQUIC + + OmitConnectionID bool // only used for gQUIC IdleTimeout time.Duration } @@ -92,12 +95,11 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte { } // readTransportParameters reads the transport parameters sent in the QUIC TLS extension -func readTransportParamters(paramsList []transportParameter) (*TransportParameters, error) { +func readTransportParameters(paramsList []transportParameter) (*TransportParameters, error) { params := &TransportParameters{} var foundInitialMaxStreamData bool var foundInitialMaxData bool - var foundInitialMaxStreamID bool var foundIdleTimeout bool for _, p := range paramsList { @@ -114,41 +116,51 @@ func readTransportParamters(paramsList []transportParameter) (*TransportParamete return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value)) } params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) - case initialMaxStreamIDParameterID: - foundInitialMaxStreamID = true - if len(p.Value) != 4 { - return nil, fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value)) + case initialMaxBidiStreamsParameterID: + if len(p.Value) != 2 { + return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", len(p.Value)) } - // TODO: handle this value + params.MaxBidiStreams = binary.BigEndian.Uint16(p.Value) + case initialMaxUniStreamsParameterID: + if len(p.Value) != 2 { + return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", len(p.Value)) + } + params.MaxUniStreams = binary.BigEndian.Uint16(p.Value) case idleTimeoutParameterID: foundIdleTimeout = true if len(p.Value) != 2 { return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) } params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second) - case omitConnectionIDParameterID: - if len(p.Value) != 0 { - return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) + case maxPacketSizeParameterID: + if len(p.Value) != 2 { + return nil, fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", len(p.Value)) } - params.OmitConnectionID = true + maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(p.Value)) + if maxPacketSize < 1200 { + return nil, fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize) + } + params.MaxPacketSize = maxPacketSize } } - if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) { + if !(foundInitialMaxStreamData && foundInitialMaxData && foundIdleTimeout) { return nil, errors.New("missing parameter") } return params, nil } // GetTransportParameters gets the parameters needed for the TLS handshake. +// It doesn't send the initial_max_stream_id_uni parameter, so the peer isn't allowed to open any unidirectional streams. func (p *TransportParameters) getTransportParameters() []transportParameter { initialMaxStreamData := make([]byte, 4) binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow)) initialMaxData := make([]byte, 4) binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow)) - initialMaxStreamID := make([]byte, 4) - // TODO: use a reasonable value here - binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32) + initialMaxBidiStreamID := make([]byte, 2) + binary.BigEndian.PutUint16(initialMaxBidiStreamID, p.MaxBidiStreams) + initialMaxUniStreamID := make([]byte, 2) + binary.BigEndian.PutUint16(initialMaxUniStreamID, p.MaxUniStreams) idleTimeout := make([]byte, 2) binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second)) maxPacketSize := make([]byte, 2) @@ -156,12 +168,16 @@ func (p *TransportParameters) getTransportParameters() []transportParameter { params := []transportParameter{ {initialMaxStreamDataParameterID, initialMaxStreamData}, {initialMaxDataParameterID, initialMaxData}, - {initialMaxStreamIDParameterID, initialMaxStreamID}, + {initialMaxBidiStreamsParameterID, initialMaxBidiStreamID}, + {initialMaxUniStreamsParameterID, initialMaxUniStreamID}, {idleTimeoutParameterID, idleTimeout}, {maxPacketSizeParameterID, maxPacketSize}, } - if p.OmitConnectionID { - params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}}) - } return params } + +// String returns a string representation, intended for logging. +// It should only used for IETF QUIC. +func (p *TransportParameters) String() string { + return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go new file mode 100644 index 0000000..dca4bcd --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go @@ -0,0 +1,56 @@ +package protocol + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" +) + +// A ConnectionID in QUIC +type ConnectionID []byte + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID() (ConnectionID, error) { + b := make([]byte, ConnectionIDLen) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return ConnectionID(b), nil +} + +// ReadConnectionID reads a connection ID of length len from the given io.Reader. +// It returns io.EOF if there are not enough bytes to read. +func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { + if len == 0 { + return nil, nil + } + c := make(ConnectionID, len) + _, err := io.ReadFull(r, c) + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return c, err +} + +// Equal says if two connection IDs are equal +func (c ConnectionID) Equal(other ConnectionID) bool { + return bytes.Equal(c, other) +} + +// Len returns the length of the connection ID in bytes +func (c ConnectionID) Len() int { + return len(c) +} + +// Bytes returns the byte representation +func (c ConnectionID) Bytes() []byte { + return []byte(c) +} + +func (c ConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%#x", c.Bytes()) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go index 6aa3b70..948e371 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go @@ -8,3 +8,14 @@ const ( PerspectiveServer Perspective = 1 PerspectiveClient Perspective = 2 ) + +func (p Perspective) String() string { + switch p { + case PerspectiveServer: + return "Server" + case PerspectiveClient: + return "Client" + default: + return "invalid perspective" + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go index dadbf32..e89b222 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go @@ -1,6 +1,8 @@ package protocol -import "math" +import ( + "fmt" +) // A PacketNumber in QUIC type PacketNumber uint64 @@ -25,29 +27,39 @@ const ( type PacketType uint8 const ( - // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet - PacketTypeVersionNegotiation PacketType = 1 - // PacketTypeInitial is the packet type of a Initial packet - PacketTypeInitial PacketType = 2 + // PacketTypeInitial is the packet type of an Initial packet + PacketTypeInitial PacketType = 0x7f // PacketTypeRetry is the packet type of a Retry packet - PacketTypeRetry PacketType = 3 - // PacketTypeHandshake is the packet type of a Cleartext packet - PacketTypeHandshake PacketType = 4 + PacketTypeRetry PacketType = 0x7e + // PacketTypeHandshake is the packet type of a Handshake packet + PacketTypeHandshake PacketType = 0x7d // PacketType0RTT is the packet type of a 0-RTT packet - PacketType0RTT PacketType = 5 + PacketType0RTT PacketType = 0x7c ) -// A ConnectionID in QUIC -type ConnectionID uint64 - -// A StreamID in QUIC -type StreamID uint32 +func (t PacketType) String() string { + switch t { + case PacketTypeInitial: + return "Initial" + case PacketTypeRetry: + return "Retry" + case PacketTypeHandshake: + return "Handshake" + case PacketType0RTT: + return "0-RTT Protected" + default: + return fmt.Sprintf("unknown packet type: %d", t) + } +} // A ByteCount in QUIC type ByteCount uint64 // MaxByteCount is the maximum value of a ByteCount -const MaxByteCount = ByteCount(math.MaxUint64) +const MaxByteCount = ByteCount(1<<62 - 1) + +// An ApplicationErrorCode is an application-defined error code. +type ApplicationErrorCode uint16 // MaxReceivePacketSize maximum packet size of any QUIC packet, based on // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, @@ -59,11 +71,14 @@ const MaxReceivePacketSize ByteCount = 1452 // Used in QUIC for congestion window computations in bytes. const DefaultTCPMSS ByteCount = 1460 -// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. -const ClientHelloMinimumSize = 1024 +// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC) +const MinClientHelloSize = 1024 + +// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have. +const MinInitialPacketSize = 1200 // MaxClientHellos is the maximum number of times we'll send a client hello // The value 3 accounts for: // * one failure due to an incorrect or missing source-address token -// * one failure due the server's certificate chain being unavailible and the server being unwilling to send it without a valid source-address token +// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token const MaxClientHellos = 3 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go index 697d787..961986a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go @@ -2,19 +2,23 @@ package protocol import "time" -// MaxPacketSize is the maximum packet size that we use for sending packets. -// It includes the QUIC packet header, but excludes the UDP and IP header. -const MaxPacketSize ByteCount = 1200 +// MaxPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. +const MaxPacketSizeIPv4 = 1252 + +// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. +const MaxPacketSizeIPv6 = 1232 // NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames const NonForwardSecurePacketSizeReduction = 50 +const defaultMaxCongestionWindowPackets = 1000 + // DefaultMaxCongestionWindow is the default for the max congestion window -const DefaultMaxCongestionWindow = 1000 +const DefaultMaxCongestionWindow ByteCount = defaultMaxCongestionWindowPackets * DefaultTCPMSS // InitialCongestionWindow is the initial congestion window in QUIC packets -const InitialCongestionWindow = 32 +const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS // MaxUndecryptablePackets limits the number of undecryptable packets that a // session queues for later until it sends a public reset. @@ -24,10 +28,6 @@ const MaxUndecryptablePackets = 10 // This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto const PublicResetTimeout = 500 * time.Millisecond -// AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet -// This is the value Chromium is using -const AckSendDelay = 25 * time.Millisecond - // ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data // This is the value that Google servers are using const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB @@ -56,8 +56,14 @@ const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 // This is the value that Chromium is using const ConnectionFlowControlMultiplier = 1.5 -// MaxIncomingStreams is the maximum number of streams that a peer may open -const MaxIncomingStreams = 100 +// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client +const WindowUpdateThreshold = 0.25 + +// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open +const DefaultMaxIncomingStreams = 100 + +// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open +const DefaultMaxIncomingUniStreams = 100 // MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used. const MaxStreamsMultiplier = 1.1 @@ -65,12 +71,8 @@ const MaxStreamsMultiplier = 1.1 // MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used. const MaxStreamsMinimumIncrement = 10 -// MaxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened -// note that the number of streams is half this value, since the client can only open streams with open StreamID -const MaxNewStreamIDDelta = 4 * MaxIncomingStreams - // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. -const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow +const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets // SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack const SkipPacketAveragePeriodLength PacketNumber = 500 @@ -81,20 +83,21 @@ const MaxTrackedSkippedPackets = 10 // CookieExpiryTime is the valid time of a cookie const CookieExpiryTime = 24 * time.Hour -// MaxTrackedSentPackets is maximum number of sent packets saved for either later retransmission or entropy calculation -const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow +// MaxOutstandingSentPackets is maximum number of packets saved for retransmission. +// When reached, it imposes a soft limit on sending new packets: +// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. +const MaxOutstandingSentPackets = 2 * defaultMaxCongestionWindowPackets + +// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. +// When reached, no more packets will be sent. +// This value *must* be larger than MaxOutstandingSentPackets. +const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked -const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow +const MaxTrackedReceivedAckRanges = defaultMaxCongestionWindowPackets -// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent -const MaxPacketsReceivedBeforeAckSend = 20 - -// MaxNonRetransmittablePackets is the maximum number of non-retransmittable packets that we send in a row -const MaxNonRetransmittablePackets = 19 - -// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for -const RetransmittablePacketsBeforeAck = 2 +// MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row +const MaxNonRetransmittableAcks = 19 // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // prevents DoS attacks against the streamFrameSorter @@ -125,3 +128,26 @@ const ClosedSessionDeleteTimeout = time.Minute // NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space const NumCachedCertificates = 128 + +// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. +// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: +// 1. it reduces the framing overhead +// 2. it reduces the head-of-line blocking, when a packet is lost +const MinStreamFrameSize ByteCount = 128 + +// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write +// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. +// The MaxAckFrameSize should be large enough to encode many ACK range, +// but must ensure that a maximum size ACK frame fits into one packet. +const MaxAckFrameSize ByteCount = 1000 + +// MinPacingDelay is the minimum duration that is used for packet pacing +// If the packet packing frequency is higher, multiple packets might be sent at once. +// Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth. +const MinPacingDelay time.Duration = 100 * time.Microsecond + +// ConnectionIDLen is the length of the source Connection ID used on IETF QUIC packets. +// The Short Header contains the connection ID, but not the length, +// so we need to know this value in advance (or encode it into the connection ID). +// TODO: make this configurable +const ConnectionIDLen = 8 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream_id.go new file mode 100644 index 0000000..a0dced0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream_id.go @@ -0,0 +1,36 @@ +package protocol + +// A StreamID in QUIC +type StreamID uint64 + +// MaxBidiStreamID is the highest stream ID that the peer is allowed to open, +// when it is allowed to open numStreams bidirectional streams. +// It is only valid for IETF QUIC. +func MaxBidiStreamID(numStreams int, pers Perspective) StreamID { + if numStreams == 0 { + return 0 + } + var first StreamID + if pers == PerspectiveClient { + first = 1 + } else { + first = 4 + } + return first + 4*StreamID(numStreams-1) +} + +// MaxUniStreamID is the highest stream ID that the peer is allowed to open, +// when it is allowed to open numStreams unidirectional streams. +// It is only valid for IETF QUIC. +func MaxUniStreamID(numStreams int, pers Perspective) StreamID { + if numStreams == 0 { + return 0 + } + var first StreamID + if pers == PerspectiveClient { + first = 3 + } else { + first = 2 + } + return first + 4*StreamID(numStreams-1) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go index 5ad04f0..d5f2f37 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go @@ -1,11 +1,14 @@ package protocol import ( + "crypto/rand" + "encoding/binary" "fmt" + "math" ) // VersionNumber is a version number as int -type VersionNumber int +type VersionNumber uint32 // gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions const ( @@ -18,7 +21,7 @@ const ( Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9 + iota VersionTLS VersionNumber = 101 VersionWhatever VersionNumber = 0 // for when the version doesn't matter - VersionUnknown VersionNumber = -1 + VersionUnknown VersionNumber = math.MaxUint32 ) // SupportedVersions lists the versions that the server supports @@ -27,6 +30,11 @@ var SupportedVersions = []VersionNumber{ Version39, } +// IsValidVersion says if the version is known to quic-go +func IsValidVersion(v VersionNumber) bool { + return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) +} + // UsesTLS says if this QUIC version uses TLS 1.3 for the handshake func (vn VersionNumber) UsesTLS() bool { return vn == VersionTLS @@ -44,7 +52,7 @@ func (vn VersionNumber) String() string { if vn.isGQUIC() { return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) } - return fmt.Sprintf("%d", vn) + return fmt.Sprintf("%#x", uint32(vn)) } } @@ -64,9 +72,14 @@ func (vn VersionNumber) CryptoStreamID() StreamID { return 0 } -// UsesMaxDataFrame tells if this version uses MAX_DATA, MAX_STREAM_DATA, BLOCKED and STREAM_BLOCKED instead of WINDOW_UDPATE and BLOCKED frames -func (vn VersionNumber) UsesMaxDataFrame() bool { - return vn.CryptoStreamID() == 0 +// UsesIETFFrameFormat tells if this version uses the IETF frame format +func (vn VersionNumber) UsesIETFFrameFormat() bool { + return vn != Version39 +} + +// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames +func (vn VersionNumber) UsesStopWaitingFrames() bool { + return vn == Version39 } // StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control @@ -112,3 +125,22 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) } return 0, false } + +// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) +func generateReservedVersion() VersionNumber { + b := make([]byte, 4) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) +} + +// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position +func GetGreasedVersions(supported []VersionNumber) []VersionNumber { + b := make([]byte, 1) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + randPos := int(b[0]) % (len(supported) + 1) + greased := make([]VersionNumber, len(supported)+1) + copy(greased, supported[:randPos]) + greased[randPos] = generateReservedVersion() + copy(greased[randPos+1:], supported[randPos:]) + return greased +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/_gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/_gen.go deleted file mode 100644 index 154515b..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/_gen.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import ( - _ "github.com/clipperhouse/linkedlist" - _ "github.com/clipperhouse/slice" - _ "github.com/clipperhouse/stringer" -) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go index 545fc20..096023e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go @@ -1,13 +1,10 @@ -// Generated by: main -// TypeWriter: linkedlist -// Directive: +gen on ByteInterval +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny package utils -// List is a modification of http://golang.org/pkg/container/list/ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. +// Linked list implementation from the Go standard library. // ByteIntervalElement is an element of a linked list. type ByteIntervalElement struct { @@ -41,8 +38,7 @@ func (e *ByteIntervalElement) Prev() *ByteIntervalElement { return nil } -// ByteIntervalList represents a doubly linked list. -// The zero value for ByteIntervalList is an empty list ready to use. +// ByteIntervalList is a linked list of ByteIntervals. type ByteIntervalList struct { root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used len int // current list length excluding (this) sentinel element @@ -63,7 +59,7 @@ func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init // The complexity is O(1). func (l *ByteIntervalList) Len() int { return l.len } -// Front returns the first element of list l or nil. +// Front returns the first element of list l or nil if the list is empty. func (l *ByteIntervalList) Front() *ByteIntervalElement { if l.len == 0 { return nil @@ -71,7 +67,7 @@ func (l *ByteIntervalList) Front() *ByteIntervalElement { return l.root.next } -// Back returns the last element of list l or nil. +// Back returns the last element of list l or nil if the list is empty. func (l *ByteIntervalList) Back() *ByteIntervalElement { if l.len == 0 { return nil @@ -79,7 +75,7 @@ func (l *ByteIntervalList) Back() *ByteIntervalElement { return l.root.prev } -// lazyInit lazily initializes a zero ByteIntervalList value. +// lazyInit lazily initializes a zero List value. func (l *ByteIntervalList) lazyInit() { if l.root.next == nil { l.Init() @@ -98,7 +94,7 @@ func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalEleme return e } -// insertValue is a convenience wrapper for insert(&ByteIntervalElement{Value: v}, at). +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { return l.insert(&ByteIntervalElement{Value: v}, at) } @@ -116,10 +112,11 @@ func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { // Remove removes e from l if e is an element of list l. // It returns the element value e.Value. +// The element must not be nil. func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { if e.list == l { // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero ByteIntervalElement) and l.remove will crash + // in l or l == nil (e is a zero Element) and l.remove will crash l.remove(e) } return e.Value @@ -139,46 +136,51 @@ func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { // InsertBefore inserts a new element e with value v immediately before mark and returns e. // If mark is not an element of l, the list is not modified. +// The mark must not be nil. func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { if mark.list != l { return nil } - // see comment in ByteIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l return l.insertValue(v, mark.prev) } // InsertAfter inserts a new element e with value v immediately after mark and returns e. // If mark is not an element of l, the list is not modified. +// The mark must not be nil. func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { if mark.list != l { return nil } - // see comment in ByteIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l return l.insertValue(v, mark) } // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. +// The element must not be nil. func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { if e.list != l || l.root.next == e { return } - // see comment in ByteIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l l.insert(l.remove(e), &l.root) } // MoveToBack moves element e to the back of list l. // If e is not an element of l, the list is not modified. +// The element must not be nil. func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { if e.list != l || l.root.prev == e { return } - // see comment in ByteIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l l.insert(l.remove(e), l.root.prev) } // MoveBefore moves element e to its new position before mark. // If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { if e.list != l || e == mark || mark.list != l { return @@ -187,7 +189,8 @@ func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { } // MoveAfter moves element e to its new position after mark. -// If e is not an element of l, or e == mark, the list is not modified. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { if e.list != l || e == mark || mark.list != l { return @@ -196,7 +199,7 @@ func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { } // PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. +// The lists l and other may be the same. They must not be nil. func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { l.lazyInit() for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { @@ -205,7 +208,7 @@ func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { } // PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. +// The lists l and other may be the same. They must not be nil. func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { l.lazyInit() for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go index 35549f6..b45800a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go @@ -3,8 +3,6 @@ package utils import ( "bytes" "io" - - "github.com/lucas-clemente/quic-go/internal/protocol" ) // A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. @@ -25,9 +23,3 @@ type ByteOrder interface { ReadUfloat16(io.ByteReader) (uint64, error) WriteUfloat16(*bytes.Buffer, uint64) } - -// GetByteOrder gets the ByteOrder to represent values on the wire -// from QUIC 39, values are encoded in big endian, before that in little endian -func GetByteOrder(v protocol.VersionNumber) ByteOrder { - return BigEndian -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go deleted file mode 100644 index b4af4e7..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go +++ /dev/null @@ -1,18 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/binary" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID() (protocol.ConnectionID, error) { - b := make([]byte, 8) - _, err := rand.Read(b) - if err != nil { - return 0, err - } - return protocol.ConnectionID(binary.LittleEndian.Uint64(b)), nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go new file mode 100644 index 0000000..bb839be --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go @@ -0,0 +1,4 @@ +package utils + +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out byteinterval_linkedlist.go gen Item=ByteInterval +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out packetinterval_linkedlist.go gen Item=PacketInterval diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go index 342d8dd..62a3d07 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go @@ -11,8 +11,6 @@ import ( // LogLevel of quic-go type LogLevel uint8 -const logEnv = "QUIC_GO_LOG_LEVEL" - const ( // LogLevelNothing disables LogLevelNothing LogLevel = iota @@ -24,72 +22,92 @@ const ( LogLevelDebug ) -var ( - logLevel = LogLevelNothing - timeFormat = "" -) +const logEnv = "QUIC_GO_LOG_LEVEL" + +// A Logger logs. +type Logger interface { + SetLogLevel(LogLevel) + SetLogTimeFormat(format string) + Debug() bool + + Errorf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Debugf(format string, args ...interface{}) +} + +// DefaultLogger is used by quic-go for logging. +var DefaultLogger Logger + +type defaultLogger struct { + logLevel LogLevel + timeFormat string +} + +var _ Logger = &defaultLogger{} // SetLogLevel sets the log level -func SetLogLevel(level LogLevel) { - logLevel = level +func (l *defaultLogger) SetLogLevel(level LogLevel) { + l.logLevel = level } // SetLogTimeFormat sets the format of the timestamp // an empty string disables the logging of timestamps -func SetLogTimeFormat(format string) { +func (l *defaultLogger) SetLogTimeFormat(format string) { log.SetFlags(0) // disable timestamp logging done by the log package - timeFormat = format + l.timeFormat = format } // Debugf logs something -func Debugf(format string, args ...interface{}) { - if logLevel == LogLevelDebug { - logMessage(format, args...) +func (l *defaultLogger) Debugf(format string, args ...interface{}) { + if l.logLevel == LogLevelDebug { + l.logMessage(format, args...) } } // Infof logs something -func Infof(format string, args ...interface{}) { - if logLevel >= LogLevelInfo { - logMessage(format, args...) +func (l *defaultLogger) Infof(format string, args ...interface{}) { + if l.logLevel >= LogLevelInfo { + l.logMessage(format, args...) } } // Errorf logs something -func Errorf(format string, args ...interface{}) { - if logLevel >= LogLevelError { - logMessage(format, args...) +func (l *defaultLogger) Errorf(format string, args ...interface{}) { + if l.logLevel >= LogLevelError { + l.logMessage(format, args...) } } -func logMessage(format string, args ...interface{}) { - if len(timeFormat) > 0 { - log.Printf(time.Now().Format(timeFormat)+" "+format, args...) +func (l *defaultLogger) logMessage(format string, args ...interface{}) { + if len(l.timeFormat) > 0 { + log.Printf(time.Now().Format(l.timeFormat)+" "+format, args...) } else { log.Printf(format, args...) } } // Debug returns true if the log level is LogLevelDebug -func Debug() bool { - return logLevel == LogLevelDebug +func (l *defaultLogger) Debug() bool { + return l.logLevel == LogLevelDebug } func init() { - readLoggingEnv() + DefaultLogger = &defaultLogger{} + DefaultLogger.SetLogLevel(readLoggingEnv()) } -func readLoggingEnv() { +func readLoggingEnv() LogLevel { switch strings.ToLower(os.Getenv(logEnv)) { case "": - return + return LogLevelNothing case "debug": - logLevel = LogLevelDebug + return LogLevelDebug case "info": - logLevel = LogLevelInfo + return LogLevelInfo case "error": - logLevel = LogLevelError + return LogLevelError default: fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") + return LogLevelNothing } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go index c984a3c..4394ab0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go @@ -82,6 +82,14 @@ func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { return b } +// MaxByteCount returns the maximum of two ByteCounts +func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return b + } + return a +} + // MaxDuration returns the max duration func MaxDuration(a, b time.Duration) time.Duration { if a > b { @@ -114,6 +122,14 @@ func MinTime(a, b time.Time) time.Time { return a } +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} + // MaxPacketNumber returns the max packet number func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { if a > b { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go index f49b0c4..62cc8b9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go @@ -3,7 +3,6 @@ package utils import "github.com/lucas-clemente/quic-go/internal/protocol" // PacketInterval is an interval from one PacketNumber to the other -// +gen linkedlist type PacketInterval struct { Start protocol.PacketNumber End protocol.PacketNumber diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go index e3431d6..b461e85 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go @@ -1,13 +1,10 @@ -// Generated by: main -// TypeWriter: linkedlist -// Directive: +gen on PacketInterval +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny package utils -// List is a modification of http://golang.org/pkg/container/list/ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. +// Linked list implementation from the Go standard library. // PacketIntervalElement is an element of a linked list. type PacketIntervalElement struct { @@ -41,8 +38,7 @@ func (e *PacketIntervalElement) Prev() *PacketIntervalElement { return nil } -// PacketIntervalList represents a doubly linked list. -// The zero value for PacketIntervalList is an empty list ready to use. +// PacketIntervalList is a linked list of PacketIntervals. type PacketIntervalList struct { root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used len int // current list length excluding (this) sentinel element @@ -63,7 +59,7 @@ func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList // The complexity is O(1). func (l *PacketIntervalList) Len() int { return l.len } -// Front returns the first element of list l or nil. +// Front returns the first element of list l or nil if the list is empty. func (l *PacketIntervalList) Front() *PacketIntervalElement { if l.len == 0 { return nil @@ -71,7 +67,7 @@ func (l *PacketIntervalList) Front() *PacketIntervalElement { return l.root.next } -// Back returns the last element of list l or nil. +// Back returns the last element of list l or nil if the list is empty. func (l *PacketIntervalList) Back() *PacketIntervalElement { if l.len == 0 { return nil @@ -79,7 +75,7 @@ func (l *PacketIntervalList) Back() *PacketIntervalElement { return l.root.prev } -// lazyInit lazily initializes a zero PacketIntervalList value. +// lazyInit lazily initializes a zero List value. func (l *PacketIntervalList) lazyInit() { if l.root.next == nil { l.Init() @@ -98,7 +94,7 @@ func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketInterva return e } -// insertValue is a convenience wrapper for insert(&PacketIntervalElement{Value: v}, at). +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { return l.insert(&PacketIntervalElement{Value: v}, at) } @@ -116,10 +112,11 @@ func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalEle // Remove removes e from l if e is an element of list l. // It returns the element value e.Value. +// The element must not be nil. func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { if e.list == l { // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero PacketIntervalElement) and l.remove will crash + // in l or l == nil (e is a zero Element) and l.remove will crash l.remove(e) } return e.Value @@ -139,46 +136,51 @@ func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { // InsertBefore inserts a new element e with value v immediately before mark and returns e. // If mark is not an element of l, the list is not modified. +// The mark must not be nil. func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { if mark.list != l { return nil } - // see comment in PacketIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l return l.insertValue(v, mark.prev) } // InsertAfter inserts a new element e with value v immediately after mark and returns e. // If mark is not an element of l, the list is not modified. +// The mark must not be nil. func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { if mark.list != l { return nil } - // see comment in PacketIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l return l.insertValue(v, mark) } // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. +// The element must not be nil. func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { if e.list != l || l.root.next == e { return } - // see comment in PacketIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l l.insert(l.remove(e), &l.root) } // MoveToBack moves element e to the back of list l. // If e is not an element of l, the list is not modified. +// The element must not be nil. func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { if e.list != l || l.root.prev == e { return } - // see comment in PacketIntervalList.Remove about initialization of l + // see comment in List.Remove about initialization of l l.insert(l.remove(e), l.root.prev) } // MoveBefore moves element e to its new position before mark. // If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { if e.list != l || e == mark || mark.list != l { return @@ -187,7 +189,8 @@ func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { } // MoveAfter moves element e to its new position after mark. -// If e is not an element of l, or e == mark, the list is not modified. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { if e.list != l || e == mark || mark.list != l { return @@ -196,7 +199,7 @@ func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { } // PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. +// The lists l and other may be the same. They must not be nil. func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { l.lazyInit() for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { @@ -205,7 +208,7 @@ func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { } // PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. +// The lists l and other may be the same. They must not be nil. func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { l.lazyInit() for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go index 3c8325b..ec16d25 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go @@ -3,7 +3,6 @@ package utils import "github.com/lucas-clemente/quic-go/internal/protocol" // ByteInterval is an interval from one ByteCount to the other -// +gen linkedlist type ByteInterval struct { Start protocol.ByteCount End protocol.ByteCount diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go index 695ad3e..20eaacd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go @@ -21,7 +21,7 @@ func (t *Timer) Chan() <-chan time.Time { // Reset the timer, no matter whether the value was read or not func (t *Timer) Reset(deadline time.Time) { - if deadline.Equal(t.deadline) { + if deadline.Equal(t.deadline) && !t.read { // No need to reset the timer return } @@ -31,7 +31,7 @@ func (t *Timer) Reset(deadline time.Time) { if !t.t.Stop() && !t.read { <-t.t.C } - t.t.Reset(deadline.Sub(time.Now())) + t.t.Reset(time.Until(deadline)) t.read = false t.deadline = deadline diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint.go new file mode 100644 index 0000000..35e8674 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/varint.go @@ -0,0 +1,101 @@ +package utils + +import ( + "bytes" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// taken from the QUIC draft +const ( + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// ReadVarInt reads a number in the QUIC varint format +func ReadVarInt(b io.ByteReader) (uint64, error) { + firstByte, err := b.ReadByte() + if err != nil { + return 0, err + } + // the first two bits of the first byte encode the length + len := 1 << ((firstByte & 0xc0) >> 6) + b1 := firstByte & (0xff - 0xc0) + if len == 1 { + return uint64(b1), nil + } + b2, err := b.ReadByte() + if err != nil { + return 0, err + } + if len == 2 { + return uint64(b2) + uint64(b1)<<8, nil + } + b3, err := b.ReadByte() + if err != nil { + return 0, err + } + b4, err := b.ReadByte() + if err != nil { + return 0, err + } + if len == 4 { + return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil + } + b5, err := b.ReadByte() + if err != nil { + return 0, err + } + b6, err := b.ReadByte() + if err != nil { + return 0, err + } + b7, err := b.ReadByte() + if err != nil { + return 0, err + } + b8, err := b.ReadByte() + if err != nil { + return 0, err + } + return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil +} + +// WriteVarInt writes a number in the QUIC varint format +func WriteVarInt(b *bytes.Buffer, i uint64) { + if i <= maxVarInt1 { + b.WriteByte(uint8(i)) + } else if i <= maxVarInt2 { + b.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) + } else if i <= maxVarInt4 { + b.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) + } else if i <= maxVarInt8 { + b.Write([]byte{ + uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) + } else { + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) + } +} + +// VarIntLen determines the number of bytes that will be needed to write a number +func VarIntLen(i uint64) protocol.ByteCount { + if i <= maxVarInt1 { + return 1 + } + if i <= maxVarInt2 { + return 2 + } + if i <= maxVarInt4 { + return 4 + } + if i <= maxVarInt8 { + return 8 + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go index 2d60baa..021cef5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go @@ -3,371 +3,180 @@ package wire import ( "bytes" "errors" + "sort" "time" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) -var ( - // ErrInvalidAckRanges occurs when a client sends inconsistent ACK ranges - ErrInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") - // ErrInvalidFirstAckRange occurs when the first ACK range contains no packets - ErrInvalidFirstAckRange = errors.New("AckFrame: ACK frame has invalid first ACK range") -) +// TODO: use the value sent in the transport parameters +const ackDelayExponent = 3 -var ( - errInconsistentAckLargestAcked = errors.New("internal inconsistency: LargestAcked does not match ACK ranges") - errInconsistentAckLowestAcked = errors.New("internal inconsistency: LowestAcked does not match ACK ranges") -) - -// An AckFrame is an ACK frame in QUIC +// An AckFrame is an ACK frame type AckFrame struct { - LargestAcked protocol.PacketNumber - LowestAcked protocol.PacketNumber - AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last - - // time when the LargestAcked was receiveid - // this field Will not be set for received ACKs frames - PacketReceivedTime time.Time - DelayTime time.Duration + AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last + DelayTime time.Duration } -// ParseAckFrame reads an ACK frame -func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { +// parseAckFrame reads an ACK frame +func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { + if !version.UsesIETFFrameFormat() { + return parseAckFrameLegacy(r, version) + } + + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &AckFrame{} - typeByte, err := r.ReadByte() + la, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + largestAcked := protocol.PacketNumber(la) + delay, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + frame.DelayTime = time.Duration(delay*1<> 2) - if largestAckedLen == 0 { - largestAckedLen = 1 - } - - missingSequenceNumberDeltaLen := 2 * (typeByte & 0x03) - if missingSequenceNumberDeltaLen == 0 { - missingSequenceNumberDeltaLen = 1 - } - - largestAcked, err := utils.GetByteOrder(version).ReadUintN(r, largestAckedLen) + // read the first ACK range + ab, err := utils.ReadVarInt(r) if err != nil { return nil, err } - frame.LargestAcked = protocol.PacketNumber(largestAcked) - - delay, err := utils.GetByteOrder(version).ReadUfloat16(r) - if err != nil { - return nil, err + ackBlock := protocol.PacketNumber(ab) + if ackBlock > largestAcked { + return nil, errors.New("invalid first ACK range") } - frame.DelayTime = time.Duration(delay) * time.Microsecond + smallest := largestAcked - ackBlock - var numAckBlocks uint8 - if hasMissingRanges { - numAckBlocks, err = r.ReadByte() + // read all the other ACK ranges + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) + for i := uint64(0); i < numBlocks; i++ { + g, err := utils.ReadVarInt(r) if err != nil { return nil, err } - } - - if hasMissingRanges && numAckBlocks == 0 { - return nil, ErrInvalidAckRanges - } - - ackBlockLength, err := utils.GetByteOrder(version).ReadUintN(r, missingSequenceNumberDeltaLen) - if err != nil { - return nil, err - } - if frame.LargestAcked > 0 && ackBlockLength < 1 { - return nil, ErrInvalidFirstAckRange - } - - if ackBlockLength > largestAcked { - return nil, ErrInvalidAckRanges - } - - if hasMissingRanges { - ackRange := AckRange{ - First: protocol.PacketNumber(largestAcked-ackBlockLength) + 1, - Last: frame.LargestAcked, + gap := protocol.PacketNumber(g) + if smallest < gap+2 { + return nil, errInvalidAckRanges } - frame.AckRanges = append(frame.AckRanges, ackRange) + largest := smallest - gap - 2 - var inLongBlock bool - var lastRangeComplete bool - for i := uint8(0); i < numAckBlocks; i++ { - var gap uint8 - gap, err = r.ReadByte() - if err != nil { - return nil, err - } - - ackBlockLength, err = utils.GetByteOrder(version).ReadUintN(r, missingSequenceNumberDeltaLen) - if err != nil { - return nil, err - } - - length := protocol.PacketNumber(ackBlockLength) - - if inLongBlock { - frame.AckRanges[len(frame.AckRanges)-1].First -= protocol.PacketNumber(gap) + length - frame.AckRanges[len(frame.AckRanges)-1].Last -= protocol.PacketNumber(gap) - } else { - lastRangeComplete = false - ackRange := AckRange{ - Last: frame.AckRanges[len(frame.AckRanges)-1].First - protocol.PacketNumber(gap) - 1, - } - ackRange.First = ackRange.Last - length + 1 - frame.AckRanges = append(frame.AckRanges, ackRange) - } - - if length > 0 { - lastRangeComplete = true - } - - inLongBlock = (ackBlockLength == 0) + ab, err := utils.ReadVarInt(r) + if err != nil { + return nil, err } + ackBlock := protocol.PacketNumber(ab) - // if the last range was not complete, First and Last make no sense - // remove the range from frame.AckRanges - if !lastRangeComplete { - frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1] - } - - frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].First - } else { - if frame.LargestAcked == 0 { - frame.LowestAcked = 0 - } else { - frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength) + if ackBlock > largest { + return nil, errInvalidAckRanges } + smallest = largest - ackBlock + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) } if !frame.validateAckRanges() { - return nil, ErrInvalidAckRanges - } - - var numTimestamp byte - numTimestamp, err = r.ReadByte() - if err != nil { - return nil, err - } - - if numTimestamp > 0 { - // Delta Largest acked - _, err = r.ReadByte() - if err != nil { - return nil, err - } - // First Timestamp - _, err = utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - - for i := 0; i < int(numTimestamp)-1; i++ { - // Delta Largest acked - _, err = r.ReadByte() - if err != nil { - return nil, err - } - - // Time Since Previous Timestamp - _, err = utils.GetByteOrder(version).ReadUint16(r) - if err != nil { - return nil, err - } - } + return nil, errInvalidAckRanges } return frame, nil } // Write writes an ACK frame. func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - largestAckedLen := protocol.GetPacketNumberLength(f.LargestAcked) - - typeByte := uint8(0x40) - - if largestAckedLen != protocol.PacketNumberLen1 { - typeByte ^= (uint8(largestAckedLen / 2)) << 2 + if !version.UsesIETFFrameFormat() { + return f.writeLegacy(b, version) } - missingSequenceNumberDeltaLen := f.getMissingSequenceNumberDeltaLen() - if missingSequenceNumberDeltaLen != protocol.PacketNumberLen1 { - typeByte ^= (uint8(missingSequenceNumberDeltaLen / 2)) + b.WriteByte(0x0d) + utils.WriteVarInt(b, uint64(f.LargestAcked())) + utils.WriteVarInt(b, encodeAckDelay(f.DelayTime)) + + numRanges := f.numEncodableAckRanges() + utils.WriteVarInt(b, uint64(numRanges-1)) + + // write the first range + _, firstRange := f.encodeAckRange(0) + utils.WriteVarInt(b, firstRange) + + // write all the other range + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + utils.WriteVarInt(b, gap) + utils.WriteVarInt(b, len) } - - if f.HasMissingRanges() { - typeByte |= 0x20 - } - - b.WriteByte(typeByte) - - switch largestAckedLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(f.LargestAcked)) - case protocol.PacketNumberLen2: - utils.GetByteOrder(version).WriteUint16(b, uint16(f.LargestAcked)) - case protocol.PacketNumberLen4: - utils.GetByteOrder(version).WriteUint32(b, uint32(f.LargestAcked)) - case protocol.PacketNumberLen6: - utils.GetByteOrder(version).WriteUint48(b, uint64(f.LargestAcked)&(1<<48-1)) - } - - f.DelayTime = time.Since(f.PacketReceivedTime) - utils.GetByteOrder(version).WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) - - var numRanges uint64 - var numRangesWritten uint64 - if f.HasMissingRanges() { - numRanges = f.numWritableNackRanges() - if numRanges > 0xFF { - panic("AckFrame: Too many ACK ranges") - } - b.WriteByte(uint8(numRanges - 1)) - } - - var firstAckBlockLength protocol.PacketNumber - if !f.HasMissingRanges() { - firstAckBlockLength = f.LargestAcked - f.LowestAcked + 1 - } else { - if f.LargestAcked != f.AckRanges[0].Last { - return errInconsistentAckLargestAcked - } - if f.LowestAcked != f.AckRanges[len(f.AckRanges)-1].First { - return errInconsistentAckLowestAcked - } - firstAckBlockLength = f.LargestAcked - f.AckRanges[0].First + 1 - numRangesWritten++ - } - - switch missingSequenceNumberDeltaLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(firstAckBlockLength)) - case protocol.PacketNumberLen2: - utils.GetByteOrder(version).WriteUint16(b, uint16(firstAckBlockLength)) - case protocol.PacketNumberLen4: - utils.GetByteOrder(version).WriteUint32(b, uint32(firstAckBlockLength)) - case protocol.PacketNumberLen6: - utils.GetByteOrder(version).WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1)) - } - - for i, ackRange := range f.AckRanges { - if i == 0 { - continue - } - - length := ackRange.Last - ackRange.First + 1 - gap := f.AckRanges[i-1].First - ackRange.Last - 1 - - num := gap/0xFF + 1 - if gap%0xFF == 0 { - num-- - } - - if num == 1 { - b.WriteByte(uint8(gap)) - switch missingSequenceNumberDeltaLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(length)) - case protocol.PacketNumberLen2: - utils.GetByteOrder(version).WriteUint16(b, uint16(length)) - case protocol.PacketNumberLen4: - utils.GetByteOrder(version).WriteUint32(b, uint32(length)) - case protocol.PacketNumberLen6: - utils.GetByteOrder(version).WriteUint48(b, uint64(length)&(1<<48-1)) - } - numRangesWritten++ - } else { - for i := 0; i < int(num); i++ { - var lengthWritten uint64 - var gapWritten uint8 - - if i == int(num)-1 { // last block - lengthWritten = uint64(length) - gapWritten = uint8(1 + ((gap - 1) % 255)) - } else { - lengthWritten = 0 - gapWritten = 0xFF - } - - b.WriteByte(gapWritten) - switch missingSequenceNumberDeltaLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(lengthWritten)) - case protocol.PacketNumberLen2: - utils.GetByteOrder(version).WriteUint16(b, uint16(lengthWritten)) - case protocol.PacketNumberLen4: - utils.GetByteOrder(version).WriteUint32(b, uint32(lengthWritten)) - case protocol.PacketNumberLen6: - utils.GetByteOrder(version).WriteUint48(b, lengthWritten&(1<<48-1)) - } - - numRangesWritten++ - } - } - - // this is needed if not all AckRanges can be written to the ACK frame (if there are more than 0xFF) - if numRangesWritten >= numRanges { - break - } - } - - if numRanges != numRangesWritten { - return errors.New("BUG: Inconsistent number of ACK ranges written") - } - - b.WriteByte(0) // no timestamps return nil } -// MinLength of a written frame -func (f *AckFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp - length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked)) - - missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen()) - - if f.HasMissingRanges() { - length += (1 + missingSequenceNumberDeltaLen) * protocol.ByteCount(f.numWritableNackRanges()) - } else { - length += missingSequenceNumberDeltaLen +// Length of a written frame +func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return f.lengthLegacy(version) } - length += (1 + 2) * 0 /* TODO: num_timestamps */ + largestAcked := f.AckRanges[0].Largest + numRanges := f.numEncodableAckRanges() - return length, nil + length := 1 + utils.VarIntLen(uint64(largestAcked)) + utils.VarIntLen(encodeAckDelay(f.DelayTime)) + + length += utils.VarIntLen(uint64(numRanges - 1)) + lowestInFirstRange := f.AckRanges[0].Smallest + length += utils.VarIntLen(uint64(largestAcked - lowestInFirstRange)) + + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + length += utils.VarIntLen(gap) + length += utils.VarIntLen(len) + } + return length +} + +// gets the number of ACK ranges that can be encoded +// such that the resulting frame is smaller than the maximum ACK frame size +func (f *AckFrame) numEncodableAckRanges() int { + length := 1 + utils.VarIntLen(uint64(f.LargestAcked())) + utils.VarIntLen(encodeAckDelay(f.DelayTime)) + length += 2 // assume that the number of ranges will consume 2 bytes + for i := 1; i < len(f.AckRanges); i++ { + gap, len := f.encodeAckRange(i) + rangeLen := utils.VarIntLen(gap) + utils.VarIntLen(len) + if length+rangeLen > protocol.MaxAckFrameSize { + // Writing range i would exceed the MaxAckFrameSize. + // So encode one range less than that. + return i - 1 + } + length += rangeLen + } + return len(f.AckRanges) +} + +func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { + if i == 0 { + return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) + } + return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), + uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) } // HasMissingRanges returns if this frame reports any missing packets func (f *AckFrame) HasMissingRanges() bool { - return len(f.AckRanges) > 0 + return len(f.AckRanges) > 1 } func (f *AckFrame) validateAckRanges() bool { if len(f.AckRanges) == 0 { - return true - } - - // if there are missing packets, there will always be at least 2 ACK ranges - if len(f.AckRanges) == 1 { - return false - } - - if f.AckRanges[0].Last != f.LargestAcked { return false } // check the validity of every single ACK range for _, ackRange := range f.AckRanges { - if ackRange.First > ackRange.Last { + if ackRange.Smallest > ackRange.Largest { return false } } @@ -378,10 +187,10 @@ func (f *AckFrame) validateAckRanges() bool { continue } lastAckRange := f.AckRanges[i-1] - if lastAckRange.First <= ackRange.First { + if lastAckRange.Smallest <= ackRange.Smallest { return false } - if lastAckRange.First <= ackRange.Last+1 { + if lastAckRange.Smallest <= ackRange.Largest+1 { return false } } @@ -389,78 +198,29 @@ func (f *AckFrame) validateAckRanges() bool { return true } -// numWritableNackRanges calculates the number of ACK blocks that are about to be written -// this number is different from len(f.AckRanges) for the case of long gaps (> 255 packets) -func (f *AckFrame) numWritableNackRanges() uint64 { - if len(f.AckRanges) == 0 { - return 0 - } - - var numRanges uint64 - for i, ackRange := range f.AckRanges { - if i == 0 { - continue - } - - lastAckRange := f.AckRanges[i-1] - gap := lastAckRange.First - ackRange.Last - 1 - rangeLength := 1 + uint64(gap)/0xFF - if uint64(gap)%0xFF == 0 { - rangeLength-- - } - - if numRanges+rangeLength < 0xFF { - numRanges += rangeLength - } else { - break - } - } - - return numRanges + 1 +// LargestAcked is the largest acked packet number +func (f *AckFrame) LargestAcked() protocol.PacketNumber { + return f.AckRanges[0].Largest } -func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen { - var maxRangeLength protocol.PacketNumber - - if f.HasMissingRanges() { - for _, ackRange := range f.AckRanges { - rangeLength := ackRange.Last - ackRange.First + 1 - if rangeLength > maxRangeLength { - maxRangeLength = rangeLength - } - } - } else { - maxRangeLength = f.LargestAcked - f.LowestAcked + 1 - } - - if maxRangeLength <= 0xFF { - return protocol.PacketNumberLen1 - } - if maxRangeLength <= 0xFFFF { - return protocol.PacketNumberLen2 - } - if maxRangeLength <= 0xFFFFFFFF { - return protocol.PacketNumberLen4 - } - - return protocol.PacketNumberLen6 +// LowestAcked is the lowest acked packet number +func (f *AckFrame) LowestAcked() protocol.PacketNumber { + return f.AckRanges[len(f.AckRanges)-1].Smallest } // AcksPacket determines if this ACK frame acks a certain packet number func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { - if p < f.LowestAcked || p > f.LargestAcked { // this is just a performance optimization + if p < f.LowestAcked() || p > f.LargestAcked() { return false } - if f.HasMissingRanges() { - // TODO: this could be implemented as a binary search - for _, ackRange := range f.AckRanges { - if p >= ackRange.First && p <= ackRange.Last { - return true - } - } - return false - } - // if packet doesn't have missing ranges - return (p >= f.LowestAcked && p <= f.LargestAcked) + i := sort.Search(len(f.AckRanges), func(i int) bool { + return p >= f.AckRanges[i].Smallest + }) + // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked + return p <= f.AckRanges[i].Largest +} + +func encodeAckDelay(delay time.Duration) uint64 { + return uint64(delay.Nanoseconds() / (1000 * (1 << ackDelayExponent))) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go new file mode 100644 index 0000000..c2a71e0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame_legacy.go @@ -0,0 +1,364 @@ +package wire + +import ( + "bytes" + "errors" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") + +func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, error) { + frame := &AckFrame{} + + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + hasMissingRanges := typeByte&0x20 == 0x20 + largestAckedLen := 2 * ((typeByte & 0x0C) >> 2) + if largestAckedLen == 0 { + largestAckedLen = 1 + } + + missingSequenceNumberDeltaLen := 2 * (typeByte & 0x03) + if missingSequenceNumberDeltaLen == 0 { + missingSequenceNumberDeltaLen = 1 + } + + la, err := utils.BigEndian.ReadUintN(r, largestAckedLen) + if err != nil { + return nil, err + } + largestAcked := protocol.PacketNumber(la) + + delay, err := utils.BigEndian.ReadUfloat16(r) + if err != nil { + return nil, err + } + frame.DelayTime = time.Duration(delay) * time.Microsecond + + var numAckBlocks uint8 + if hasMissingRanges { + numAckBlocks, err = r.ReadByte() + if err != nil { + return nil, err + } + } + + if hasMissingRanges && numAckBlocks == 0 { + return nil, errInvalidAckRanges + } + + abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) + if err != nil { + return nil, err + } + ackBlockLength := protocol.PacketNumber(abl) + if largestAcked > 0 && ackBlockLength < 1 { + return nil, errors.New("invalid first ACK range") + } + + if ackBlockLength > largestAcked+1 { + return nil, errInvalidAckRanges + } + + if hasMissingRanges { + ackRange := AckRange{ + Smallest: largestAcked - ackBlockLength + 1, + Largest: largestAcked, + } + frame.AckRanges = append(frame.AckRanges, ackRange) + + var inLongBlock bool + var lastRangeComplete bool + for i := uint8(0); i < numAckBlocks; i++ { + var gap uint8 + gap, err = r.ReadByte() + if err != nil { + return nil, err + } + + abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen) + if err != nil { + return nil, err + } + ackBlockLength := protocol.PacketNumber(abl) + + if inLongBlock { + frame.AckRanges[len(frame.AckRanges)-1].Smallest -= protocol.PacketNumber(gap) + ackBlockLength + frame.AckRanges[len(frame.AckRanges)-1].Largest -= protocol.PacketNumber(gap) + } else { + lastRangeComplete = false + ackRange := AckRange{ + Largest: frame.AckRanges[len(frame.AckRanges)-1].Smallest - protocol.PacketNumber(gap) - 1, + } + ackRange.Smallest = ackRange.Largest - ackBlockLength + 1 + frame.AckRanges = append(frame.AckRanges, ackRange) + } + + if ackBlockLength > 0 { + lastRangeComplete = true + } + inLongBlock = (ackBlockLength == 0) + } + + // if the last range was not complete, First and Last make no sense + // remove the range from frame.AckRanges + if !lastRangeComplete { + frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1] + } + } else { + frame.AckRanges = make([]AckRange, 1) + if largestAcked != 0 { + frame.AckRanges[0].Largest = largestAcked + frame.AckRanges[0].Smallest = largestAcked + 1 - ackBlockLength + } + } + + if !frame.validateAckRanges() { + return nil, errInvalidAckRanges + } + + var numTimestamp byte + numTimestamp, err = r.ReadByte() + if err != nil { + return nil, err + } + + if numTimestamp > 0 { + // Delta Largest acked + _, err = r.ReadByte() + if err != nil { + return nil, err + } + // First Timestamp + _, err = utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + + for i := 0; i < int(numTimestamp)-1; i++ { + // Delta Largest acked + _, err = r.ReadByte() + if err != nil { + return nil, err + } + + // Time Since Previous Timestamp + _, err = utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + } + } + return frame, nil +} + +func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error { + largestAcked := f.LargestAcked() + largestAckedLen := protocol.GetPacketNumberLength(largestAcked) + + typeByte := uint8(0x40) + + if largestAckedLen != protocol.PacketNumberLen1 { + typeByte ^= (uint8(largestAckedLen / 2)) << 2 + } + + missingSequenceNumberDeltaLen := f.getMissingSequenceNumberDeltaLen() + if missingSequenceNumberDeltaLen != protocol.PacketNumberLen1 { + typeByte ^= (uint8(missingSequenceNumberDeltaLen / 2)) + } + + if f.HasMissingRanges() { + typeByte |= 0x20 + } + + b.WriteByte(typeByte) + + switch largestAckedLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(largestAcked)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(largestAcked)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(largestAcked)) + case protocol.PacketNumberLen6: + utils.BigEndian.WriteUint48(b, uint64(largestAcked)&(1<<48-1)) + } + + utils.BigEndian.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) + + var numRanges uint64 + var numRangesWritten uint64 + if f.HasMissingRanges() { + numRanges = f.numWritableNackRanges() + if numRanges > 0xFF { + panic("AckFrame: Too many ACK ranges") + } + b.WriteByte(uint8(numRanges - 1)) + } + + var firstAckBlockLength protocol.PacketNumber + if !f.HasMissingRanges() { + firstAckBlockLength = largestAcked - f.LowestAcked() + 1 + } else { + firstAckBlockLength = largestAcked - f.AckRanges[0].Smallest + 1 + numRangesWritten++ + } + + switch missingSequenceNumberDeltaLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(firstAckBlockLength)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(firstAckBlockLength)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(firstAckBlockLength)) + case protocol.PacketNumberLen6: + utils.BigEndian.WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1)) + } + + for i, ackRange := range f.AckRanges { + if i == 0 { + continue + } + + length := ackRange.Largest - ackRange.Smallest + 1 + gap := f.AckRanges[i-1].Smallest - ackRange.Largest - 1 + + num := gap/0xFF + 1 + if gap%0xFF == 0 { + num-- + } + + if num == 1 { + b.WriteByte(uint8(gap)) + switch missingSequenceNumberDeltaLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(length)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(length)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(length)) + case protocol.PacketNumberLen6: + utils.BigEndian.WriteUint48(b, uint64(length)&(1<<48-1)) + } + numRangesWritten++ + } else { + for i := 0; i < int(num); i++ { + var lengthWritten uint64 + var gapWritten uint8 + + if i == int(num)-1 { // last block + lengthWritten = uint64(length) + gapWritten = uint8(1 + ((gap - 1) % 255)) + } else { + lengthWritten = 0 + gapWritten = 0xFF + } + + b.WriteByte(gapWritten) + switch missingSequenceNumberDeltaLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(lengthWritten)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(lengthWritten)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(lengthWritten)) + case protocol.PacketNumberLen6: + utils.BigEndian.WriteUint48(b, lengthWritten&(1<<48-1)) + } + + numRangesWritten++ + } + } + + // this is needed if not all AckRanges can be written to the ACK frame (if there are more than 0xFF) + if numRangesWritten >= numRanges { + break + } + } + + if numRanges != numRangesWritten { + return errors.New("BUG: Inconsistent number of ACK ranges written") + } + + b.WriteByte(0) // no timestamps + return nil +} + +func (f *AckFrame) lengthLegacy(_ protocol.VersionNumber) protocol.ByteCount { + length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp + length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked())) + + missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen()) + + if f.HasMissingRanges() { + length += (1 + missingSequenceNumberDeltaLen) * protocol.ByteCount(f.numWritableNackRanges()) + } else { + length += missingSequenceNumberDeltaLen + } + // we don't write + return length +} + +// numWritableNackRanges calculates the number of ACK blocks that are about to be written +// this number is different from len(f.AckRanges) for the case of long gaps (> 255 packets) +func (f *AckFrame) numWritableNackRanges() uint64 { + if len(f.AckRanges) == 0 { + return 0 + } + + var numRanges uint64 + for i, ackRange := range f.AckRanges { + if i == 0 { + continue + } + + lastAckRange := f.AckRanges[i-1] + gap := lastAckRange.Smallest - ackRange.Largest - 1 + rangeLength := 1 + uint64(gap)/0xFF + if uint64(gap)%0xFF == 0 { + rangeLength-- + } + + if numRanges+rangeLength < 0xFF { + numRanges += rangeLength + } else { + break + } + } + + return numRanges + 1 +} + +func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen { + var maxRangeLength protocol.PacketNumber + + if f.HasMissingRanges() { + for _, ackRange := range f.AckRanges { + rangeLength := ackRange.Largest - ackRange.Smallest + 1 + if rangeLength > maxRangeLength { + maxRangeLength = rangeLength + } + } + } else { + maxRangeLength = f.LargestAcked() - f.LowestAcked() + 1 + } + + if maxRangeLength <= 0xFF { + return protocol.PacketNumberLen1 + } + if maxRangeLength <= 0xFFFF { + return protocol.PacketNumberLen2 + } + if maxRangeLength <= 0xFFFFFFFF { + return protocol.PacketNumberLen4 + } + + return protocol.PacketNumberLen6 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go index c561762..0f41858 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go @@ -4,6 +4,11 @@ import "github.com/lucas-clemente/quic-go/internal/protocol" // AckRange is an ACK range type AckRange struct { - First protocol.PacketNumber - Last protocol.PacketNumber + Smallest protocol.PacketNumber + Largest protocol.PacketNumber +} + +// Len returns the number of packets contained in this ACK range +func (r AckRange) Len() protocol.PacketNumber { + return r.Largest - r.Smallest + 1 } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go index 08dc051..1d3e56e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go @@ -4,32 +4,42 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) // A BlockedFrame is a BLOCKED frame -type BlockedFrame struct{} +type BlockedFrame struct { + Offset protocol.ByteCount +} -// ParseBlockedFrame parses a BLOCKED frame -func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) { +// parseBlockedFrame parses a BLOCKED frame +func parseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) { if _, err := r.ReadByte(); err != nil { return nil, err } - return &BlockedFrame{}, nil + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &BlockedFrame{ + Offset: protocol.ByteCount(offset), + }, nil } func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - if !version.UsesMaxDataFrame() { + if !version.UsesIETFFrameFormat() { return (&blockedFrameLegacy{}).Write(b, version) } typeByte := uint8(0x08) b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.Offset)) return nil } -// MinLength of a written frame -func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - if !version.UsesMaxDataFrame() { // writing this frame would result in a legacy BLOCKED being written, which is longer - return 1 + 4, nil +// Length of a written frame +func (f *BlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return 1 + 4 } - return 1, nil + return 1 + utils.VarIntLen(uint64(f.Offset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go index d60ca4c..9943e16 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go @@ -11,16 +11,15 @@ type blockedFrameLegacy struct { StreamID protocol.StreamID } -// ParseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format) +// parseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format) // The frame returned is // * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream // * a BLOCKED frame, if the BLOCKED applies to the connection -func ParseBlockedFrameLegacy(r *bytes.Reader, version protocol.VersionNumber) (Frame, error) { - // read the TypeByte - if _, err := r.ReadByte(); err != nil { +func parseBlockedFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } - streamID, err := utils.GetByteOrder(version).ReadUint32(r) + streamID, err := utils.BigEndian.ReadUint32(r) if err != nil { return nil, err } @@ -31,8 +30,8 @@ func ParseBlockedFrameLegacy(r *bytes.Reader, version protocol.VersionNumber) (F } //Write writes a BLOCKED frame -func (f *blockedFrameLegacy) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *blockedFrameLegacy) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { b.WriteByte(0x05) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go index 432c6a8..667ded7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go @@ -17,28 +17,38 @@ type ConnectionCloseFrame struct { ReasonPhrase string } -// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame -func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) { - frame := &ConnectionCloseFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { +// parseConnectionCloseFrame reads a CONNECTION_CLOSE frame +func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } - errorCode, err := utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - frame.ErrorCode = qerr.ErrorCode(errorCode) - - reasonPhraseLen, err := utils.GetByteOrder(version).ReadUint16(r) - if err != nil { - return nil, err + var errorCode qerr.ErrorCode + var reasonPhraseLen uint64 + if version.UsesIETFFrameFormat() { + ec, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + errorCode = qerr.ErrorCode(ec) + reasonPhraseLen, err = utils.ReadVarInt(r) + if err != nil { + return nil, err + } + } else { + ec, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + errorCode = qerr.ErrorCode(ec) + length, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + reasonPhraseLen = uint64(length) } - // shortcut to prevent the unneccessary allocation of dataLen bytes + // shortcut to prevent the unnecessary allocation of dataLen bytes // if the dataLen is larger than the remaining length of the packet // reading the whole reason phrase would result in EOF when attempting to READ if int(reasonPhraseLen) > r.Len() { @@ -50,27 +60,36 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) // this should never happen, since we already checked the reasonPhraseLen earlier return nil, err } - frame.ReasonPhrase = string(reasonPhrase) - return frame, nil + return &ConnectionCloseFrame{ + ErrorCode: errorCode, + ReasonPhrase: string(reasonPhrase), + }, nil } -// MinLength of a written frame -func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil +// Length of a written frame +func (f *ConnectionCloseFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if version.UsesIETFFrameFormat() { + return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) + } + return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)) } // Write writes an CONNECTION_CLOSE frame. func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { b.WriteByte(0x02) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.ErrorCode)) if len(f.ReasonPhrase) > math.MaxUint16 { return errors.New("ConnectionFrame: ReasonPhrase too long") } - reasonPhraseLen := uint16(len(f.ReasonPhrase)) - utils.GetByteOrder(version).WriteUint16(b, reasonPhraseLen) + if version.UsesIETFFrameFormat() { + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + utils.WriteVarInt(b, uint64(len(f.ReasonPhrase))) + } else { + utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode)) + utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase))) + } b.WriteString(f.ReasonPhrase) return nil diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go index f31f5bf..835905a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go @@ -9,5 +9,5 @@ import ( // A Frame in QUIC type Frame interface { Write(b *bytes.Buffer, version protocol.VersionNumber) error - MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) + Length(version protocol.VersionNumber) protocol.ByteCount } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go new file mode 100644 index 0000000..67b191b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go @@ -0,0 +1,162 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/qerr" +) + +// ParseNextFrame parses the next frame +// It skips PADDING frames. +func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Frame, error) { + for r.Len() != 0 { + typeByte, _ := r.ReadByte() + if typeByte == 0x0 { // PADDING frame + continue + } + r.UnreadByte() + + if !v.UsesIETFFrameFormat() { + return parseGQUICFrame(r, typeByte, hdr, v) + } + return parseIETFFrame(r, typeByte, v) + } + return nil, nil +} + +func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) { + var frame Frame + var err error + if typeByte&0xf8 == 0x10 { + frame, err = parseStreamFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidStreamData, err.Error()) + } + return frame, err + } + // TODO: implement all IETF QUIC frame types + switch typeByte { + case 0x1: + frame, err = parseRstStreamFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) + } + case 0x2: + frame, err = parseConnectionCloseFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) + } + case 0x4: + frame, err = parseMaxDataFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + case 0x5: + frame, err = parseMaxStreamDataFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + case 0x6: + frame, err = parseMaxStreamIDFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0x7: + frame, err = parsePingFrame(r, v) + case 0x8: + frame, err = parseBlockedFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + case 0x9: + frame, err = parseStreamBlockedFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + case 0xa: + frame, err = parseStreamIDBlockedFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0xc: + frame, err = parseStopSendingFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0xd: + frame, err = parseAckFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidAckData, err.Error()) + } + case 0xe: + frame, err = parsePathChallengeFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + case 0xf: + frame, err = parsePathResponseFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } + default: + err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) + } + return frame, err +} + +func parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *Header, v protocol.VersionNumber) (Frame, error) { + var frame Frame + var err error + if typeByte&0x80 == 0x80 { + frame, err = parseStreamFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidStreamData, err.Error()) + } + return frame, err + } else if typeByte&0xc0 == 0x40 { + frame, err = parseAckFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidAckData, err.Error()) + } + return frame, err + } + switch typeByte { + case 0x1: + frame, err = parseRstStreamFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) + } + case 0x2: + frame, err = parseConnectionCloseFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) + } + case 0x3: + frame, err = parseGoawayFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidGoawayData, err.Error()) + } + case 0x4: + frame, err = parseWindowUpdateFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + case 0x5: + frame, err = parseBlockedFrameLegacy(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + case 0x6: + frame, err = parseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, v) + if err != nil { + err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) + } + case 0x7: + frame, err = parsePingFrame(r, v) + default: + err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) + } + return frame, err +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go index 5332210..86bf2b4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go @@ -16,32 +16,32 @@ type GoawayFrame struct { ReasonPhrase string } -// ParseGoawayFrame parses a GOAWAY frame -func ParseGoawayFrame(r *bytes.Reader, version protocol.VersionNumber) (*GoawayFrame, error) { +// parseGoawayFrame parses a GOAWAY frame +func parseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, error) { frame := &GoawayFrame{} if _, err := r.ReadByte(); err != nil { return nil, err } - errorCode, err := utils.GetByteOrder(version).ReadUint32(r) + errorCode, err := utils.BigEndian.ReadUint32(r) if err != nil { return nil, err } frame.ErrorCode = qerr.ErrorCode(errorCode) - lastGoodStream, err := utils.GetByteOrder(version).ReadUint32(r) + lastGoodStream, err := utils.BigEndian.ReadUint32(r) if err != nil { return nil, err } frame.LastGoodStream = protocol.StreamID(lastGoodStream) - reasonPhraseLen, err := utils.GetByteOrder(version).ReadUint16(r) + reasonPhraseLen, err := utils.BigEndian.ReadUint16(r) if err != nil { return nil, err } - if reasonPhraseLen > uint16(protocol.MaxPacketSize) { + if reasonPhraseLen > uint16(protocol.MaxReceivePacketSize) { return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long") } @@ -53,16 +53,16 @@ func ParseGoawayFrame(r *bytes.Reader, version protocol.VersionNumber) (*GoawayF return frame, nil } -func (f *GoawayFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { b.WriteByte(0x03) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.ErrorCode)) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.LastGoodStream)) - utils.GetByteOrder(version).WriteUint16(b, uint16(len(f.ReasonPhrase))) + utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode)) + utils.BigEndian.WriteUint32(b, uint32(f.LastGoodStream)) + utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase))) b.WriteString(f.ReasonPhrase) return nil } -// MinLength of a written frame -func (f *GoawayFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)), nil +// Length of a written frame +func (f *GoawayFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go index 96066cc..4126133 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go @@ -4,18 +4,27 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) // Header is the header of a QUIC packet. // It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header. type Header struct { - Raw []byte - ConnectionID protocol.ConnectionID - OmitConnectionID bool - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - Version protocol.VersionNumber // VersionNumber sent by the client - SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server + IsPublicHeader bool + + Raw []byte + + Version protocol.VersionNumber + + DestConnectionID protocol.ConnectionID + SrcConnectionID protocol.ConnectionID + OmitConnectionID bool + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + + IsVersionNegotiation bool + SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server // only needed for the gQUIC Public Header VersionFlag bool @@ -26,9 +35,7 @@ type Header struct { Type protocol.PacketType IsLongHeader bool KeyPhase int - - // only needed for logging - isPublicHeader bool + PayloadLen protocol.ByteCount } // ParseHeaderSentByServer parses the header for a packet that was sent by the server. @@ -40,17 +47,15 @@ func ParseHeaderSentByServer(b *bytes.Reader, version protocol.VersionNumber) (* _ = b.UnreadByte() // unread the type byte var isPublicHeader bool - // As a client, we know the version of the packet that the server sent, except for Version Negotiation Packets. - if typeByte == 0x81 { // IETF draft Version Negotiation Packet + if typeByte&0x80 > 0 { // gQUIC always has 0x80 unset. IETF Long Header or Version Negotiation isPublicHeader = false } else if typeByte&0xcf == 0x9 { // gQUIC Version Negotiation Packet - // IETF QUIC Version Negotiation Packets are sent with the Long Header (indicated by the 0x80 bit) - // gQUIC always has 0x80 unset isPublicHeader = true - } else { // not a Version Negotiation Packet + } else { // the client knows the version that this packet was sent with isPublicHeader = !version.UsesTLS() } + return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader) } @@ -62,12 +67,13 @@ func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) { } _ = b.UnreadByte() // unread the type byte - // If this is a gQUIC header 0x80 and 0x40 will be set to 0. - // If this is an IETF QUIC header there are two options: - // * either 0x80 will be 1 (for the Long Header) - // * or 0x40 (the Connection ID Flag) will be 0 (for the Short Header), since we don't the client to omit it - isPublicHeader := typeByte&0xc0 == 0 - + // In an IETF QUIC packet header + // * either 0x80 is set (for the Long Header) + // * or 0x8 is unset (for the Short Header) + // In a gQUIC Public Header + // * 0x80 is always unset and + // * and 0x8 is always set (this is the Connection ID flag, which the client always sets) + isPublicHeader := typeByte&0x88 == 0x8 return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader) } @@ -78,16 +84,16 @@ func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHea if err != nil { return nil, err } - hdr.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later + hdr.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later return hdr, nil } - return parseHeader(b, sentBy) + return parseHeader(b) } // Write writes the Header. func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { if !version.UsesTLS() { - h.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later + h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later return h.writePublicHeader(b, pers, version) } return h.writeHeader(b) @@ -102,10 +108,10 @@ func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNu } // Log logs the Header -func (h *Header) Log() { - if h.isPublicHeader { - h.logPublicHeader() +func (h *Header) Log(logger utils.Logger) { + if h.IsPublicHeader { + h.logPublicHeader(logger) } else { - h.logHeader() + h.logHeader(logger) } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go index 3db67cc..aa98226 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go @@ -2,7 +2,9 @@ package wire import ( "bytes" + "errors" "fmt" + "io" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -10,52 +12,50 @@ import ( ) // parseHeader parses the header. -func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { +func parseHeader(b *bytes.Reader) (*Header, error) { typeByte, err := b.ReadByte() if err != nil { return nil, err } if typeByte&0x80 > 0 { - return parseLongHeader(b, packetSentBy, typeByte) + return parseLongHeader(b, typeByte) } return parseShortHeader(b, typeByte) } -func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte) (*Header, error) { - connID, err := utils.BigEndian.ReadUint64(b) - if err != nil { - return nil, err - } - pn, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, err - } +// parse long header and version negotiation packets +func parseLongHeader(b *bytes.Reader, typeByte byte) (*Header, error) { v, err := utils.BigEndian.ReadUint32(b) if err != nil { return nil, err } - packetType := protocol.PacketType(typeByte & 0x7f) - if sentBy == protocol.PerspectiveClient && (packetType != protocol.PacketTypeInitial && packetType != protocol.PacketTypeHandshake && packetType != protocol.PacketType0RTT) { - if packetType == protocol.PacketTypeVersionNegotiation { - return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "sent by the client") - } - return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", packetType)) + + connIDLenByte, err := b.ReadByte() + if err != nil { + return nil, err } - if sentBy == protocol.PerspectiveServer && (packetType != protocol.PacketTypeVersionNegotiation && packetType != protocol.PacketTypeRetry && packetType != protocol.PacketTypeHandshake) { - return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", packetType)) + dcil, scil := decodeConnIDLen(connIDLenByte) + destConnID, err := protocol.ReadConnectionID(b, dcil) + if err != nil { + return nil, err } + srcConnID, err := protocol.ReadConnectionID(b, scil) + if err != nil { + return nil, err + } + h := &Header{ - Type: packetType, - IsLongHeader: true, - ConnectionID: protocol.ConnectionID(connID), - PacketNumber: protocol.PacketNumber(pn), - PacketNumberLen: protocol.PacketNumberLen4, - Version: protocol.VersionNumber(v), + IsLongHeader: true, + Version: protocol.VersionNumber(v), + DestConnectionID: destConnID, + SrcConnectionID: srcConnID, } - if h.Type == protocol.PacketTypeVersionNegotiation { + + if v == 0 { // version negotiation packet if b.Len() == 0 { return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") } + h.IsVersionNegotiation = true h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4) for i := 0; b.Len() > 0; i++ { v, err := utils.BigEndian.ReadUint32(b) @@ -64,31 +64,60 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte } h.SupportedVersions[i] = protocol.VersionNumber(v) } + return h, nil + } + + pl, err := utils.ReadVarInt(b) + if err != nil { + return nil, err + } + h.PayloadLen = protocol.ByteCount(pl) + pn, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + h.PacketNumber = protocol.PacketNumber(pn) + h.PacketNumberLen = protocol.PacketNumberLen4 + h.Type = protocol.PacketType(typeByte & 0x7f) + + if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake { + return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) } return h, nil } func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { - hasConnID := typeByte&0x40 > 0 - var connID uint64 - if hasConnID { - var err error - connID, err = utils.BigEndian.ReadUint64(b) - if err != nil { - return nil, err + connID := make(protocol.ConnectionID, 8) + if _, err := io.ReadFull(b, connID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF } + return nil, err + } + // bits 2 and 3 must be set, bit 4 must be unset + if typeByte&0x38 != 0x30 { + return nil, errors.New("invalid bits 3, 4 and 5") + } + var pnLen protocol.PacketNumberLen + switch typeByte & 0x3 { + case 0x0: + pnLen = protocol.PacketNumberLen1 + case 0x1: + pnLen = protocol.PacketNumberLen2 + case 0x2: + pnLen = protocol.PacketNumberLen4 + default: + return nil, errors.New("invalid short header type") } - pnLen := 1 << ((typeByte & 0x3) - 1) pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen)) if err != nil { return nil, err } return &Header{ - KeyPhase: int(typeByte&0x20) >> 5, - OmitConnectionID: !hasConnID, - ConnectionID: protocol.ConnectionID(connID), + KeyPhase: int(typeByte&0x40) >> 6, + DestConnectionID: connID, PacketNumber: protocol.PacketNumber(pn), - PacketNumberLen: protocol.PacketNumberLen(pnLen), + PacketNumberLen: pnLen, }, nil } @@ -102,33 +131,38 @@ func (h *Header) writeHeader(b *bytes.Buffer) error { // TODO: add support for the key phase func (h *Header) writeLongHeader(b *bytes.Buffer) error { - b.WriteByte(byte(0x80 ^ h.Type)) - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + if h.SrcConnectionID.Len() != protocol.ConnectionIDLen { + return fmt.Errorf("Header: source connection ID must be %d bytes, is %d", protocol.ConnectionIDLen, h.SrcConnectionID.Len()) + } + b.WriteByte(byte(0x80 | h.Type)) utils.BigEndian.WriteUint32(b, uint32(h.Version)) + connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) + if err != nil { + return err + } + b.WriteByte(connIDLen) + b.Write(h.DestConnectionID.Bytes()) + b.Write(h.SrcConnectionID.Bytes()) + utils.WriteVarInt(b, uint64(h.PayloadLen)) + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) return nil } func (h *Header) writeShortHeader(b *bytes.Buffer) error { - typeByte := byte(h.KeyPhase << 5) - if !h.OmitConnectionID { - typeByte ^= 0x40 - } + typeByte := byte(0x30) + typeByte |= byte(h.KeyPhase << 6) switch h.PacketNumberLen { case protocol.PacketNumberLen1: - typeByte ^= 0x1 case protocol.PacketNumberLen2: - typeByte ^= 0x2 + typeByte |= 0x1 case protocol.PacketNumberLen4: - typeByte ^= 0x3 + typeByte |= 0x2 default: return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) } b.WriteByte(typeByte) - if !h.OmitConnectionID { - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) - } + b.Write(h.DestConnectionID.Bytes()) switch h.PacketNumberLen { case protocol.PacketNumberLen1: b.WriteByte(uint8(h.PacketNumber)) @@ -140,16 +174,12 @@ func (h *Header) writeShortHeader(b *bytes.Buffer) error { return nil } -// getHeaderLength gets the length of the Header in bytes. func (h *Header) getHeaderLength() (protocol.ByteCount, error) { if h.IsLongHeader { - return 1 + 8 + 4 + 4, nil + return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + 4 /* packet number */, nil } - length := protocol.ByteCount(1) // type byte - if !h.OmitConnectionID { - length += 8 - } + length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) } @@ -157,14 +187,48 @@ func (h *Header) getHeaderLength() (protocol.ByteCount, error) { return length, nil } -func (h *Header) logHeader() { +func (h *Header) logHeader(logger utils.Logger) { if h.IsLongHeader { - utils.Debugf(" Long Header{Type: %#x, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) - } else { - connID := "(omitted)" - if !h.OmitConnectionID { - connID = fmt.Sprintf("%#x", h.ConnectionID) + if h.Version == 0 { + logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions) + } else { + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PayloadLen, h.Version) } - utils.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } else { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) } } + +func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) { + dcil, err := encodeSingleConnIDLen(dest) + if err != nil { + return 0, err + } + scil, err := encodeSingleConnIDLen(src) + if err != nil { + return 0, err + } + return scil | dcil<<4, nil +} + +func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) { + len := id.Len() + if len == 0 { + return 0, nil + } + if len < 4 || len > 18 { + return 0, fmt.Errorf("invalid connection ID length: %d bytes", len) + } + return byte(len - 3), nil +} + +func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) { + return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf) +} + +func decodeSingleConnIDLen(enc uint8) int { + if enc == 0 { + return 0 + } + return int(enc) + 3 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go index 0e72ea9..465e82a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go @@ -1,10 +1,15 @@ package wire -import "github.com/lucas-clemente/quic-go/internal/utils" +import ( + "fmt" + "strings" + + "github.com/lucas-clemente/quic-go/internal/utils" +) // LogFrame logs a frame, either sent or received -func LogFrame(frame Frame, sent bool) { - if !utils.Debug() { +func LogFrame(logger utils.Logger, frame Frame, sent bool) { + if !logger.Debug() { return } dir := "<-" @@ -13,16 +18,24 @@ func LogFrame(frame Frame, sent bool) { } switch f := frame.(type) { case *StreamFrame: - utils.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) case *StopWaitingFrame: if sent { - utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) + logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) } else { - utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) + logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) } case *AckFrame: - utils.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) + if len(f.AckRanges) > 1 { + ackRanges := make([]string, len(f.AckRanges)) + for i, r := range f.AckRanges { + ackRanges[i] = fmt.Sprintf("{Largest: %#x, Smallest: %#x}", r.Largest, r.Smallest) + } + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %#x, LowestAcked: %#x, AckRanges: {%s}, DelayTime: %s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String()) + } else { + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %#x, LowestAcked: %#x, DelayTime: %s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String()) + } default: - utils.Debugf("\t%s %#v", dir, frame) + logger.Debugf("\t%s %#v", dir, frame) } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go index cd3ff65..0bca27d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go @@ -12,15 +12,15 @@ type MaxDataFrame struct { ByteOffset protocol.ByteCount } -// ParseMaxDataFrame parses a MAX_DATA frame -func ParseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDataFrame, error) { +// parseMaxDataFrame parses a MAX_DATA frame +func parseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDataFrame, error) { // read the TypeByte if _, err := r.ReadByte(); err != nil { return nil, err } frame := &MaxDataFrame{} - byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) + byteOffset, err := utils.ReadVarInt(r) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func ParseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDat //Write writes a MAX_STREAM_DATA frame func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - if !version.UsesMaxDataFrame() { + if !version.UsesIETFFrameFormat() { // write a gQUIC WINDOW_UPDATE frame (with stream ID 0, which means connection-level there) return (&windowUpdateFrame{ StreamID: 0, @@ -38,14 +38,14 @@ func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) er }).Write(b, version) } b.WriteByte(0x4) - utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) return nil } -// MinLength of a written frame -func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - if !version.UsesMaxDataFrame() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer - return 1 + 4 + 8, nil +// Length of a written frame +func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer + return 1 + 4 + 8 } - return 1 + 8, nil + return 1 + utils.VarIntLen(uint64(f.ByteOffset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go index 56c44c9..6d8be23 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go @@ -13,8 +13,8 @@ type MaxStreamDataFrame struct { ByteOffset protocol.ByteCount } -// ParseMaxStreamDataFrame parses a MAX_STREAM_DATA frame -func ParseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxStreamDataFrame, error) { +// parseMaxStreamDataFrame parses a MAX_STREAM_DATA frame +func parseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxStreamDataFrame, error) { frame := &MaxStreamDataFrame{} // read the TypeByte @@ -22,13 +22,13 @@ func ParseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (* return nil, err } - sid, err := utils.GetByteOrder(version).ReadUint32(r) + sid, err := utils.ReadVarInt(r) if err != nil { return nil, err } frame.StreamID = protocol.StreamID(sid) - byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) + byteOffset, err := utils.ReadVarInt(r) if err != nil { return nil, err } @@ -38,19 +38,23 @@ func ParseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (* // Write writes a MAX_STREAM_DATA frame func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - if !version.UsesMaxDataFrame() { + if !version.UsesIETFFrameFormat() { return (&windowUpdateFrame{ StreamID: f.StreamID, ByteOffset: f.ByteOffset, }).Write(b, version) } b.WriteByte(0x5) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) - utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) return nil } -// MinLength of a written frame -func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 8, nil +// Length of a written frame +func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which has a different length + if !version.UsesIETFFrameFormat() { + return 1 + 4 + 8 + } + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.ByteOffset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go new file mode 100644 index 0000000..9f5424d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_id_frame.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxStreamIDFrame is a MAX_STREAM_ID frame +type MaxStreamIDFrame struct { + StreamID protocol.StreamID +} + +// parseMaxStreamIDFrame parses a MAX_STREAM_ID frame +func parseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) { + // read the Type byte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x6) + utils.WriteVarInt(b, uint64(f.StreamID)) + return nil +} + +// Length of a written frame +func (f *MaxStreamIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go new file mode 100644 index 0000000..f2a27d8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A PathChallengeFrame is a PATH_CHALLENGE frame +type PathChallengeFrame struct { + Data [8]byte +} + +func parsePathChallengeFrame(r *bytes.Reader, version protocol.VersionNumber) (*PathChallengeFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathChallengeFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x0e) + b.WriteByte(typeByte) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go new file mode 100644 index 0000000..2ab2fcd --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A PathResponseFrame is a PATH_RESPONSE frame +type PathResponseFrame struct { + Data [8]byte +} + +func parsePathResponseFrame(r *bytes.Reader, version protocol.VersionNumber) (*PathResponseFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathResponseFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x0f) + b.WriteByte(typeByte) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go index 2a09c33..bc1deda 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go @@ -9,8 +9,8 @@ import ( // A PingFrame is a ping frame type PingFrame struct{} -// ParsePingFrame parses a Ping frame -func ParsePingFrame(r *bytes.Reader, version protocol.VersionNumber) (*PingFrame, error) { +// parsePingFrame parses a Ping frame +func parsePingFrame(r *bytes.Reader, version protocol.VersionNumber) (*PingFrame, error) { frame := &PingFrame{} _, err := r.ReadByte() @@ -27,7 +27,7 @@ func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error return nil } -// MinLength of a written frame -func (f *PingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1, nil +// Length of a written frame +func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go index ba5c8e6..33a0eba 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go @@ -19,10 +19,19 @@ var ( ) // writePublicHeader writes a Public Header. -func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { +func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error { + if h.VersionFlag && pers == protocol.PerspectiveServer { + return errors.New("PublicHeader: Writing of Version Negotiation Packets not supported") + } if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet } + if !h.DestConnectionID.Equal(h.SrcConnectionID) { + return fmt.Errorf("PublicHeader: SrcConnectionID must be equal to DestConnectionID") + } + if len(h.DestConnectionID) != 8 { + return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID)) + } publicFlagByte := uint8(0x00) if h.VersionFlag { @@ -56,7 +65,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, v b.WriteByte(publicFlagByte) if !h.OmitConnectionID { - utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) + b.Write(h.DestConnectionID) } if h.VersionFlag && pers == protocol.PerspectiveClient { utils.BigEndian.WriteUint32(b, uint32(h.Version)) @@ -73,11 +82,11 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, v case protocol.PacketNumberLen1: b.WriteByte(uint8(h.PacketNumber)) case protocol.PacketNumberLen2: - utils.GetByteOrder(version).WriteUint16(b, uint16(h.PacketNumber)) + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) case protocol.PacketNumberLen4: - utils.GetByteOrder(version).WriteUint32(b, uint32(h.PacketNumber)) + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) case protocol.PacketNumberLen6: - utils.GetByteOrder(version).WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) + utils.BigEndian.WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) default: return errors.New("PublicHeader: PacketNumberLen not set") } @@ -123,21 +132,23 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea // Connection ID if !header.OmitConnectionID { - var connID uint64 - connID, err = utils.BigEndian.ReadUint64(b) - if err != nil { + connID := make(protocol.ConnectionID, 8) + if _, err := io.ReadFull(b, connID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF + } return nil, err } - header.ConnectionID = protocol.ConnectionID(connID) - if header.ConnectionID == 0 { + if connID[0] == 0 && connID[1] == 0 && connID[2] == 0 && connID[3] == 0 && connID[4] == 0 && connID[5] == 0 && connID[6] == 0 && connID[7] == 0 { return nil, errInvalidConnectionID } + header.DestConnectionID = connID + header.SrcConnectionID = connID } + // Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server. + // It doesn't have any meaning when sent by the client. if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 { - // TODO: remove the if once the Google servers send the correct value - // assume that a packet doesn't contain a diversification nonce if the version flag or the reset flag is set, no matter what the public flag says - // see https://github.com/lucas-clemente/quic-go/issues/232 if !header.VersionFlag && !header.ResetFlag { header.DiversificationNonce = make([]byte, 32) if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil { @@ -148,13 +159,14 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea // Version (optional) if !header.ResetFlag && header.VersionFlag { - if packetSentBy == protocol.PerspectiveServer { // parse the version negotiaton packet + if packetSentBy == protocol.PerspectiveServer { // parse the version negotiation packet if b.Len() == 0 { return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") } if b.Len()%4 != 0 { return nil, qerr.InvalidVersionNegotiationPacket } + header.IsVersionNegotiation = true header.SupportedVersions = make([]protocol.VersionNumber, 0) for { var versionTag uint32 @@ -228,14 +240,10 @@ func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { return true } -func (h *Header) logPublicHeader() { - connID := "(omitted)" - if !h.OmitConnectionID { - connID = fmt.Sprintf("%#x", h.ConnectionID) - } +func (h *Header) logPublicHeader(logger utils.Logger) { ver := "(unset)" if h.Version != 0 { - ver = fmt.Sprintf("%s", h.Version) + ver = h.Version.String() } - utils.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) + logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go index 6adc9f6..b57ea7a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go @@ -16,11 +16,11 @@ type PublicReset struct { Nonce uint64 } -// WritePublicReset writes a Public Reset +// WritePublicReset writes a PUBLIC_RESET func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { b := &bytes.Buffer{} b.WriteByte(0x0a) - utils.BigEndian.WriteUint64(b, uint64(connectionID)) + b.Write(connectionID) utils.LittleEndian.WriteUint32(b, uint32(handshake.TagPRST)) utils.LittleEndian.WriteUint32(b, 2) utils.LittleEndian.WriteUint32(b, uint32(handshake.TagRNON)) @@ -32,7 +32,7 @@ func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p return b.Bytes() } -// ParsePublicReset parses a Public Reset +// ParsePublicReset parses a PUBLIC_RESET func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { pr := PublicReset{} msg, err := handshake.ParseHandshakeMessage(r) @@ -44,7 +44,7 @@ func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { } // The RSEQ tag is mandatory according to the gQUIC wire spec. - // However, Google doesn't send RSEQ in their Public Resets. + // However, Google doesn't send RSEQ in their PUBLIC_RESETs. // Therefore, we'll treat RSEQ as an optional field. if rseq, ok := msg.Data[handshake.TagRSEQ]; ok { if len(rseq) != 8 { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go index 04086f8..209422c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go @@ -7,51 +7,83 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -// A RstStreamFrame in QUIC +// A RstStreamFrame is a RST_STREAM frame in QUIC type RstStreamFrame struct { - StreamID protocol.StreamID - ErrorCode uint32 + StreamID protocol.StreamID + // The error code is a uint32 in gQUIC, but a uint16 in IETF QUIC. + // protocol.ApplicaitonErrorCode is a uint16, so larger values in gQUIC frames will be truncated. + ErrorCode protocol.ApplicationErrorCode ByteOffset protocol.ByteCount } +// parseRstStreamFrame parses a RST_STREAM frame +func parseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + var streamID protocol.StreamID + var errorCode uint16 + var byteOffset protocol.ByteCount + if version.UsesIETFFrameFormat() { + sid, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + errorCode, err = utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + bo, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + } else { + sid, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + bo, err := utils.BigEndian.ReadUint64(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + ec, err := utils.BigEndian.ReadUint32(r) + if err != nil { + return nil, err + } + errorCode = uint16(ec) + } + + return &RstStreamFrame{ + StreamID: streamID, + ErrorCode: protocol.ApplicationErrorCode(errorCode), + ByteOffset: byteOffset, + }, nil +} + //Write writes a RST_STREAM frame func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { b.WriteByte(0x01) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) - utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) - utils.GetByteOrder(version).WriteUint32(b, f.ErrorCode) + if version.UsesIETFFrameFormat() { + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + utils.WriteVarInt(b, uint64(f.ByteOffset)) + } else { + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) + utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset)) + utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode)) + } return nil } -// MinLength of a written frame -func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 8 + 4, nil -} - -// ParseRstStreamFrame parses a RST_STREAM frame -func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) { - frame := &RstStreamFrame{} - - // read the TypeByte - if _, err := r.ReadByte(); err != nil { - return nil, err - } - - sid, err := utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - - byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) - if err != nil { - return nil, err - } - frame.ByteOffset = protocol.ByteCount(byteOffset) - - frame.ErrorCode, err = utils.GetByteOrder(version).ReadUint32(r) - if err != nil { - return nil, err - } - return frame, nil +// Length of a written frame +func (f *RstStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if version.UsesIETFFrameFormat() { + return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 + utils.VarIntLen(uint64(f.ByteOffset)) + } + return 1 + 4 + 8 + 4 } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go new file mode 100644 index 0000000..b5e6980 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go @@ -0,0 +1,47 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StopSendingFrame is a STOP_SENDING frame +type StopSendingFrame struct { + StreamID protocol.StreamID + ErrorCode protocol.ApplicationErrorCode +} + +// parseStopSendingFrame parses a STOP_SENDING frame +func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + errorCode, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + + return &StopSendingFrame{ + StreamID: protocol.StreamID(streamID), + ErrorCode: protocol.ApplicationErrorCode(errorCode), + }, nil +} + +// Length of a written frame +func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) + 2 +} + +func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x0c) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go index 9eb068d..b87606a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go @@ -6,7 +6,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/qerr" ) // A StopWaitingFrame in QUIC @@ -23,7 +22,10 @@ var ( errPacketNumberLenNotSet = errors.New("StopWaitingFrame: PacketNumberLen not set") ) -func (f *StopWaitingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *StopWaitingFrame) Write(b *bytes.Buffer, v protocol.VersionNumber) error { + if v.UsesIETFFrameFormat() { + return errors.New("STOP_WAITING not defined in IETF QUIC") + } // make sure the PacketNumber was set if f.PacketNumber == protocol.PacketNumber(0) { return errPacketNumberNotSet @@ -38,30 +40,24 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber case protocol.PacketNumberLen1: b.WriteByte(uint8(leastUnackedDelta)) case protocol.PacketNumberLen2: - utils.GetByteOrder(version).WriteUint16(b, uint16(leastUnackedDelta)) + utils.BigEndian.WriteUint16(b, uint16(leastUnackedDelta)) case protocol.PacketNumberLen4: - utils.GetByteOrder(version).WriteUint32(b, uint32(leastUnackedDelta)) + utils.BigEndian.WriteUint32(b, uint32(leastUnackedDelta)) case protocol.PacketNumberLen6: - utils.GetByteOrder(version).WriteUint48(b, leastUnackedDelta&(1<<48-1)) + utils.BigEndian.WriteUint48(b, leastUnackedDelta&(1<<48-1)) default: return errPacketNumberLenNotSet } return nil } -// MinLength of a written frame -func (f *StopWaitingFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - minLength := protocol.ByteCount(1) // typeByte - - if f.PacketNumberLen == protocol.PacketNumberLenInvalid { - return 0, errPacketNumberLenNotSet - } - minLength += protocol.ByteCount(f.PacketNumberLen) - return minLength, nil +// Length of a written frame +func (f *StopWaitingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + protocol.ByteCount(f.PacketNumberLen) } -// ParseStopWaitingFrame parses a StopWaiting frame -func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, version protocol.VersionNumber) (*StopWaitingFrame, error) { +// parseStopWaitingFrame parses a StopWaiting frame +func parseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, _ protocol.VersionNumber) (*StopWaitingFrame, error) { frame := &StopWaitingFrame{} // read the TypeByte @@ -69,12 +65,12 @@ func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, return nil, err } - leastUnackedDelta, err := utils.GetByteOrder(version).ReadUintN(r, uint8(packetNumberLen)) + leastUnackedDelta, err := utils.BigEndian.ReadUintN(r, uint8(packetNumberLen)) if err != nil { return nil, err } - if leastUnackedDelta >= uint64(packetNumber) { - return nil, qerr.Error(qerr.InvalidStopWaitingData, "invalid LeastUnackedDelta") + if leastUnackedDelta > uint64(packetNumber) { + return nil, errors.New("invalid LeastUnackedDelta") } frame.LeastUnacked = protocol.PacketNumber(uint64(packetNumber) - leastUnackedDelta) return frame, nil diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go index 981c0ec..a083a9f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go @@ -10,35 +10,43 @@ import ( // A StreamBlockedFrame in QUIC type StreamBlockedFrame struct { StreamID protocol.StreamID + Offset protocol.ByteCount } -// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame -func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) { - frame := &StreamBlockedFrame{} - - // read the TypeByte - if _, err := r.ReadByte(); err != nil { +// parseStreamBlockedFrame parses a STREAM_BLOCKED frame +func parseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte return nil, err } - sid, err := utils.GetByteOrder(version).ReadUint32(r) + sid, err := utils.ReadVarInt(r) if err != nil { return nil, err } - frame.StreamID = protocol.StreamID(sid) - return frame, nil + offset, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &StreamBlockedFrame{ + StreamID: protocol.StreamID(sid), + Offset: protocol.ByteCount(offset), + }, nil } // Write writes a STREAM_BLOCKED frame func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - if !version.UsesMaxDataFrame() { + if !version.UsesIETFFrameFormat() { return (&blockedFrameLegacy{StreamID: f.StreamID}).Write(b, version) } b.WriteByte(0x09) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + utils.WriteVarInt(b, uint64(f.StreamID)) + utils.WriteVarInt(b, uint64(f.Offset)) return nil } -// MinLength of a written frame -func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4, nil +// Length of a written frame +func (f *StreamBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return 1 + 4 + } + return 1 + utils.VarIntLen(uint64(f.StreamID)) + utils.VarIntLen(uint64(f.Offset)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go index 75be888..d848127 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go @@ -19,13 +19,12 @@ type StreamFrame struct { Data []byte } -var ( - errInvalidStreamIDLen = errors.New("StreamFrame: Invalid StreamID length") - errInvalidOffsetLen = errors.New("StreamFrame: Invalid offset length") -) +// parseStreamFrame reads a STREAM frame +func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { + if !version.UsesIETFFrameFormat() { + return parseLegacyStreamFrame(r, version) + } -// ParseStreamFrame reads a stream frame. The type byte must not have been read yet. -func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { frame := &StreamFrame{} typeByte, err := r.ReadByte() @@ -33,44 +32,39 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF return nil, err } - frame.FinBit = typeByte&0x40 > 0 - frame.DataLenPresent = typeByte&0x20 > 0 - offsetLen := typeByte & 0x1c >> 2 - if offsetLen != 0 { - offsetLen++ - } - streamIDLen := typeByte&0x3 + 1 + frame.FinBit = typeByte&0x1 > 0 + frame.DataLenPresent = typeByte&0x2 > 0 + hasOffset := typeByte&0x4 > 0 - sid, err := utils.GetByteOrder(version).ReadUintN(r, streamIDLen) + streamID, err := utils.ReadVarInt(r) if err != nil { return nil, err } - frame.StreamID = protocol.StreamID(sid) - - offset, err := utils.GetByteOrder(version).ReadUintN(r, offsetLen) - if err != nil { - return nil, err - } - frame.Offset = protocol.ByteCount(offset) - - var dataLen uint16 - if frame.DataLenPresent { - dataLen, err = utils.GetByteOrder(version).ReadUint16(r) + frame.StreamID = protocol.StreamID(streamID) + if hasOffset { + offset, err := utils.ReadVarInt(r) if err != nil { return nil, err } + frame.Offset = protocol.ByteCount(offset) } - // shortcut to prevent the unneccessary allocation of dataLen bytes - // if the dataLen is larger than the remaining length of the packet - // reading the packet contents would result in EOF when attempting to READ - if int(dataLen) > r.Len() { - return nil, io.EOF - } - - if !frame.DataLenPresent { + var dataLen uint64 + if frame.DataLenPresent { + var err error + dataLen, err = utils.ReadVarInt(r) + if err != nil { + return nil, err + } + // shortcut to prevent the unnecessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the packet contents would result in EOF when attempting to READ + if dataLen > uint64(r.Len()) { + return nil, io.EOF + } + } else { // The rest of the packet is data - dataLen = uint16(r.Len()) + dataLen = uint64(r.Len()) } if dataLen != 0 { frame.Data = make([]byte, dataLen) @@ -79,128 +73,111 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF return nil, err } } - - if frame.Offset+frame.DataLen() < frame.Offset { + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") } - if !frame.FinBit && frame.DataLen() == 0 { + // empty frames are only allowed if they have offset 0 or the FIN bit set + if frame.DataLen() == 0 && !frame.FinBit && frame.Offset != 0 { return nil, qerr.EmptyStreamFrameNoFin } return frame, nil } -// WriteStreamFrame writes a stream frame. +// Write writes a STREAM frame func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + return f.writeLegacy(b, version) + } + if len(f.Data) == 0 && !f.FinBit { return errors.New("StreamFrame: attempting to write empty frame without FIN") } - typeByte := uint8(0x80) // sets the leftmost bit to 1 + typeByte := byte(0x10) if f.FinBit { - typeByte ^= 0x40 + typeByte ^= 0x1 } + hasOffset := f.Offset != 0 if f.DataLenPresent { - typeByte ^= 0x20 + typeByte ^= 0x2 } - - offsetLength := f.getOffsetLength() - if offsetLength > 0 { - typeByte ^= (uint8(offsetLength) - 1) << 2 + if hasOffset { + typeByte ^= 0x4 } - - streamIDLen := f.calculateStreamIDLength() - typeByte ^= streamIDLen - 1 - b.WriteByte(typeByte) - - switch streamIDLen { - case 1: - b.WriteByte(uint8(f.StreamID)) - case 2: - utils.GetByteOrder(version).WriteUint16(b, uint16(f.StreamID)) - case 3: - utils.GetByteOrder(version).WriteUint24(b, uint32(f.StreamID)) - case 4: - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) - default: - return errInvalidStreamIDLen + utils.WriteVarInt(b, uint64(f.StreamID)) + if hasOffset { + utils.WriteVarInt(b, uint64(f.Offset)) } - - switch offsetLength { - case 0: - case 2: - utils.GetByteOrder(version).WriteUint16(b, uint16(f.Offset)) - case 3: - utils.GetByteOrder(version).WriteUint24(b, uint32(f.Offset)) - case 4: - utils.GetByteOrder(version).WriteUint32(b, uint32(f.Offset)) - case 5: - utils.GetByteOrder(version).WriteUint40(b, uint64(f.Offset)) - case 6: - utils.GetByteOrder(version).WriteUint48(b, uint64(f.Offset)) - case 7: - utils.GetByteOrder(version).WriteUint56(b, uint64(f.Offset)) - case 8: - utils.GetByteOrder(version).WriteUint64(b, uint64(f.Offset)) - default: - return errInvalidOffsetLen - } - if f.DataLenPresent { - utils.GetByteOrder(version).WriteUint16(b, uint16(len(f.Data))) + utils.WriteVarInt(b, uint64(f.DataLen())) } - b.Write(f.Data) return nil } -func (f *StreamFrame) calculateStreamIDLength() uint8 { - if f.StreamID < (1 << 8) { - return 1 - } else if f.StreamID < (1 << 16) { - return 2 - } else if f.StreamID < (1 << 24) { - return 3 +// Length returns the total length of the STREAM frame +func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return f.lengthLegacy(version) } - return 4 + length := 1 + utils.VarIntLen(uint64(f.StreamID)) + if f.Offset != 0 { + length += utils.VarIntLen(uint64(f.Offset)) + } + if f.DataLenPresent { + length += utils.VarIntLen(uint64(f.DataLen())) + } + return length + f.DataLen() } -func (f *StreamFrame) getOffsetLength() protocol.ByteCount { - if f.Offset == 0 { +// MaxDataLen returns the maximum data length +// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). +func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + if !version.UsesIETFFrameFormat() { + return f.maxDataLenLegacy(maxSize, version) + } + + headerLen := 1 + utils.VarIntLen(uint64(f.StreamID)) + if f.Offset != 0 { + headerLen += utils.VarIntLen(uint64(f.Offset)) + } + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { return 0 } - if f.Offset < (1 << 16) { - return 2 + maxDataLen := maxSize - headerLen + if f.DataLenPresent && utils.VarIntLen(uint64(maxDataLen)) != 1 { + maxDataLen-- } - if f.Offset < (1 << 24) { - return 3 - } - if f.Offset < (1 << 32) { - return 4 - } - if f.Offset < (1 << 40) { - return 5 - } - if f.Offset < (1 << 48) { - return 6 - } - if f.Offset < (1 << 56) { - return 7 - } - return 8 + return maxDataLen } -// MinLength returns the length of the header of a StreamFrame -// the total length of the StreamFrame is frame.MinLength() + frame.DataLen() -func (f *StreamFrame) MinLength(protocol.VersionNumber) (protocol.ByteCount, error) { - length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() - if f.DataLenPresent { - length += 2 +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// If n >= len(frame), nil is returned and nothing is modified. +func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, error) { + if maxSize >= f.Length(version) { + return nil, nil } - return length, nil -} -// DataLen gives the length of data in bytes -func (f *StreamFrame) DataLen() protocol.ByteCount { - return protocol.ByteCount(len(f.Data)) + n := f.MaxDataLen(maxSize, version) + if n == 0 { + return nil, errors.New("too small") + } + newFrame := &StreamFrame{ + FinBit: false, + StreamID: f.StreamID, + Offset: f.Offset, + Data: f.Data[:n], + DataLenPresent: f.DataLenPresent, + } + + f.Data = f.Data[n:] + f.Offset += n + + return newFrame, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go new file mode 100644 index 0000000..a2b159d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame_legacy.go @@ -0,0 +1,209 @@ +package wire + +import ( + "bytes" + "errors" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +var ( + errInvalidStreamIDLen = errors.New("StreamFrame: Invalid StreamID length") + errInvalidOffsetLen = errors.New("StreamFrame: Invalid offset length") +) + +// parseLegacyStreamFrame reads a stream frame. The type byte must not have been read yet. +func parseLegacyStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, error) { + frame := &StreamFrame{} + + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + frame.FinBit = typeByte&0x40 > 0 + frame.DataLenPresent = typeByte&0x20 > 0 + offsetLen := typeByte & 0x1c >> 2 + if offsetLen != 0 { + offsetLen++ + } + streamIDLen := typeByte&0x3 + 1 + + sid, err := utils.BigEndian.ReadUintN(r, streamIDLen) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(sid) + + offset, err := utils.BigEndian.ReadUintN(r, offsetLen) + if err != nil { + return nil, err + } + frame.Offset = protocol.ByteCount(offset) + + var dataLen uint16 + if frame.DataLenPresent { + dataLen, err = utils.BigEndian.ReadUint16(r) + if err != nil { + return nil, err + } + } + + // shortcut to prevent the unnecessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the packet contents would result in EOF when attempting to READ + if int(dataLen) > r.Len() { + return nil, io.EOF + } + + if !frame.DataLenPresent { + // The rest of the packet is data + dataLen = uint16(r.Len()) + } + if dataLen != 0 { + frame.Data = make([]byte, dataLen) + if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier + return nil, err + } + } + + // MaxByteCount is the highest value that can be encoded with the IETF QUIC variable integer encoding (2^62-1). + // Note that this value is smaller than the maximum value that could be encoded in the gQUIC STREAM frame (2^64-1). + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { + return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") + } + if !frame.FinBit && frame.DataLen() == 0 { + return nil, qerr.EmptyStreamFrameNoFin + } + return frame, nil +} + +// writeLegacy writes a stream frame. +func (f *StreamFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error { + if len(f.Data) == 0 && !f.FinBit { + return errors.New("StreamFrame: attempting to write empty frame without FIN") + } + + typeByte := uint8(0x80) // sets the leftmost bit to 1 + if f.FinBit { + typeByte ^= 0x40 + } + if f.DataLenPresent { + typeByte ^= 0x20 + } + + offsetLength := f.getOffsetLength() + if offsetLength > 0 { + typeByte ^= (uint8(offsetLength) - 1) << 2 + } + + streamIDLen := f.calculateStreamIDLength() + typeByte ^= streamIDLen - 1 + + b.WriteByte(typeByte) + + switch streamIDLen { + case 1: + b.WriteByte(uint8(f.StreamID)) + case 2: + utils.BigEndian.WriteUint16(b, uint16(f.StreamID)) + case 3: + utils.BigEndian.WriteUint24(b, uint32(f.StreamID)) + case 4: + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) + default: + return errInvalidStreamIDLen + } + + switch offsetLength { + case 0: + case 2: + utils.BigEndian.WriteUint16(b, uint16(f.Offset)) + case 3: + utils.BigEndian.WriteUint24(b, uint32(f.Offset)) + case 4: + utils.BigEndian.WriteUint32(b, uint32(f.Offset)) + case 5: + utils.BigEndian.WriteUint40(b, uint64(f.Offset)) + case 6: + utils.BigEndian.WriteUint48(b, uint64(f.Offset)) + case 7: + utils.BigEndian.WriteUint56(b, uint64(f.Offset)) + case 8: + utils.BigEndian.WriteUint64(b, uint64(f.Offset)) + default: + return errInvalidOffsetLen + } + + if f.DataLenPresent { + utils.BigEndian.WriteUint16(b, uint16(len(f.Data))) + } + + b.Write(f.Data) + return nil +} + +func (f *StreamFrame) calculateStreamIDLength() uint8 { + if f.StreamID < (1 << 8) { + return 1 + } else if f.StreamID < (1 << 16) { + return 2 + } else if f.StreamID < (1 << 24) { + return 3 + } + return 4 +} + +func (f *StreamFrame) getOffsetLength() protocol.ByteCount { + if f.Offset == 0 { + return 0 + } + if f.Offset < (1 << 16) { + return 2 + } + if f.Offset < (1 << 24) { + return 3 + } + if f.Offset < (1 << 32) { + return 4 + } + if f.Offset < (1 << 40) { + return 5 + } + if f.Offset < (1 << 48) { + return 6 + } + if f.Offset < (1 << 56) { + return 7 + } + return 8 +} + +func (f *StreamFrame) headerLengthLegacy(_ protocol.VersionNumber) protocol.ByteCount { + length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() + if f.DataLenPresent { + length += 2 + } + return length +} + +func (f *StreamFrame) lengthLegacy(version protocol.VersionNumber) protocol.ByteCount { + return f.headerLengthLegacy(version) + f.DataLen() +} + +func (f *StreamFrame) maxDataLenLegacy(maxFrameSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + headerLen := f.headerLengthLegacy(version) + if headerLen > maxFrameSize { + return 0 + } + return maxFrameSize - headerLen +} + +// DataLen gives the length of data in bytes +func (f *StreamFrame) DataLen() protocol.ByteCount { + return protocol.ByteCount(len(f.Data)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go new file mode 100644 index 0000000..6476eb9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_id_blocked_frame.go @@ -0,0 +1,37 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StreamIDBlockedFrame is a STREAM_ID_BLOCKED frame +type StreamIDBlockedFrame struct { + StreamID protocol.StreamID +} + +// parseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame +func parseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + return &StreamIDBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +func (f *StreamIDBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x0a) + b.WriteByte(typeByte) + utils.WriteVarInt(b, uint64(f.StreamID)) + return nil +} + +// Length of a written frame +func (f *StreamIDBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(uint64(f.StreamID)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go index 92afb3b..a19f276 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "crypto/rand" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -9,43 +10,32 @@ import ( // ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { - fullReply := &bytes.Buffer{} - ph := Header{ - ConnectionID: connID, - PacketNumber: 1, - VersionFlag: true, - } - if err := ph.writePublicHeader(fullReply, protocol.PerspectiveServer, protocol.VersionWhatever); err != nil { - utils.Errorf("error composing version negotiation packet: %s", err.Error()) - return nil - } + buf := bytes.NewBuffer(make([]byte, 0, 1+8+len(versions)*4)) + buf.Write([]byte{0x1 | 0x8}) // type byte + buf.Write(connID) for _, v := range versions { - utils.BigEndian.WriteUint32(fullReply, uint32(v)) + utils.BigEndian.WriteUint32(buf, uint32(v)) } - return fullReply.Bytes() + return buf.Bytes() } // ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft -func ComposeVersionNegotiation( - connID protocol.ConnectionID, - pn protocol.PacketNumber, - versionOffered protocol.VersionNumber, - versions []protocol.VersionNumber, -) []byte { - fullReply := &bytes.Buffer{} - ph := Header{ - IsLongHeader: true, - Type: protocol.PacketTypeVersionNegotiation, - ConnectionID: connID, - PacketNumber: pn, - Version: versionOffered, +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) ([]byte, error) { + greasedVersions := protocol.GetGreasedVersions(versions) + buf := bytes.NewBuffer(make([]byte, 0, 1+8+4+len(greasedVersions)*4)) + r := make([]byte, 1) + _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. + buf.WriteByte(r[0] | 0x80) + utils.BigEndian.WriteUint32(buf, 0) // version 0 + connIDLen, err := encodeConnIDLen(destConnID, srcConnID) + if err != nil { + return nil, err } - if err := ph.writeHeader(fullReply); err != nil { - utils.Errorf("error composing version negotiation packet: %s", err.Error()) - return nil + buf.WriteByte(connIDLen) + buf.Write(destConnID) + buf.Write(srcConnID) + for _, v := range greasedVersions { + utils.BigEndian.WriteUint32(buf, uint32(v)) } - for _, v := range versions { - utils.BigEndian.WriteUint32(fullReply, uint32(v)) - } - return fullReply.Bytes() + return buf.Bytes(), nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go index 20d7b66..606e25c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go @@ -12,24 +12,34 @@ type windowUpdateFrame struct { ByteOffset protocol.ByteCount } -// ParseWindowUpdateFrame parses a WINDOW_UPDATE frame +// parseWindowUpdateFrame parses a WINDOW_UPDATE frame // The frame returned is // * a MAX_STREAM_DATA frame, if the WINDOW_UPDATE applies to a stream // * a MAX_DATA frame, if the WINDOW_UPDATE applies to the connection -func ParseWindowUpdateFrame(r *bytes.Reader, version protocol.VersionNumber) (Frame, error) { - f, err := ParseMaxStreamDataFrame(r, version) +func parseWindowUpdateFrame(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + streamID, err := utils.BigEndian.ReadUint32(r) if err != nil { return nil, err } - if f.StreamID == 0 { - return &MaxDataFrame{ByteOffset: f.ByteOffset}, nil + offset, err := utils.BigEndian.ReadUint64(r) + if err != nil { + return nil, err } - return f, nil + if streamID == 0 { + return &MaxDataFrame{ByteOffset: protocol.ByteCount(offset)}, nil + } + return &MaxStreamDataFrame{ + StreamID: protocol.StreamID(streamID), + ByteOffset: protocol.ByteCount(offset), + }, nil } -func (f *windowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *windowUpdateFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { b.WriteByte(0x4) - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) - utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + utils.BigEndian.WriteUint32(b, uint32(f.StreamID)) + utils.BigEndian.WriteUint64(b, uint64(f.ByteOffset)) return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/mint_utils.go b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go new file mode 100644 index 0000000..36af76d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mint_utils.go @@ -0,0 +1,168 @@ +package quic + +import ( + "bytes" + gocrypto "crypto" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type mintController struct { + csc *handshake.CryptoStreamConn + conn *mint.Conn +} + +var _ handshake.MintTLS = &mintController{} + +func newMintController( + csc *handshake.CryptoStreamConn, + mconf *mint.Config, + pers protocol.Perspective, +) handshake.MintTLS { + var conn *mint.Conn + if pers == protocol.PerspectiveClient { + conn = mint.Client(csc, mconf) + } else { + conn = mint.Server(csc, mconf) + } + return &mintController{ + csc: csc, + conn: conn, + } +} + +func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { + return mc.conn.ConnectionState().CipherSuite +} + +func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + return mc.conn.ComputeExporter(label, context, keyLength) +} + +func (mc *mintController) Handshake() mint.Alert { + return mc.conn.Handshake() +} + +func (mc *mintController) State() mint.State { + return mc.conn.ConnectionState().HandshakeState +} + +func (mc *mintController) ConnectionState() mint.ConnectionState { + return mc.conn.ConnectionState() +} + +func (mc *mintController) SetCryptoStream(stream io.ReadWriter) { + mc.csc.SetStream(stream) +} + +func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { + mconf := &mint.Config{ + NonBlocking: true, + CipherSuites: []mint.CipherSuite{ + mint.TLS_AES_128_GCM_SHA256, + mint.TLS_AES_256_GCM_SHA384, + }, + } + if tlsConf != nil { + mconf.ServerName = tlsConf.ServerName + mconf.InsecureSkipVerify = tlsConf.InsecureSkipVerify + mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) + mconf.RootCAs = tlsConf.RootCAs + mconf.VerifyPeerCertificate = tlsConf.VerifyPeerCertificate + for i, certChain := range tlsConf.Certificates { + mconf.Certificates[i] = &mint.Certificate{ + Chain: make([]*x509.Certificate, len(certChain.Certificate)), + PrivateKey: certChain.PrivateKey.(gocrypto.Signer), + } + for j, cert := range certChain.Certificate { + c, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err + } + mconf.Certificates[i].Chain[j] = c + } + } + switch tlsConf.ClientAuth { + case tls.NoClientCert: + case tls.RequireAnyClientCert: + mconf.RequireClientAuth = true + default: + return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert") + } + } + if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { + return nil, err + } + return mconf, nil +} + +// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets +// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0. +func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, logger utils.Logger, version protocol.VersionNumber) (*wire.StreamFrame, error) { + decrypted, err := aead.Open(data[:0], data, hdr.PacketNumber, hdr.Raw) + if err != nil { + return nil, err + } + var frame *wire.StreamFrame + r := bytes.NewReader(decrypted) + for { + f, err := wire.ParseNextFrame(r, hdr, version) + if err != nil { + return nil, err + } + var ok bool + if frame, ok = f.(*wire.StreamFrame); ok || frame == nil { + break + } + } + if frame == nil { + return nil, errors.New("Packet doesn't contain a STREAM_FRAME") + } + if frame.StreamID != version.CryptoStreamID() { + return nil, fmt.Errorf("Received STREAM_FRAME for wrong stream (Stream ID %d)", frame.StreamID) + } + // We don't need a check for the stream ID here. + // The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream. + if frame.Offset != 0 { + return nil, errors.New("received stream data with non-zero offset") + } + if logger.Debug() { + logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID) + hdr.Log(logger) + wire.LogFrame(logger, frame, false) + } + return frame, nil +} + +// packUnencryptedPacket provides a low-overhead way to pack a packet. +// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. +func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective, logger utils.Logger) ([]byte, error) { + raw := *getPacketBuffer() + buffer := bytes.NewBuffer(raw[:0]) + if err := hdr.Write(buffer, pers, hdr.Version); err != nil { + return nil, err + } + payloadStartIndex := buffer.Len() + if err := f.Write(buffer, hdr.Version); err != nil { + return nil, err + } + raw = raw[0:buffer.Len()] + _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) + raw = raw[0 : buffer.Len()+aead.Overhead()] + if logger.Debug() { + logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.SrcConnectionID, protocol.EncryptionUnencrypted) + hdr.Log(logger) + wire.LogFrame(logger, f, true) + } + return raw, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/mockgen.go new file mode 100644 index 0000000..65f3854 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen.go @@ -0,0 +1,16 @@ +package quic + +//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager" +//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go" +//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker" +//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_unpacker_test.go mock_unpacker_test.go" +//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD" +//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" +//go:generate sh -c "goimports -w mock*_test.go" diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh b/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh new file mode 100755 index 0000000..7fbe68d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Mockgen refuses to generate mocks private types. +# This script copies the quic package to a temporary directory, and adds an public alias for the private type. +# It then creates a mock for this public (alias) type. + +TEMP_DIR=$(mktemp -d) +mkdir -p $TEMP_DIR/src/github.com/lucas-clemente/quic-go/ + +# copy all .go files to a temporary directory +# golang.org/x/crypto/curve25519/ uses Go compiler directives, which is confusing to mockgen +rsync -r --exclude 'vendor/golang.org/x/crypto/curve25519/' --include='*.go' --include '*/' --exclude '*' $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/ +echo "type $5 = $4" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/interface.go + +export GOPATH="$TEMP_DIR:$GOPATH" + +mockgen -package $1 -self_package $1 -destination $2 $3 $5 + +rm -r "$TEMP_DIR" diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go b/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go index 8ece95a..ac63577 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go @@ -17,9 +17,9 @@ type packetNumberGenerator struct { nextToSkip protocol.PacketNumber } -func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator { +func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { return &packetNumberGenerator{ - next: 1, + next: initial, averagePeriod: averagePeriod, } } diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go index 1a63715..f616db1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -4,10 +4,14 @@ import ( "bytes" "errors" "fmt" + "net" + "sync" + "time" - "github.com/lucas-clemente/quic-go/ackhandler" + "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -18,35 +22,89 @@ type packedPacket struct { encryptionLevel protocol.EncryptionLevel } -type packetPacker struct { - connectionID protocol.ConnectionID - perspective protocol.Perspective - version protocol.VersionNumber - cryptoSetup handshake.CryptoSetup - - packetNumberGenerator *packetNumberGenerator - streamFramer *streamFramer - - controlFrames []wire.Frame - stopWaiting *wire.StopWaitingFrame - ackFrame *wire.AckFrame - leastUnacked protocol.PacketNumber - omitConnectionID bool +func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { + return &ackhandler.Packet{ + PacketNumber: p.header.PacketNumber, + PacketType: p.header.Type, + Frames: p.frames, + Length: protocol.ByteCount(len(p.raw)), + EncryptionLevel: p.encryptionLevel, + SendTime: time.Now(), + } } -func newPacketPacker(connectionID protocol.ConnectionID, - cryptoSetup handshake.CryptoSetup, - streamFramer *streamFramer, +type sealingManager interface { + GetSealer() (protocol.EncryptionLevel, handshake.Sealer) + GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) + GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error) +} + +type streamFrameSource interface { + HasCryptoStreamData() bool + PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame + PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame +} + +type packetPacker struct { + destConnID protocol.ConnectionID + srcConnID protocol.ConnectionID + + perspective protocol.Perspective + version protocol.VersionNumber + divNonce []byte + cryptoSetup sealingManager + + packetNumberGenerator *packetNumberGenerator + getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen + streams streamFrameSource + + controlFrameMutex sync.Mutex + controlFrames []wire.Frame + + stopWaiting *wire.StopWaitingFrame + ackFrame *wire.AckFrame + omitConnectionID bool + maxPacketSize protocol.ByteCount + hasSentPacket bool // has the packetPacker already sent a packet + numNonRetransmittableAcks int +} + +func newPacketPacker( + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + initialPacketNumber protocol.PacketNumber, + getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen, + remoteAddr net.Addr, // only used for determining the max packet size + divNonce []byte, + cryptoSetup sealingManager, + streamFramer streamFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { + maxPacketSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + // If ip is not an IPv4 address, To4 returns nil. + // Note that there might be some corner cases, where this is not correct. + // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. + if udpAddr.IP.To4() == nil { + maxPacketSize = protocol.MaxPacketSizeIPv6 + } else { + maxPacketSize = protocol.MaxPacketSizeIPv4 + } + } return &packetPacker{ cryptoSetup: cryptoSetup, - connectionID: connectionID, + divNonce: divNonce, + destConnID: destConnID, + srcConnID: srcConnID, perspective: perspective, version: version, - streamFramer: streamFramer, - packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), + streams: streamFramer, + getPacketNumberLen: getPacketNumberLen, + packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), + maxPacketSize: maxPacketSize, } } @@ -71,7 +129,7 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) frames := []wire.Frame{p.ackFrame} - if p.stopWaiting != nil { + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC p.stopWaiting.PacketNumber = header.PacketNumber p.stopWaiting.PacketNumberLen = header.PacketNumberLen frames = append(frames, p.stopWaiting) @@ -87,23 +145,135 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) { }, err } -// PackHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption -func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) { - if packet.EncryptionLevel == protocol.EncryptionForwardSecure { - return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment") +// PackRetransmission packs a retransmission +// For packets sent after completion of the handshake, it might happen that 2 packets have to be sent. +// This can happen e.g. when a longer packet number is used in the header. +func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) { + if packet.EncryptionLevel != protocol.EncryptionForwardSecure { + p, err := p.packHandshakeRetransmission(packet) + return []*packedPacket{p}, err } + + var controlFrames []wire.Frame + var streamFrames []*wire.StreamFrame + for _, f := range packet.Frames { + if sf, ok := f.(*wire.StreamFrame); ok { + sf.DataLenPresent = true + streamFrames = append(streamFrames, sf) + } else { + controlFrames = append(controlFrames, f) + } + } + + var packets []*packedPacket + encLevel, sealer := p.cryptoSetup.GetSealer() + for len(controlFrames) > 0 || len(streamFrames) > 0 { + var frames []wire.Frame + var payloadLength protocol.ByteCount + + header := p.getHeader(encLevel) + headerLength, err := header.GetLength(p.perspective, p.version) + if err != nil { + return nil, err + } + maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength + + // for gQUIC: add a STOP_WAITING for *every* retransmission + if p.version.UsesStopWaitingFrames() { + if p.stopWaiting == nil { + return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") + } + // create a new StopWaitingFrame, since we might need to send more than one packet as a retransmission + swf := &wire.StopWaitingFrame{ + LeastUnacked: p.stopWaiting.LeastUnacked, + PacketNumber: header.PacketNumber, + PacketNumberLen: header.PacketNumberLen, + } + payloadLength += swf.Length(p.version) + frames = append(frames, swf) + } + + for len(controlFrames) > 0 { + frame := controlFrames[0] + length := frame.Length(p.version) + if payloadLength+length > maxSize { + break + } + payloadLength += length + frames = append(frames, frame) + controlFrames = controlFrames[1:] + } + + // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field + // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set + // however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size + // for gQUIC STREAM frames, DataLen is always 2 bytes + // for IETF draft style STREAM frames, the length is encoded to either 1 or 2 bytes + if p.version.UsesIETFFrameFormat() { + maxSize++ + } else { + maxSize += 2 + } + for len(streamFrames) > 0 && payloadLength+protocol.MinStreamFrameSize < maxSize { + // TODO: optimize by setting DataLenPresent = false on all but the last STREAM frame + frame := streamFrames[0] + frameToAdd := frame + + sf, err := frame.MaybeSplitOffFrame(maxSize-payloadLength, p.version) + if err != nil { + return nil, err + } + if sf != nil { + frameToAdd = sf + } else { + streamFrames = streamFrames[1:] + } + payloadLength += frameToAdd.Length(p.version) + frames = append(frames, frameToAdd) + } + if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { + sf.DataLenPresent = false + } + raw, err := p.writeAndSealPacket(header, frames, sealer) + if err != nil { + return nil, err + } + packets = append(packets, &packedPacket{ + header: header, + raw: raw, + frames: frames, + encryptionLevel: encLevel, + }) + } + p.stopWaiting = nil + return packets, nil +} + +// packHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption +func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) { sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel) if err != nil { return nil, err } - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") + // make sure that the retransmission for an Initial packet is sent as an Initial packet + if packet.PacketType == protocol.PacketTypeInitial { + p.hasSentPacket = false } header := p.getHeader(packet.EncryptionLevel) - p.stopWaiting.PacketNumber = header.PacketNumber - p.stopWaiting.PacketNumberLen = header.PacketNumberLen - frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...) - p.stopWaiting = nil + header.Type = packet.PacketType + var frames []wire.Frame + if p.version.UsesStopWaitingFrames() { // for gQUIC: pack a STOP_WAITING first + if p.stopWaiting == nil { + return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") + } + swf := p.stopWaiting + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + p.stopWaiting = nil + frames = append([]wire.Frame{swf}, packet.Frames...) + } else { + frames = packet.Frames + } raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ header: header, @@ -116,7 +286,12 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - if p.streamFramer.HasCryptoStreamFrame() { + hasCryptoStreamFrame := p.streams.HasCryptoStreamData() + // if this is the first packet to be send, make sure it contains stream data + if !p.hasSentPacket && !hasCryptoStreamFrame { + return nil, nil + } + if hasCryptoStreamFrame { return p.packCryptoPacket() } @@ -132,7 +307,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.stopWaiting.PacketNumberLen = header.PacketNumberLen } - maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength + maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) if err != nil { return nil, err @@ -146,6 +321,19 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if len(payloadFrames) == 1 && p.stopWaiting != nil { return nil, nil } + if p.ackFrame != nil { + // check if this packet only contains an ACK (and maybe a STOP_WAITING) + if len(payloadFrames) == 1 || (p.stopWaiting != nil && len(payloadFrames) == 2) { + if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { + payloadFrames = append(payloadFrames, &wire.PingFrame{}) + p.numNonRetransmittableAcks = 0 + } else { + p.numNonRetransmittableAcks++ + } + } else { + p.numNonRetransmittableAcks = 0 + } + } p.stopWaiting = nil p.ackFrame = nil @@ -168,8 +356,10 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { if err != nil { return nil, err } - maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength - frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} + maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength + sf := p.streams.PopCryptoStreamFrame(maxLen) + sf.DataLenPresent = false + frames := []wire.Frame{sf} raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err @@ -190,36 +380,28 @@ func (p *packetPacker) composeNextPacket( var payloadFrames []wire.Frame // STOP_WAITING and ACK will always fit - if p.stopWaiting != nil { - payloadFrames = append(payloadFrames, p.stopWaiting) - l, err := p.stopWaiting.MinLength(p.version) - if err != nil { - return nil, err - } + if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them + payloadFrames = append(payloadFrames, p.ackFrame) + l := p.ackFrame.Length(p.version) payloadLength += l } - if p.ackFrame != nil { - payloadFrames = append(payloadFrames, p.ackFrame) - l, err := p.ackFrame.MinLength(p.version) - if err != nil { - return nil, err - } - payloadLength += l + if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC + payloadFrames = append(payloadFrames, p.stopWaiting) + payloadLength += p.stopWaiting.Length(p.version) } + p.controlFrameMutex.Lock() for len(p.controlFrames) > 0 { frame := p.controlFrames[len(p.controlFrames)-1] - minLength, err := frame.MinLength(p.version) - if err != nil { - return nil, err - } - if payloadLength+minLength > maxFrameSize { + length := frame.Length(p.version) + if payloadLength+length > maxFrameSize { break } payloadFrames = append(payloadFrames, frame) - payloadLength += minLength + payloadLength += length p.controlFrames = p.controlFrames[:len(p.controlFrames)-1] } + p.controlFrameMutex.Unlock() if payloadLength > maxFrameSize { return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) @@ -229,25 +411,25 @@ func (p *packetPacker) composeNextPacket( return payloadFrames, nil } - // temporarily increase the maxFrameSize by 2 bytes + // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set - // however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size - maxFrameSize += 2 + // however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size + // for gQUIC STREAM frames, DataLen is always 2 bytes + // for IETF draft style STREAM frames, the length is encoded to either 1 or 2 bytes + if p.version.UsesIETFFrameFormat() { + maxFrameSize++ + } else { + maxFrameSize += 2 + } - fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength) + fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength) if len(fs) != 0 { fs[len(fs)-1].DataLenPresent = false } - // TODO: Simplify for _, f := range fs { payloadFrames = append(payloadFrames, f) } - - for b := p.streamFramer.PopBlockedFrame(); b != nil; b = p.streamFramer.PopBlockedFrame() { - p.controlFrames = append(p.controlFrames, b) - } - return payloadFrames, nil } @@ -258,26 +440,34 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) { case *wire.AckFrame: p.ackFrame = f default: + p.controlFrameMutex.Lock() p.controlFrames = append(p.controlFrames, f) + p.controlFrameMutex.Unlock() } } func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { pnum := p.packetNumberGenerator.Peek() - packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) - - var isLongHeader bool - if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { - // TODO: set the Long Header type - packetNumberLen = protocol.PacketNumberLen4 - isLongHeader = true - } + packetNumberLen := p.getPacketNumberLen(pnum) header := &wire.Header{ - ConnectionID: p.connectionID, - PacketNumber: pnum, - PacketNumberLen: packetNumberLen, - IsLongHeader: isLongHeader, + DestConnectionID: p.destConnID, + SrcConnectionID: p.srcConnID, + PacketNumber: pnum, + PacketNumberLen: packetNumberLen, + } + + if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { + header.PacketNumberLen = protocol.PacketNumberLen4 + header.IsLongHeader = true + // Set the payload len to maximum size. + // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns. + header.PayloadLen = p.maxPacketSize + if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient { + header.Type = protocol.PacketTypeInitial + } else { + header.Type = protocol.PacketTypeHandshake + } } if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { @@ -285,14 +475,13 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header } if !p.version.UsesTLS() { if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { - header.DiversificationNonce = p.cryptoSetup.DiversificationNonce() + header.DiversificationNonce = p.divNonce } if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { header.VersionFlag = true header.Version = p.version } } else { - header.Type = p.cryptoSetup.GetNextPacketType() if encLevel != protocol.EncryptionForwardSecure { header.Version = p.version } @@ -305,21 +494,51 @@ func (p *packetPacker) writeAndSealPacket( payloadFrames []wire.Frame, sealer handshake.Sealer, ) ([]byte, error) { - raw := getPacketBuffer() - buffer := bytes.NewBuffer(raw) + raw := *getPacketBuffer() + buffer := bytes.NewBuffer(raw[:0]) + + // the payload length is only needed for Long Headers + if header.IsLongHeader { + if header.Type == protocol.PacketTypeInitial { + headerLen, _ := header.GetLength(p.perspective, p.version) + header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen + } else { + payloadLen := protocol.ByteCount(sealer.Overhead()) + for _, frame := range payloadFrames { + payloadLen += frame.Length(p.version) + } + header.PayloadLen = payloadLen + } + } if err := header.Write(buffer, p.perspective, p.version); err != nil { return nil, err } payloadStartIndex := buffer.Len() + + // the Initial packet needs to be padded, so the last STREAM frame must have the data length present + if header.Type == protocol.PacketTypeInitial { + lastFrame := payloadFrames[len(payloadFrames)-1] + if sf, ok := lastFrame.(*wire.StreamFrame); ok { + sf.DataLenPresent = true + } + } for _, frame := range payloadFrames { - err := frame.Write(buffer, p.version) - if err != nil { + if err := frame.Write(buffer, p.version); err != nil { return nil, err } } - if protocol.ByteCount(buffer.Len()+sealer.Overhead()) > protocol.MaxPacketSize { - return nil, errors.New("PacketPacker BUG: packet too large") + // if this is an IETF QUIC Initial packet, we need to pad it to fulfill the minimum size requirement + // in gQUIC, padding is handled in the CHLO + if header.Type == protocol.PacketTypeInitial { + paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len() + if paddingLen > 0 { + buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) + } + } + + if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { + return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } raw = raw[0:buffer.Len()] @@ -330,7 +549,7 @@ func (p *packetPacker) writeAndSealPacket( if num != header.PacketNumber { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - + p.hasSentPacket = true return raw, nil } @@ -341,10 +560,10 @@ func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool { return encLevel == protocol.EncryptionForwardSecure } -func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { - p.leastUnacked = leastUnacked -} - func (p *packetPacker) SetOmitConnectionID() { p.omitConnectionID = true } + +func (p *packetPacker) SetMaxPacketSize(size protocol.ByteCount) { + p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, size) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go index f891e37..9949790 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -14,113 +13,111 @@ type unpackedPacket struct { frames []wire.Frame } -type quicAEAD interface { +type gQUICAEAD interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) } -type packetUnpacker struct { - version protocol.VersionNumber - aead quicAEAD +type quicAEAD interface { + OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) + Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) } -func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { - buf := getPacketBuffer() - defer putPacketBuffer(buf) - decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, headerBinary) - if err != nil { - // Wrap err in quicError so that public reset is sent by session - return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) - } - r := bytes.NewReader(decrypted) +type packetUnpackerBase struct { + version protocol.VersionNumber +} +func (u *packetUnpackerBase) parseFrames(decrypted []byte, hdr *wire.Header) ([]wire.Frame, error) { + r := bytes.NewReader(decrypted) if r.Len() == 0 { return nil, qerr.MissingPayload } fs := make([]wire.Frame, 0, 2) - // Read all frames in the packet - for r.Len() > 0 { - typeByte, _ := r.ReadByte() - if typeByte == 0x0 { // PADDING frame - continue - } - r.UnreadByte() - - var frame wire.Frame - if typeByte&0x80 == 0x80 { - frame, err = wire.ParseStreamFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidStreamData, err.Error()) - } else { - streamID := frame.(*wire.StreamFrame).StreamID - if streamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted { - err = qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", streamID)) - } - } - } else if typeByte&0xc0 == 0x40 { - frame, err = wire.ParseAckFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidAckData, err.Error()) - } - } else if typeByte == 0x01 { - frame, err = wire.ParseRstStreamFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) - } - } else if typeByte == 0x02 { - frame, err = wire.ParseConnectionCloseFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) - } - } else if typeByte == 0x3 { - frame, err = wire.ParseGoawayFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidGoawayData, err.Error()) - } - } else if u.version.UsesMaxDataFrame() && typeByte == 0x4 { // in IETF QUIC, 0x4 is a MAX_DATA frame - frame, err = wire.ParseMaxDataFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) - } - } else if typeByte == 0x4 { // in gQUIC, 0x4 is a WINDOW_UPDATE frame - frame, err = wire.ParseWindowUpdateFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) - } - } else if u.version.UsesMaxDataFrame() && typeByte == 0x5 { // in IETF QUIC, 0x5 is a MAX_STREAM_DATA frame - frame, err = wire.ParseMaxStreamDataFrame(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) - } - } else if typeByte == 0x5 { // in gQUIC, 0x5 is a BLOCKED frame - frame, err = wire.ParseBlockedFrameLegacy(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidBlockedData, err.Error()) - } - } else if typeByte == 0x6 { - frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) - } - } else if typeByte == 0x7 { - frame, err = wire.ParsePingFrame(r, u.version) - } else if u.version.UsesMaxDataFrame() && typeByte == 0x8 { // in IETF QUIC, 0x4 is a BLOCKED frame - frame, err = wire.ParseBlockedFrame(r, u.version) - } else if u.version.UsesMaxDataFrame() && typeByte == 0x9 { // in IETF QUIC, 0x4 is a STREAM_BLOCKED frame - frame, err = wire.ParseBlockedFrameLegacy(r, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidBlockedData, err.Error()) - } - } else { - err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) - } + for { + frame, err := wire.ParseNextFrame(r, hdr, u.version) if err != nil { return nil, err } - if frame != nil { - fs = append(fs, frame) + if frame == nil { + break } + fs = append(fs, frame) + } + return fs, nil +} + +// The packetUnpackerGQUIC unpacks gQUIC packets. +type packetUnpackerGQUIC struct { + packetUnpackerBase + aead gQUICAEAD +} + +var _ unpacker = &packetUnpackerGQUIC{} + +func newPacketUnpackerGQUIC(aead gQUICAEAD, version protocol.VersionNumber) unpacker { + return &packetUnpackerGQUIC{ + packetUnpackerBase: packetUnpackerBase{version: version}, + aead: aead, + } +} + +func (u *packetUnpackerGQUIC) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { + decrypted, encryptionLevel, err := u.aead.Open(data[:0], data, hdr.PacketNumber, headerBinary) + if err != nil { + // Wrap err in quicError so that public reset is sent by session + return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) + } + + fs, err := u.parseFrames(decrypted, hdr) + if err != nil { + return nil, err + } + + return &unpackedPacket{ + encryptionLevel: encryptionLevel, + frames: fs, + }, nil +} + +// The packetUnpacker unpacks IETF QUIC packets. +type packetUnpacker struct { + packetUnpackerBase + aead quicAEAD +} + +var _ unpacker = &packetUnpacker{} + +func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker { + return &packetUnpacker{ + packetUnpackerBase: packetUnpackerBase{version: version}, + aead: aead, + } +} + +func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { + buf := *getPacketBuffer() + buf = buf[:0] + defer putPacketBuffer(&buf) + + var decrypted []byte + var encryptionLevel protocol.EncryptionLevel + var err error + if hdr.IsLongHeader { + decrypted, err = u.aead.OpenHandshake(buf, data, hdr.PacketNumber, headerBinary) + encryptionLevel = protocol.EncryptionUnencrypted + } else { + decrypted, err = u.aead.Open1RTT(buf, data, hdr.PacketNumber, headerBinary) + encryptionLevel = protocol.EncryptionForwardSecure + } + if err != nil { + // Wrap err in quicError so that public reset is sent by session + return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) + } + + fs, err := u.parseFrames(decrypted, hdr) + if err != nil { + return nil, err } return &unpackedPacket{ diff --git a/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go b/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go index 5a8e024..22d0c85 100644 --- a/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go +++ b/vendor/github.com/lucas-clemente/quic-go/qerr/errorcode_string.go @@ -1,8 +1,8 @@ -// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT +// Code generated by "stringer -type=ErrorCode"; DO NOT EDIT. package qerr -import "fmt" +import "strconv" const ( _ErrorCode_name_0 = "InternalErrorStreamDataAfterTerminationInvalidPacketHeaderInvalidFrameDataInvalidFecDataInvalidRstStreamDataInvalidConnectionCloseDataInvalidGoawayDataInvalidAckDataInvalidVersionNegotiationPacketInvalidPublicRstPacketDecryptionFailureEncryptionFailurePacketTooLarge" @@ -19,7 +19,6 @@ var ( _ErrorCode_index_2 = [...]uint16{0, 15, 37, 57, 75, 96, 112, 127, 147, 167, 191, 226, 250, 279, 309, 340, 366, 385, 410, 425, 445, 457, 475, 505, 530, 547} _ErrorCode_index_3 = [...]uint16{0, 14, 29, 50, 65, 90, 119, 158, 184, 208, 231, 249, 279, 301, 322, 340, 366, 390, 425} _ErrorCode_index_4 = [...]uint16{0, 16, 45, 78, 97, 114, 144, 169, 192, 215, 238, 256, 276, 292, 308, 346, 379, 410, 448, 459, 477, 498, 532} - _ErrorCode_index_5 = [...]uint8{0, 34} ) func (i ErrorCode) String() string { @@ -42,6 +41,6 @@ func (i ErrorCode) String() string { case i == 97: return _ErrorCode_name_5 default: - return fmt.Sprintf("ErrorCode(%d)", i) + return "ErrorCode(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/vendor/github.com/lucas-clemente/quic-go/qerr/quic_error.go b/vendor/github.com/lucas-clemente/quic-go/qerr/quic_error.go index 9e1956f..42d08c4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/qerr/quic_error.go +++ b/vendor/github.com/lucas-clemente/quic-go/qerr/quic_error.go @@ -2,8 +2,6 @@ package qerr import ( "fmt" - - "github.com/lucas-clemente/quic-go/internal/utils" ) // ErrorCode can be used as a normal error without reason. @@ -31,6 +29,7 @@ func (e *QuicError) Error() string { return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage) } +// Timeout says if this error is a timeout. func (e *QuicError) Timeout() bool { switch e.ErrorCode { case NetworkIdleTimeout, @@ -50,6 +49,5 @@ func ToQuicError(err error) *QuicError { case ErrorCode: return Error(e, "") } - utils.Errorf("Internal error: %v", err) return Error(InternalError, err.Error()) } diff --git a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go new file mode 100644 index 0000000..cec69f1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go @@ -0,0 +1,284 @@ +package quic + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type receiveStreamI interface { + ReceiveStream + + handleStreamFrame(*wire.StreamFrame) error + handleRstStreamFrame(*wire.RstStreamFrame) error + closeForShutdown(error) + getWindowUpdate() protocol.ByteCount +} + +type receiveStream struct { + mutex sync.Mutex + + streamID protocol.StreamID + + sender streamSender + + frameQueue *streamFrameSorter + readPosInFrame int + readOffset protocol.ByteCount + + closeForShutdownErr error + cancelReadErr error + resetRemotelyErr StreamError + + closedForShutdown bool // set when CloseForShutdown() is called + finRead bool // set once we read a frame with a FinBit + canceledRead bool // set when CancelRead() is called + resetRemotely bool // set when HandleRstStreamFrame() is called + + readChan chan struct{} + readDeadline time.Time + + flowController flowcontrol.StreamFlowController + version protocol.VersionNumber +} + +var _ ReceiveStream = &receiveStream{} +var _ receiveStreamI = &receiveStream{} + +func newReceiveStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *receiveStream { + return &receiveStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), + version: version, + } +} + +func (s *receiveStream) StreamID() protocol.StreamID { + return s.streamID +} + +// Read implements io.Reader. It is not thread safe! +func (s *receiveStream) Read(p []byte) (int, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finRead { + return 0, io.EOF + } + if s.canceledRead { + return 0, s.cancelReadErr + } + if s.resetRemotely { + return 0, s.resetRemotelyErr + } + if s.closedForShutdown { + return 0, s.closeForShutdownErr + } + + bytesRead := 0 + for bytesRead < len(p) { + frame := s.frameQueue.Head() + if frame == nil && bytesRead > 0 { + return bytesRead, s.closeForShutdownErr + } + + for { + // Stop waiting on errors + if s.closedForShutdown { + return bytesRead, s.closeForShutdownErr + } + if s.canceledRead { + return bytesRead, s.cancelReadErr + } + if s.resetRemotely { + return bytesRead, s.resetRemotelyErr + } + + deadline := s.readDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + return bytesRead, errDeadline + } + + if frame != nil { + s.readPosInFrame = int(s.readOffset - frame.Offset) + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.readChan + } else { + select { + case <-s.readChan: + case <-time.After(time.Until(deadline)): + } + } + s.mutex.Lock() + frame = s.frameQueue.Head() + } + + if bytesRead > len(p) { + return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + } + if s.readPosInFrame > int(frame.DataLen()) { + return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) + } + + s.mutex.Unlock() + + copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) + m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) + s.readPosInFrame += m + bytesRead += m + s.readOffset += protocol.ByteCount(m) + + s.mutex.Lock() + // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream + if !s.resetRemotely { + s.flowController.AddBytesRead(protocol.ByteCount(m)) + } + // increase the flow control window, if necessary + s.flowController.MaybeQueueWindowUpdate() + + if s.readPosInFrame >= int(frame.DataLen()) { + s.frameQueue.Pop() + s.finRead = frame.FinBit + if frame.FinBit { + s.sender.onStreamCompleted(s.streamID) + return bytesRead, io.EOF + } + } + } + return bytesRead, nil +} + +func (s *receiveStream) CancelRead(errorCode protocol.ApplicationErrorCode) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finRead { + return nil + } + if s.canceledRead { + return nil + } + s.canceledRead = true + s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) + s.signalRead() + if s.version.UsesIETFFrameFormat() { + s.sender.queueControlFrame(&wire.StopSendingFrame{ + StreamID: s.streamID, + ErrorCode: errorCode, + }) + } + return nil +} + +func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { + maxOffset := frame.Offset + frame.DataLen() + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { + return err + } + + s.mutex.Lock() + defer s.mutex.Unlock() + if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { + return err + } + s.signalRead() + return nil +} + +func (s *receiveStream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closedForShutdown { + return nil + } + if err := s.flowController.UpdateHighestReceived(frame.ByteOffset, true); err != nil { + return err + } + // In gQUIC, error code 0 has a special meaning. + // The peer will reliably continue transmitting, but is not interested in reading from the stream. + // We should therefore just continue reading from the stream, until we encounter the FIN bit. + if !s.version.UsesIETFFrameFormat() && frame.ErrorCode == 0 { + return nil + } + + // ignore duplicate RST_STREAM frames for this stream (after checking their final offset) + if s.resetRemotely { + return nil + } + s.resetRemotely = true + s.resetRemotelyErr = streamCanceledError{ + errorCode: frame.ErrorCode, + error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), + } + s.signalRead() + s.sender.onStreamCompleted(s.streamID) + return nil +} + +func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { + s.handleStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) +} + +func (s *receiveStream) onClose(offset protocol.ByteCount) { + if s.canceledRead && !s.version.UsesIETFFrameFormat() { + s.sender.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: offset, + ErrorCode: 0, + }) + } +} + +func (s *receiveStream) SetReadDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.readDeadline + s.readDeadline = t + s.mutex.Unlock() + // if the new deadline is before the currently set deadline, wake up Read() + if t.Before(oldDeadline) { + s.signalRead() + } + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Read unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *receiveStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalRead() +} + +func (s *receiveStream) getWindowUpdate() protocol.ByteCount { + return s.flowController.GetWindowUpdate() +} + +// signalRead performs a non-blocking send on the readChan +func (s *receiveStream) signalRead() { + select { + case s.readChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/send_stream.go b/vendor/github.com/lucas-clemente/quic-go/send_stream.go new file mode 100644 index 0000000..62ef445 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/send_stream.go @@ -0,0 +1,313 @@ +package quic + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type sendStreamI interface { + SendStream + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) + closeForShutdown(error) + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) +} + +type sendStream struct { + mutex sync.Mutex + + ctx context.Context + ctxCancel context.CancelFunc + + streamID protocol.StreamID + sender streamSender + + writeOffset protocol.ByteCount + + cancelWriteErr error + closeForShutdownErr error + + closedForShutdown bool // set when CloseForShutdown() is called + finishedWriting bool // set once Close() is called + canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received + finSent bool // set when a STREAM_FRAME with FIN bit has b + + dataForWriting []byte + writeChan chan struct{} + writeDeadline time.Time + + flowController flowcontrol.StreamFlowController + + version protocol.VersionNumber +} + +var _ SendStream = &sendStream{} +var _ sendStreamI = &sendStream{} + +func newSendStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *sendStream { + s := &sendStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + writeChan: make(chan struct{}, 1), + version: version, + } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s +} + +func (s *sendStream) StreamID() protocol.StreamID { + return s.streamID // same for receiveStream and sendStream +} + +func (s *sendStream) Write(p []byte) (int, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finishedWriting { + return 0, fmt.Errorf("write on closed stream %d", s.streamID) + } + if s.canceledWrite { + return 0, s.cancelWriteErr + } + if s.closeForShutdownErr != nil { + return 0, s.closeForShutdownErr + } + if !s.writeDeadline.IsZero() && !time.Now().Before(s.writeDeadline) { + return 0, errDeadline + } + if len(p) == 0 { + return 0, nil + } + + s.dataForWriting = make([]byte, len(p)) + copy(s.dataForWriting, p) + s.sender.onHasStreamData(s.streamID) + + var bytesWritten int + var err error + for { + bytesWritten = len(p) - len(s.dataForWriting) + deadline := s.writeDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + s.dataForWriting = nil + err = errDeadline + break + } + if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.writeChan + } else { + select { + case <-s.writeChan: + case <-time.After(time.Until(deadline)): + } + } + s.mutex.Lock() + } + + if s.closeForShutdownErr != nil { + err = s.closeForShutdownErr + } else if s.cancelWriteErr != nil { + err = s.cancelWriteErr + } + return bytesWritten, err +} + +// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream +// maxBytes is the maximum length this frame (including frame header) will have. +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closeForShutdownErr != nil { + return nil, false + } + + frame := &wire.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + DataLenPresent: true, + } + maxDataLen := frame.MaxDataLen(maxBytes, s.version) + if maxDataLen == 0 { // a STREAM frame must have at least one byte of data + return nil, s.dataForWriting != nil + } + frame.Data, frame.FinBit = s.getDataForWriting(maxDataLen) + if len(frame.Data) == 0 && !frame.FinBit { + // this can happen if: + // - popStreamFrame is called but there's no data for writing + // - there's data for writing, but the stream is stream-level flow control blocked + // - there's data for writing, but the stream is connection-level flow control blocked + if s.dataForWriting == nil { + return nil, false + } + isBlocked, _ := s.flowController.IsBlocked() + return nil, !isBlocked + } + if frame.FinBit { + s.finSent = true + s.sender.onStreamCompleted(s.streamID) + } else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream + if isBlocked, offset := s.flowController.IsBlocked(); isBlocked { + s.sender.queueControlFrame(&wire.StreamBlockedFrame{ + StreamID: s.streamID, + Offset: offset, + }) + return frame, false + } + } + return frame, s.dataForWriting != nil +} + +func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { + if s.dataForWriting == nil { + return nil, s.finishedWriting && !s.finSent + } + + // TODO(#657): Flow control for the crypto stream + if s.streamID != s.version.CryptoStreamID() { + maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) + } + if maxBytes == 0 { + return nil, false + } + + var ret []byte + if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { + ret = s.dataForWriting[:maxBytes] + s.dataForWriting = s.dataForWriting[maxBytes:] + } else { + ret = s.dataForWriting + s.dataForWriting = nil + s.signalWrite() + } + s.writeOffset += protocol.ByteCount(len(ret)) + s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) + return ret, s.finishedWriting && s.dataForWriting == nil && !s.finSent +} + +func (s *sendStream) Close() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.canceledWrite { + return fmt.Errorf("Close called for canceled stream %d", s.streamID) + } + s.finishedWriting = true + s.sender.onHasStreamData(s.streamID) // need to send the FIN + s.ctxCancel() + return nil +} + +func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) +} + +// must be called after locking the mutex +func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) error { + if s.canceledWrite { + return nil + } + if s.finishedWriting { + return fmt.Errorf("CancelWrite for closed stream %d", s.streamID) + } + s.canceledWrite = true + s.cancelWriteErr = writeErr + s.signalWrite() + s.sender.queueControlFrame(&wire.RstStreamFrame{ + StreamID: s.streamID, + ByteOffset: s.writeOffset, + ErrorCode: errorCode, + }) + // TODO(#991): cancel retransmissions for this stream + s.ctxCancel() + s.sender.onStreamCompleted(s.streamID) + return nil +} + +func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.handleStopSendingFrameImpl(frame) +} + +func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { + s.flowController.UpdateSendWindow(frame.ByteOffset) + s.mutex.Lock() + if s.dataForWriting != nil { + s.sender.onHasStreamData(s.streamID) + } + s.mutex.Unlock() +} + +// must be called after locking the mutex +func (s *sendStream) handleStopSendingFrameImpl(frame *wire.StopSendingFrame) { + writeErr := streamCanceledError{ + errorCode: frame.ErrorCode, + error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), + } + errorCode := errorCodeStopping + if !s.version.UsesIETFFrameFormat() { + errorCode = errorCodeStoppingGQUIC + } + s.cancelWriteImpl(errorCode, writeErr) +} + +func (s *sendStream) Context() context.Context { + return s.ctx +} + +func (s *sendStream) SetWriteDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.writeDeadline + s.writeDeadline = t + s.mutex.Unlock() + if t.Before(oldDeadline) { + s.signalWrite() + } + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *sendStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalWrite() + s.ctxCancel() +} + +func (s *sendStream) getWriteOffset() protocol.ByteCount { + return s.writeOffset +} + +// signalWrite performs a non-blocking send on the writeChan +func (s *sendStream) signalWrite() { + select { + case s.writeChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go index fb73ccb..31b7911 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/server.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "errors" + "fmt" "net" "sync" "time" @@ -19,6 +20,8 @@ import ( // packetHandler handles packets type packetHandler interface { Session + getCryptoStream() cryptoStreamI + handshakeStatus() <-chan error handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber run() error @@ -32,24 +35,30 @@ type server struct { conn net.PacketConn + supportsTLS bool + serverTLS *serverTLS + certChain crypto.CertChain scfg *handshake.ServerConfig - sessions map[protocol.ConnectionID]packetHandler - sessionsMutex sync.RWMutex - deleteClosedSessionsAfter time.Duration + sessionsMutex sync.RWMutex + sessions map[string] /* string(ConnectionID)*/ packetHandler + closed bool serverError error sessionQueue chan Session errorChan chan struct{} - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error) + // set as members, so they can be set in the tests + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error) + deleteClosedSessionsAfter time.Duration + + logger utils.Logger } var _ Listener = &server{} // ListenAddr creates a QUIC server listening on a given address. -// The listener is not active until Serve() is called. // The tls.Config must not be nil, the quic.Config may be nil. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) @@ -64,7 +73,6 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err } // Listen listens for QUIC connections on a given net.PacketConn. -// The listener is not active until Serve() is called. // The tls.Config must not be nil, the quic.Config may be nil. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { certChain := crypto.NewCertChain(tlsConf) @@ -76,24 +84,77 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, if err != nil { return nil, err } + config = populateServerConfig(config) + + var supportsTLS bool + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + // check if any of the supported versions supports TLS + if v.UsesTLS() { + supportsTLS = true + break + } + } s := &server{ conn: conn, tlsConf: tlsConf, - config: populateServerConfig(config), + config: config, certChain: certChain, scfg: scfg, - sessions: map[protocol.ConnectionID]packetHandler{}, + sessions: map[string]packetHandler{}, newSession: newSession, deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, sessionQueue: make(chan Session, 5), errorChan: make(chan struct{}), + supportsTLS: supportsTLS, + logger: utils.DefaultLogger, + } + if supportsTLS { + if err := s.setupTLS(); err != nil { + return nil, err + } } go s.serve() - utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) + s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } +func (s *server) setupTLS() error { + cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie, s.logger) + if err != nil { + return err + } + serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger) + if err != nil { + return err + } + s.serverTLS = serverTLS + // handle TLS connection establishment statelessly + go func() { + for { + select { + case <-s.errorChan: + return + case tlsSession := <-sessionChan: + connID := tlsSession.connID + sess := tlsSession.sess + s.sessionsMutex.Lock() + if _, ok := s.sessions[string(connID)]; ok { // drop this session if it already exists + s.sessionsMutex.Unlock() + continue + } + s.sessions[string(connID)] = sess + s.sessionsMutex.Unlock() + s.runHandshakeAndSession(sess, connID) + } + } + }() + return nil +} + var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool { if cookie == nil { return false @@ -143,6 +204,18 @@ func populateServerConfig(config *Config) *Config { if maxReceiveConnectionFlowControlWindow == 0 { maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer } + maxIncomingStreams := config.MaxIncomingStreams + if maxIncomingStreams == 0 { + maxIncomingStreams = protocol.DefaultMaxIncomingStreams + } else if maxIncomingStreams < 0 { + maxIncomingStreams = 0 + } + maxIncomingUniStreams := config.MaxIncomingUniStreams + if maxIncomingUniStreams == 0 { + maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams + } else if maxIncomingUniStreams < 0 { + maxIncomingUniStreams = 0 + } return &Config{ Versions: versions, @@ -152,13 +225,15 @@ func populateServerConfig(config *Config) *Config { KeepAlive: config.KeepAlive, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, } } // serve listens on an existing PacketConn func (s *server) serve() { for { - data := getPacketBuffer() + data := *getPacketBuffer() data = data[:protocol.MaxReceivePacketSize] // The packet size should not exceed protocol.MaxReceivePacketSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable @@ -170,8 +245,8 @@ func (s *server) serve() { return } data = data[:n] - if err := s.handlePacket(s.conn, remoteAddr, data); err != nil { - utils.Errorf("error handling packet: %s", err.Error()) + if err := s.handlePacket(remoteAddr, data); err != nil { + s.logger.Errorf("error handling packet: %s", err.Error()) } } } @@ -190,6 +265,12 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { s.sessionsMutex.Lock() + if s.closed { + s.sessionsMutex.Unlock() + return nil + } + s.closed = true + var wg sync.WaitGroup for _, session := range s.sessions { if session != nil { @@ -204,10 +285,9 @@ func (s *server) Close() error { s.sessionsMutex.Unlock() wg.Wait() - if s.conn == nil { - return nil - } - return s.conn.Close() + err := s.conn.Close() + <-s.errorChan // wait for serve() to return + return err } // Addr returns the server's network address @@ -215,7 +295,7 @@ func (s *server) Addr() net.Addr { return s.conn.LocalAddr() } -func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet []byte) error { +func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error { rcvTime := time.Now() r := bytes.NewReader(packet) @@ -224,10 +304,62 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } hdr.Raw = packet[:len(packet)-r.Len()] - connID := hdr.ConnectionID + packetData := packet[len(packet)-r.Len():] + + if hdr.IsPublicHeader { + return s.handleGQUICPacket(hdr, packetData, remoteAddr, rcvTime) + } + return s.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime) +} + +func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { + if hdr.IsLongHeader { + if !s.supportsTLS { + return errors.New("Received an IETF QUIC Long Header") + } + if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { + return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) + } + packetData = packetData[:int(hdr.PayloadLen)] + // TODO(#1312): implement parsing of compound packets + + switch hdr.Type { + case protocol.PacketTypeInitial: + go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData) + return nil + case protocol.PacketTypeHandshake: + // nothing to do here. Packet will be passed to the session. + default: + // Note that this also drops 0-RTT packets. + return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) + } + } s.sessionsMutex.RLock() - session, sessionKnown := s.sessions[connID] + session, sessionKnown := s.sessions[string(hdr.DestConnectionID)] + s.sessionsMutex.RUnlock() + + if sessionKnown && session == nil { + // Late packet for closed session + return nil + } + if !sessionKnown { + s.logger.Debugf("Received %s packet for unknown connection %s.", hdr.Type, hdr.DestConnectionID) + return nil + } + + session.handlePacket(&receivedPacket{ + remoteAddr: remoteAddr, + header: hdr, + data: packetData, + rcvTime: rcvTime, + }) + return nil +} + +func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { + s.sessionsMutex.RLock() + session, sessionKnown := s.sessions[string(hdr.DestConnectionID)] s.sessionsMutex.RUnlock() if sessionKnown && session == nil { @@ -237,25 +369,14 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet // ignore all Public Reset packets if hdr.ResetFlag { - if sessionKnown { - var pr *wire.PublicReset - pr, err = wire.ParsePublicReset(r) - if err != nil { - utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") - } else { - utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) - } - } else { - utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) - } + s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID) return nil } // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. - // TODO(#943): implement sending of IETF draft style stateless resets - if !sessionKnown && (!hdr.VersionFlag && hdr.Type != protocol.PacketTypeInitial) { - _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) + if !sessionKnown && !hdr.VersionFlag { + _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr) return err } @@ -270,79 +391,79 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet // since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { // drop packets that are too small to be valid first packets - if len(packet) < protocol.ClientHelloMinimumSize+len(hdr.Raw) { + if len(packetData) < protocol.MinClientHelloSize { return errors.New("dropping small packet with unknown version") } - utils.Infof("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) - if _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr); err != nil { - return err - } - } - // send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header - if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - _, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, hdr.Version, s.config.Versions), remoteAddr) + s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) + _, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.SrcConnectionID, s.config.Versions), remoteAddr) return err } if !sessionKnown { + // This is (potentially) a Client Hello. + // Make sure it has the minimum required size before spending any more ressources on it. + if len(packetData) < protocol.MinClientHelloSize { + return errors.New("dropping small packet for unknown connection") + } + version := hdr.Version if !protocol.IsSupportedVersion(s.config.Versions, version) { return errors.New("Server BUG: negotiated version not supported") } - utils.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) - var handshakeChan <-chan handshakeEvent - session, handshakeChan, err = s.newSession( - &conn{pconn: pconn, currentAddr: remoteAddr}, + s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr) + var err error + session, err = s.newSession( + &conn{pconn: s.conn, currentAddr: remoteAddr}, version, - hdr.ConnectionID, + hdr.DestConnectionID, s.scfg, s.tlsConf, s.config, + s.logger, ) if err != nil { return err } s.sessionsMutex.Lock() - s.sessions[connID] = session + s.sessions[string(hdr.DestConnectionID)] = session s.sessionsMutex.Unlock() - go func() { - // session.run() returns as soon as the session is closed - _ = session.run() - s.removeConnection(connID) - }() - - go func() { - for { - ev := <-handshakeChan - if ev.err != nil { - return - } - if ev.encLevel == protocol.EncryptionForwardSecure { - break - } - } - s.sessionQueue <- session - }() + s.runHandshakeAndSession(session, hdr.DestConnectionID) } + session.handlePacket(&receivedPacket{ remoteAddr: remoteAddr, header: hdr, - data: packet[len(packet)-r.Len():], + data: packetData, rcvTime: rcvTime, }) return nil } +func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) { + go func() { + _ = session.run() + // session.run() returns as soon as the session is closed + s.removeConnection(connID) + }() + + go func() { + if err := <-session.handshakeStatus(); err != nil { + return + } + s.sessionQueue <- session + }() +} + func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Lock() - s.sessions[id] = nil + s.sessions[string(id)] = nil s.sessionsMutex.Unlock() time.AfterFunc(s.deleteClosedSessionsAfter, func() { s.sessionsMutex.Lock() - delete(s.sessions, id) + delete(s.sessions, string(id)) s.sessionsMutex.Unlock() }) } diff --git a/vendor/github.com/lucas-clemente/quic-go/server_tls.go b/vendor/github.com/lucas-clemente/quic-go/server_tls.go new file mode 100644 index 0000000..7424a40 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/server_tls.go @@ -0,0 +1,235 @@ +package quic + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type nullAEAD struct { + aead crypto.AEAD +} + +var _ quicAEAD = &nullAEAD{} + +func (n *nullAEAD) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + return n.aead.Open(dst, src, packetNumber, associatedData) +} + +func (n *nullAEAD) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + return nil, errors.New("no 1-RTT keys") +} + +type tlsSession struct { + connID protocol.ConnectionID + sess packetHandler +} + +type serverTLS struct { + conn net.PacketConn + config *Config + supportedVersions []protocol.VersionNumber + mintConf *mint.Config + params *handshake.TransportParameters + newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) + + sessionChan chan<- tlsSession + + logger utils.Logger +} + +func newServerTLS( + conn net.PacketConn, + config *Config, + cookieHandler *handshake.CookieHandler, + tlsConf *tls.Config, + logger utils.Logger, +) (*serverTLS, <-chan tlsSession, error) { + mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) + if err != nil { + return nil, nil, err + } + mconf.RequireCookie = true + cs, err := mint.NewDefaultCookieProtector() + if err != nil { + return nil, nil, err + } + mconf.CookieProtector = cs + mconf.CookieHandler = cookieHandler + + sessionChan := make(chan tlsSession) + s := &serverTLS{ + conn: conn, + config: config, + supportedVersions: config.Versions, + mintConf: mconf, + sessionChan: sessionChan, + params: &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + IdleTimeout: config.IdleTimeout, + MaxBidiStreams: uint16(config.MaxIncomingStreams), + MaxUniStreams: uint16(config.MaxIncomingUniStreams), + }, + logger: logger, + } + s.newMintConn = s.newMintConnImpl + return s, sessionChan, nil +} + +func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { + // TODO: add a check that DestConnID == SrcConnID + s.logger.Debugf("Received a Packet. Handling it statelessly.") + sess, err := s.handleInitialImpl(remoteAddr, hdr, data) + if err != nil { + s.logger.Errorf("Error occurred handling initial packet: %s", err) + return + } + if sess == nil { // a stateless reset was done + return + } + s.sessionChan <- tlsSession{ + connID: hdr.DestConnectionID, + sess: sess, + } +} + +// will be set to s.newMintConn by the constructor +func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { + extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v, s.logger) + conf := s.mintConf.Clone() + conf.ExtensionHandler = extHandler + return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil +} + +func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error { + ccf := &wire.ConnectionCloseFrame{ + ErrorCode: qerr.HandshakeFailed, + ReasonPhrase: closeErr.Error(), + } + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + SrcConnectionID: clientHdr.DestConnectionID, + DestConnectionID: clientHdr.SrcConnectionID, + PacketNumber: 1, // random packet number + Version: clientHdr.Version, + } + data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger) + if err != nil { + return err + } + _, err = s.conn.WriteTo(data, remoteAddr) + return err +} + +func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { + if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize { + return nil, errors.New("dropping too small Initial packet") + } + // check version, if not matching send VNP + if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { + s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) + vnp, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.supportedVersions) + if err != nil { + return nil, err + } + _, err = s.conn.WriteTo(vnp, remoteAddr) + return nil, err + } + + // unpack packet and check stream frame contents + aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.DestConnectionID, protocol.VersionTLS) + if err != nil { + return nil, err + } + frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version) + if err != nil { + s.logger.Debugf("Error unpacking initial packet: %s", err) + return nil, nil + } + sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) + if err != nil { + if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { + s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr) + } + return nil, err + } + return sess, nil +} + +func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) { + version := hdr.Version + bc := handshake.NewCryptoStreamConn(remoteAddr) + bc.AddDataForReading(frame.Data) + tls, paramsChan, err := s.newMintConn(bc, version) + if err != nil { + return nil, err + } + alert := tls.Handshake() + if alert == mint.AlertStatelessRetry { + // the HelloRetryRequest was written to the bufferConn + // Take that data and write send a Retry packet + f := &wire.StreamFrame{ + StreamID: version.CryptoStreamID(), + Data: bc.GetDataForWriting(), + } + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + DestConnectionID: hdr.SrcConnectionID, + SrcConnectionID: hdr.DestConnectionID, + PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()), + PacketNumber: hdr.PacketNumber, // echo the client's packet number + Version: version, + } + data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger) + if err != nil { + return nil, err + } + _, err = s.conn.WriteTo(data, remoteAddr) + return nil, err + } + if alert != mint.AlertNoAlert { + return nil, alert + } + if tls.State() != mint.StateServerNegotiated { + return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State()) + } + if alert := tls.Handshake(); alert != mint.AlertNoAlert { + return nil, alert + } + if tls.State() != mint.StateServerWaitFlight2 { + return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) + } + params := <-paramsChan + sess, err := newTLSServerSession( + &conn{pconn: s.conn, currentAddr: remoteAddr}, + hdr.SrcConnectionID, + hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here + protocol.PacketNumber(1), // TODO: use a random packet number here + s.config, + tls, + bc, + aead, + ¶ms, + version, + s.logger, + ) + if err != nil { + return nil, err + } + cs := sess.getCryptoStream() + cs.setReadOffset(frame.DataLen()) + bc.SetStream(cs) + return sess, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/session.go b/vendor/github.com/lucas-clemente/quic-go/session.go index 06d6916..ad53499 100644 --- a/vendor/github.com/lucas-clemente/quic-go/session.go +++ b/vendor/github.com/lucas-clemente/quic-go/session.go @@ -2,6 +2,7 @@ package quic import ( "context" + "crypto/rand" "crypto/tls" "errors" "fmt" @@ -9,8 +10,9 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/ackhandler" - "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -23,6 +25,35 @@ type unpacker interface { Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) } +type streamGetter interface { + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) +} + +type streamManager interface { + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + OpenStream() (Stream, error) + OpenUniStream() (SendStream, error) + OpenStreamSync() (Stream, error) + OpenUniStreamSync() (SendStream, error) + AcceptStream() (Stream, error) + AcceptUniStream() (ReceiveStream, error) + DeleteStream(protocol.StreamID) error + UpdateLimits(*handshake.TransportParameters) + HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error + CloseWithError(error) +} + +type cryptoStreamHandler interface { + HandleCryptoStream() error + ConnectionState() handshake.ConnectionState +} + +type divNonceSetter interface { + SetDiversificationNonce([]byte) error +} + type receivedPacket struct { remoteAddr net.Addr header *wire.Header @@ -30,21 +61,11 @@ type receivedPacket struct { rcvTime time.Time } -var ( - errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") - errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") -) - var ( newCryptoSetup = handshake.NewCryptoSetup newCryptoSetupClient = handshake.NewCryptoSetupClient ) -type handshakeEvent struct { - encLevel protocol.EncryptionLevel - err error -} - type closeError struct { err error remote bool @@ -52,28 +73,30 @@ type closeError struct { // A Session is a QUIC session type session struct { - connectionID protocol.ConnectionID - perspective protocol.Perspective - version protocol.VersionNumber - config *Config + destConnID protocol.ConnectionID + srcConnID protocol.ConnectionID + + perspective protocol.Perspective + version protocol.VersionNumber + config *Config conn connection - streamsMap *streamsMap - cryptoStream streamI + streamsMap streamManager + cryptoStream cryptoStreamI rttStats *congestion.RTTStats sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler streamFramer *streamFramer - - connFlowController flowcontrol.ConnectionFlowController + windowUpdateQueue *windowUpdateQueue + connFlowController flowcontrol.ConnectionFlowController unpacker unpacker packer *packetPacker - cryptoSetup handshake.CryptoSetup + cryptoStreamHandler cryptoStreamHandler receivedPackets chan *receivedPacket sendingScheduled chan struct{} @@ -91,25 +114,26 @@ type session struct { // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them paramsChan <-chan handshake.TransportParameters - // this channel is passed to the CryptoSetup and receives the current encryption level - // it is closed as soon as the handshake is complete - aeadChanged <-chan protocol.EncryptionLevel + // the handshakeEvent channel is passed to the CryptoSetup. + // It receives when it makes sense to try decrypting undecryptable packets. + handshakeEvent <-chan struct{} + // handshakeChan is returned by handshakeStatus. + // It receives any error that might occur during the handshake. + // It is closed when the handshake is complete. + handshakeChan chan error handshakeComplete bool - // will be closed as soon as the handshake completes, and receive any error that might occur until then - // it is used to block WaitUntilHandshakeComplete() - handshakeCompleteChan chan error - // handshakeChan receives handshake events and is closed as soon the handshake completes - // the receiving end of this channel is passed to the creator of the session - // it receives at most 3 handshake events: 2 when the encryption level changes, and one error - handshakeChan chan<- handshakeEvent - lastRcvdPacketNumber protocol.PacketNumber + receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this + receivedFirstForwardSecurePacket bool + lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire // representation, and sent back in public reset packets largestRcvdPacketNumber protocol.PacketNumber sessionCreationTime time.Time lastNetworkActivityTime time.Time + // pacingDeadline is the time when the next packet should be sent + pacingDeadline time.Time peerParams *handshake.TransportParameters @@ -117,30 +141,84 @@ type session struct { // keepAlivePingSent stores whether a Ping frame was sent to the peer or not // it is reset as soon as we receive a packet from the peer keepAlivePingSent bool + + logger utils.Logger } var _ Session = &session{} +var _ streamSender = &session{} // newSession makes a new session func newSession( conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, - sCfg *handshake.ServerConfig, + scfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, -) (packetHandler, <-chan handshakeEvent, error) { + logger utils.Logger, +) (packetHandler, error) { + paramsChan := make(chan handshake.TransportParameters) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveServer, - version: v, - config: config, + conn: conn, + srcConnID: connectionID, + destConnID: connectionID, + perspective: protocol.PerspectiveServer, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, + logger: logger, } - return s.setup(sCfg, "", tlsConf, v, nil) + s.preSetup() + transportParams := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: uint32(s.config.MaxIncomingStreams), + IdleTimeout: s.config.IdleTimeout, + } + divNonce := make([]byte, 32) + if _, err := rand.Read(divNonce); err != nil { + return nil, err + } + cs, err := newCryptoSetup( + s.cryptoStream, + connectionID, + s.conn.RemoteAddr(), + s.version, + divNonce, + scfg, + transportParams, + s.config.Versions, + s.config.AcceptCookie, + paramsChan, + handshakeEvent, + s.logger, + ) + if err != nil { + return nil, err + } + s.cryptoStreamHandler = cs + s.unpacker = newPacketUnpackerGQUIC(cs, s.version) + s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker( + connectionID, + connectionID, + 1, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + divNonce, + cs, + s.streamFramer, + s.perspective, + s.version, + ) + return s, s.postSetup() } -// declare this as a variable, such that we can it mock it in the tests +// declare this as a variable, so that we can it mock it in the tests var newClientSession = func( conn connection, hostname string, @@ -149,32 +227,193 @@ var newClientSession = func( tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, - negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton -) (packetHandler, <-chan handshakeEvent, error) { + negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiation + logger utils.Logger, +) (packetHandler, error) { + paramsChan := make(chan handshake.TransportParameters) + handshakeEvent := make(chan struct{}, 1) s := &session{ - conn: conn, - connectionID: connectionID, - perspective: protocol.PerspectiveClient, - version: v, - config: config, + conn: conn, + srcConnID: connectionID, + destConnID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + config: config, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, + logger: logger, } - return s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions) + s.preSetup() + transportParams := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: uint32(s.config.MaxIncomingStreams), + IdleTimeout: s.config.IdleTimeout, + OmitConnectionID: s.config.RequestConnectionIDOmission, + } + cs, err := newCryptoSetupClient( + s.cryptoStream, + hostname, + connectionID, + s.version, + tlsConf, + transportParams, + paramsChan, + handshakeEvent, + initialVersion, + negotiatedVersions, + s.logger, + ) + if err != nil { + return nil, err + } + s.cryptoStreamHandler = cs + s.unpacker = newPacketUnpackerGQUIC(cs, s.version) + s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker( + connectionID, + connectionID, + 1, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + nil, // no diversification nonce + cs, + s.streamFramer, + s.perspective, + s.version, + ) + return s, s.postSetup() } -func (s *session) setup( - scfg *handshake.ServerConfig, +func newTLSServerSession( + conn connection, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + initialPacketNumber protocol.PacketNumber, + config *Config, + tls handshake.MintTLS, + cryptoStreamConn *handshake.CryptoStreamConn, + nullAEAD crypto.AEAD, + peerParams *handshake.TransportParameters, + v protocol.VersionNumber, + logger utils.Logger, +) (packetHandler, error) { + handshakeEvent := make(chan struct{}, 1) + s := &session{ + conn: conn, + config: config, + srcConnID: srcConnID, + destConnID: destConnID, + perspective: protocol.PerspectiveServer, + version: v, + handshakeEvent: handshakeEvent, + logger: logger, + } + s.preSetup() + cs := handshake.NewCryptoSetupTLSServer( + tls, + cryptoStreamConn, + nullAEAD, + handshakeEvent, + v, + ) + s.cryptoStreamHandler = cs + s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker( + s.destConnID, + s.srcConnID, + initialPacketNumber, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + nil, // no diversification nonce + cs, + s.streamFramer, + s.perspective, + s.version, + ) + if err := s.postSetup(); err != nil { + return nil, err + } + s.peerParams = peerParams + s.processTransportParameters(peerParams) + s.unpacker = newPacketUnpacker(cs, s.version) + return s, nil +} + +// declare this as a variable, such that we can it mock it in the tests +var newTLSClientSession = func( + conn connection, hostname string, - tlsConf *tls.Config, - initialVersion protocol.VersionNumber, - negotiatedVersions []protocol.VersionNumber, -) (packetHandler, <-chan handshakeEvent, error) { - aeadChanged := make(chan protocol.EncryptionLevel, 2) - paramsChan := make(chan handshake.TransportParameters) - s.aeadChanged = aeadChanged - s.paramsChan = paramsChan - handshakeChan := make(chan handshakeEvent, 3) - s.handshakeChan = handshakeChan - s.handshakeCompleteChan = make(chan error, 1) + v protocol.VersionNumber, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + config *Config, + tls handshake.MintTLS, + paramsChan <-chan handshake.TransportParameters, + initialPacketNumber protocol.PacketNumber, + logger utils.Logger, +) (packetHandler, error) { + handshakeEvent := make(chan struct{}, 1) + s := &session{ + conn: conn, + config: config, + srcConnID: srcConnID, + destConnID: destConnID, + perspective: protocol.PerspectiveClient, + version: v, + handshakeEvent: handshakeEvent, + paramsChan: paramsChan, + logger: logger, + } + s.preSetup() + tls.SetCryptoStream(s.cryptoStream) + cs, err := handshake.NewCryptoSetupTLSClient( + s.cryptoStream, + s.destConnID, + hostname, + handshakeEvent, + tls, + v, + ) + if err != nil { + return nil, err + } + s.cryptoStreamHandler = cs + s.unpacker = newPacketUnpacker(cs, s.version) + s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker( + s.destConnID, + s.srcConnID, + initialPacketNumber, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + nil, // no diversification nonce + cs, + s.streamFramer, + s.perspective, + s.version, + ) + return s, s.postSetup() +} + +func (s *session) preSetup() { + s.rttStats = &congestion.RTTStats{} + s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger) + s.connFlowController = flowcontrol.NewConnectionFlowController( + protocol.ReceiveConnectionFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), + s.onHasConnectionWindowUpdate, + s.rttStats, + s.logger, + ) + s.cryptoStream = s.newCryptoStream() +} + +func (s *session) postSetup() error { + s.handshakeChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -186,99 +425,9 @@ func (s *session) setup( s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.rttStats = &congestion.RTTStats{} - transportParams := &handshake.TransportParameters{ - StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, - ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - MaxStreams: protocol.MaxIncomingStreams, - IdleTimeout: s.config.IdleTimeout, - } - s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) - s.connFlowController = flowcontrol.NewConnectionFlowController( - protocol.ReceiveConnectionFlowControlWindow, - protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), - s.rttStats, - ) - s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) - s.cryptoStream = s.newStream(s.version.CryptoStreamID()) - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController) - - var err error - if s.perspective == protocol.PerspectiveServer { - verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool { - return s.config.AcceptCookie(clientAddr, cookie) - } - if s.version.UsesTLS() { - s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer( - s.cryptoStream, - s.connectionID, - tlsConf, - s.conn.RemoteAddr(), - transportParams, - paramsChan, - aeadChanged, - verifySourceAddr, - s.config.Versions, - s.version, - ) - } else { - s.cryptoSetup, err = newCryptoSetup( - s.cryptoStream, - s.connectionID, - s.conn.RemoteAddr(), - s.version, - scfg, - transportParams, - s.config.Versions, - verifySourceAddr, - paramsChan, - aeadChanged, - ) - } - } else { - transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission - if s.version.UsesTLS() { - s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient( - s.cryptoStream, - s.connectionID, - hostname, - tlsConf, - transportParams, - paramsChan, - aeadChanged, - initialVersion, - s.config.Versions, - s.version, - ) - } else { - s.cryptoSetup, err = newCryptoSetupClient( - s.cryptoStream, - hostname, - s.connectionID, - s.version, - tlsConf, - transportParams, - paramsChan, - aeadChanged, - initialVersion, - negotiatedVersions, - ) - } - } - if err != nil { - return nil, nil, err - } - - s.packer = newPacketPacker(s.connectionID, - s.cryptoSetup, - s.streamFramer, - s.perspective, - s.version, - ) - s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - - return s, handshakeChan, nil + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.connFlowController, s.packer.QueueControlFrame) + return nil } // run the session main loop @@ -286,20 +435,23 @@ func (s *session) run() error { defer s.ctxCancel() go func() { - if err := s.cryptoSetup.HandleCryptoStream(); err != nil { + if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil { s.Close(err) } }() var closeErr closeError - aeadChanged := s.aeadChanged runLoop: for { + // Close immediately if requested select { case closeErr = <-s.closeChan: break runLoop + case _, ok := <-s.handshakeEvent: + // when the handshake is completed, the channel will be closed + s.handleHandshakeEvent(!ok) default: } @@ -327,38 +479,43 @@ runLoop: } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. - putPacketBuffer(p.header.Raw) + putPacketBuffer(&p.header.Raw) case p := <-s.paramsChan: s.processTransportParameters(&p) - case l, ok := <-aeadChanged: - if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. - s.handshakeComplete = true - aeadChanged = nil // prevent this case from ever being selected again - s.sentPacketHandler.SetHandshakeComplete() - close(s.handshakeChan) - close(s.handshakeCompleteChan) - } else { - s.tryDecryptingQueuedPackets() - s.handshakeChan <- handshakeEvent{encLevel: l} - } + case _, ok := <-s.handshakeEvent: + // when the handshake is completed, the channel will be closed + s.handleHandshakeEvent(!ok) } now := time.Now() if timeout := s.sentPacketHandler.GetAlarmTimeout(); !timeout.IsZero() && timeout.Before(now) { - // This could cause packets to be retransmitted, so check it before trying - // to send packets. - s.sentPacketHandler.OnAlarm() + // This could cause packets to be retransmitted. + // Check it before trying to send packets. + if err := s.sentPacketHandler.OnAlarm(); err != nil { + s.closeLocal(err) + } } - if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 { + var pacingDeadline time.Time + if s.pacingDeadline.IsZero() { // the timer didn't have a pacing deadline set + pacingDeadline = s.sentPacketHandler.TimeUntilSend() + } + if s.config.KeepAlive && !s.keepAlivePingSent && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 { // send the PING frame since there is no activity in the session s.packer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true + } else if !pacingDeadline.IsZero() && now.Before(pacingDeadline) { + // If we get to this point before the pacing deadline, we should wait until that deadline. + // This can happen when scheduleSending is called, or a packet is received. + // Set the timer and restart the run loop. + s.pacingDeadline = pacingDeadline + continue } - if err := s.sendPacket(); err != nil { + if err := s.sendPackets(); err != nil { s.closeLocal(err) } + if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } @@ -368,17 +525,12 @@ runLoop: if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout { s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } - - if err := s.streamsMap.DeleteClosedStreams(); err != nil { - s.closeLocal(err) - } } // only send the error the handshakeChan when the handshake is not completed yet // otherwise this chan will already be closed if !s.handshakeComplete { - s.handshakeCompleteChan <- closeErr.err - s.handshakeChan <- handshakeEvent{err: closeErr.err} + s.handshakeChan <- closeErr.err } s.handleCloseError(closeErr) return closeErr.err @@ -388,6 +540,10 @@ func (s *session) Context() context.Context { return s.ctx } +func (s *session) ConnectionState() ConnectionState { + return s.cryptoStreamHandler.ConnectionState() +} + func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { @@ -409,15 +565,36 @@ func (s *session) maybeResetTimer() { if !s.receivedTooManyUndecrytablePacketsTime.IsZero() { deadline = utils.MinTime(deadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout)) } + if !s.pacingDeadline.IsZero() { + deadline = utils.MinTime(deadline, s.pacingDeadline) + } s.timer.Reset(deadline) } +func (s *session) handleHandshakeEvent(completed bool) { + if !completed { + s.tryDecryptingQueuedPackets() + return + } + s.handshakeComplete = true + s.handshakeEvent = nil // prevent this case from ever being selected again + if !s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient { + // In gQUIC, there's no equivalent to the Finished message in TLS + // The server knows that the handshake is complete when it receives the first forward-secure packet sent by the client. + // We need to make sure that the client actually sends such a packet. + s.packer.QueueControlFrame(&wire.PingFrame{}) + s.scheduleSending() + } + close(s.handshakeChan) +} + func (s *session) handlePacketImpl(p *receivedPacket) error { if s.perspective == protocol.PerspectiveClient { - diversificationNonce := p.header.DiversificationNonce - if len(diversificationNonce) > 0 { - s.cryptoSetup.SetDiversificationNonce(diversificationNonce) + if divNonce := p.header.DiversificationNonce; len(divNonce) > 0 { + if err := s.cryptoStreamHandler.(divNonceSetter).SetDiversificationNonce(divNonce); err != nil { + return err + } } } @@ -426,6 +603,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { p.rcvTime = time.Now() } + s.receivedFirstPacket = true s.lastNetworkActivityTime = p.rcvTime s.keepAlivePingSent = false hdr := p.header @@ -439,34 +617,39 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { ) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) - if utils.Debug() { + if s.logger.Debug() { if err != nil { - utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID) } else { - utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID, packet.encryptionLevel) } - hdr.Log() + hdr.Log(s.logger) } // if the decryption failed, this might be a packet sent by an attacker - // don't update the remote address - if quicErr, ok := err.(*qerr.QuicError); ok && quicErr.ErrorCode == qerr.DecryptionFailure { - return err - } - if s.perspective == protocol.PerspectiveServer { - // update the remote address, even if unpacking failed for any other reason than a decryption error - s.conn.SetCurrentRemoteAddr(p.remoteAddr) - } if err != nil { return err } + // In TLS 1.3, the client considers the handshake complete as soon as + // it received the server's Finished message and sent its Finished. + // We have to wait for the first forward-secure packet from the server before + // deleting all handshake packets from the history. + if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.EncryptionForwardSecure { + s.receivedFirstForwardSecurePacket = true + s.sentPacketHandler.SetHandshakeComplete() + } + s.lastRcvdPacketNumber = hdr.PacketNumber // Only do this after decrypting, so we are sure the packet is not attacker-controlled s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) - isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) - if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil { - return err + // If this is a Retry packet, there's no need to send an ACK. + // The session will be closed and recreated as soon as the crypto setup processed the HRR. + if hdr.Type != protocol.PacketTypeRetry { + isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames) + if err := s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, p.rcvTime, isRetransmittable); err != nil { + return err + } } return s.handleFrames(packet.frames, packet.encryptionLevel) @@ -475,45 +658,42 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error { for _, ff := range fs { var err error - wire.LogFrame(ff, false) + wire.LogFrame(s.logger, ff, false) switch frame := ff.(type) { case *wire.StreamFrame: - err = s.handleStreamFrame(frame) + err = s.handleStreamFrame(frame, encLevel) case *wire.AckFrame: err = s.handleAckFrame(frame, encLevel) case *wire.ConnectionCloseFrame: s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) case *wire.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") - case *wire.StopWaitingFrame: - // LeastUnacked is guaranteed to have LeastUnacked > 0 - // therefore this will never underflow - s.receivedPacketHandler.SetLowerLimit(frame.LeastUnacked - 1) + case *wire.StopWaitingFrame: // ignore STOP_WAITINGs case *wire.RstStreamFrame: err = s.handleRstStreamFrame(frame) case *wire.MaxDataFrame: s.handleMaxDataFrame(frame) case *wire.MaxStreamDataFrame: err = s.handleMaxStreamDataFrame(frame) + case *wire.MaxStreamIDFrame: + err = s.handleMaxStreamIDFrame(frame) case *wire.BlockedFrame: case *wire.StreamBlockedFrame: + case *wire.StreamIDBlockedFrame: + case *wire.StopSendingFrame: + err = s.handleStopSendingFrame(frame) case *wire.PingFrame: + case *wire.PathChallengeFrame: + s.handlePathChallengeFrame(frame) + case *wire.PathResponseFrame: + // since we don't send PATH_CHALLENGEs, we don't expect PATH_RESPONSEs + err = errors.New("unexpected PATH_RESPONSE frame") default: return errors.New("Session BUG: unexpected frame type") } if err != nil { - switch err { - case ackhandler.ErrDuplicateOrOutOfOrderAck: - // Can happen e.g. when packets thought missing arrive late - case errRstStreamOnInvalidStream: - // Can happen when RST_STREAMs arrive early or late (?) - utils.Errorf("Ignoring error in session: %s", err.Error()) - case errWindowUpdateOnClosedStream: - // Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream - default: - return err - } + return err } } return nil @@ -529,11 +709,16 @@ func (s *session) handlePacket(p *receivedPacket) { } } -func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { +func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error { if frame.StreamID == s.version.CryptoStreamID() { - return s.cryptoStream.AddStreamFrame(frame) + if frame.FinBit { + return errors.New("Received STREAM frame with FIN bit for the crypto stream") + } + return s.cryptoStream.handleStreamFrame(frame) + } else if encLevel <= protocol.EncryptionUnencrypted { + return qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", frame.StreamID)) } - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err } @@ -542,7 +727,7 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { // ignore this StreamFrame return nil } - return str.AddStreamFrame(frame) + return str.handleStreamFrame(frame) } func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { @@ -550,30 +735,67 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { } func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + if frame.StreamID == s.version.CryptoStreamID() { + s.cryptoStream.handleMaxStreamDataFrame(frame) + return nil + } + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) if err != nil { return err } if str == nil { - return errWindowUpdateOnClosedStream + // stream is closed and already garbage collected + return nil } - str.UpdateSendWindow(frame.ByteOffset) + str.handleMaxStreamDataFrame(frame) return nil } +func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error { + return s.streamsMap.HandleMaxStreamIDFrame(frame) +} + func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received RST_STREAM frame for the crypto stream") + } + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err } if str == nil { - return errRstStreamOnInvalidStream + // stream is closed and already garbage collected + return nil } - return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset) + return str.handleRstStreamFrame(frame) +} + +func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + return errors.New("Received a STOP_SENDING frame for the crypto stream") + } + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.handleStopSendingFrame(frame) + return nil +} + +func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { + s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) } func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { - return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime) + if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { + return err + } + s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) + return nil } func (s *session) closeLocal(e error) { @@ -608,15 +830,15 @@ func (s *session) handleCloseError(closeErr closeError) error { } // Don't log 'normal' reasons if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout { - utils.Infof("Closing connection %x", s.connectionID) + s.logger.Infof("Closing connection %s", s.srcConnID) } else { - utils.Errorf("Closing session with error: %s", closeErr.err.Error()) + s.logger.Errorf("Closing session with error: %s", closeErr.err.Error()) } - s.cryptoStream.Cancel(quicErr) + s.cryptoStream.closeForShutdown(quicErr) s.streamsMap.CloseWithError(quicErr) - if closeErr.err == errCloseSessionForNewVersion { + if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry { return nil } @@ -635,130 +857,203 @@ func (s *session) handleCloseError(closeErr closeError) error { func (s *session) processTransportParameters(params *handshake.TransportParameters) { s.peerParams = params - s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams) + s.streamsMap.UpdateLimits(params) if params.OmitConnectionID { s.packer.SetOmitConnectionID() } + if params.MaxPacketSize != 0 { + s.packer.SetMaxPacketSize(params.MaxPacketSize) + } s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) - s.streamsMap.Range(func(str streamI) { - str.UpdateSendWindow(params.StreamFlowControlWindow) - }) + // the crypto stream is the only open stream at this moment + // so we don't need to update stream flow control windows } -func (s *session) sendPacket() error { - s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) +func (s *session) sendPackets() error { + s.pacingDeadline = time.Time{} - // Get MAX_DATA and MAX_STREAM_DATA frames - // this call triggers the flow controller to increase the flow control windows, if necessary - windowUpdates := s.getWindowUpdates() - for _, f := range windowUpdates { - s.packer.QueueControlFrame(f) + sendMode := s.sentPacketHandler.SendMode() + if sendMode == ackhandler.SendNone { // shortcut: return immediately if there's nothing to send + return nil } - ack := s.receivedPacketHandler.GetAckFrame() - if ack != nil { - s.packer.QueueControlFrame(ack) - } - - // Repeatedly try sending until we don't have any more data, or run out of the congestion window + numPackets := s.sentPacketHandler.ShouldSendNumPackets() + var numPacketsSent int +sendLoop: for { - if !s.sentPacketHandler.SendingAllowed() { - if ack == nil { - return nil - } - // If we aren't allowed to send, at least try sending an ACK frame - swf := s.sentPacketHandler.GetStopWaitingFrame(false) - if swf != nil { - s.packer.QueueControlFrame(swf) - } - packet, err := s.packer.PackAckPacket() + switch sendMode { + case ackhandler.SendNone: + break sendLoop + case ackhandler.SendAck: + // We can at most send a single ACK only packet. + // There will only be a new ACK after receiving new packets. + // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. + return s.maybeSendAckOnlyPacket() + case ackhandler.SendRTO: + // try to send a retransmission first + sentPacket, err := s.maybeSendRetransmission() if err != nil { return err } - return s.sendPackedPacket(packet) - } - - // check for retransmissions first - for { - retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() - if retransmitPacket == nil { - break - } - - if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { - if s.handshakeComplete { - // Don't retransmit handshake packets when the handshake is complete - continue - } - utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) - s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket) + if !sentPacket { + // In RTO mode, a probe packet has to be sent. + // Add a PING frame to make sure a (retransmittable) packet will be sent. + s.queueControlFrame(&wire.PingFrame{}) + sentPacket, err := s.sendPacket() if err != nil { return err } - if err = s.sendPackedPacket(packet); err != nil { - return err - } - } else { - utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) - // resend the frames that were in the packet - for _, frame := range retransmitPacket.GetFramesForRetransmission() { - // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window - switch f := frame.(type) { - case *wire.StreamFrame: - s.streamFramer.AddFrameForRetransmission(f) - default: - s.packer.QueueControlFrame(frame) - } + if !sentPacket { + return errors.New("session BUG: expected a packet to be sent in RTO mode") } } - } - - hasRetransmission := s.streamFramer.HasFramesForRetransmission() - if ack != nil || hasRetransmission { - swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) - if swf != nil { - s.packer.QueueControlFrame(swf) + numPacketsSent++ + case ackhandler.SendTLP: + // In TLP mode, a probe packet has to be sent. + // Add a PING frame to make sure a (retransmittable) packet will be sent. + s.queueControlFrame(&wire.PingFrame{}) + sentPacket, err := s.sendPacket() + if err != nil { + return err } + if !sentPacket { + return errors.New("session BUG: expected a packet to be sent in TLP mode") + } + return nil + case ackhandler.SendRetransmission: + sentPacket, err := s.maybeSendRetransmission() + if err != nil { + return err + } + if sentPacket { + numPacketsSent++ + // This can happen if a retransmission queued, but it wasn't necessary to send it. + // e.g. when an Initial is queued, but we already received a packet from the server. + } + case ackhandler.SendAny: + sentPacket, err := s.sendPacket() + if err != nil { + return err + } + if !sentPacket { + break sendLoop + } + numPacketsSent++ + default: + return fmt.Errorf("BUG: invalid send mode %d", sendMode) } - // add a retransmittable frame - if s.sentPacketHandler.ShouldSendRetransmittablePacket() { - s.packer.QueueControlFrame(&wire.PingFrame{}) + if numPacketsSent >= numPackets { + break } - packet, err := s.packer.PackPacket() - if err != nil || packet == nil { - return err - } - if err = s.sendPackedPacket(packet); err != nil { - return err - } - - // send every window update twice - for _, f := range windowUpdates { - s.packer.QueueControlFrame(f) - } - windowUpdates = nil - ack = nil + sendMode = s.sentPacketHandler.SendMode() } + // Only start the pacing timer if we sent as many packets as we were allowed. + // There will probably be more to send when calling sendPacket again. + if numPacketsSent == numPackets { + s.pacingDeadline = s.sentPacketHandler.TimeUntilSend() + } + return nil } -func (s *session) sendPackedPacket(packet *packedPacket) error { - defer putPacketBuffer(packet.raw) - err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{ - PacketNumber: packet.header.PacketNumber, - Frames: packet.frames, - Length: protocol.ByteCount(len(packet.raw)), - EncryptionLevel: packet.encryptionLevel, - }) +func (s *session) maybeSendAckOnlyPacket() error { + ack := s.receivedPacketHandler.GetAckFrame() + if ack == nil { + return nil + } + s.packer.QueueControlFrame(ack) + + if s.version.UsesStopWaitingFrames() { // for gQUIC, maybe add a STOP_WAITING + if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { + s.packer.QueueControlFrame(swf) + } + } + packet, err := s.packer.PackAckPacket() if err != nil { return err } + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) + return s.sendPackedPacket(packet) +} + +// maybeSendRetransmission sends retransmissions for at most one packet. +// It takes care that Initials aren't retransmitted, if a packet from the server was already received. +func (s *session) maybeSendRetransmission() (bool, error) { + var retransmitPacket *ackhandler.Packet + for { + retransmitPacket = s.sentPacketHandler.DequeuePacketForRetransmission() + if retransmitPacket == nil { + return false, nil + } + + // Don't retransmit Initial packets if we already received a response. + // An Initial might have been retransmitted multiple times before we receive a response. + // As soon as we receive one response, we don't need to send any more Initials. + if s.receivedFirstPacket && retransmitPacket.PacketType == protocol.PacketTypeInitial { + s.logger.Debugf("Skipping retransmission of packet %d. Already received a response to an Initial.", retransmitPacket.PacketNumber) + continue + } + break + } + + if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { + s.logger.Debugf("Dequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + } else { + s.logger.Debugf("Dequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) + } + + if s.version.UsesStopWaitingFrames() { + s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) + } + packets, err := s.packer.PackRetransmission(retransmitPacket) + if err != nil { + return false, err + } + ackhandlerPackets := make([]*ackhandler.Packet, len(packets)) + for i, packet := range packets { + ackhandlerPackets[i] = packet.ToAckHandlerPacket() + } + s.sentPacketHandler.SentPacketsAsRetransmission(ackhandlerPackets, retransmitPacket.PacketNumber) + for _, packet := range packets { + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + } + return true, nil +} + +func (s *session) sendPacket() (bool, error) { + if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { + s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset}) + } + s.windowUpdateQueue.QueueAll() + + if ack := s.receivedPacketHandler.GetAckFrame(); ack != nil { + s.packer.QueueControlFrame(ack) + if s.version.UsesStopWaitingFrames() { + if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { + s.packer.QueueControlFrame(swf) + } + } + } + + packet, err := s.packer.PackPacket() + if err != nil || packet == nil { + return false, err + } + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) + if err := s.sendPackedPacket(packet); err != nil { + return false, err + } + return true, nil +} + +func (s *session) sendPackedPacket(packet *packedPacket) error { + defer putPacketBuffer(&packet.raw) s.logPacket(packet) return s.conn.Write(packet.raw) } func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { - s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) packet, err := s.packer.PackConnectionClose(&wire.ConnectionCloseFrame{ ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage, @@ -771,23 +1066,27 @@ func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { } func (s *session) logPacket(packet *packedPacket) { - if !utils.Debug() { + if !s.logger.Debug() { // We don't need to allocate the slices for calling the format functions return } - utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) - packet.header.Log() + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel) + packet.header.Log(s.logger) for _, frame := range packet.frames { - wire.LogFrame(frame, true) + wire.LogFrame(s.logger, frame, true) } } // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. -// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. +// It is *only* needed for gQUIC's H2. +// It will be removed as soon as gQUIC moves towards the IETF H2/QUIC stream mapping. func (s *session) GetOrOpenStream(id protocol.StreamID) (Stream, error) { - str, err := s.streamsMap.GetOrOpenStream(id) + str, err := s.streamsMap.GetOrOpenSendStream(id) if str != nil { - return str, err + if bstr, ok := str.(Stream); ok { + return bstr, err + } + return nil, fmt.Errorf("Stream %d is not a bidirectional stream", id) } // make sure to return an actual nil value here, not an Stream with value nil return nil, err @@ -798,6 +1097,10 @@ func (s *session) AcceptStream() (Stream, error) { return s.streamsMap.AcceptStream() } +func (s *session) AcceptUniStream() (ReceiveStream, error) { + return s.streamsMap.AcceptUniStream() +} + // OpenStream opens a stream func (s *session) OpenStream() (Stream, error) { return s.streamsMap.OpenStream() @@ -807,38 +1110,56 @@ func (s *session) OpenStreamSync() (Stream, error) { return s.streamsMap.OpenStreamSync() } -func (s *session) WaitUntilHandshakeComplete() error { - return <-s.handshakeCompleteChan +func (s *session) OpenUniStream() (SendStream, error) { + return s.streamsMap.OpenUniStream() } -func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { - s.packer.QueueControlFrame(&wire.RstStreamFrame{ - StreamID: id, - ByteOffset: offset, - }) - s.scheduleSending() +func (s *session) OpenUniStreamSync() (SendStream, error) { + return s.streamsMap.OpenUniStreamSync() } func (s *session) newStream(id protocol.StreamID) streamI { + flowController := s.newFlowController(id) + return newStream(id, s, flowController, s.version) +} + +func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { var initialSendWindow protocol.ByteCount if s.peerParams != nil { initialSendWindow = s.peerParams.StreamFlowControlWindow } - flowController := flowcontrol.NewStreamFlowController( + return flowcontrol.NewStreamFlowController( id, s.version.StreamContributesToConnectionFlowControl(id), s.connFlowController, protocol.ReceiveStreamFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), initialSendWindow, + s.onHasStreamWindowUpdate, s.rttStats, + s.logger, ) - return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version) +} + +func (s *session) newCryptoStream() cryptoStreamI { + id := s.version.CryptoStreamID() + flowController := flowcontrol.NewStreamFlowController( + id, + s.version.StreamContributesToConnectionFlowControl(id), + s.connFlowController, + protocol.ReceiveStreamFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + 0, + s.onHasStreamWindowUpdate, + s.rttStats, + s.logger, + ) + return newCryptoStream(s, flowController, s.version) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { - utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) - return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) + s.logger.Infof("Sending public reset for connection %x, packet number %d", s.destConnID, rejectedPacketNumber) + return s.conn.Write(wire.WritePublicReset(s.destConnID, rejectedPacketNumber, 0)) } // scheduleSending signals that we have data for sending @@ -851,7 +1172,7 @@ func (s *session) scheduleSending() { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.handshakeComplete { - utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) + s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) return } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { @@ -860,10 +1181,10 @@ func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { s.receivedTooManyUndecrytablePacketsTime = time.Now() s.maybeResetTimer() } - utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) + s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) return } - utils.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) + s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) s.undecryptablePackets = append(s.undecryptablePackets, p) } @@ -874,33 +1195,48 @@ func (s *session) tryDecryptingQueuedPackets() { s.undecryptablePackets = s.undecryptablePackets[:0] } -func (s *session) getWindowUpdates() []wire.Frame { - var res []wire.Frame - s.streamsMap.Range(func(str streamI) { - if offset := str.GetWindowUpdate(); offset != 0 { - res = append(res, &wire.MaxStreamDataFrame{ - StreamID: str.StreamID(), - ByteOffset: offset, - }) - } - }) - if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { - res = append(res, &wire.MaxDataFrame{ - ByteOffset: offset, - }) +func (s *session) queueControlFrame(f wire.Frame) { + s.packer.QueueControlFrame(f) + s.scheduleSending() +} + +func (s *session) onHasStreamWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.AddStream(id) + s.scheduleSending() +} + +func (s *session) onHasConnectionWindowUpdate() { + s.windowUpdateQueue.AddConnection() + s.scheduleSending() +} + +func (s *session) onHasStreamData(id protocol.StreamID) { + s.streamFramer.AddActiveStream(id) + s.scheduleSending() +} + +func (s *session) onStreamCompleted(id protocol.StreamID) { + if err := s.streamsMap.DeleteStream(id); err != nil { + s.Close(err) } - return res } func (s *session) LocalAddr() net.Addr { return s.conn.LocalAddr() } -// RemoteAddr returns the net.Addr of the client func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } +func (s *session) handshakeStatus() <-chan error { + return s.handshakeChan +} + +func (s *session) getCryptoStream() cryptoStreamI { + return s.cryptoStream +} + func (s *session) GetVersion() protocol.VersionNumber { return s.version } diff --git a/vendor/github.com/lucas-clemente/quic-go/stream.go b/vendor/github.com/lucas-clemente/quic-go/stream.go index 806e7fc..f8d851b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream.go @@ -1,85 +1,80 @@ package quic import ( - "context" - "fmt" - "io" "net" "sync" "time" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) +const ( + errorCodeStopping protocol.ApplicationErrorCode = 0 + errorCodeStoppingGQUIC protocol.ApplicationErrorCode = 7 +) + +// The streamSender is notified by the stream about various events. +type streamSender interface { + queueControlFrame(wire.Frame) + onHasStreamData(protocol.StreamID) + onStreamCompleted(protocol.StreamID) +} + +// Each of the both stream halves gets its own uniStreamSender. +// This is necessary in order to keep track when both halves have been completed. +type uniStreamSender struct { + streamSender + onStreamCompletedImpl func() +} + +func (s *uniStreamSender) queueControlFrame(f wire.Frame) { + s.streamSender.queueControlFrame(f) +} + +func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { + s.streamSender.onHasStreamData(id) +} + +func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { + s.onStreamCompletedImpl() +} + +var _ streamSender = &uniStreamSender{} + type streamI interface { Stream - - AddStreamFrame(*wire.StreamFrame) error - RegisterRemoteError(error, protocol.ByteCount) error - LenOfDataForWriting() protocol.ByteCount - GetDataForWriting(maxBytes protocol.ByteCount) []byte - GetWriteOffset() protocol.ByteCount - Finished() bool - Cancel(error) - ShouldSendFin() bool - SentFin() - // methods needed for flow control - GetWindowUpdate() protocol.ByteCount - UpdateSendWindow(protocol.ByteCount) - IsFlowControlBlocked() bool + closeForShutdown(error) + // for receiving + handleStreamFrame(*wire.StreamFrame) error + handleRstStreamFrame(*wire.RstStreamFrame) error + getWindowUpdate() protocol.ByteCount + // for sending + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) + handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) } +var _ receiveStreamI = (streamI)(nil) +var _ sendStreamI = (streamI)(nil) + // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { - mutex sync.Mutex + receiveStream + sendStream - ctx context.Context - ctxCancel context.CancelFunc + completedMutex sync.Mutex + sender streamSender + receiveStreamCompleted bool + sendStreamCompleted bool - streamID protocol.StreamID - onData func() - // onReset is a callback that should send a RST_STREAM - onReset func(protocol.StreamID, protocol.ByteCount) - - readPosInFrame int - writeOffset protocol.ByteCount - readOffset protocol.ByteCount - - // Once set, the errors must not be changed! - err error - - // cancelled is set when Cancel() is called - cancelled utils.AtomicBool - // finishedReading is set once we read a frame with a FinBit - finishedReading utils.AtomicBool - // finisedWriting is set once Close() is called - finishedWriting utils.AtomicBool - // resetLocally is set if Reset() is called - resetLocally utils.AtomicBool - // resetRemotely is set if RegisterRemoteError() is called - resetRemotely utils.AtomicBool - - frameQueue *streamFrameSorter - readChan chan struct{} - readDeadline time.Time - - dataForWriting []byte - finSent utils.AtomicBool - rstSent utils.AtomicBool - writeChan chan struct{} - writeDeadline time.Time - - flowController flowcontrol.StreamFlowController - version protocol.VersionNumber + version protocol.VersionNumber } var _ Stream = &stream{} -var _ streamI = &stream{} type deadlineError struct{} @@ -89,293 +84,58 @@ func (deadlineError) Timeout() bool { return true } var errDeadline net.Error = &deadlineError{} +type streamCanceledError struct { + error + errorCode protocol.ApplicationErrorCode +} + +func (streamCanceledError) Canceled() bool { return true } +func (e streamCanceledError) ErrorCode() protocol.ApplicationErrorCode { return e.errorCode } + +var _ StreamError = &streamCanceledError{} + // newStream creates a new Stream -func newStream(StreamID protocol.StreamID, - onData func(), - onReset func(protocol.StreamID, protocol.ByteCount), +func newStream(streamID protocol.StreamID, + sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber, ) *stream { - s := &stream{ - onData: onData, - onReset: onReset, - streamID: StreamID, - flowController: flowController, - frameQueue: newStreamFrameSorter(), - readChan: make(chan struct{}, 1), - writeChan: make(chan struct{}, 1), - version: version, + s := &stream{sender: sender} + senderForSendStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.sendStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, } - s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) + senderForReceiveStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.receiveStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController, version) return s } -// Read implements io.Reader. It is not thread safe! -func (s *stream) Read(p []byte) (int, error) { - s.mutex.Lock() - err := s.err - s.mutex.Unlock() - if s.cancelled.Get() || s.resetLocally.Get() { - return 0, err - } - if s.finishedReading.Get() { - return 0, io.EOF - } - - bytesRead := 0 - for bytesRead < len(p) { - s.mutex.Lock() - frame := s.frameQueue.Head() - if frame == nil && bytesRead > 0 { - err = s.err - s.mutex.Unlock() - return bytesRead, err - } - - var err error - for { - // Stop waiting on errors - if s.resetLocally.Get() || s.cancelled.Get() { - err = s.err - break - } - - deadline := s.readDeadline - if !deadline.IsZero() && !time.Now().Before(deadline) { - err = errDeadline - break - } - - if frame != nil { - s.readPosInFrame = int(s.readOffset - frame.Offset) - break - } - - s.mutex.Unlock() - if deadline.IsZero() { - <-s.readChan - } else { - select { - case <-s.readChan: - case <-time.After(deadline.Sub(time.Now())): - } - } - s.mutex.Lock() - frame = s.frameQueue.Head() - } - s.mutex.Unlock() - - if err != nil { - return bytesRead, err - } - - m := utils.Min(len(p)-bytesRead, int(frame.DataLen())-s.readPosInFrame) - - if bytesRead > len(p) { - return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) - } - if s.readPosInFrame > int(frame.DataLen()) { - return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, frame.DataLen()) - } - copy(p[bytesRead:], frame.Data[s.readPosInFrame:]) - - s.readPosInFrame += m - bytesRead += m - s.readOffset += protocol.ByteCount(m) - - // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream - if !s.resetRemotely.Get() { - s.flowController.AddBytesRead(protocol.ByteCount(m)) - } - s.onData() // so that a possible WINDOW_UPDATE is sent - - if s.readPosInFrame >= int(frame.DataLen()) { - fin := frame.FinBit - s.mutex.Lock() - s.frameQueue.Pop() - s.mutex.Unlock() - if fin { - s.finishedReading.Set(true) - return bytesRead, io.EOF - } - } - } - - return bytesRead, nil +// need to define StreamID() here, since both receiveStream and readStream have a StreamID() +func (s *stream) StreamID() protocol.StreamID { + // the result is same for receiveStream and sendStream + return s.sendStream.StreamID() } -func (s *stream) Write(p []byte) (int, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.resetLocally.Get() || s.err != nil { - return 0, s.err - } - if s.finishedWriting.Get() { - return 0, fmt.Errorf("write on closed stream %d", s.streamID) - } - if len(p) == 0 { - return 0, nil - } - - s.dataForWriting = make([]byte, len(p)) - copy(s.dataForWriting, p) - s.onData() - - var err error - for { - deadline := s.writeDeadline - if !deadline.IsZero() && !time.Now().Before(deadline) { - err = errDeadline - break - } - if s.dataForWriting == nil || s.err != nil { - break - } - - s.mutex.Unlock() - if deadline.IsZero() { - <-s.writeChan - } else { - select { - case <-s.writeChan: - case <-time.After(deadline.Sub(time.Now())): - } - } - s.mutex.Lock() - } - - if err != nil { - return 0, err - } - if s.err != nil { - return len(p) - len(s.dataForWriting), s.err - } - return len(p), nil -} - -func (s *stream) GetWriteOffset() protocol.ByteCount { - return s.writeOffset -} - -func (s *stream) LenOfDataForWriting() protocol.ByteCount { - s.mutex.Lock() - var l protocol.ByteCount - if s.err == nil { - l = protocol.ByteCount(len(s.dataForWriting)) - } - s.mutex.Unlock() - return l -} - -func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.err != nil || s.dataForWriting == nil { - return nil - } - - // TODO(#657): Flow control for the crypto stream - if s.streamID != s.version.CryptoStreamID() { - maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) - } - if maxBytes == 0 { - return nil - } - - var ret []byte - if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { - ret = s.dataForWriting[:maxBytes] - s.dataForWriting = s.dataForWriting[maxBytes:] - } else { - ret = s.dataForWriting - s.dataForWriting = nil - s.signalWrite() - } - s.writeOffset += protocol.ByteCount(len(ret)) - s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) - return ret -} - -// Close implements io.Closer func (s *stream) Close() error { - s.finishedWriting.Set(true) - s.ctxCancel() - s.onData() - return nil -} - -func (s *stream) shouldSendReset() bool { - if s.rstSent.Get() { - return false - } - return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() -} - -func (s *stream) ShouldSendFin() bool { - s.mutex.Lock() - res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil - s.mutex.Unlock() - return res -} - -func (s *stream) SentFin() { - s.finSent.Set(true) -} - -// AddStreamFrame adds a new stream frame -func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error { - maxOffset := frame.Offset + frame.DataLen() - if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { + if err := s.sendStream.Close(); err != nil { return err } - - s.mutex.Lock() - defer s.mutex.Unlock() - if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { - return err - } - s.signalRead() - return nil -} - -// signalRead performs a non-blocking send on the readChan -func (s *stream) signalRead() { - select { - case s.readChan <- struct{}{}: - default: - } -} - -// signalRead performs a non-blocking send on the writeChan -func (s *stream) signalWrite() { - select { - case s.writeChan <- struct{}{}: - default: - } -} - -func (s *stream) SetReadDeadline(t time.Time) error { - s.mutex.Lock() - oldDeadline := s.readDeadline - s.readDeadline = t - s.mutex.Unlock() - // if the new deadline is before the currently set deadline, wake up Read() - if t.Before(oldDeadline) { - s.signalRead() - } - return nil -} - -func (s *stream) SetWriteDeadline(t time.Time) error { - s.mutex.Lock() - oldDeadline := s.writeDeadline - s.writeDeadline = t - s.mutex.Unlock() - if t.Before(oldDeadline) { - s.signalWrite() - } + // in gQUIC, we need to send a RST_STREAM with the final offset if CancelRead() was called + s.receiveStream.onClose(s.sendStream.getWriteOffset()) return nil } @@ -385,99 +145,31 @@ func (s *stream) SetDeadline(t time.Time) error { return nil } -// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset -func (s *stream) CloseRemote(offset protocol.ByteCount) { - s.AddStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) +// CloseForShutdown closes a stream abruptly. +// It makes Read and Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *stream) closeForShutdown(err error) { + s.sendStream.closeForShutdown(err) + s.receiveStream.closeForShutdown(err) } -// Cancel is called by session to indicate that an error occurred -// The stream should will be closed immediately -func (s *stream) Cancel(err error) { - s.mutex.Lock() - s.cancelled.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalRead() - s.signalWrite() - } - s.mutex.Unlock() -} - -// resets the stream locally -func (s *stream) Reset(err error) { - if s.resetLocally.Get() { - return - } - s.mutex.Lock() - s.resetLocally.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalRead() - s.signalWrite() - } - if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) - s.rstSent.Set(true) - } - s.mutex.Unlock() -} - -// resets the stream remotely -func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error { - if s.resetRemotely.Get() { - return nil - } - s.mutex.Lock() - s.resetRemotely.Set(true) - s.ctxCancel() - // errors must not be changed! - if s.err == nil { - s.err = err - s.signalWrite() - } - if err := s.flowController.UpdateHighestReceived(offset, true); err != nil { +func (s *stream) handleRstStreamFrame(frame *wire.RstStreamFrame) error { + if err := s.receiveStream.handleRstStreamFrame(frame); err != nil { return err } - if s.shouldSendReset() { - s.onReset(s.streamID, s.writeOffset) - s.rstSent.Set(true) + if !s.version.UsesIETFFrameFormat() { + s.handleStopSendingFrame(&wire.StopSendingFrame{ + StreamID: s.StreamID(), + ErrorCode: frame.ErrorCode, + }) } - s.mutex.Unlock() return nil } -func (s *stream) finishedWriteAndSentFin() bool { - return s.finishedWriting.Get() && s.finSent.Get() -} - -func (s *stream) Finished() bool { - return s.cancelled.Get() || - (s.finishedReading.Get() && s.finishedWriteAndSentFin()) || - (s.resetRemotely.Get() && s.rstSent.Get()) || - (s.finishedReading.Get() && s.rstSent.Get()) || - (s.finishedWriteAndSentFin() && s.resetRemotely.Get()) -} - -func (s *stream) Context() context.Context { - return s.ctx -} - -func (s *stream) StreamID() protocol.StreamID { - return s.streamID -} - -func (s *stream) UpdateSendWindow(n protocol.ByteCount) { - s.flowController.UpdateSendWindow(n) -} - -func (s *stream) IsFlowControlBlocked() bool { - return s.flowController.IsBlocked() -} - -func (s *stream) GetWindowUpdate() protocol.ByteCount { - return s.flowController.GetWindowUpdate() +// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. +// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. +func (s *stream) checkIfCompleted() { + if s.sendStreamCompleted && s.receiveStreamCompleted { + s.sender.onStreamCompleted(s.StreamID()) + } } diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go index 8928e49..c453f86 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go @@ -1,174 +1,98 @@ package quic import ( - "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "sync" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFramer struct { - streamsMap *streamsMap - cryptoStream streamI + streamGetter streamGetter + cryptoStream cryptoStreamI + version protocol.VersionNumber - connFlowController flowcontrol.ConnectionFlowController - - retransmissionQueue []*wire.StreamFrame - blockedFrameQueue []wire.Frame + streamQueueMutex sync.Mutex + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID + hasCryptoStreamData bool } func newStreamFramer( - cryptoStream streamI, - streamsMap *streamsMap, - cfc flowcontrol.ConnectionFlowController, + cryptoStream cryptoStreamI, + streamGetter streamGetter, + v protocol.VersionNumber, ) *streamFramer { return &streamFramer{ - streamsMap: streamsMap, - cryptoStream: cryptoStream, - connFlowController: cfc, + streamGetter: streamGetter, + cryptoStream: cryptoStream, + activeStreams: make(map[protocol.StreamID]struct{}), + version: v, } } -func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) { - f.retransmissionQueue = append(f.retransmissionQueue, frame) -} - -func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame { - fs, currentLen := f.maybePopFramesForRetransmission(maxLen) - return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) -} - -func (f *streamFramer) PopBlockedFrame() wire.Frame { - if len(f.blockedFrameQueue) == 0 { - return nil +func (f *streamFramer) AddActiveStream(id protocol.StreamID) { + if id == f.version.CryptoStreamID() { // the crypto stream is handled separately + f.streamQueueMutex.Lock() + f.hasCryptoStreamData = true + f.streamQueueMutex.Unlock() + return } - frame := f.blockedFrameQueue[0] - f.blockedFrameQueue = f.blockedFrameQueue[1:] - return frame + f.streamQueueMutex.Lock() + if _, ok := f.activeStreams[id]; !ok { + f.streamQueue = append(f.streamQueue, id) + f.activeStreams[id] = struct{}{} + } + f.streamQueueMutex.Unlock() } -func (f *streamFramer) HasFramesForRetransmission() bool { - return len(f.retransmissionQueue) > 0 +func (f *streamFramer) HasCryptoStreamData() bool { + f.streamQueueMutex.Lock() + hasCryptoStreamData := f.hasCryptoStreamData + f.streamQueueMutex.Unlock() + return hasCryptoStreamData } -func (f *streamFramer) HasCryptoStreamFrame() bool { - return f.cryptoStream.LenOfDataForWriting() > 0 -} - -// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { - if !f.HasCryptoStreamFrame() { - return nil - } - frame := &wire.StreamFrame{ - StreamID: f.cryptoStream.StreamID(), - Offset: f.cryptoStream.GetWriteOffset(), - } - frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error - frame.Data = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) + f.streamQueueMutex.Lock() + frame, hasMoreData := f.cryptoStream.popStreamFrame(maxLen) + f.hasCryptoStreamData = hasMoreData + f.streamQueueMutex.Unlock() return frame } -func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { - for len(f.retransmissionQueue) > 0 { - frame := f.retransmissionQueue[0] - frame.DataLenPresent = true - - frameHeaderLen, _ := frame.MinLength(protocol.VersionWhatever) // can never error - if currentLen+frameHeaderLen >= maxLen { - break - } - - currentLen += frameHeaderLen - - splitFrame := maybeSplitOffFrame(frame, maxLen-currentLen) - if splitFrame != nil { // StreamFrame was split - res = append(res, splitFrame) - currentLen += splitFrame.DataLen() - break - } - - f.retransmissionQueue = f.retransmissionQueue[1:] - res = append(res, frame) - currentLen += frame.DataLen() - } - return -} - -func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*wire.StreamFrame) { - frame := &wire.StreamFrame{DataLenPresent: true} +func (f *streamFramer) PopStreamFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame { var currentLen protocol.ByteCount - - fn := func(s streamI) (bool, error) { - if s == nil { - return true, nil + var frames []*wire.StreamFrame + f.streamQueueMutex.Lock() + // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet + numActiveStreams := len(f.streamQueue) + for i := 0; i < numActiveStreams; i++ { + if maxTotalLen-currentLen < protocol.MinStreamFrameSize { + break } - - frame.StreamID = s.StreamID() - frame.Offset = s.GetWriteOffset() - // not perfect, but thread-safe since writeOffset is only written when getting data - frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error - if currentLen+frameHeaderBytes > maxBytes { - return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here + id := f.streamQueue[0] + f.streamQueue = f.streamQueue[1:] + // This should never return an error. Better check it anyway. + // The stream will only be in the streamQueue, if it enqueued itself there. + str, err := f.streamGetter.GetOrOpenSendStream(id) + // The stream can be nil if it completed after it said it had data. + if str == nil || err != nil { + delete(f.activeStreams, id) + continue } - maxLen := maxBytes - currentLen - frameHeaderBytes - - var data []byte - if s.LenOfDataForWriting() > 0 { - data = s.GetDataForWriting(maxLen) + frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen) + if hasMoreData { // put the stream back in the queue (at the end) + f.streamQueue = append(f.streamQueue, id) + } else { // no more data to send. Stream is not active any more + delete(f.activeStreams, id) } - - // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. - shouldSendFin := s.ShouldSendFin() - if data == nil && !shouldSendFin { - return true, nil + if frame == nil { // can happen if the receiveStream was canceled after it said it had data + continue } - - if shouldSendFin { - frame.FinBit = true - s.SentFin() - } - - frame.Data = data - - // Finally, check if we are now FC blocked and should queue a BLOCKED frame - if !frame.FinBit && s.IsFlowControlBlocked() { - f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()}) - } - if f.connFlowController.IsBlocked() { - f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{}) - } - - res = append(res, frame) - currentLen += frameHeaderBytes + frame.DataLen() - - if currentLen == maxBytes { - return false, nil - } - - frame = &wire.StreamFrame{DataLenPresent: true} - return true, nil - } - - f.streamsMap.RoundRobinIterate(fn) - return -} - -// maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified. -func maybeSplitOffFrame(frame *wire.StreamFrame, n protocol.ByteCount) *wire.StreamFrame { - if n >= frame.DataLen() { - return nil - } - - defer func() { - frame.Data = frame.Data[n:] - frame.Offset += n - }() - - return &wire.StreamFrame{ - FinBit: false, - StreamID: frame.StreamID, - Offset: frame.Offset, - Data: frame.Data[:n], - DataLenPresent: frame.DataLenPresent, + frames = append(frames, frame) + currentLen += frame.Length(f.version) } + f.streamQueueMutex.Unlock() + return frames } diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map.go b/vendor/github.com/lucas-clemente/quic-go/streams_map.go index df5b4c9..b9a56d6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map.go +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map.go @@ -1,344 +1,224 @@ package quic import ( - "errors" "fmt" - "sync" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type streamType int + +const ( + streamTypeOutgoingBidi streamType = iota + streamTypeIncomingBidi + streamTypeOutgoingUni + streamTypeIncomingUni ) type streamsMap struct { - mutex sync.RWMutex - perspective protocol.Perspective - streams map[protocol.StreamID]streamI - // needed for round-robin scheduling - openStreams []protocol.StreamID - roundRobinIndex int + sender streamSender + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController - nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() - highestStreamOpenedByPeer protocol.StreamID - nextStreamOrErrCond sync.Cond - openStreamOrErrCond sync.Cond - - closeErr error - nextStreamToAccept protocol.StreamID - - newStream newStreamLambda - - numOutgoingStreams uint32 - numIncomingStreams uint32 - maxIncomingStreams uint32 - maxOutgoingStreams uint32 + outgoingBidiStreams *outgoingBidiStreamsMap + outgoingUniStreams *outgoingUniStreamsMap + incomingBidiStreams *incomingBidiStreamsMap + incomingUniStreams *incomingUniStreamsMap } -type streamLambda func(streamI) (bool, error) -type newStreamLambda func(protocol.StreamID) streamI +var _ streamManager = &streamsMap{} -var errMapAccess = errors.New("streamsMap: Error accessing the streams map") - -func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap { - // add some tolerance to the maximum incoming streams value - maxStreams := uint32(protocol.MaxIncomingStreams) - maxIncomingStreams := utils.MaxUint32( - maxStreams+protocol.MaxStreamsMinimumIncrement, - uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), +func newStreamsMap( + sender streamSender, + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, + maxIncomingStreams int, + maxIncomingUniStreams int, + perspective protocol.Perspective, + version protocol.VersionNumber, +) streamManager { + m := &streamsMap{ + perspective: perspective, + newFlowController: newFlowController, + sender: sender, + } + var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID + if perspective == protocol.PerspectiveServer { + firstOutgoingBidiStream = 1 + firstIncomingBidiStream = 4 // the crypto stream is handled separately + firstOutgoingUniStream = 3 + firstIncomingUniStream = 2 + } else { + firstOutgoingBidiStream = 4 // the crypto stream is handled separately + firstIncomingBidiStream = 1 + firstOutgoingUniStream = 2 + firstIncomingUniStream = 3 + } + newBidiStream := func(id protocol.StreamID) streamI { + return newStream(id, m.sender, m.newFlowController(id), version) + } + newUniSendStream := func(id protocol.StreamID) sendStreamI { + return newSendStream(id, m.sender, m.newFlowController(id), version) + } + newUniReceiveStream := func(id protocol.StreamID) receiveStreamI { + return newReceiveStream(id, m.sender, m.newFlowController(id), version) + } + m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + firstOutgoingBidiStream, + newBidiStream, + sender.queueControlFrame, ) - sm := streamsMap{ - perspective: pers, - streams: make(map[protocol.StreamID]streamI), - openStreams: make([]protocol.StreamID, 0), - newStream: newStream, - maxIncomingStreams: maxIncomingStreams, - } - sm.nextStreamOrErrCond.L = &sm.mutex - sm.openStreamOrErrCond.L = &sm.mutex - - nextOddStream := protocol.StreamID(1) - if ver.CryptoStreamID() == protocol.StreamID(1) { - nextOddStream = 3 - } - if pers == protocol.PerspectiveClient { - sm.nextStream = nextOddStream - sm.nextStreamToAccept = 2 - } else { - sm.nextStream = 2 - sm.nextStreamToAccept = nextOddStream - } - - return &sm + m.incomingBidiStreams = newIncomingBidiStreamsMap( + firstIncomingBidiStream, + protocol.MaxBidiStreamID(maxIncomingStreams, perspective), + maxIncomingStreams, + sender.queueControlFrame, + newBidiStream, + ) + m.outgoingUniStreams = newOutgoingUniStreamsMap( + firstOutgoingUniStream, + newUniSendStream, + sender.queueControlFrame, + ) + m.incomingUniStreams = newIncomingUniStreamsMap( + firstIncomingUniStream, + protocol.MaxUniStreamID(maxIncomingUniStreams, perspective), + maxIncomingUniStreams, + sender.queueControlFrame, + newUniReceiveStream, + ) + return m } -// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. -// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. -func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { - m.mutex.RLock() - s, ok := m.streams[id] - m.mutex.RUnlock() - if ok { - return s, nil // s may be nil - } - - // ... we don't have an existing stream - m.mutex.Lock() - defer m.mutex.Unlock() - // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). - s, ok = m.streams[id] - if ok { - return s, nil - } - +func (m *streamsMap) getStreamType(id protocol.StreamID) streamType { if m.perspective == protocol.PerspectiveServer { - if id%2 == 0 { - if id <= m.nextStream { // this is a server-side stream that we already opened. Must have been closed already - return nil, nil - } - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) + switch id % 4 { + case 0: + return streamTypeIncomingBidi + case 1: + return streamTypeOutgoingBidi + case 2: + return streamTypeIncomingUni + case 3: + return streamTypeOutgoingUni } - if id <= m.highestStreamOpenedByPeer { // this is a client-side stream that doesn't exist anymore. Must have been closed already - return nil, nil - } - } - if m.perspective == protocol.PerspectiveClient { - if id%2 == 1 { - if id <= m.nextStream { // this is a client-side stream that we already opened. - return nil, nil - } - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) - } - if id <= m.highestStreamOpenedByPeer { // this is a server-side stream that doesn't exist anymore. Must have been closed already - return nil, nil - } - } - - // sid is the next stream that will be opened - sid := m.highestStreamOpenedByPeer + 2 - // if there is no stream opened yet, and this is the server, stream 1 should be openend - if sid == 2 && m.perspective == protocol.PerspectiveServer { - sid = 1 - } - - for ; sid <= id; sid += 2 { - _, err := m.openRemoteStream(sid) - if err != nil { - return nil, err - } - } - - m.nextStreamOrErrCond.Broadcast() - return m.streams[id], nil -} - -func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) { - if m.numIncomingStreams >= m.maxIncomingStreams { - return nil, qerr.TooManyOpenStreams - } - if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { - return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) - } - - if m.perspective == protocol.PerspectiveServer { - m.numIncomingStreams++ } else { - m.numOutgoingStreams++ + switch id % 4 { + case 0: + return streamTypeOutgoingBidi + case 1: + return streamTypeIncomingBidi + case 2: + return streamTypeOutgoingUni + case 3: + return streamTypeIncomingUni + } } - - if id > m.highestStreamOpenedByPeer { - m.highestStreamOpenedByPeer = id - } - - s := m.newStream(id) - m.putStream(s) - return s, nil + panic("") } -func (m *streamsMap) openStreamImpl() (streamI, error) { - id := m.nextStream - if m.numOutgoingStreams >= m.maxOutgoingStreams { - return nil, qerr.TooManyOpenStreams - } - - if m.perspective == protocol.PerspectiveServer { - m.numOutgoingStreams++ - } else { - m.numIncomingStreams++ - } - - m.nextStream += 2 - s := m.newStream(id) - m.putStream(s) - return s, nil +func (m *streamsMap) OpenStream() (Stream, error) { + return m.outgoingBidiStreams.OpenStream() } -// OpenStream opens the next available stream -func (m *streamsMap) OpenStream() (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - return m.openStreamImpl() +func (m *streamsMap) OpenStreamSync() (Stream, error) { + return m.outgoingBidiStreams.OpenStreamSync() } -func (m *streamsMap) OpenStreamSync() (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() +func (m *streamsMap) OpenUniStream() (SendStream, error) { + return m.outgoingUniStreams.OpenStream() +} - for { - if m.closeErr != nil { - return nil, m.closeErr - } - str, err := m.openStreamImpl() - if err == nil { - return str, err - } - if err != nil && err != qerr.TooManyOpenStreams { - return nil, err - } - m.openStreamOrErrCond.Wait() +func (m *streamsMap) OpenUniStreamSync() (SendStream, error) { + return m.outgoingUniStreams.OpenStreamSync() +} + +func (m *streamsMap) AcceptStream() (Stream, error) { + return m.incomingBidiStreams.AcceptStream() +} + +func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) { + return m.incomingUniStreams.AcceptStream() +} + +func (m *streamsMap) DeleteStream(id protocol.StreamID) error { + switch m.getStreamType(id) { + case streamTypeIncomingBidi: + return m.incomingBidiStreams.DeleteStream(id) + case streamTypeOutgoingBidi: + return m.outgoingBidiStreams.DeleteStream(id) + case streamTypeIncomingUni: + return m.incomingUniStreams.DeleteStream(id) + case streamTypeOutgoingUni: + return m.outgoingUniStreams.DeleteStream(id) + default: + panic("invalid stream type") } } -// AcceptStream returns the next stream opened by the peer -// it blocks until a new stream is opened -func (m *streamsMap) AcceptStream() (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - var str streamI - for { - var ok bool - if m.closeErr != nil { - return nil, m.closeErr - } - str, ok = m.streams[m.nextStreamToAccept] - if ok { - break - } - m.nextStreamOrErrCond.Wait() +func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + switch m.getStreamType(id) { + case streamTypeOutgoingBidi: + return m.outgoingBidiStreams.GetStream(id) + case streamTypeIncomingBidi: + return m.incomingBidiStreams.GetOrOpenStream(id) + case streamTypeIncomingUni: + return m.incomingUniStreams.GetOrOpenStream(id) + case streamTypeOutgoingUni: + // an outgoing unidirectional stream is a send stream, not a receive stream + return nil, fmt.Errorf("peer attempted to open receive stream %d", id) + default: + panic("invalid stream type") } - m.nextStreamToAccept += 2 - return str, nil } -func (m *streamsMap) DeleteClosedStreams() error { - m.mutex.Lock() - defer m.mutex.Unlock() - - var numDeletedStreams int - // for every closed stream, the streamID is replaced by 0 in the openStreams slice - for i, streamID := range m.openStreams { - str, ok := m.streams[streamID] - if !ok { - return errMapAccess - } - if !str.Finished() { - continue - } - numDeletedStreams++ - m.openStreams[i] = 0 - if streamID%2 == 0 { - m.numOutgoingStreams-- - } else { - m.numIncomingStreams-- - } - delete(m.streams, streamID) +func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + switch m.getStreamType(id) { + case streamTypeOutgoingBidi: + return m.outgoingBidiStreams.GetStream(id) + case streamTypeIncomingBidi: + return m.incomingBidiStreams.GetOrOpenStream(id) + case streamTypeOutgoingUni: + return m.outgoingUniStreams.GetStream(id) + case streamTypeIncomingUni: + // an incoming unidirectional stream is a receive stream, not a send stream + return nil, fmt.Errorf("peer attempted to open send stream %d", id) + default: + panic("invalid stream type") } +} - if numDeletedStreams == 0 { +func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { + id := f.StreamID + switch m.getStreamType(id) { + case streamTypeOutgoingBidi: + m.outgoingBidiStreams.SetMaxStream(id) return nil - } - - // remove all 0s (representing closed streams) from the openStreams slice - // and adjust the roundRobinIndex - var j int - for i, id := range m.openStreams { - if i != j { - m.openStreams[j] = m.openStreams[i] - } - if id != 0 { - j++ - } else if j < m.roundRobinIndex { - m.roundRobinIndex-- - } - } - m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams] - m.openStreamOrErrCond.Signal() - return nil -} - -// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false -// It uses a round-robin-like scheduling to ensure that every stream is considered fairly -// It prioritizes the the header-stream (StreamID 3) -func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - numStreams := len(m.streams) - startIndex := m.roundRobinIndex - - for i := 0; i < numStreams; i++ { - streamID := m.openStreams[(i+startIndex)%numStreams] - cont, err := m.iterateFunc(streamID, fn) - if err != nil { - return err - } - m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams - if !cont { - break - } - } - return nil -} - -// Range executes a callback for all streams, in pseudo-random order -func (m *streamsMap) Range(cb func(s streamI)) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - for _, s := range m.streams { - if s != nil { - cb(s) - } + case streamTypeOutgoingUni: + m.outgoingUniStreams.SetMaxStream(id) + return nil + default: + return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) } } -func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { - str, ok := m.streams[streamID] - if !ok { - return true, errMapAccess +func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) { + // Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open. + // Invert the perspective to determine the value that we are allowed to open. + peerPers := protocol.PerspectiveServer + if m.perspective == protocol.PerspectiveServer { + peerPers = protocol.PerspectiveClient } - return fn(str) -} - -func (m *streamsMap) putStream(s streamI) error { - id := s.StreamID() - if _, ok := m.streams[id]; ok { - return fmt.Errorf("a stream with ID %d already exists", id) - } - - m.streams[id] = s - m.openStreams = append(m.openStreams, id) - return nil + m.outgoingBidiStreams.SetMaxStream(protocol.MaxBidiStreamID(int(p.MaxBidiStreams), peerPers)) + m.outgoingUniStreams.SetMaxStream(protocol.MaxUniStreamID(int(p.MaxUniStreams), peerPers)) } func (m *streamsMap) CloseWithError(err error) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.closeErr = err - m.nextStreamOrErrCond.Broadcast() - m.openStreamOrErrCond.Broadcast() - for _, s := range m.openStreams { - m.streams[s].Cancel(err) - } -} - -func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.maxOutgoingStreams = limit + m.outgoingBidiStreams.CloseWithError(err) + m.outgoingUniStreams.CloseWithError(err) + m.incomingBidiStreams.CloseWithError(err) + m.incomingUniStreams.CloseWithError(err) } diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go new file mode 100644 index 0000000..f48db21 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go @@ -0,0 +1,11 @@ +package quic + +import "github.com/cheekybits/genny/generic" + +// In the auto-generated streams maps, we need to be able to close the streams. +// Therefore, extend the generic.Type with the stream close method. +// This definition must be in a file that Genny doesn't process. +type item interface { + generic.Type + closeForShutdown(error) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go new file mode 100644 index 0000000..317f5e2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go @@ -0,0 +1,131 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type incomingBidiStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]streamI + + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) streamI + queueMaxStreamID func(*wire.MaxStreamIDFrame) + + closeErr error +} + +func newIncomingBidiStreamsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) streamI, +) *incomingBidiStreamsMap { + m := &incomingBidiStreamsMap{ + streams: make(map[protocol.StreamID]streamI), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str streamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { + m.mutex.RLock() + if id > m.maxStream { + m.mutex.RUnlock() + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } + return nil +} + +func (m *incomingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go new file mode 100644 index 0000000..58f1ccb --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go @@ -0,0 +1,129 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream" +//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream" +type incomingItemsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]item + + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) item + queueMaxStreamID func(*wire.MaxStreamIDFrame) + + closeErr error +} + +func newIncomingItemsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) item, +) *incomingItemsMap { + m := &incomingItemsMap{ + streams: make(map[protocol.StreamID]item), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingItemsMap) AcceptStream() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str item + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { + m.mutex.RLock() + if id > m.maxStream { + m.mutex.RUnlock() + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } + return nil +} + +func (m *incomingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go new file mode 100644 index 0000000..8e775aa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go @@ -0,0 +1,131 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type incomingUniStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]receiveStreamI + + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) receiveStreamI + queueMaxStreamID func(*wire.MaxStreamIDFrame) + + closeErr error +} + +func newIncomingUniStreamsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) receiveStreamI, +) *incomingUniStreamsMap { + m := &incomingUniStreamsMap{ + streams: make(map[protocol.StreamID]receiveStreamI), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + var str receiveStreamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStream] + if ok { + break + } + m.cond.Wait() + } + m.nextStream += 4 + return str, nil +} + +func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) { + m.mutex.RLock() + if id > m.maxStream { + m.mutex.RUnlock() + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } + // if the id is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if id <= m.highestStream { + s := m.streams[id] + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + var start protocol.StreamID + if m.highestStream == 0 { + start = m.nextStream + } else { + start = m.highestStream + 4 + } + for newID := start; newID <= id; newID += 4 { + m.streams[newID] = m.newStream(newID) + m.cond.Signal() + } + m.highestStream = id + s := m.streams[id] + m.mutex.Unlock() + return s, nil +} + +func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } + return nil +} + +func (m *incomingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + m.mutex.Unlock() + m.cond.Broadcast() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go new file mode 100644 index 0000000..240eeea --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_legacy.go @@ -0,0 +1,277 @@ +package quic + +import ( + "errors" + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type streamsMapLegacy struct { + mutex sync.RWMutex + + perspective protocol.Perspective + + streams map[protocol.StreamID]streamI + + nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() + highestStreamOpenedByPeer protocol.StreamID + nextStreamOrErrCond sync.Cond + openStreamOrErrCond sync.Cond + + closeErr error + nextStreamToAccept protocol.StreamID + + newStream func(protocol.StreamID) streamI + + numOutgoingStreams uint32 + numIncomingStreams uint32 + maxIncomingStreams uint32 + maxOutgoingStreams uint32 +} + +var _ streamManager = &streamsMapLegacy{} + +var errMapAccess = errors.New("streamsMap: Error accessing the streams map") + +func newStreamsMapLegacy(newStream func(protocol.StreamID) streamI, maxStreams int, pers protocol.Perspective) streamManager { + // add some tolerance to the maximum incoming streams value + maxIncomingStreams := utils.MaxUint32( + uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, + uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), + ) + sm := streamsMapLegacy{ + perspective: pers, + streams: make(map[protocol.StreamID]streamI), + newStream: newStream, + maxIncomingStreams: maxIncomingStreams, + } + sm.nextStreamOrErrCond.L = &sm.mutex + sm.openStreamOrErrCond.L = &sm.mutex + + nextServerInitiatedStream := protocol.StreamID(2) + nextClientInitiatedStream := protocol.StreamID(3) + if pers == protocol.PerspectiveServer { + sm.highestStreamOpenedByPeer = 1 + } + if pers == protocol.PerspectiveServer { + sm.nextStreamToOpen = nextServerInitiatedStream + sm.nextStreamToAccept = nextClientInitiatedStream + } else { + sm.nextStreamToOpen = nextClientInitiatedStream + sm.nextStreamToAccept = nextServerInitiatedStream + } + return &sm +} + +// getStreamPerspective says which side should initiate a stream +func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective { + if id%2 == 0 { + return protocol.PerspectiveServer + } + return protocol.PerspectiveClient +} + +func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + // every bidirectional stream is also a receive stream + return m.getOrOpenStream(id) +} + +func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + // every bidirectional stream is also a send stream + return m.getOrOpenStream(id) +} + +// getOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. +// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. +func (m *streamsMapLegacy) getOrOpenStream(id protocol.StreamID) (streamI, error) { + m.mutex.RLock() + s, ok := m.streams[id] + m.mutex.RUnlock() + if ok { + return s, nil + } + + // ... we don't have an existing stream + m.mutex.Lock() + defer m.mutex.Unlock() + // We need to check whether another invocation has already created a stream (between RUnlock() and Lock()). + s, ok = m.streams[id] + if ok { + return s, nil + } + + if m.perspective == m.streamInitiatedBy(id) { + if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already + return nil, nil + } + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already + return nil, nil + } + + for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 { + if _, err := m.openRemoteStream(sid); err != nil { + return nil, err + } + } + + m.nextStreamOrErrCond.Broadcast() + return m.streams[id], nil +} + +func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) { + if m.numIncomingStreams >= m.maxIncomingStreams { + return nil, qerr.TooManyOpenStreams + } + // maxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened + // note that the number of streams is half this value, since the client can only open streams with open StreamID + maxStreamIDDelta := protocol.StreamID(4 * m.maxIncomingStreams) + if id+maxStreamIDDelta < m.highestStreamOpenedByPeer { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) + } + + m.numIncomingStreams++ + if id > m.highestStreamOpenedByPeer { + m.highestStreamOpenedByPeer = id + } + + s := m.newStream(id) + return s, m.putStream(s) +} + +func (m *streamsMapLegacy) openStreamImpl() (streamI, error) { + if m.numOutgoingStreams >= m.maxOutgoingStreams { + return nil, qerr.TooManyOpenStreams + } + + m.numOutgoingStreams++ + s := m.newStream(m.nextStreamToOpen) + m.nextStreamToOpen += 2 + return s, m.putStream(s) +} + +// OpenStream opens the next available stream +func (m *streamsMapLegacy) OpenStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + return m.openStreamImpl() +} + +func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + if m.closeErr != nil { + return nil, m.closeErr + } + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.openStreamOrErrCond.Wait() + } +} + +func (m *streamsMapLegacy) OpenUniStream() (SendStream, error) { + return nil, errors.New("gQUIC doesn't support unidirectional streams") +} + +func (m *streamsMapLegacy) OpenUniStreamSync() (SendStream, error) { + return nil, errors.New("gQUIC doesn't support unidirectional streams") +} + +// AcceptStream returns the next stream opened by the peer +// it blocks until a new stream is opened +func (m *streamsMapLegacy) AcceptStream() (Stream, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + var str streamI + for { + var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } + str, ok = m.streams[m.nextStreamToAccept] + if ok { + break + } + m.nextStreamOrErrCond.Wait() + } + m.nextStreamToAccept += 2 + return str, nil +} + +func (m *streamsMapLegacy) AcceptUniStream() (ReceiveStream, error) { + return nil, errors.New("gQUIC doesn't support unidirectional streams") +} + +func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + _, ok := m.streams[id] + if !ok { + return errMapAccess + } + delete(m.streams, id) + if m.streamInitiatedBy(id) == m.perspective { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- + } + m.openStreamOrErrCond.Signal() + return nil +} + +func (m *streamsMapLegacy) putStream(s streamI) error { + id := s.StreamID() + if _, ok := m.streams[id]; ok { + return fmt.Errorf("a stream with ID %d already exists", id) + } + m.streams[id] = s + return nil +} + +func (m *streamsMapLegacy) CloseWithError(err error) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.closeErr = err + m.nextStreamOrErrCond.Broadcast() + m.openStreamOrErrCond.Broadcast() + for _, s := range m.streams { + s.closeForShutdown(err) + } +} + +// TODO(#952): this won't be needed when gQUIC supports stateless handshakes +func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) { + m.mutex.Lock() + m.maxOutgoingStreams = params.MaxStreams + for id, str := range m.streams { + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: params.StreamFlowControlWindow, + }) + } + m.mutex.Unlock() + m.openStreamOrErrCond.Broadcast() +} + +// should never be called, since MAX_STREAM_ID frames can only be unpacked for IETF QUIC +func (m *streamsMapLegacy) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { + return errors.New("gQUIC doesn't have MAX_STREAM_ID frames") +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go new file mode 100644 index 0000000..ea9f47e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go @@ -0,0 +1,126 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type outgoingBidiStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]streamI + + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) streamI + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) + + closeErr error +} + +func newOutgoingBidiStreamsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) streamI, + queueControlFrame func(wire.Frame), +) *outgoingBidiStreamsMap { + m := &outgoingBidiStreamsMap{ + streams: make(map[protocol.StreamID]streamI), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.openStreamImpl() +} + +func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } + return nil, qerr.TooManyOpenStreams + } + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream += 4 + return s, nil +} + +func (m *outgoingBidiStreamsMap) GetStream(id protocol.StreamID) (streamI, error) { + m.mutex.RLock() + if id >= m.nextStream { + m.mutex.RUnlock() + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + s := m.streams[id] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + +func (m *outgoingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + m.cond.Broadcast() + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go new file mode 100644 index 0000000..f4b3eb6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go @@ -0,0 +1,124 @@ +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream" +//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream" +type outgoingItemsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]item + + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) item + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) + + closeErr error +} + +func newOutgoingItemsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) item, + queueControlFrame func(wire.Frame), +) *outgoingItemsMap { + m := &outgoingItemsMap{ + streams: make(map[protocol.StreamID]item), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *outgoingItemsMap) OpenStream() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.openStreamImpl() +} + +func (m *outgoingItemsMap) OpenStreamSync() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingItemsMap) openStreamImpl() (item, error) { + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } + return nil, qerr.TooManyOpenStreams + } + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream += 4 + return s, nil +} + +func (m *outgoingItemsMap) GetStream(id protocol.StreamID) (item, error) { + m.mutex.RLock() + if id >= m.nextStream { + m.mutex.RUnlock() + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + s := m.streams[id] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingItemsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + +func (m *outgoingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + m.cond.Broadcast() + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go new file mode 100644 index 0000000..6ad0348 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go @@ -0,0 +1,126 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" +) + +type outgoingUniStreamsMap struct { + mutex sync.RWMutex + cond sync.Cond + + streams map[protocol.StreamID]sendStreamI + + nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamID // the maximum stream ID we're allowed to open + highestBlocked protocol.StreamID // the highest stream ID that we queued a STREAM_ID_BLOCKED frame for + + newStream func(protocol.StreamID) sendStreamI + queueStreamIDBlocked func(*wire.StreamIDBlockedFrame) + + closeErr error +} + +func newOutgoingUniStreamsMap( + nextStream protocol.StreamID, + newStream func(protocol.StreamID) sendStreamI, + queueControlFrame func(wire.Frame), +) *outgoingUniStreamsMap { + m := &outgoingUniStreamsMap{ + streams: make(map[protocol.StreamID]sendStreamI), + nextStream: nextStream, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamIDBlockedFrame) { queueControlFrame(f) }, + } + m.cond.L = &m.mutex + return m +} + +func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.openStreamImpl() +} + +func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + if m.maxStream == 0 || m.highestBlocked < m.maxStream { + m.queueStreamIDBlocked(&wire.StreamIDBlockedFrame{StreamID: m.maxStream}) + m.highestBlocked = m.maxStream + } + return nil, qerr.TooManyOpenStreams + } + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream += 4 + return s, nil +} + +func (m *outgoingUniStreamsMap) GetStream(id protocol.StreamID) (sendStreamI, error) { + m.mutex.RLock() + if id >= m.nextStream { + m.mutex.RUnlock() + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) + } + s := m.streams[id] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingUniStreamsMap) DeleteStream(id protocol.StreamID) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[id]; !ok { + return fmt.Errorf("Tried to delete unknown stream %d", id) + } + delete(m.streams, id) + return nil +} + +func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + +func (m *outgoingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + m.cond.Broadcast() + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go new file mode 100644 index 0000000..dfbc45a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go @@ -0,0 +1,79 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type windowUpdateQueue struct { + mutex sync.Mutex + + queue map[protocol.StreamID]bool // used as a set + queuedConn bool // connection-level window update + + cryptoStream cryptoStreamI + streamGetter streamGetter + connFlowController flowcontrol.ConnectionFlowController + callback func(wire.Frame) +} + +func newWindowUpdateQueue( + streamGetter streamGetter, + cryptoStream cryptoStreamI, + connFC flowcontrol.ConnectionFlowController, + cb func(wire.Frame), +) *windowUpdateQueue { + return &windowUpdateQueue{ + queue: make(map[protocol.StreamID]bool), + streamGetter: streamGetter, + cryptoStream: cryptoStream, + connFlowController: connFC, + callback: cb, + } +} + +func (q *windowUpdateQueue) AddStream(id protocol.StreamID) { + q.mutex.Lock() + q.queue[id] = true + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) AddConnection() { + q.mutex.Lock() + q.queuedConn = true + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) QueueAll() { + q.mutex.Lock() + // queue a connection-level window update + if q.queuedConn { + q.callback(&wire.MaxDataFrame{ByteOffset: q.connFlowController.GetWindowUpdate()}) + q.queuedConn = false + } + // queue all stream-level window updates + var offset protocol.ByteCount + for id := range q.queue { + if id == q.cryptoStream.StreamID() { + offset = q.cryptoStream.getWindowUpdate() + } else { + str, err := q.streamGetter.GetOrOpenReceiveStream(id) + if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update + continue + } + offset = str.getWindowUpdate() + } + if offset == 0 { // can happen if we received a final offset, right after queueing the window update + continue + } + q.callback(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: offset, + }) + delete(q.queue, id) + } + q.mutex.Unlock() +} diff --git a/vendor/gopkg.in/sorcix/irc.v2/LICENSE b/vendor/gopkg.in/sorcix/irc.v2/LICENSE new file mode 100644 index 0000000..10cecc4 --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/LICENSE @@ -0,0 +1,22 @@ +Copyright 2014 Vic Demuzere + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/gopkg.in/sorcix/irc.v2/README.md b/vendor/gopkg.in/sorcix/irc.v2/README.md new file mode 100644 index 0000000..496e5fb --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/README.md @@ -0,0 +1,60 @@ +# Go **irc** package + +[![Build Status](https://travis-ci.org/sorcix/irc.svg?branch=v2)](https://travis-ci.org/sorcix/irc) +[![GoDoc](https://godoc.org/gopkg.in/sorcix/irc.v2?status.svg)](https://godoc.org/gopkg.in/sorcix/irc.v2) + +## Features +Package irc allows your application to speak the IRC protocol. + + - **Limited scope**, does one thing and does it well. + - Focus on simplicity and **speed**. + - **Stable API**: updates shouldn't break existing software. + - Well [documented][Documentation] code. + +*This package does not manage your entire IRC connection. It only translates the protocol to easy to use Go types. It is meant as a single component in a larger IRC library, or for basic IRC bots for which a large IRC package would be overkill.* + +## Usage + +``` +import "gopkg.in/sorcix/irc.v2" +``` + +### Message +The [Message][] and [Prefix][] types provide translation to and from IRC message format. + + // Parse the IRC-encoded data and stores the result in a new struct. + message := irc.ParseMessage(raw) + + // Returns the IRC encoding of the message. + raw = message.String() + +### Encoder & Decoder +The [Encoder][] and [Decoder][] types allow working with IRC message streams. + + // Create a decoder that reads from given io.Reader + dec := irc.NewDecoder(reader) + + // Decode the next IRC message + message, err := dec.Decode() + + // Create an encoder that writes to given io.Writer + enc := irc.NewEncoder(writer) + + // Send a message to the writer. + enc.Encode(message) + +### Conn +The [Conn][] type combines an [Encoder][] and [Decoder][] for a duplex connection. + + c, err := irc.Dial("irc.server.net:6667") + + // Methods from both Encoder and Decoder are available + message, err := c.Decode() + +[Documentation]: https://godoc.org/gopkg.in/sorcix/irc.v2 "Package documentation by Godoc.org" +[Message]: https://godoc.org/gopkg.in/sorcix/irc.v2#Message "Message type documentation" +[Prefix]: https://godoc.org/gopkg.in/sorcix/irc.v2#Prefix "Prefix type documentation" +[Encoder]: https://godoc.org/gopkg.in/sorcix/irc.v2#Encoder "Encoder type documentation" +[Decoder]: https://godoc.org/gopkg.in/sorcix/irc.v2#Decoder "Decoder type documentation" +[Conn]: https://godoc.org/gopkg.in/sorcix/irc.v2#Conn "Conn type documentation" +[RFC1459]: https://tools.ietf.org/html/rfc1459.html "RFC 1459" diff --git a/vendor/gopkg.in/sorcix/irc.v2/constants.go b/vendor/gopkg.in/sorcix/irc.v2/constants.go new file mode 100644 index 0000000..022fb3d --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/constants.go @@ -0,0 +1,298 @@ +// Copyright 2014 Vic Demuzere +// +// Use of this source code is governed by the MIT license. + +package irc + +// Various prefixes extracted from RFC 1459. +const ( + Channel = '#' // Normal channel + Distributed = '&' // Distributed channel + + Owner = '~' // Channel owner +q (non-standard) + Admin = '&' // Channel admin +a (non-standard) + Operator = '@' // Channel operator +o + HalfOperator = '%' // Channel half operator +h (non-standard) + Voice = '+' // User has voice +v +) + +// User modes as defined by RFC 1459 section 4.2.3.2. +const ( + UserModeInvisible = 'i' // User is invisible + UserModeServerNotices = 's' // User wants to receive server notices + UserModeWallops = 'w' // User wants to receive Wallops + UserModeOperator = 'o' // Server operator +) + +// Channel modes as defined by RFC 1459 section 4.2.3.1 +const ( + ModeOperator = 'o' // Operator privileges + ModeVoice = 'v' // Ability to speak on a moderated channel + ModePrivate = 'p' // Private channel + ModeSecret = 's' // Secret channel + ModeInviteOnly = 'i' // Users can't join without invite + ModeTopic = 't' // Topic can only be set by an operator + ModeModerated = 'm' // Only voiced users and operators can talk + ModeLimit = 'l' // User limit + ModeKey = 'k' // Channel password + + ModeOwner = 'q' // Owner privileges (non-standard) + ModeAdmin = 'a' // Admin privileges (non-standard) + ModeHalfOperator = 'h' // Half-operator privileges (non-standard) +) + +// IRC commands extracted from RFC 2812 section 3 and RFC 2813 section 4. +const ( + PASS = "PASS" + NICK = "NICK" + USER = "USER" + OPER = "OPER" + MODE = "MODE" + SERVICE = "SERVICE" + QUIT = "QUIT" + SQUIT = "SQUIT" + JOIN = "JOIN" + PART = "PART" + TOPIC = "TOPIC" + NAMES = "NAMES" + LIST = "LIST" + INVITE = "INVITE" + KICK = "KICK" + PRIVMSG = "PRIVMSG" + NOTICE = "NOTICE" + MOTD = "MOTD" + LUSERS = "LUSERS" + VERSION = "VERSION" + STATS = "STATS" + LINKS = "LINKS" + TIME = "TIME" + CONNECT = "CONNECT" + TRACE = "TRACE" + ADMIN = "ADMIN" + INFO = "INFO" + SERVLIST = "SERVLIST" + SQUERY = "SQUERY" + WHO = "WHO" + WHOIS = "WHOIS" + WHOWAS = "WHOWAS" + KILL = "KILL" + PING = "PING" + PONG = "PONG" + ERROR = "ERROR" + AWAY = "AWAY" + REHASH = "REHASH" + DIE = "DIE" + RESTART = "RESTART" + SUMMON = "SUMMON" + USERS = "USERS" + WALLOPS = "WALLOPS" + USERHOST = "USERHOST" + ISON = "ISON" + SERVER = "SERVER" + NJOIN = "NJOIN" +) + +// Numeric IRC replies extracted from RFC 2812 section 5. +const ( + RPL_WELCOME = "001" + RPL_YOURHOST = "002" + RPL_CREATED = "003" + RPL_MYINFO = "004" + RPL_BOUNCE = "005" + RPL_ISUPPORT = "005" + RPL_USERHOST = "302" + RPL_ISON = "303" + RPL_AWAY = "301" + RPL_UNAWAY = "305" + RPL_NOWAWAY = "306" + RPL_WHOISUSER = "311" + RPL_WHOISSERVER = "312" + RPL_WHOISOPERATOR = "313" + RPL_WHOISIDLE = "317" + RPL_ENDOFWHOIS = "318" + RPL_WHOISCHANNELS = "319" + RPL_WHOWASUSER = "314" + RPL_ENDOFWHOWAS = "369" + RPL_LISTSTART = "321" + RPL_LIST = "322" + RPL_LISTEND = "323" + RPL_UNIQOPIS = "325" + RPL_CHANNELMODEIS = "324" + RPL_NOTOPIC = "331" + RPL_TOPIC = "332" + RPL_INVITING = "341" + RPL_SUMMONING = "342" + RPL_INVITELIST = "346" + RPL_ENDOFINVITELIST = "347" + RPL_EXCEPTLIST = "348" + RPL_ENDOFEXCEPTLIST = "349" + RPL_VERSION = "351" + RPL_WHOREPLY = "352" + RPL_ENDOFWHO = "315" + RPL_NAMREPLY = "353" + RPL_ENDOFNAMES = "366" + RPL_LINKS = "364" + RPL_ENDOFLINKS = "365" + RPL_BANLIST = "367" + RPL_ENDOFBANLIST = "368" + RPL_INFO = "371" + RPL_ENDOFINFO = "374" + RPL_MOTDSTART = "375" + RPL_MOTD = "372" + RPL_ENDOFMOTD = "376" + RPL_YOUREOPER = "381" + RPL_REHASHING = "382" + RPL_YOURESERVICE = "383" + RPL_TIME = "391" + RPL_USERSSTART = "392" + RPL_USERS = "393" + RPL_ENDOFUSERS = "394" + RPL_NOUSERS = "395" + RPL_TRACELINK = "200" + RPL_TRACECONNECTING = "201" + RPL_TRACEHANDSHAKE = "202" + RPL_TRACEUNKNOWN = "203" + RPL_TRACEOPERATOR = "204" + RPL_TRACEUSER = "205" + RPL_TRACESERVER = "206" + RPL_TRACESERVICE = "207" + RPL_TRACENEWTYPE = "208" + RPL_TRACECLASS = "209" + RPL_TRACERECONNECT = "210" + RPL_TRACELOG = "261" + RPL_TRACEEND = "262" + RPL_STATSLINKINFO = "211" + RPL_STATSCOMMANDS = "212" + RPL_ENDOFSTATS = "219" + RPL_STATSUPTIME = "242" + RPL_STATSOLINE = "243" + RPL_UMODEIS = "221" + RPL_SERVLIST = "234" + RPL_SERVLISTEND = "235" + RPL_LUSERCLIENT = "251" + RPL_LUSEROP = "252" + RPL_LUSERUNKNOWN = "253" + RPL_LUSERCHANNELS = "254" + RPL_LUSERME = "255" + RPL_ADMINME = "256" + RPL_ADMINLOC1 = "257" + RPL_ADMINLOC2 = "258" + RPL_ADMINEMAIL = "259" + RPL_TRYAGAIN = "263" + ERR_NOSUCHNICK = "401" + ERR_NOSUCHSERVER = "402" + ERR_NOSUCHCHANNEL = "403" + ERR_CANNOTSENDTOCHAN = "404" + ERR_TOOMANYCHANNELS = "405" + ERR_WASNOSUCHNICK = "406" + ERR_TOOMANYTARGETS = "407" + ERR_NOSUCHSERVICE = "408" + ERR_NOORIGIN = "409" + ERR_NORECIPIENT = "411" + ERR_NOTEXTTOSEND = "412" + ERR_NOTOPLEVEL = "413" + ERR_WILDTOPLEVEL = "414" + ERR_BADMASK = "415" + ERR_UNKNOWNCOMMAND = "421" + ERR_NOMOTD = "422" + ERR_NOADMININFO = "423" + ERR_FILEERROR = "424" + ERR_NONICKNAMEGIVEN = "431" + ERR_ERRONEUSNICKNAME = "432" + ERR_NICKNAMEINUSE = "433" + ERR_NICKCOLLISION = "436" + ERR_UNAVAILRESOURCE = "437" + ERR_USERNOTINCHANNEL = "441" + ERR_NOTONCHANNEL = "442" + ERR_USERONCHANNEL = "443" + ERR_NOLOGIN = "444" + ERR_SUMMONDISABLED = "445" + ERR_USERSDISABLED = "446" + ERR_NOTREGISTERED = "451" + ERR_NEEDMOREPARAMS = "461" + ERR_ALREADYREGISTRED = "462" + ERR_NOPERMFORHOST = "463" + ERR_PASSWDMISMATCH = "464" + ERR_YOUREBANNEDCREEP = "465" + ERR_YOUWILLBEBANNED = "466" + ERR_KEYSET = "467" + ERR_CHANNELISFULL = "471" + ERR_UNKNOWNMODE = "472" + ERR_INVITEONLYCHAN = "473" + ERR_BANNEDFROMCHAN = "474" + ERR_BADCHANNELKEY = "475" + ERR_BADCHANMASK = "476" + ERR_NOCHANMODES = "477" + ERR_BANLISTFULL = "478" + ERR_NOPRIVILEGES = "481" + ERR_CHANOPRIVSNEEDED = "482" + ERR_CANTKILLSERVER = "483" + ERR_RESTRICTED = "484" + ERR_UNIQOPPRIVSNEEDED = "485" + ERR_NOOPERHOST = "491" + ERR_UMODEUNKNOWNFLAG = "501" + ERR_USERSDONTMATCH = "502" +) + +// IRC commands extracted from the IRCv3 spec at http://www.ircv3.org/. +const ( + CAP = "CAP" + CAP_LS = "LS" // Subcommand (param) + CAP_LIST = "LIST" // Subcommand (param) + CAP_REQ = "REQ" // Subcommand (param) + CAP_ACK = "ACK" // Subcommand (param) + CAP_NAK = "NAK" // Subcommand (param) + CAP_CLEAR = "CLEAR" // Subcommand (param) + CAP_END = "END" // Subcommand (param) + + AUTHENTICATE = "AUTHENTICATE" +) + +// Numeric IRC replies extracted from the IRCv3 spec. +const ( + RPL_LOGGEDIN = "900" + RPL_LOGGEDOUT = "901" + RPL_NICKLOCKED = "902" + RPL_SASLSUCCESS = "903" + ERR_SASLFAIL = "904" + ERR_SASLTOOLONG = "905" + ERR_SASLABORTED = "906" + ERR_SASLALREADY = "907" + RPL_SASLMECHS = "908" +) + +// RFC 2812 section 5.3 +const ( + RPL_STATSCLINE = "213" + RPL_STATSNLINE = "214" + RPL_STATSILINE = "215" + RPL_STATSKLINE = "216" + RPL_STATSQLINE = "217" + RPL_STATSYLINE = "218" + RPL_SERVICEINFO = "231" + RPL_ENDOFSERVICES = "232" + RPL_SERVICE = "233" + RPL_STATSVLINE = "240" + RPL_STATSLLINE = "241" + RPL_STATSHLINE = "244" + RPL_STATSSLINE = "245" + RPL_STATSPING = "246" + RPL_STATSBLINE = "247" + RPL_STATSDLINE = "250" + RPL_NONE = "300" + RPL_WHOISCHANOP = "316" + RPL_KILLDONE = "361" + RPL_CLOSING = "362" + RPL_CLOSEEND = "363" + RPL_INFOSTART = "373" + RPL_MYPORTIS = "384" + ERR_NOSERVICEHOST = "492" +) + +// Other constants +const ( + ERR_TOOMANYMATCHES = "416" // Used on IRCNet + RPL_TOPICWHOTIME = "333" // From ircu, in use on Freenode + RPL_LOCALUSERS = "265" // From aircd, Hybrid, Hybrid, Bahamut, in use on Freenode + RPL_GLOBALUSERS = "266" // From aircd, Hybrid, Hybrid, Bahamut, in use on Freenode +) diff --git a/vendor/gopkg.in/sorcix/irc.v2/doc.go b/vendor/gopkg.in/sorcix/irc.v2/doc.go new file mode 100644 index 0000000..a354664 --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/doc.go @@ -0,0 +1,36 @@ +// Copyright 2014 Vic Demuzere +// +// Use of this source code is governed by the MIT license. + +// Package irc allows your application to speak the IRC protocol. +// +// The Message and Prefix structs provide translation to and from raw IRC messages: +// +// // Parse the IRC-encoded data and store the result in a new struct: +// message := irc.ParseMessage(raw) +// +// // Translate back to a raw IRC message string: +// raw = message.String() +// +// Decoder and Encoder can be used to decode and encode messages in a stream: +// +// // Create a decoder that reads from given io.Reader +// dec := irc.NewDecoder(reader) +// +// // Decode the next IRC message +// message, err := dec.Decode() +// +// // Create an encoder that writes to given io.Writer +// enc := irc.NewEncoder(writer) +// +// // Send a message to the writer. +// enc.Encode(message) +// +// The Conn type combines an Encoder and Decoder for a duplex connection. +// +// c, err := irc.Dial("irc.server.net:6667") +// +// // Methods from both Encoder and Decoder are available +// message, err := c.Decode() +// +package irc // import "gopkg.in/sorcix/irc.v2" diff --git a/vendor/gopkg.in/sorcix/irc.v2/internal/strings.go b/vendor/gopkg.in/sorcix/irc.v2/internal/strings.go new file mode 100644 index 0000000..0021193 --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/internal/strings.go @@ -0,0 +1,19 @@ +// Copyright 2014 Vic Demuzere +// +// Use of this source code is governed by the MIT license. + +// +build go1.2 + +// Documented in strings_legacy.go + +package internal + +import ( + "strings" +) + +// IndexByte is a compatibility function so strings.IndexByte can be used in +// older versions of go. +func IndexByte(s string, c byte) int { + return strings.IndexByte(s, c) +} diff --git a/vendor/gopkg.in/sorcix/irc.v2/internal/strings_legacy.go b/vendor/gopkg.in/sorcix/irc.v2/internal/strings_legacy.go new file mode 100644 index 0000000..5cf8979 --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/internal/strings_legacy.go @@ -0,0 +1,22 @@ +// Copyright 2014 Vic Demuzere +// +// Use of this source code is governed by the MIT license. + +// +build !go1.2 + +// Debian Wheezy only ships Go 1.0: +// https://github.com/sorcix/irc/issues/4 +// +// This code may be removed when Wheezy is no longer supported. + +package internal + +// IndexByte implements strings.IndexByte for Go versions < 1.2. +func IndexByte(s string, c byte) int { + for i := range s { + if s[i] == c { + return i + } + } + return -1 +} diff --git a/vendor/gopkg.in/sorcix/irc.v2/message.go b/vendor/gopkg.in/sorcix/irc.v2/message.go new file mode 100644 index 0000000..318775d --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/message.go @@ -0,0 +1,314 @@ +// Copyright 2014 Vic Demuzere +// +// Use of this source code is governed by the MIT license. + +package irc + +import ( + "bytes" + "strings" + + "gopkg.in/sorcix/irc.v2/internal" +) + +// Various constants used for formatting IRC messages. +const ( + prefix byte = 0x3A // Prefix or last argument + prefixUser byte = 0x21 // Username + prefixHost byte = 0x40 // Hostname + space byte = 0x20 // Separator + + maxLength = 510 // Maximum length is 512 - 2 for the line endings. +) + +func cutsetFunc(r rune) bool { + // Characters to trim from prefixes/messages. + return r == '\r' || r == '\n' +} + +// Sender represents objects that are able to send messages to an IRC server. +// +// As there might be a message queue, it is possible that Send returns a nil +// error, but the message is not sent (yet). The error value is only used when +// it is certain that sending the message is impossible. +// +// This interface is not used inside this package, and shouldn't have been +// defined here in the first place. For backwards compatibility only. +type Sender interface { + Send(*Message) error +} + +// Prefix represents the prefix (sender) of an IRC message. +// See RFC1459 section 2.3.1. +// +// | [ '!' ] [ '@' ] +// +type Prefix struct { + Name string // Nick- or servername + User string // Username + Host string // Hostname +} + +// ParsePrefix takes a string and attempts to create a Prefix struct. +func ParsePrefix(raw string) (p *Prefix) { + + p = new(Prefix) + + user := internal.IndexByte(raw, prefixUser) + host := internal.IndexByte(raw, prefixHost) + + switch { + + case user > 0 && host > user: + p.Name = raw[:user] + p.User = raw[user+1 : host] + p.Host = raw[host+1:] + + case user > 0: + p.Name = raw[:user] + p.User = raw[user+1:] + + case host > 0: + p.Name = raw[:host] + p.Host = raw[host+1:] + + default: + p.Name = raw + + } + + return p +} + +// Len calculates the length of the string representation of this prefix. +func (p *Prefix) Len() (length int) { + length = len(p.Name) + if len(p.User) > 0 { + length = length + len(p.User) + 1 + } + if len(p.Host) > 0 { + length = length + len(p.Host) + 1 + } + return +} + +// Bytes returns a []byte representation of this prefix. +func (p *Prefix) Bytes() []byte { + buffer := new(bytes.Buffer) + p.writeTo(buffer) + return buffer.Bytes() +} + +// String returns a string representation of this prefix. +func (p *Prefix) String() (s string) { + // Benchmarks revealed that in this case simple string concatenation + // is actually faster than using a ByteBuffer as in (*Message).String() + s = p.Name + if len(p.User) > 0 { + s = s + string(prefixUser) + p.User + } + if len(p.Host) > 0 { + s = s + string(prefixHost) + p.Host + } + return +} + +// IsHostmask returns true if this prefix looks like a user hostmask. +func (p *Prefix) IsHostmask() bool { + return len(p.User) > 0 && len(p.Host) > 0 +} + +// IsServer returns true if this prefix looks like a server name. +func (p *Prefix) IsServer() bool { + return len(p.User) <= 0 && len(p.Host) <= 0 // && internal.IndexByte(p.Name, '.') > 0 +} + +// writeTo is an utility function to write the prefix to the bytes.Buffer in Message.String(). +func (p *Prefix) writeTo(buffer *bytes.Buffer) { + buffer.WriteString(p.Name) + if len(p.User) > 0 { + buffer.WriteByte(prefixUser) + buffer.WriteString(p.User) + } + if len(p.Host) > 0 { + buffer.WriteByte(prefixHost) + buffer.WriteString(p.Host) + } + return +} + +// Message represents an IRC protocol message. +// See RFC1459 section 2.3.1. +// +// ::= [':' ] +// ::= | [ '!' ] [ '@' ] +// ::= { } | +// ::= ' ' { ' ' } +// ::= [ ':' | ] +// +// ::= +// ::= +// +// ::= CR LF +type Message struct { + *Prefix + Command string + Params []string +} + +func (m *Message) Trailing() string { + if len(m.Params) > 0 { + return m.Params[len(m.Params)-1] + } + return "" +} + +// ParseMessage takes a string and attempts to create a Message struct. +// Returns nil if the Message is invalid. +func ParseMessage(raw string) (m *Message) { + + // Ignore empty messages. + if raw = strings.TrimFunc(raw, cutsetFunc); len(raw) < 2 { + return nil + } + + i, j := 0, 0 + + m = new(Message) + + if raw[0] == prefix { + + // Prefix ends with a space. + i = internal.IndexByte(raw, space) + + // Prefix string must not be empty if the indicator is present. + if i < 2 { + return nil + } + + m.Prefix = ParsePrefix(raw[1:i]) + + // Skip space at the end of the prefix + i++ + } + + // Find end of command + j = i + internal.IndexByte(raw[i:], space) + + // Extract command + if j > i { + m.Command = strings.ToUpper(raw[i:j]) + } else { + m.Command = strings.ToUpper(raw[i:]) + + // We're done here! + return m + } + + // Find prefix for trailer. Note that because we need to match the trailing + // argument even if it's the only one, we can't skip the space until we've + // searched for it. + i = strings.Index(raw[j:], " :") + + // Skip the space + j++ + + if i < 0 { + + // There is no trailing argument! + m.Params = strings.Split(raw[j:], string(space)) + + // We're done here! + return m + } + + // Compensate for index on substring. Note that we skipped the space after + // looking for i, so we need to subtract 1 to account for that. + i = i + j - 1 + + // Check if we need to parse arguments. + if i > j { + m.Params = strings.Split(raw[j:i], string(space)) + } + + m.Params = append(m.Params, raw[i+2:]) + + return m +} + +// Len calculates the length of the string representation of this message. +func (m *Message) Len() (length int) { + + if m.Prefix != nil { + length = m.Prefix.Len() + 2 // Include prefix and trailing space + } + + length = length + len(m.Command) + + if len(m.Params) > 0 { + length = length + len(m.Params) + for _, param := range m.Params { + length = length + len(param) + } + + if trailing := m.Trailing(); len(trailing) < 1 || strings.Contains(trailing, " ") || trailing[0] == ':' { + // Add one for the colon in the trailing parameter + length++ + } + } + + return +} + +// Bytes returns a []byte representation of this message. +// +// As noted in rfc2812 section 2.3, messages should not exceed 512 characters +// in length. This method forces that limit by discarding any characters +// exceeding the length limit. +func (m *Message) Bytes() []byte { + + buffer := new(bytes.Buffer) + + // Message prefix + if m.Prefix != nil { + buffer.WriteByte(prefix) + m.Prefix.writeTo(buffer) + buffer.WriteByte(space) + } + + // Command is required + buffer.WriteString(m.Command) + + // Space separated list of arguments + if len(m.Params) > 1 { + buffer.WriteByte(space) + buffer.WriteString(strings.Join(m.Params[:len(m.Params)-1], string(space))) + } + + if len(m.Params) > 0 { + buffer.WriteByte(space) + trailing := m.Trailing() + if len(trailing) < 1 || strings.Contains(trailing, " ") || trailing[0] == ':' { + buffer.WriteByte(prefix) + } + buffer.WriteString(trailing) + } + + // We need the limit the buffer length. + if buffer.Len() > (maxLength) { + buffer.Truncate(maxLength) + } + + return buffer.Bytes() +} + +// String returns a string representation of this message. +// +// As noted in rfc2812 section 2.3, messages should not exceed 512 characters +// in length. This method forces that limit by discarding any characters +// exceeding the length limit. +func (m *Message) String() string { + return string(m.Bytes()) +} diff --git a/vendor/gopkg.in/sorcix/irc.v2/stream.go b/vendor/gopkg.in/sorcix/irc.v2/stream.go new file mode 100644 index 0000000..1f6c0b6 --- /dev/null +++ b/vendor/gopkg.in/sorcix/irc.v2/stream.go @@ -0,0 +1,147 @@ +// Copyright 2014 Vic Demuzere +// +// Use of this source code is governed by the MIT license. + +package irc + +import ( + "bufio" + "io" + "net" + "crypto/tls" + "sync" +) + +// Messages are delimited with CR and LF line endings, +// we're using the last one to split the stream. Both are removed +// during message parsing. +const delim byte = '\n' + +var endline = []byte("\r\n") + +// A Conn represents an IRC network protocol connection. +// It consists of an Encoder and Decoder to manage I/O. +type Conn struct { + Encoder + Decoder + + conn io.ReadWriteCloser +} + +// NewConn returns a new Conn using rwc for I/O. +func NewConn(rwc io.ReadWriteCloser) *Conn { + return &Conn{ + Encoder: Encoder{ + writer: rwc, + }, + Decoder: Decoder{ + reader: bufio.NewReader(rwc), + }, + conn: rwc, + } +} + +// Dial connects to the given address using net.Dial and +// then returns a new Conn for the connection. +func Dial(addr string) (*Conn, error) { + c, err := net.Dial("tcp", addr) + + if err != nil { + return nil, err + } + + return NewConn(c), nil +} + +// DialTLS connects to the given address using tls.Dial and +// then returns a new Conn for the connection. +func DialTLS(addr string, config *tls.Config) (*Conn, error) { + c, err := tls.Dial("tcp", addr, config) + + if err != nil { + return nil, err + } + + return NewConn(c), nil +} + +// Close closes the underlying ReadWriteCloser. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// A Decoder reads Message objects from an input stream. +type Decoder struct { + reader *bufio.Reader + line string + mu sync.Mutex +} + +// NewDecoder returns a new Decoder that reads from r. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{ + reader: bufio.NewReader(r), + } +} + +// Decode attempts to read a single Message from the stream. +// +// Returns a non-nil error if the read failed. +func (dec *Decoder) Decode() (m *Message, err error) { + + dec.mu.Lock() + dec.line, err = dec.reader.ReadString(delim) + dec.mu.Unlock() + + if err != nil { + return nil, err + } + + return ParseMessage(dec.line), nil +} + +// An Encoder writes Message objects to an output stream. +type Encoder struct { + writer io.Writer + mu sync.Mutex +} + +// NewEncoder returns a new Encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + writer: w, + } +} + +// Encode writes the IRC encoding of m to the stream. +// +// This method may be used from multiple goroutines. +// +// Returns an non-nil error if the write to the underlying stream stopped early. +func (enc *Encoder) Encode(m *Message) (err error) { + + _, err = enc.Write(m.Bytes()) + + return +} + +// Write writes len(p) bytes from p followed by CR+LF. +// +// This method can be used simultaneously from multiple goroutines, +// it guarantees to serialize access. However, writing a single IRC message +// using multiple Write calls will cause corruption. +func (enc *Encoder) Write(p []byte) (n int, err error) { + + enc.mu.Lock() + n, err = enc.writer.Write(p) + + if err != nil { + enc.mu.Unlock() + return + } + + _, err = enc.writer.Write(endline) + enc.mu.Unlock() + + return +} diff --git a/vendor/vendor.json b/vendor/vendor.json index cb3ce1f..822204c 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -87,16 +87,22 @@ "revisionTime": "2017-01-16T20:05:12Z" }, { - "checksumSHA1": "p0SvpHmpHEGr2eYb1tohV4EZhD0=", + "checksumSHA1": "U6Xb8fxRhdSrxoWzL5ehTRk/704=", "path": "github.com/bifurcation/mint", - "revision": "64af8ab8ccb81bd5d4eab356f79ba0939117d9f6", - "revisionTime": "2017-10-31T22:03:52Z" + "revision": "198357931e6129b810c9c77c12e0dd754846170c", + "revisionTime": "2018-03-06T13:52:33Z" }, { - "checksumSHA1": "usbuF7R80ixs5RS8ZM99C6OTDlc=", + "checksumSHA1": "yZuFJmLUlFHTDWyCLxcDqif/H/w=", "path": "github.com/bifurcation/mint/syntax", - "revision": "64af8ab8ccb81bd5d4eab356f79ba0939117d9f6", - "revisionTime": "2017-10-31T22:03:52Z" + "revision": "198357931e6129b810c9c77c12e0dd754846170c", + "revisionTime": "2018-03-06T13:52:33Z" + }, + { + "checksumSHA1": "PYXuf7wvcj492uWhjvmXlyyQUYc=", + "path": "github.com/cheekybits/genny/generic", + "revision": "9127e812e1e9e501ce899a18121d316ecb52e4ba", + "revisionTime": "2017-03-28T20:00:08Z" }, { "checksumSHA1": "xqVDKHGnakGlcRhmWd1j9JYmfLc=", @@ -141,16 +147,22 @@ "revisionTime": "2016-01-25T20:49:56Z" }, { - "checksumSHA1": "d9PxF1XQGLMJZRct2R8qVM/eYlE=", + "checksumSHA1": "UquR8kc0nKU285HwLbkievlLQz4=", "path": "github.com/hashicorp/golang-lru", - "revision": "0a025b7e63adc15a622f29b0b2c4c3848243bbf6", - "revisionTime": "2016-08-13T22:13:03Z" + "revision": "0fb14efe8c47ae851c0034ed7a448854d3d34cf3", + "revisionTime": "2018-02-01T23:52:37Z" }, { - "checksumSHA1": "9hffs0bAIU6CquiRhKQdzjHnKt0=", + "checksumSHA1": "8Z637dcPkbR5HdLQQBp/9jTbx9Y=", "path": "github.com/hashicorp/golang-lru/simplelru", - "revision": "0a025b7e63adc15a622f29b0b2c4c3848243bbf6", - "revisionTime": "2016-08-13T22:13:03Z" + "revision": "0fb14efe8c47ae851c0034ed7a448854d3d34cf3", + "revisionTime": "2018-02-01T23:52:37Z" + }, + { + "checksumSHA1": "It/0oRIKZ4N13p+FB7midIrwzGk=", + "path": "github.com/isofew/go-stun/stun", + "revision": "87ca637a836189f118b7b12c62a22ebcf15ae246", + "revisionTime": "2018-05-15T16:46:54Z" }, { "checksumSHA1": "/EgCTbjJkJh2yi9lqEgzmau8O4I=", @@ -174,8 +186,8 @@ "checksumSHA1": "2RFzGcdTeQrFkkhT70WhQcMWF6c=", "origin": "github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12", "path": "github.com/lucas-clemente/aes12", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "cd47fb39b79f867c6e4e5cd39cf7abd799f71670", + "revisionTime": "2017-10-27T16:34:21Z" }, { "checksumSHA1": "ne1X+frkx5fJcpz9FaZPuUZ7amM=", @@ -185,17 +197,17 @@ "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "8wGsJmhHz2l6XAcpoJ4NiF2zXjY=", + "checksumSHA1": "8cInEDwkOAFQ6VKSTlMtKo+MOtA=", "path": "github.com/lucas-clemente/quic-go", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { "checksumSHA1": "OA9E+y7g05x/mWJJHmA7oPxWKQo=", "origin": "github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/quic-go-certificates", "path": "github.com/lucas-clemente/quic-go-certificates", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "d2f86524cced5186554df90d92529757d22c1cb6", + "revisionTime": "2016-08-23T09:51:56Z" }, { "checksumSHA1": "w1JPHEDB2mQW4qpKZbWj5zsavuM=", @@ -210,46 +222,58 @@ "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "lWOh0Q0bY/dd3G/MZDCkzk1dVTo=", + "checksumSHA1": "hYpTIYEUy7Nfx4MGRQjy//F5PZs=", + "path": "github.com/lucas-clemente/quic-go/internal/ackhandler", + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" + }, + { + "checksumSHA1": "i1yfut7QQqMehw5yE9llhWNnrxk=", + "path": "github.com/lucas-clemente/quic-go/internal/congestion", + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" + }, + { + "checksumSHA1": "8CRRInUpwdxqXFGWnrW1KTUYOUE=", "path": "github.com/lucas-clemente/quic-go/internal/crypto", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { - "checksumSHA1": "cab7WtoBeOlbQGMEoTaKAjEbqZg=", + "checksumSHA1": "rnRicg73lPAeRh9Nko6a0CZQS5I=", "path": "github.com/lucas-clemente/quic-go/internal/flowcontrol", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { - "checksumSHA1": "gWAXju/s95yWXKYLgCnxx+Ed22M=", + "checksumSHA1": "A9fe2DfiT694bmD1deedP84eUsE=", "path": "github.com/lucas-clemente/quic-go/internal/handshake", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { - "checksumSHA1": "VjW23wuTXH3REwjcwhfdQrJTUDI=", + "checksumSHA1": "NwrPs5iGdniZnUEbJGlAcpoEPEY=", "path": "github.com/lucas-clemente/quic-go/internal/protocol", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { - "checksumSHA1": "V9xXEL18b0TrZoe+dqHAmK0beCY=", + "checksumSHA1": "O3VAr5sAfdTINEGQdhrkZ88zWZs=", "path": "github.com/lucas-clemente/quic-go/internal/utils", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { - "checksumSHA1": "Zs24W5nNpq5Adhl7069+094yW3A=", + "checksumSHA1": "xjxmJu6YsH81fJTD48OTd/zQgrE=", "path": "github.com/lucas-clemente/quic-go/internal/wire", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { - "checksumSHA1": "RaG0jfP+lFzgedW98Bfp0Uri7EY=", + "checksumSHA1": "bFSC4TOZGOZGBJEFmLAT3V4ieoo=", "path": "github.com/lucas-clemente/quic-go/qerr", - "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", - "revisionTime": "2017-11-13T03:10:14Z" + "revision": "407a563c73ad24115f0ac1e7df5f6097353642db", + "revisionTime": "2018-05-14T02:42:56Z" }, { "checksumSHA1": "ynJSWoF6v+3zMnh9R0QmmG6iGV8=", @@ -437,6 +461,18 @@ "revision": "3ab3a8b8831546bd18fd182c20687ca853b2bb13", "revisionTime": "2016-12-15T22:53:35Z" }, + { + "checksumSHA1": "GBscl6AAvoNke1G8ZgVbPWheqqE=", + "path": "gopkg.in/sorcix/irc.v2", + "revision": "1b25be7f891d1bd0190ac0ef159da153c9ffa22a", + "revisionTime": "2017-07-26T15:46:28Z" + }, + { + "checksumSHA1": "745d6gaBv5Ni3YrPvslpM4BoU6A=", + "path": "gopkg.in/sorcix/irc.v2/internal", + "revision": "1b25be7f891d1bd0190ac0ef159da153c9ffa22a", + "revisionTime": "2017-07-26T15:46:28Z" + }, { "checksumSHA1": "FJUY4FAPSimDxXpnlZM8ou26ztE=", "path": "gopkg.in/xtaci/kcp-go.v2",