From db4591cadd9fe5bc6484717d205e1f8f88770748 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Mon, 20 Nov 2017 16:36:35 +0800 Subject: [PATCH 1/3] update vendor for QUIC --- vendor/github.com/aead/chacha20/LICENSE | 21 + vendor/github.com/aead/chacha20/README.md | 79 ++ .../github.com/aead/chacha20/chacha/chacha.go | 176 ++++ .../aead/chacha20/chacha/chachaAVX2_amd64.s | 542 ++++++++++ .../aead/chacha20/chacha/chacha_386.go | 67 ++ .../aead/chacha20/chacha/chacha_386.s | 311 ++++++ .../aead/chacha20/chacha/chacha_amd64.s | 788 +++++++++++++++ .../aead/chacha20/chacha/chacha_generic.go | 319 ++++++ .../aead/chacha20/chacha/chacha_go16_amd64.go | 56 ++ .../aead/chacha20/chacha/chacha_go17_amd64.go | 72 ++ .../aead/chacha20/chacha/chacha_ref.go | 26 + vendor/github.com/aead/chacha20/chacha20.go | 41 + vendor/github.com/bifurcation/mint/LICENSE.md | 21 + vendor/github.com/bifurcation/mint/README.md | 88 ++ vendor/github.com/bifurcation/mint/alert.go | 99 ++ .../bifurcation/mint/client-state-machine.go | 942 ++++++++++++++++++ vendor/github.com/bifurcation/mint/common.go | 152 +++ vendor/github.com/bifurcation/mint/conn.go | 819 +++++++++++++++ vendor/github.com/bifurcation/mint/crypto.go | 654 ++++++++++++ .../github.com/bifurcation/mint/extensions.go | 586 +++++++++++ vendor/github.com/bifurcation/mint/ffdhe.go | 147 +++ .../bifurcation/mint/frame-reader.go | 98 ++ .../bifurcation/mint/handshake-layer.go | 253 +++++ .../bifurcation/mint/handshake-messages.go | 450 +++++++++ vendor/github.com/bifurcation/mint/log.go | 55 + vendor/github.com/bifurcation/mint/mint.svg | 101 ++ .../bifurcation/mint/negotiation.go | 217 ++++ .../bifurcation/mint/record-layer.go | 296 ++++++ .../bifurcation/mint/server-state-machine.go | 898 +++++++++++++++++ .../bifurcation/mint/state-machine.go | 230 +++++ .../bifurcation/mint/syntax/README.md | 74 ++ .../bifurcation/mint/syntax/decode.go | 243 +++++ .../bifurcation/mint/syntax/encode.go | 187 ++++ .../bifurcation/mint/syntax/tags.go | 30 + vendor/github.com/bifurcation/mint/tls.go | 168 ++++ .../lucas-clemente/aes12/cipher_generic.go | 2 +- .../lucas-clemente/quic-go/Changelog.md | 7 +- .../lucas-clemente/quic-go/README.md | 4 +- .../quic-go/ackhandler/interfaces.go | 14 +- .../quic-go/ackhandler/packet.go | 14 +- .../ackhandler/received_packet_handler.go | 60 +- .../ackhandler/received_packet_history.go | 69 +- .../quic-go/ackhandler/retransmittable.go | 16 +- .../quic-go/ackhandler/sent_packet_handler.go | 80 +- .../ackhandler/stop_waiting_manager.go | 12 +- .../lucas-clemente/quic-go/appveyor.yml | 9 +- .../lucas-clemente/quic-go/buffer_pool.go | 2 +- .../lucas-clemente/quic-go/client.go | 179 ++-- .../quic-go/congestion/bandwidth.go | 2 +- .../quic-go/congestion/cubic.go | 2 +- .../quic-go/congestion/cubic_sender.go | 2 +- .../quic-go/congestion/hybrid_slow_start.go | 2 +- .../quic-go/congestion/interface.go | 2 +- .../quic-go/congestion/prr_sender.go | 2 +- .../quic-go/congestion/rtt_stats.go | 1 + .../quic-go/congestion/stats.go | 2 +- .../quic-go/crypto/aesgcm_aead.go | 58 -- .../lucas-clemente/quic-go/crypto/nonce.go | 14 - .../flowcontrol/flow_control_manager.go | 240 ----- .../quic-go/flowcontrol/flow_controller.go | 198 ---- .../quic-go/flowcontrol/interface.go | 26 - .../quic-go/frames/ack_range.go | 9 - .../quic-go/frames/blocked_frame.go | 44 - .../lucas-clemente/quic-go/frames/log.go | 28 - .../quic-go/frames/window_update_frame.go | 54 - .../lucas-clemente/quic-go/h2quic/client.go | 296 ------ .../quic-go/h2quic/gzipreader.go | 35 - .../lucas-clemente/quic-go/h2quic/request.go | 80 -- .../quic-go/h2quic/request_body.go | 29 - .../quic-go/h2quic/request_writer.go | 201 ---- .../lucas-clemente/quic-go/h2quic/response.go | 111 --- .../quic-go/h2quic/response_writer.go | 108 -- .../quic-go/h2quic/roundtrip.go | 168 ---- .../lucas-clemente/quic-go/h2quic/server.go | 382 ------- .../connection_parameters_manager.go | 265 ----- .../quic-go/handshake/stk_generator.go | 100 -- .../lucas-clemente/quic-go/interface.go | 61 +- .../quic-go/{ => internal}/crypto/AEAD.go | 3 +- .../quic-go/internal/crypto/aesgcm12_aead.go | 72 ++ .../quic-go/internal/crypto/aesgcm_aead.go | 74 ++ .../{ => internal}/crypto/cert_cache.go | 2 +- .../{ => internal}/crypto/cert_chain.go | 0 .../{ => internal}/crypto/cert_compression.go | 14 +- .../{ => internal}/crypto/cert_dict.go | 0 .../{ => internal}/crypto/cert_manager.go | 0 .../{ => internal}/crypto/cert_sets.go | 0 .../crypto/chacha20poly1305_aead.go | 14 +- .../crypto/chacha20poly1305_aead_test.go | 0 .../{ => internal}/crypto/curve_25519.go | 0 .../quic-go/internal/crypto/key_derivation.go | 49 + .../crypto/key_derivation_quic_crypto.go} | 10 +- .../{ => internal}/crypto/key_exchange.go | 0 .../quic-go/internal/crypto/null_aead.go | 11 + .../internal/crypto/null_aead_aesgcm.go | 44 + .../crypto/null_aead_fnv128a.go} | 43 +- .../{ => internal}/crypto/server_proof.go | 0 .../crypto/source_address_token.go | 0 .../flowcontrol/base_flow_controller.go | 110 ++ .../flowcontrol/connection_flow_controller.go | 77 ++ .../quic-go/internal/flowcontrol/interface.go | 37 + .../flowcontrol/stream_flow_controller.go | 128 +++ .../internal/handshake/cookie_generator.go | 101 ++ .../internal/handshake/cookie_handler.go | 43 + .../handshake/crypto_setup_client.go | 148 ++- .../handshake/crypto_setup_server.go | 139 +-- .../internal/handshake/crypto_setup_tls.go | 242 +++++ .../handshake/ephermal_cache.go | 4 +- .../handshake/handshake_message.go | 20 +- .../{ => internal}/handshake/interface.go | 19 +- .../quic-go/internal/handshake/mint_utils.go | 127 +++ .../{ => internal}/handshake/server_config.go | 10 +- .../handshake/server_config_client.go | 2 +- .../quic-go/{ => internal}/handshake/tags.go | 3 + .../internal/handshake/tls_extension.go | 54 + .../handshake/tls_extension_handler_client.go | 122 +++ .../handshake/tls_extension_handler_server.go | 109 ++ .../handshake/transport_parameters.go | 167 ++++ .../protocol/encryption_level.go | 0 .../{ => internal}/protocol/packet_number.go | 12 +- .../{ => internal}/protocol/perspective.go | 0 .../{ => internal}/protocol/protocol.go | 22 +- .../protocol/server_parameters.go | 52 +- .../quic-go/internal/protocol/version.go | 114 +++ .../quic-go/internal/utils/byteorder.go | 33 + .../internal/utils/byteorder_big_endian.go | 157 +++ .../{utils.go => byteorder_little_endian.go} | 53 +- .../quic-go/internal/utils/connection_id.go | 2 +- .../quic-go/internal/utils/float16.go | 12 +- .../quic-go/internal/utils/log.go | 9 +- .../quic-go/internal/utils/minmax.go | 2 +- .../quic-go/internal/utils/packet_interval.go | 2 +- .../internal/utils/streamframe_interval.go | 2 +- .../{frames => internal/wire}/ack_frame.go | 86 +- .../quic-go/internal/wire/ack_range.go | 9 + .../quic-go/internal/wire/blocked_frame.go | 35 + .../internal/wire/blocked_frame_legacy.go | 38 + .../wire}/connection_close_frame.go | 22 +- .../{frames => internal/wire}/frame.go | 4 +- .../{frames => internal/wire}/goaway_frame.go | 27 +- .../quic-go/internal/wire/header.go | 111 +++ .../quic-go/internal/wire/ietf_header.go | 170 ++++ .../quic-go/internal/wire/log.go | 28 + .../quic-go/internal/wire/max_data_frame.go | 51 + .../internal/wire/max_stream_data_frame.go | 56 ++ .../{frames => internal/wire}/ping_frame.go | 6 +- .../{ => internal/wire}/public_header.go | 160 ++- .../quic-go/internal/wire/public_reset.go | 65 ++ .../wire}/rst_stream_frame.go | 22 +- .../wire}/stop_waiting_frame.go | 32 +- .../internal/wire/stream_blocked_frame.go | 44 + .../{frames => internal/wire}/stream_frame.go | 60 +- .../internal/wire/version_negotiation.go | 51 + .../internal/wire/window_update_frame.go | 35 + .../quic-go/packet_number_generator.go | 2 +- .../lucas-clemente/quic-go/packet_packer.go | 155 +-- .../lucas-clemente/quic-go/packet_unpacker.go | 111 ++- .../quic-go/protocol/version.go | 55 - .../lucas-clemente/quic-go/public_reset.go | 62 -- .../lucas-clemente/quic-go/server.go | 128 +-- .../lucas-clemente/quic-go/session.go | 404 ++++---- .../lucas-clemente/quic-go/stream.go | 120 ++- .../quic-go/stream_frame_sorter.go | 14 +- .../lucas-clemente/quic-go/stream_framer.go | 101 +- .../lucas-clemente/quic-go/streams_map.go | 181 ++-- .../x/crypto/curve25519/curve25519.go | 2 +- vendor/vendor.json | 129 ++- 166 files changed, 13451 insertions(+), 3891 deletions(-) create mode 100644 vendor/github.com/aead/chacha20/LICENSE create mode 100644 vendor/github.com/aead/chacha20/README.md create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha.go create mode 100644 vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha_386.go create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha_386.s create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha_amd64.s create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha_generic.go create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha_go17_amd64.go create mode 100644 vendor/github.com/aead/chacha20/chacha/chacha_ref.go create mode 100644 vendor/github.com/aead/chacha20/chacha20.go create mode 100644 vendor/github.com/bifurcation/mint/LICENSE.md create mode 100644 vendor/github.com/bifurcation/mint/README.md create mode 100644 vendor/github.com/bifurcation/mint/alert.go create mode 100644 vendor/github.com/bifurcation/mint/client-state-machine.go create mode 100644 vendor/github.com/bifurcation/mint/common.go create mode 100644 vendor/github.com/bifurcation/mint/conn.go create mode 100644 vendor/github.com/bifurcation/mint/crypto.go create mode 100644 vendor/github.com/bifurcation/mint/extensions.go create mode 100644 vendor/github.com/bifurcation/mint/ffdhe.go create mode 100644 vendor/github.com/bifurcation/mint/frame-reader.go create mode 100644 vendor/github.com/bifurcation/mint/handshake-layer.go create mode 100644 vendor/github.com/bifurcation/mint/handshake-messages.go create mode 100644 vendor/github.com/bifurcation/mint/log.go create mode 100644 vendor/github.com/bifurcation/mint/mint.svg create mode 100644 vendor/github.com/bifurcation/mint/negotiation.go create mode 100644 vendor/github.com/bifurcation/mint/record-layer.go create mode 100644 vendor/github.com/bifurcation/mint/server-state-machine.go create mode 100644 vendor/github.com/bifurcation/mint/state-machine.go create mode 100644 vendor/github.com/bifurcation/mint/syntax/README.md create mode 100644 vendor/github.com/bifurcation/mint/syntax/decode.go create mode 100644 vendor/github.com/bifurcation/mint/syntax/encode.go create mode 100644 vendor/github.com/bifurcation/mint/syntax/tags.go create mode 100644 vendor/github.com/bifurcation/mint/tls.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/log.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/client.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/request.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/response.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/h2quic/server.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/AEAD.go (79%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_cache.go (94%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_chain.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_compression.go (94%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_dict.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_manager.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/cert_sets.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/chacha20poly1305_aead.go (72%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/chacha20poly1305_aead_test.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/curve_25519.go (100%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go rename vendor/github.com/lucas-clemente/quic-go/{crypto/key_derivation.go => internal/crypto/key_derivation_quic_crypto.go} (84%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/key_exchange.go (100%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go rename vendor/github.com/lucas-clemente/quic-go/{crypto/null_aead.go => internal/crypto/null_aead_fnv128a.go} (55%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/server_proof.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/crypto/source_address_token.go (100%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/crypto_setup_client.go (77%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/crypto_setup_server.go (78%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/ephermal_cache.go (92%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/handshake_message.go (86%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/interface.go (52%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/server_config.go (88%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/server_config_client.go (98%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/handshake/tags.go (96%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/encryption_level.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/packet_number.go (74%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/perspective.go (100%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/protocol.go (75%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal}/protocol/server_parameters.go (74%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go rename vendor/github.com/lucas-clemente/quic-go/internal/utils/{utils.go => byteorder_little_endian.go} (64%) rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/ack_frame.go (78%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/connection_close_frame.go (62%) rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/frame.go (74%) rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/goaway_frame.go (66%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/ping_frame.go (76%) rename vendor/github.com/lucas-clemente/quic-go/{ => internal/wire}/public_header.go (57%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/rst_stream_frame.go (60%) rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/stop_waiting_frame.go (78%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go rename vendor/github.com/lucas-clemente/quic-go/{frames => internal/wire}/stream_frame.go (69%) create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/protocol/version.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/public_reset.go diff --git a/vendor/github.com/aead/chacha20/LICENSE b/vendor/github.com/aead/chacha20/LICENSE new file mode 100644 index 0000000..b6a9210 --- /dev/null +++ b/vendor/github.com/aead/chacha20/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Andreas Auernhammer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/aead/chacha20/README.md b/vendor/github.com/aead/chacha20/README.md new file mode 100644 index 0000000..bb2a320 --- /dev/null +++ b/vendor/github.com/aead/chacha20/README.md @@ -0,0 +1,79 @@ +[![Godoc Reference](https://godoc.org/github.com/aead/chacha20?status.svg)](https://godoc.org/github.com/aead/chacha20) + +## The ChaCha20 stream cipher + +ChaCha is a stream cipher family created by Daniel J. Bernstein. +The most common ChaCha cipher is ChaCha20 (20 rounds). ChaCha20 is standardized in [RFC 7539](https://tools.ietf.org/html/rfc7539 "RFC 7539"). + +This package provides implementations of three ChaCha versions: +- ChaCha20 with a 64 bit nonce (can en/decrypt up to 2^64 * 64 bytes for one key-nonce combination) +- ChaCha20 with a 96 bit nonce (can en/decrypt up to 2^32 * 64 bytes ~ 256 GB for one key-nonce combination) +- XChaCha20 with a 192 bit nonce (can en/decrypt up to 2^64 * 64 bytes for one key-nonce combination) + +Furthermore the chacha subpackage implements ChaCha20/12 and ChaCha20/8. +These versions use 12 or 8 rounds instead of 20. +But it's recommended to use ChaCha20 (with 20 rounds) - it will be fast enough for almost all purposes. + +### Installation +Install in your GOPATH: `go get -u github.com/aead/chacha20` + +### Requirements +All go versions >= 1.5.3 are supported. +Please notice, that the amd64 AVX2 asm implementation requires go1.7 or newer. + +### Performance + +#### AMD64 +Hardware: Intel i7-6500U 2.50GHz x 2 +System: Linux Ubuntu 16.04 - kernel: 4.4.0-62-generic +Go version: 1.8.0 +``` +AVX2 +name speed cpb +ChaCha20_64-4 573MB/s ± 0% 4.16 +ChaCha20_1K-4 2.19GB/s ± 0% 1.06 +XChaCha20_64-4 261MB/s ± 0% 9.13 +XChaCha20_1K-4 1.69GB/s ± 4% 1.37 +XORKeyStream64-4 474MB/s ± 2% 5.02 +XORKeyStream1K-4 2.09GB/s ± 1% 1.11 +XChaCha20_XORKeyStream64-4 262MB/s ± 0% 9.09 +XChaCha20_XORKeyStream1K-4 1.71GB/s ± 1% 1.36 + +SSSE3 +name speed cpb +ChaCha20_64-4 583MB/s ± 0% 4.08 +ChaCha20_1K-4 1.15GB/s ± 1% 2.02 +XChaCha20_64-4 267MB/s ± 0% 8.92 +XChaCha20_1K-4 984MB/s ± 5% 2.42 +XORKeyStream64-4 492MB/s ± 1% 4.84 +XORKeyStream1K-4 1.10GB/s ± 5% 2.11 +XChaCha20_XORKeyStream64-4 266MB/s ± 0% 8.96 +XChaCha20_XORKeyStream1K-4 1.00GB/s ± 2% 2.32 +``` +#### 386 +Hardware: Intel i7-6500U 2.50GHz x 2 +System: Linux Ubuntu 16.04 - kernel: 4.4.0-62-generic +Go version: 1.8.0 +``` +SSSE3 +name                        speed cpb +ChaCha20_64-4               570MB/s ± 0% 4.18 +ChaCha20_1K-4               650MB/s ± 0% 3.66 +XChaCha20_64-4              223MB/s ± 0% 10.69 +XChaCha20_1K-4              584MB/s ± 1% 4.08 +XORKeyStream64-4            392MB/s ± 1% 6.08 +XORKeyStream1K-4            629MB/s ± 1% 3.79 +XChaCha20_XORKeyStream64-4  222MB/s ± 0% 10.73 +XChaCha20_XORKeyStream1K-4  585MB/s ± 0% 4.07 + +SSE2 +name speed cpb +ChaCha20_64-4 509MB/s ± 0% 4.68 +ChaCha20_1K-4 553MB/s ± 2% 4.31 +XChaCha20_64-4 201MB/s ± 0% 11.86 +XChaCha20_1K-4 498MB/s ± 4% 4.78 +XORKeyStream64-4 359MB/s ± 1% 6.64 +XORKeyStream1K-4 545MB/s ± 0% 4.37 +XChaCha20_XORKeyStream64-4 201MB/s ± 1% 11.86 +XChaCha20_XORKeyStream1K-4 507MB/s ± 0% 4.70 +``` diff --git a/vendor/github.com/aead/chacha20/chacha/chacha.go b/vendor/github.com/aead/chacha20/chacha/chacha.go new file mode 100644 index 0000000..8c387a9 --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha.go @@ -0,0 +1,176 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// Package chacha implements some low-level functions of the +// ChaCha cipher family. +package chacha // import "github.com/aead/chacha20/chacha" + +import ( + "encoding/binary" + "errors" +) + +const ( + // NonceSize is the size of the ChaCha20 nonce in bytes. + NonceSize = 8 + + // INonceSize is the size of the IETF-ChaCha20 nonce in bytes. + INonceSize = 12 + + // XNonceSize is the size of the XChaCha20 nonce in bytes. + XNonceSize = 24 + + // KeySize is the size of the key in bytes. + KeySize = 32 +) + +var ( + useSSE2 bool + useSSSE3 bool + useAVX2 bool +) + +var ( + errKeySize = errors.New("chacha20/chacha: bad key length") + errInvalidNonce = errors.New("chacha20/chacha: bad nonce length") +) + +func setup(state *[64]byte, nonce, key []byte) (err error) { + if len(key) != KeySize { + err = errKeySize + return + } + var Nonce [16]byte + switch len(nonce) { + case NonceSize: + copy(Nonce[8:], nonce) + initialize(state, key, &Nonce) + case INonceSize: + copy(Nonce[4:], nonce) + initialize(state, key, &Nonce) + case XNonceSize: + var tmpKey [32]byte + var hNonce [16]byte + + copy(hNonce[:], nonce[:16]) + copy(tmpKey[:], key) + hChaCha20(&tmpKey, &hNonce, &tmpKey) + copy(Nonce[8:], nonce[16:]) + initialize(state, tmpKey[:], &Nonce) + + // BUG(aead): A "good" compiler will remove this (optimizations) + // But using the provided key instead of tmpKey, + // will change the key (-> probably confuses users) + for i := range tmpKey { + tmpKey[i] = 0 + } + default: + err = errInvalidNonce + } + return +} + +// XORKeyStream crypts bytes from src to dst using the given nonce and key. +// The length of the nonce determinds the version of ChaCha20: +// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period. +// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period. +// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period. +// The rounds argument specifies the number of rounds performed for keystream +// generation - valid values are 8, 12 or 20. The src and dst may be the same slice +// but otherwise should not overlap. If len(dst) < len(src) this function panics. +// If the nonce is neither 64, 96 nor 192 bits long, this function panics. +func XORKeyStream(dst, src, nonce, key []byte, rounds int) { + if rounds != 20 && rounds != 12 && rounds != 8 { + panic("chacha20/chacha: bad number of rounds") + } + if len(dst) < len(src) { + panic("chacha20/chacha: dst buffer is to small") + } + if len(nonce) == INonceSize && uint64(len(src)) > (1<<38) { + panic("chacha20/chacha: src is too large") + } + + var block, state [64]byte + if err := setup(&state, nonce, key); err != nil { + panic(err) + } + xorKeyStream(dst, src, &block, &state, rounds) +} + +// Cipher implements ChaCha20/r (XChaCha20/r) for a given number of rounds r. +type Cipher struct { + state, block [64]byte + off int + rounds int // 20 for ChaCha20 + noncesize int +} + +// NewCipher returns a new *chacha.Cipher implementing the ChaCha20/r or XChaCha20/r +// (r = 8, 12 or 20) stream cipher. The nonce must be unique for one key for all time. +// The length of the nonce determinds the version of ChaCha20: +// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period. +// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period. +// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period. +// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned. +func NewCipher(nonce, key []byte, rounds int) (*Cipher, error) { + if rounds != 20 && rounds != 12 && rounds != 8 { + panic("chacha20/chacha: bad number of rounds") + } + + c := new(Cipher) + if err := setup(&(c.state), nonce, key); err != nil { + return nil, err + } + c.rounds = rounds + + if len(nonce) == INonceSize { + c.noncesize = INonceSize + } else { + c.noncesize = NonceSize + } + + return c, nil +} + +// XORKeyStream crypts bytes from src to dst. Src and dst may be the same slice +// but otherwise should not overlap. If len(dst) < len(src) the function panics. +func (c *Cipher) XORKeyStream(dst, src []byte) { + if len(dst) < len(src) { + panic("chacha20/chacha: dst buffer is to small") + } + + if c.off > 0 { + n := len(c.block[c.off:]) + if len(src) <= n { + for i, v := range src { + dst[i] = v ^ c.block[c.off] + c.off++ + } + if c.off == 64 { + c.off = 0 + } + return + } + + for i, v := range c.block[c.off:] { + dst[i] = src[i] ^ v + } + src = src[n:] + dst = dst[n:] + c.off = 0 + } + + c.off += xorKeyStream(dst, src, &(c.block), &(c.state), c.rounds) +} + +// SetCounter skips ctr * 64 byte blocks. SetCounter(0) resets the cipher. +// This function always skips the unused keystream of the current 64 byte block. +func (c *Cipher) SetCounter(ctr uint64) { + if c.noncesize == INonceSize { + binary.LittleEndian.PutUint32(c.state[48:], uint32(ctr)) + } else { + binary.LittleEndian.PutUint64(c.state[48:], ctr) + } + c.off = 0 +} diff --git a/vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s b/vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s new file mode 100644 index 0000000..8d02233 --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chachaAVX2_amd64.s @@ -0,0 +1,542 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build go1.7,amd64,!gccgo,!appengine,!nacl + +#include "textflag.h" + +DATA ·sigma_AVX<>+0x00(SB)/4, $0x61707865 +DATA ·sigma_AVX<>+0x04(SB)/4, $0x3320646e +DATA ·sigma_AVX<>+0x08(SB)/4, $0x79622d32 +DATA ·sigma_AVX<>+0x0C(SB)/4, $0x6b206574 +GLOBL ·sigma_AVX<>(SB), (NOPTR+RODATA), $16 + +DATA ·one_AVX<>+0x00(SB)/8, $1 +DATA ·one_AVX<>+0x08(SB)/8, $0 +GLOBL ·one_AVX<>(SB), (NOPTR+RODATA), $16 + +DATA ·one_AVX2<>+0x00(SB)/8, $0 +DATA ·one_AVX2<>+0x08(SB)/8, $0 +DATA ·one_AVX2<>+0x10(SB)/8, $1 +DATA ·one_AVX2<>+0x18(SB)/8, $0 +GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 + +DATA ·two_AVX2<>+0x00(SB)/8, $2 +DATA ·two_AVX2<>+0x08(SB)/8, $0 +DATA ·two_AVX2<>+0x10(SB)/8, $2 +DATA ·two_AVX2<>+0x18(SB)/8, $0 +GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32 + +DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302 +DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A +DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302 +DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A +GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 + +DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003 +DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B +DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003 +DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B +GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 + +#define ROTL(n, t, v) \ + VPSLLD $n, v, t; \ + VPSRLD $(32-n), v, v; \ + VPXOR v, t, v + +#define CHACHA_QROUND(v0, v1, v2, v3, t, c16, c8) \ + VPADDD v0, v1, v0; \ + VPXOR v3, v0, v3; \ + VPSHUFB c16, v3, v3; \ + VPADDD v2, v3, v2; \ + VPXOR v1, v2, v1; \ + ROTL(12, t, v1); \ + VPADDD v0, v1, v0; \ + VPXOR v3, v0, v3; \ + VPSHUFB c8, v3, v3; \ + VPADDD v2, v3, v2; \ + VPXOR v1, v2, v1; \ + ROTL(7, t, v1) + +#define CHACHA_SHUFFLE(v1, v2, v3) \ + VPSHUFD $0x39, v1, v1; \ + VPSHUFD $0x4E, v2, v2; \ + VPSHUFD $-109, v3, v3 + +#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ + VMOVDQU (0+off)(src), t0; \ + VPERM2I128 $32, v1, v0, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (0+off)(dst); \ + VMOVDQU (32+off)(src), t0; \ + VPERM2I128 $32, v3, v2, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (32+off)(dst); \ + VMOVDQU (64+off)(src), t0; \ + VPERM2I128 $49, v1, v0, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (64+off)(dst); \ + VMOVDQU (96+off)(src), t0; \ + VPERM2I128 $49, v3, v2, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (96+off)(dst) + +#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ + VMOVDQU (0+off)(src), t0; \ + VPERM2I128 $32, v1, v0, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (0+off)(dst); \ + VMOVDQU (32+off)(src), t0; \ + VPERM2I128 $32, v3, v2, t1; \ + VPXOR t0, t1, t0; \ + VMOVDQU t0, (32+off)(dst); \ + +#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \ + VPERM2I128 $49, v1, v0, t0; \ + VMOVDQU t0, 0(dst); \ + VPERM2I128 $49, v3, v2, t0; \ + VMOVDQU t0, 32(dst) + +#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t0) \ + VPXOR 0+off(src), v0, t0; \ + VMOVDQU t0, 0+off(dst); \ + VPXOR 16+off(src), v1, t0; \ + VMOVDQU t0, 16+off(dst); \ + VPXOR 32+off(src), v2, t0; \ + VMOVDQU t0, 32+off(dst); \ + VPXOR 48+off(src), v3, t0; \ + VMOVDQU t0, 48+off(dst) + +#define TWO 0(SP) +#define C16 32(SP) +#define C8 64(SP) +#define STATE_0 96(SP) +#define STATE_1 128(SP) +#define STATE_2 160(SP) +#define STATE_3 192(SP) +#define TMP_0 224(SP) +#define TMP_1 256(SP) + +// func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamAVX2(SB), 4, $320-80 + MOVQ dst_base+0(FP), DI + MOVQ src_base+24(FP), SI + MOVQ src_len+32(FP), CX + MOVQ block+48(FP), BX + MOVQ state+56(FP), AX + MOVQ rounds+64(FP), DX + + MOVQ SP, R8 + ADDQ $32, SP + ANDQ $-32, SP + + VMOVDQU 0(AX), Y2 + VMOVDQU 32(AX), Y3 + VPERM2I128 $0x22, Y2, Y0, Y0 + VPERM2I128 $0x33, Y2, Y1, Y1 + VPERM2I128 $0x22, Y3, Y2, Y2 + VPERM2I128 $0x33, Y3, Y3, Y3 + + TESTQ CX, CX + JZ done + + VMOVDQU ·one_AVX2<>(SB), Y4 + VPADDD Y4, Y3, Y3 + + VMOVDQA Y0, STATE_0 + VMOVDQA Y1, STATE_1 + VMOVDQA Y2, STATE_2 + VMOVDQA Y3, STATE_3 + + VMOVDQU ·rol16_AVX2<>(SB), Y4 + VMOVDQU ·rol8_AVX2<>(SB), Y5 + VMOVDQU ·two_AVX2<>(SB), Y6 + VMOVDQA Y4, Y14 + VMOVDQA Y5, Y15 + VMOVDQA Y4, C16 + VMOVDQA Y5, C8 + VMOVDQA Y6, TWO + + CMPQ CX, $64 + JBE between_0_and_64 + CMPQ CX, $192 + JBE between_64_and_192 + CMPQ CX, $320 + JBE between_192_and_320 + CMPQ CX, $448 + JBE between_320_and_448 + +at_least_512: + VMOVDQA Y0, Y4 + VMOVDQA Y1, Y5 + VMOVDQA Y2, Y6 + VPADDQ TWO, Y3, Y7 + VMOVDQA Y0, Y8 + VMOVDQA Y1, Y9 + VMOVDQA Y2, Y10 + VPADDQ TWO, Y7, Y11 + VMOVDQA Y0, Y12 + VMOVDQA Y1, Y13 + VMOVDQA Y2, Y14 + VPADDQ TWO, Y11, Y15 + + MOVQ DX, R9 + +chacha_loop_512: + VMOVDQA Y8, TMP_0 + CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8) + VMOVDQA TMP_0, Y8 + VMOVDQA Y0, TMP_0 + CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8) + CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8) + CHACHA_SHUFFLE(Y1, Y2, Y3) + CHACHA_SHUFFLE(Y5, Y6, Y7) + CHACHA_SHUFFLE(Y9, Y10, Y11) + CHACHA_SHUFFLE(Y13, Y14, Y15) + + CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8) + CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8) + VMOVDQA TMP_0, Y0 + VMOVDQA Y8, TMP_0 + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8) + CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) + VMOVDQA TMP_0, Y8 + CHACHA_SHUFFLE(Y3, Y2, Y1) + CHACHA_SHUFFLE(Y7, Y6, Y5) + CHACHA_SHUFFLE(Y11, Y10, Y9) + CHACHA_SHUFFLE(Y15, Y14, Y13) + SUBQ $2, R9 + JA chacha_loop_512 + + VMOVDQA Y12, TMP_0 + VMOVDQA Y13, TMP_1 + VPADDD STATE_0, Y0, Y0 + VPADDD STATE_1, Y1, Y1 + VPADDD STATE_2, Y2, Y2 + VPADDD STATE_3, Y3, Y3 + XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13) + VMOVDQA STATE_0, Y0 + VMOVDQA STATE_1, Y1 + VMOVDQA STATE_2, Y2 + VMOVDQA STATE_3, Y3 + VPADDQ TWO, Y3, Y3 + + VPADDD Y0, Y4, Y4 + VPADDD Y1, Y5, Y5 + VPADDD Y2, Y6, Y6 + VPADDD Y3, Y7, Y7 + XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13) + VPADDQ TWO, Y3, Y3 + + VPADDD Y0, Y8, Y8 + VPADDD Y1, Y9, Y9 + VPADDD Y2, Y10, Y10 + VPADDD Y3, Y11, Y11 + XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) + VPADDQ TWO, Y3, Y3 + + VPADDD TMP_0, Y0, Y12 + VPADDD TMP_1, Y1, Y13 + VPADDD Y2, Y14, Y14 + VPADDD Y3, Y15, Y15 + VPADDQ TWO, Y3, Y3 + + CMPQ CX, $512 + JB less_than_512 + + XOR_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5) + VMOVDQA Y3, STATE_3 + ADDQ $512, SI + ADDQ $512, DI + SUBQ $512, CX + CMPQ CX, $448 + JA at_least_512 + + TESTQ CX, CX + JZ done + + VMOVDQA C16, Y14 + VMOVDQA C8, Y15 + + CMPQ CX, $64 + JBE between_0_and_64 + CMPQ CX, $192 + JBE between_64_and_192 + CMPQ CX, $320 + JBE between_192_and_320 + JMP between_320_and_448 + +less_than_512: + XOR_UPPER_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5) + EXTRACT_LOWER(BX, Y12, Y13, Y14, Y15, Y4) + ADDQ $448, SI + ADDQ $448, DI + SUBQ $448, CX + JMP finalize + +between_320_and_448: + VMOVDQA Y0, Y4 + VMOVDQA Y1, Y5 + VMOVDQA Y2, Y6 + VPADDQ TWO, Y3, Y7 + VMOVDQA Y0, Y8 + VMOVDQA Y1, Y9 + VMOVDQA Y2, Y10 + VPADDQ TWO, Y7, Y11 + + MOVQ DX, R9 + +chacha_loop_384: + CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE(Y1, Y2, Y3) + CHACHA_SHUFFLE(Y5, Y6, Y7) + CHACHA_SHUFFLE(Y9, Y10, Y11) + CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE(Y3, Y2, Y1) + CHACHA_SHUFFLE(Y7, Y6, Y5) + CHACHA_SHUFFLE(Y11, Y10, Y9) + SUBQ $2, R9 + JA chacha_loop_384 + + VPADDD STATE_0, Y0, Y0 + VPADDD STATE_1, Y1, Y1 + VPADDD STATE_2, Y2, Y2 + VPADDD STATE_3, Y3, Y3 + XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13) + VMOVDQA STATE_0, Y0 + VMOVDQA STATE_1, Y1 + VMOVDQA STATE_2, Y2 + VMOVDQA STATE_3, Y3 + VPADDQ TWO, Y3, Y3 + + VPADDD Y0, Y4, Y4 + VPADDD Y1, Y5, Y5 + VPADDD Y2, Y6, Y6 + VPADDD Y3, Y7, Y7 + XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13) + VPADDQ TWO, Y3, Y3 + + VPADDD Y0, Y8, Y8 + VPADDD Y1, Y9, Y9 + VPADDD Y2, Y10, Y10 + VPADDD Y3, Y11, Y11 + VPADDQ TWO, Y3, Y3 + + CMPQ CX, $384 + JB less_than_384 + + XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) + SUBQ $384, CX + TESTQ CX, CX + JE done + + ADDQ $384, SI + ADDQ $384, DI + JMP between_0_and_64 + +less_than_384: + XOR_UPPER_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) + EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12) + ADDQ $320, SI + ADDQ $320, DI + SUBQ $320, CX + JMP finalize + +between_192_and_320: + VMOVDQA Y0, Y4 + VMOVDQA Y1, Y5 + VMOVDQA Y2, Y6 + VMOVDQA Y3, Y7 + VMOVDQA Y0, Y8 + VMOVDQA Y1, Y9 + VMOVDQA Y2, Y10 + VPADDQ TWO, Y3, Y11 + + MOVQ DX, R9 + +chacha_loop_256: + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE(Y5, Y6, Y7) + CHACHA_SHUFFLE(Y9, Y10, Y11) + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) + CHACHA_SHUFFLE(Y7, Y6, Y5) + CHACHA_SHUFFLE(Y11, Y10, Y9) + SUBQ $2, R9 + JA chacha_loop_256 + + VPADDD Y0, Y4, Y4 + VPADDD Y1, Y5, Y5 + VPADDD Y2, Y6, Y6 + VPADDD Y3, Y7, Y7 + VPADDQ TWO, Y3, Y3 + XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) + VPADDD Y0, Y8, Y8 + VPADDD Y1, Y9, Y9 + VPADDD Y2, Y10, Y10 + VPADDD Y3, Y11, Y11 + VPADDQ TWO, Y3, Y3 + + CMPQ CX, $256 + JB less_than_256 + + XOR_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13) + SUBQ $256, CX + TESTQ CX, CX + JE done + + ADDQ $256, SI + ADDQ $256, DI + JMP between_0_and_64 + +less_than_256: + XOR_UPPER_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13) + EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12) + ADDQ $192, SI + ADDQ $192, DI + SUBQ $192, CX + JMP finalize + +between_64_and_192: + VMOVDQA Y0, Y4 + VMOVDQA Y1, Y5 + VMOVDQA Y2, Y6 + VMOVDQA Y3, Y7 + + MOVQ DX, R9 + +chacha_loop_128: + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_SHUFFLE(Y5, Y6, Y7) + CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) + CHACHA_SHUFFLE(Y7, Y6, Y5) + SUBQ $2, R9 + JA chacha_loop_128 + + VPADDD Y0, Y4, Y4 + VPADDD Y1, Y5, Y5 + VPADDD Y2, Y6, Y6 + VPADDD Y3, Y7, Y7 + VPADDQ TWO, Y3, Y3 + + CMPQ CX, $128 + JB less_than_128 + + XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) + SUBQ $128, CX + TESTQ CX, CX + JE done + + ADDQ $128, SI + ADDQ $128, DI + JMP between_0_and_64 + +less_than_128: + XOR_UPPER_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) + EXTRACT_LOWER(BX, Y4, Y5, Y6, Y7, Y13) + ADDQ $64, SI + ADDQ $64, DI + SUBQ $64, CX + JMP finalize + +between_0_and_64: + VMOVDQA X0, X4 + VMOVDQA X1, X5 + VMOVDQA X2, X6 + VMOVDQA X3, X7 + + MOVQ DX, R9 + +chacha_loop_64: + CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) + CHACHA_SHUFFLE(X7, X6, X5) + SUBQ $2, R9 + JA chacha_loop_64 + + VPADDD X0, X4, X4 + VPADDD X1, X5, X5 + VPADDD X2, X6, X6 + VPADDD X3, X7, X7 + VMOVDQU ·one_AVX<>(SB), X0 + VPADDQ X0, X3, X3 + + CMPQ CX, $64 + JB less_than_64 + + XOR_AVX(DI, SI, 0, X4, X5, X6, X7, X13) + SUBQ $64, CX + JMP done + +less_than_64: + VMOVDQU X4, 0(BX) + VMOVDQU X5, 16(BX) + VMOVDQU X6, 32(BX) + VMOVDQU X7, 48(BX) + +finalize: + XORQ R11, R11 + XORQ R12, R12 + MOVQ CX, BP + +xor_loop: + MOVB 0(SI), R11 + MOVB 0(BX), R12 + XORQ R11, R12 + MOVB R12, 0(DI) + INCQ SI + INCQ BX + INCQ DI + DECQ BP + JA xor_loop + +done: + VMOVDQU X3, 48(AX) + VZEROUPPER + MOVQ R8, SP + MOVQ CX, ret+72(FP) + RET + +// func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20AVX(SB), 4, $0-24 + MOVQ out+0(FP), DI + MOVQ nonce+8(FP), AX + MOVQ key+16(FP), BX + + VMOVDQU ·sigma_AVX<>(SB), X0 + VMOVDQU 0(BX), X1 + VMOVDQU 16(BX), X2 + VMOVDQU 0(AX), X3 + VMOVDQU ·rol16_AVX2<>(SB), X5 + VMOVDQU ·rol8_AVX2<>(SB), X6 + + MOVQ $20, CX + +chacha_loop: + CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE(X3, X2, X1) + SUBQ $2, CX + JNZ chacha_loop + + VMOVDQU X0, 0(DI) + VMOVDQU X3, 16(DI) + VZEROUPPER + RET + +// func supportsAVX2() bool +TEXT ·supportsAVX2(SB), 4, $0-1 + MOVQ runtime·support_avx(SB), AX + MOVQ runtime·support_avx2(SB), BX + ANDQ AX, BX + MOVB BX, ret+0(FP) + RET diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_386.go b/vendor/github.com/aead/chacha20/chacha/chacha_386.go new file mode 100644 index 0000000..e3135ef --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha_386.go @@ -0,0 +1,67 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build 386,!gccgo,!appengine,!nacl + +package chacha + +import "encoding/binary" + +func init() { + useSSE2 = supportsSSE2() + useSSSE3 = supportsSSSE3() + useAVX2 = false +} + +func initialize(state *[64]byte, key []byte, nonce *[16]byte) { + binary.LittleEndian.PutUint32(state[0:], sigma[0]) + binary.LittleEndian.PutUint32(state[4:], sigma[1]) + binary.LittleEndian.PutUint32(state[8:], sigma[2]) + binary.LittleEndian.PutUint32(state[12:], sigma[3]) + copy(state[16:], key[:]) + copy(state[48:], nonce[:]) +} + +// This function is implemented in chacha_386.s +//go:noescape +func supportsSSE2() bool + +// This function is implemented in chacha_386.s +//go:noescape +func supportsSSSE3() bool + +// This function is implemented in chacha_386.s +//go:noescape +func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) + +// This function is implemented in chacha_386.s +//go:noescape +func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) + +// This function is implemented in chacha_386.s +//go:noescape +func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int + +// This function is implemented in chacha_386.s +//go:noescape +func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int + +func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { + if useSSSE3 { + hChaCha20SSSE3(out, nonce, key) + } else if useSSE2 { + hChaCha20SSE2(out, nonce, key) + } else { + hChaCha20Generic(out, nonce, key) + } +} + +func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { + if useSSSE3 { + return xorKeyStreamSSSE3(dst, src, block, state, rounds) + } else if useSSE2 { + return xorKeyStreamSSE2(dst, src, block, state, rounds) + } + return xorKeyStreamGeneric(dst, src, block, state, rounds) +} diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_386.s b/vendor/github.com/aead/chacha20/chacha/chacha_386.s new file mode 100644 index 0000000..d7bba75 --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha_386.s @@ -0,0 +1,311 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build 386,!gccgo,!appengine,!nacl + +#include "textflag.h" + +DATA ·sigma<>+0x00(SB)/4, $0x61707865 +DATA ·sigma<>+0x04(SB)/4, $0x3320646e +DATA ·sigma<>+0x08(SB)/4, $0x79622d32 +DATA ·sigma<>+0x0C(SB)/4, $0x6b206574 +GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 + +DATA ·one<>+0x00(SB)/8, $1 +DATA ·one<>+0x08(SB)/8, $0 +GLOBL ·one<>(SB), (NOPTR+RODATA), $16 + +DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 +DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A +GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 + +DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 +DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B +GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 + +#define ROTL_SSE2(n, t, v) \ + MOVO v, t; \ + PSLLL $n, t; \ + PSRLL $(32-n), v; \ + PXOR t, v + +#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \ + PADDL v1, v0; \ + PXOR v0, v3; \ + ROTL_SSE2(16, t0, v3); \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(12, t0, v1); \ + PADDL v1, v0; \ + PXOR v0, v3; \ + ROTL_SSE2(8, t0, v3); \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(7, t0, v1) + +#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \ + PADDL v1, v0; \ + PXOR v0, v3; \ + PSHUFB r16, v3; \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(12, t0, v1); \ + PADDL v1, v0; \ + PXOR v0, v3; \ + PSHUFB r8, v3; \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(7, t0, v1) + +#define CHACHA_SHUFFLE(v1, v2, v3) \ + PSHUFL $0x39, v1, v1; \ + PSHUFL $0x4E, v2, v2; \ + PSHUFL $0x93, v3, v3 + +#define XOR(dst, src, off, v0, v1, v2, v3, t0) \ + MOVOU 0+off(src), t0; \ + PXOR v0, t0; \ + MOVOU t0, 0+off(dst); \ + MOVOU 16+off(src), t0; \ + PXOR v1, t0; \ + MOVOU t0, 16+off(dst); \ + MOVOU 32+off(src), t0; \ + PXOR v2, t0; \ + MOVOU t0, 32+off(dst); \ + MOVOU 48+off(src), t0; \ + PXOR v3, t0; \ + MOVOU t0, 48+off(dst) + +#define FINALIZE(dst, src, block, len, t0, t1) \ + XORL t0, t0; \ + XORL t1, t1; \ + finalize: \ + MOVB 0(src), t0; \ + MOVB 0(block), t1; \ + XORL t0, t1; \ + MOVB t1, 0(dst); \ + INCL src; \ + INCL block; \ + INCL dst; \ + DECL len; \ + JA finalize \ + +// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamSSE2(SB), 4, $0-40 + MOVL dst_base+0(FP), DI + MOVL src_base+12(FP), SI + MOVL src_len+16(FP), CX + MOVL state+28(FP), AX + MOVL rounds+32(FP), DX + + MOVOU 0(AX), X0 + MOVOU 16(AX), X1 + MOVOU 32(AX), X2 + MOVOU 48(AX), X3 + + TESTL CX, CX + JZ done + +at_least_64: + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + + MOVL DX, BX + +chacha_loop: + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) + CHACHA_SHUFFLE(X7, X6, X5) + SUBL $2, BX + JA chacha_loop + + MOVOU 0(AX), X0 + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + MOVOU ·one<>(SB), X0 + PADDQ X0, X3 + + CMPL CX, $64 + JB less_than_64 + + XOR(DI, SI, 0, X4, X5, X6, X7, X0) + MOVOU 0(AX), X0 + ADDL $64, SI + ADDL $64, DI + SUBL $64, CX + JNZ at_least_64 + +less_than_64: + MOVL CX, BP + TESTL BP, BP + JZ done + + MOVL block+24(FP), BX + MOVOU X4, 0(BX) + MOVOU X5, 16(BX) + MOVOU X6, 32(BX) + MOVOU X7, 48(BX) + FINALIZE(DI, SI, BX, BP, AX, DX) + +done: + MOVL state+28(FP), AX + MOVOU X3, 48(AX) + MOVL CX, ret+36(FP) + RET + +// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamSSSE3(SB), 4, $64-40 + MOVL dst_base+0(FP), DI + MOVL src_base+12(FP), SI + MOVL src_len+16(FP), CX + MOVL state+28(FP), AX + MOVL rounds+32(FP), DX + + MOVOU 48(AX), X3 + TESTL CX, CX + JZ done + + MOVL SP, BP + ADDL $16, SP + ANDL $-16, SP + + MOVOU ·one<>(SB), X0 + MOVOU 16(AX), X1 + MOVOU 32(AX), X2 + MOVO X0, 0(SP) + MOVO X1, 16(SP) + MOVO X2, 32(SP) + + MOVOU 0(AX), X0 + MOVOU ·rol16<>(SB), X1 + MOVOU ·rol8<>(SB), X2 + +at_least_64: + MOVO X0, X4 + MOVO 16(SP), X5 + MOVO 32(SP), X6 + MOVO X3, X7 + + MOVL DX, BX + +chacha_loop: + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) + CHACHA_SHUFFLE(X7, X6, X5) + SUBL $2, BX + JA chacha_loop + + MOVOU 0(AX), X0 + PADDL X0, X4 + PADDL 16(SP), X5 + PADDL 32(SP), X6 + PADDL X3, X7 + PADDQ 0(SP), X3 + + CMPL CX, $64 + JB less_than_64 + + XOR(DI, SI, 0, X4, X5, X6, X7, X0) + MOVOU 0(AX), X0 + ADDL $64, SI + ADDL $64, DI + SUBL $64, CX + JNZ at_least_64 + +less_than_64: + MOVL BP, SP + MOVL CX, BP + TESTL BP, BP + JE done + + MOVL block+24(FP), BX + MOVOU X4, 0(BX) + MOVOU X5, 16(BX) + MOVOU X6, 32(BX) + MOVOU X7, 48(BX) + FINALIZE(DI, SI, BX, BP, AX, DX) + +done: + MOVL state+28(FP), AX + MOVOU X3, 48(AX) + MOVL CX, ret+36(FP) + RET + +// func supportsSSE2() bool +TEXT ·supportsSSE2(SB), NOSPLIT, $0-1 + XORL AX, AX + INCL AX + CPUID + SHRL $26, DX + ANDL $1, DX + MOVB DX, ret+0(FP) + RET + +// func supportsSSSE3() bool +TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1 + XORL AX, AX + INCL AX + CPUID + SHRL $9, CX + ANDL $1, CX + MOVB CX, ret+0(FP) + RET + +// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20SSE2(SB), 4, $0-12 + MOVL out+0(FP), DI + MOVL nonce+4(FP), AX + MOVL key+8(FP), BX + + MOVOU ·sigma<>(SB), X0 + MOVOU 0(BX), X1 + MOVOU 16(BX), X2 + MOVOU 0(AX), X3 + + MOVL $20, CX + +chacha_loop: + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_SHUFFLE(X3, X2, X1) + SUBL $2, CX + JNZ chacha_loop + + MOVOU X0, 0(DI) + MOVOU X3, 16(DI) + RET + +// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20SSSE3(SB), 4, $0-12 + MOVL out+0(FP), DI + MOVL nonce+4(FP), AX + MOVL key+8(FP), BX + + MOVOU ·sigma<>(SB), X0 + MOVOU 0(BX), X1 + MOVOU 16(BX), X2 + MOVOU 0(AX), X3 + MOVOU ·rol16<>(SB), X5 + MOVOU ·rol8<>(SB), X6 + + MOVL $20, CX + +chacha_loop: + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE(X3, X2, X1) + SUBL $2, CX + JNZ chacha_loop + + MOVOU X0, 0(DI) + MOVOU X3, 16(DI) + RET diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_amd64.s b/vendor/github.com/aead/chacha20/chacha/chacha_amd64.s new file mode 100644 index 0000000..5bc41ef --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha_amd64.s @@ -0,0 +1,788 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build amd64,!gccgo,!appengine,!nacl + +#include "textflag.h" + +DATA ·sigma<>+0x00(SB)/4, $0x61707865 +DATA ·sigma<>+0x04(SB)/4, $0x3320646e +DATA ·sigma<>+0x08(SB)/4, $0x79622d32 +DATA ·sigma<>+0x0C(SB)/4, $0x6b206574 +GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 + +DATA ·one<>+0x00(SB)/8, $1 +DATA ·one<>+0x08(SB)/8, $0 +GLOBL ·one<>(SB), (NOPTR+RODATA), $16 + +DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 +DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A +GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 + +DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 +DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B +GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 + +#define ROTL_SSE2(n, t, v) \ + MOVO v, t; \ + PSLLL $n, t; \ + PSRLL $(32-n), v; \ + PXOR t, v + +#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \ + PADDL v1, v0; \ + PXOR v0, v3; \ + ROTL_SSE2(16, t0, v3); \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(12, t0, v1); \ + PADDL v1, v0; \ + PXOR v0, v3; \ + ROTL_SSE2(8, t0, v3); \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(7, t0, v1) + +#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \ + PADDL v1, v0; \ + PXOR v0, v3; \ + PSHUFB r16, v3; \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(12, t0, v1); \ + PADDL v1, v0; \ + PXOR v0, v3; \ + PSHUFB r8, v3; \ + PADDL v3, v2; \ + PXOR v2, v1; \ + ROTL_SSE2(7, t0, v1) + +#define CHACHA_SHUFFLE(v1, v2, v3) \ + PSHUFL $0x39, v1, v1; \ + PSHUFL $0x4E, v2, v2; \ + PSHUFL $0x93, v3, v3 + +#define XOR(dst, src, off, v0, v1, v2, v3, t0) \ + MOVOU 0+off(src), t0; \ + PXOR v0, t0; \ + MOVOU t0, 0+off(dst); \ + MOVOU 16+off(src), t0; \ + PXOR v1, t0; \ + MOVOU t0, 16+off(dst); \ + MOVOU 32+off(src), t0; \ + PXOR v2, t0; \ + MOVOU t0, 32+off(dst); \ + MOVOU 48+off(src), t0; \ + PXOR v3, t0; \ + MOVOU t0, 48+off(dst) + +// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamSSE2(SB), 4, $112-80 + MOVQ dst_base+0(FP), DI + MOVQ src_base+24(FP), SI + MOVQ src_len+32(FP), CX + MOVQ block+48(FP), BX + MOVQ state+56(FP), AX + MOVQ rounds+64(FP), DX + + MOVQ SP, R9 + ADDQ $16, SP + ANDQ $-16, SP + + MOVOU 0(AX), X0 + MOVOU 16(AX), X1 + MOVOU 32(AX), X2 + MOVOU 48(AX), X3 + MOVOU ·one<>(SB), X15 + + TESTQ CX, CX + JZ done + + CMPQ CX, $64 + JBE between_0_and_64 + + CMPQ CX, $128 + JBE between_64_and_128 + + MOVO X0, 0(SP) + MOVO X1, 16(SP) + MOVO X2, 32(SP) + MOVO X3, 48(SP) + MOVO X15, 64(SP) + + CMPQ CX, $192 + JBE between_128_and_192 + + MOVQ $192, R14 + +at_least_256: + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + PADDQ 64(SP), X7 + MOVO X0, X12 + MOVO X1, X13 + MOVO X2, X14 + MOVO X7, X15 + PADDQ 64(SP), X15 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X15, X11 + PADDQ 64(SP), X11 + + MOVQ DX, R8 + +chacha_loop_256: + MOVO X8, 80(SP) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) + MOVO 80(SP), X8 + + MOVO X0, 80(SP) + CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) + MOVO 80(SP), X0 + + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_SHUFFLE(X13, X14, X15) + CHACHA_SHUFFLE(X9, X10, X11) + + MOVO X8, 80(SP) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X8) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X8) + MOVO 80(SP), X8 + + MOVO X0, 80(SP) + CHACHA_QROUND_SSE2(X12, X13, X14, X15, X0) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X0) + MOVO 80(SP), X0 + + CHACHA_SHUFFLE(X3, X2, X1) + CHACHA_SHUFFLE(X7, X6, X5) + CHACHA_SHUFFLE(X15, X14, X13) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_256 + + MOVO X8, 80(SP) + + PADDL 0(SP), X0 + PADDL 16(SP), X1 + PADDL 32(SP), X2 + PADDL 48(SP), X3 + XOR(DI, SI, 0, X0, X1, X2, X3, X8) + + MOVO 0(SP), X0 + MOVO 16(SP), X1 + MOVO 32(SP), X2 + MOVO 48(SP), X3 + PADDQ 64(SP), X3 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 64(SP), X3 + XOR(DI, SI, 64, X4, X5, X6, X7, X8) + + MOVO 64(SP), X5 + MOVO 80(SP), X8 + + PADDL X0, X12 + PADDL X1, X13 + PADDL X2, X14 + PADDL X3, X15 + PADDQ X5, X3 + XOR(DI, SI, 128, X12, X13, X14, X15, X4) + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X5, X3 + + CMPQ CX, $256 + JB less_than_64 + + XOR(DI, SI, 192, X8, X9, X10, X11, X4) + MOVO X3, 48(SP) + ADDQ $256, SI + ADDQ $256, DI + SUBQ $256, CX + CMPQ CX, $192 + JA at_least_256 + + TESTQ CX, CX + JZ done + MOVO 64(SP), X15 + CMPQ CX, $64 + JBE between_0_and_64 + CMPQ CX, $128 + JBE between_64_and_128 + +between_128_and_192: + MOVQ $128, R14 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + PADDQ X15, X7 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X7, X11 + PADDQ X15, X11 + + MOVQ DX, R8 + +chacha_loop_192: + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_SHUFFLE(X9, X10, X11) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X12) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_SHUFFLE(X3, X2, X1) + CHACHA_SHUFFLE(X7, X6, X5) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_192 + + PADDL 0(SP), X0 + PADDL 16(SP), X1 + PADDL 32(SP), X2 + PADDL 48(SP), X3 + XOR(DI, SI, 0, X0, X1, X2, X3, X12) + + MOVO 0(SP), X0 + MOVO 16(SP), X1 + MOVO 32(SP), X2 + MOVO 48(SP), X3 + PADDQ X15, X3 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ X15, X3 + XOR(DI, SI, 64, X4, X5, X6, X7, X12) + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X15, X3 + + CMPQ CX, $192 + JB less_than_64 + + XOR(DI, SI, 128, X8, X9, X10, X11, X12) + SUBQ $192, CX + JMP done + +between_64_and_128: + MOVQ $64, R14 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + PADDQ X15, X11 + + MOVQ DX, R8 + +chacha_loop_128: + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_SHUFFLE(X9, X10, X11) + CHACHA_QROUND_SSE2(X4, X5, X6, X7, X12) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_SHUFFLE(X7, X6, X5) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_128 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ X15, X3 + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X15, X3 + XOR(DI, SI, 0, X4, X5, X6, X7, X12) + + CMPQ CX, $128 + JB less_than_64 + + XOR(DI, SI, 64, X8, X9, X10, X11, X12) + SUBQ $128, CX + JMP done + +between_0_and_64: + MOVQ $0, R14 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + MOVQ DX, R8 + +chacha_loop_64: + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_SHUFFLE(X9, X10, X11) + CHACHA_QROUND_SSE2(X8, X9, X10, X11, X12) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_64 + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X15, X3 + CMPQ CX, $64 + JB less_than_64 + + XOR(DI, SI, 0, X8, X9, X10, X11, X12) + SUBQ $64, CX + JMP done + +less_than_64: + // R14 contains the num of bytes already xor'd + ADDQ R14, SI + ADDQ R14, DI + SUBQ R14, CX + MOVOU X8, 0(BX) + MOVOU X9, 16(BX) + MOVOU X10, 32(BX) + MOVOU X11, 48(BX) + XORQ R11, R11 + XORQ R12, R12 + MOVQ CX, BP + +xor_loop: + MOVB 0(SI), R11 + MOVB 0(BX), R12 + XORQ R11, R12 + MOVB R12, 0(DI) + INCQ SI + INCQ BX + INCQ DI + DECQ BP + JA xor_loop + +done: + MOVOU X3, 48(AX) + MOVQ R9, SP + MOVQ CX, ret+72(FP) + RET + +// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int +TEXT ·xorKeyStreamSSSE3(SB), 4, $144-80 + MOVQ dst_base+0(FP), DI + MOVQ src_base+24(FP), SI + MOVQ src_len+32(FP), CX + MOVQ block+48(FP), BX + MOVQ state+56(FP), AX + MOVQ rounds+64(FP), DX + + MOVQ SP, R9 + ADDQ $16, SP + ANDQ $-16, SP + + MOVOU 0(AX), X0 + MOVOU 16(AX), X1 + MOVOU 32(AX), X2 + MOVOU 48(AX), X3 + MOVOU ·rol16<>(SB), X13 + MOVOU ·rol8<>(SB), X14 + MOVOU ·one<>(SB), X15 + + TESTQ CX, CX + JZ done + + CMPQ CX, $64 + JBE between_0_and_64 + + CMPQ CX, $128 + JBE between_64_and_128 + + MOVO X0, 0(SP) + MOVO X1, 16(SP) + MOVO X2, 32(SP) + MOVO X3, 48(SP) + MOVO X15, 64(SP) + + CMPQ CX, $192 + JBE between_128_and_192 + + MOVO X13, 96(SP) + MOVO X14, 112(SP) + MOVQ $192, R14 + +at_least_256: + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + PADDQ 64(SP), X7 + MOVO X0, X12 + MOVO X1, X13 + MOVO X2, X14 + MOVO X7, X15 + PADDQ 64(SP), X15 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X15, X11 + PADDQ 64(SP), X11 + + MOVQ DX, R8 + +chacha_loop_256: + MOVO X8, 80(SP) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP)) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, 96(SP), 112(SP)) + MOVO 80(SP), X8 + + MOVO X0, 80(SP) + CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, 96(SP), 112(SP)) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 96(SP), 112(SP)) + MOVO 80(SP), X0 + + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_SHUFFLE(X13, X14, X15) + CHACHA_SHUFFLE(X9, X10, X11) + + MOVO X8, 80(SP) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X8, 96(SP), 112(SP)) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X8, 96(SP), 112(SP)) + MOVO 80(SP), X8 + + MOVO X0, 80(SP) + CHACHA_QROUND_SSSE3(X12, X13, X14, X15, X0, 96(SP), 112(SP)) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X0, 96(SP), 112(SP)) + MOVO 80(SP), X0 + + CHACHA_SHUFFLE(X3, X2, X1) + CHACHA_SHUFFLE(X7, X6, X5) + CHACHA_SHUFFLE(X15, X14, X13) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_256 + + MOVO X8, 80(SP) + + PADDL 0(SP), X0 + PADDL 16(SP), X1 + PADDL 32(SP), X2 + PADDL 48(SP), X3 + XOR(DI, SI, 0, X0, X1, X2, X3, X8) + MOVO 0(SP), X0 + MOVO 16(SP), X1 + MOVO 32(SP), X2 + MOVO 48(SP), X3 + PADDQ 64(SP), X3 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ 64(SP), X3 + XOR(DI, SI, 64, X4, X5, X6, X7, X8) + + MOVO 64(SP), X5 + MOVO 80(SP), X8 + + PADDL X0, X12 + PADDL X1, X13 + PADDL X2, X14 + PADDL X3, X15 + PADDQ X5, X3 + XOR(DI, SI, 128, X12, X13, X14, X15, X4) + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X5, X3 + + CMPQ CX, $256 + JB less_than_64 + + XOR(DI, SI, 192, X8, X9, X10, X11, X4) + MOVO X3, 48(SP) + ADDQ $256, SI + ADDQ $256, DI + SUBQ $256, CX + CMPQ CX, $192 + JA at_least_256 + + TESTQ CX, CX + JZ done + MOVOU ·rol16<>(SB), X13 + MOVOU ·rol8<>(SB), X14 + MOVO 64(SP), X15 + CMPQ CX, $64 + JBE between_0_and_64 + CMPQ CX, $128 + JBE between_64_and_128 + +between_128_and_192: + MOVQ $128, R14 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + PADDQ X15, X7 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X7, X11 + PADDQ X15, X11 + + MOVQ DX, R8 + +chacha_loop_192: + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_SHUFFLE(X9, X10, X11) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X12, X13, X14) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_SHUFFLE(X3, X2, X1) + CHACHA_SHUFFLE(X7, X6, X5) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_192 + + PADDL 0(SP), X0 + PADDL 16(SP), X1 + PADDL 32(SP), X2 + PADDL 48(SP), X3 + XOR(DI, SI, 0, X0, X1, X2, X3, X12) + + MOVO 0(SP), X0 + MOVO 16(SP), X1 + MOVO 32(SP), X2 + MOVO 48(SP), X3 + PADDQ X15, X3 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ X15, X3 + XOR(DI, SI, 64, X4, X5, X6, X7, X12) + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X15, X3 + + CMPQ CX, $192 + JB less_than_64 + + XOR(DI, SI, 128, X8, X9, X10, X11, X12) + SUBQ $192, CX + JMP done + +between_64_and_128: + MOVQ $64, R14 + MOVO X0, X4 + MOVO X1, X5 + MOVO X2, X6 + MOVO X3, X7 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + PADDQ X15, X11 + + MOVQ DX, R8 + +chacha_loop_128: + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_SHUFFLE(X5, X6, X7) + CHACHA_SHUFFLE(X9, X10, X11) + CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X12, X13, X14) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_SHUFFLE(X7, X6, X5) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_128 + + PADDL X0, X4 + PADDL X1, X5 + PADDL X2, X6 + PADDL X3, X7 + PADDQ X15, X3 + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X15, X3 + XOR(DI, SI, 0, X4, X5, X6, X7, X12) + + CMPQ CX, $128 + JB less_than_64 + + XOR(DI, SI, 64, X8, X9, X10, X11, X12) + SUBQ $128, CX + JMP done + +between_0_and_64: + MOVQ $0, R14 + MOVO X0, X8 + MOVO X1, X9 + MOVO X2, X10 + MOVO X3, X11 + MOVQ DX, R8 + +chacha_loop_64: + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_SHUFFLE(X9, X10, X11) + CHACHA_QROUND_SSSE3(X8, X9, X10, X11, X12, X13, X14) + CHACHA_SHUFFLE(X11, X10, X9) + SUBQ $2, R8 + JA chacha_loop_64 + + PADDL X0, X8 + PADDL X1, X9 + PADDL X2, X10 + PADDL X3, X11 + PADDQ X15, X3 + CMPQ CX, $64 + JB less_than_64 + + XOR(DI, SI, 0, X8, X9, X10, X11, X12) + SUBQ $64, CX + JMP done + +less_than_64: + // R14 contains the num of bytes already xor'd + ADDQ R14, SI + ADDQ R14, DI + SUBQ R14, CX + MOVOU X8, 0(BX) + MOVOU X9, 16(BX) + MOVOU X10, 32(BX) + MOVOU X11, 48(BX) + XORQ R11, R11 + XORQ R12, R12 + MOVQ CX, BP + +xor_loop: + MOVB 0(SI), R11 + MOVB 0(BX), R12 + XORQ R11, R12 + MOVB R12, 0(DI) + INCQ SI + INCQ BX + INCQ DI + DECQ BP + JA xor_loop + +done: + MOVQ R9, SP + MOVOU X3, 48(AX) + MOVQ CX, ret+72(FP) + RET + +// func supportsSSSE3() bool +TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1 + XORQ AX, AX + INCQ AX + CPUID + SHRQ $9, CX + ANDQ $1, CX + MOVB CX, ret+0(FP) + RET + +// func initialize(state *[64]byte, key []byte, nonce *[16]byte) +TEXT ·initialize(SB), 4, $0-40 + MOVQ state+0(FP), DI + MOVQ key+8(FP), AX + MOVQ nonce+32(FP), BX + + MOVOU ·sigma<>(SB), X0 + MOVOU 0(AX), X1 + MOVOU 16(AX), X2 + MOVOU 0(BX), X3 + + MOVOU X0, 0(DI) + MOVOU X1, 16(DI) + MOVOU X2, 32(DI) + MOVOU X3, 48(DI) + RET + +// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20SSE2(SB), 4, $0-24 + MOVQ out+0(FP), DI + MOVQ nonce+8(FP), AX + MOVQ key+16(FP), BX + + MOVOU ·sigma<>(SB), X0 + MOVOU 0(BX), X1 + MOVOU 16(BX), X2 + MOVOU 0(AX), X3 + + MOVQ $20, CX + +chacha_loop: + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) + CHACHA_SHUFFLE(X3, X2, X1) + SUBQ $2, CX + JNZ chacha_loop + + MOVOU X0, 0(DI) + MOVOU X3, 16(DI) + RET + +// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) +TEXT ·hChaCha20SSSE3(SB), 4, $0-24 + MOVQ out+0(FP), DI + MOVQ nonce+8(FP), AX + MOVQ key+16(FP), BX + + MOVOU ·sigma<>(SB), X0 + MOVOU 0(BX), X1 + MOVOU 16(BX), X2 + MOVOU 0(AX), X3 + MOVOU ·rol16<>(SB), X5 + MOVOU ·rol8<>(SB), X6 + + MOVQ $20, CX + +chacha_loop: + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE(X1, X2, X3) + CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) + CHACHA_SHUFFLE(X3, X2, X1) + SUBQ $2, CX + JNZ chacha_loop + + MOVOU X0, 0(DI) + MOVOU X3, 16(DI) + RET diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_generic.go b/vendor/github.com/aead/chacha20/chacha/chacha_generic.go new file mode 100644 index 0000000..8832d5b --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha_generic.go @@ -0,0 +1,319 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +package chacha + +import "encoding/binary" + +var sigma = [4]uint32{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574} + +func xorKeyStreamGeneric(dst, src []byte, block, state *[64]byte, rounds int) int { + for len(src) >= 64 { + chachaGeneric(block, state, rounds) + + for i, v := range block { + dst[i] = src[i] ^ v + } + src = src[64:] + dst = dst[64:] + } + + n := len(src) + if n > 0 { + chachaGeneric(block, state, rounds) + for i, v := range src { + dst[i] = v ^ block[i] + } + } + return n +} + +func chachaGeneric(dst *[64]byte, state *[64]byte, rounds int) { + v00 := binary.LittleEndian.Uint32(state[0:]) + v01 := binary.LittleEndian.Uint32(state[4:]) + v02 := binary.LittleEndian.Uint32(state[8:]) + v03 := binary.LittleEndian.Uint32(state[12:]) + v04 := binary.LittleEndian.Uint32(state[16:]) + v05 := binary.LittleEndian.Uint32(state[20:]) + v06 := binary.LittleEndian.Uint32(state[24:]) + v07 := binary.LittleEndian.Uint32(state[28:]) + v08 := binary.LittleEndian.Uint32(state[32:]) + v09 := binary.LittleEndian.Uint32(state[36:]) + v10 := binary.LittleEndian.Uint32(state[40:]) + v11 := binary.LittleEndian.Uint32(state[44:]) + v12 := binary.LittleEndian.Uint32(state[48:]) + v13 := binary.LittleEndian.Uint32(state[52:]) + v14 := binary.LittleEndian.Uint32(state[56:]) + v15 := binary.LittleEndian.Uint32(state[60:]) + + s00, s01, s02, s03, s04, s05, s06, s07 := v00, v01, v02, v03, v04, v05, v06, v07 + s08, s09, s10, s11, s12, s13, s14, s15 := v08, v09, v10, v11, v12, v13, v14, v15 + + for i := 0; i < rounds; i += 2 { + v00 += v04 + v12 ^= v00 + v12 = (v12 << 16) | (v12 >> 16) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 12) | (v04 >> 20) + v00 += v04 + v12 ^= v00 + v12 = (v12 << 8) | (v12 >> 24) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 7) | (v04 >> 25) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 16) | (v13 >> 16) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 12) | (v05 >> 20) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 8) | (v13 >> 24) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 7) | (v05 >> 25) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 16) | (v14 >> 16) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 12) | (v06 >> 20) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 8) | (v14 >> 24) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 7) | (v06 >> 25) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 16) | (v15 >> 16) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 12) | (v07 >> 20) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 8) | (v15 >> 24) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 7) | (v07 >> 25) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 16) | (v15 >> 16) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 12) | (v05 >> 20) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 8) | (v15 >> 24) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 7) | (v05 >> 25) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 16) | (v12 >> 16) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 12) | (v06 >> 20) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 8) | (v12 >> 24) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 7) | (v06 >> 25) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 16) | (v13 >> 16) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 12) | (v07 >> 20) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 8) | (v13 >> 24) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 7) | (v07 >> 25) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 16) | (v14 >> 16) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 12) | (v04 >> 20) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 8) | (v14 >> 24) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 7) | (v04 >> 25) + } + + v00 += s00 + v01 += s01 + v02 += s02 + v03 += s03 + v04 += s04 + v05 += s05 + v06 += s06 + v07 += s07 + v08 += s08 + v09 += s09 + v10 += s10 + v11 += s11 + v12 += s12 + v13 += s13 + v14 += s14 + v15 += s15 + + s12++ + binary.LittleEndian.PutUint32(state[48:], s12) + if s12 == 0 { // indicates overflow + s13++ + binary.LittleEndian.PutUint32(state[52:], s13) + } + + binary.LittleEndian.PutUint32(dst[0:], v00) + binary.LittleEndian.PutUint32(dst[4:], v01) + binary.LittleEndian.PutUint32(dst[8:], v02) + binary.LittleEndian.PutUint32(dst[12:], v03) + binary.LittleEndian.PutUint32(dst[16:], v04) + binary.LittleEndian.PutUint32(dst[20:], v05) + binary.LittleEndian.PutUint32(dst[24:], v06) + binary.LittleEndian.PutUint32(dst[28:], v07) + binary.LittleEndian.PutUint32(dst[32:], v08) + binary.LittleEndian.PutUint32(dst[36:], v09) + binary.LittleEndian.PutUint32(dst[40:], v10) + binary.LittleEndian.PutUint32(dst[44:], v11) + binary.LittleEndian.PutUint32(dst[48:], v12) + binary.LittleEndian.PutUint32(dst[52:], v13) + binary.LittleEndian.PutUint32(dst[56:], v14) + binary.LittleEndian.PutUint32(dst[60:], v15) +} + +func hChaCha20Generic(out *[32]byte, nonce *[16]byte, key *[32]byte) { + v00 := sigma[0] + v01 := sigma[1] + v02 := sigma[2] + v03 := sigma[3] + v04 := binary.LittleEndian.Uint32(key[0:]) + v05 := binary.LittleEndian.Uint32(key[4:]) + v06 := binary.LittleEndian.Uint32(key[8:]) + v07 := binary.LittleEndian.Uint32(key[12:]) + v08 := binary.LittleEndian.Uint32(key[16:]) + v09 := binary.LittleEndian.Uint32(key[20:]) + v10 := binary.LittleEndian.Uint32(key[24:]) + v11 := binary.LittleEndian.Uint32(key[28:]) + v12 := binary.LittleEndian.Uint32(nonce[0:]) + v13 := binary.LittleEndian.Uint32(nonce[4:]) + v14 := binary.LittleEndian.Uint32(nonce[8:]) + v15 := binary.LittleEndian.Uint32(nonce[12:]) + + for i := 0; i < 20; i += 2 { + v00 += v04 + v12 ^= v00 + v12 = (v12 << 16) | (v12 >> 16) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 12) | (v04 >> 20) + v00 += v04 + v12 ^= v00 + v12 = (v12 << 8) | (v12 >> 24) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 7) | (v04 >> 25) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 16) | (v13 >> 16) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 12) | (v05 >> 20) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 8) | (v13 >> 24) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 7) | (v05 >> 25) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 16) | (v14 >> 16) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 12) | (v06 >> 20) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 8) | (v14 >> 24) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 7) | (v06 >> 25) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 16) | (v15 >> 16) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 12) | (v07 >> 20) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 8) | (v15 >> 24) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 7) | (v07 >> 25) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 16) | (v15 >> 16) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 12) | (v05 >> 20) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 8) | (v15 >> 24) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 7) | (v05 >> 25) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 16) | (v12 >> 16) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 12) | (v06 >> 20) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 8) | (v12 >> 24) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 7) | (v06 >> 25) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 16) | (v13 >> 16) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 12) | (v07 >> 20) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 8) | (v13 >> 24) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 7) | (v07 >> 25) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 16) | (v14 >> 16) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 12) | (v04 >> 20) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 8) | (v14 >> 24) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 7) | (v04 >> 25) + } + + binary.LittleEndian.PutUint32(out[0:], v00) + binary.LittleEndian.PutUint32(out[4:], v01) + binary.LittleEndian.PutUint32(out[8:], v02) + binary.LittleEndian.PutUint32(out[12:], v03) + binary.LittleEndian.PutUint32(out[16:], v12) + binary.LittleEndian.PutUint32(out[20:], v13) + binary.LittleEndian.PutUint32(out[24:], v14) + binary.LittleEndian.PutUint32(out[28:], v15) +} diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go b/vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go new file mode 100644 index 0000000..0dcb302 --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha_go16_amd64.go @@ -0,0 +1,56 @@ +// Copyright (c) 2017 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build amd64,!gccgo,!appengine,!nacl,!go1.7 + +package chacha + +func init() { + useSSE2 = true + useSSSE3 = supportsSSSE3() + useAVX2 = false +} + +// This function is implemented in chacha_amd64.s +//go:noescape +func initialize(state *[64]byte, key []byte, nonce *[16]byte) + +// This function is implemented in chacha_amd64.s +//go:noescape +func supportsSSSE3() bool + +// This function is implemented in chacha_amd64.s +//go:noescape +func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) + +// This function is implemented in chacha_amd64.s +//go:noescape +func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) + +// This function is implemented in chacha_amd64.s +//go:noescape +func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int + +// This function is implemented in chacha_amd64.s +//go:noescape +func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int + +func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { + if useSSSE3 { + hChaCha20SSSE3(out, nonce, key) + } else if useSSE2 { // on amd64 this is always true - used to test generic on amd64 + hChaCha20SSE2(out, nonce, key) + } else { + hChaCha20Generic(out, nonce, key) + } +} + +func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { + if useSSSE3 { + return xorKeyStreamSSSE3(dst, src, block, state, rounds) + } else if useSSE2 { // on amd64 this is always true - used to test generic on amd64 + return xorKeyStreamSSE2(dst, src, block, state, rounds) + } + return xorKeyStreamGeneric(dst, src, block, state, rounds) +} diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_go17_amd64.go b/vendor/github.com/aead/chacha20/chacha/chacha_go17_amd64.go new file mode 100644 index 0000000..9ff41cf --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha_go17_amd64.go @@ -0,0 +1,72 @@ +// Copyright (c) 2017 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build go1.7,amd64,!gccgo,!appengine,!nacl + +package chacha + +func init() { + useSSE2 = true + useSSSE3 = supportsSSSE3() + useAVX2 = supportsAVX2() +} + +// This function is implemented in chacha_amd64.s +//go:noescape +func initialize(state *[64]byte, key []byte, nonce *[16]byte) + +// This function is implemented in chacha_amd64.s +//go:noescape +func supportsSSSE3() bool + +// This function is implemented in chachaAVX2_amd64.s +//go:noescape +func supportsAVX2() bool + +// This function is implemented in chacha_amd64.s +//go:noescape +func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) + +// This function is implemented in chacha_amd64.s +//go:noescape +func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) + +// This function is implemented in chachaAVX2_amd64.s +//go:noescape +func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte) + +// This function is implemented in chacha_amd64.s +//go:noescape +func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int + +// This function is implemented in chacha_amd64.s +//go:noescape +func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int + +// This function is implemented in chachaAVX2_amd64.s +//go:noescape +func xorKeyStreamAVX2(dst, src []byte, block, state *[64]byte, rounds int) int + +func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { + if useAVX2 { + hChaCha20AVX(out, nonce, key) + } else if useSSSE3 { + hChaCha20SSSE3(out, nonce, key) + } else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64 + hChaCha20SSE2(out, nonce, key) + } else { + hChaCha20Generic(out, nonce, key) + } +} + +func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { + if useAVX2 { + return xorKeyStreamAVX2(dst, src, block, state, rounds) + } else if useSSSE3 { + return xorKeyStreamSSSE3(dst, src, block, state, rounds) + } else if useSSE2 { // on amd64 this is always true - neccessary for testing generic on amd64 + return xorKeyStreamSSE2(dst, src, block, state, rounds) + } + return xorKeyStreamGeneric(dst, src, block, state, rounds) +} diff --git a/vendor/github.com/aead/chacha20/chacha/chacha_ref.go b/vendor/github.com/aead/chacha20/chacha/chacha_ref.go new file mode 100644 index 0000000..2c95a0c --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha/chacha_ref.go @@ -0,0 +1,26 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// +build !amd64,!386 gccgo appengine nacl + +package chacha + +import "encoding/binary" + +func initialize(state *[64]byte, key []byte, nonce *[16]byte) { + binary.LittleEndian.PutUint32(state[0:], sigma[0]) + binary.LittleEndian.PutUint32(state[4:], sigma[1]) + binary.LittleEndian.PutUint32(state[8:], sigma[2]) + binary.LittleEndian.PutUint32(state[12:], sigma[3]) + copy(state[16:], key[:]) + copy(state[48:], nonce[:]) +} + +func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int { + return xorKeyStreamGeneric(dst, src, block, state, rounds) +} + +func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { + hChaCha20Generic(out, nonce, key) +} diff --git a/vendor/github.com/aead/chacha20/chacha20.go b/vendor/github.com/aead/chacha20/chacha20.go new file mode 100644 index 0000000..df6ddd2 --- /dev/null +++ b/vendor/github.com/aead/chacha20/chacha20.go @@ -0,0 +1,41 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +// Package chacha20 implements the ChaCha20 / XChaCha20 stream chipher. +// Notice that one specific key-nonce combination must be unique for all time. +// +// There are three versions of ChaCha20: +// - ChaCha20 with a 64 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination) +// - ChaCha20 with a 96 bit nonce (en/decrypt up to 2^32 * 64 bytes (~256 GB) for one key-nonce combination) +// - XChaCha20 with a 192 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination) +package chacha20 // import "github.com/aead/chacha20" + +import ( + "crypto/cipher" + + "github.com/aead/chacha20/chacha" +) + +// XORKeyStream crypts bytes from src to dst using the given nonce and key. +// The length of the nonce determinds the version of ChaCha20: +// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period. +// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period. +// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period. +// Src and dst may be the same slice but otherwise should not overlap. +// If len(dst) < len(src) this function panics. +// If the nonce is neither 64, 96 nor 192 bits long, this function panics. +func XORKeyStream(dst, src, nonce, key []byte) { + chacha.XORKeyStream(dst, src, nonce, key, 20) +} + +// NewCipher returns a new cipher.Stream implementing a ChaCha20 version. +// The nonce must be unique for one key for all time. +// The length of the nonce determinds the version of ChaCha20: +// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period. +// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period. +// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period. +// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned. +func NewCipher(nonce, key []byte) (cipher.Stream, error) { + return chacha.NewCipher(nonce, key, 20) +} diff --git a/vendor/github.com/bifurcation/mint/LICENSE.md b/vendor/github.com/bifurcation/mint/LICENSE.md new file mode 100644 index 0000000..6385812 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/LICENSE.md @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Richard Barnes + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/bifurcation/mint/README.md b/vendor/github.com/bifurcation/mint/README.md new file mode 100644 index 0000000..0ac41e0 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/README.md @@ -0,0 +1,88 @@ +![A lock with a mint leaf](https://ipv.sx/mint/mint.svg) + +mint - A Minimal TLS 1.3 stack +============================== + +[![Build Status](https://circleci.com/gh/bifurcation/mint.svg)](https://circleci.com/gh/bifurcation/mint) + +This project is primarily a learning effort for me to understand the [TLS +1.3](http://tlswg.github.io/tls13-spec/) protocol. The goal is to arrive at a +pretty complete implementation of TLS 1.3, with minimal, elegant code that +demonstrates how things work. Testing is a priority to ensure correctness, but +otherwise, the quality of the software engineering might not be at a level where +it makes sense to integrate this with other libraries. Backward compatibility +is not an objective. + +We borrow liberally from the [Go TLS +library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns +with earlier TLS versions. However, unnecessary parts will be ruthlessly cut +off. + +## Quickstart + +Installation is the same as for any other Go package: + +``` +go get github.com/bifurcation/mint +``` + +The API is pretty much the same as for the TLS module, with `Dial` and `Listen` +methods wrapping the underlying socket APIs. + +``` +conn, err := mint.Dial("tcp", "localhost:4430", &mint.Config{...}) +... +listener, err := mint.Listen("tcp", "localhost:4430", &mint.Config{...}) +``` + +Documentation is available on +[godoc.org](https://godoc.org/github.com/bifurcation/mint) + + +## Interoperability testing + +The `mint-client` and `mint-server` executables are included to make it easy to +do basic interoperability tests with other TLS 1.3 implementations. The steps +for testing against NSS are as follows. + +``` +# Install mint +go get github.com/bifurcation/mint + +# Environment for NSS (you'll probably want a new directory) +NSS_ROOT= +mkdir $NSS_ROOT +cd $NSS_ROOT +export USE_64=1 +export ENABLE_TLS_1_3=1 +export HOST=localhost +export DOMSUF=localhost + +# Build NSS +hg clone https://hg.mozilla.org/projects/nss +hg clone https://hg.mozilla.org/projects/nspr +cd nss +make nss_build_all + +export PLATFORM=`cat $NSS_ROOT/dist/latest` +export DYLD_LIBRARY_PATH=$NSS_ROOT/dist/$PLATFORM/lib +export LD_LIBRARY_PATH=$NSS_ROOT/dist/$PLATFORM/lib + +# Run NSS tests (this creates data for the server to use) +cd tests/ssl_gtests +./ssl_gtests.sh + +# Test with client=mint server=NSS +cd $NSS_ROOT +./dist/$PLATFORM/bin/selfserv -d tests_results/security/$HOST.1/ssl_gtests/ -n rsa -p 4430 +# if you get `NSS_Init failed.`, check the path above, particularly around $HOST +# ... +go run $GOPATH/src/github.com/bifurcation/mint/bin/mint-client/main.go + +# Test with client=NSS server=mint +go run $GOPATH/src/github.com/bifurcation/mint/bin/mint-server/main.go +# ... +cd $NSS_ROOT +dist/$PLATFORM/bin/tstclnt -d tests_results/security/$HOST/ssl_gtests/ -V tls1.3:tls1.3 -h 127.0.0.1 -p 4430 -o +``` + diff --git a/vendor/github.com/bifurcation/mint/alert.go b/vendor/github.com/bifurcation/mint/alert.go new file mode 100644 index 0000000..5e31035 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/alert.go @@ -0,0 +1,99 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mint + +import "strconv" + +type Alert uint8 + +const ( + // alert level + AlertLevelWarning = 1 + AlertLevelError = 2 +) + +const ( + AlertCloseNotify Alert = 0 + AlertUnexpectedMessage Alert = 10 + AlertBadRecordMAC Alert = 20 + AlertDecryptionFailed Alert = 21 + AlertRecordOverflow Alert = 22 + AlertDecompressionFailure Alert = 30 + AlertHandshakeFailure Alert = 40 + AlertBadCertificate Alert = 42 + AlertUnsupportedCertificate Alert = 43 + AlertCertificateRevoked Alert = 44 + AlertCertificateExpired Alert = 45 + AlertCertificateUnknown Alert = 46 + AlertIllegalParameter Alert = 47 + AlertUnknownCA Alert = 48 + AlertAccessDenied Alert = 49 + AlertDecodeError Alert = 50 + AlertDecryptError Alert = 51 + AlertProtocolVersion Alert = 70 + AlertInsufficientSecurity Alert = 71 + AlertInternalError Alert = 80 + AlertInappropriateFallback Alert = 86 + AlertUserCanceled Alert = 90 + AlertNoRenegotiation Alert = 100 + AlertMissingExtension Alert = 109 + AlertUnsupportedExtension Alert = 110 + AlertCertificateUnobtainable Alert = 111 + AlertUnrecognizedName Alert = 112 + AlertBadCertificateStatsResponse Alert = 113 + AlertBadCertificateHashValue Alert = 114 + AlertUnknownPSKIdentity Alert = 115 + AlertNoApplicationProtocol Alert = 120 + AlertWouldBlock Alert = 254 + AlertNoAlert Alert = 255 +) + +var alertText = map[Alert]string{ + AlertCloseNotify: "close notify", + AlertUnexpectedMessage: "unexpected message", + AlertBadRecordMAC: "bad record MAC", + AlertDecryptionFailed: "decryption failed", + AlertRecordOverflow: "record overflow", + AlertDecompressionFailure: "decompression failure", + AlertHandshakeFailure: "handshake failure", + AlertBadCertificate: "bad certificate", + AlertUnsupportedCertificate: "unsupported certificate", + AlertCertificateRevoked: "revoked certificate", + AlertCertificateExpired: "expired certificate", + AlertCertificateUnknown: "unknown certificate", + AlertIllegalParameter: "illegal parameter", + AlertUnknownCA: "unknown certificate authority", + AlertAccessDenied: "access denied", + AlertDecodeError: "error decoding message", + AlertDecryptError: "error decrypting message", + AlertProtocolVersion: "protocol version not supported", + AlertInsufficientSecurity: "insufficient security level", + AlertInternalError: "internal error", + AlertInappropriateFallback: "inappropriate fallback", + AlertUserCanceled: "user canceled", + AlertMissingExtension: "missing extension", + AlertUnsupportedExtension: "unsupported extension", + AlertCertificateUnobtainable: "certificate unobtainable", + AlertUnrecognizedName: "unrecognized name", + AlertBadCertificateStatsResponse: "bad certificate status response", + AlertBadCertificateHashValue: "bad certificate hash value", + AlertUnknownPSKIdentity: "unknown PSK identity", + AlertNoApplicationProtocol: "no application protocol", + AlertNoRenegotiation: "no renegotiation", + AlertWouldBlock: "would have blocked", + AlertNoAlert: "no alert", +} + +func (e Alert) String() string { + s, ok := alertText[e] + if ok { + return s + } + return "alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e Alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/bifurcation/mint/client-state-machine.go new file mode 100644 index 0000000..290a930 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/client-state-machine.go @@ -0,0 +1,942 @@ +package mint + +import ( + "bytes" + "crypto" + "hash" + "time" +) + +// Client State Machine +// +// START <----+ +// Send ClientHello | | Recv HelloRetryRequest +// / v | +// | WAIT_SH ---+ +// Can | | Recv ServerHello +// send | V +// early | WAIT_EE +// data | | Recv EncryptedExtensions +// | +--------+--------+ +// | Using | | Using certificate +// | PSK | v +// | | WAIT_CERT_CR +// | | Recv | | Recv CertificateRequest +// | | Certificate | v +// | | | WAIT_CERT +// | | | | Recv Certificate +// | | v v +// | | WAIT_CV +// | | | Recv CertificateVerify +// | +> WAIT_FINISHED <+ +// | | Recv Finished +// \ | +// | [Send EndOfEarlyData] +// | [Send Certificate [+ CertificateVerify]] +// | Send Finished +// Can send v +// app data --> CONNECTED +// after +// here +// +// State Instructions +// START Send(CH); [RekeyOut; SendEarlyData] +// WAIT_SH Send(CH) || RekeyIn +// WAIT_EE {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) + +type ClientStateStart struct { + Caps Capabilities + Opts ConnectionOptions + Params ConnectionParameters + + cookie []byte + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage +} + +func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm != nil { + logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message") + return nil, nil, AlertUnexpectedMessage + } + + // key_shares + offeredDH := map[NamedGroup][]byte{} + ks := KeyShareExtension{ + HandshakeType: HandshakeTypeClientHello, + Shares: make([]KeyShareEntry, len(state.Caps.Groups)), + } + for i, group := range state.Caps.Groups { + pub, priv, err := newKeyShare(group) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) + return nil, nil, AlertInternalError + } + + ks.Shares[i].Group = group + ks.Shares[i].KeyExchange = pub + offeredDH[group] = priv + } + + logf(logTypeHandshake, "opts: %+v", state.Opts) + + // supported_versions, supported_groups, signature_algorithms, server_name + sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}} + sni := ServerNameExtension(state.Opts.ServerName) + sg := SupportedGroupsExtension{Groups: state.Caps.Groups} + sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} + + state.Params.ServerName = state.Opts.ServerName + + // Application Layer Protocol Negotiation + var alpn *ALPNExtension + if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { + alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} + } + + // Construct base ClientHello + ch := &ClientHelloBody{ + CipherSuites: state.Caps.CipherSuites, + } + _, err := prng.Read(ch.Random[:]) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err) + return nil, nil, AlertInternalError + } + for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} { + err := ch.Extensions.Add(ext) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err) + return nil, nil, AlertInternalError + } + } + // XXX: These optional extensions can't be folded into the above because Go + // interface-typed values are never reported as nil + if alpn != nil { + err := ch.Extensions.Add(alpn) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.cookie != nil { + err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Handle PSK and EarlyData just before transmitting, so that we can + // calculate the PSK binder value + var psk *PreSharedKeyExtension + var ed *EarlyDataExtension + var offeredPSK PreSharedKey + var earlyHash crypto.Hash + var earlySecret []byte + var clientEarlyTrafficKeys keySet + var clientHello *HandshakeMessage + if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok { + offeredPSK = key + + // Narrow ciphersuites to ones that match PSK hash + params, ok := cipherSuiteMap[key.CipherSuite] + if !ok { + logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite") + return nil, nil, AlertInternalError + } + + compatibleSuites := []CipherSuite{} + for _, suite := range ch.CipherSuites { + if cipherSuiteMap[suite].Hash == params.Hash { + compatibleSuites = append(compatibleSuites, suite) + } + } + ch.CipherSuites = compatibleSuites + + // Signal early data if we're going to do it + if len(state.Opts.EarlyData) > 0 { + state.Params.ClientSendingEarlyData = true + ed = &EarlyDataExtension{} + err = ch.Extensions.Add(ed) + if err != nil { + logf(logTypeHandshake, "Error adding early data extension: %v", err) + return nil, nil, AlertInternalError + } + } + + // Signal supported PSK key exchange modes + if len(state.Caps.PSKModes) == 0 { + logf(logTypeHandshake, "PSK selected, but no PSKModes") + return nil, nil, AlertInternalError + } + kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes} + err = ch.Extensions.Add(kem) + if err != nil { + logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) + return nil, nil, AlertInternalError + } + + // Add the shim PSK extension to the ClientHello + logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity) + psk = &PreSharedKeyExtension{ + HandshakeType: HandshakeTypeClientHello, + Identities: []PSKIdentity{ + { + Identity: key.Identity, + ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd, + }, + }, + Binders: []PSKBinderEntry{ + // Note: Stub to get the length fields right + {Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())}, + }, + } + ch.Extensions.Add(psk) + + // Compute the binder key + h0 := params.Hash.New().Sum(nil) + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + earlyHash = params.Hash + earlySecret = HkdfExtract(params.Hash, zero, key.Key) + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + + binderLabel := labelExternalBinder + if key.IsResumption { + binderLabel = labelResumptionBinder + } + binderKey := deriveSecret(params, earlySecret, binderLabel, h0) + logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey) + + // Compute the binder value + trunc, err := ch.Truncated() + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err) + return nil, nil, AlertInternalError + } + + truncHash := params.Hash.New() + truncHash.Write(trunc) + + binder := computeFinishedData(params, binderKey, truncHash.Sum(nil)) + + // Replace the PSK extension + psk.Binders[0].Binder = binder + ch.Extensions.Add(psk) + + // If we got here, the earlier marshal succeeded (in ch.Truncated()), so + // this one should too. + clientHello, _ = HandshakeMessageFromBody(ch) + + // Compute early traffic keys + h := params.Hash.New() + h.Write(clientHello.Marshal()) + chHash := h.Sum(nil) + + earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) + logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) + clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) + } else if len(state.Opts.EarlyData) > 0 { + logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") + return nil, nil, AlertInternalError + } else { + clientHello, err = HandshakeMessageFromBody(ch) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) + return nil, nil, AlertInternalError + } + } + + logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") + nextState := ClientStateWaitSH{ + Caps: state.Caps, + Opts: state.Opts, + Params: state.Params, + OfferedDH: offeredDH, + OfferedPSK: offeredPSK, + + earlySecret: earlySecret, + earlyHash: earlyHash, + + firstClientHello: state.firstClientHello, + helloRetryRequest: state.helloRetryRequest, + clientHello: clientHello, + } + + toSend := []HandshakeAction{ + SendHandshakeMessage{clientHello}, + } + if state.Params.ClientSendingEarlyData { + toSend = append(toSend, []HandshakeAction{ + RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys}, + SendEarlyData{}, + }...) + } + + return nextState, toSend, AlertNoAlert +} + +type ClientStateWaitSH struct { + Caps Capabilities + Opts ConnectionOptions + Params ConnectionParameters + OfferedDH map[NamedGroup][]byte + OfferedPSK PreSharedKey + PSK []byte + + earlySecret []byte + earlyHash crypto.Hash + + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage +} + +func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + switch body := bodyGeneric.(type) { + case *HelloRetryRequestBody: + hrr := body + + if state.helloRetryRequest != nil { + logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest") + return nil, nil, AlertUnexpectedMessage + } + + // Check that the version sent by the server is the one we support + if hrr.Version != supportedVersion { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version) + return nil, nil, AlertProtocolVersion + } + + // Check that the server provided a supported ciphersuite + supportedCipherSuite := false + for _, suite := range state.Caps.CipherSuites { + supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite) + } + if !supportedCipherSuite { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Narrow the supported ciphersuites to the server-provided one + state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite} + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // The only thing we know how to respond to in an HRR is the Cookie + // extension, so if there is either no Cookie extension or anything other + // than a Cookie extension, we have to fail. + serverCookie := new(CookieExtension) + foundCookie := hrr.Extensions.Find(serverCookie) + if !foundCookie || len(hrr.Extensions) != 1 { + logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) + return nil, nil, AlertIllegalParameter + } + + // Hash the body into a pseudo-message + // XXX: Ignoring some errors here + params := cipherSuiteMap[hrr.CipherSuite] + h := params.Hash.New() + h.Write(state.clientHello.Marshal()) + firstClientHello := &HandshakeMessage{ + msgType: HandshakeTypeMessageHash, + body: h.Sum(nil), + } + + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") + return ClientStateStart{ + Caps: state.Caps, + Opts: state.Opts, + cookie: serverCookie.Cookie, + firstClientHello: firstClientHello, + helloRetryRequest: hm, + }.Next(nil) + + case *ServerHelloBody: + sh := body + + // Check that the version sent by the server is the one we support + if sh.Version != supportedVersion { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version) + return nil, nil, AlertProtocolVersion + } + + // Check that the server provided a supported ciphersuite + supportedCipherSuite := false + for _, suite := range state.Caps.CipherSuites { + supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) + } + if !supportedCipherSuite { + logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Do PSK or key agreement depending on extensions + serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} + serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} + + foundPSK := sh.Extensions.Find(&serverPSK) + foundKeyShare := sh.Extensions.Find(&serverKeyShare) + + if foundPSK && (serverPSK.SelectedIdentity == 0) { + state.Params.UsingPSK = true + } + + var dhSecret []byte + if foundKeyShare { + sks := serverKeyShare.Shares[0] + priv, ok := state.OfferedDH[sks.Group] + if !ok { + logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingDH = true + dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) + } + + suite := sh.CipherSuite + state.Params.CipherSuite = suite + + params, ok := cipherSuiteMap[suite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(hm.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + if params.Hash != state.earlyHash { + logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", + state.earlyHash, suite, params.Hash) + } + + earlySecret = state.earlySecret + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if dhSecret == nil { + dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + + logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") + nextState := ClientStateWaitEE{ + Caps: state.Caps, + Params: state.Params, + cryptoParams: params, + handshakeHash: handshakeHash, + certificates: state.Caps.Certificates, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, + } + toSend := []HandshakeAction{ + RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys}, + } + return nextState, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType) + return nil, nil, AlertUnexpectedMessage +} + +type ClientStateWaitEE struct { + Caps Capabilities + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + certificates []*Certificate + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { + logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + ee := EncryptedExtensionsBody{} + _, err := ee.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if err != nil { + logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + serverALPN := ALPNExtension{} + serverEarlyData := EarlyDataExtension{} + + gotALPN := ee.Extensions.Find(&serverALPN) + state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData) + + if gotALPN && len(serverALPN.Protocols) > 0 { + state.Params.NextProto = serverALPN.Protocols[0] + } + + state.handshakeHash.Write(hm.Marshal()) + + if state.Params.UsingPSK { + logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") + nextState := ClientStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") + nextState := ClientStateWaitCertCR{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitCertCR struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + certificates []*Certificate + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + switch body := bodyGeneric.(type) { + case *CertificateBody: + logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") + nextState := ClientStateWaitCV{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificate: body, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + + case *CertificateRequestBody: + // A certificate request in the handshake should have a zero-length context + if len(body.CertificateRequestContext) > 0 { + logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err) + return nil, nil, AlertIllegalParameter + } + + state.Params.UsingClientAuth = true + + logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") + nextState := ClientStateWaitCert{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificateRequest: body, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert + } + + return nil, nil, AlertUnexpectedMessage +} + +type ClientStateWaitCert struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + certificates []*Certificate + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificate { + logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + cert := &CertificateBody{} + _, err := cert.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") + nextState := ClientStateWaitCV{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificate: cert, + serverCertificateRequest: state.serverCertificateRequest, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitCV struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + certificates []*Certificate + serverCertificate *CertificateBody + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { + logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + certVerify := CertificateVerifyBody{} + _, err := certVerify.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey + if err := certVerify.Verify(serverPublicKey, hcv); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") + return nil, nil, AlertHandshakeFailure + } + + if state.AuthCertificate != nil { + err := state.AuthCertificate(state.serverCertificate.CertificateList) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate") + return nil, nil, AlertBadCertificate + } + } else { + logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate") + } + + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") + nextState := ClientStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + certificates: state.certificates, + serverCertificateRequest: state.serverCertificateRequest, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, + } + return nextState, nil, AlertNoAlert +} + +type ClientStateWaitFinished struct { + Params ConnectionParameters + cryptoParams CipherSuiteParams + handshakeHash hash.Hash + + certificates []*Certificate + serverCertificateRequest *CertificateRequestBody + + masterSecret []byte + clientHandshakeTrafficSecret []byte + serverHandshakeTrafficSecret []byte +} + +func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeFinished { + logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + // Verify server's Finished + h3 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) + + serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3) + logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) + + fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} + _, err := fin.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + if !bytes.Equal(fin.VerifyData, serverFinishedData) { + logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]", + fin.VerifyData, serverFinishedData) + return nil, nil, AlertHandshakeFailure + } + + // Update the handshake hash with the Finished + state.handshakeHash.Write(hm.Marshal()) + logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal()) + h4 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4) + + // Compute traffic secrets and keys + clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4) + serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4) + logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) + logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) + + clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret) + serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret) + + exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4) + logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret) + + // Assemble client's second flight + toSend := []HandshakeAction{} + + if state.Params.UsingEarlyData { + // Note: We only send EOED if the server is actually going to use the early + // data. Otherwise, it will never see it, and the transcripts will + // mismatch. + // EOED marshal is infallible + eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{}) + toSend = append(toSend, SendHandshakeMessage{eoedm}) + state.handshakeHash.Write(eoedm.Marshal()) + logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) + } + + clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) + toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys}) + + if state.Params.UsingClientAuth { + // Extract constraints from certicateRequest + schemes := SignatureAlgorithmsExtension{} + gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes) + if !gotSchemes { + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) + return nil, nil, AlertIllegalParameter + } + + // Select a certificate + cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates) + if err != nil { + // XXX: Signal this to the application layer? + logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) + + certificate := &CertificateBody{} + certm, err := HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certm}) + state.handshakeHash.Write(certm.Marshal()) + } else { + // Create and send Certificate, CertificateVerify + certificate := &CertificateBody{ + CertificateList: make([]CertificateEntry, len(cert.Chain)), + } + for i, entry := range cert.Chain { + certificate.CertificateList[i] = CertificateEntry{CertData: entry} + } + certm, err := HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certm}) + state.handshakeHash.Write(certm.Marshal()) + + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + certificateVerify := &CertificateVerifyBody{Algorithm: certScheme} + logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash) + + err = certificateVerify.Sign(cert.PrivateKey, hcv) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + certvm, err := HandshakeMessageFromBody(certificateVerify) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certvm}) + state.handshakeHash.Write(certvm.Marshal()) + } + } + + // Compute the client's Finished message + h5 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) + + clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) + logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) + + fin = &FinishedBody{ + VerifyDataLen: len(clientFinishedData), + VerifyData: clientFinishedData, + } + finm, err := HandshakeMessageFromBody(fin) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) + return nil, nil, AlertInternalError + } + + // Compute the resumption secret + state.handshakeHash.Write(finm.Marshal()) + h6 := state.handshakeHash.Sum(nil) + + resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) + logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) + + toSend = append(toSend, []HandshakeAction{ + SendHandshakeMessage{finm}, + RekeyIn{Label: "application", KeySet: serverTrafficKeys}, + RekeyOut{Label: "application", KeySet: clientTrafficKeys}, + }...) + + logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") + nextState := StateConnected{ + Params: state.Params, + isClient: true, + cryptoParams: state.cryptoParams, + resumptionSecret: resumptionSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + return nextState, toSend, AlertNoAlert +} diff --git a/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/bifurcation/mint/common.go new file mode 100644 index 0000000..dfda7c3 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/common.go @@ -0,0 +1,152 @@ +package mint + +import ( + "fmt" + "strconv" +) + +var ( + supportedVersion uint16 = 0x7f15 // draft-21 + + // Flags for some minor compat issues + allowWrongVersionNumber = true + allowPKCS1 = true +) + +// enum {...} ContentType; +type RecordType byte + +const ( + RecordTypeAlert RecordType = 21 + RecordTypeHandshake RecordType = 22 + RecordTypeApplicationData RecordType = 23 +) + +// enum {...} HandshakeType; +type HandshakeType byte + +const ( + // Omitted: *_RESERVED + HandshakeTypeClientHello HandshakeType = 1 + HandshakeTypeServerHello HandshakeType = 2 + HandshakeTypeNewSessionTicket HandshakeType = 4 + HandshakeTypeEndOfEarlyData HandshakeType = 5 + HandshakeTypeHelloRetryRequest HandshakeType = 6 + HandshakeTypeEncryptedExtensions HandshakeType = 8 + HandshakeTypeCertificate HandshakeType = 11 + HandshakeTypeCertificateRequest HandshakeType = 13 + HandshakeTypeCertificateVerify HandshakeType = 15 + HandshakeTypeServerConfiguration HandshakeType = 17 + HandshakeTypeFinished HandshakeType = 20 + HandshakeTypeKeyUpdate HandshakeType = 24 + HandshakeTypeMessageHash HandshakeType = 254 +) + +// uint8 CipherSuite[2]; +type CipherSuite uint16 + +const ( + // XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero + // value for this type so that we can detect when a field is set. + CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000 + TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301 + TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303 + TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304 + TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305 +) + +func (c CipherSuite) String() string { + switch c { + case CIPHER_SUITE_UNKNOWN: + return "unknown" + case TLS_AES_128_GCM_SHA256: + return "TLS_AES_128_GCM_SHA256" + case TLS_AES_256_GCM_SHA384: + return "TLS_AES_256_GCM_SHA384" + case TLS_CHACHA20_POLY1305_SHA256: + return "TLS_CHACHA20_POLY1305_SHA256" + case TLS_AES_128_CCM_SHA256: + return "TLS_AES_128_CCM_SHA256" + case TLS_AES_256_CCM_8_SHA256: + return "TLS_AES_256_CCM_8_SHA256" + } + // cannot use %x here, since it calls String(), leading to infinite recursion + return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16)) +} + +// enum {...} SignatureScheme +type SignatureScheme uint16 + +const ( + // RSASSA-PKCS1-v1_5 algorithms + RSA_PKCS1_SHA1 SignatureScheme = 0x0201 + RSA_PKCS1_SHA256 SignatureScheme = 0x0401 + RSA_PKCS1_SHA384 SignatureScheme = 0x0501 + RSA_PKCS1_SHA512 SignatureScheme = 0x0601 + // ECDSA algorithms + ECDSA_P256_SHA256 SignatureScheme = 0x0403 + ECDSA_P384_SHA384 SignatureScheme = 0x0503 + ECDSA_P521_SHA512 SignatureScheme = 0x0603 + // RSASSA-PSS algorithms + RSA_PSS_SHA256 SignatureScheme = 0x0804 + RSA_PSS_SHA384 SignatureScheme = 0x0805 + RSA_PSS_SHA512 SignatureScheme = 0x0806 + // EdDSA algorithms + Ed25519 SignatureScheme = 0x0807 + Ed448 SignatureScheme = 0x0808 +) + +// enum {...} ExtensionType +type ExtensionType uint16 + +const ( + ExtensionTypeServerName ExtensionType = 0 + ExtensionTypeSupportedGroups ExtensionType = 10 + ExtensionTypeSignatureAlgorithms ExtensionType = 13 + ExtensionTypeALPN ExtensionType = 16 + ExtensionTypeKeyShare ExtensionType = 40 + ExtensionTypePreSharedKey ExtensionType = 41 + ExtensionTypeEarlyData ExtensionType = 42 + ExtensionTypeSupportedVersions ExtensionType = 43 + ExtensionTypeCookie ExtensionType = 44 + ExtensionTypePSKKeyExchangeModes ExtensionType = 45 + ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 +) + +// enum {...} NamedGroup +type NamedGroup uint16 + +const ( + // Elliptic Curve Groups. + P256 NamedGroup = 23 + P384 NamedGroup = 24 + P521 NamedGroup = 25 + // ECDH functions. + X25519 NamedGroup = 29 + X448 NamedGroup = 30 + // Finite field groups. + FFDHE2048 NamedGroup = 256 + FFDHE3072 NamedGroup = 257 + FFDHE4096 NamedGroup = 258 + FFDHE6144 NamedGroup = 259 + FFDHE8192 NamedGroup = 260 +) + +// enum {...} PskKeyExchangeMode; +type PSKKeyExchangeMode uint8 + +const ( + PSKModeKE PSKKeyExchangeMode = 0 + PSKModeDHEKE PSKKeyExchangeMode = 1 +) + +// enum { +// update_not_requested(0), update_requested(1), (255) +// } KeyUpdateRequest; +type KeyUpdateRequest uint8 + +const ( + KeyUpdateNotRequested KeyUpdateRequest = 0 + KeyUpdateRequested KeyUpdateRequest = 1 +) diff --git a/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/bifurcation/mint/conn.go new file mode 100644 index 0000000..08eb58d --- /dev/null +++ b/vendor/github.com/bifurcation/mint/conn.go @@ -0,0 +1,819 @@ +package mint + +import ( + "crypto" + "crypto/x509" + "encoding/hex" + "fmt" + "io" + "net" + "reflect" + "sync" + "time" +) + +var WouldBlock = fmt.Errorf("Would have blocked") + +type Certificate struct { + Chain []*x509.Certificate + PrivateKey crypto.Signer +} + +type PreSharedKey struct { + CipherSuite CipherSuite + IsResumption bool + Identity []byte + Key []byte + NextProto string + ReceivedAt time.Time + ExpiresAt time.Time + TicketAgeAdd uint32 +} + +type PreSharedKeyCache interface { + Get(string) (PreSharedKey, bool) + Put(string, PreSharedKey) + Size() int +} + +type PSKMapCache map[string]PreSharedKey + +// A CookieHandler does two things: +// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest +// - validates this byte string echoed by the client in the ClientHello +type CookieHandler interface { + Generate(*Conn) ([]byte, error) + Validate(*Conn, []byte) bool +} + +func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { + psk, ok = cache[key] + return +} + +func (cache *PSKMapCache) Put(key string, psk PreSharedKey) { + (*cache)[key] = psk +} + +func (cache PSKMapCache) Size() int { + return len(cache) +} + +// Config is the struct used to pass configuration settings to a TLS client or +// server instance. The settings for client and server are pretty different, +// but we just throw them all in here. +type Config struct { + // Client fields + ServerName string + + // Server fields + SendSessionTickets bool + TicketLifetime uint32 + TicketLen int + EarlyDataLifetime uint32 + AllowEarlyData bool + // Require the client to echo a cookie. + RequireCookie bool + // If cookies are required and no CookieHandler is set, a default cookie handler is used. + // The default cookie handler uses 32 random bytes as a cookie. + CookieHandler CookieHandler + RequireClientAuth bool + + // Shared fields + Certificates []*Certificate + AuthCertificate func(chain []CertificateEntry) error + CipherSuites []CipherSuite + Groups []NamedGroup + SignatureSchemes []SignatureScheme + NextProtos []string + PSKs PreSharedKeyCache + PSKModes []PSKKeyExchangeMode + NonBlocking bool + + // The same config object can be shared among different connections, so it + // needs its own mutex + mutex sync.RWMutex +} + +// Clone returns a shallow clone of c. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *Config) Clone() *Config { + c.mutex.Lock() + defer c.mutex.Unlock() + + return &Config{ + ServerName: c.ServerName, + + SendSessionTickets: c.SendSessionTickets, + TicketLifetime: c.TicketLifetime, + TicketLen: c.TicketLen, + EarlyDataLifetime: c.EarlyDataLifetime, + AllowEarlyData: c.AllowEarlyData, + RequireCookie: c.RequireCookie, + RequireClientAuth: c.RequireClientAuth, + + Certificates: c.Certificates, + AuthCertificate: c.AuthCertificate, + CipherSuites: c.CipherSuites, + Groups: c.Groups, + SignatureSchemes: c.SignatureSchemes, + NextProtos: c.NextProtos, + PSKs: c.PSKs, + PSKModes: c.PSKModes, + NonBlocking: c.NonBlocking, + } +} + +func (c *Config) Init(isClient bool) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Set defaults + if len(c.CipherSuites) == 0 { + c.CipherSuites = defaultSupportedCipherSuites + } + if len(c.Groups) == 0 { + c.Groups = defaultSupportedGroups + } + if len(c.SignatureSchemes) == 0 { + c.SignatureSchemes = defaultSignatureSchemes + } + if c.TicketLen == 0 { + c.TicketLen = defaultTicketLen + } + if !reflect.ValueOf(c.PSKs).IsValid() { + c.PSKs = &PSKMapCache{} + } + if len(c.PSKModes) == 0 { + c.PSKModes = defaultPSKModes + } + + // If there is no certificate, generate one + if !isClient && len(c.Certificates) == 0 { + logf(logTypeHandshake, "Generating key name=%v", c.ServerName) + priv, err := newSigningKey(RSA_PSS_SHA256) + if err != nil { + return err + } + + cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv) + if err != nil { + return err + } + + c.Certificates = []*Certificate{ + { + Chain: []*x509.Certificate{cert}, + PrivateKey: priv, + }, + } + } + + return nil +} + +func (c *Config) ValidForServer() bool { + return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) || + (len(c.Certificates) > 0 && + len(c.Certificates[0].Chain) > 0 && + c.Certificates[0].PrivateKey != nil) +} + +func (c *Config) ValidForClient() bool { + return len(c.ServerName) > 0 +} + +var ( + defaultSupportedCipherSuites = []CipherSuite{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } + + defaultSupportedGroups = []NamedGroup{ + P256, + P384, + FFDHE2048, + X25519, + } + + defaultSignatureSchemes = []SignatureScheme{ + RSA_PSS_SHA256, + RSA_PSS_SHA384, + RSA_PSS_SHA512, + ECDSA_P256_SHA256, + ECDSA_P384_SHA384, + ECDSA_P521_SHA512, + } + + defaultTicketLen = 16 + + defaultPSKModes = []PSKKeyExchangeMode{ + PSKModeKE, + PSKModeDHEKE, + } +) + +type ConnectionState struct { + HandshakeState string // string representation of the handshake state. + CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement + NextProto string // Selected ALPN proto +} + +// Conn implements the net.Conn interface, as with "crypto/tls" +// * Read, Write, and Close are provided locally +// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn +type Conn struct { + config *Config + conn net.Conn + isClient bool + + EarlyData []byte + + state StateConnected + hState HandshakeState + handshakeMutex sync.Mutex + handshakeAlert Alert + handshakeComplete bool + + readBuffer []byte + in, out *RecordLayer + hIn, hOut *HandshakeLayer + + extHandler AppExtensionHandler +} + +func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { + c := &Conn{conn: conn, config: config, isClient: isClient} + c.in = NewRecordLayer(c.conn) + c.out = NewRecordLayer(c.conn) + c.hIn = NewHandshakeLayer(c.in) + c.hIn.nonblocking = c.config.NonBlocking + c.hOut = NewHandshakeLayer(c.out) + return c +} + +// Read up +func (c *Conn) consumeRecord() error { + pt, err := c.in.ReadRecord() + if pt == nil { + logf(logTypeIO, "extendBuffer returns error %v", err) + return err + } + + switch pt.contentType { + case RecordTypeHandshake: + logf(logTypeHandshake, "Received post-handshake message") + // We do not support fragmentation of post-handshake handshake messages. + // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage() + start := 0 + for start < len(pt.fragment) { + if len(pt.fragment[start:]) < handshakeHeaderLen { + return fmt.Errorf("Post-handshake handshake message too short for header") + } + + hm := &HandshakeMessage{} + hm.msgType = HandshakeType(pt.fragment[start]) + hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) + + if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen { + return fmt.Errorf("Post-handshake handshake message too short for body") + } + hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen] + + // Advance state machine + state, actions, alert := c.state.Next(hm) + + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error in state transition: %v", alert) + c.sendAlert(alert) + return io.EOF + } + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return io.EOF + } + } + + // XXX: If we want to support more advanced cases, e.g., post-handshake + // authentication, we'll need to allow transitions other than + // Connected -> Connected + var connected bool + c.state, connected = state.(StateConnected) + if !connected { + logf(logTypeHandshake, "Disconnected after state transition: %v", alert) + c.sendAlert(alert) + return io.EOF + } + + start += handshakeHeaderLen + hmLen + } + case RecordTypeAlert: + logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) + if len(pt.fragment) != 2 { + c.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + if Alert(pt.fragment[1]) == AlertCloseNotify { + return io.EOF + } + + switch pt.fragment[0] { + case AlertLevelWarning: + // drop on the floor + case AlertLevelError: + return Alert(pt.fragment[1]) + default: + c.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + + case RecordTypeApplicationData: + c.readBuffer = append(c.readBuffer, pt.fragment...) + logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) + } + + return err +} + +// Read application data up to the size of buffer. Handshake and alert records +// are consumed by the Conn object directly. +func (c *Conn) Read(buffer []byte) (int, error) { + logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) + if alert := c.Handshake(); alert != AlertNoAlert { + return 0, alert + } + + if len(buffer) == 0 { + return 0, nil + } + + // Lock the input channel + c.in.Lock() + defer c.in.Unlock() + for len(c.readBuffer) == 0 { + err := c.consumeRecord() + + // err can be nil if consumeRecord processed a non app-data + // record. + if err != nil { + if c.config.NonBlocking || err != WouldBlock { + logf(logTypeIO, "conn.Read returns err=%v", err) + return 0, err + } + } + } + + var read int + n := len(buffer) + logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) + if len(c.readBuffer) <= n { + buffer = buffer[:len(c.readBuffer)] + copy(buffer, c.readBuffer) + read = len(c.readBuffer) + c.readBuffer = c.readBuffer[:0] + } else { + logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) + copy(buffer[:n], c.readBuffer[:n]) + c.readBuffer = c.readBuffer[n:] + read = n + } + + logf(logTypeVerbose, "Returning %v", string(buffer)) + return read, nil +} + +// Write application data +func (c *Conn) Write(buffer []byte) (int, error) { + // Lock the output channel + c.out.Lock() + defer c.out.Unlock() + + // Send full-size fragments + var start int + sent := 0 + for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { + err := c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: buffer[start : start+maxFragmentLen], + }) + + if err != nil { + return sent, err + } + sent += maxFragmentLen + } + + // Send a final partial fragment if necessary + if start < len(buffer) { + err := c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: buffer[start:], + }) + + if err != nil { + return sent, err + } + sent += len(buffer[start:]) + } + return sent, nil +} + +// sendAlert sends a TLS alert message. +// c.out.Mutex <= L. +func (c *Conn) sendAlert(err Alert) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + var level int + switch err { + case AlertNoRenegotiation, AlertCloseNotify: + level = AlertLevelWarning + default: + level = AlertLevelError + } + + buf := []byte{byte(err), byte(level)} + c.out.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAlert, + fragment: buf, + }) + + // close_notify and end_of_early_data are not actually errors + if level == AlertLevelWarning { + return &net.OpError{Op: "local error", Err: err} + } + + return c.Close() +} + +// Close closes the connection. +func (c *Conn) Close() error { + // XXX crypto/tls has an interlock with Write here. Do we need that? + + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { + label := "[server]" + if c.isClient { + label = "[client]" + } + + switch action := actionGeneric.(type) { + case SendHandshakeMessage: + err := c.hOut.WriteMessage(action.Message) + if err != nil { + logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) + return AlertInternalError + } + + case RekeyIn: + logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet) + err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + if err != nil { + logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) + return AlertInternalError + } + + case RekeyOut: + logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet) + err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) + if err != nil { + logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) + return AlertInternalError + } + + case SendEarlyData: + logf(logTypeHandshake, "%s Sending early data...", label) + _, err := c.Write(c.EarlyData) + if err != nil { + logf(logTypeHandshake, "%s Error writing early data: %v", label, err) + return AlertInternalError + } + + case ReadPastEarlyData: + logf(logTypeHandshake, "%s Reading past early data...", label) + // Scan past all records that fail to decrypt + _, err := c.in.PeekRecordType(!c.config.NonBlocking) + if err == nil { + break + } + _, ok := err.(DecryptError) + + for ok { + _, err = c.in.PeekRecordType(!c.config.NonBlocking) + if err == nil { + break + } + _, ok = err.(DecryptError) + } + + case ReadEarlyData: + logf(logTypeHandshake, "%s Reading early data...", label) + t, err := c.in.PeekRecordType(!c.config.NonBlocking) + if err != nil { + logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) + return AlertInternalError + } + logf(logTypeHandshake, "%s Got record type(1): %v", label, t) + + for t == RecordTypeApplicationData { + // Read a record into the buffer. Note that this is safe + // in blocking mode because we read the record in in + // PeekRecordType. + pt, err := c.in.ReadRecord() + if err != nil { + logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) + return AlertInternalError + } + + logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) + c.EarlyData = append(c.EarlyData, pt.fragment...) + + t, err = c.in.PeekRecordType(!c.config.NonBlocking) + if err != nil { + logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) + return AlertInternalError + } + logf(logTypeHandshake, "%s Got record type (2): %v", label, t) + } + logf(logTypeHandshake, "%s Done reading early data", label) + + case StorePSK: + logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) + if c.isClient { + // Clients look up PSKs based on server name + c.config.PSKs.Put(c.config.ServerName, action.PSK) + } else { + // Servers look them up based on the identity in the extension + c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK) + } + + default: + logf(logTypeHandshake, "%s Unknown actionuction type", label) + return AlertInternalError + } + + return AlertNoAlert +} + +func (c *Conn) HandshakeSetup() Alert { + var state HandshakeState + var actions []HandshakeAction + var alert Alert + + if err := c.config.Init(c.isClient); err != nil { + logf(logTypeHandshake, "Error initializing config: %v", err) + return AlertInternalError + } + + // Set things up + caps := Capabilities{ + CipherSuites: c.config.CipherSuites, + Groups: c.config.Groups, + SignatureSchemes: c.config.SignatureSchemes, + PSKs: c.config.PSKs, + PSKModes: c.config.PSKModes, + AllowEarlyData: c.config.AllowEarlyData, + RequireCookie: c.config.RequireCookie, + CookieHandler: c.config.CookieHandler, + RequireClientAuth: c.config.RequireClientAuth, + NextProtos: c.config.NextProtos, + Certificates: c.config.Certificates, + ExtensionHandler: c.extHandler, + } + opts := ConnectionOptions{ + ServerName: c.config.ServerName, + NextProtos: c.config.NextProtos, + EarlyData: c.EarlyData, + } + + if caps.RequireCookie && caps.CookieHandler == nil { + caps.CookieHandler = &defaultCookieHandler{} + } + + if c.isClient { + state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error initializing client state: %v", alert) + return alert + } + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + return alert + } + } + } else { + state = ServerStateStart{Caps: caps, conn: c} + } + + c.hState = state + + return AlertNoAlert +} + +// Handshake causes a TLS handshake on the connection. The `isClient` member +// determines whether a client or server handshake is performed. If a +// handshake has already been performed, then its result will be returned. +func (c *Conn) Handshake() Alert { + label := "[server]" + if c.isClient { + label = "[client]" + } + + // TODO Lock handshakeMutex + // TODO Remove CloseNotify hack + if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify { + logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert) + return c.handshakeAlert + } + if c.handshakeComplete { + return AlertNoAlert + } + + var alert Alert + if c.hState == nil { + logf(logTypeHandshake, "%s First time through handshake, setting up", label) + alert = c.HandshakeSetup() + if alert != AlertNoAlert { + return alert + } + } else { + logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState) + } + + state := c.hState + _, connected := state.(StateConnected) + + var actions []HandshakeAction + + for !connected { + // Read a handshake message + hm, err := c.hIn.ReadMessage() + if err == WouldBlock { + logf(logTypeHandshake, "%s Would block reading message: %v", label, err) + return AlertWouldBlock + } + if err != nil { + logf(logTypeHandshake, "%s Error reading message: %v", label, err) + c.sendAlert(AlertCloseNotify) + return AlertCloseNotify + } + logf(logTypeHandshake, "Read message with type: %v", hm.msgType) + + // Advance the state machine + state, actions, alert = state.Next(hm) + + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error in state transition: %v", alert) + return alert + } + + for index, action := range actions { + logf(logTypeHandshake, "%s taking next action (%d)", label, index) + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + + c.hState = state + logf(logTypeHandshake, "state is now %s", c.GetHsState()) + + _, connected = state.(StateConnected) + } + + c.state = state.(StateConnected) + + // Send NewSessionTicket if acting as server + if !c.isClient && c.config.SendSessionTickets { + actions, alert := c.state.NewSessionTicket( + c.config.TicketLen, + c.config.TicketLifetime, + c.config.EarlyDataLifetime) + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + } + + c.handshakeComplete = true + return AlertNoAlert +} + +func (c *Conn) SendKeyUpdate(requestUpdate bool) error { + if !c.handshakeComplete { + return fmt.Errorf("Cannot update keys until after handshake") + } + + request := KeyUpdateNotRequested + if requestUpdate { + request = KeyUpdateRequested + } + + // Create the key update and update state + actions, alert := c.state.KeyUpdate(request) + if alert != AlertNoAlert { + c.sendAlert(alert) + return fmt.Errorf("Alert while generating key update: %v", alert) + } + + // Take actions (send key update and rekey) + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + c.sendAlert(alert) + return fmt.Errorf("Alert during key update actions: %v", alert) + } + } + + return nil +} + +func (c *Conn) GetHsState() string { + return reflect.TypeOf(c.hState).Name() +} + +func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + _, connected := c.hState.(StateConnected) + if !connected { + return nil, fmt.Errorf("Cannot compute exporter when state is not connected") + } + + if c.state.exporterSecret == nil { + return nil, fmt.Errorf("Internal error: no exporter secret") + } + + h0 := c.state.cryptoParams.Hash.New().Sum(nil) + tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0) + + hc := c.state.cryptoParams.Hash.New().Sum(context) + return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil +} + +func (c *Conn) State() ConnectionState { + state := ConnectionState{ + HandshakeState: c.GetHsState(), + } + + if c.handshakeComplete { + state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] + state.NextProto = c.state.Params.NextProto + } + + return state +} + +func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error { + if c.hState != nil { + return fmt.Errorf("Can't set extension handler after setup") + } + + c.extHandler = h + return nil +} diff --git a/vendor/github.com/bifurcation/mint/crypto.go b/vendor/github.com/bifurcation/mint/crypto.go new file mode 100644 index 0000000..60d3437 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/crypto.go @@ -0,0 +1,654 @@ +package mint + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "fmt" + "math/big" + "time" + + "golang.org/x/crypto/curve25519" + + // Blank includes to ensure hash support + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +var prng = rand.Reader + +type aeadFactory func(key []byte) (cipher.AEAD, error) + +type CipherSuiteParams struct { + Suite CipherSuite + Cipher aeadFactory // Cipher factory + Hash crypto.Hash // Hash function + KeyLen int // Key length in octets + IvLen int // IV length in octets +} + +type signatureAlgorithm uint8 + +const ( + signatureAlgorithmUnknown = iota + signatureAlgorithmRSA_PKCS1 + signatureAlgorithmRSA_PSS + signatureAlgorithmECDSA +) + +var ( + hashMap = map[SignatureScheme]crypto.Hash{ + RSA_PKCS1_SHA1: crypto.SHA1, + RSA_PKCS1_SHA256: crypto.SHA256, + RSA_PKCS1_SHA384: crypto.SHA384, + RSA_PKCS1_SHA512: crypto.SHA512, + ECDSA_P256_SHA256: crypto.SHA256, + ECDSA_P384_SHA384: crypto.SHA384, + ECDSA_P521_SHA512: crypto.SHA512, + RSA_PSS_SHA256: crypto.SHA256, + RSA_PSS_SHA384: crypto.SHA384, + RSA_PSS_SHA512: crypto.SHA512, + } + + sigMap = map[SignatureScheme]signatureAlgorithm{ + RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1, + RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1, + ECDSA_P256_SHA256: signatureAlgorithmECDSA, + ECDSA_P384_SHA384: signatureAlgorithmECDSA, + ECDSA_P521_SHA512: signatureAlgorithmECDSA, + RSA_PSS_SHA256: signatureAlgorithmRSA_PSS, + RSA_PSS_SHA384: signatureAlgorithmRSA_PSS, + RSA_PSS_SHA512: signatureAlgorithmRSA_PSS, + } + + curveMap = map[SignatureScheme]NamedGroup{ + ECDSA_P256_SHA256: P256, + ECDSA_P384_SHA384: P384, + ECDSA_P521_SHA512: P521, + } + + newAESGCM = func(key []byte) (cipher.AEAD, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // TLS always uses 12-byte nonces + return cipher.NewGCMWithNonceSize(block, 12) + } + + cipherSuiteMap = map[CipherSuite]CipherSuiteParams{ + TLS_AES_128_GCM_SHA256: { + Suite: TLS_AES_128_GCM_SHA256, + Cipher: newAESGCM, + Hash: crypto.SHA256, + KeyLen: 16, + IvLen: 12, + }, + TLS_AES_256_GCM_SHA384: { + Suite: TLS_AES_256_GCM_SHA384, + Cipher: newAESGCM, + Hash: crypto.SHA384, + KeyLen: 32, + IvLen: 12, + }, + } + + x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{ + RSA_PKCS1_SHA1: x509.SHA1WithRSA, + RSA_PKCS1_SHA256: x509.SHA256WithRSA, + RSA_PKCS1_SHA384: x509.SHA384WithRSA, + RSA_PKCS1_SHA512: x509.SHA512WithRSA, + ECDSA_P256_SHA256: x509.ECDSAWithSHA256, + ECDSA_P384_SHA384: x509.ECDSAWithSHA384, + ECDSA_P521_SHA512: x509.ECDSAWithSHA512, + } + + defaultRSAKeySize = 2048 +) + +func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) { + switch group { + case P256: + crv = elliptic.P256() + case P384: + crv = elliptic.P384() + case P521: + crv = elliptic.P521() + } + return +} + +func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) { + switch key.Curve.Params().Name { + case elliptic.P256().Params().Name: + g = P256 + case elliptic.P384().Params().Name: + g = P384 + case elliptic.P521().Params().Name: + g = P521 + } + return +} + +func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) { + size = 0 + switch group { + case X25519: + size = 32 + case P256: + size = 65 + case P384: + size = 97 + case P521: + size = 133 + case FFDHE2048: + size = 256 + case FFDHE3072: + size = 384 + case FFDHE4096: + size = 512 + case FFDHE6144: + size = 768 + case FFDHE8192: + size = 1024 + } + return +} + +func primeFromNamedGroup(group NamedGroup) (p *big.Int) { + switch group { + case FFDHE2048: + p = finiteFieldPrime2048 + case FFDHE3072: + p = finiteFieldPrime3072 + case FFDHE4096: + p = finiteFieldPrime4096 + case FFDHE6144: + p = finiteFieldPrime6144 + case FFDHE8192: + p = finiteFieldPrime8192 + } + return +} + +func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool { + sigType := sigMap[alg] + switch key.(type) { + case *rsa.PrivateKey: + return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS + case *ecdsa.PrivateKey: + return sigType == signatureAlgorithmECDSA + default: + return false + } +} + +func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) { + primeLen := len(p.Bytes()) + for { + // g = 2 for all ffdhe groups + priv, err = rand.Int(prng, p) + if err != nil { + return + } + + pub = big.NewInt(0) + pub.Exp(big.NewInt(2), priv, p) + + if len(pub.Bytes()) == primeLen { + return + } + } +} + +func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) { + switch group { + case P256, P384, P521: + var x, y *big.Int + crv := curveFromNamedGroup(group) + priv, x, y, err = elliptic.GenerateKey(crv, prng) + if err != nil { + return + } + + pub = elliptic.Marshal(crv, x, y) + return + + case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: + p := primeFromNamedGroup(group) + x, X, err2 := ffdheKeyShareFromPrime(p) + if err2 != nil { + err = err2 + return + } + + priv = x.Bytes() + pubBytes := X.Bytes() + + numBytes := keyExchangeSizeFromNamedGroup(group) + + pub = make([]byte, numBytes) + copy(pub[numBytes-len(pubBytes):], pubBytes) + + return + + case X25519: + var private, public [32]byte + _, err = prng.Read(private[:]) + if err != nil { + return + } + + curve25519.ScalarBaseMult(&public, &private) + priv = private[:] + pub = public[:] + return + + default: + return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group) + } +} + +func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) { + switch group { + case P256, P384, P521: + if len(pub) != keyExchangeSizeFromNamedGroup(group) { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + + crv := curveFromNamedGroup(group) + pubX, pubY := elliptic.Unmarshal(crv, pub) + x, _ := crv.Params().ScalarMult(pubX, pubY, priv) + xBytes := x.Bytes() + + numBytes := len(crv.Params().P.Bytes()) + + ret := make([]byte, numBytes) + copy(ret[numBytes-len(xBytes):], xBytes) + + return ret, nil + + case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: + numBytes := keyExchangeSizeFromNamedGroup(group) + if len(pub) != numBytes { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + p := primeFromNamedGroup(group) + x := big.NewInt(0).SetBytes(priv) + Y := big.NewInt(0).SetBytes(pub) + ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes() + + ret := make([]byte, numBytes) + copy(ret[numBytes-len(ZBytes):], ZBytes) + + return ret, nil + + case X25519: + if len(pub) != keyExchangeSizeFromNamedGroup(group) { + return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") + } + + var private, public, ret [32]byte + copy(private[:], priv) + copy(public[:], pub) + curve25519.ScalarMult(&ret, &private, &public) + + return ret[:], nil + + default: + return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group) + } +} + +func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { + switch sig { + case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256, + RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, + RSA_PSS_SHA256, RSA_PSS_SHA384, + RSA_PSS_SHA512: + return rsa.GenerateKey(prng, defaultRSAKeySize) + case ECDSA_P256_SHA256: + return ecdsa.GenerateKey(elliptic.P256(), prng) + case ECDSA_P384_SHA384: + return ecdsa.GenerateKey(elliptic.P384(), prng) + case ECDSA_P521_SHA512: + return ecdsa.GenerateKey(elliptic.P521(), prng) + default: + return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig) + } +} + +func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) { + sigAlg, ok := x509AlgMap[alg] + if !ok { + return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg) + } + if len(name) == 0 { + return nil, fmt.Errorf("tls.selfsigned: No name provided") + } + + serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0)) + if err != nil { + return nil, err + } + + template := &x509.Certificate{ + SerialNumber: serial, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, 1), + SignatureAlgorithm: sigAlg, + Subject: pkix.Name{CommonName: name}, + DNSNames: []string{name}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv) + if err != nil { + return nil, err + } + + // It is safe to ignore the error here because we're parsing known-good data + cert, _ := x509.ParseCertificate(der) + return cert, nil +} + +// XXX(rlb): Copied from crypto/x509 +type ecdsaSignature struct { + R, S *big.Int +} + +func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) { + var opts crypto.SignerOpts + + hash := hashMap[alg] + if hash == crypto.SHA1 { + return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") + } + + sigType := sigMap[alg] + var realInput []byte + switch key := privateKey.(type) { + case *rsa.PrivateKey: + switch { + case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size()) + opts = hash + case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + fallthrough + case sigType == signatureAlgorithmRSA_PSS: + logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size()) + opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} + default: + return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key") + } + + h := hash.New() + h.Write(sigInput) + realInput = h.Sum(nil) + case *ecdsa.PrivateKey: + if sigType != signatureAlgorithmECDSA { + return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key") + } + + algGroup := curveMap[alg] + keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey)) + if algGroup != keyGroup { + return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination") + } + + h := hash.New() + h.Write(sigInput) + realInput = h.Sum(nil) + default: + return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type") + } + + sig, err := privateKey.Sign(prng, realInput, opts) + logf(logTypeCrypto, "signature: %x", sig) + return sig, err +} + +func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error { + hash := hashMap[alg] + + if hash == crypto.SHA1 { + return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") + } + + sigType := sigMap[alg] + switch pub := publicKey.(type) { + case *rsa.PublicKey: + switch { + case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size()) + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + return rsa.VerifyPKCS1v15(pub, hash, realInput, sig) + case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: + fallthrough + case sigType == signatureAlgorithmRSA_PSS: + logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size()) + opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + return rsa.VerifyPSS(pub, hash, realInput, sig, opts) + default: + return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key") + } + + case *ecdsa.PublicKey: + if sigType != signatureAlgorithmECDSA { + return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key") + } + + if curveMap[alg] != namedGroupFromECDSAKey(pub) { + return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key") + } + + ecdsaSig := new(ecdsaSignature) + if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { + return err + } else if len(rest) != 0 { + return fmt.Errorf("tls.verify: trailing data after ECDSA signature") + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values") + } + + h := hash.New() + h.Write(sigInput) + realInput := h.Sum(nil) + if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) { + return fmt.Errorf("tls.verify: ECDSA verification failure") + } + return nil + default: + return fmt.Errorf("tls.verify: Unsupported key type") + } +} + +// 0 +// | +// v +// PSK -> HKDF-Extract = Early Secret +// | +// +-----> Derive-Secret(., +// | "ext binder" | +// | "res binder", +// | "") +// | = binder_key +// | +// +-----> Derive-Secret(., "c e traffic", +// | ClientHello) +// | = client_early_traffic_secret +// | +// +-----> Derive-Secret(., "e exp master", +// | ClientHello) +// | = early_exporter_master_secret +// v +// Derive-Secret(., "derived", "") +// | +// v +// (EC)DHE -> HKDF-Extract = Handshake Secret +// | +// +-----> Derive-Secret(., "c hs traffic", +// | ClientHello...ServerHello) +// | = client_handshake_traffic_secret +// | +// +-----> Derive-Secret(., "s hs traffic", +// | ClientHello...ServerHello) +// | = server_handshake_traffic_secret +// v +// Derive-Secret(., "derived", "") +// | +// v +// 0 -> HKDF-Extract = Master Secret +// | +// +-----> Derive-Secret(., "c ap traffic", +// | ClientHello...server Finished) +// | = client_application_traffic_secret_0 +// | +// +-----> Derive-Secret(., "s ap traffic", +// | ClientHello...server Finished) +// | = server_application_traffic_secret_0 +// | +// +-----> Derive-Secret(., "exp master", +// | ClientHello...server Finished) +// | = exporter_master_secret +// | +// +-----> Derive-Secret(., "res master", +// ClientHello...client Finished) +// = resumption_master_secret + +// From RFC 5869 +// PRK = HMAC-Hash(salt, IKM) +func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte { + salt := saltIn + + // if [salt is] not provided, it is set to a string of HashLen zeros + if salt == nil { + salt = bytes.Repeat([]byte{0}, hash.Size()) + } + + h := hmac.New(hash.New, salt) + h.Write(input) + out := h.Sum(nil) + + logf(logTypeCrypto, "HKDF Extract:\n") + logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt) + logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input) + logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out) + + return out +} + +const ( + labelExternalBinder = "ext binder" + labelResumptionBinder = "res binder" + labelEarlyTrafficSecret = "c e traffic" + labelEarlyExporterSecret = "e exp master" + labelClientHandshakeTrafficSecret = "c hs traffic" + labelServerHandshakeTrafficSecret = "s hs traffic" + labelClientApplicationTrafficSecret = "c ap traffic" + labelServerApplicationTrafficSecret = "s ap traffic" + labelExporterSecret = "exp master" + labelResumptionSecret = "res master" + labelDerived = "derived" + labelFinished = "finished" + labelResumption = "resumption" +) + +// struct HkdfLabel { +// uint16 length; +// opaque label<9..255>; +// opaque hash_value<0..255>; +// }; +func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte { + label := "tls13 " + labelIn + + labelLen := len(label) + hashLen := len(hashValue) + hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen) + hkdfLabel[0] = byte(outLen >> 8) + hkdfLabel[1] = byte(outLen) + hkdfLabel[2] = byte(labelLen) + copy(hkdfLabel[3:3+labelLen], []byte(label)) + hkdfLabel[3+labelLen] = byte(hashLen) + copy(hkdfLabel[3+labelLen+1:], hashValue) + + return hkdfLabel +} + +func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte { + out := []byte{} + T := []byte{} + i := byte(1) + for len(out) < outLen { + block := append(T, info...) + block = append(block, i) + + h := hmac.New(hash.New, prk) + h.Write(block) + + T = h.Sum(nil) + out = append(out, T...) + i++ + } + return out[:outLen] +} + +func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte { + info := hkdfEncodeLabel(label, hashValue, outLen) + derived := HkdfExpand(hash, secret, info, outLen) + + logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen) + logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret) + logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue) + logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info) + logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived) + + return derived +} + +func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte { + return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size()) +} + +func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte { + macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size()) + mac := hmac.New(params.Hash.New, macKey) + mac.Write(input) + return mac.Sum(nil) +} + +type keySet struct { + cipher aeadFactory + key []byte + iv []byte +} + +func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { + logf(logTypeCrypto, "making traffic keys: secret=%x", secret) + return keySet{ + cipher: params.Cipher, + key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen), + iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), + } +} diff --git a/vendor/github.com/bifurcation/mint/extensions.go b/vendor/github.com/bifurcation/mint/extensions.go new file mode 100644 index 0000000..1dbe7bd --- /dev/null +++ b/vendor/github.com/bifurcation/mint/extensions.go @@ -0,0 +1,586 @@ +package mint + +import ( + "bytes" + "fmt" + + "github.com/bifurcation/mint/syntax" +) + +type ExtensionBody interface { + Type() ExtensionType + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +// struct { +// ExtensionType extension_type; +// opaque extension_data<0..2^16-1>; +// } Extension; +type Extension struct { + ExtensionType ExtensionType + ExtensionData []byte `tls:"head=2"` +} + +func (ext Extension) Marshal() ([]byte, error) { + return syntax.Marshal(ext) +} + +func (ext *Extension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ext) +} + +type ExtensionList []Extension + +type extensionListInner struct { + List []Extension `tls:"head=2"` +} + +func (el ExtensionList) Marshal() ([]byte, error) { + return syntax.Marshal(extensionListInner{el}) +} + +func (el *ExtensionList) Unmarshal(data []byte) (int, error) { + var list extensionListInner + read, err := syntax.Unmarshal(data, &list) + if err != nil { + return 0, err + } + + *el = list.List + return read, nil +} + +func (el *ExtensionList) Add(src ExtensionBody) error { + data, err := src.Marshal() + if err != nil { + return err + } + + if el == nil { + el = new(ExtensionList) + } + + // If one already exists with this type, replace it + for i := range *el { + if (*el)[i].ExtensionType == src.Type() { + (*el)[i].ExtensionData = data + return nil + } + } + + // Otherwise append + *el = append(*el, Extension{ + ExtensionType: src.Type(), + ExtensionData: data, + }) + return nil +} + +func (el ExtensionList) Find(dst ExtensionBody) bool { + for _, ext := range el { + if ext.ExtensionType == dst.Type() { + _, err := dst.Unmarshal(ext.ExtensionData) + return err == nil + } + } + return false +} + +// struct { +// NameType name_type; +// select (name_type) { +// case host_name: HostName; +// } name; +// } ServerName; +// +// enum { +// host_name(0), (255) +// } NameType; +// +// opaque HostName<1..2^16-1>; +// +// struct { +// ServerName server_name_list<1..2^16-1> +// } ServerNameList; +// +// But we only care about the case where there's a single DNS hostname. We +// will never create anything else, and throw if we receive something else +// +// 2 1 2 +// | listLen | NameType | nameLen | name | +type ServerNameExtension string + +type serverNameInner struct { + NameType uint8 + HostName []byte `tls:"head=2,min=1"` +} + +type serverNameListInner struct { + ServerNameList []serverNameInner `tls:"head=2,min=1"` +} + +func (sni ServerNameExtension) Type() ExtensionType { + return ExtensionTypeServerName +} + +func (sni ServerNameExtension) Marshal() ([]byte, error) { + list := serverNameListInner{ + ServerNameList: []serverNameInner{{ + NameType: 0x00, // host_name + HostName: []byte(sni), + }}, + } + + return syntax.Marshal(list) +} + +func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) { + var list serverNameListInner + read, err := syntax.Unmarshal(data, &list) + if err != nil { + return 0, err + } + + // Syntax requires at least one entry + // Entries beyond the first are ignored + if nameType := list.ServerNameList[0].NameType; nameType != 0x00 { + return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType) + } + + *sni = ServerNameExtension(list.ServerNameList[0].HostName) + return read, nil +} + +// struct { +// NamedGroup group; +// opaque key_exchange<1..2^16-1>; +// } KeyShareEntry; +// +// struct { +// select (Handshake.msg_type) { +// case client_hello: +// KeyShareEntry client_shares<0..2^16-1>; +// +// case hello_retry_request: +// NamedGroup selected_group; +// +// case server_hello: +// KeyShareEntry server_share; +// }; +// } KeyShare; +type KeyShareEntry struct { + Group NamedGroup + KeyExchange []byte `tls:"head=2,min=1"` +} + +func (kse KeyShareEntry) SizeValid() bool { + return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group) +} + +type KeyShareExtension struct { + HandshakeType HandshakeType + SelectedGroup NamedGroup + Shares []KeyShareEntry +} + +type KeyShareClientHelloInner struct { + ClientShares []KeyShareEntry `tls:"head=2,min=0"` +} +type KeyShareHelloRetryInner struct { + SelectedGroup NamedGroup +} +type KeyShareServerHelloInner struct { + ServerShare KeyShareEntry +} + +func (ks KeyShareExtension) Type() ExtensionType { + return ExtensionTypeKeyShare +} + +func (ks KeyShareExtension) Marshal() ([]byte, error) { + switch ks.HandshakeType { + case HandshakeTypeClientHello: + for _, share := range ks.Shares { + if !share.SizeValid() { + return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + } + return syntax.Marshal(KeyShareClientHelloInner{ks.Shares}) + + case HandshakeTypeHelloRetryRequest: + if len(ks.Shares) > 0 { + return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest") + } + + return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup}) + + case HandshakeTypeServerHello: + if len(ks.Shares) != 1 { + return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share") + } + + if !ks.Shares[0].SizeValid() { + return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + + return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]}) + + default: + return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed") + } +} + +func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) { + switch ks.HandshakeType { + case HandshakeTypeClientHello: + var inner KeyShareClientHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + for _, share := range inner.ClientShares { + if !share.SizeValid() { + return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + } + + ks.Shares = inner.ClientShares + return read, nil + + case HandshakeTypeHelloRetryRequest: + var inner KeyShareHelloRetryInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + ks.SelectedGroup = inner.SelectedGroup + return read, nil + + case HandshakeTypeServerHello: + var inner KeyShareServerHelloInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if !inner.ServerShare.SizeValid() { + return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") + } + + ks.Shares = []KeyShareEntry{inner.ServerShare} + return read, nil + + default: + return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed") + } +} + +// struct { +// NamedGroup named_group_list<2..2^16-1>; +// } NamedGroupList; +type SupportedGroupsExtension struct { + Groups []NamedGroup `tls:"head=2,min=2"` +} + +func (sg SupportedGroupsExtension) Type() ExtensionType { + return ExtensionTypeSupportedGroups +} + +func (sg SupportedGroupsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sg) +} + +func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sg) +} + +// struct { +// SignatureScheme supported_signature_algorithms<2..2^16-2>; +// } SignatureSchemeList +type SignatureAlgorithmsExtension struct { + Algorithms []SignatureScheme `tls:"head=2,min=2"` +} + +func (sa SignatureAlgorithmsExtension) Type() ExtensionType { + return ExtensionTypeSignatureAlgorithms +} + +func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sa) +} + +func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sa) +} + +// struct { +// opaque identity<1..2^16-1>; +// uint32 obfuscated_ticket_age; +// } PskIdentity; +// +// opaque PskBinderEntry<32..255>; +// +// struct { +// select (Handshake.msg_type) { +// case client_hello: +// PskIdentity identities<7..2^16-1>; +// PskBinderEntry binders<33..2^16-1>; +// +// case server_hello: +// uint16 selected_identity; +// }; +// +// } PreSharedKeyExtension; +type PSKIdentity struct { + Identity []byte `tls:"head=2,min=1"` + ObfuscatedTicketAge uint32 +} + +type PSKBinderEntry struct { + Binder []byte `tls:"head=1,min=32"` +} + +type PreSharedKeyExtension struct { + HandshakeType HandshakeType + Identities []PSKIdentity + Binders []PSKBinderEntry + SelectedIdentity uint16 +} + +type preSharedKeyClientInner struct { + Identities []PSKIdentity `tls:"head=2,min=7"` + Binders []PSKBinderEntry `tls:"head=2,min=33"` +} + +type preSharedKeyServerInner struct { + SelectedIdentity uint16 +} + +func (psk PreSharedKeyExtension) Type() ExtensionType { + return ExtensionTypePreSharedKey +} + +func (psk PreSharedKeyExtension) Marshal() ([]byte, error) { + switch psk.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(preSharedKeyClientInner{ + Identities: psk.Identities, + Binders: psk.Binders, + }) + + case HandshakeTypeServerHello: + if len(psk.Identities) > 0 || len(psk.Binders) > 0 { + return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index") + } + return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity}) + + default: + return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported") + } +} + +func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) { + switch psk.HandshakeType { + case HandshakeTypeClientHello: + var inner preSharedKeyClientInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + if len(inner.Identities) != len(inner.Binders) { + return 0, fmt.Errorf("Lengths of identities and binders not equal") + } + + psk.Identities = inner.Identities + psk.Binders = inner.Binders + return read, nil + + case HandshakeTypeServerHello: + var inner preSharedKeyServerInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + psk.SelectedIdentity = inner.SelectedIdentity + return read, nil + + default: + return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported") + } +} + +func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) { + for i, localID := range psk.Identities { + if bytes.Equal(localID.Identity, id) { + return psk.Binders[i].Binder, true + } + } + return nil, false +} + +// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode; +// +// struct { +// PskKeyExchangeMode ke_modes<1..255>; +// } PskKeyExchangeModes; +type PSKKeyExchangeModesExtension struct { + KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"` +} + +func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType { + return ExtensionTypePSKKeyExchangeModes +} + +func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) { + return syntax.Marshal(pkem) +} + +func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, pkem) +} + +// struct { +// } EarlyDataIndication; + +type EarlyDataExtension struct{} + +func (ed EarlyDataExtension) Type() ExtensionType { + return ExtensionTypeEarlyData +} + +func (ed EarlyDataExtension) Marshal() ([]byte, error) { + return []byte{}, nil +} + +func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) { + return 0, nil +} + +// struct { +// uint32 max_early_data_size; +// } TicketEarlyDataInfo; + +type TicketEarlyDataInfoExtension struct { + MaxEarlyDataSize uint32 +} + +func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType { + return ExtensionTypeTicketEarlyDataInfo +} + +func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) { + return syntax.Marshal(tedi) +} + +func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, tedi) +} + +// opaque ProtocolName<1..2^8-1>; +// +// struct { +// ProtocolName protocol_name_list<2..2^16-1> +// } ProtocolNameList; +type ALPNExtension struct { + Protocols []string +} + +type protocolNameInner struct { + Name []byte `tls:"head=1,min=1"` +} + +type alpnExtensionInner struct { + Protocols []protocolNameInner `tls:"head=2,min=2"` +} + +func (alpn ALPNExtension) Type() ExtensionType { + return ExtensionTypeALPN +} + +func (alpn ALPNExtension) Marshal() ([]byte, error) { + protocols := make([]protocolNameInner, len(alpn.Protocols)) + for i, protocol := range alpn.Protocols { + protocols[i] = protocolNameInner{[]byte(protocol)} + } + return syntax.Marshal(alpnExtensionInner{protocols}) +} + +func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { + var inner alpnExtensionInner + read, err := syntax.Unmarshal(data, &inner) + + if err != nil { + return 0, err + } + + alpn.Protocols = make([]string, len(inner.Protocols)) + for i, protocol := range inner.Protocols { + alpn.Protocols[i] = string(protocol.Name) + } + return read, nil +} + +// struct { +// ProtocolVersion versions<2..254>; +// } SupportedVersions; +type SupportedVersionsExtension struct { + Versions []uint16 `tls:"head=1,min=2,max=254"` +} + +func (sv SupportedVersionsExtension) Type() ExtensionType { + return ExtensionTypeSupportedVersions +} + +func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { + return syntax.Marshal(sv) +} + +func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sv) +} + +// struct { +// opaque cookie<1..2^16-1>; +// } Cookie; +type CookieExtension struct { + Cookie []byte `tls:"head=2,min=1"` +} + +func (c CookieExtension) Type() ExtensionType { + return ExtensionTypeCookie +} + +func (c CookieExtension) Marshal() ([]byte, error) { + return syntax.Marshal(c) +} + +func (c *CookieExtension) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, c) +} + +// defaultCookieLength is the default length of a cookie +const defaultCookieLength = 32 + +type defaultCookieHandler struct { + data []byte +} + +var _ CookieHandler = &defaultCookieHandler{} + +// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data +func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) { + h.data = make([]byte, defaultCookieLength) + if _, err := prng.Read(h.data); err != nil { + return nil, err + } + return h.data, nil +} + +func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool { + return bytes.Equal(h.data, data) +} diff --git a/vendor/github.com/bifurcation/mint/ffdhe.go b/vendor/github.com/bifurcation/mint/ffdhe.go new file mode 100644 index 0000000..59d1f7f --- /dev/null +++ b/vendor/github.com/bifurcation/mint/ffdhe.go @@ -0,0 +1,147 @@ +package mint + +import ( + "encoding/hex" + "math/big" +) + +var ( + finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B423861285C97FFFFFFFFFFFFFFFF" + finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex) + finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes) + + finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF" + finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex) + finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes) + + finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" + + "FFFFFFFFFFFFFFFF" + finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex) + finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes) + + finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + + "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + + "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + + "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + + "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + + "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + + "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + + "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + + "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + + "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + + "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + + "A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF" + finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex) + finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes) + + finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + + "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + + "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + + "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + + "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + + "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + + "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + + "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + + "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + + "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + + "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + + "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + + "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + + "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + + "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + + "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + + "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + + "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + + "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + + "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + + "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + + "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + + "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + + "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + + "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + + "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + + "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + + "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + + "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + + "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + + "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + + "A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" + + "1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" + + "0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" + + "CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" + + "2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" + + "BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" + + "51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" + + "D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" + + "1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" + + "FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" + + "97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" + + "D68C8BB7C5C6424CFFFFFFFFFFFFFFFF" + finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex) + finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes) +) diff --git a/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/bifurcation/mint/frame-reader.go new file mode 100644 index 0000000..99ea470 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/frame-reader.go @@ -0,0 +1,98 @@ +// Read a generic "framed" packet consisting of a header and a +// This is used for both TLS Records and TLS Handshake Messages +package mint + +type framing interface { + headerLen() int + defaultReadLen() int + frameLen(hdr []byte) (int, error) +} + +const ( + kFrameReaderHdr = 0 + kFrameReaderBody = 1 +) + +type frameNextAction func(f *frameReader) error + +type frameReader struct { + details framing + state uint8 + header []byte + body []byte + working []byte + writeOffset int + remainder []byte +} + +func newFrameReader(d framing) *frameReader { + hdr := make([]byte, d.headerLen()) + return &frameReader{ + d, + kFrameReaderHdr, + hdr, + nil, + hdr, + 0, + nil, + } +} + +func dup(a []byte) []byte { + r := make([]byte, len(a)) + copy(r, a) + return r +} + +func (f *frameReader) needed() int { + tmp := (len(f.working) - f.writeOffset) - len(f.remainder) + if tmp < 0 { + return 0 + } + return tmp +} + +func (f *frameReader) addChunk(in []byte) { + // Append to the buffer. + logf(logTypeFrameReader, "Appending %v", len(in)) + f.remainder = append(f.remainder, in...) +} + +func (f *frameReader) process() (hdr []byte, body []byte, err error) { + for f.needed() == 0 { + logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) + // Fill out our working block + copied := copy(f.working[f.writeOffset:], f.remainder) + f.remainder = f.remainder[copied:] + f.writeOffset += copied + if f.writeOffset < len(f.working) { + logf(logTypeFrameReader, "Read would have blocked 1") + return nil, nil, WouldBlock + } + // Reset the write offset, because we are now full. + f.writeOffset = 0 + + // We have read a full frame + if f.state == kFrameReaderBody { + logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) + f.state = kFrameReaderHdr + f.working = f.header + return dup(f.header), dup(f.body), nil + } + + // We have read the header + bodyLen, err := f.details.frameLen(f.header) + if err != nil { + return nil, nil, err + } + logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) + + f.body = make([]byte, bodyLen) + f.working = f.body + f.writeOffset = 0 + f.state = kFrameReaderBody + } + + logf(logTypeFrameReader, "Read would have blocked 2") + return nil, nil, WouldBlock +} diff --git a/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/bifurcation/mint/handshake-layer.go new file mode 100644 index 0000000..2b04ac5 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/handshake-layer.go @@ -0,0 +1,253 @@ +package mint + +import ( + "fmt" + "io" + "net" +) + +const ( + handshakeHeaderLen = 4 // handshake message header length + maxHandshakeMessageLen = 1 << 24 // max handshake message length +) + +// struct { +// HandshakeType msg_type; /* handshake type */ +// uint24 length; /* bytes in message */ +// select (HandshakeType) { +// ... +// } body; +// } Handshake; +// +// We do the select{...} part in a different layer, so we treat the +// actual message body as opaque: +// +// struct { +// HandshakeType msg_type; +// opaque msg<0..2^24-1> +// } Handshake; +// +// TODO: File a spec bug +type HandshakeMessage struct { + // Omitted: length + msgType HandshakeType + body []byte +} + +// Note: This could be done with the `syntax` module, using the simplified +// syntax as discussed above. However, since this is so simple, there's not +// much benefit to doing so. +func (hm *HandshakeMessage) Marshal() []byte { + if hm == nil { + return []byte{} + } + + msgLen := len(hm.body) + data := make([]byte, 4+len(hm.body)) + data[0] = byte(hm.msgType) + data[1] = byte(msgLen >> 16) + data[2] = byte(msgLen >> 8) + data[3] = byte(msgLen) + copy(data[4:], hm.body) + return data +} + +func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { + logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body) + + var body HandshakeMessageBody + switch hm.msgType { + case HandshakeTypeClientHello: + body = new(ClientHelloBody) + case HandshakeTypeServerHello: + body = new(ServerHelloBody) + case HandshakeTypeHelloRetryRequest: + body = new(HelloRetryRequestBody) + case HandshakeTypeEncryptedExtensions: + body = new(EncryptedExtensionsBody) + case HandshakeTypeCertificate: + body = new(CertificateBody) + case HandshakeTypeCertificateRequest: + body = new(CertificateRequestBody) + case HandshakeTypeCertificateVerify: + body = new(CertificateVerifyBody) + case HandshakeTypeFinished: + body = &FinishedBody{VerifyDataLen: len(hm.body)} + case HandshakeTypeNewSessionTicket: + body = new(NewSessionTicketBody) + case HandshakeTypeKeyUpdate: + body = new(KeyUpdateBody) + case HandshakeTypeEndOfEarlyData: + body = new(EndOfEarlyDataBody) + default: + return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") + } + + _, err := body.Unmarshal(hm.body) + return body, err +} + +func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { + data, err := body.Marshal() + if err != nil { + return nil, err + } + + return &HandshakeMessage{ + msgType: body.Type(), + body: data, + }, nil +} + +type HandshakeLayer struct { + nonblocking bool // Should we operate in nonblocking mode + conn *RecordLayer // Used for reading/writing records + frame *frameReader // The buffered frame reader +} + +type handshakeLayerFrameDetails struct{} + +func (d handshakeLayerFrameDetails) headerLen() int { + return handshakeHeaderLen +} + +func (d handshakeLayerFrameDetails) defaultReadLen() int { + return handshakeHeaderLen + maxFragmentLen +} + +func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { + logf(logTypeIO, "Header=%x", hdr) + return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil +} + +func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer { + h := HandshakeLayer{} + h.conn = r + h.frame = newFrameReader(&handshakeLayerFrameDetails{}) + return &h +} + +func (h *HandshakeLayer) readRecord() error { + logf(logTypeIO, "Trying to read record") + pt, err := h.conn.ReadRecord() + if err != nil { + return err + } + + if pt.contentType != RecordTypeHandshake && + pt.contentType != RecordTypeAlert { + return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) + } + + if pt.contentType == RecordTypeAlert { + logf(logTypeIO, "read alert %v", pt.fragment[1]) + if len(pt.fragment) < 2 { + h.sendAlert(AlertUnexpectedMessage) + return io.EOF + } + return Alert(pt.fragment[1]) + } + + logf(logTypeIO, "read handshake record of len %v", len(pt.fragment)) + h.frame.addChunk(pt.fragment) + + return nil +} + +// sendAlert sends a TLS alert message. +func (h *HandshakeLayer) sendAlert(err Alert) error { + tmp := make([]byte, 2) + tmp[0] = AlertLevelError + tmp[1] = byte(err) + h.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAlert, + fragment: tmp}, + ) + + // closeNotify is a special case in that it isn't an error: + if err != AlertCloseNotify { + return &net.OpError{Op: "local error", Err: err} + } + return nil +} + +func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { + var hdr, body []byte + var err error + + for { + logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder)) + if h.frame.needed() > 0 { + logf(logTypeHandshake, "Trying to read a new record") + err = h.readRecord() + } + if err != nil && (h.nonblocking || err != WouldBlock) { + return nil, err + } + + hdr, body, err = h.frame.process() + if err == nil { + break + } + if err != nil && (h.nonblocking || err != WouldBlock) { + return nil, err + } + } + + logf(logTypeHandshake, "read handshake message") + + hm := &HandshakeMessage{} + hm.msgType = HandshakeType(hdr[0]) + + hm.body = make([]byte, len(body)) + copy(hm.body, body) + + return hm, nil +} + +func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { + return h.WriteMessages([]*HandshakeMessage{hm}) +} + +func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { + for _, hm := range hms { + logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) + } + + // Write out headers and bodies + buffer := []byte{} + for _, msg := range hms { + msgLen := len(msg.body) + if msgLen > maxHandshakeMessageLen { + return fmt.Errorf("tls.handshakelayer: Message too large to send") + } + + buffer = append(buffer, msg.Marshal()...) + } + + // Send full-size fragments + var start int + for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { + err := h.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeHandshake, + fragment: buffer[start : start+maxFragmentLen], + }) + + if err != nil { + return err + } + } + + // Send a final partial fragment if necessary + if start < len(buffer) { + err := h.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeHandshake, + fragment: buffer[start:], + }) + + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/bifurcation/mint/handshake-messages.go b/vendor/github.com/bifurcation/mint/handshake-messages.go new file mode 100644 index 0000000..339bbcd --- /dev/null +++ b/vendor/github.com/bifurcation/mint/handshake-messages.go @@ -0,0 +1,450 @@ +package mint + +import ( + "bytes" + "crypto" + "crypto/x509" + "encoding/binary" + "fmt" + + "github.com/bifurcation/mint/syntax" +) + +type HandshakeMessageBody interface { + Type() HandshakeType + Marshal() ([]byte, error) + Unmarshal(data []byte) (int, error) +} + +// struct { +// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ +// Random random; +// opaque legacy_session_id<0..32>; +// CipherSuite cipher_suites<2..2^16-2>; +// opaque legacy_compression_methods<1..2^8-1>; +// Extension extensions<0..2^16-1>; +// } ClientHello; +type ClientHelloBody struct { + // Omitted: clientVersion + // Omitted: legacySessionID + // Omitted: legacyCompressionMethods + Random [32]byte + CipherSuites []CipherSuite + Extensions ExtensionList +} + +type clientHelloBodyInner struct { + LegacyVersion uint16 + Random [32]byte + LegacySessionID []byte `tls:"head=1,max=32"` + CipherSuites []CipherSuite `tls:"head=2,min=2"` + LegacyCompressionMethods []byte `tls:"head=1,min=1"` + Extensions []Extension `tls:"head=2"` +} + +func (ch ClientHelloBody) Type() HandshakeType { + return HandshakeTypeClientHello +} + +func (ch ClientHelloBody) Marshal() ([]byte, error) { + return syntax.Marshal(clientHelloBodyInner{ + LegacyVersion: 0x0303, + Random: ch.Random, + LegacySessionID: []byte{}, + CipherSuites: ch.CipherSuites, + LegacyCompressionMethods: []byte{0}, + Extensions: ch.Extensions, + }) +} + +func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { + var inner clientHelloBodyInner + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return 0, err + } + + // We are strict about these things because we only support 1.3 + if inner.LegacyVersion != 0x0303 { + return 0, fmt.Errorf("tls.clienthello: Incorrect version number") + } + + if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { + return 0, fmt.Errorf("tls.clienthello: Invalid compression method") + } + + ch.Random = inner.Random + ch.CipherSuites = inner.CipherSuites + ch.Extensions = inner.Extensions + return read, nil +} + +// TODO: File a spec bug to clarify this +func (ch ClientHelloBody) Truncated() ([]byte, error) { + if len(ch.Extensions) == 0 { + return nil, fmt.Errorf("tls.clienthello.truncate: No extensions") + } + + pskExt := ch.Extensions[len(ch.Extensions)-1] + if pskExt.ExtensionType != ExtensionTypePreSharedKey { + return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") + } + + chm, err := HandshakeMessageFromBody(&ch) + if err != nil { + return nil, err + } + chData := chm.Marshal() + + psk := PreSharedKeyExtension{ + HandshakeType: HandshakeTypeClientHello, + } + _, err = psk.Unmarshal(pskExt.ExtensionData) + if err != nil { + return nil, err + } + + // Marshal just the binders so that we know how much to truncate + binders := struct { + Binders []PSKBinderEntry `tls:"head=2,min=33"` + }{Binders: psk.Binders} + binderData, _ := syntax.Marshal(binders) + binderLen := len(binderData) + + chLen := len(chData) + return chData[:chLen-binderLen], nil +} + +// struct { +// ProtocolVersion server_version; +// CipherSuite cipher_suite; +// Extension extensions<2..2^16-1>; +// } HelloRetryRequest; +type HelloRetryRequestBody struct { + Version uint16 + CipherSuite CipherSuite + Extensions ExtensionList `tls:"head=2,min=2"` +} + +func (hrr HelloRetryRequestBody) Type() HandshakeType { + return HandshakeTypeHelloRetryRequest +} + +func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) { + return syntax.Marshal(hrr) +} + +func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, hrr) +} + +// struct { +// ProtocolVersion version; +// Random random; +// CipherSuite cipher_suite; +// Extension extensions<0..2^16-1>; +// } ServerHello; +type ServerHelloBody struct { + Version uint16 + Random [32]byte + CipherSuite CipherSuite + Extensions ExtensionList `tls:"head=2"` +} + +func (sh ServerHelloBody) Type() HandshakeType { + return HandshakeTypeServerHello +} + +func (sh ServerHelloBody) Marshal() ([]byte, error) { + return syntax.Marshal(sh) +} + +func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, sh) +} + +// struct { +// opaque verify_data[verify_data_length]; +// } Finished; +// +// verifyDataLen is not a field in the TLS struct, but we add it here so +// that calling code can tell us how much data to expect when we marshal / +// unmarshal. (We could add this to the marshal/unmarshal methods, but let's +// try to keep the signature consistent for now.) +// +// For similar reasons, we don't use the `syntax` module here, because this +// struct doesn't map well to standard TLS presentation language concepts. +// +// TODO: File a spec bug +type FinishedBody struct { + VerifyDataLen int + VerifyData []byte +} + +func (fin FinishedBody) Type() HandshakeType { + return HandshakeTypeFinished +} + +func (fin FinishedBody) Marshal() ([]byte, error) { + if len(fin.VerifyData) != fin.VerifyDataLen { + return nil, fmt.Errorf("tls.finished: data length mismatch") + } + + body := make([]byte, len(fin.VerifyData)) + copy(body, fin.VerifyData) + return body, nil +} + +func (fin *FinishedBody) Unmarshal(data []byte) (int, error) { + if len(data) < fin.VerifyDataLen { + return 0, fmt.Errorf("tls.finished: Malformed finished; too short") + } + + fin.VerifyData = make([]byte, fin.VerifyDataLen) + copy(fin.VerifyData, data[:fin.VerifyDataLen]) + return fin.VerifyDataLen, nil +} + +// struct { +// Extension extensions<0..2^16-1>; +// } EncryptedExtensions; +// +// Marshal() and Unmarshal() are handled by ExtensionList +type EncryptedExtensionsBody struct { + Extensions ExtensionList `tls:"head=2"` +} + +func (ee EncryptedExtensionsBody) Type() HandshakeType { + return HandshakeTypeEncryptedExtensions +} + +func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) { + return syntax.Marshal(ee) +} + +func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ee) +} + +// opaque ASN1Cert<1..2^24-1>; +// +// struct { +// ASN1Cert cert_data; +// Extension extensions<0..2^16-1> +// } CertificateEntry; +// +// struct { +// opaque certificate_request_context<0..2^8-1>; +// CertificateEntry certificate_list<0..2^24-1>; +// } Certificate; +type CertificateEntry struct { + CertData *x509.Certificate + Extensions ExtensionList +} + +type CertificateBody struct { + CertificateRequestContext []byte + CertificateList []CertificateEntry +} + +type certificateEntryInner struct { + CertData []byte `tls:"head=3,min=1"` + Extensions ExtensionList `tls:"head=2"` +} + +type certificateBodyInner struct { + CertificateRequestContext []byte `tls:"head=1"` + CertificateList []certificateEntryInner `tls:"head=3"` +} + +func (c CertificateBody) Type() HandshakeType { + return HandshakeTypeCertificate +} + +func (c CertificateBody) Marshal() ([]byte, error) { + inner := certificateBodyInner{ + CertificateRequestContext: c.CertificateRequestContext, + CertificateList: make([]certificateEntryInner, len(c.CertificateList)), + } + + for i, entry := range c.CertificateList { + inner.CertificateList[i] = certificateEntryInner{ + CertData: entry.CertData.Raw, + Extensions: entry.Extensions, + } + } + + return syntax.Marshal(inner) +} + +func (c *CertificateBody) Unmarshal(data []byte) (int, error) { + inner := certificateBodyInner{} + read, err := syntax.Unmarshal(data, &inner) + if err != nil { + return read, err + } + + c.CertificateRequestContext = inner.CertificateRequestContext + c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) + + for i, entry := range inner.CertificateList { + c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) + if err != nil { + return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) + } + + c.CertificateList[i].Extensions = entry.Extensions + } + + return read, nil +} + +// struct { +// SignatureScheme algorithm; +// opaque signature<0..2^16-1>; +// } CertificateVerify; +type CertificateVerifyBody struct { + Algorithm SignatureScheme + Signature []byte `tls:"head=2"` +} + +func (cv CertificateVerifyBody) Type() HandshakeType { + return HandshakeTypeCertificateVerify +} + +func (cv CertificateVerifyBody) Marshal() ([]byte, error) { + return syntax.Marshal(cv) +} + +func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, cv) +} + +func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte { + // TODO: Change context for client auth + // TODO: Put this in a const + const context = "TLS 1.3, server CertificateVerify" + sigInput := bytes.Repeat([]byte{0x20}, 64) + sigInput = append(sigInput, []byte(context)...) + sigInput = append(sigInput, []byte{0}...) + sigInput = append(sigInput, data...) + return sigInput +} + +func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) { + sigInput := cv.EncodeSignatureInput(handshakeHash) + cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput) + logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) + return +} + +func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error { + sigInput := cv.EncodeSignatureInput(handshakeHash) + logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) + return verify(cv.Algorithm, publicKey, sigInput, cv.Signature) +} + +// struct { +// opaque certificate_request_context<0..2^8-1>; +// Extension extensions<2..2^16-1>; +// } CertificateRequest; +type CertificateRequestBody struct { + CertificateRequestContext []byte `tls:"head=1"` + Extensions ExtensionList `tls:"head=2"` +} + +func (cr CertificateRequestBody) Type() HandshakeType { + return HandshakeTypeCertificateRequest +} + +func (cr CertificateRequestBody) Marshal() ([]byte, error) { + return syntax.Marshal(cr) +} + +func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, cr) +} + +// struct { +// uint32 ticket_lifetime; +// uint32 ticket_age_add; +// opaque ticket_nonce<1..255>; +// opaque ticket<1..2^16-1>; +// Extension extensions<0..2^16-2>; +// } NewSessionTicket; +type NewSessionTicketBody struct { + TicketLifetime uint32 + TicketAgeAdd uint32 + TicketNonce []byte `tls:"head=1,min=1"` + Ticket []byte `tls:"head=2,min=1"` + Extensions ExtensionList `tls:"head=2"` +} + +const ticketNonceLen = 16 + +func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) { + buf := make([]byte, 4+ticketNonceLen+ticketLen) + _, err := prng.Read(buf) + if err != nil { + return nil, err + } + + tkt := &NewSessionTicketBody{ + TicketLifetime: ticketLifetime, + TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]), + TicketNonce: buf[4 : 4+ticketNonceLen], + Ticket: buf[4+ticketNonceLen:], + } + + return tkt, err +} + +func (tkt NewSessionTicketBody) Type() HandshakeType { + return HandshakeTypeNewSessionTicket +} + +func (tkt NewSessionTicketBody) Marshal() ([]byte, error) { + return syntax.Marshal(tkt) +} + +func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, tkt) +} + +// enum { +// update_not_requested(0), update_requested(1), (255) +// } KeyUpdateRequest; +// +// struct { +// KeyUpdateRequest request_update; +// } KeyUpdate; +type KeyUpdateBody struct { + KeyUpdateRequest KeyUpdateRequest +} + +func (ku KeyUpdateBody) Type() HandshakeType { + return HandshakeTypeKeyUpdate +} + +func (ku KeyUpdateBody) Marshal() ([]byte, error) { + return syntax.Marshal(ku) +} + +func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) { + return syntax.Unmarshal(data, ku) +} + +// struct {} EndOfEarlyData; +type EndOfEarlyDataBody struct{} + +func (eoed EndOfEarlyDataBody) Type() HandshakeType { + return HandshakeTypeEndOfEarlyData +} + +func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) { + return []byte{}, nil +} + +func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) { + return 0, nil +} diff --git a/vendor/github.com/bifurcation/mint/log.go b/vendor/github.com/bifurcation/mint/log.go new file mode 100644 index 0000000..2fba90d --- /dev/null +++ b/vendor/github.com/bifurcation/mint/log.go @@ -0,0 +1,55 @@ +package mint + +import ( + "fmt" + "log" + "os" + "strings" +) + +// We use this environment variable to control logging. It should be a +// comma-separated list of log tags (see below) or "*" to enable all logging. +const logConfigVar = "MINT_LOG" + +// Pre-defined log types +const ( + logTypeCrypto = "crypto" + logTypeHandshake = "handshake" + logTypeNegotiation = "negotiation" + logTypeIO = "io" + logTypeFrameReader = "frame" + logTypeVerbose = "verbose" +) + +var ( + logFunction = log.Printf + logAll = false + logSettings = map[string]bool{} +) + +func init() { + parseLogEnv(os.Environ()) +} + +func parseLogEnv(env []string) { + for _, stmt := range env { + if strings.HasPrefix(stmt, logConfigVar+"=") { + val := stmt[len(logConfigVar)+1:] + + if val == "*" { + logAll = true + } else { + for _, t := range strings.Split(val, ",") { + logSettings[t] = true + } + } + } + } +} + +func logf(tag string, format string, args ...interface{}) { + if logAll || logSettings[tag] { + fullFormat := fmt.Sprintf("[%s] %s", tag, format) + logFunction(fullFormat, args...) + } +} diff --git a/vendor/github.com/bifurcation/mint/mint.svg b/vendor/github.com/bifurcation/mint/mint.svg new file mode 100644 index 0000000..ae32703 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/mint.svg @@ -0,0 +1,101 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + diff --git a/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/bifurcation/mint/negotiation.go new file mode 100644 index 0000000..f4ead72 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/negotiation.go @@ -0,0 +1,217 @@ +package mint + +import ( + "bytes" + "encoding/hex" + "fmt" + "time" +) + +func VersionNegotiation(offered, supported []uint16) (bool, uint16) { + for _, offeredVersion := range offered { + for _, supportedVersion := range supported { + logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion) + if offeredVersion == supportedVersion { + // XXX: Should probably be highest supported version, but for now, we + // only support one version, so it doesn't really matter. + return true, offeredVersion + } + } + } + + return false, 0 +} + +func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) { + for _, share := range keyShares { + for _, group := range groups { + if group != share.Group { + continue + } + + pub, priv, err := newKeyShare(share.Group) + if err != nil { + // If we encounter an error, just keep looking + continue + } + + dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv) + if err != nil { + // If we encounter an error, just keep looking + continue + } + + return true, group, pub, dhSecret + } + } + + return false, 0, nil, nil +} + +const ( + ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds +) + +func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) { + logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size()) + for i, id := range identities { + identityHex := hex.EncodeToString(id.Identity) + + psk, ok := psks.Get(identityHex) + if !ok { + logf(logTypeNegotiation, "No PSK for identity %x", identityHex) + continue + } + + // For resumption, make sure the ticket age is correct + if psk.IsResumption { + extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd + knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond) + ticketAgeDelta := knownTicketAge - extTicketAge + if knownTicketAge < extTicketAge { + ticketAgeDelta = extTicketAge - knownTicketAge + } + if ticketAgeDelta > ticketAgeTolerance { + logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity) + logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]", + extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance) + return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity) + } + } + + params, ok := cipherSuiteMap[psk.CipherSuite] + if !ok { + err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite) + return false, 0, nil, CipherSuiteParams{}, err + } + + // Compute binder + binderLabel := labelExternalBinder + if psk.IsResumption { + binderLabel = labelResumptionBinder + } + + h0 := params.Hash.New().Sum(nil) + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + earlySecret := HkdfExtract(params.Hash, zero, psk.Key) + binderKey := deriveSecret(params, earlySecret, binderLabel, h0) + + // context = ClientHello[truncated] + // context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated] + ctxHash := params.Hash.New() + ctxHash.Write(context) + + binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil)) + if !bytes.Equal(binder, binders[i].Binder) { + logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder) + return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity) + } + + logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity) + return true, i, &psk, params, nil + } + + logf(logTypeNegotiation, "Failed to find a usable PSK") + return false, 0, nil, CipherSuiteParams{}, nil +} + +func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) { + logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes) + dhAllowed := false + dhRequired := true + for _, mode := range modes { + dhAllowed = dhAllowed || (mode == PSKModeDHEKE) + dhRequired = dhRequired && (mode == PSKModeDHEKE) + } + + // Use PSK if we can meet DH requirement and modes were provided + usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0) + + // Use DH if allowed + usingDH := canDoDH && (dhAllowed || !usingPSK) + + logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK) + return usingDH, usingPSK +} + +func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) { + // Select for server name if provided + candidates := certs + if serverName != nil { + candidatesByName := []*Certificate{} + for _, cert := range certs { + for _, name := range cert.Chain[0].DNSNames { + if len(*serverName) > 0 && name == *serverName { + candidatesByName = append(candidatesByName, cert) + } + } + } + + if len(candidatesByName) == 0 { + return nil, 0, fmt.Errorf("No certificates available for server name") + } + + candidates = candidatesByName + } + + // Select for signature scheme + for _, cert := range candidates { + for _, scheme := range signatureSchemes { + if !schemeValidForKey(scheme, cert.PrivateKey) { + continue + } + + return cert, scheme, nil + } + } + + return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") +} + +func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { + usingEarlyData := gotEarlyData && usingPSK && allowEarlyData + logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) + return usingEarlyData +} + +func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { + for _, s1 := range offered { + if psk != nil { + if s1 == psk.CipherSuite { + return s1, nil + } + continue + } + + for _, s2 := range supported { + if s1 == s2 { + return s1, nil + } + } + } + + return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil) +} + +func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) { + for _, p1 := range offered { + if psk != nil { + if p1 != psk.NextProto { + continue + } + } + + for _, p2 := range supported { + if p1 == p2 { + return p1, nil + } + } + } + + // If the client offers ALPN on resumption, it must match the earlier one + var err error + if psk != nil && psk.IsResumption && (len(offered) > 0) { + err = fmt.Errorf("ALPN for PSK not provided") + } + return "", err +} diff --git a/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/bifurcation/mint/record-layer.go new file mode 100644 index 0000000..bcef613 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/record-layer.go @@ -0,0 +1,296 @@ +package mint + +import ( + "bytes" + "crypto/cipher" + "fmt" + "io" + "sync" +) + +const ( + sequenceNumberLen = 8 // sequence number length + recordHeaderLen = 5 // record header length + maxFragmentLen = 1 << 14 // max number of bytes in a record +) + +type DecryptError string + +func (err DecryptError) Error() string { + return string(err) +} + +// struct { +// ContentType type; +// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */ +// uint16 length; +// opaque fragment[TLSPlaintext.length]; +// } TLSPlaintext; +type TLSPlaintext struct { + // Omitted: record_version (static) + // Omitted: length (computed from fragment) + contentType RecordType + fragment []byte +} + +type RecordLayer struct { + sync.Mutex + + conn io.ReadWriter // The underlying connection + frame *frameReader // The buffered frame reader + nextData []byte // The next record to send + cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" + cachedError error // Error on the last record read + + ivLength int // Length of the seq and nonce fields + seq []byte // Zero-padded sequence number + nonce []byte // Buffer for per-record nonces + cipher cipher.AEAD // AEAD cipher +} + +type recordLayerFrameDetails struct{} + +func (d recordLayerFrameDetails) headerLen() int { + return recordHeaderLen +} + +func (d recordLayerFrameDetails) defaultReadLen() int { + return recordHeaderLen + maxFragmentLen +} + +func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { + return (int(hdr[3]) << 8) | int(hdr[4]), nil +} + +func NewRecordLayer(conn io.ReadWriter) *RecordLayer { + r := RecordLayer{} + r.conn = conn + r.frame = newFrameReader(recordLayerFrameDetails{}) + r.ivLength = 0 + return &r +} + +func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error { + var err error + r.cipher, err = cipher(key) + if err != nil { + return err + } + + r.ivLength = len(iv) + r.seq = bytes.Repeat([]byte{0}, r.ivLength) + r.nonce = make([]byte, r.ivLength) + copy(r.nonce, iv) + return nil +} + +func (r *RecordLayer) incrementSequenceNumber() { + if r.ivLength == 0 { + return + } + + for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- { + r.seq[i]++ + r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i] + if r.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext { + // Expand the fragment to hold contentType, padding, and overhead + originalLen := len(pt.fragment) + plaintextLen := originalLen + 1 + padLen + ciphertextLen := plaintextLen + r.cipher.Overhead() + + // Assemble the revised plaintext + out := &TLSPlaintext{ + contentType: RecordTypeApplicationData, + fragment: make([]byte, ciphertextLen), + } + copy(out.fragment, pt.fragment) + out.fragment[originalLen] = byte(pt.contentType) + for i := 1; i <= padLen; i++ { + out.fragment[originalLen+i] = 0 + } + + // Encrypt the fragment + payload := out.fragment[:plaintextLen] + r.cipher.Seal(payload[:0], r.nonce, payload, nil) + return out +} + +func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) { + if len(pt.fragment) < r.cipher.Overhead() { + msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead()) + return nil, 0, DecryptError(msg) + } + + decryptLen := len(pt.fragment) - r.cipher.Overhead() + out := &TLSPlaintext{ + contentType: pt.contentType, + fragment: make([]byte, decryptLen), + } + + // Decrypt + _, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil) + if err != nil { + return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") + } + + // Find the padding boundary + padLen := 0 + for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ { + } + + // Transfer the content type + newLen := decryptLen - padLen - 1 + out.contentType = RecordType(out.fragment[newLen]) + + // Truncate the message to remove contentType, padding, overhead + out.fragment = out.fragment[:newLen] + return out, padLen, nil +} + +func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { + var pt *TLSPlaintext + var err error + + for { + pt, err = r.nextRecord() + if err == nil { + break + } + if !block || err != WouldBlock { + return 0, err + } + } + return pt.contentType, nil +} + +func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { + pt, err := r.nextRecord() + + // Consume the cached record if there was one + r.cachedRecord = nil + r.cachedError = nil + + return pt, err +} + +func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { + if r.cachedRecord != nil { + logf(logTypeIO, "Returning cached record") + return r.cachedRecord, r.cachedError + } + + // Loop until one of three things happens: + // + // 1. We get a frame + // 2. We try to read off the socket and get nothing, in which case + // return WouldBlock + // 3. We get an error. + err := WouldBlock + var header, body []byte + + for err != nil { + if r.frame.needed() > 0 { + buf := make([]byte, recordHeaderLen+maxFragmentLen) + n, err := r.conn.Read(buf) + if err != nil { + logf(logTypeIO, "Error reading, %v", err) + return nil, err + } + + if n == 0 { + return nil, WouldBlock + } + + logf(logTypeIO, "Read %v bytes", n) + + buf = buf[:n] + r.frame.addChunk(buf) + } + + header, body, err = r.frame.process() + // Loop around on WouldBlock to see if some + // data is now available. + if err != nil && err != WouldBlock { + return nil, err + } + } + + pt := &TLSPlaintext{} + // Validate content type + switch RecordType(header[0]) { + default: + return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) + case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: + pt.contentType = RecordType(header[0]) + } + + // Validate version + if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) { + return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2]) + } + + // Validate size < max + size := (int(header[3]) << 8) + int(header[4]) + if size > maxFragmentLen+256 { + return nil, fmt.Errorf("tls.record: Ciphertext size too big") + } + + pt.fragment = make([]byte, size) + copy(pt.fragment, body) + + // Attempt to decrypt fragment + if r.cipher != nil { + pt, _, err = r.decrypt(pt) + if err != nil { + return nil, err + } + } + + // Check that plaintext length is not too long + if len(pt.fragment) > maxFragmentLen { + return nil, fmt.Errorf("tls.record: Plaintext size too big") + } + + logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) + + r.cachedRecord = pt + r.incrementSequenceNumber() + return pt, nil +} + +func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { + return r.WriteRecordWithPadding(pt, 0) +} + +func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { + if r.cipher != nil { + pt = r.encrypt(pt, padLen) + } else if padLen > 0 { + return fmt.Errorf("tls.record: Padding can only be done on encrypted records") + } + + if len(pt.fragment) > maxFragmentLen { + return fmt.Errorf("tls.record: Record size too big") + } + + length := len(pt.fragment) + header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)} + record := append(header, pt.fragment...) + + logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment) + + r.incrementSequenceNumber() + _, err := r.conn.Write(record) + return err +} diff --git a/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/bifurcation/mint/server-state-machine.go new file mode 100644 index 0000000..60df9b6 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/server-state-machine.go @@ -0,0 +1,898 @@ +package mint + +import ( + "bytes" + "hash" + "reflect" +) + +// Server State Machine +// +// START <-----+ +// Recv ClientHello | | Send HelloRetryRequest +// v | +// RECVD_CH ----+ +// | Select parameters +// | Send ServerHello +// v +// NEGOTIATED +// | Send EncryptedExtensions +// | [Send CertificateRequest] +// Can send | [Send Certificate + CertificateVerify] +// app data --> | Send Finished +// after +--------+--------+ +// here No 0-RTT | | 0-RTT +// | v +// | WAIT_EOED <---+ +// | Recv | | | Recv +// | EndOfEarlyData | | | early data +// | | +-----+ +// +> WAIT_FLIGHT2 <-+ +// | +// +--------+--------+ +// No auth | | Client auth +// | | +// | v +// | WAIT_CERT +// | Recv | | Recv Certificate +// | empty | v +// | Certificate | WAIT_CV +// | | | Recv +// | v | CertificateVerify +// +-> WAIT_FINISHED <---+ +// | Recv Finished +// v +// CONNECTED +// +// NB: Not using state RECVD_CH +// +// State Instructions +// START {} +// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] +// WAIT_EOED RekeyIn; +// WAIT_FLIGHT2 {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) + +type ServerStateStart struct { + Caps Capabilities + conn *Conn + + cookieSent bool + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage +} + +func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeClientHello { + logf(logTypeHandshake, "[ServerStateStart] unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + ch := &ClientHelloBody{} + _, err := ch.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + clientHello := hm + connParams := ConnectionParameters{} + + supportedVersions := new(SupportedVersionsExtension) + serverName := new(ServerNameExtension) + supportedGroups := new(SupportedGroupsExtension) + signatureAlgorithms := new(SignatureAlgorithmsExtension) + clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello} + clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello} + clientEarlyData := &EarlyDataExtension{} + clientALPN := new(ALPNExtension) + clientPSKModes := new(PSKKeyExchangeModesExtension) + clientCookie := new(CookieExtension) + + // Handle external extensions. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err) + return nil, nil, AlertInternalError + } + } + + gotSupportedVersions := ch.Extensions.Find(supportedVersions) + gotServerName := ch.Extensions.Find(serverName) + gotSupportedGroups := ch.Extensions.Find(supportedGroups) + gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms) + gotEarlyData := ch.Extensions.Find(clientEarlyData) + ch.Extensions.Find(clientKeyShares) + ch.Extensions.Find(clientPSK) + ch.Extensions.Find(clientALPN) + ch.Extensions.Find(clientPSKModes) + ch.Extensions.Find(clientCookie) + + if gotServerName { + connParams.ServerName = string(*serverName) + } + + // If the client didn't send supportedVersions or doesn't support 1.3, + // then we're done here. + if !gotSupportedVersions { + logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") + return nil, nil, AlertProtocolVersion + } + versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion}) + if !versionOK { + logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version") + return nil, nil, AlertProtocolVersion + } + + if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) { + logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") + return nil, nil, AlertAccessDenied + } + + // Figure out if we can do DH + canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups) + + // Figure out if we can do PSK + canDoPSK := false + var selectedPSK int + var psk *PreSharedKey + var params CipherSuiteParams + if len(clientPSK.Identities) > 0 { + contextBase := []byte{} + if state.helloRetryRequest != nil { + chBytes := state.firstClientHello.Marshal() + hrrBytes := state.helloRetryRequest.Marshal() + contextBase = append(chBytes, hrrBytes...) + } + + chTrunc, err := ch.Truncated() + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err) + return nil, nil, AlertDecodeError + } + + context := append(contextBase, chTrunc...) + + canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Figure out if we actually should do DH / PSK + connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) + + // Select a ciphersuite + connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) + return nil, nil, AlertHandshakeFailure + } + + // Send a cookie if required + // NB: Need to do this here because it's after ciphersuite selection, which + // has to be after PSK selection. + // XXX: Doing this statefully for now, could be stateless + var cookieData []byte + if state.Caps.RequireCookie && !state.cookieSent { + var err error + cookieData, err = state.Caps.CookieHandler.Generate(state.conn) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) + return nil, nil, AlertInternalError + } + } + if cookieData != nil { + // Ignoring errors because everything here is newly constructed, so there + // shouldn't be marshal errors + hrr := &HelloRetryRequestBody{ + Version: supportedVersion, + CipherSuite: connParams.CipherSuite, + } + hrr.Extensions.Add(&CookieExtension{Cookie: cookieData}) + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + helloRetryRequest, err := HandshakeMessageFromBody(hrr) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) + return nil, nil, AlertInternalError + } + + params := cipherSuiteMap[connParams.CipherSuite] + h := params.Hash.New() + h.Write(clientHello.Marshal()) + firstClientHello := &HandshakeMessage{ + msgType: HandshakeTypeMessageHash, + body: h.Sum(nil), + } + + nextState := ServerStateStart{ + Caps: state.Caps, + conn: state.conn, + cookieSent: true, + firstClientHello: firstClientHello, + helloRetryRequest: helloRetryRequest, + } + toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}} + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") + return nextState, toSend, AlertNoAlert + } + + // If we've got no entropy to make keys from, fail + if !connParams.UsingDH && !connParams.UsingPSK { + logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated") + return nil, nil, AlertHandshakeFailure + } + + var pskSecret []byte + var cert *Certificate + var certScheme SignatureScheme + if connParams.UsingPSK { + pskSecret = psk.Key + } else { + psk = nil + + // If we're not using a PSK mode, then we need to have certain extensions + if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms { + logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)", + gotServerName, gotSupportedGroups, gotSignatureAlgorithms) + return nil, nil, AlertMissingExtension + } + + // Select a certificate + name := string(*serverName) + var err error + cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err) + return nil, nil, AlertAccessDenied + } + } + + if !connParams.UsingDH { + dhSecret = nil + } + + // Figure out if we're going to do early data + var clientEarlyTrafficSecret []byte + connParams.ClientSendingEarlyData = gotEarlyData + connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData) + if connParams.UsingEarlyData { + + h := params.Hash.New() + h.Write(clientHello.Marshal()) + chHash := h.Sum(nil) + + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + earlySecret := HkdfExtract(params.Hash, zero, pskSecret) + clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) + } + + // Select a next protocol + connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos) + if err != nil { + logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err) + return nil, nil, AlertNoApplicationProtocol + } + + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") + return ServerStateNegotiated{ + Caps: state.Caps, + Params: connParams, + + dhGroup: dhGroup, + dhPublic: dhPublic, + dhSecret: dhSecret, + pskSecret: pskSecret, + selectedPSK: selectedPSK, + cert: cert, + certScheme: certScheme, + clientEarlyTrafficSecret: clientEarlyTrafficSecret, + + firstClientHello: state.firstClientHello, + helloRetryRequest: state.helloRetryRequest, + clientHello: clientHello, + }.Next(nil) +} + +type ServerStateNegotiated struct { + Caps Capabilities + Params ConnectionParameters + + dhGroup NamedGroup + dhPublic []byte + dhSecret []byte + pskSecret []byte + clientEarlyTrafficSecret []byte + selectedPSK int + cert *Certificate + certScheme SignatureScheme + + firstClientHello *HandshakeMessage + helloRetryRequest *HandshakeMessage + clientHello *HandshakeMessage +} + +func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + // Create the ServerHello + sh := &ServerHelloBody{ + Version: supportedVersion, + CipherSuite: state.Params.CipherSuite, + } + _, err := prng.Read(sh.Random[:]) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) + return nil, nil, AlertInternalError + } + if state.Params.UsingDH { + logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") + err = sh.Extensions.Add(&KeyShareExtension{ + HandshakeType: HandshakeTypeServerHello, + Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.Params.UsingPSK { + logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension") + err = sh.Extensions.Add(&PreSharedKeyExtension{ + HandshakeType: HandshakeTypeServerHello, + SelectedIdentity: uint16(state.selectedPSK), + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err) + return nil, nil, AlertInternalError + } + } + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + serverHello, err := HandshakeMessageFromBody(sh) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err) + return nil, nil, AlertInternalError + } + + // Look up crypto params + params, ok := cipherSuiteMap[sh.CipherSuite] + if !ok { + logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite) + return nil, nil, AlertHandshakeFailure + } + + // Start up the handshake hash + handshakeHash := params.Hash.New() + handshakeHash.Write(state.firstClientHello.Marshal()) + handshakeHash.Write(state.helloRetryRequest.Marshal()) + handshakeHash.Write(state.clientHello.Marshal()) + handshakeHash.Write(serverHello.Marshal()) + + // Compute handshake secrets + zero := bytes.Repeat([]byte{0}, params.Hash.Size()) + + var earlySecret []byte + if state.Params.UsingPSK { + earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret) + } else { + earlySecret = HkdfExtract(params.Hash, zero, zero) + } + + if state.dhSecret == nil { + state.dhSecret = zero + } + + h0 := params.Hash.New().Sum(nil) + h2 := handshakeHash.Sum(nil) + preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) + handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret) + clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) + serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) + preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) + masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) + + logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret) + logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) + logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) + logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) + logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) + + clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret) + serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) + + // Send an EncryptedExtensions message (even if it's empty) + eeList := ExtensionList{} + if state.Params.NextProto != "" { + logf(logTypeHandshake, "[server] sending ALPN extension") + err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}}) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.Params.UsingEarlyData { + logf(logTypeHandshake, "[server] sending EDI extension") + err = eeList.Add(&EarlyDataExtension{}) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + ee := &EncryptedExtensionsBody{eeList} + + // Run the external extension handler. + if state.Caps.ExtensionHandler != nil { + err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) + return nil, nil, AlertInternalError + } + } + + eem, err := HandshakeMessageFromBody(ee) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + + handshakeHash.Write(eem.Marshal()) + + toSend := []HandshakeAction{ + SendHandshakeMessage{serverHello}, + RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys}, + SendHandshakeMessage{eem}, + } + + // Authenticate with a certificate if required + if !state.Params.UsingPSK { + // Send a CertificateRequest message if we want client auth + if state.Caps.RequireClientAuth { + state.Params.UsingClientAuth = true + + // XXX: We don't support sending any constraints besides a list of + // supported signature algorithms + cr := &CertificateRequestBody{} + schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes} + err := cr.Extensions.Add(schemes) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err) + return nil, nil, AlertInternalError + } + + crm, err := HandshakeMessageFromBody(cr) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err) + return nil, nil, AlertInternalError + } + //TODO state.state.serverCertificateRequest = cr + + toSend = append(toSend, SendHandshakeMessage{crm}) + handshakeHash.Write(crm.Marshal()) + } + + // Create and send Certificate, CertificateVerify + certificate := &CertificateBody{ + CertificateList: make([]CertificateEntry, len(state.cert.Chain)), + } + for i, entry := range state.cert.Chain { + certificate.CertificateList[i] = CertificateEntry{CertData: entry} + } + certm, err := HandshakeMessageFromBody(certificate) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certm}) + handshakeHash.Write(certm.Marshal()) + + certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} + logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) + + hcv := handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + err = certificateVerify.Sign(state.cert.PrivateKey, hcv) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + certvm, err := HandshakeMessageFromBody(certificateVerify) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err) + return nil, nil, AlertInternalError + } + + toSend = append(toSend, SendHandshakeMessage{certvm}) + handshakeHash.Write(certvm.Marshal()) + } + + // Compute secrets resulting from the server's first flight + h3 := handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) + + serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3) + logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) + + // Assemble the Finished message + fin := &FinishedBody{ + VerifyDataLen: len(serverFinishedData), + VerifyData: serverFinishedData, + } + finm, _ := HandshakeMessageFromBody(fin) + + toSend = append(toSend, SendHandshakeMessage{finm}) + handshakeHash.Write(finm.Marshal()) + + // Compute traffic secrets + h4 := handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4) + logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4) + + clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4) + serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4) + logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) + logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) + + serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret) + toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys}) + + exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4) + logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret) + + if state.Params.UsingEarlyData { + clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret) + + logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]") + nextState := ServerStateWaitEOED{ + AuthCertificate: state.Caps.AuthCertificate, + Params: state.Params, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + toSend = append(toSend, []HandshakeAction{ + RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys}, + ReadEarlyData{}, + }...) + return nextState, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") + toSend = append(toSend, []HandshakeAction{ + RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, + ReadPastEarlyData{}, + }...) + waitFlight2 := ServerStateWaitFlight2{ + AuthCertificate: state.Caps.AuthCertificate, + Params: state.Params, + cryptoParams: params, + handshakeHash: handshakeHash, + masterSecret: masterSecret, + clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, + clientTrafficSecret: clientTrafficSecret, + serverTrafficSecret: serverTrafficSecret, + exporterSecret: exporterSecret, + } + nextState, moreToSend, alert := waitFlight2.Next(nil) + toSend = append(toSend, moreToSend...) + return nextState, toSend, alert +} + +type ServerStateWaitEOED struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData { + logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + if len(hm.body) > 0 { + logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]") + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) + + logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]") + toSend := []HandshakeAction{ + RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys}, + } + waitFlight2 := ServerStateWaitFlight2{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + nextState, moreToSend, alert := waitFlight2.Next(nil) + toSend = append(toSend, moreToSend...) + return nextState, toSend, alert +} + +type ServerStateWaitFlight2 struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm != nil { + logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + if state.Params.UsingClientAuth { + logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]") + nextState := ServerStateWaitCert{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + handshakeHash: state.handshakeHash, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitCert struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + masterSecret []byte + clientHandshakeTrafficSecret []byte + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificate { + logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + cert := &CertificateBody{} + _, err := cert.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") + return nil, nil, AlertDecodeError + } + + state.handshakeHash.Write(hm.Marshal()) + + if len(cert.CertificateList) == 0 { + logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate") + + logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert + } + + logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]") + nextState := ServerStateWaitCV{ + AuthCertificate: state.AuthCertificate, + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + clientCertificate: cert, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitCV struct { + AuthCertificate func(chain []CertificateEntry) error + Params ConnectionParameters + cryptoParams CipherSuiteParams + + masterSecret []byte + clientHandshakeTrafficSecret []byte + + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte + + clientCertificate *CertificateBody +} + +func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { + logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm)) + return nil, nil, AlertUnexpectedMessage + } + + certVerify := &CertificateVerifyBody{} + _, err := certVerify.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) + return nil, nil, AlertDecodeError + } + + // Verify client signature over handshake hash + hcv := state.handshakeHash.Sum(nil) + logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) + + clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey + if err := certVerify.Verify(clientPublicKey, hcv); err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) + return nil, nil, AlertHandshakeFailure + } + + if state.AuthCertificate != nil { + err := state.AuthCertificate(state.clientCertificate.CertificateList) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate") + return nil, nil, AlertBadCertificate + } + } else { + logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate") + } + + // If it passes, record the certificateVerify in the transcript hash + state.handshakeHash.Write(hm.Marshal()) + + logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]") + nextState := ServerStateWaitFinished{ + Params: state.Params, + cryptoParams: state.cryptoParams, + masterSecret: state.masterSecret, + clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, + handshakeHash: state.handshakeHash, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + return nextState, nil, AlertNoAlert +} + +type ServerStateWaitFinished struct { + Params ConnectionParameters + cryptoParams CipherSuiteParams + + masterSecret []byte + clientHandshakeTrafficSecret []byte + + handshakeHash hash.Hash + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil || hm.msgType != HandshakeTypeFinished { + logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} + _, err := fin.Unmarshal(hm.body) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) + return nil, nil, AlertDecodeError + } + + // Verify client Finished data + h5 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) + + clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) + logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) + + if !bytes.Equal(fin.VerifyData, clientFinishedData) { + logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify") + return nil, nil, AlertHandshakeFailure + } + + // Compute the resumption secret + state.handshakeHash.Write(hm.Marshal()) + h6 := state.handshakeHash.Sum(nil) + logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6) + + resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) + logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) + + // Compute client traffic keys + clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + + logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") + nextState := StateConnected{ + Params: state.Params, + isClient: false, + cryptoParams: state.cryptoParams, + resumptionSecret: resumptionSecret, + clientTrafficSecret: state.clientTrafficSecret, + serverTrafficSecret: state.serverTrafficSecret, + exporterSecret: state.exporterSecret, + } + toSend := []HandshakeAction{ + RekeyIn{Label: "application", KeySet: clientTrafficKeys}, + } + return nextState, toSend, AlertNoAlert +} diff --git a/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/bifurcation/mint/state-machine.go new file mode 100644 index 0000000..4eb468c --- /dev/null +++ b/vendor/github.com/bifurcation/mint/state-machine.go @@ -0,0 +1,230 @@ +package mint + +import ( + "time" +) + +// Marker interface for actions that an implementation should take based on +// state transitions. +type HandshakeAction interface{} + +type SendHandshakeMessage struct { + Message *HandshakeMessage +} + +type SendEarlyData struct{} + +type ReadEarlyData struct{} + +type ReadPastEarlyData struct{} + +type RekeyIn struct { + Label string + KeySet keySet +} + +type RekeyOut struct { + Label string + KeySet keySet +} + +type StorePSK struct { + PSK PreSharedKey +} + +type HandshakeState interface { + Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) +} + +type AppExtensionHandler interface { + Send(hs HandshakeType, el *ExtensionList) error + Receive(hs HandshakeType, el *ExtensionList) error +} + +// Capabilities objects represent the capabilities of a TLS client or server, +// as an input to TLS negotiation +type Capabilities struct { + // For both client and server + CipherSuites []CipherSuite + Groups []NamedGroup + SignatureSchemes []SignatureScheme + PSKs PreSharedKeyCache + Certificates []*Certificate + AuthCertificate func(chain []CertificateEntry) error + ExtensionHandler AppExtensionHandler + + // For client + PSKModes []PSKKeyExchangeMode + + // For server + NextProtos []string + AllowEarlyData bool + RequireCookie bool + CookieHandler CookieHandler + RequireClientAuth bool +} + +// ConnectionOptions objects represent per-connection settings for a client +// initiating a connection +type ConnectionOptions struct { + ServerName string + NextProtos []string + EarlyData []byte +} + +// ConnectionParameters objects represent the parameters negotiated for a +// connection. +type ConnectionParameters struct { + UsingPSK bool + UsingDH bool + ClientSendingEarlyData bool + UsingEarlyData bool + UsingClientAuth bool + + CipherSuite CipherSuite + ServerName string + NextProto string +} + +// StateConnected is symmetric between client and server +type StateConnected struct { + Params ConnectionParameters + isClient bool + cryptoParams CipherSuiteParams + resumptionSecret []byte + clientTrafficSecret []byte + serverTrafficSecret []byte + exporterSecret []byte +} + +func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { + var trafficKeys keySet + if state.isClient { + state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, + labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + } else { + state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, + labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) + } + + kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) + return nil, AlertInternalError + } + + toSend := []HandshakeAction{ + SendHandshakeMessage{kum}, + RekeyOut{Label: "update", KeySet: trafficKeys}, + } + return toSend, AlertNoAlert +} + +func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { + tkt, err := NewSessionTicket(length, lifetime) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime}) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, + labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size()) + + newPSK := PreSharedKey{ + CipherSuite: state.cryptoParams.Suite, + IsResumption: true, + Identity: tkt.Ticket, + Key: resumptionKey, + NextProto: state.Params.NextProto, + ReceivedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second), + TicketAgeAdd: tkt.TicketAgeAdd, + } + + tktm, err := HandshakeMessageFromBody(tkt) + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) + return nil, AlertInternalError + } + + toSend := []HandshakeAction{ + StorePSK{newPSK}, + SendHandshakeMessage{tktm}, + } + return toSend, AlertNoAlert +} + +func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { + if hm == nil { + logf(logTypeHandshake, "[StateConnected] Unexpected message") + return nil, nil, AlertUnexpectedMessage + } + + bodyGeneric, err := hm.ToBody() + if err != nil { + logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err) + return nil, nil, AlertDecodeError + } + + switch body := bodyGeneric.(type) { + case *KeyUpdateBody: + var trafficKeys keySet + if !state.isClient { + state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, + labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + } else { + state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, + labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) + trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) + } + + toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}} + + // If requested, roll outbound keys and send a KeyUpdate + if body.KeyUpdateRequest == KeyUpdateRequested { + moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) + if alert != AlertNoAlert { + return nil, nil, alert + } + + toSend = append(toSend, moreToSend...) + } + + return state, toSend, AlertNoAlert + + case *NewSessionTicketBody: + // XXX: Allow NewSessionTicket in both directions? + if !state.isClient { + return nil, nil, AlertUnexpectedMessage + } + + resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, + labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) + + psk := PreSharedKey{ + CipherSuite: state.cryptoParams.Suite, + IsResumption: true, + Identity: body.Ticket, + Key: resumptionKey, + NextProto: state.Params.NextProto, + ReceivedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second), + TicketAgeAdd: body.TicketAgeAdd, + } + + toSend := []HandshakeAction{StorePSK{psk}} + return state, toSend, AlertNoAlert + } + + logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType) + return nil, nil, AlertUnexpectedMessage +} diff --git a/vendor/github.com/bifurcation/mint/syntax/README.md b/vendor/github.com/bifurcation/mint/syntax/README.md new file mode 100644 index 0000000..dbf4ec2 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/syntax/README.md @@ -0,0 +1,74 @@ +TLS Syntax +========== + +TLS defines [its own syntax](https://tlswg.github.io/tls13-spec/#rfc.section.3) +for describing structures used in that protocol. To facilitate experimentation +with TLS in Go, this module maps that syntax to the Go structure syntax, taking +advantage of Go's type annotations to encode non-type information carried in the +TLS presentation format. + +For example, in the TLS specification, a ClientHello message has the following +structure: + +~~~~~ +uint16 ProtocolVersion; +opaque Random[32]; +uint8 CipherSuite[2]; +enum { ... (65535)} ExtensionType; + +struct { + ExtensionType extension_type; + opaque extension_data<0..2^16-1>; +} Extension; + +struct { + ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ + Random random; + opaque legacy_session_id<0..32>; + CipherSuite cipher_suites<2..2^16-2>; + opaque legacy_compression_methods<1..2^8-1>; + Extension extensions<0..2^16-1>; +} ClientHello; +~~~~~ + +This maps to the following Go type definitions: + +~~~~~ +type protocolVersion uint16 +type random [32]byte +type cipherSuite uint16 // or [2]byte +type extensionType uint16 + +type extension struct { + ExtensionType extensionType + ExtensionData []byte `tls:"head=2"` +} + +type clientHello struct { + LegacyVersion protocolVersion + Random random + LegacySessionID []byte `tls:"head=1,max=32"` + CipherSuites []cipherSuite `tls:"head=2,min=2"` + LegacyCompressionMethods []byte `tls:"head=1,min=1"` + Extensions []extension `tls:"head=2"` +} +~~~~~ + +Then you can just declare, marshal, and unmarshal structs just like you would +with, say JSON. + +The available annotations right now are all related to vectors: + +* `head`: The number of bytes of length to use as a "header" +* `min`: The minimum length of the vector, in bytes +* `max`: The maximum length of the vector, in bytes + +## Not supported + +* The `select()` syntax for creating alternate version of the same struct (see, + e.g., the KeyShare extension) + +* The backreference syntax for array lengths or select parameters, as in `opaque + fragment[TLSPlaintext.length]`. Note, however, that in cases where the length + immediately preceds the array, these can be reframed as vectors with + appropriate sizes. diff --git a/vendor/github.com/bifurcation/mint/syntax/decode.go b/vendor/github.com/bifurcation/mint/syntax/decode.go new file mode 100644 index 0000000..cd5aada --- /dev/null +++ b/vendor/github.com/bifurcation/mint/syntax/decode.go @@ -0,0 +1,243 @@ +package syntax + +import ( + "bytes" + "fmt" + "reflect" + "runtime" +) + +func Unmarshal(data []byte, v interface{}) (int, error) { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + d := decodeState{} + d.Write(data) + return d.unmarshal(v) +} + +// These are the options that can be specified in the struct tag. Right now, +// all of them apply to variable-length vectors and nothing else +type decOpts struct { + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes +} + +type decodeState struct { + bytes.Buffer +} + +func (d *decodeState) unmarshal(v interface{}) (read int, err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if s, ok := r.(string); ok { + panic(s) + } + err = r.(error) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)") + } + + read = d.value(rv) + return read, nil +} + +func (e *decodeState) value(v reflect.Value) int { + return valueDecoder(v)(e, v, decOpts{}) +} + +type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int + +func valueDecoder(v reflect.Value) decoderFunc { + return typeDecoder(v.Type().Elem()) +} + +func typeDecoder(t reflect.Type) decoderFunc { + // Note: Omits the caching / wait-group things that encoding/json uses + return newTypeDecoder(t) +} + +func newTypeDecoder(t reflect.Type) decoderFunc { + // Note: Does not support Marshaler, so don't need the allowAddr argument + + switch t.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uintDecoder + case reflect.Array: + return newArrayDecoder(t) + case reflect.Slice: + return newSliceDecoder(t) + case reflect.Struct: + return newStructDecoder(t) + default: + panic(fmt.Errorf("Unsupported type (%s)", t)) + } +} + +///// Specific decoders below + +func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int { + var uintLen int + switch v.Elem().Kind() { + case reflect.Uint8: + uintLen = 1 + case reflect.Uint16: + uintLen = 2 + case reflect.Uint32: + uintLen = 4 + case reflect.Uint64: + uintLen = 8 + } + + buf := make([]byte, uintLen) + n, err := d.Read(buf) + if err != nil { + panic(err) + } + if n != uintLen { + panic(fmt.Errorf("Insufficient data to read uint")) + } + + val := uint64(0) + for _, b := range buf { + val = (val << 8) + uint64(b) + } + + v.Elem().SetUint(val) + return uintLen +} + +////////// + +type arrayDecoder struct { + elemDec decoderFunc +} + +func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + n := v.Elem().Type().Len() + read := 0 + for i := 0; i < n; i += 1 { + read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts) + } + return read +} + +func newArrayDecoder(t reflect.Type) decoderFunc { + dec := &arrayDecoder{typeDecoder(t.Elem())} + return dec.decode +} + +////////// + +type sliceDecoder struct { + elementType reflect.Type + elementDec decoderFunc +} + +func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + if opts.head == 0 { + panic(fmt.Errorf("Cannot decode a slice without a header length")) + } + + lengthBytes := make([]byte, opts.head) + n, err := d.Read(lengthBytes) + if err != nil { + panic(err) + } + if uint(n) != opts.head { + panic(fmt.Errorf("Not enough data to read header")) + } + + length := uint(0) + for _, b := range lengthBytes { + length = (length << 8) + uint(b) + } + + if opts.max > 0 && length > opts.max { + panic(fmt.Errorf("Length of vector exceeds declared max")) + } + if length < opts.min { + panic(fmt.Errorf("Length of vector below declared min")) + } + + data := make([]byte, length) + n, err = d.Read(data) + if err != nil { + panic(err) + } + if uint(n) != length { + panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length)) + } + + elemBuf := &decodeState{} + elemBuf.Write(data) + elems := []reflect.Value{} + read := int(opts.head) + for elemBuf.Len() > 0 { + elem := reflect.New(sd.elementType) + read += sd.elementDec(elemBuf, elem, opts) + elems = append(elems, elem) + } + + v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems))) + for i := 0; i < len(elems); i += 1 { + v.Elem().Index(i).Set(elems[i].Elem()) + } + return read +} + +func newSliceDecoder(t reflect.Type) decoderFunc { + dec := &sliceDecoder{ + elementType: t.Elem(), + elementDec: typeDecoder(t.Elem()), + } + return dec.decode +} + +////////// + +type structDecoder struct { + fieldOpts []decOpts + fieldDecs []decoderFunc +} + +func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int { + read := 0 + for i := range sd.fieldDecs { + read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i]) + } + return read +} + +func newStructDecoder(t reflect.Type) decoderFunc { + n := t.NumField() + sd := structDecoder{ + fieldOpts: make([]decOpts, n), + fieldDecs: make([]decoderFunc, n), + } + + for i := 0; i < n; i += 1 { + f := t.Field(i) + + tag := f.Tag.Get("tls") + tagOpts := parseTag(tag) + + sd.fieldOpts[i] = decOpts{ + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + } + + sd.fieldDecs[i] = typeDecoder(f.Type) + } + + return sd.decode +} diff --git a/vendor/github.com/bifurcation/mint/syntax/encode.go b/vendor/github.com/bifurcation/mint/syntax/encode.go new file mode 100644 index 0000000..2874f40 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/syntax/encode.go @@ -0,0 +1,187 @@ +package syntax + +import ( + "bytes" + "fmt" + "reflect" + "runtime" +) + +func Marshal(v interface{}) ([]byte, error) { + e := &encodeState{} + err := e.marshal(v, encOpts{}) + if err != nil { + return nil, err + } + return e.Bytes(), nil +} + +// These are the options that can be specified in the struct tag. Right now, +// all of them apply to variable-length vectors and nothing else +type encOpts struct { + head uint // length of length in bytes + min uint // minimum size in bytes + max uint // maximum size in bytes +} + +type encodeState struct { + bytes.Buffer +} + +func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if s, ok := r.(string); ok { + panic(s) + } + err = r.(error) + } + }() + e.reflectValue(reflect.ValueOf(v), opts) + return nil +} + +func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) { + valueEncoder(v)(e, v, opts) +} + +type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts) + +func valueEncoder(v reflect.Value) encoderFunc { + if !v.IsValid() { + panic(fmt.Errorf("Cannot encode an invalid value")) + } + return typeEncoder(v.Type()) +} + +func typeEncoder(t reflect.Type) encoderFunc { + // Note: Omits the caching / wait-group things that encoding/json uses + return newTypeEncoder(t) +} + +func newTypeEncoder(t reflect.Type) encoderFunc { + // Note: Does not support Marshaler, so don't need the allowAddr argument + + switch t.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uintEncoder + case reflect.Array: + return newArrayEncoder(t) + case reflect.Slice: + return newSliceEncoder(t) + case reflect.Struct: + return newStructEncoder(t) + default: + panic(fmt.Errorf("Unsupported type (%s)", t)) + } +} + +///// Specific encoders below + +func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) { + u := v.Uint() + switch v.Type().Kind() { + case reflect.Uint8: + e.WriteByte(byte(u)) + case reflect.Uint16: + e.Write([]byte{byte(u >> 8), byte(u)}) + case reflect.Uint32: + e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) + case reflect.Uint64: + e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32), + byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) + } +} + +////////// + +type arrayEncoder struct { + elemEnc encoderFunc +} + +func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + n := v.Len() + for i := 0; i < n; i += 1 { + ae.elemEnc(e, v.Index(i), opts) + } +} + +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := &arrayEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +////////// + +type sliceEncoder struct { + ae *arrayEncoder +} + +func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + if opts.head == 0 { + panic(fmt.Errorf("Cannot encode a slice without a header length")) + } + + arrayState := &encodeState{} + se.ae.encode(arrayState, v, opts) + + n := uint(arrayState.Len()) + if opts.max > 0 && n > opts.max { + panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max)) + } + if n>>(8*opts.head) > 0 { + panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head)) + } + if n < opts.min { + panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min)) + } + + for i := int(opts.head - 1); i >= 0; i -= 1 { + e.WriteByte(byte(n >> (8 * uint(i)))) + } + e.Write(arrayState.Bytes()) +} + +func newSliceEncoder(t reflect.Type) encoderFunc { + enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}} + return enc.encode +} + +////////// + +type structEncoder struct { + fieldOpts []encOpts + fieldEncs []encoderFunc +} + +func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { + for i := range se.fieldEncs { + se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i]) + } +} + +func newStructEncoder(t reflect.Type) encoderFunc { + n := t.NumField() + se := structEncoder{ + fieldOpts: make([]encOpts, n), + fieldEncs: make([]encoderFunc, n), + } + + for i := 0; i < n; i += 1 { + f := t.Field(i) + tag := f.Tag.Get("tls") + tagOpts := parseTag(tag) + + se.fieldOpts[i] = encOpts{ + head: tagOpts["head"], + max: tagOpts["max"], + min: tagOpts["min"], + } + se.fieldEncs[i] = typeEncoder(f.Type) + } + + return se.encode +} diff --git a/vendor/github.com/bifurcation/mint/syntax/tags.go b/vendor/github.com/bifurcation/mint/syntax/tags.go new file mode 100644 index 0000000..a6c9c88 --- /dev/null +++ b/vendor/github.com/bifurcation/mint/syntax/tags.go @@ -0,0 +1,30 @@ +package syntax + +import ( + "strconv" + "strings" +) + +// `tls:"head=2,min=2,max=255"` + +type tagOptions map[string]uint + +// parseTag parses a struct field's "tls" tag as a comma-separated list of +// name=value pairs, where the values MUST be unsigned integers +func parseTag(tag string) tagOptions { + opts := tagOptions{} + for _, token := range strings.Split(tag, ",") { + if strings.Index(token, "=") == -1 { + continue + } + + parts := strings.Split(token, "=") + if len(parts[0]) == 0 { + continue + } + if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 { + opts[parts[0]] = uint(val) + } + } + return opts +} diff --git a/vendor/github.com/bifurcation/mint/tls.go b/vendor/github.com/bifurcation/mint/tls.go new file mode 100644 index 0000000..0c57aba --- /dev/null +++ b/vendor/github.com/bifurcation/mint/tls.go @@ -0,0 +1,168 @@ +package mint + +// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls + +import ( + "errors" + "net" + "strings" + "time" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config) *Conn { + return NewConn(conn, config, false) +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *Conn { + return NewConn(conn, config, true) +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type Listener struct { + net.Listener + config *Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. +func (l *Listener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return + } + server := Server(c, l.config) + err = server.Handshake() + if err == AlertNoAlert { + err = nil + } + c = server + return +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config) net.Listener { + l := new(Listener) + l.Listener = inner + l.config = config + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config) (net.Listener, error) { + if config == nil || !config.ValidForServer() { + return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config), nil +} + +type TimeoutError struct{} + +func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (TimeoutError) Timeout() bool { return true } +func (TimeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- TimeoutError{} + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = &Config{} + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + if err == AlertNoAlert { + err = nil + } + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + if err == AlertNoAlert { + err = nil + } + } + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} diff --git a/vendor/github.com/lucas-clemente/aes12/cipher_generic.go b/vendor/github.com/lucas-clemente/aes12/cipher_generic.go index a9a6abd..6861677 100644 --- a/vendor/github.com/lucas-clemente/aes12/cipher_generic.go +++ b/vendor/github.com/lucas-clemente/aes12/cipher_generic.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !amd64,!s390x +// +build !amd64 package aes12 diff --git a/vendor/github.com/lucas-clemente/quic-go/Changelog.md b/vendor/github.com/lucas-clemente/quic-go/Changelog.md index 0b357ab..4725779 100644 --- a/vendor/github.com/lucas-clemente/quic-go/Changelog.md +++ b/vendor/github.com/lucas-clemente/quic-go/Changelog.md @@ -2,16 +2,19 @@ ## v0.6.0 (unreleased) +- Add support for QUIC 39, drop support for QUIC 35 - 37 - Added `quic.Config` options for maximal flow control windows - Add a `quic.Config` option for QUIC versions -- Add a `quic.Config` option to request truncation of the connection ID from a server +- Add a `quic.Config` option to request omission of the connection ID from a server - Add a `quic.Config` option to configure the source address validation - Add a `quic.Config` option to configure the handshake timeout +- Add a `quic.Config` option to configure the idle timeout - Add a `quic.Config` option to configure keep-alive +- Rename the STK to Cookie - Implement `net.Conn`-style deadlines for streams - Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details. - Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details. - Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper` - Changed `h2quic.Server.Serve()` to accept a `net.PacketConn` -- Drop support for Go 1.7. +- Drop support for Go 1.7 and 1.8. - Various bugfixes diff --git a/vendor/github.com/lucas-clemente/quic-go/README.md b/vendor/github.com/lucas-clemente/quic-go/README.md index 6f3a11e..1a6b1c2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/README.md +++ b/vendor/github.com/lucas-clemente/quic-go/README.md @@ -16,7 +16,7 @@ As Google's QUIC versions are expected to converge towards the [IETF QUIC draft] ## Guides -We currently support Go 1.8+. +We currently support Go 1.9+. Installing and updating dependencies: @@ -69,4 +69,4 @@ http.Client{ ## Contributing -We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [want-help](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aopen+is%3Aissue+label%3Awant-help). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. +We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go index e100264..8492fd4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go @@ -3,18 +3,20 @@ package ackhandler import ( "time" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet SentPacket(packet *Packet) error - ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error + ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error + SetHandshakeComplete() SendingAllowed() bool - GetStopWaitingFrame(force bool) *frames.StopWaitingFrame + GetStopWaitingFrame(force bool) *wire.StopWaitingFrame + ShouldSendRetransmittablePacket() bool DequeuePacketForRetransmission() (packet *Packet) GetLeastUnacked() protocol.PacketNumber @@ -25,8 +27,8 @@ type SentPacketHandler interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error - ReceivedStopWaiting(*frames.StopWaitingFrame) error + SetLowerLimit(protocol.PacketNumber) GetAlarmTimeout() time.Time - GetAckFrame() *frames.AckFrame + GetAckFrame() *wire.AckFrame } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go index e9dbf6a..9c4ee30 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go @@ -3,15 +3,15 @@ package ackhandler import ( "time" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) // A Packet is a packet // +gen linkedlist type Packet struct { PacketNumber protocol.PacketNumber - Frames []frames.Frame + Frames []wire.Frame Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel @@ -19,13 +19,13 @@ type Packet struct { } // GetFramesForRetransmission gets all the frames for retransmission -func (p *Packet) GetFramesForRetransmission() []frames.Frame { - var fs []frames.Frame +func (p *Packet) GetFramesForRetransmission() []wire.Frame { + var fs []wire.Frame for _, frame := range p.Frames { switch frame.(type) { - case *frames.AckFrame: + case *wire.AckFrame: continue - case *frames.StopWaitingFrame: + case *wire.StopWaitingFrame: continue } fs = append(fs, frame) diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go index c5e9dc2..d0cf78d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go @@ -4,15 +4,15 @@ import ( "errors" "time" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") type receivedPacketHandler struct { largestObserved protocol.PacketNumber - ignorePacketsBelow protocol.PacketNumber + lowerLimit protocol.PacketNumber largestObservedReceivedTime time.Time packetHistory *receivedPacketHistory @@ -23,14 +23,17 @@ type receivedPacketHandler struct { retransmittablePacketsReceivedSinceLastAck int ackQueued bool ackAlarm time.Time - lastAck *frames.AckFrame + lastAck *wire.AckFrame + + version protocol.VersionNumber } // NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler() ReceivedPacketHandler { +func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler { return &receivedPacketHandler{ packetHistory: newReceivedPacketHistory(), ackSendDelay: protocol.AckSendDelay, + version: version, } } @@ -39,31 +42,27 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe return errInvalidPacketNumber } - if packetNumber > h.ignorePacketsBelow { - if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { - return err - } - } - if packetNumber > h.largestObserved { h.largestObserved = packetNumber h.largestObservedReceivedTime = time.Now() } + if packetNumber <= h.lowerLimit { + return nil + } + + if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { + return err + } h.maybeQueueAck(packetNumber, shouldInstigateAck) return nil } -func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) error { - // ignore if StopWaiting is unneeded, because we already received a StopWaiting with a higher LeastUnacked - if h.ignorePacketsBelow >= f.LeastUnacked { - return nil - } - - h.ignorePacketsBelow = f.LeastUnacked - 1 - - h.packetHistory.DeleteBelow(f.LeastUnacked) - return nil +// SetLowerLimit sets a lower limit for acking packets. +// Packets with packet numbers smaller or equal than p will not be acked. +func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) { + h.lowerLimit = p + h.packetHistory.DeleteUpTo(p) } func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { @@ -78,10 +77,13 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber h.ackQueued = true } - // Always send an ack every 20 packets in order to allow the peer to discard - // information from the SentPacketManager and provide an RTT measurement. - if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend { - h.ackQueued = true + if h.version < protocol.Version39 { + // Always send an ack every 20 packets in order to allow the peer to discard + // information from the SentPacketManager and provide an RTT measurement. + // From QUIC 39, this is not needed anymore, since the peer will regularly send a retransmittable packet. + if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend { + h.ackQueued = true + } } // if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK @@ -91,7 +93,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber } // check if a new missing range above the previously was created - if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked { + if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked { h.ackQueued = true } @@ -111,15 +113,15 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber } } -func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame { +func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) { return nil } ackRanges := h.packetHistory.GetAckRanges() - ack := &frames.AckFrame{ + ack := &wire.AckFrame{ LargestAcked: h.largestObserved, - LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber, + LowestAcked: ackRanges[len(ackRanges)-1].First, PacketReceivedTime: h.largestObservedReceivedTime, } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go index 791dec1..14bdfd5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go @@ -1,30 +1,26 @@ package ackhandler import ( - "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) +// The receivedPacketHistory stores if a packet number has already been received. +// It does not store packet contents. type receivedPacketHistory struct { ranges *utils.PacketIntervalList - // the map is used as a replacement for a set here. The bool is always supposed to be set to true - receivedPacketNumbers map[protocol.PacketNumber]bool lowestInReceivedPacketNumbers protocol.PacketNumber } -var ( - errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges") - errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets") -) +var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges") // newReceivedPacketHistory creates a new received packet history func newReceivedPacketHistory() *receivedPacketHistory { return &receivedPacketHistory{ - ranges: utils.NewPacketIntervalList(), - receivedPacketNumbers: make(map[protocol.PacketNumber]bool), + ranges: utils.NewPacketIntervalList(), } } @@ -34,12 +30,6 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error { return errTooManyOutstandingReceivedAckRanges } - if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets { - return errTooManyOutstandingReceivedPackets - } - - h.receivedPacketNumbers[p] = true - if h.ranges.Len() == 0 { h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) return nil @@ -84,23 +74,17 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error { return nil } -// DeleteBelow deletes all entries below the leastUnacked packet number -func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) { - h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, leastUnacked) +// DeleteUpTo deletes all entries up to (and including) p +func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) { + h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1) nextEl := h.ranges.Front() for el := h.ranges.Front(); nextEl != nil; el = nextEl { nextEl = el.Next() - if leastUnacked > el.Value.Start && leastUnacked <= el.Value.End { - for i := el.Value.Start; i < leastUnacked; i++ { // adjust start value of a range - delete(h.receivedPacketNumbers, i) - } - el.Value.Start = leastUnacked - } else if el.Value.End < leastUnacked { // delete a whole range - for i := el.Value.Start; i <= el.Value.End; i++ { - delete(h.receivedPacketNumbers, i) - } + if p >= el.Value.Start && p < el.Value.End { + el.Value.Start = p + 1 + } else if el.Value.End <= p { // delete a whole range h.ranges.Remove(el) } else { // no ranges affected. Nothing to do return @@ -108,38 +92,27 @@ func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) } } -// IsDuplicate determines if a packet should be regarded as a duplicate packet -// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed -func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool { - if p < h.lowestInReceivedPacketNumbers { - return true - } - - _, ok := h.receivedPacketNumbers[p] - return ok -} - // GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame -func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { +func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { if h.ranges.Len() == 0 { return nil } - var ackRanges []frames.AckRange - + ackRanges := make([]wire.AckRange, h.ranges.Len()) + i := 0 for el := h.ranges.Back(); el != nil; el = el.Prev() { - ackRanges = append(ackRanges, frames.AckRange{FirstPacketNumber: el.Value.Start, LastPacketNumber: el.Value.End}) + ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End} + i++ } - return ackRanges } -func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange { - ackRange := frames.AckRange{} +func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange { + ackRange := wire.AckRange{} if h.ranges.Len() > 0 { r := h.ranges.Back().Value - ackRange.FirstPacketNumber = r.Start - ackRange.LastPacketNumber = r.End + ackRange.First = r.Start + ackRange.Last = r.End } return ackRange } diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go index 17437b8..e6ce46f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/retransmittable.go @@ -1,12 +1,10 @@ package ackhandler -import ( - "github.com/lucas-clemente/quic-go/frames" -) +import "github.com/lucas-clemente/quic-go/internal/wire" // Returns a new slice with all non-retransmittable frames deleted. -func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame { - res := make([]frames.Frame, 0, len(fs)) +func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame { + res := make([]wire.Frame, 0, len(fs)) for _, f := range fs { if IsFrameRetransmittable(f) { res = append(res, f) @@ -16,11 +14,11 @@ func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame { } // IsFrameRetransmittable returns true if the frame should be retransmitted. -func IsFrameRetransmittable(f frames.Frame) bool { +func IsFrameRetransmittable(f wire.Frame) bool { switch f.(type) { - case *frames.StopWaitingFrame: + case *wire.StopWaitingFrame: return false - case *frames.AckFrame: + case *wire.AckFrame: return false default: return true @@ -28,7 +26,7 @@ func IsFrameRetransmittable(f frames.Frame) bool { } // HasRetransmittableFrames returns true if at least one frame is retransmittable. -func HasRetransmittableFrames(fs []frames.Frame) bool { +func HasRetransmittableFrames(fs []wire.Frame) bool { for _, f := range fs { if IsFrameRetransmittable(f) { return true diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go index 300b665..68267aa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go @@ -6,9 +6,9 @@ import ( "time" "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -16,8 +16,13 @@ const ( // Maximum reordering in time space before time based loss detection considers a packet lost. // In fraction of an RTT. timeReorderingFraction = 1.0 / 8 + // The default RTT used before an RTT sample is taken. + // Note: This constant is also defined in the congestion package. + defaultInitialRTT = 100 * time.Millisecond // defaultRTOTimeout is the RTO time on new connections defaultRTOTimeout = 500 * time.Millisecond + // Minimum time in the future a tail loss probe alarm may be set for. + minTPLTimeout = 10 * time.Millisecond // Minimum time in the future an RTO alarm may be set for. minRTOTimeout = 200 * time.Millisecond // maxRTOTimeout is the maximum RTO time @@ -40,6 +45,8 @@ type sentPacketHandler struct { lastSentPacketNumber protocol.PacketNumber skippedPackets []protocol.PacketNumber + numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet + LargestAcked protocol.PacketNumber largestReceivedPacketWithAck protocol.PacketNumber @@ -54,6 +61,10 @@ type sentPacketHandler struct { congestion congestion.SendAlgorithm rttStats *congestion.RTTStats + handshakeComplete bool + // The number of times the handshake packets have been retransmitted without receiving an ack. + handshakeCount uint32 + // The number of times an RTO has been sent without receiving an ack. rtoCount uint32 @@ -89,6 +100,14 @@ func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber { return h.LargestAcked } +func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool { + return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets +} + +func (h *sentPacketHandler) SetHandshakeComplete() { + h.handshakeComplete = true +} + func (h *sentPacketHandler) SentPacket(packet *Packet) error { if packet.PacketNumber <= h.lastSentPacketNumber { return errPacketNumberNotIncreasing @@ -116,6 +135,9 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { packet.SendTime = now h.bytesInFlight += packet.Length h.packetHistory.PushBack(*packet) + h.numNonRetransmittablePackets = 0 + } else { + h.numNonRetransmittablePackets++ } h.congestion.OnPacketSent( @@ -130,7 +152,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { return nil } -func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, rcvTime time.Time) error { +func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { if ackFrame.LargestAcked > h.lastSentPacketNumber { return errAckForUnsentPacket } @@ -164,6 +186,9 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum if len(ackedPackets) > 0 { for _, p := range ackedPackets { + if encLevel < p.Value.EncryptionLevel { + return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel) + } h.onPacketAcked(p) h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight) } @@ -178,7 +203,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum return nil } -func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame) ([]*PacketElement, error) { +func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) { var ackedPackets []*PacketElement ackRangeIndex := 0 for el := h.packetHistory.Front(); el != nil; el = el.Next() { @@ -197,14 +222,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame if ackFrame.HasMissingRanges() { ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] - for packetNumber > ackRange.LastPacketNumber && ackRangeIndex < len(ackFrame.AckRanges)-1 { + for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 { ackRangeIndex++ ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] } - if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range - if packetNumber > ackRange.LastPacketNumber { - return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber) + if packetNumber >= ackRange.First { // packet i contained in ACK range + if packetNumber > ackRange.Last { + return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last) } ackedPackets = append(ackedPackets, el) } @@ -238,9 +263,10 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() { return } - // TODO(#496): Handle handshake packets separately // TODO(#497): TLP - if !h.lossTime.IsZero() { + if !h.handshakeComplete { + h.alarm = time.Now().Add(h.computeHandshakeTimeout()) + } else if !h.lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = h.lossTime } else { @@ -282,9 +308,11 @@ func (h *sentPacketHandler) detectLostPackets() { } func (h *sentPacketHandler) OnAlarm() { - // TODO(#496): Handle handshake packets separately // TODO(#497): TLP - if !h.lossTime.IsZero() { + if !h.handshakeComplete { + h.queueHandshakePacketsForRetransmission() + h.handshakeCount++ + } else if !h.lossTime.IsZero() { // Early retransmit or time loss detection h.detectLostPackets() } else { @@ -303,6 +331,7 @@ func (h *sentPacketHandler) GetAlarmTimeout() time.Time { func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) { h.bytesInFlight -= packetElement.Value.Length h.rtoCount = 0 + h.handshakeCount = 0 // TODO(#497): h.tlpCount = 0 h.packetHistory.Remove(packetElement) } @@ -323,7 +352,7 @@ func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return h.largestInOrderAcked() + 1 } -func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame { +func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { return h.stopWaitingManager.GetStopWaitingFrame(force) } @@ -363,6 +392,18 @@ func (h *sentPacketHandler) queueRTO(el *PacketElement) { h.congestion.OnRetransmissionTimeout(true) } +func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() { + var handshakePackets []*PacketElement + for el := h.packetHistory.Front(); el != nil; el = el.Next() { + if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { + handshakePackets = append(handshakePackets, el) + } + } + for _, el := range handshakePackets { + h.queuePacketForRetransmission(el) + } +} + func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) { packet := &packetElement.Value h.bytesInFlight -= packet.Length @@ -371,6 +412,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketEl h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber) } +func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { + duration := 2 * h.rttStats.SmoothedRTT() + if duration == 0 { + duration = 2 * defaultInitialRTT + } + duration = utils.MaxDuration(duration, minTPLTimeout) + // exponential backoff + // There's an implicit limit to this set by the handshake timeout. + return duration << h.handshakeCount +} + func (h *sentPacketHandler) computeRTOTimeout() time.Duration { rto := h.congestion.RetransmissionDelay() if rto == 0 { @@ -382,7 +434,7 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration { return utils.MinDuration(rto, maxRTOTimeout) } -func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool { +func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool { for _, p := range h.skippedPackets { if ackFrame.AcksPacket(p) { return true diff --git a/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go b/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go index dfd79ae..04cb61f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go +++ b/vendor/github.com/lucas-clemente/quic-go/ackhandler/stop_waiting_manager.go @@ -1,8 +1,8 @@ package ackhandler import ( - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) // This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33 @@ -10,10 +10,10 @@ type stopWaitingManager struct { largestLeastUnackedSent protocol.PacketNumber nextLeastUnacked protocol.PacketNumber - lastStopWaitingFrame *frames.StopWaitingFrame + lastStopWaitingFrame *wire.StopWaitingFrame } -func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame { +func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { if s.nextLeastUnacked <= s.largestLeastUnackedSent { if force { return s.lastStopWaitingFrame @@ -22,14 +22,14 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaiting } s.largestLeastUnackedSent = s.nextLeastUnacked - swf := &frames.StopWaitingFrame{ + swf := &wire.StopWaitingFrame{ LeastUnacked: s.nextLeastUnacked, } s.lastStopWaitingFrame = swf return swf } -func (s *stopWaitingManager) ReceivedAck(ack *frames.AckFrame) { +func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) { if ack.LargestAcked >= s.nextLeastUnacked { s.nextLeastUnacked = ack.LargestAcked + 1 } diff --git a/vendor/github.com/lucas-clemente/quic-go/appveyor.yml b/vendor/github.com/lucas-clemente/quic-go/appveyor.yml index 8a3c907..bcd3ac5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/appveyor.yml +++ b/vendor/github.com/lucas-clemente/quic-go/appveyor.yml @@ -14,8 +14,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go install: - rmdir c:\go /s /q - - appveyor DownloadFile https://storage.googleapis.com/golang/go1.8.3.windows-amd64.zip - - 7z x go1.8.3.windows-amd64.zip -y -oC:\ > NUL + - appveyor DownloadFile https://storage.googleapis.com/golang/go1.9.2.windows-amd64.zip + - 7z x go1.9.2.windows-amd64.zip -y -oC:\ > NUL - set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin - echo %PATH% - echo %GOPATH% @@ -27,9 +27,8 @@ install: - go get -v -t ./... build_script: - - rm -r integrationtests - - ginkgo -r --randomizeAllSpecs --randomizeSuites --trace --progress -skipPackage benchmark - - ginkgo --randomizeAllSpecs --randomizeSuites --trace --progress benchmark -- -samples=1 + - ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage benchmark,integrationtests + - ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -samples=1 test: off diff --git a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go index f592d47..5032ca7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go +++ b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go @@ -3,7 +3,7 @@ package quic import ( "sync" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) var bufferPool sync.Pool diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go index 2e18de8..d13dd81 100644 --- a/vendor/github.com/lucas-clemente/quic-go/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/client.go @@ -10,24 +10,26 @@ import ( "sync" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) type client struct { - mutex sync.Mutex - listenErr error + mutex sync.Mutex conn connection hostname string - errorChan chan struct{} handshakeChan <-chan handshakeEvent - tlsConf *tls.Config - config *Config - versionNegotiated bool // has version negotiation completed yet + versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version + versionNegotiated bool // has version negotiation completed yet + receivedVersionNegotiationPacket bool + + tlsConf *tls.Config + config *Config connectionID protocol.ConnectionID version protocol.VersionNumber @@ -36,6 +38,8 @@ type client struct { } var ( + // make it possible to mock connection ID generation in the tests + generateConnectionID = utils.GenerateConnectionID errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") ) @@ -80,7 +84,7 @@ func DialNonFWSecure( tlsConf *tls.Config, config *Config, ) (NonFWSession, error) { - connID, err := utils.GenerateConnectionID() + connID, err := generateConnectionID() if err != nil { return nil, err } @@ -99,23 +103,21 @@ func DialNonFWSecure( clientConfig := populateClientConfig(config) c := &client{ - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - connectionID: connID, - hostname: hostname, - tlsConf: tlsConf, - config: clientConfig, - version: clientConfig.Versions[0], - errorChan: make(chan struct{}), + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + connectionID: connID, + hostname: hostname, + tlsConf: tlsConf, + config: clientConfig, + version: clientConfig.Versions[0], + versionNegotiationChan: make(chan struct{}), } - err = c.createNewSession(nil) - if err != nil { + utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) + + if err := c.establishSecureConnection(); err != nil { return nil, err } - - utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) - - return c.session.(NonFWSession), c.establishSecureConnection() + return c.session.(NonFWSession), nil } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -131,8 +133,7 @@ func Dial( if err != nil { return nil, err } - err = sess.WaitUntilHandshakeComplete() - if err != nil { + if err := sess.WaitUntilHandshakeComplete(); err != nil { return nil, err } return sess, nil @@ -153,6 +154,10 @@ func populateClientConfig(config *Config) *Config { if config.HandshakeTimeout != 0 { handshakeTimeout = config.HandshakeTimeout } + idleTimeout := protocol.DefaultIdleTimeout + if config.IdleTimeout != 0 { + idleTimeout = config.IdleTimeout + } maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow if maxReceiveStreamFlowControlWindow == 0 { @@ -166,7 +171,8 @@ func populateClientConfig(config *Config) *Config { return &Config{ Versions: versions, HandshakeTimeout: handshakeTimeout, - RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, + IdleTimeout: idleTimeout, + RequestConnectionIDOmission: config.RequestConnectionIDOmission, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, KeepAlive: config.KeepAlive, @@ -175,16 +181,40 @@ func populateClientConfig(config *Config) *Config { // establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) func (c *client) establishSecureConnection() error { + if err := c.createNewSession(c.version, nil); err != nil { + return err + } go c.listen() + var runErr error + errorChan := make(chan struct{}) + go func() { + // session.run() returns as soon as the session is closed + runErr = c.session.run() + if runErr == errCloseSessionForNewVersion { + // run the new session + runErr = c.session.run() + } + close(errorChan) + utils.Infof("Connection %x closed.", c.connectionID) + c.conn.Close() + }() + + // wait until the server accepts the QUIC version (or an error occurs) select { - case <-c.errorChan: - return c.listenErr + case <-errorChan: + return runErr + case <-c.versionNegotiationChan: + } + + select { + case <-errorChan: + return runErr case ev := <-c.handshakeChan: if ev.err != nil { return ev.err } - if ev.encLevel != protocol.EncryptionSecure { + if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure { return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel) } return nil @@ -219,10 +249,18 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) + hdr, err := wire.ParseHeaderSentByServer(r, c.version) if err != nil { utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) - // drop this packet if we can't parse the Public Header + // drop this packet if we can't parse the header + return + } + // reject packets with truncated connection id if we didn't request truncation + if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { + return + } + // reject packets with the wrong connection ID + if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID { return } hdr.Raw = packet[:len(packet)-r.Len()] @@ -238,44 +276,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { utils.Infof("Received a spoofed Public Reset. Ignoring.") return } - pr, err := parsePublicReset(r) + pr, err := wire.ParsePublicReset(r) if err != nil { - utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") + utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) return } - utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber) - c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber))) + utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) + c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) return } - // ignore delayed / duplicated version negotiation packets - if c.versionNegotiated && hdr.VersionFlag { - return - } + isVersionNegotiationPacket := hdr.VersionFlag /* gQUIC Version Negotiation Packet */ || hdr.Type == protocol.PacketTypeVersionNegotiation /* IETF draft style Version Negotiation Packet */ - // this is the first packet after the client sent a packet with the VersionFlag set - // if the server doesn't send a version negotiation packet, it supports the suggested version - if !hdr.VersionFlag && !c.versionNegotiated { - c.versionNegotiated = true - } + // handle Version Negotiation Packets + if isVersionNegotiationPacket { + // ignore delayed / duplicated version negotiation packets + if c.receivedVersionNegotiationPacket || c.versionNegotiated { + return + } - if hdr.VersionFlag { // version negotiation packets have no payload - if err := c.handlePacketWithVersionFlag(hdr); err != nil { + if err := c.handleVersionNegotiationPacket(hdr); err != nil { c.session.Close(err) } return } + // this is the first packet we are receiving + // since it is not a Version Negotiation Packet, this means the server supports the suggested version + if !c.versionNegotiated { + c.versionNegotiated = true + close(c.versionNegotiationChan) + } + c.session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - publicHeader: hdr, - data: packet[len(packet)-r.Len():], - rcvTime: rcvTime, + remoteAddr: remoteAddr, + header: hdr, + data: packet[len(packet)-r.Len():], + rcvTime: rcvTime, }) } -func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { +func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { for _, v := range hdr.SupportedVersions { if v == c.version { // the version negotiation packet contains the version that we offered @@ -285,27 +327,33 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { } } - newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) - if newVersion == protocol.VersionUnsupported { + c.receivedVersionNegotiationPacket = true + + newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) + if !ok { return qerr.InvalidVersion } // switch to negotiated version + initialVersion := c.version c.version = newVersion - c.versionNegotiated = true var err error c.connectionID, err = utils.GenerateConnectionID() if err != nil { return err } - utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID) + utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) - c.session.Close(errCloseSessionForNewVersion) - return c.createNewSession(hdr.SupportedVersions) + // create a new session and close the old one + // the new session must be created first to update client member variables + oldSession := c.session + defer oldSession.Close(errCloseSessionForNewVersion) + return c.createNewSession(initialVersion, hdr.SupportedVersions) } -func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { +func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error { var err error + utils.Debugf("createNewSession with initial version %s", initialVersion) c.session, c.handshakeChan, err = newClientSession( c.conn, c.hostname, @@ -313,23 +361,8 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e c.connectionID, c.tlsConf, c.config, + initialVersion, negotiatedVersions, ) - if err != nil { - return err - } - - go func() { - // session.run() returns as soon as the session is closed - err := c.session.run() - if err == errCloseSessionForNewVersion { - return - } - c.listenErr = err - close(c.errorChan) - - utils.Infof("Connection %x closed.", c.connectionID) - c.conn.Close() - }() - return nil + return err } diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go b/vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go index e76ea16..54269c5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/bandwidth.go @@ -3,7 +3,7 @@ package congestion import ( "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // Bandwidth of a connection diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go b/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go index 62e7355..3922f47 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/cubic.go @@ -4,8 +4,8 @@ import ( "math" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) // This cubic implementation is based on the one found in Chromiums's QUIC diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go b/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go index 02e4206..f2c8c2d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/cubic_sender.go @@ -3,8 +3,8 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) const ( diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go b/vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go index 01a64f8..f41c1e5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/hybrid_slow_start.go @@ -3,8 +3,8 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) // Note(pwestin): the magic clamping numbers come from the original code in diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go b/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go index bbce0a6..411a5f2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/interface.go @@ -3,7 +3,7 @@ package congestion import ( "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // A SendAlgorithm performs congestion control and calculates the congestion window diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go b/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go index b8a0a10..18a3736 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/prr_sender.go @@ -3,8 +3,8 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go b/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go index 546c1cb..624957c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/rtt_stats.go @@ -7,6 +7,7 @@ import ( ) const ( + // Note: This constant is also defined in the ackhandler package. initialRTTus = 100 * 1000 rttAlpha float32 = 0.125 oneMinusAlpha float32 = (1 - rttAlpha) diff --git a/vendor/github.com/lucas-clemente/quic-go/congestion/stats.go b/vendor/github.com/lucas-clemente/quic-go/congestion/stats.go index 8f272b2..ed669c1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/congestion/stats.go +++ b/vendor/github.com/lucas-clemente/quic-go/congestion/stats.go @@ -1,6 +1,6 @@ package congestion -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" type connectionStats struct { slowstartPacketsLost protocol.PacketNumber diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go b/vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go deleted file mode 100644 index a738cc2..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go +++ /dev/null @@ -1,58 +0,0 @@ -package crypto - -import ( - "crypto/cipher" - "errors" - - "github.com/lucas-clemente/aes12" - - "github.com/lucas-clemente/quic-go/protocol" -) - -type aeadAESGCM struct { - otherIV []byte - myIV []byte - encrypter cipher.AEAD - decrypter cipher.AEAD -} - -// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size -// -// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte -// tag size, and couples the cipher and aes packages closely. -// See https://github.com/lucas-clemente/aes12. -func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { - if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 { - return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs") - } - encrypterCipher, err := aes12.NewCipher(myKey) - if err != nil { - return nil, err - } - encrypter, err := aes12.NewGCM(encrypterCipher) - if err != nil { - return nil, err - } - decrypterCipher, err := aes12.NewCipher(otherKey) - if err != nil { - return nil, err - } - decrypter, err := aes12.NewGCM(decrypterCipher) - if err != nil { - return nil, err - } - return &aeadAESGCM{ - otherIV: otherIV, - myIV: myIV, - encrypter: encrypter, - decrypter: decrypter, - }, nil -} - -func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData) -} - -func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go b/vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go deleted file mode 100644 index 9b6d416..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go +++ /dev/null @@ -1,14 +0,0 @@ -package crypto - -import ( - "encoding/binary" - - "github.com/lucas-clemente/quic-go/protocol" -) - -func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { - res := make([]byte, 12) - copy(res[0:4], iv) - binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) - return res -} diff --git a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go b/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go deleted file mode 100644 index 9362d60..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go +++ /dev/null @@ -1,240 +0,0 @@ -package flowcontrol - -import ( - "errors" - "fmt" - "sync" - - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" -) - -type flowControlManager struct { - connectionParameters handshake.ConnectionParametersManager - rttStats *congestion.RTTStats - - streamFlowController map[protocol.StreamID]*flowController - connFlowController *flowController - mutex sync.RWMutex -} - -var _ FlowControlManager = &flowControlManager{} - -var errMapAccess = errors.New("Error accessing the flowController map.") - -// NewFlowControlManager creates a new flow control manager -func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager { - return &flowControlManager{ - connectionParameters: connectionParameters, - rttStats: rttStats, - streamFlowController: make(map[protocol.StreamID]*flowController), - connFlowController: newFlowController(0, false, connectionParameters, rttStats), - } -} - -// NewStream creates new flow controllers for a stream -// it does nothing if the stream already exists -func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) { - f.mutex.Lock() - defer f.mutex.Unlock() - - if _, ok := f.streamFlowController[streamID]; ok { - return - } - - f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats) -} - -// RemoveStream removes a closed stream from flow control -func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { - f.mutex.Lock() - delete(f.streamFlowController, streamID) - f.mutex.Unlock() -} - -// ResetStream should be called when receiving a RstStreamFrame -// it updates the byte offset to the value in the RstStreamFrame -// streamID must not be 0 here -func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - streamFlowController, err := f.getFlowController(streamID) - if err != nil { - return err - } - increment, err := streamFlowController.UpdateHighestReceived(byteOffset) - if err != nil { - return qerr.StreamDataAfterTermination - } - - if streamFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) - } - - if streamFlowController.ContributesToConnection() { - f.connFlowController.IncrementHighestReceived(increment) - if f.connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) - } - } - - return nil -} - -// UpdateHighestReceived updates the highest received byte offset for a stream -// it adds the number of additional bytes to connection level flow control -// streamID must not be 0 here -func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - streamFlowController, err := f.getFlowController(streamID) - if err != nil { - return err - } - // UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered - // this error can be ignored here - increment, _ := streamFlowController.UpdateHighestReceived(byteOffset) - - if streamFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow)) - } - - if streamFlowController.ContributesToConnection() { - f.connFlowController.IncrementHighestReceived(increment) - if f.connFlowController.CheckFlowControlViolation() { - return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow)) - } - } - - return nil -} - -// streamID must not be 0 here -func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return err - } - - fc.AddBytesRead(n) - if fc.ContributesToConnection() { - f.connFlowController.AddBytesRead(n) - } - - return nil -} - -func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { - f.mutex.Lock() - defer f.mutex.Unlock() - - // get WindowUpdates for streams - for id, fc := range f.streamFlowController { - if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary { - res = append(res, WindowUpdate{StreamID: id, Offset: offset}) - if fc.ContributesToConnection() && newIncrement != 0 { - f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier)) - } - } - } - // get a WindowUpdate for the connection - if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary { - res = append(res, WindowUpdate{StreamID: 0, Offset: offset}) - } - - return -} - -func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { - f.mutex.RLock() - defer f.mutex.RUnlock() - - // StreamID can be 0 when retransmitting - if streamID == 0 { - return f.connFlowController.receiveWindow, nil - } - - flowController, err := f.getFlowController(streamID) - if err != nil { - return 0, err - } - return flowController.receiveWindow, nil -} - -// streamID must not be 0 here -func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { - f.mutex.Lock() - defer f.mutex.Unlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return err - } - - fc.AddBytesSent(n) - if fc.ContributesToConnection() { - f.connFlowController.AddBytesSent(n) - } - - return nil -} - -// must not be called with StreamID 0 -func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) { - f.mutex.RLock() - defer f.mutex.RUnlock() - - fc, err := f.getFlowController(streamID) - if err != nil { - return 0, err - } - res := fc.SendWindowSize() - - if fc.ContributesToConnection() { - res = utils.MinByteCount(res, f.connFlowController.SendWindowSize()) - } - - return res, nil -} - -func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { - f.mutex.RLock() - defer f.mutex.RUnlock() - - return f.connFlowController.SendWindowSize() -} - -// streamID may be 0 here -func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) { - f.mutex.Lock() - defer f.mutex.Unlock() - - var fc *flowController - if streamID == 0 { - fc = f.connFlowController - } else { - var err error - fc, err = f.getFlowController(streamID) - if err != nil { - return false, err - } - } - - return fc.UpdateSendWindow(offset), nil -} - -func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) { - streamFlowController, ok := f.streamFlowController[streamID] - if !ok { - return nil, errMapAccess - } - return streamFlowController, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go deleted file mode 100644 index 387ee05..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go +++ /dev/null @@ -1,198 +0,0 @@ -package flowcontrol - -import ( - "errors" - "time" - - "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -type flowController struct { - streamID protocol.StreamID - contributesToConnection bool // does the stream contribute to connection level flow control - - connectionParameters handshake.ConnectionParametersManager - rttStats *congestion.RTTStats - - bytesSent protocol.ByteCount - sendWindow protocol.ByteCount - - lastWindowUpdateTime time.Time - - bytesRead protocol.ByteCount - highestReceived protocol.ByteCount - receiveWindow protocol.ByteCount - receiveWindowIncrement protocol.ByteCount - maxReceiveWindowIncrement protocol.ByteCount -} - -// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously -var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") - -// newFlowController gets a new flow controller -func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController { - fc := flowController{ - streamID: streamID, - contributesToConnection: contributesToConnection, - connectionParameters: connectionParameters, - rttStats: rttStats, - } - - if streamID == 0 { - fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow() - fc.receiveWindowIncrement = fc.receiveWindow - fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow() - } else { - fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow() - fc.receiveWindowIncrement = fc.receiveWindow - fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow() - } - - return &fc -} - -func (c *flowController) ContributesToConnection() bool { - return c.contributesToConnection -} - -func (c *flowController) getSendWindow() protocol.ByteCount { - if c.sendWindow == 0 { - if c.streamID == 0 { - return c.connectionParameters.GetSendConnectionFlowControlWindow() - } - return c.connectionParameters.GetSendStreamFlowControlWindow() - } - return c.sendWindow -} - -func (c *flowController) AddBytesSent(n protocol.ByteCount) { - c.bytesSent += n -} - -// UpdateSendWindow should be called after receiving a WindowUpdateFrame -// it returns true if the window was actually updated -func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { - if newOffset > c.sendWindow { - c.sendWindow = newOffset - return true - } - return false -} - -func (c *flowController) SendWindowSize() protocol.ByteCount { - sendWindow := c.getSendWindow() - - if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here - return 0 - } - return sendWindow - c.bytesSent -} - -func (c *flowController) SendWindowOffset() protocol.ByteCount { - return c.getSendWindow() -} - -// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher -// Should **only** be used for the stream-level FlowController -// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before -// This error occurs every time StreamFrames get reordered and has to be ignored in that case -// It should only be treated as an error when resetting a stream -func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) { - if byteOffset == c.highestReceived { - return 0, nil - } - if byteOffset > c.highestReceived { - increment := byteOffset - c.highestReceived - c.highestReceived = byteOffset - return increment, nil - } - return 0, ErrReceivedSmallerByteOffset -} - -// IncrementHighestReceived adds an increment to the highestReceived value -// Should **only** be used for the connection-level FlowController -func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) { - c.highestReceived += increment -} - -func (c *flowController) AddBytesRead(n protocol.ByteCount) { - // pretend we sent a WindowUpdate when reading the first byte - // this way auto-tuning of the window increment already works for the first WindowUpdate - if c.bytesRead == 0 { - c.lastWindowUpdateTime = time.Now() - } - c.bytesRead += n -} - -// MaybeUpdateWindow updates the receive window, if necessary -// if the receive window increment is changed, the new value is returned, otherwise a 0 -// the last return value is the new offset of the receive window -func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) { - diff := c.receiveWindow - c.bytesRead - - // Chromium implements the same threshold - if diff < (c.receiveWindowIncrement / 2) { - var newWindowIncrement protocol.ByteCount - oldWindowIncrement := c.receiveWindowIncrement - - c.maybeAdjustWindowIncrement() - if c.receiveWindowIncrement != oldWindowIncrement { - newWindowIncrement = c.receiveWindowIncrement - } - - c.lastWindowUpdateTime = time.Now() - c.receiveWindow = c.bytesRead + c.receiveWindowIncrement - return true, newWindowIncrement, c.receiveWindow - } - - return false, 0, 0 -} - -// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often -func (c *flowController) maybeAdjustWindowIncrement() { - if c.lastWindowUpdateTime.IsZero() { - return - } - - rtt := c.rttStats.SmoothedRTT() - if rtt == 0 { - return - } - - timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) - - // interval between the window updates is sufficiently large, no need to increase the increment - if timeSinceLastWindowUpdate >= 2*rtt { - return - } - - oldWindowSize := c.receiveWindowIncrement - c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) - - // debug log, if the window size was actually increased - if oldWindowSize < c.receiveWindowIncrement { - newWindowSize := c.receiveWindowIncrement / (1 << 10) - if c.streamID == 0 { - utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize) - } else { - utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize) - } - } -} - -// EnsureMinimumWindowIncrement sets a minimum window increment -// it is intended be used for the connection-level flow controller -// it should make sure that the connection-level window is increased when a stream-level window grows -func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { - if inc > c.receiveWindowIncrement { - c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) - c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update - } -} - -func (c *flowController) CheckFlowControlViolation() bool { - return c.highestReceived > c.receiveWindow -} diff --git a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go deleted file mode 100644 index e1ea3fa..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go +++ /dev/null @@ -1,26 +0,0 @@ -package flowcontrol - -import "github.com/lucas-clemente/quic-go/protocol" - -// WindowUpdate provides the data for WindowUpdateFrames. -type WindowUpdate struct { - StreamID protocol.StreamID - Offset protocol.ByteCount -} - -// A FlowControlManager manages the flow control -type FlowControlManager interface { - NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) - RemoveStream(streamID protocol.StreamID) - // methods needed for receiving data - ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error - UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error - AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error - GetWindowUpdates() []WindowUpdate - GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) - // methods needed for sending data - AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error - SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) - RemainingConnectionWindowSize() protocol.ByteCount - UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go b/vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go deleted file mode 100644 index ac65d33..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go +++ /dev/null @@ -1,9 +0,0 @@ -package frames - -import "github.com/lucas-clemente/quic-go/protocol" - -// AckRange is an ACK range -type AckRange struct { - FirstPacketNumber protocol.PacketNumber - LastPacketNumber protocol.PacketNumber -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go deleted file mode 100644 index 4464578..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go +++ /dev/null @@ -1,44 +0,0 @@ -package frames - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -// A BlockedFrame in QUIC -type BlockedFrame struct { - StreamID protocol.StreamID -} - -//Write writes a BlockedFrame frame -func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x05) - utils.WriteUint32(b, uint32(f.StreamID)) - return nil -} - -// MinLength of a written frame -func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4, nil -} - -// ParseBlockedFrame parses a BLOCKED frame -func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) { - frame := &BlockedFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { - return nil, err - } - - sid, err := utils.ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - - return frame, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/log.go b/vendor/github.com/lucas-clemente/quic-go/frames/log.go deleted file mode 100644 index 6b7fdce..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/log.go +++ /dev/null @@ -1,28 +0,0 @@ -package frames - -import "github.com/lucas-clemente/quic-go/internal/utils" - -// LogFrame logs a frame, either sent or received -func LogFrame(frame Frame, sent bool) { - if !utils.Debug() { - return - } - dir := "<-" - if sent { - dir = "->" - } - switch f := frame.(type) { - case *StreamFrame: - utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) - case *StopWaitingFrame: - if sent { - utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) - } else { - utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) - } - case *AckFrame: - utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) - default: - utils.Debugf("\t%s %#v", dir, frame) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go b/vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go deleted file mode 100644 index 9b8b459..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go +++ /dev/null @@ -1,54 +0,0 @@ -package frames - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -// A WindowUpdateFrame in QUIC -type WindowUpdateFrame struct { - StreamID protocol.StreamID - ByteOffset protocol.ByteCount -} - -//Write writes a RST_STREAM frame -func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x04) - b.WriteByte(typeByte) - - utils.WriteUint32(b, uint32(f.StreamID)) - utils.WriteUint64(b, uint64(f.ByteOffset)) - return nil -} - -// MinLength of a written frame -func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { - return 1 + 4 + 8, nil -} - -// ParseWindowUpdateFrame parses a RST_STREAM frame -func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) { - frame := &WindowUpdateFrame{} - - // read the TypeByte - _, err := r.ReadByte() - if err != nil { - return nil, err - } - - sid, err := utils.ReadUint32(r) - if err != nil { - return nil, err - } - frame.StreamID = protocol.StreamID(sid) - - byteOffset, err := utils.ReadUint64(r) - if err != nil { - return nil, err - } - frame.ByteOffset = protocol.ByteCount(byteOffset) - - return frame, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go deleted file mode 100644 index 866b11a..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go +++ /dev/null @@ -1,296 +0,0 @@ -package h2quic - -import ( - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "net/http" - "strings" - "sync" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" - "golang.org/x/net/idna" - - quic "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" -) - -type roundTripperOpts struct { - DisableCompression bool -} - -var dialAddr = quic.DialAddr - -// client is a HTTP2 client doing QUIC requests -type client struct { - mutex sync.RWMutex - - tlsConf *tls.Config - config *quic.Config - opts *roundTripperOpts - - hostname string - encryptionLevel protocol.EncryptionLevel - handshakeErr error - dialOnce sync.Once - - session quic.Session - headerStream quic.Stream - headerErr *qerr.QuicError - headerErrored chan struct{} // this channel is closed if an error occurs on the header stream - requestWriter *requestWriter - - responses map[protocol.StreamID]chan *http.Response -} - -var _ http.RoundTripper = &client{} - -var defaultQuicConfig = &quic.Config{ - RequestConnectionIDTruncation: true, - KeepAlive: true, -} - -// newClient creates a new client -func newClient( - hostname string, - tlsConfig *tls.Config, - opts *roundTripperOpts, - quicConfig *quic.Config, -) *client { - config := defaultQuicConfig - if quicConfig != nil { - config = quicConfig - } - return &client{ - hostname: authorityAddr("https", hostname), - responses: make(map[protocol.StreamID]chan *http.Response), - encryptionLevel: protocol.EncryptionUnencrypted, - tlsConf: tlsConfig, - config: config, - opts: opts, - headerErrored: make(chan struct{}), - } -} - -// dial dials the connection -func (c *client) dial() error { - var err error - c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) - if err != nil { - return err - } - - // once the version has been negotiated, open the header stream - c.headerStream, err = c.session.OpenStream() - if err != nil { - return err - } - if c.headerStream.StreamID() != 3 { - return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3") - } - c.requestWriter = newRequestWriter(c.headerStream) - go c.handleHeaderStream() - return nil -} - -func (c *client) handleHeaderStream() { - decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) - h2framer := http2.NewFramer(nil, c.headerStream) - - var lastStream protocol.StreamID - - for { - frame, err := h2framer.ReadFrame() - if err != nil { - c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") - break - } - lastStream = protocol.StreamID(frame.Header().StreamID) - hframe, ok := frame.(*http2.HeadersFrame) - if !ok { - c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") - break - } - mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} - mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) - if err != nil { - c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") - break - } - - c.mutex.RLock() - responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] - c.mutex.RUnlock() - if !ok { - c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) - break - } - - rsp, err := responseFromHeaders(mhframe) - if err != nil { - c.headerErr = qerr.Error(qerr.InternalError, err.Error()) - } - responseChan <- rsp - } - - // stop all running request - utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) - close(c.headerErrored) -} - -// Roundtrip executes a request and returns a response -func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { - // TODO: add port to address, if it doesn't have one - if req.URL.Scheme != "https" { - return nil, errors.New("quic http2: unsupported scheme") - } - if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { - return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host) - } - - c.dialOnce.Do(func() { - c.handshakeErr = c.dial() - }) - - if c.handshakeErr != nil { - return nil, c.handshakeErr - } - - hasBody := (req.Body != nil) - - responseChan := make(chan *http.Response) - dataStream, err := c.session.OpenStreamSync() - if err != nil { - _ = c.CloseWithError(err) - return nil, err - } - c.mutex.Lock() - c.responses[dataStream.StreamID()] = responseChan - c.mutex.Unlock() - - var requestedGzip bool - if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { - requestedGzip = true - } - // TODO: add support for trailers - endStream := !hasBody - err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) - if err != nil { - _ = c.CloseWithError(err) - return nil, err - } - - resc := make(chan error, 1) - if hasBody { - go func() { - resc <- c.writeRequestBody(dataStream, req.Body) - }() - } - - var res *http.Response - - var receivedResponse bool - var bodySent bool - - if !hasBody { - bodySent = true - } - - for !(bodySent && receivedResponse) { - select { - case res = <-responseChan: - receivedResponse = true - c.mutex.Lock() - delete(c.responses, dataStream.StreamID()) - c.mutex.Unlock() - case err := <-resc: - bodySent = true - if err != nil { - return nil, err - } - case <-c.headerErrored: - // an error occured on the header stream - _ = c.CloseWithError(c.headerErr) - return nil, c.headerErr - } - } - - // TODO: correctly set this variable - var streamEnded bool - isHead := (req.Method == "HEAD") - - res = setLength(res, isHead, streamEnded) - - if streamEnded || isHead { - res.Body = noBody - } else { - res.Body = dataStream - if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = &gzipReader{body: res.Body} - res.Uncompressed = true - } - } - - res.Request = req - return res, nil -} - -func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) { - defer func() { - cerr := body.Close() - if err == nil { - // TODO: what to do with dataStream here? Maybe reset it? - err = cerr - } - }() - - _, err = io.Copy(dataStream, body) - if err != nil { - // TODO: what to do with dataStream here? Maybe reset it? - return err - } - return dataStream.Close() -} - -// Close closes the client -func (c *client) CloseWithError(e error) error { - if c.session == nil { - return nil - } - return c.session.Close(e) -} - -func (c *client) Close() error { - return c.CloseWithError(nil) -} - -// copied from net/transport.go - -// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) -// and returns a host:port. The port 443 is added if needed. -func authorityAddr(scheme string, authority string) (addr string) { - host, port, err := net.SplitHostPort(authority) - if err != nil { // authority didn't have a port - port = "443" - if scheme == "http" { - port = "80" - } - host = authority - } - if a, err := idna.ToASCII(host); err == nil { - host = a - } - // IPv6 address literal, without a port: - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - return host + ":" + port - } - return net.JoinHostPort(host, port) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go deleted file mode 100644 index 91c226b..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go +++ /dev/null @@ -1,35 +0,0 @@ -package h2quic - -// copied from net/transport.go - -// gzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read -import ( - "compress/gzip" - "io" -) - -// call gzip.NewReader on the first call to Read -type gzipReader struct { - body io.ReadCloser // underlying Response.Body - zr *gzip.Reader // lazily-initialized gzip reader - zerr error // sticky error -} - -func (gz *gzipReader) Read(p []byte) (n int, err error) { - if gz.zerr != nil { - return 0, gz.zerr - } - if gz.zr == nil { - gz.zr, err = gzip.NewReader(gz.body) - if err != nil { - gz.zerr = err - return 0, err - } - } - return gz.zr.Read(p) -} - -func (gz *gzipReader) Close() error { - return gz.body.Close() -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go deleted file mode 100644 index 911485e..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go +++ /dev/null @@ -1,80 +0,0 @@ -package h2quic - -import ( - "crypto/tls" - "errors" - "net/http" - "net/url" - "strconv" - "strings" - - "golang.org/x/net/http2/hpack" -) - -func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { - var path, authority, method, contentLengthStr string - httpHeaders := http.Header{} - - for _, h := range headers { - switch h.Name { - case ":path": - path = h.Value - case ":method": - method = h.Value - case ":authority": - authority = h.Value - case "content-length": - contentLengthStr = h.Value - default: - if !h.IsPseudo() { - httpHeaders.Add(h.Name, h.Value) - } - } - } - - // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 - if len(httpHeaders["Cookie"]) > 0 { - httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) - } - - if len(path) == 0 || len(authority) == 0 || len(method) == 0 { - return nil, errors.New(":path, :authority and :method must not be empty") - } - - u, err := url.Parse(path) - if err != nil { - return nil, err - } - - var contentLength int64 - if len(contentLengthStr) > 0 { - contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) - if err != nil { - return nil, err - } - } - - return &http.Request{ - Method: method, - URL: u, - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - Header: httpHeaders, - Body: nil, - ContentLength: contentLength, - Host: authority, - RequestURI: path, - TLS: &tls.ConnectionState{}, - }, nil -} - -func hostnameFromRequest(req *http.Request) string { - if len(req.Host) > 0 { - return req.Host - } - if req.URL != nil { - return req.URL.Host - } - return "" -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go deleted file mode 100644 index 2d4d595..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go +++ /dev/null @@ -1,29 +0,0 @@ -package h2quic - -import ( - "io" - - quic "github.com/lucas-clemente/quic-go" -) - -type requestBody struct { - requestRead bool - dataStream quic.Stream -} - -// make sure the requestBody can be used as a http.Request.Body -var _ io.ReadCloser = &requestBody{} - -func newRequestBody(stream quic.Stream) *requestBody { - return &requestBody{dataStream: stream} -} - -func (b *requestBody) Read(p []byte) (int, error) { - b.requestRead = true - return b.dataStream.Read(p) -} - -func (b *requestBody) Close() error { - // stream's Close() closes the write side, not the read side - return nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go deleted file mode 100644 index dad591c..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go +++ /dev/null @@ -1,201 +0,0 @@ -package h2quic - -import ( - "bytes" - "fmt" - "net/http" - "strconv" - "strings" - "sync" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" - "golang.org/x/net/lex/httplex" - - quic "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -type requestWriter struct { - mutex sync.Mutex - headerStream quic.Stream - - henc *hpack.Encoder - hbuf bytes.Buffer // HPACK encoder writes into this -} - -const defaultUserAgent = "quic-go" - -func newRequestWriter(headerStream quic.Stream) *requestWriter { - rw := &requestWriter{ - headerStream: headerStream, - } - rw.henc = hpack.NewEncoder(&rw.hbuf) - return rw -} - -func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error { - // TODO: add support for trailers - // TODO: add support for gzip compression - // TODO: write continuation frames, if the header frame is too long - - w.mutex.Lock() - defer w.mutex.Unlock() - - w.encodeHeaders(req, requestGzip, "", actualContentLength(req)) - h2framer := http2.NewFramer(w.headerStream, nil) - return h2framer.WriteHeaders(http2.HeadersFrameParam{ - StreamID: uint32(dataStreamID), - EndHeaders: true, - EndStream: endStream, - BlockFragment: w.hbuf.Bytes(), - Priority: http2.PriorityParam{Weight: 0xff}, - }) -} - -// the rest of this files is copied from http2.Transport -func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { - w.hbuf.Reset() - - host := req.Host - if host == "" { - host = req.URL.Host - } - host, err := httplex.PunycodeHostPort(host) - if err != nil { - return nil, err - } - - var path string - if req.Method != "CONNECT" { - path = req.URL.RequestURI() - if !validPseudoPath(path) { - orig := path - path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) - if !validPseudoPath(path) { - if req.URL.Opaque != "" { - return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) - } else { - return nil, fmt.Errorf("invalid request :path %q", orig) - } - } - } - } - - // Check for any invalid headers and return an error before we - // potentially pollute our hpack state. (We want to be able to - // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httplex.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httplex.ValidHeaderFieldValue(v) { - return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) - } - } - } - - // 8.1.2.3 Request Pseudo-Header Fields - // The :path pseudo-header field includes the path and query parts of the - // target URI (the path-absolute production and optionally a '?' character - // followed by the query production (see Sections 3.3 and 3.4 of - // [RFC3986]). - w.writeHeader(":authority", host) - w.writeHeader(":method", req.Method) - if req.Method != "CONNECT" { - w.writeHeader(":path", path) - w.writeHeader(":scheme", req.URL.Scheme) - } - if trailers != "" { - w.writeHeader("trailer", trailers) - } - - var didUA bool - for k, vv := range req.Header { - lowKey := strings.ToLower(k) - switch lowKey { - case "host", "content-length": - // Host is :authority, already sent. - // Content-Length is automatic, set below. - continue - case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. - continue - case "user-agent": - // Match Go's http1 behavior: at most one - // User-Agent. If set to nil or empty string, - // then omit it. Otherwise if not mentioned, - // include the default (below). - didUA = true - if len(vv) < 1 { - continue - } - vv = vv[:1] - if vv[0] == "" { - continue - } - } - for _, v := range vv { - w.writeHeader(lowKey, v) - } - } - if shouldSendReqContentLength(req.Method, contentLength) { - w.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) - } - if addGzipHeader { - w.writeHeader("accept-encoding", "gzip") - } - if !didUA { - w.writeHeader("user-agent", defaultUserAgent) - } - return w.hbuf.Bytes(), nil -} - -func (w *requestWriter) writeHeader(name, value string) { - utils.Debugf("http2: Transport encoding header %q = %q", name, value) - w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) -} - -// shouldSendReqContentLength reports whether the http2.Transport should send -// a "content-length" request header. This logic is basically a copy of the net/http -// transferWriter.shouldSendContentLength. -// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). -// -1 means unknown. -func shouldSendReqContentLength(method string, contentLength int64) bool { - if contentLength > 0 { - return true - } - if contentLength < 0 { - return false - } - // For zero bodies, whether we send a content-length depends on the method. - // It also kinda doesn't matter for http2 either way, with END_STREAM. - switch method { - case "POST", "PUT", "PATCH": - return true - default: - return false - } -} - -func validPseudoPath(v string) bool { - return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" -} - -// actualContentLength returns a sanitized version of -// req.ContentLength, where 0 actually means zero (not unknown) and -1 -// means unknown. -func actualContentLength(req *http.Request) int64 { - if req.Body == nil { - return 0 - } - if req.ContentLength != 0 { - return req.ContentLength - } - return -1 -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/response.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/response.go deleted file mode 100644 index 13efdf8..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/response.go +++ /dev/null @@ -1,111 +0,0 @@ -package h2quic - -import ( - "bytes" - "errors" - "io" - "io/ioutil" - "net/http" - "net/textproto" - "strconv" - "strings" - - "golang.org/x/net/http2" -) - -// copied from net/http2/transport.go - -var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") -var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) - -// from the handleResponse function -func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { - if f.Truncated { - return nil, errResponseHeaderListSize - } - - status := f.PseudoValue("status") - if status == "" { - return nil, errors.New("missing status pseudo header") - } - statusCode, err := strconv.Atoi(status) - if err != nil { - return nil, errors.New("malformed non-numeric status pseudo header") - } - - if statusCode == 100 { - // TODO: handle this - - // traceGot100Continue(cs.trace) - // if cs.on100 != nil { - // cs.on100() // forces any write delay timer to fire - // } - // cs.pastHeaders = false // do it all again - // return nil, nil - } - - header := make(http.Header) - res := &http.Response{ - Proto: "HTTP/2.0", - ProtoMajor: 2, - Header: header, - StatusCode: statusCode, - Status: status + " " + http.StatusText(statusCode), - } - for _, hf := range f.RegularFields() { - key := http.CanonicalHeaderKey(hf.Name) - if key == "Trailer" { - t := res.Trailer - if t == nil { - t = make(http.Header) - res.Trailer = t - } - foreachHeaderElement(hf.Value, func(v string) { - t[http.CanonicalHeaderKey(v)] = nil - }) - } else { - header[key] = append(header[key], hf.Value) - } - } - - return res, nil -} - -// continuation of the handleResponse function -func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { - if !streamEnded || isHead { - res.ContentLength = -1 - if clens := res.Header["Content-Length"]; len(clens) == 1 { - if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { - res.ContentLength = clen64 - } else { - // TODO: care? unlike http/1, it won't mess up our framing, so it's - // more safe smuggling-wise to ignore. - } - } else if len(clens) > 1 { - // TODO: care? unlike http/1, it won't mess up our framing, so it's - // more safe smuggling-wise to ignore. - } - } - return res -} - -// copied from net/http/server.go - -// foreachHeaderElement splits v according to the "#rule" construction -// in RFC 2616 section 2.1 and calls fn for each non-empty element. -func foreachHeaderElement(v string, fn func(string)) { - v = textproto.TrimString(v) - if v == "" { - return - } - if !strings.Contains(v, ",") { - fn(v) - return - } - for _, f := range strings.Split(v, ",") { - if f = textproto.TrimString(f); f != "" { - fn(f) - } - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go deleted file mode 100644 index 2468934..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go +++ /dev/null @@ -1,108 +0,0 @@ -package h2quic - -import ( - "bytes" - "net/http" - "strconv" - "strings" - "sync" - - quic "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" -) - -type responseWriter struct { - dataStreamID protocol.StreamID - dataStream quic.Stream - - headerStream quic.Stream - headerStreamMutex *sync.Mutex - - header http.Header - status int // status code passed to WriteHeader - headerWritten bool -} - -func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter { - return &responseWriter{ - header: http.Header{}, - headerStream: headerStream, - headerStreamMutex: headerStreamMutex, - dataStream: dataStream, - dataStreamID: dataStreamID, - } -} - -func (w *responseWriter) Header() http.Header { - return w.header -} - -func (w *responseWriter) WriteHeader(status int) { - if w.headerWritten { - return - } - w.headerWritten = true - w.status = status - - var headers bytes.Buffer - enc := hpack.NewEncoder(&headers) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - - for k, v := range w.header { - for index := range v { - enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) - } - } - - utils.Infof("Responding with %d", status) - w.headerStreamMutex.Lock() - defer w.headerStreamMutex.Unlock() - h2framer := http2.NewFramer(w.headerStream, nil) - err := h2framer.WriteHeaders(http2.HeadersFrameParam{ - StreamID: uint32(w.dataStreamID), - EndHeaders: true, - BlockFragment: headers.Bytes(), - }) - if err != nil { - utils.Errorf("could not write h2 header: %s", err.Error()) - } -} - -func (w *responseWriter) Write(p []byte) (int, error) { - if !w.headerWritten { - w.WriteHeader(200) - } - if !bodyAllowedForStatus(w.status) { - return 0, http.ErrBodyNotAllowed - } - return w.dataStream.Write(p) -} - -func (w *responseWriter) Flush() {} - -// TODO: Implement a functional CloseNotify method. -func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) } - -// test that we implement http.Flusher -var _ http.Flusher = &responseWriter{} - -// test that we implement http.CloseNotifier -var _ http.CloseNotifier = &responseWriter{} - -// copied from http2/http2.go -// bodyAllowedForStatus reports whether a given response status code -// permits a body. See RFC 2616, section 4.4. -func bodyAllowedForStatus(status int) bool { - switch { - case status >= 100 && status <= 199: - return false - case status == 204: - return false - case status == 304: - return false - } - return true -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go deleted file mode 100644 index 9ac5f19..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go +++ /dev/null @@ -1,168 +0,0 @@ -package h2quic - -import ( - "crypto/tls" - "errors" - "fmt" - "io" - "net/http" - "strings" - "sync" - - quic "github.com/lucas-clemente/quic-go" - - "golang.org/x/net/lex/httplex" -) - -type roundTripCloser interface { - http.RoundTripper - io.Closer -} - -// RoundTripper implements the http.RoundTripper interface -type RoundTripper struct { - mutex sync.Mutex - - // DisableCompression, if true, prevents the Transport from - // requesting compression with an "Accept-Encoding: gzip" - // request header when the Request contains no existing - // Accept-Encoding value. If the Transport requests gzip on - // its own and gets a gzipped response, it's transparently - // decoded in the Response.Body. However, if the user - // explicitly requested gzip it is not automatically - // uncompressed. - DisableCompression bool - - // TLSClientConfig specifies the TLS configuration to use with - // tls.Client. If nil, the default configuration is used. - TLSClientConfig *tls.Config - - // QuicConfig is the quic.Config used for dialing new connections. - // If nil, reasonable default values will be used. - QuicConfig *quic.Config - - clients map[string]roundTripCloser -} - -// RoundTripOpt are options for the Transport.RoundTripOpt method. -type RoundTripOpt struct { - // OnlyCachedConn controls whether the RoundTripper may - // create a new QUIC connection. If set true and - // no cached connection is available, RoundTrip - // will return ErrNoCachedConn. - OnlyCachedConn bool -} - -var _ roundTripCloser = &RoundTripper{} - -// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set -var ErrNoCachedConn = errors.New("h2quic: no cached connection was available") - -// RoundTripOpt is like RoundTrip, but takes options. -func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { - if req.URL == nil { - closeRequestBody(req) - return nil, errors.New("quic: nil Request.URL") - } - if req.URL.Host == "" { - closeRequestBody(req) - return nil, errors.New("quic: no Host in request URL") - } - if req.Header == nil { - closeRequestBody(req) - return nil, errors.New("quic: nil Request.Header") - } - - if req.URL.Scheme == "https" { - for k, vv := range req.Header { - if !httplex.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("quic: invalid http header field name %q", k) - } - for _, v := range vv { - if !httplex.ValidHeaderFieldValue(v) { - return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) - } - } - } - } else { - closeRequestBody(req) - return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) - } - - if req.Method != "" && !validMethod(req.Method) { - closeRequestBody(req) - return nil, fmt.Errorf("quic: invalid method %q", req.Method) - } - - hostname := authorityAddr("https", hostnameFromRequest(req)) - cl, err := r.getClient(hostname, opt.OnlyCachedConn) - if err != nil { - return nil, err - } - return cl.RoundTrip(req) -} - -// RoundTrip does a round trip. -func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return r.RoundTripOpt(req, RoundTripOpt{}) -} - -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) { - r.mutex.Lock() - defer r.mutex.Unlock() - - if r.clients == nil { - r.clients = make(map[string]roundTripCloser) - } - - client, ok := r.clients[hostname] - if !ok { - if onlyCached { - return nil, ErrNoCachedConn - } - client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) - r.clients[hostname] = client - } - return client, nil -} - -// Close closes the QUIC connections that this RoundTripper has used -func (r *RoundTripper) Close() error { - r.mutex.Lock() - defer r.mutex.Unlock() - for _, client := range r.clients { - if err := client.Close(); err != nil { - return err - } - } - r.clients = nil - return nil -} - -func closeRequestBody(req *http.Request) { - if req.Body != nil { - req.Body.Close() - } -} - -func validMethod(method string) bool { - /* - Method = "OPTIONS" ; Section 9.2 - | "GET" ; Section 9.3 - | "HEAD" ; Section 9.4 - | "POST" ; Section 9.5 - | "PUT" ; Section 9.6 - | "DELETE" ; Section 9.7 - | "TRACE" ; Section 9.8 - | "CONNECT" ; Section 9.9 - | extension-method - extension-method = token - token = 1* - */ - return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 -} - -// copied from net/http/http.go -func isNotToken(r rune) bool { - return !httplex.IsTokenRune(r) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go b/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go deleted file mode 100644 index 3647dc6..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go +++ /dev/null @@ -1,382 +0,0 @@ -package h2quic - -import ( - "crypto/tls" - "errors" - "fmt" - "net" - "net/http" - "runtime" - "strconv" - "sync" - "sync/atomic" - "time" - - quic "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" -) - -type streamCreator interface { - quic.Session - GetOrOpenStream(protocol.StreamID) (quic.Stream, error) -} - -type remoteCloser interface { - CloseRemote(protocol.ByteCount) -} - -// allows mocking of quic.Listen and quic.ListenAddr -var ( - quicListen = quic.Listen - quicListenAddr = quic.ListenAddr -) - -// Server is a HTTP2 server listening for QUIC connections. -type Server struct { - *http.Server - - // By providing a quic.Config, it is possible to set parameters of the QUIC connection. - // If nil, it uses reasonable default values. - QuicConfig *quic.Config - - // Private flag for demo, do not use - CloseAfterFirstRequest bool - - port uint32 // used atomically - - listenerMutex sync.Mutex - listener quic.Listener - - supportedVersionsAsString string -} - -// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. -func (s *Server) ListenAndServe() error { - if s.Server == nil { - return errors.New("use of h2quic.Server without http.Server") - } - return s.serveImpl(s.TLSConfig, nil) -} - -// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. -func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { - var err error - certs := make([]tls.Certificate, 1) - certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return err - } - // We currently only use the cert-related stuff from tls.Config, - // so we don't need to make a full copy. - config := &tls.Config{ - Certificates: certs, - } - return s.serveImpl(config, nil) -} - -// Serve an existing UDP connection. -func (s *Server) Serve(conn net.PacketConn) error { - return s.serveImpl(s.TLSConfig, conn) -} - -func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { - if s.Server == nil { - return errors.New("use of h2quic.Server without http.Server") - } - s.listenerMutex.Lock() - if s.listener != nil { - s.listenerMutex.Unlock() - return errors.New("ListenAndServe may only be called once") - } - - var ln quic.Listener - var err error - if conn == nil { - ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig) - } else { - ln, err = quicListen(conn, tlsConfig, s.QuicConfig) - } - if err != nil { - s.listenerMutex.Unlock() - return err - } - s.listener = ln - s.listenerMutex.Unlock() - - for { - sess, err := ln.Accept() - if err != nil { - return err - } - go s.handleHeaderStream(sess.(streamCreator)) - } -} - -func (s *Server) handleHeaderStream(session streamCreator) { - stream, err := session.AcceptStream() - if err != nil { - session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) - return - } - if stream.StreamID() != 3 { - session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3")) - return - } - - hpackDecoder := hpack.NewDecoder(4096, nil) - h2framer := http2.NewFramer(nil, stream) - - go func() { - var headerStreamMutex sync.Mutex // Protects concurrent calls to Write() - for { - if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil { - // QuicErrors must originate from stream.Read() returning an error. - // In this case, the session has already logged the error, so we don't - // need to log it again. - if _, ok := err.(*qerr.QuicError); !ok { - utils.Errorf("error handling h2 request: %s", err.Error()) - } - session.Close(err) - return - } - } - }() -} - -func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { - h2frame, err := h2framer.ReadFrame() - if err != nil { - return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") - } - h2headersFrame, ok := h2frame.(*http2.HeadersFrame) - if !ok { - return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame") - } - if !h2headersFrame.HeadersEnded() { - return errors.New("http2 header continuation not implemented") - } - headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment()) - if err != nil { - utils.Errorf("invalid http2 headers encoding: %s", err.Error()) - return err - } - - req, err := requestFromHeaders(headers) - if err != nil { - return err - } - - req.RemoteAddr = session.RemoteAddr().String() - - if utils.Debug() { - utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) - } else { - utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) - } - - dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID)) - if err != nil { - return err - } - // this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request - if dataStream == nil { - return nil - } - - var streamEnded bool - if h2headersFrame.StreamEnded() { - dataStream.(remoteCloser).CloseRemote(0) - streamEnded = true - _, _ = dataStream.Read([]byte{0}) // read the eof - } - - reqBody := newRequestBody(dataStream) - req.Body = reqBody - - responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) - - go func() { - handler := s.Handler - if handler == nil { - handler = http.DefaultServeMux - } - panicked := false - func() { - defer func() { - if p := recover(); p != nil { - // Copied from net/http/server.go - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - utils.Errorf("http: panic serving: %v\n%s", p, buf) - panicked = true - } - }() - handler.ServeHTTP(responseWriter, req) - }() - if panicked { - responseWriter.WriteHeader(500) - } else { - responseWriter.WriteHeader(200) - } - if responseWriter.dataStream != nil { - if !streamEnded && !reqBody.requestRead { - responseWriter.dataStream.Reset(nil) - } - responseWriter.dataStream.Close() - } - if s.CloseAfterFirstRequest { - time.Sleep(100 * time.Millisecond) - session.Close(nil) - } - }() - - return nil -} - -// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. -// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. -func (s *Server) Close() error { - s.listenerMutex.Lock() - defer s.listenerMutex.Unlock() - if s.listener != nil { - err := s.listener.Close() - s.listener = nil - return err - } - return nil -} - -// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. -// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. -func (s *Server) CloseGracefully(timeout time.Duration) error { - // TODO: implement - return nil -} - -// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC. -// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443): -// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" -func (s *Server) SetQuicHeaders(hdr http.Header) error { - port := atomic.LoadUint32(&s.port) - - if port == 0 { - // Extract port from s.Server.Addr - _, portStr, err := net.SplitHostPort(s.Server.Addr) - if err != nil { - return err - } - portInt, err := net.LookupPort("tcp", portStr) - if err != nil { - return err - } - port = uint32(portInt) - atomic.StoreUint32(&s.port, port) - } - - if s.supportedVersionsAsString == "" { - for i, v := range protocol.SupportedVersions { - s.supportedVersionsAsString += strconv.Itoa(int(v)) - if i != len(protocol.SupportedVersions)-1 { - s.supportedVersionsAsString += "," - } - } - } - - hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString)) - - return nil -} - -// ListenAndServeQUIC listens on the UDP network address addr and calls the -// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is -// used when handler is nil. -func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { - server := &Server{ - Server: &http.Server{ - Addr: addr, - Handler: handler, - }, - } - return server.ListenAndServeTLS(certFile, keyFile) -} - -// ListenAndServe listens on the given network address for both, TLS and QUIC -// connetions in parallel. It returns if one of the two returns an error. -// http.DefaultServeMux is used when handler is nil. -// The correct Alt-Svc headers for QUIC are set. -func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error { - // Load certs - var err error - certs := make([]tls.Certificate, 1) - certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return err - } - // We currently only use the cert-related stuff from tls.Config, - // so we don't need to make a full copy. - config := &tls.Config{ - Certificates: certs, - } - - // Open the listeners - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return err - } - udpConn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return err - } - defer udpConn.Close() - - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return err - } - tcpConn, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return err - } - defer tcpConn.Close() - - // Start the servers - httpServer := &http.Server{ - Addr: addr, - TLSConfig: config, - } - - quicServer := &Server{ - Server: httpServer, - } - - if handler == nil { - handler = http.DefaultServeMux - } - httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - quicServer.SetQuicHeaders(w.Header()) - handler.ServeHTTP(w, r) - }) - - hErr := make(chan error) - qErr := make(chan error) - go func() { - hErr <- httpServer.Serve(tcpConn) - }() - go func() { - qErr <- quicServer.Serve(udpConn) - }() - - select { - case err := <-hErr: - quicServer.Close() - return err - case err := <-qErr: - // Cannot close the HTTP server or wait for requests to complete properly :/ - return err - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go b/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go deleted file mode 100644 index 1ad9a3a..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go +++ /dev/null @@ -1,265 +0,0 @@ -package handshake - -import ( - "bytes" - "sync" - "time" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" -) - -// ConnectionParametersManager negotiates and stores the connection parameters -// A ConnectionParametersManager can be used for a server as well as a client -// For the server: -// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation -// 2. call GetHelloMap to get the values to send in the SHLO -// For the client: -// 1. call GetHelloMap to get the values to send in a CHLO -// 2. call SetFromMap with the values received in the SHLO -type ConnectionParametersManager interface { - SetFromMap(map[Tag][]byte) error - GetHelloMap() (map[Tag][]byte, error) - - GetSendStreamFlowControlWindow() protocol.ByteCount - GetSendConnectionFlowControlWindow() protocol.ByteCount - GetReceiveStreamFlowControlWindow() protocol.ByteCount - GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount - GetReceiveConnectionFlowControlWindow() protocol.ByteCount - GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount - GetMaxOutgoingStreams() uint32 - GetMaxIncomingStreams() uint32 - GetIdleConnectionStateLifetime() time.Duration - TruncateConnectionID() bool -} - -type connectionParametersManager struct { - mutex sync.RWMutex - - version protocol.VersionNumber - perspective protocol.Perspective - - flowControlNegotiated bool - - truncateConnectionID bool - maxStreamsPerConnection uint32 - maxIncomingDynamicStreamsPerConnection uint32 - idleConnectionStateLifetime time.Duration - sendStreamFlowControlWindow protocol.ByteCount - sendConnectionFlowControlWindow protocol.ByteCount - receiveStreamFlowControlWindow protocol.ByteCount - receiveConnectionFlowControlWindow protocol.ByteCount - maxReceiveStreamFlowControlWindow protocol.ByteCount - maxReceiveConnectionFlowControlWindow protocol.ByteCount -} - -var _ ConnectionParametersManager = &connectionParametersManager{} - -// ErrMalformedTag is returned when the tag value cannot be read -var ( - ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") - ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported") -) - -// NewConnectionParamatersManager creates a new connection parameters manager -func NewConnectionParamatersManager( - pers protocol.Perspective, v protocol.VersionNumber, - maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount, -) ConnectionParametersManager { - h := &connectionParametersManager{ - perspective: pers, - version: v, - sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client - sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client - receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, - receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, - maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, - } - - if h.perspective == protocol.PerspectiveServer { - h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent - h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective - } else { - h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient - h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent - h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective - } - - return h -} - -// SetFromMap reads all params -func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error { - h.mutex.Lock() - defer h.mutex.Unlock() - - if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.truncateConnectionID = (clientValue == 0) - } - if value, ok := params[TagMSPC]; ok { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) - } - if value, ok := params[TagMIDS]; ok { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue) - } - if value, ok := params[TagICSL]; ok { - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second) - } - if value, ok := params[TagSFCW]; ok { - if h.flowControlNegotiated { - return ErrFlowControlRenegotiationNotSupported - } - sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow) - } - if value, ok := params[TagCFCW]; ok { - if h.flowControlNegotiated { - return ErrFlowControlRenegotiationNotSupported - } - sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow) - } - - _, containsSFCW := params[TagSFCW] - _, containsCFCW := params[TagCFCW] - if containsCFCW || containsSFCW { - h.flowControlNegotiated = true - } - - return nil -} - -func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { - return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) -} - -func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { - return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) -} - -func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { - if h.perspective == protocol.PerspectiveServer { - return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer) - } - return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient) -} - -// GetHelloMap gets all parameters needed for the Hello message -func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) { - sfcw := bytes.NewBuffer([]byte{}) - utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) - cfcw := bytes.NewBuffer([]byte{}) - utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) - mspc := bytes.NewBuffer([]byte{}) - utils.WriteUint32(mspc, h.maxStreamsPerConnection) - mids := bytes.NewBuffer([]byte{}) - utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection) - icsl := bytes.NewBuffer([]byte{}) - utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second)) - - return map[Tag][]byte{ - TagICSL: icsl.Bytes(), - TagMSPC: mspc.Bytes(), - TagMIDS: mids.Bytes(), - TagCFCW: cfcw.Bytes(), - TagSFCW: sfcw.Bytes(), - }, nil -} - -// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.sendStreamFlowControlWindow -} - -// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.sendConnectionFlowControlWindow -} - -// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data -func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.receiveStreamFlowControlWindow -} - -// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { - return h.maxReceiveStreamFlowControlWindow -} - -// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data -func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.receiveConnectionFlowControlWindow -} - -// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data -func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { - return h.maxReceiveConnectionFlowControlWindow -} - -// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection -func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 { - h.mutex.RLock() - defer h.mutex.RUnlock() - - return h.maxIncomingDynamicStreamsPerConnection -} - -// GetMaxIncomingStreams get the maximum number of incoming streams per connection -func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 { - h.mutex.RLock() - defer h.mutex.RUnlock() - - maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection - return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier)) -} - -// GetIdleConnectionStateLifetime gets the idle timeout -func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.idleConnectionStateLifetime -} - -// TruncateConnectionID determines if the client requests truncated ConnectionIDs -func (h *connectionParametersManager) TruncateConnectionID() bool { - if h.perspective == protocol.PerspectiveClient { - return false - } - - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.truncateConnectionID -} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go b/vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go deleted file mode 100644 index c3caea3..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go +++ /dev/null @@ -1,100 +0,0 @@ -package handshake - -import ( - "encoding/asn1" - "fmt" - "net" - "time" - - "github.com/lucas-clemente/quic-go/crypto" -) - -const ( - stkPrefixIP byte = iota - stkPrefixString -) - -// An STK is a source address token -type STK struct { - RemoteAddr string - SentTime time.Time -} - -// token is the struct that is used for ASN1 serialization and deserialization -type token struct { - Data []byte - Timestamp int64 -} - -// An STKGenerator generates STKs -type STKGenerator struct { - stkSource crypto.StkSource -} - -// NewSTKGenerator initializes a new STKGenerator -func NewSTKGenerator() (*STKGenerator, error) { - stkSource, err := crypto.NewStkSource() - if err != nil { - return nil, err - } - return &STKGenerator{ - stkSource: stkSource, - }, nil -} - -// NewToken generates a new STK token for a given source address -func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) { - data, err := asn1.Marshal(token{ - Data: encodeRemoteAddr(raddr), - Timestamp: time.Now().Unix(), - }) - if err != nil { - return nil, err - } - return g.stkSource.NewToken(data) -} - -// DecodeToken decodes an STK token -func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) { - // if the client didn't send any STK, DecodeToken will be called with a nil-slice - if len(encrypted) == 0 { - return nil, nil - } - - data, err := g.stkSource.DecodeToken(encrypted) - if err != nil { - return nil, err - } - t := &token{} - rest, err := asn1.Unmarshal(data, t) - if err != nil { - return nil, err - } - if len(rest) != 0 { - return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) - } - return &STK{ - RemoteAddr: decodeRemoteAddr(t.Data), - SentTime: time.Unix(t.Timestamp, 0), - }, nil -} - -// encodeRemoteAddr encodes a remote address such that it can be saved in the STK -func encodeRemoteAddr(remoteAddr net.Addr) []byte { - if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - return append([]byte{stkPrefixIP}, udpAddr.IP...) - } - return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...) -} - -// decodeRemoteAddr decodes the remote address saved in the STK -func decodeRemoteAddr(data []byte) string { - // data will never be empty for an STK that we generated. Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == stkPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/interface.go b/vendor/github.com/lucas-clemente/quic-go/interface.go index c4b7ed3..87bf9ea 100644 --- a/vendor/github.com/lucas-clemente/quic-go/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/interface.go @@ -1,13 +1,24 @@ package quic import ( + "context" "io" "net" "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" ) +// The StreamID is the ID of a QUIC stream. +type StreamID = protocol.StreamID + +// A VersionNumber is a QUIC version number. +type VersionNumber = protocol.VersionNumber + +// A Cookie can be used to verify the ownership of the client address. +type Cookie = handshake.Cookie + // Stream is the interface implemented by QUIC streams type Stream interface { // Read reads data from the stream. @@ -19,9 +30,13 @@ type Stream interface { // after a fixed time limit; see SetDeadline and SetWriteDeadline. io.Writer io.Closer - StreamID() protocol.StreamID + StreamID() StreamID // Reset closes the stream with an error. Reset(error) + // The context is canceled as soon as the write-side of the stream is closed. + // This happens when Close() is called, or when the stream is reset (either locally or remotely). + // Warning: This API should not be considered stable and might change soon. + Context() context.Context // SetReadDeadline sets the deadline for future Read calls and // any currently-blocked Read call. // A zero value for t means Read will not time out. @@ -43,7 +58,7 @@ type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. // Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server). AcceptStream() (Stream, error) - // OpenStream opens a new QUIC stream, returning a special error when the peeer's concurrent stream limit is reached. + // OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached. // New streams always have the smallest possible stream ID. // TODO: Enable testing for the special error OpenStream() (Stream, error) @@ -56,9 +71,9 @@ type Session interface { RemoteAddr() net.Addr // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. Close(error) error - // WaitUntilClosed() blocks until the session is closed. + // The context is cancelled when the session is closed. // Warning: This API should not be considered stable and might change soon. - WaitUntilClosed() + Context() context.Context } // A NonFWSession is a QUIC connection between two peers half-way through the handshake. @@ -68,44 +83,36 @@ type NonFWSession interface { WaitUntilHandshakeComplete() error } -// An STK is a Source Address token. -// It is issued by the server and sent to the client. For the client, it is an opaque blob. -// The client can send the STK in subsequent handshakes to prove ownership of its IP address. -type STK struct { - // The remote address this token was issued for. - // If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String()) - // Otherwise, this is the string representation of the net.Addr (net.Addr.String()) - remoteAddr string - // The time that the STK was issued (resolution 1 second) - sentTime time.Time -} - // Config contains all configuration data needed for a QUIC server or client. -// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. type Config struct { // The QUIC versions that can be negotiated. // If not set, it uses all versions available. // Warning: This API should not be considered stable and will change soon. - Versions []protocol.VersionNumber - // Ask the server to truncate the connection ID sent in the Public Header. + Versions []VersionNumber + // Ask the server to omit the connection ID sent in the Public Header. // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // Currently only valid for the client. - RequestConnectionIDTruncation bool + RequestConnectionIDOmission bool // HandshakeTimeout is the maximum duration that the cryptographic handshake may take. // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 10 seconds. HandshakeTimeout time.Duration - // AcceptSTK determines if an STK is accepted. - // It is called with stk = nil if the client didn't send an STK. - // If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours. + // IdleTimeout is the maximum duration that may pass without any incoming network activity. + // This value only applies after the handshake has completed. + // If the timeout is exceeded, the connection is closed. + // If this value is zero, the timeout is set to 30 seconds. + IdleTimeout time.Duration + // AcceptCookie determines if a Cookie is accepted. + // It is called with cookie = nil if the client didn't send an Cookie. + // If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours. // This option is only valid for the server. - AcceptSTK func(clientAddr net.Addr, stk *STK) bool + AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool // MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data. // If this value is zero, it will default to 1 MB for the server and 6 MB for the client. - MaxReceiveStreamFlowControlWindow protocol.ByteCount + MaxReceiveStreamFlowControlWindow uint64 // MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data. // If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client. - MaxReceiveConnectionFlowControlWindow protocol.ByteCount + MaxReceiveConnectionFlowControlWindow uint64 // KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive. KeepAlive bool } diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/AEAD.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/AEAD.go similarity index 79% rename from vendor/github.com/lucas-clemente/quic-go/crypto/AEAD.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/AEAD.go index a59ce6e..d190515 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/AEAD.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/AEAD.go @@ -1,9 +1,10 @@ package crypto -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" // An AEAD implements QUIC's authenticated encryption and associated data type AEAD interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + Overhead() int } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go new file mode 100644 index 0000000..55e45be --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go @@ -0,0 +1,72 @@ +package crypto + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/lucas-clemente/aes12" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type aeadAESGCM12 struct { + otherIV []byte + myIV []byte + encrypter cipher.AEAD + decrypter cipher.AEAD +} + +var _ AEAD = &aeadAESGCM12{} + +// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size +// +// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte +// tag size, and couples the cipher and aes packages closely. +// See https://github.com/lucas-clemente/aes12. +func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { + if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 { + return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs") + } + encrypterCipher, err := aes12.NewCipher(myKey) + if err != nil { + return nil, err + } + encrypter, err := aes12.NewGCM(encrypterCipher) + if err != nil { + return nil, err + } + decrypterCipher, err := aes12.NewCipher(otherKey) + if err != nil { + return nil, err + } + decrypter, err := aes12.NewGCM(decrypterCipher) + if err != nil { + return nil, err + } + return &aeadAESGCM12{ + otherIV: otherIV, + myIV: myIV, + encrypter: encrypter, + decrypter: decrypter, + }, nil +} + +func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { + res := make([]byte, 12) + copy(res[0:4], iv) + binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) + return res +} + +func (aead *aeadAESGCM12) Overhead() int { + return aead.encrypter.Overhead() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go new file mode 100644 index 0000000..d55974e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go @@ -0,0 +1,74 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type aeadAESGCM struct { + otherIV []byte + myIV []byte + encrypter cipher.AEAD + decrypter cipher.AEAD +} + +var _ AEAD = &aeadAESGCM{} + +const ivLen = 12 + +// NewAEADAESGCM creates a AEAD using AES-GCM +func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) { + // the IVs need to be at least 8 bytes long, otherwise we can't compute the nonce + if len(otherIV) != ivLen || len(myIV) != ivLen { + return nil, errors.New("AES-GCM: expected 12 byte IVs") + } + + encrypterCipher, err := aes.NewCipher(myKey) + if err != nil { + return nil, err + } + encrypter, err := cipher.NewGCM(encrypterCipher) + if err != nil { + return nil, err + } + decrypterCipher, err := aes.NewCipher(otherKey) + if err != nil { + return nil, err + } + decrypter, err := cipher.NewGCM(decrypterCipher) + if err != nil { + return nil, err + } + + return &aeadAESGCM{ + otherIV: otherIV, + myIV: myIV, + encrypter: encrypter, + decrypter: decrypter, + }, nil +} + +func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) +} + +func (aead *aeadAESGCM) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { + nonce := make([]byte, ivLen) + binary.BigEndian.PutUint64(nonce[ivLen-8:], uint64(packetNumber)) + for i := 0; i < ivLen; i++ { + nonce[i] ^= iv[i] + } + return nonce +} + +func (aead *aeadAESGCM) Overhead() int { + return aead.encrypter.Overhead() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_cache.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_cache.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_cache.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_cache.go index 3ebdc1a..d8e8d8f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_cache.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_cache.go @@ -5,7 +5,7 @@ import ( "hash/fnv" "github.com/hashicorp/golang-lru" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) var ( diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_chain.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_compression.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_compression.go index ea5ecff..908b7ce 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_compression.go @@ -51,10 +51,10 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by res.WriteByte(uint8(e.t)) switch e.t { case entryCached: - utils.WriteUint64(res, e.h) + utils.LittleEndian.WriteUint64(res, e.h) case entryCommon: - utils.WriteUint64(res, e.h) - utils.WriteUint32(res, e.i) + utils.LittleEndian.WriteUint64(res, e.h) + utils.LittleEndian.WriteUint32(res, e.i) case entryCompressed: totalUncompressedLen += 4 + len(chain[i]) } @@ -67,7 +67,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by return nil, fmt.Errorf("cert compression failed: %s", err.Error()) } - utils.WriteUint32(res, uint32(totalUncompressedLen)) + utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen)) for i, e := range entries { if e.t != entryCompressed { @@ -115,11 +115,11 @@ func decompressChain(data []byte) ([][]byte, error) { return nil, errors.New("unexpected cached certificate") case entryCommon: e := entry{t: entryCommon} - e.h, err = utils.ReadUint64(r) + e.h, err = utils.LittleEndian.ReadUint64(r) if err != nil { return nil, err } - e.i, err = utils.ReadUint32(r) + e.i, err = utils.LittleEndian.ReadUint32(r) if err != nil { return nil, err } @@ -146,7 +146,7 @@ func decompressChain(data []byte) ([][]byte, error) { } if hasCompressedCerts { - uncompressedLength, err := utils.ReadUint32(r) + uncompressedLength, err := utils.LittleEndian.ReadUint32(r) if err != nil { fmt.Println(4) return nil, err diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_dict.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_dict.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_dict.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_dict.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_manager.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/cert_sets.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_sets.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/cert_sets.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/cert_sets.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go similarity index 72% rename from vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go index 5c58c4e..5d2e36f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go @@ -4,11 +4,12 @@ package crypto import ( "crypto/cipher" + "encoding/binary" "errors" "github.com/aead/chacha20" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) type aeadChacha20Poly1305 struct { @@ -45,9 +46,16 @@ func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV } func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData) + return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData) } func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData) + return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData) +} + +func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { + res := make([]byte, 12) + copy(res[0:4], iv) + binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) + return res } diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead_test.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead_test.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead_test.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/curve_25519.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/curve_25519.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/curve_25519.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go new file mode 100644 index 0000000..316bd1b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go @@ -0,0 +1,49 @@ +package crypto + +import ( + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +const ( + clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret" + serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" +) + +// A TLSExporter gets the negotiated ciphersuite and computes exporter +type TLSExporter interface { + GetCipherSuite() mint.CipherSuiteParams + ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) +} + +// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance +func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { + var myLabel, otherLabel string + if pers == protocol.PerspectiveClient { + myLabel = clientExporterLabel + otherLabel = serverExporterLabel + } else { + myLabel = serverExporterLabel + otherLabel = clientExporterLabel + } + myKey, myIV, err := computeKeyAndIV(tls, myLabel) + if err != nil { + return nil, err + } + otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel) + if err != nil { + return nil, err + } + return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) +} + +func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) { + cs := tls.GetCipherSuite() + secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size()) + if err != nil { + return nil, nil, err + } + key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen) + iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen) + return key, iv, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go similarity index 84% rename from vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go index accdbea..28f6c2c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go @@ -5,8 +5,8 @@ import ( "crypto/sha256" "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "golang.org/x/crypto/hkdf" ) @@ -20,8 +20,8 @@ import ( // return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV) // } -// DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance -func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) { +// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance +func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) { var swap bool if pers == protocol.PerspectiveClient { swap = true @@ -30,7 +30,7 @@ func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID pr if err != nil { return nil, err } - return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) + return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV) } // deriveKeys derives the keys and the IVs @@ -42,7 +42,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol } else { info.Write([]byte("QUIC key expansion\x00")) } - utils.WriteUint64(&info, uint64(connID)) + utils.BigEndian.WriteUint64(&info, uint64(connID)) info.Write(chlo) info.Write(scfg) info.Write(cert) diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/key_exchange.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_exchange.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/key_exchange.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_exchange.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go new file mode 100644 index 0000000..27158be --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go @@ -0,0 +1,11 @@ +package crypto + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// NewNullAEAD creates a NullAEAD +func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) { + if v.UsesTLS() { + return newNullAEADAESGCM(connID, p) + } + return &nullAEADFNV128a{perspective: p}, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go new file mode 100644 index 0000000..a647ad7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go @@ -0,0 +1,44 @@ +package crypto + +import ( + "crypto" + "encoding/binary" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39} + +func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) { + clientSecret, serverSecret := computeSecrets(connectionID) + + var mySecret, otherSecret []byte + if pers == protocol.PerspectiveClient { + mySecret = clientSecret + otherSecret = serverSecret + } else { + mySecret = serverSecret + otherSecret = clientSecret + } + + myKey, myIV := computeNullAEADKeyAndIV(mySecret) + otherKey, otherIV := computeNullAEADKeyAndIV(otherSecret) + + return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) +} + +func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { + connID := make([]byte, 8) + binary.BigEndian.PutUint64(connID, uint64(connectionID)) + cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID) + clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size()) + serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size()) + return +} + +func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { + key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16) + iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12) + return +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/null_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go similarity index 55% rename from vendor/github.com/lucas-clemente/quic-go/crypto/null_aead.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go index ed85663..ecc4010 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto/null_aead.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_fnv128a.go @@ -5,27 +5,18 @@ import ( "errors" "github.com/lucas-clemente/fnv128a" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // nullAEAD handles not-yet encrypted packets -type nullAEAD struct { +type nullAEADFNV128a struct { perspective protocol.Perspective - version protocol.VersionNumber } -var _ AEAD = &nullAEAD{} - -// NewNullAEAD creates a NullAEAD -func NewNullAEAD(p protocol.Perspective, v protocol.VersionNumber) AEAD { - return &nullAEAD{ - perspective: p, - version: v, - } -} +var _ AEAD = &nullAEADFNV128a{} // Open and verify the ciphertext -func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { +func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { if len(src) < 12 { return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long") } @@ -33,12 +24,10 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass hash := fnv128a.New() hash.Write(associatedData) hash.Write(src[12:]) - if n.version >= protocol.Version37 { - if n.perspective == protocol.PerspectiveServer { - hash.Write([]byte("Client")) - } else { - hash.Write([]byte("Server")) - } + if n.perspective == protocol.PerspectiveServer { + hash.Write([]byte("Client")) + } else { + hash.Write([]byte("Server")) } testHigh, testLow := hash.Sum128() @@ -52,7 +41,7 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass } // Seal writes hash and ciphertext to the buffer -func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { +func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { if cap(dst) < 12+len(src) { dst = make([]byte, 12+len(src)) } else { @@ -63,12 +52,10 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass hash.Write(associatedData) hash.Write(src) - if n.version >= protocol.Version37 { - if n.perspective == protocol.PerspectiveServer { - hash.Write([]byte("Server")) - } else { - hash.Write([]byte("Client")) - } + if n.perspective == protocol.PerspectiveServer { + hash.Write([]byte("Server")) + } else { + hash.Write([]byte("Client")) } high, low := hash.Sum128() @@ -78,3 +65,7 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass binary.LittleEndian.PutUint32(dst[8:], uint32(high)) return dst } + +func (n *nullAEADFNV128a) Overhead() int { + return 12 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/server_proof.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/server_proof.go diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go b/vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go rename to vendor/github.com/lucas-clemente/quic-go/internal/crypto/source_address_token.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go new file mode 100644 index 0000000..e74c1d1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go @@ -0,0 +1,110 @@ +package flowcontrol + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type baseFlowController struct { + mutex sync.RWMutex + + rttStats *congestion.RTTStats + + bytesSent protocol.ByteCount + sendWindow protocol.ByteCount + + lastWindowUpdateTime time.Time + + bytesRead protocol.ByteCount + highestReceived protocol.ByteCount + receiveWindow protocol.ByteCount + receiveWindowIncrement protocol.ByteCount + maxReceiveWindowIncrement protocol.ByteCount +} + +func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.bytesSent += n +} + +// UpdateSendWindow should be called after receiving a WindowUpdateFrame +// it returns true if the window was actually updated +func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if offset > c.sendWindow { + c.sendWindow = offset + } +} + +func (c *baseFlowController) sendWindowSize() protocol.ByteCount { + // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters + if c.bytesSent > c.sendWindow { + return 0 + } + return c.sendWindow - c.bytesSent +} + +func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + + // pretend we sent a WindowUpdate when reading the first byte + // this way auto-tuning of the window increment already works for the first WindowUpdate + if c.bytesRead == 0 { + c.lastWindowUpdateTime = time.Now() + } + c.bytesRead += n +} + +// getWindowUpdate updates the receive window, if necessary +// it returns the new offset +func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { + diff := c.receiveWindow - c.bytesRead + // update the window when more than half of it was already consumed + if diff >= (c.receiveWindowIncrement / 2) { + return 0 + } + + c.maybeAdjustWindowIncrement() + c.receiveWindow = c.bytesRead + c.receiveWindowIncrement + c.lastWindowUpdateTime = time.Now() + return c.receiveWindow +} + +func (c *baseFlowController) IsBlocked() bool { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.sendWindowSize() == 0 +} + +// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often +func (c *baseFlowController) maybeAdjustWindowIncrement() { + if c.lastWindowUpdateTime.IsZero() { + return + } + + rtt := c.rttStats.SmoothedRTT() + if rtt == 0 { + return + } + + timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime) + // interval between the window updates is sufficiently large, no need to increase the increment + if timeSinceLastWindowUpdate >= 2*rtt { + return + } + c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement) +} + +func (c *baseFlowController) checkFlowControlViolation() bool { + return c.highestReceived > c.receiveWindow +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go new file mode 100644 index 0000000..934d646 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -0,0 +1,77 @@ +package flowcontrol + +import ( + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +type connectionFlowController struct { + baseFlowController +} + +var _ ConnectionFlowController = &connectionFlowController{} + +// NewConnectionFlowController gets a new flow controller for the connection +// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0. +func NewConnectionFlowController( + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + rttStats *congestion.RTTStats, +) ConnectionFlowController { + return &connectionFlowController{ + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowIncrement: receiveWindow, + maxReceiveWindowIncrement: maxReceiveWindow, + }, + } +} + +func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.baseFlowController.sendWindowSize() +} + +// IncrementHighestReceived adds an increment to the highestReceived value +func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.highestReceived += increment + if c.checkFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow)) + } + return nil +} + +func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { + c.mutex.Lock() + defer c.mutex.Unlock() + + oldWindowIncrement := c.receiveWindowIncrement + offset := c.baseFlowController.getWindowUpdate() + if oldWindowIncrement < c.receiveWindowIncrement { + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + } + return offset +} + +// EnsureMinimumWindowIncrement sets a minimum window increment +// it should make sure that the connection-level window is increased when a stream-level window grows +func (c *connectionFlowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if inc > c.receiveWindowIncrement { + c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement) + c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go new file mode 100644 index 0000000..75ec6fa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go @@ -0,0 +1,37 @@ +package flowcontrol + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +type flowController interface { + // for sending + SendWindowSize() protocol.ByteCount + IsBlocked() bool + UpdateSendWindow(protocol.ByteCount) + AddBytesSent(protocol.ByteCount) + // for receiving + AddBytesRead(protocol.ByteCount) + GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary +} + +// A StreamFlowController is a flow controller for a QUIC stream. +type StreamFlowController interface { + flowController + // for receiving + // UpdateHighestReceived should be called when a new highest offset is received + // final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame + UpdateHighestReceived(offset protocol.ByteCount, final bool) error +} + +// The ConnectionFlowController is the flow controller for the connection. +type ConnectionFlowController interface { + flowController +} + +type connectionFlowControllerI interface { + ConnectionFlowController + // The following two methods are not supposed to be called from outside this packet, but are needed internally + // for sending + EnsureMinimumWindowIncrement(protocol.ByteCount) + // for receiving + IncrementHighestReceived(protocol.ByteCount) error +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go new file mode 100644 index 0000000..96e13dc --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -0,0 +1,128 @@ +package flowcontrol + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +type streamFlowController struct { + baseFlowController + + streamID protocol.StreamID + + connection connectionFlowControllerI + contributesToConnection bool // does the stream contribute to connection level flow control + + receivedFinalOffset bool +} + +var _ StreamFlowController = &streamFlowController{} + +// NewStreamFlowController gets a new flow controller for a stream +func NewStreamFlowController( + streamID protocol.StreamID, + contributesToConnection bool, + cfc ConnectionFlowController, + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + initialSendWindow protocol.ByteCount, + rttStats *congestion.RTTStats, +) StreamFlowController { + return &streamFlowController{ + streamID: streamID, + contributesToConnection: contributesToConnection, + connection: cfc.(connectionFlowControllerI), + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowIncrement: receiveWindow, + maxReceiveWindowIncrement: maxReceiveWindow, + sendWindow: initialSendWindow, + }, + } +} + +// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher +// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before +func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount, final bool) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // when receiving a final offset, check that this final offset is consistent with a final offset we might have received earlier + if final && c.receivedFinalOffset && byteOffset != c.highestReceived { + return qerr.Error(qerr.StreamDataAfterTermination, fmt.Sprintf("Received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, byteOffset)) + } + // if we already received a final offset, check that the offset in the STREAM frames is below the final offset + if c.receivedFinalOffset && byteOffset > c.highestReceived { + return qerr.StreamDataAfterTermination + } + if final { + c.receivedFinalOffset = true + } + if byteOffset == c.highestReceived { + return nil + } + if byteOffset <= c.highestReceived { + // a STREAM_FRAME with a higher offset was received before. + if final { + // If the current byteOffset is smaller than the offset in that STREAM_FRAME, this STREAM_FRAME contained data after the end of the stream + return qerr.StreamDataAfterTermination + } + // this is a reordered STREAM_FRAME + return nil + } + + increment := byteOffset - c.highestReceived + c.highestReceived = byteOffset + if c.checkFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow)) + } + if c.contributesToConnection { + return c.connection.IncrementHighestReceived(increment) + } + return nil +} + +func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) { + c.baseFlowController.AddBytesRead(n) + if c.contributesToConnection { + c.connection.AddBytesRead(n) + } +} + +func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { + c.baseFlowController.AddBytesSent(n) + if c.contributesToConnection { + c.connection.AddBytesSent(n) + } +} + +func (c *streamFlowController) SendWindowSize() protocol.ByteCount { + c.mutex.Lock() + defer c.mutex.Unlock() + + window := c.baseFlowController.sendWindowSize() + if c.contributesToConnection { + window = utils.MinByteCount(window, c.connection.SendWindowSize()) + } + return window +} + +func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { + c.mutex.Lock() + defer c.mutex.Unlock() + + oldWindowIncrement := c.receiveWindowIncrement + offset := c.baseFlowController.getWindowUpdate() + if c.receiveWindowIncrement > oldWindowIncrement { // auto-tuning enlarged the window increment + utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowIncrement/(1<<10)) + if c.contributesToConnection { + c.connection.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(c.receiveWindowIncrement) * protocol.ConnectionFlowControlMultiplier)) + } + } + return offset +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go new file mode 100644 index 0000000..10281fa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go @@ -0,0 +1,101 @@ +package handshake + +import ( + "encoding/asn1" + "fmt" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/crypto" +) + +const ( + cookiePrefixIP byte = iota + cookiePrefixString +) + +// A Cookie is derived from the client address and can be used to verify the ownership of this address. +type Cookie struct { + RemoteAddr string + // The time that the STK was issued (resolution 1 second) + SentTime time.Time +} + +// token is the struct that is used for ASN1 serialization and deserialization +type token struct { + Data []byte + Timestamp int64 +} + +// A CookieGenerator generates Cookies +type CookieGenerator struct { + cookieSource crypto.StkSource +} + +// NewCookieGenerator initializes a new CookieGenerator +func NewCookieGenerator() (*CookieGenerator, error) { + stkSource, err := crypto.NewStkSource() + if err != nil { + return nil, err + } + return &CookieGenerator{ + cookieSource: stkSource, + }, nil +} + +// NewToken generates a new Cookie for a given source address +func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) { + data, err := asn1.Marshal(token{ + Data: encodeRemoteAddr(raddr), + Timestamp: time.Now().Unix(), + }) + if err != nil { + return nil, err + } + return g.cookieSource.NewToken(data) +} + +// DecodeToken decodes a Cookie +func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) { + // if the client didn't send any Cookie, DecodeToken will be called with a nil-slice + if len(encrypted) == 0 { + return nil, nil + } + + data, err := g.cookieSource.DecodeToken(encrypted) + if err != nil { + return nil, err + } + t := &token{} + rest, err := asn1.Unmarshal(data, t) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) + } + return &Cookie{ + RemoteAddr: decodeRemoteAddr(t.Data), + SentTime: time.Unix(t.Timestamp, 0), + }, nil +} + +// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie +func encodeRemoteAddr(remoteAddr net.Addr) []byte { + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + return append([]byte{cookiePrefixIP}, udpAddr.IP...) + } + return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...) +} + +// decodeRemoteAddr decodes the remote address saved in the Cookie +func decodeRemoteAddr(data []byte) string { + // data will never be empty for a Cookie that we generated. Check it to be on the safe side + if len(data) == 0 { + return "" + } + if data[0] == cookiePrefixIP { + return net.IP(data[1:]).String() + } + return string(data[1:]) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go new file mode 100644 index 0000000..317f6e5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go @@ -0,0 +1,43 @@ +package handshake + +import ( + "net" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type cookieHandler struct { + callback func(net.Addr, *Cookie) bool + + cookieGenerator *CookieGenerator +} + +var _ mint.CookieHandler = &cookieHandler{} + +func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) { + cookieGenerator, err := NewCookieGenerator() + if err != nil { + return nil, err + } + return &cookieHandler{ + callback: callback, + cookieGenerator: cookieGenerator, + }, nil +} + +func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) { + if h.callback(conn.RemoteAddr(), nil) { + return nil, nil + } + return h.cookieGenerator.NewToken(conn.RemoteAddr()) +} + +func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool { + data, err := h.cookieGenerator.DecodeToken(token) + if err != nil { + utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) + return false + } + return h.callback(conn.RemoteAddr(), data) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go similarity index 77% rename from vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go index a8d8812..c923bbc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go @@ -11,9 +11,9 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -23,6 +23,7 @@ type cryptoSetupClient struct { hostname string connID protocol.ConnectionID version protocol.VersionNumber + initialVersion protocol.VersionNumber negotiatedVersions []protocol.VersionNumber cryptoStream io.ReadWriter @@ -42,17 +43,18 @@ type cryptoSetupClient struct { clientHelloCounter int serverVerified bool // has the certificate chain and the proof already been verified - keyDerivation KeyDerivationFunction + keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction receivedSecurePacket bool nullAEAD crypto.AEAD secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - aeadChanged chan<- protocol.EncryptionLevel - params *TransportParameters - connectionParameters ConnectionParametersManager + paramsChan chan<- TransportParameters + aeadChanged chan<- protocol.EncryptionLevel + + params *TransportParameters } var _ CryptoSetup = &cryptoSetupClient{} @@ -65,30 +67,36 @@ var ( // NewCryptoSetupClient creates a new CryptoSetup instance for a client func NewCryptoSetupClient( + cryptoStream io.ReadWriter, hostname string, connID protocol.ConnectionID, version protocol.VersionNumber, - cryptoStream io.ReadWriter, tlsConfig *tls.Config, - connectionParameters ConnectionParametersManager, - aeadChanged chan<- protocol.EncryptionLevel, params *TransportParameters, + paramsChan chan<- TransportParameters, + aeadChanged chan<- protocol.EncryptionLevel, + initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, ) (CryptoSetup, error) { + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) + if err != nil { + return nil, err + } return &cryptoSetupClient{ - hostname: hostname, - connID: connID, - version: version, - cryptoStream: cryptoStream, - certManager: crypto.NewCertManager(tlsConfig), - connectionParameters: connectionParameters, - keyDerivation: crypto.DeriveKeysAESGCM, - keyExchange: getEphermalKEX, - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), - aeadChanged: aeadChanged, - negotiatedVersions: negotiatedVersions, - divNonceChan: make(chan []byte), - params: params, + cryptoStream: cryptoStream, + hostname: hostname, + connID: connID, + version: version, + certManager: crypto.NewCertManager(tlsConfig), + params: params, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: nullAEAD, + paramsChan: paramsChan, + aeadChanged: aeadChanged, + initialVersion: initialVersion, + negotiatedVersions: negotiatedVersions, + divNonceChan: make(chan []byte), }, nil } @@ -141,15 +149,21 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { utils.Debugf("Got %s", message) switch message.Tag { case TagREJ: - err = h.handleREJMessage(message.Data) + if err := h.handleREJMessage(message.Data); err != nil { + return err + } case TagSHLO: - err = h.handleSHLOMessage(message.Data) + params, err := h.handleSHLOMessage(message.Data) + if err != nil { + return err + } + // blocks until the session has received the parameters + h.paramsChan <- *params + h.aeadChanged <- protocol.EncryptionForwardSecure + close(h.aeadChanged) default: return qerr.InvalidCryptoMessageType } - if err != nil { - return err - } } } @@ -215,12 +229,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { return nil } -func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { +func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) { h.mutex.Lock() defer h.mutex.Unlock() if !h.receivedSecurePacket { - return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") + return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") } if sno, ok := cryptoData[TagSNO]; ok { @@ -229,22 +243,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { serverPubs, ok := cryptoData[TagPUBS] if !ok { - return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") + return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") } verTag, ok := cryptoData[TagVER] if !ok { - return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") + return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") } if !h.validateVersionList(verTag) { - return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") + return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") } nonce := append(h.nonc, h.sno...) ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs) if err != nil { - return err + return nil, err } leafCert := h.certManager.GetLeafCert() @@ -261,39 +275,32 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { protocol.PerspectiveClient, ) if err != nil { - return err + return nil, err } - err = h.connectionParameters.SetFromMap(cryptoData) + params, err := readHelloMap(cryptoData) if err != nil { - return qerr.InvalidCryptoMessageParameter + return nil, qerr.InvalidCryptoMessageParameter } - - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) - - return nil + return params, nil } func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { - if len(h.negotiatedVersions) == 0 { + numNegotiatedVersions := len(h.negotiatedVersions) + if numNegotiatedVersions == 0 { return true } - if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) { + if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions { return false } b := bytes.NewReader(verTags) - for _, negotiatedVersion := range h.negotiatedVersions { - verTag, err := utils.ReadUint32(b) + for i := 0; i < numNegotiatedVersions; i++ { + v, err := utils.BigEndian.ReadUint32(b) if err != nil { // should never occur, since the length was already checked return false } - ver := protocol.VersionTagToNumber(verTag) - if !protocol.IsSupportedVersion(protocol.SupportedVersions, ver) { - ver = protocol.VersionUnsupported - } - if ver != negotiatedVersion { + if protocol.VersionNumber(v) != h.negotiatedVersions[i] { return false } } @@ -333,16 +340,16 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) { h.mutex.RLock() defer h.mutex.RUnlock() if h.forwardSecureAEAD != nil { - return protocol.EncryptionForwardSecure, h.sealForwardSecure + return protocol.EncryptionForwardSecure, h.forwardSecureAEAD } else if h.secureAEAD != nil { - return protocol.EncryptionSecure, h.sealSecure + return protocol.EncryptionSecure, h.secureAEAD } else { - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } } func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { @@ -351,33 +358,21 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry switch encLevel { case protocol.EncryptionUnencrypted: - return h.sealUnencrypted, nil + return h.nullAEAD, nil case protocol.EncryptionSecure: if h.secureAEAD == nil { return nil, errors.New("CryptoSetupClient: no secureAEAD") } - return h.sealSecure, nil + return h.secureAEAD, nil case protocol.EncryptionForwardSecure: if h.forwardSecureAEAD == nil { return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD") } - return h.sealForwardSecure, nil + return h.forwardSecureAEAD, nil } return nil, errors.New("CryptoSetupClient: no encryption level specified") } -func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.nullAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) -} - func (h *cryptoSetupClient) DiversificationNonce() []byte { panic("not needed for cryptoSetupClient") } @@ -386,6 +381,10 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) { h.divNonceChan <- data } +func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType { + panic("not needed for cryptoSetupServer") +} + func (h *cryptoSetupClient) sendCHLO() error { h.clientHelloCounter++ if h.clientHelloCounter > protocol.MaxClientHellos { @@ -413,15 +412,11 @@ func (h *cryptoSetupClient) sendCHLO() error { } h.lastSentCHLO = b.Bytes() - return nil } func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { - tags, err := h.connectionParameters.GetHelloMap() - if err != nil { - return nil, err - } + tags := h.params.getHelloMap() tags[TagSNI] = []byte(h.hostname) tags[TagPDMD] = []byte("X509") @@ -431,12 +426,9 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { } versionTag := make([]byte, 4) - binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) + binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion)) tags[TagVER] = versionTag - if h.params.RequestConnectionIDTruncation { - tags[TagTCID] = []byte{0, 0, 0, 0} - } if len(h.stk) > 0 { tags[TagSTK] = h.stk } diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go similarity index 78% rename from vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go index f639f51..50e2618 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go @@ -9,14 +9,14 @@ import ( "net" "sync" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) -// KeyDerivationFunction is used for key derivation -type KeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) +// QuicCryptoKeyDerivationFunction is used for key derivation +type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) // KeyExchangeFunction is used to make a new KEX type KeyExchangeFunction func() crypto.KeyExchange @@ -26,70 +26,77 @@ type cryptoSetupServer struct { connID protocol.ConnectionID remoteAddr net.Addr scfg *ServerConfig - stkGenerator *STKGenerator diversificationNonce []byte version protocol.VersionNumber supportedVersions []protocol.VersionNumber - acceptSTKCallback func(net.Addr, *STK) bool + acceptSTKCallback func(net.Addr, *Cookie) bool nullAEAD crypto.AEAD secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD receivedForwardSecurePacket bool - sentSHLO bool receivedSecurePacket bool - aeadChanged chan<- protocol.EncryptionLevel + sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written - keyDerivation KeyDerivationFunction + receivedParams bool + paramsChan chan<- TransportParameters + aeadChanged chan<- protocol.EncryptionLevel + + keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction cryptoStream io.ReadWriter - connectionParameters ConnectionParametersManager + params *TransportParameters mutex sync.RWMutex } var _ CryptoSetup = &cryptoSetupServer{} -// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO -// this is an expiremnt implemented by Chrome in QUIC 36, which we don't support +// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO. +// This is an experiment implemented by Chrome in QUIC 36, which we don't support. // TODO: remove this when dropping support for QUIC 36 var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported") +// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO. +// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point. +var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported") + // NewCryptoSetup creates a new CryptoSetup instance for a server func NewCryptoSetup( + cryptoStream io.ReadWriter, connID protocol.ConnectionID, remoteAddr net.Addr, version protocol.VersionNumber, scfg *ServerConfig, - cryptoStream io.ReadWriter, - connectionParametersManager ConnectionParametersManager, + params *TransportParameters, supportedVersions []protocol.VersionNumber, - acceptSTK func(net.Addr, *STK) bool, + acceptSTK func(net.Addr, *Cookie) bool, + paramsChan chan<- TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, ) (CryptoSetup, error) { - stkGenerator, err := NewSTKGenerator() + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) if err != nil { return nil, err } - return &cryptoSetupServer{ - connID: connID, - remoteAddr: remoteAddr, - version: version, - supportedVersions: supportedVersions, - scfg: scfg, - stkGenerator: stkGenerator, - keyDerivation: crypto.DeriveKeysAESGCM, - keyExchange: getEphermalKEX, - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), - cryptoStream: cryptoStream, - connectionParameters: connectionParametersManager, - acceptSTKCallback: acceptSTK, - aeadChanged: aeadChanged, + cryptoStream: cryptoStream, + connID: connID, + remoteAddr: remoteAddr, + version: version, + supportedVersions: supportedVersions, + scfg: scfg, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: nullAEAD, + params: params, + acceptSTKCallback: acceptSTK, + sentSHLO: make(chan struct{}), + paramsChan: paramsChan, + aeadChanged: aeadChanged, }, nil } @@ -120,6 +127,9 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment { return false, ErrHOLExperiment } + if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment { + return false, ErrNSTPExperiment + } sniSlice, ok := cryptoData[TagSNI] if !ok { @@ -139,8 +149,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] if len(verSlice) != 4 { return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag") } - verTag := binary.LittleEndian.Uint32(verSlice) - ver := protocol.VersionTagToNumber(verTag) + ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice)) // If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack. if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) { return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") @@ -154,16 +163,27 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] return false, err } + params, err := readHelloMap(cryptoData) + if err != nil { + return false, err + } + // blocks until the session has received the parameters + if !h.receivedParams { + h.receivedParams = true + h.paramsChan <- *params + } + if !h.isInchoateCHLO(cryptoData, certUncompressed) { // We have a CHLO with a proper server config ID, do a 0-RTT handshake reply, err = h.handleCHLO(sni, chloData, cryptoData) if err != nil { return false, err } - _, err = h.cryptoStream.Write(reply) - if err != nil { + if _, err := h.cryptoStream.Write(reply); err != nil { return false, err } + h.aeadChanged <- protocol.EncryptionForwardSecure + close(h.sentSHLO) return true, nil } @@ -186,6 +206,8 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu if err == nil { if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client h.receivedForwardSecurePacket = true + // wait until protocol.EncryptionForwardSecure was sent on the aeadChan + <-h.sentSHLO close(h.aeadChanged) } return res, protocol.EncryptionForwardSecure, nil @@ -215,18 +237,18 @@ func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) { h.mutex.RLock() defer h.mutex.RUnlock() if h.forwardSecureAEAD != nil { - return protocol.EncryptionForwardSecure, h.sealForwardSecure + return protocol.EncryptionForwardSecure, h.forwardSecureAEAD } - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { h.mutex.RLock() defer h.mutex.RUnlock() if h.secureAEAD != nil { - return protocol.EncryptionSecure, h.sealSecure + return protocol.EncryptionSecure, h.secureAEAD } - return protocol.EncryptionUnencrypted, h.sealUnencrypted + return protocol.EncryptionUnencrypted, h.nullAEAD } func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { @@ -235,33 +257,21 @@ func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.Encry switch encLevel { case protocol.EncryptionUnencrypted: - return h.sealUnencrypted, nil + return h.nullAEAD, nil case protocol.EncryptionSecure: if h.secureAEAD == nil { return nil, errors.New("CryptoSetupServer: no secureAEAD") } - return h.sealSecure, nil + return h.secureAEAD, nil case protocol.EncryptionForwardSecure: if h.forwardSecureAEAD == nil { return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD") } - return h.sealForwardSecure, nil + return h.forwardSecureAEAD, nil } return nil, errors.New("CryptoSetupServer: no encryption level specified") } -func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.nullAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) -} - -func (h *cryptoSetupServer) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) -} - func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool { if _, ok := cryptoData[TagPUBS]; !ok { return true @@ -282,7 +292,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt } func (h *cryptoSetupServer) acceptSTK(token []byte) bool { - stk, err := h.stkGenerator.DecodeToken(token) + stk, err := h.scfg.cookieGenerator.DecodeToken(token) if err != nil { utils.Debugf("STK invalid: %s", err.Error()) return false @@ -295,7 +305,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small") } - token, err := h.stkGenerator.NewToken(h.remoteAddr) + token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr) if err != nil { return nil, err } @@ -418,19 +428,11 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T return nil, err } - err = h.connectionParameters.SetFromMap(cryptoData) - if err != nil { - return nil, err - } - - replyMap, err := h.connectionParameters.GetHelloMap() - if err != nil { - return nil, err - } + replyMap := h.params.getHelloMap() // add crypto parameters verTag := &bytes.Buffer{} for _, v := range h.supportedVersions { - utils.WriteUint32(verTag, protocol.VersionNumberToTag(v)) + utils.BigEndian.WriteUint32(verTag, uint32(v)) } replyMap[TagPUBS] = ephermalKex.PublicKey() replyMap[TagSNO] = serverNonce @@ -444,9 +446,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T var reply bytes.Buffer message.Write(&reply) utils.Debugf("Sending %s", message) - - h.aeadChanged <- protocol.EncryptionForwardSecure - return reply.Bytes(), nil } @@ -459,6 +458,10 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) { panic("not needed for cryptoSetupServer") } +func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType { + panic("not needed for cryptoSetupServer") +} + func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { if len(nonce) != 32 { return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go new file mode 100644 index 0000000..e14e7ad --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go @@ -0,0 +1,242 @@ +package handshake + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "sync" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// KeyDerivationFunction is used for key derivation +type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) + +type cryptoSetupTLS struct { + mutex sync.RWMutex + + perspective protocol.Perspective + + tls mintTLS + conn *fakeConn + + nextPacketType protocol.PacketType + + keyDerivation KeyDerivationFunction + nullAEAD crypto.AEAD + aead crypto.AEAD + + aeadChanged chan<- protocol.EncryptionLevel +} + +// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server +func NewCryptoSetupTLSServer( + cryptoStream io.ReadWriter, + connID protocol.ConnectionID, + tlsConfig *tls.Config, + remoteAddr net.Addr, + params *TransportParameters, + paramsChan chan<- TransportParameters, + aeadChanged chan<- protocol.EncryptionLevel, + checkCookie func(net.Addr, *Cookie) bool, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) (CryptoSetup, error) { + mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer) + if err != nil { + return nil, err + } + mintConf.RequireCookie = true + mintConf.CookieHandler, err = newCookieHandler(checkCookie) + if err != nil { + return nil, err + } + conn := &fakeConn{ + stream: cryptoStream, + pers: protocol.PerspectiveServer, + remoteAddr: remoteAddr, + } + mintConn := mint.Server(conn, mintConf) + eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version) + if err := mintConn.SetExtensionHandler(eh); err != nil { + return nil, err + } + + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) + if err != nil { + return nil, err + } + + return &cryptoSetupTLS{ + perspective: protocol.PerspectiveServer, + tls: &mintController{mintConn}, + conn: conn, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, + }, nil +} + +// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client +func NewCryptoSetupTLSClient( + cryptoStream io.ReadWriter, + connID protocol.ConnectionID, + hostname string, + tlsConfig *tls.Config, + params *TransportParameters, + paramsChan chan<- TransportParameters, + aeadChanged chan<- protocol.EncryptionLevel, + initialVersion protocol.VersionNumber, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) (CryptoSetup, error) { + mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient) + if err != nil { + return nil, err + } + mintConf.ServerName = hostname + conn := &fakeConn{ + stream: cryptoStream, + pers: protocol.PerspectiveClient, + } + mintConn := mint.Client(conn, mintConf) + eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version) + if err := mintConn.SetExtensionHandler(eh); err != nil { + return nil, err + } + + nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) + if err != nil { + return nil, err + } + + return &cryptoSetupTLS{ + conn: conn, + perspective: protocol.PerspectiveClient, + tls: &mintController{mintConn}, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, + nextPacketType: protocol.PacketTypeInitial, + }, nil +} + +func (h *cryptoSetupTLS) HandleCryptoStream() error { +handshakeLoop: + for { + switch alert := h.tls.Handshake(); alert { + case mint.AlertNoAlert: // handshake complete + break handshakeLoop + case mint.AlertWouldBlock: + h.determineNextPacketType() + if err := h.conn.Continue(); err != nil { + return err + } + default: + return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) + } + } + + aead, err := h.keyDerivation(h.tls, h.perspective) + if err != nil { + return err + } + h.mutex.Lock() + h.aead = aead + h.mutex.Unlock() + + // signal to the outside world that the handshake completed + h.aeadChanged <- protocol.EncryptionForwardSecure + close(h.aeadChanged) + return nil +} + +func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { + h.mutex.RLock() + defer h.mutex.RUnlock() + + if h.aead != nil { + data, err := h.aead.Open(dst, src, packetNumber, associatedData) + if err != nil { + return nil, protocol.EncryptionUnspecified, err + } + return data, protocol.EncryptionForwardSecure, nil + } + data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData) + if err != nil { + return nil, protocol.EncryptionUnspecified, err + } + return data, protocol.EncryptionUnencrypted, nil +} + +func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) { + h.mutex.RLock() + defer h.mutex.RUnlock() + + if h.aead != nil { + return protocol.EncryptionForwardSecure, h.aead + } + return protocol.EncryptionUnencrypted, h.nullAEAD +} + +func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { + errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String()) + h.mutex.RLock() + defer h.mutex.RUnlock() + + switch encLevel { + case protocol.EncryptionUnencrypted: + return h.nullAEAD, nil + case protocol.EncryptionForwardSecure: + if h.aead == nil { + return nil, errNoSealer + } + return h.aead, nil + default: + return nil, errNoSealer + } +} + +func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { + return protocol.EncryptionUnencrypted, h.nullAEAD +} + +func (h *cryptoSetupTLS) determineNextPacketType() error { + h.mutex.Lock() + defer h.mutex.Unlock() + state := h.tls.State().HandshakeState + if h.perspective == protocol.PerspectiveServer { + switch state { + case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest + h.nextPacketType = protocol.PacketTypeRetry + case "ServerStateWaitFinished": + h.nextPacketType = protocol.PacketTypeHandshake + default: + // TODO: accept 0-RTT data + return fmt.Errorf("Unexpected handshake state: %s", state) + } + return nil + } + // client + if state != "ClientStateWaitSH" { + h.nextPacketType = protocol.PacketTypeHandshake + } + return nil +} + +func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType { + h.mutex.RLock() + defer h.mutex.RUnlock() + return h.nextPacketType +} + +func (h *cryptoSetupTLS) DiversificationNonce() []byte { + panic("diversification nonce not needed for TLS") +} + +func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) { + panic("diversification nonce not needed for TLS") +} diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/ephermal_cache.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go similarity index 92% rename from vendor/github.com/lucas-clemente/quic-go/handshake/ephermal_cache.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go index da6724f..3bccbef 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/ephermal_cache.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/ephermal_cache.go @@ -4,9 +4,9 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) var ( diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go similarity index 86% rename from vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go index 0744cbd..c09db26 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go @@ -7,8 +7,8 @@ import ( "io" "sort" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -72,9 +72,9 @@ func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) { // Write writes a crypto message func (h HandshakeMessage) Write(b *bytes.Buffer) { data := h.Data - utils.WriteUint32(b, uint32(h.Tag)) - utils.WriteUint16(b, uint16(len(data))) - utils.WriteUint16(b, 0) + utils.LittleEndian.WriteUint32(b, uint32(h.Tag)) + utils.LittleEndian.WriteUint16(b, uint16(len(data))) + utils.LittleEndian.WriteUint16(b, 0) // Save current position in the buffer, so that we can update the index in-place later indexStart := b.Len() @@ -87,7 +87,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) { v := data[Tag(t)] b.Write(v) offset += uint32(len(v)) - binary.LittleEndian.PutUint32(indexData[i*8:], t) + binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t)) binary.LittleEndian.PutUint32(indexData[i*8+4:], offset) } @@ -95,14 +95,16 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) { copy(b.Bytes()[indexStart:], indexData) } -func (h *HandshakeMessage) getTagsSorted() []uint32 { - tags := make([]uint32, len(h.Data)) +func (h *HandshakeMessage) getTagsSorted() []Tag { + tags := make([]Tag, len(h.Data)) i := 0 for t := range h.Data { - tags[i] = uint32(t) + tags[i] = t i++ } - sort.Sort(utils.Uint32Slice(tags)) + sort.Slice(tags, func(i, j int) bool { + return tags[i] < tags[j] + }) return tags } diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go similarity index 52% rename from vendor/github.com/lucas-clemente/quic-go/handshake/interface.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go index 751aae1..c34c8f1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/interface.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go @@ -1,24 +1,25 @@ package handshake -import "github.com/lucas-clemente/quic-go/protocol" +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" +) // Sealer seals a packet -type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte +type Sealer interface { + Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + Overhead() int +} // CryptoSetup is a crypto setup type CryptoSetup interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) HandleCryptoStream() error // TODO: clean up this interface - DiversificationNonce() []byte // only needed for cryptoSetupServer - SetDiversificationNonce([]byte) // only needed for cryptoSetupClient + DiversificationNonce() []byte // only needed for cryptoSetupServer + SetDiversificationNonce([]byte) // only needed for cryptoSetupClient + GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) } - -// TransportParameters are parameters sent to the peer during the handshake -type TransportParameters struct { - RequestConnectionIDTruncation bool -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go new file mode 100644 index 0000000..8c3a83b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mint_utils.go @@ -0,0 +1,127 @@ +package handshake + +import ( + "bytes" + gocrypto "crypto" + "crypto/tls" + "crypto/x509" + "io" + "net" + "time" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { + mconf := &mint.Config{ + NonBlocking: true, + CipherSuites: []mint.CipherSuite{ + mint.TLS_AES_128_GCM_SHA256, + mint.TLS_AES_256_GCM_SHA384, + }, + } + if tlsConf != nil { + mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) + for i, certChain := range tlsConf.Certificates { + mconf.Certificates[i] = &mint.Certificate{ + Chain: make([]*x509.Certificate, len(certChain.Certificate)), + PrivateKey: certChain.PrivateKey.(gocrypto.Signer), + } + for j, cert := range certChain.Certificate { + c, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err + } + mconf.Certificates[i].Chain[j] = c + } + } + } + if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { + return nil, err + } + return mconf, nil +} + +type mintTLS interface { + // These two methods are the same as the crypto.TLSExporter interface. + // Cannot use embedding here, because mockgen source mode refuses to generate mocks then. + GetCipherSuite() mint.CipherSuiteParams + ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) + // additional methods + Handshake() mint.Alert + State() mint.ConnectionState +} + +var _ crypto.TLSExporter = (mintTLS)(nil) + +type mintController struct { + conn *mint.Conn +} + +var _ mintTLS = &mintController{} + +func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams { + return mc.conn.State().CipherSuite +} + +func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + return mc.conn.ComputeExporter(label, context, keyLength) +} + +func (mc *mintController) Handshake() mint.Alert { + return mc.conn.Handshake() +} + +func (mc *mintController) State() mint.ConnectionState { + return mc.conn.State() +} + +// mint expects a net.Conn, but we're doing the handshake on a stream +// so we wrap a stream such that implements a net.Conn +type fakeConn struct { + stream io.ReadWriter + pers protocol.Perspective + remoteAddr net.Addr + + blockRead bool + writeBuffer bytes.Buffer +} + +var _ net.Conn = &fakeConn{} + +func (c *fakeConn) Read(b []byte) (int, error) { + if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock + return 0, nil + } + c.blockRead = true // block the next Read call + return c.stream.Read(b) +} + +func (c *fakeConn) Write(p []byte) (int, error) { + if c.pers == protocol.PerspectiveClient { + return c.stream.Write(p) + } + // Buffer all writes by the server. + // Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data. + return c.writeBuffer.Write(p) +} + +func (c *fakeConn) Continue() error { + c.blockRead = false + if c.pers == protocol.PerspectiveClient { + return nil + } + // write all contents of the write buffer to the stream. + _, err := c.stream.Write(c.writeBuffer.Bytes()) + c.writeBuffer.Reset() + return err +} + +func (c *fakeConn) Close() error { return nil } +func (c *fakeConn) LocalAddr() net.Addr { return nil } +func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *fakeConn) SetReadDeadline(time.Time) error { return nil } +func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil } +func (c *fakeConn) SetDeadline(time.Time) error { return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go similarity index 88% rename from vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go index fce66ef..2b7fba6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config.go @@ -4,7 +4,7 @@ import ( "bytes" "crypto/rand" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" ) // ServerConfig is a server config @@ -13,6 +13,7 @@ type ServerConfig struct { certChain crypto.CertChain ID []byte obit []byte + cookieGenerator *CookieGenerator } // NewServerConfig creates a new server config @@ -28,11 +29,18 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve return nil, err } + cookieGenerator, err := NewCookieGenerator() + + if err != nil { + return nil, err + } + return &ServerConfig{ kex: kex, certChain: certChain, ID: id, obit: obit, + cookieGenerator: cookieGenerator, }, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go similarity index 98% rename from vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go index 4201419..eb042f6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go @@ -7,7 +7,7 @@ import ( "math" "time" - "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" ) diff --git a/vendor/github.com/lucas-clemente/quic-go/handshake/tags.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go similarity index 96% rename from vendor/github.com/lucas-clemente/quic-go/handshake/tags.go rename to vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go index 2b3783f..19ec78d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/handshake/tags.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tags.go @@ -54,6 +54,9 @@ const ( // Chrome experiment (see https://codereview.chromium.org/2115033002) // unsupported by quic-go TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24 + // TagNSTP is the no STOP_WAITING experiment + // currently unsupported by quic-go + TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24 // TagSTK is the source-address token TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go new file mode 100644 index 0000000..7e56e92 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go @@ -0,0 +1,54 @@ +package handshake + +import ( + "github.com/bifurcation/mint" +) + +type transportParameterID uint16 + +const quicTLSExtensionType = 26 + +const ( + initialMaxStreamDataParameterID transportParameterID = iota + initialMaxDataParameterID + initialMaxStreamIDParameterID + idleTimeoutParameterID + omitConnectionIDParameterID + maxPacketSizeParameterID + statelessResetTokenParameterID +) + +type transportParameter struct { + Parameter transportParameterID + Value []byte `tls:"head=2"` +} + +type clientHelloTransportParameters struct { + NegotiatedVersion uint32 // actually a protocol.VersionNumber + InitialVersion uint32 // actually a protocol.VersionNumber + Parameters []transportParameter `tls:"head=2"` +} + +type encryptedExtensionsTransportParameters struct { + SupportedVersions []uint32 `tls:"head=1"` // actually a protocol.VersionNumber + Parameters []transportParameter `tls:"head=2"` +} + +type tlsExtensionBody struct { + data []byte +} + +var _ mint.ExtensionBody = &tlsExtensionBody{} + +func (e *tlsExtensionBody) Type() mint.ExtensionType { + return quicTLSExtensionType +} + +func (e *tlsExtensionBody) Marshal() ([]byte, error) { + return e.data, nil +} + +func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) { + e.data = data + return len(data), nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go new file mode 100644 index 0000000..4187804 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go @@ -0,0 +1,122 @@ +package handshake + +import ( + "errors" + "fmt" + "math" + + "github.com/lucas-clemente/quic-go/qerr" + + "github.com/bifurcation/mint" + "github.com/bifurcation/mint/syntax" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type extensionHandlerClient struct { + params *TransportParameters + paramsChan chan<- TransportParameters + + initialVersion protocol.VersionNumber + supportedVersions []protocol.VersionNumber + version protocol.VersionNumber +} + +var _ mint.AppExtensionHandler = &extensionHandlerClient{} + +func newExtensionHandlerClient( + params *TransportParameters, + paramsChan chan<- TransportParameters, + initialVersion protocol.VersionNumber, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) *extensionHandlerClient { + return &extensionHandlerClient{ + params: params, + paramsChan: paramsChan, + initialVersion: initialVersion, + supportedVersions: supportedVersions, + version: version, + } +} + +func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { + if hType != mint.HandshakeTypeClientHello { + return nil + } + + data, err := syntax.Marshal(clientHelloTransportParameters{ + NegotiatedVersion: uint32(h.version), + InitialVersion: uint32(h.initialVersion), + Parameters: h.params.getTransportParameters(), + }) + if err != nil { + return err + } + return el.Add(&tlsExtensionBody{data}) +} + +func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { + ext := &tlsExtensionBody{} + found := el.Find(ext) + + if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { + if found { + return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) + } + return nil + } + if hType == mint.HandshakeTypeNewSessionTicket { + // the extension it's optional in the NewSessionTicket message + // TODO: handle this + return nil + } + + // hType == mint.HandshakeTypeEncryptedExtensions + if !found { + return errors.New("EncryptedExtensions message didn't contain a QUIC extension") + } + + eetp := &encryptedExtensionsTransportParameters{} + if _, err := syntax.Unmarshal(ext.data, eetp); err != nil { + return err + } + serverSupportedVersions := make([]protocol.VersionNumber, len(eetp.SupportedVersions)) + for i, v := range eetp.SupportedVersions { + serverSupportedVersions[i] = protocol.VersionNumber(v) + } + // check that the current version is included in the supported versions + if !protocol.IsSupportedVersion(serverSupportedVersions, h.version) { + return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions") + } + // if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server + if h.version != h.initialVersion { + negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, serverSupportedVersions) + if !ok || h.version != negotiatedVersion { + return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version") + } + } + + // check that the server sent the stateless reset token + var foundStatelessResetToken bool + for _, p := range eetp.Parameters { + if p.Parameter == statelessResetTokenParameterID { + if len(p.Value) != 16 { + return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", len(p.Value)) + } + foundStatelessResetToken = true + // TODO: handle this value + } + } + if !foundStatelessResetToken { + // TODO: return the right error here + return errors.New("server didn't sent stateless_reset_token") + } + params, err := readTransportParamters(eetp.Parameters) + if err != nil { + return err + } + // TODO(#878): remove this when implementing the MAX_STREAM_ID frame + params.MaxStreams = math.MaxUint32 + h.paramsChan <- *params + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go new file mode 100644 index 0000000..49830d8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go @@ -0,0 +1,109 @@ +package handshake + +import ( + "bytes" + "errors" + "fmt" + "math" + + "github.com/lucas-clemente/quic-go/qerr" + + "github.com/bifurcation/mint" + "github.com/bifurcation/mint/syntax" + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type extensionHandlerServer struct { + params *TransportParameters + paramsChan chan<- TransportParameters + + version protocol.VersionNumber + supportedVersions []protocol.VersionNumber +} + +var _ mint.AppExtensionHandler = &extensionHandlerServer{} + +func newExtensionHandlerServer( + params *TransportParameters, + paramsChan chan<- TransportParameters, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) *extensionHandlerServer { + return &extensionHandlerServer{ + params: params, + paramsChan: paramsChan, + version: version, + supportedVersions: supportedVersions, + } +} + +func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { + if hType != mint.HandshakeTypeEncryptedExtensions { + return nil + } + + transportParams := append( + h.params.getTransportParameters(), + // TODO(#855): generate a real token + transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, + ) + supportedVersions := make([]uint32, len(h.supportedVersions)) + for i, v := range h.supportedVersions { + supportedVersions[i] = uint32(v) + } + data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ + SupportedVersions: supportedVersions, + Parameters: transportParams, + }) + if err != nil { + return err + } + return el.Add(&tlsExtensionBody{data}) +} + +func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { + ext := &tlsExtensionBody{} + found := el.Find(ext) + + if hType != mint.HandshakeTypeClientHello { + if found { + return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) + } + return nil + } + + if !found { + return errors.New("ClientHello didn't contain a QUIC extension") + } + chtp := &clientHelloTransportParameters{} + if _, err := syntax.Unmarshal(ext.data, chtp); err != nil { + return err + } + initialVersion := protocol.VersionNumber(chtp.InitialVersion) + negotiatedVersion := protocol.VersionNumber(chtp.NegotiatedVersion) + // check that the negotiated version is the version we're currently using + if negotiatedVersion != h.version { + return qerr.Error(qerr.VersionNegotiationMismatch, "Inconsistent negotiated version") + } + // perform the stateless version negotiation validation: + // make sure that we would have sent a Version Negotiation Packet if the client offered the initial version + // this is the case when the initial version is not contained in the supported versions + if initialVersion != negotiatedVersion && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) { + return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version") + } + + for _, p := range chtp.Parameters { + if p.Parameter == statelessResetTokenParameterID { + // TODO: return the correct error type + return errors.New("client sent a stateless reset token") + } + } + params, err := readTransportParamters(chtp.Parameters) + if err != nil { + return err + } + // TODO(#878): remove this when implementing the MAX_STREAM_ID frame + params.MaxStreams = math.MaxUint32 + h.paramsChan <- *params + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go new file mode 100644 index 0000000..bda12c2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go @@ -0,0 +1,167 @@ +package handshake + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// errMalformedTag is returned when the tag value cannot be read +var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") + +// TransportParameters are parameters sent to the peer during the handshake +type TransportParameters struct { + StreamFlowControlWindow protocol.ByteCount + ConnectionFlowControlWindow protocol.ByteCount + + MaxStreams uint32 + + OmitConnectionID bool + IdleTimeout time.Duration +} + +// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message +func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) { + params := &TransportParameters{} + if value, ok := tags[TagTCID]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.OmitConnectionID = (v == 0) + } + if value, ok := tags[TagMIDS]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.MaxStreams = v + } + if value, ok := tags[TagICSL]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second) + } + if value, ok := tags[TagSFCW]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.StreamFlowControlWindow = protocol.ByteCount(v) + } + if value, ok := tags[TagCFCW]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.ConnectionFlowControlWindow = protocol.ByteCount(v) + } + return params, nil +} + +// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake. +func (p *TransportParameters) getHelloMap() map[Tag][]byte { + sfcw := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow)) + cfcw := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow)) + mids := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(mids, p.MaxStreams) + icsl := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second)) + + tags := map[Tag][]byte{ + TagICSL: icsl.Bytes(), + TagMIDS: mids.Bytes(), + TagCFCW: cfcw.Bytes(), + TagSFCW: sfcw.Bytes(), + } + if p.OmitConnectionID { + tags[TagTCID] = []byte{0, 0, 0, 0} + } + return tags +} + +// readTransportParameters reads the transport parameters sent in the QUIC TLS extension +func readTransportParamters(paramsList []transportParameter) (*TransportParameters, error) { + params := &TransportParameters{} + + var foundInitialMaxStreamData bool + var foundInitialMaxData bool + var foundInitialMaxStreamID bool + var foundIdleTimeout bool + + for _, p := range paramsList { + switch p.Parameter { + case initialMaxStreamDataParameterID: + foundInitialMaxStreamData = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value)) + } + params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) + case initialMaxDataParameterID: + foundInitialMaxData = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value)) + } + params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) + case initialMaxStreamIDParameterID: + foundInitialMaxStreamID = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value)) + } + // TODO: handle this value + case idleTimeoutParameterID: + foundIdleTimeout = true + if len(p.Value) != 2 { + return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) + } + params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second) + case omitConnectionIDParameterID: + if len(p.Value) != 0 { + return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) + } + params.OmitConnectionID = true + } + } + + if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) { + return nil, errors.New("missing parameter") + } + return params, nil +} + +// GetTransportParameters gets the parameters needed for the TLS handshake. +func (p *TransportParameters) getTransportParameters() []transportParameter { + initialMaxStreamData := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow)) + initialMaxData := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow)) + initialMaxStreamID := make([]byte, 4) + // TODO: use a reasonable value here + binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32) + idleTimeout := make([]byte, 2) + binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second)) + maxPacketSize := make([]byte, 2) + binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize)) + params := []transportParameter{ + {initialMaxStreamDataParameterID, initialMaxStreamData}, + {initialMaxDataParameterID, initialMaxData}, + {initialMaxStreamIDParameterID, initialMaxStreamID}, + {idleTimeoutParameterID, idleTimeout}, + {maxPacketSizeParameterID, maxPacketSize}, + } + if p.OmitConnectionID { + params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}}) + } + return params +} diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/packet_number.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go similarity index 74% rename from vendor/github.com/lucas-clemente/quic-go/protocol/packet_number.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go index c4f468a..4bc8bfc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/packet_number.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go @@ -27,18 +27,14 @@ func delta(a, b PacketNumber) PacketNumber { return a - b } -// GetPacketNumberLengthForPublicHeader gets the length of the packet number for the public header +// GetPacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForPublicHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen { +func GetPacketNumberLengthForHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen { diff := uint64(packetNumber - leastUnacked) - if diff < (2 << (uint8(PacketNumberLen2)*8 - 2)) { + if diff < (1 << (uint8(PacketNumberLen2)*8 - 1)) { return PacketNumberLen2 } - if diff < (2 << (uint8(PacketNumberLen4)*8 - 2)) { - return PacketNumberLen4 - } - // we do not check if there are less than 2^46 packets in flight, since flow control and congestion control will limit this number *a lot* sooner - return PacketNumberLen6 + return PacketNumberLen4 } // GetPacketNumberLength gets the minimum length needed to fully represent the packet number diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go similarity index 75% rename from vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go index cf9cf05..dadbf32 100644 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go @@ -21,6 +21,22 @@ const ( PacketNumberLen6 PacketNumberLen = 6 ) +// The PacketType is the Long Header Type (only used for the IETF draft header format) +type PacketType uint8 + +const ( + // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet + PacketTypeVersionNegotiation PacketType = 1 + // PacketTypeInitial is the packet type of a Initial packet + PacketTypeInitial PacketType = 2 + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry PacketType = 3 + // PacketTypeHandshake is the packet type of a Cleartext packet + PacketTypeHandshake PacketType = 4 + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT PacketType = 5 +) + // A ConnectionID in QUIC type ConnectionID uint64 @@ -43,12 +59,6 @@ const MaxReceivePacketSize ByteCount = 1452 // Used in QUIC for congestion window computations in bytes. const DefaultTCPMSS ByteCount = 1460 -// InitialStreamFlowControlWindow is the initial stream-level flow control window for sending -const InitialStreamFlowControlWindow ByteCount = (1 << 14) // 16 kB - -// InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending -const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB - // ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. const ClientHelloMinimumSize = 1024 diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go similarity index 74% rename from vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go rename to vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go index 8e632cc..697d787 100644 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go @@ -2,12 +2,9 @@ package protocol import "time" -// MaxPacketSize is the maximum packet size, including the public header, that we use for sending packets -// This is the value used by Chromium for a QUIC packet sent using IPv6 (for IPv4 it would be 1370) -const MaxPacketSize ByteCount = 1350 - -// MaxFrameAndPublicHeaderSize is the maximum size of a QUIC frame plus PublicHeader -const MaxFrameAndPublicHeaderSize = MaxPacketSize - 12 /*crypto signature*/ +// MaxPacketSize is the maximum packet size that we use for sending packets. +// It includes the QUIC packet header, but excludes the UDP and IP header. +const MaxPacketSize ByteCount = 1200 // NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames @@ -33,37 +30,34 @@ const AckSendDelay = 25 * time.Millisecond // ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data // This is the value that Google servers are using -const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB +const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB // ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data // This is the value that Google servers are using -const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB +const ReceiveConnectionFlowControlWindow = (1 << 10) * 48 // 48 kB // DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server // This is the value that Google servers are using -const DefaultMaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB +const DefaultMaxReceiveStreamFlowControlWindowServer = 1 * (1 << 20) // 1 MB // DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server // This is the value that Google servers are using -const DefaultMaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB +const DefaultMaxReceiveConnectionFlowControlWindowServer = 1.5 * (1 << 20) // 1.5 MB // DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client // This is the value that Chromium is using -const DefaultMaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB +const DefaultMaxReceiveStreamFlowControlWindowClient = 6 * (1 << 20) // 6 MB // DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client // This is the value that Google servers are using -const DefaultMaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB +const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 MB // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // This is the value that Chromium is using const ConnectionFlowControlMultiplier = 1.5 -// MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection -const MaxStreamsPerConnection = 100 - -// MaxIncomingDynamicStreamsPerConnection is the maximum value accepted for the incoming number of dynamic streams per connection -const MaxIncomingDynamicStreamsPerConnection = 100 +// MaxIncomingStreams is the maximum number of streams that a peer may open +const MaxIncomingStreams = 100 // MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used. const MaxStreamsMultiplier = 1.1 @@ -73,7 +67,7 @@ const MaxStreamsMinimumIncrement = 10 // MaxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened // note that the number of streams is half this value, since the client can only open streams with open StreamID -const MaxNewStreamIDDelta = 4 * MaxStreamsPerConnection +const MaxNewStreamIDDelta = 4 * MaxIncomingStreams // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow @@ -84,21 +78,21 @@ const SkipPacketAveragePeriodLength PacketNumber = 500 // MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation const MaxTrackedSkippedPackets = 10 -// STKExpiryTime is the valid time of a source address token -const STKExpiryTime = 24 * time.Hour +// CookieExpiryTime is the valid time of a cookie +const CookieExpiryTime = 24 * time.Hour // MaxTrackedSentPackets is maximum number of sent packets saved for either later retransmission or entropy calculation const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow -// MaxTrackedReceivedPackets is the maximum number of received packets saved for doing the entropy calculations -const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow - // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow // MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent const MaxPacketsReceivedBeforeAckSend = 20 +// MaxNonRetransmittablePackets is the maximum number of non-retransmittable packets that we send in a row +const MaxNonRetransmittablePackets = 19 + // RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for const RetransmittablePacketsBeforeAck = 2 @@ -116,18 +110,12 @@ const CryptoParameterMaxLength = 4000 // EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX. const EphermalKeyLifetime = time.Minute -// InitialIdleTimeout is the timeout before the handshake succeeds. -const InitialIdleTimeout = 5 * time.Second +// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout +const MinRemoteIdleTimeout = 5 * time.Second -// DefaultIdleTimeout is the default idle timeout, for the server +// DefaultIdleTimeout is the default idle timeout const DefaultIdleTimeout = 30 * time.Second -// MaxIdleTimeoutServer is the maximum idle timeout that can be negotiated, for the server -const MaxIdleTimeoutServer = 1 * time.Minute - -// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server -const MaxIdleTimeoutClient = 2 * time.Minute - // DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. const DefaultHandshakeTimeout = 10 * time.Second diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go new file mode 100644 index 0000000..5ad04f0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go @@ -0,0 +1,114 @@ +package protocol + +import ( + "fmt" +) + +// VersionNumber is a version number as int +type VersionNumber int + +// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions +const ( + gquicVersion0 = 0x51303030 + maxGquicVersion = 0x51303439 +) + +// The version numbers, making grepping easier +const ( + Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9 + iota + VersionTLS VersionNumber = 101 + VersionWhatever VersionNumber = 0 // for when the version doesn't matter + VersionUnknown VersionNumber = -1 +) + +// SupportedVersions lists the versions that the server supports +// must be in sorted descending order +var SupportedVersions = []VersionNumber{ + Version39, +} + +// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake +func (vn VersionNumber) UsesTLS() bool { + return vn == VersionTLS +} + +func (vn VersionNumber) String() string { + switch vn { + case VersionWhatever: + return "whatever" + case VersionUnknown: + return "unknown" + case VersionTLS: + return "TLS dev version (WIP)" + default: + if vn.isGQUIC() { + return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) + } + return fmt.Sprintf("%d", vn) + } +} + +// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters +func (vn VersionNumber) ToAltSvc() string { + if vn.isGQUIC() { + return fmt.Sprintf("%d", vn.toGQUICVersion()) + } + return fmt.Sprintf("%d", vn) +} + +// CryptoStreamID gets the Stream ID of the crypto stream +func (vn VersionNumber) CryptoStreamID() StreamID { + if vn.isGQUIC() { + return 1 + } + return 0 +} + +// UsesMaxDataFrame tells if this version uses MAX_DATA, MAX_STREAM_DATA, BLOCKED and STREAM_BLOCKED instead of WINDOW_UDPATE and BLOCKED frames +func (vn VersionNumber) UsesMaxDataFrame() bool { + return vn.CryptoStreamID() == 0 +} + +// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control +func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool { + if id == vn.CryptoStreamID() { + return false + } + if vn.isGQUIC() && id == 3 { + return false + } + return true +} + +func (vn VersionNumber) isGQUIC() bool { + return vn > gquicVersion0 && vn <= maxGquicVersion +} + +func (vn VersionNumber) toGQUICVersion() int { + return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) +} + +// IsSupportedVersion returns true if the server supports this version +func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { + for _, t := range supported { + if t == v { + return true + } + } + return false +} + +// ChooseSupportedVersion finds the best version in the overlap of ours and theirs +// ours is a slice of versions that we support, sorted by our preference (descending) +// theirs is a slice of versions offered by the peer. The order does not matter. +// The bool returned indicates if a matching version was found. +func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { + for _, ourVer := range ours { + for _, theirVer := range theirs { + if ourVer == theirVer { + return ourVer, true + } + } + } + return 0, false +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go new file mode 100644 index 0000000..35549f6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go @@ -0,0 +1,33 @@ +package utils + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. +type ByteOrder interface { + ReadUintN(b io.ByteReader, length uint8) (uint64, error) + ReadUint64(io.ByteReader) (uint64, error) + ReadUint32(io.ByteReader) (uint32, error) + ReadUint16(io.ByteReader) (uint16, error) + + WriteUint64(*bytes.Buffer, uint64) + WriteUint56(*bytes.Buffer, uint64) + WriteUint48(*bytes.Buffer, uint64) + WriteUint40(*bytes.Buffer, uint64) + WriteUint32(*bytes.Buffer, uint32) + WriteUint24(*bytes.Buffer, uint32) + WriteUint16(*bytes.Buffer, uint16) + + ReadUfloat16(io.ByteReader) (uint64, error) + WriteUfloat16(*bytes.Buffer, uint64) +} + +// GetByteOrder gets the ByteOrder to represent values on the wire +// from QUIC 39, values are encoded in big endian, before that in little endian +func GetByteOrder(v protocol.VersionNumber) ByteOrder { + return BigEndian +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go new file mode 100644 index 0000000..9f6c9a6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go @@ -0,0 +1,157 @@ +package utils + +import ( + "bytes" + "fmt" + "io" +) + +// BigEndian is the big-endian implementation of ByteOrder. +var BigEndian ByteOrder = bigEndian{} + +type bigEndian struct{} + +var _ ByteOrder = &bigEndian{} + +// ReadUintN reads N bytes +func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { + var res uint64 + for i := uint8(0); i < length; i++ { + bt, err := b.ReadByte() + if err != nil { + return 0, err + } + res ^= uint64(bt) << ((length - 1 - i) * 8) + } + return res, nil +} + +// ReadUint64 reads a uint64 +func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) { + var b1, b2, b3, b4, b5, b6, b7, b8 uint8 + var err error + if b8, err = b.ReadByte(); err != nil { + return 0, err + } + if b7, err = b.ReadByte(); err != nil { + return 0, err + } + if b6, err = b.ReadByte(); err != nil { + return 0, err + } + if b5, err = b.ReadByte(); err != nil { + return 0, err + } + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil +} + +// ReadUint32 reads a uint32 +func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { + var b1, b2, b3, b4 uint8 + var err error + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil +} + +// ReadUint16 reads a uint16 +func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { + var b1, b2 uint8 + var err error + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint16(b1) + uint16(b2)<<8, nil +} + +// WriteUint64 writes a uint64 +func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) { + b.Write([]byte{ + uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint56 writes 56 bit of a uint64 +func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) { + if i >= (1 << 56) { + panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i)) + } + b.Write([]byte{ + uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint48 writes 48 bit of a uint64 +func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) { + if i >= (1 << 48) { + panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i)) + } + b.Write([]byte{ + uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint40 writes 40 bit of a uint64 +func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) { + if i >= (1 << 40) { + panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i)) + } + b.Write([]byte{ + uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) +} + +// WriteUint32 writes a uint32 +func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint24 writes 24 bit of a uint32 +func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { + if i >= (1 << 24) { + panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i)) + } + b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint16 writes a uint16 +func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { + b.Write([]byte{uint8(i >> 8), uint8(i)}) +} + +func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) { + return readUfloat16(b, l) +} + +func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) { + writeUfloat16(b, l, val) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/utils.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go similarity index 64% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/utils.go rename to vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go index f6c4e03..71ff95d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/utils.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go @@ -2,11 +2,19 @@ package utils import ( "bytes" + "fmt" "io" ) +// LittleEndian is the little-endian implementation of ByteOrder. +var LittleEndian ByteOrder = littleEndian{} + +type littleEndian struct{} + +var _ ByteOrder = &littleEndian{} + // ReadUintN reads N bytes -func ReadUintN(b io.ByteReader, length uint8) (uint64, error) { +func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { var res uint64 for i := uint8(0); i < length; i++ { bt, err := b.ReadByte() @@ -19,7 +27,7 @@ func ReadUintN(b io.ByteReader, length uint8) (uint64, error) { } // ReadUint64 reads a uint64 -func ReadUint64(b io.ByteReader) (uint64, error) { +func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) { var b1, b2, b3, b4, b5, b6, b7, b8 uint8 var err error if b1, err = b.ReadByte(); err != nil { @@ -50,7 +58,7 @@ func ReadUint64(b io.ByteReader) (uint64, error) { } // ReadUint32 reads a uint32 -func ReadUint32(b io.ByteReader) (uint32, error) { +func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) { var b1, b2, b3, b4 uint8 var err error if b1, err = b.ReadByte(); err != nil { @@ -69,7 +77,7 @@ func ReadUint32(b io.ByteReader) (uint32, error) { } // ReadUint16 reads a uint16 -func ReadUint16(b io.ByteReader) (uint16, error) { +func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) { var b1, b2 uint8 var err error if b1, err = b.ReadByte(); err != nil { @@ -82,7 +90,7 @@ func ReadUint16(b io.ByteReader) (uint16, error) { } // WriteUint64 writes a uint64 -func WriteUint64(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) { b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56), @@ -90,7 +98,10 @@ func WriteUint64(b *bytes.Buffer, i uint64) { } // WriteUint56 writes 56 bit of a uint64 -func WriteUint56(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) { + if i >= (1 << 56) { + panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), @@ -98,7 +109,10 @@ func WriteUint56(b *bytes.Buffer, i uint64) { } // WriteUint48 writes 48 bit of a uint64 -func WriteUint48(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) { + if i >= (1 << 48) { + panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), @@ -106,7 +120,10 @@ func WriteUint48(b *bytes.Buffer, i uint64) { } // WriteUint40 writes 40 bit of a uint64 -func WriteUint40(b *bytes.Buffer, i uint64) { +func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) { + if i >= (1 << 40) { + panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), @@ -114,23 +131,27 @@ func WriteUint40(b *bytes.Buffer, i uint64) { } // WriteUint32 writes a uint32 -func WriteUint32(b *bytes.Buffer, i uint32) { +func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) { b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)}) } // WriteUint24 writes 24 bit of a uint32 -func WriteUint24(b *bytes.Buffer, i uint32) { +func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) { + if i >= (1 << 24) { + panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i)) + } b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)}) } // WriteUint16 writes a uint16 -func WriteUint16(b *bytes.Buffer, i uint16) { +func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) { b.Write([]byte{uint8(i), uint8(i >> 8)}) } -// Uint32Slice attaches the methods of sort.Interface to []uint32, sorting in increasing order. -type Uint32Slice []uint32 +func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) { + return readUfloat16(b, l) +} -func (s Uint32Slice) Len() int { return len(s) } -func (s Uint32Slice) Less(i, j int) bool { return s[i] < s[j] } -func (s Uint32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) { + writeUfloat16(b, l, val) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go index c2252e6..b4af4e7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/connection_id.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "encoding/binary" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // GenerateConnectionID generates a connection ID using cryptographic random diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go index 8abdb51..8e2ca1b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/float16.go @@ -21,9 +21,9 @@ const uFloat16MantissaBits = 16 - uFloat16ExponentBits const uFloat16MantissaEffectiveBits = uFloat16MantissaBits + 1 // 12 const uFloat16MaxValue = ((uint64(1) << uFloat16MantissaEffectiveBits) - 1) << uFloat16MaxExponent // 0x3FFC0000000 -// ReadUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation -func ReadUfloat16(b io.ByteReader) (uint64, error) { - val, err := ReadUint16(b) +// readUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation +func readUfloat16(b io.ByteReader, byteOrder ByteOrder) (uint64, error) { + val, err := byteOrder.ReadUint16(b) if err != nil { return 0, err } @@ -50,8 +50,8 @@ func ReadUfloat16(b io.ByteReader) (uint64, error) { return res, nil } -// WriteUfloat16 writes a float in the QUIC-float16 format from its uint64 representation -func WriteUfloat16(b *bytes.Buffer, value uint64) { +// writeUfloat16 writes a float in the QUIC-float16 format from its uint64 representation +func writeUfloat16(b *bytes.Buffer, byteOrder ByteOrder, value uint64) { var result uint16 if value < (uint64(1) << uFloat16MantissaEffectiveBits) { // Fast path: either the value is denormalized, or has exponent zero. @@ -82,5 +82,5 @@ func WriteUfloat16(b *bytes.Buffer, value uint64) { result = (uint16(value) + (exponent << uFloat16MantissaBits)) } - WriteUint16(b, result) + byteOrder.WriteUint16(b, result) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go index 9128510..342d8dd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "strings" "time" ) @@ -79,14 +80,14 @@ func init() { } func readLoggingEnv() { - switch os.Getenv(logEnv) { + switch strings.ToLower(os.Getenv(logEnv)) { case "": return - case "DEBUG": + case "debug": logLevel = LogLevelDebug - case "INFO": + case "info": logLevel = LogLevelInfo - case "ERROR": + case "error": logLevel = LogLevelError default: fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go index 6e23df5..c984a3c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go @@ -4,7 +4,7 @@ import ( "math" "time" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // InfDuration is a duration of infinite length diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go index 09800b6..f49b0c4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go @@ -1,6 +1,6 @@ package utils -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" // PacketInterval is an interval from one PacketNumber to the other // +gen linkedlist diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go index c918b62..3c8325b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go @@ -1,6 +1,6 @@ package utils -import "github.com/lucas-clemente/quic-go/protocol" +import "github.com/lucas-clemente/quic-go/internal/protocol" // ByteInterval is an interval from one ByteCount to the other // +gen linkedlist diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go similarity index 78% rename from vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go index ceeba48..2d60baa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go @@ -1,12 +1,12 @@ -package frames +package wire import ( "bytes" "errors" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) var ( @@ -25,7 +25,7 @@ var ( type AckFrame struct { LargestAcked protocol.PacketNumber LowestAcked protocol.PacketNumber - AckRanges []AckRange // has to be ordered. The ACK range with the highest FirstPacketNumber goes first, the ACK range with the lowest FirstPacketNumber goes last + AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last // time when the LargestAcked was receiveid // this field Will not be set for received ACKs frames @@ -57,13 +57,13 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, missingSequenceNumberDeltaLen = 1 } - largestAcked, err := utils.ReadUintN(r, largestAckedLen) + largestAcked, err := utils.GetByteOrder(version).ReadUintN(r, largestAckedLen) if err != nil { return nil, err } frame.LargestAcked = protocol.PacketNumber(largestAcked) - delay, err := utils.ReadUfloat16(r) + delay, err := utils.GetByteOrder(version).ReadUfloat16(r) if err != nil { return nil, err } @@ -81,7 +81,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, return nil, ErrInvalidAckRanges } - ackBlockLength, err := utils.ReadUintN(r, missingSequenceNumberDeltaLen) + ackBlockLength, err := utils.GetByteOrder(version).ReadUintN(r, missingSequenceNumberDeltaLen) if err != nil { return nil, err } @@ -95,8 +95,8 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, if hasMissingRanges { ackRange := AckRange{ - FirstPacketNumber: protocol.PacketNumber(largestAcked-ackBlockLength) + 1, - LastPacketNumber: frame.LargestAcked, + First: protocol.PacketNumber(largestAcked-ackBlockLength) + 1, + Last: frame.LargestAcked, } frame.AckRanges = append(frame.AckRanges, ackRange) @@ -109,7 +109,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, return nil, err } - ackBlockLength, err = utils.ReadUintN(r, missingSequenceNumberDeltaLen) + ackBlockLength, err = utils.GetByteOrder(version).ReadUintN(r, missingSequenceNumberDeltaLen) if err != nil { return nil, err } @@ -117,14 +117,14 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, length := protocol.PacketNumber(ackBlockLength) if inLongBlock { - frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber -= protocol.PacketNumber(gap) + length - frame.AckRanges[len(frame.AckRanges)-1].LastPacketNumber -= protocol.PacketNumber(gap) + frame.AckRanges[len(frame.AckRanges)-1].First -= protocol.PacketNumber(gap) + length + frame.AckRanges[len(frame.AckRanges)-1].Last -= protocol.PacketNumber(gap) } else { lastRangeComplete = false ackRange := AckRange{ - LastPacketNumber: frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber - protocol.PacketNumber(gap) - 1, + Last: frame.AckRanges[len(frame.AckRanges)-1].First - protocol.PacketNumber(gap) - 1, } - ackRange.FirstPacketNumber = ackRange.LastPacketNumber - length + 1 + ackRange.First = ackRange.Last - length + 1 frame.AckRanges = append(frame.AckRanges, ackRange) } @@ -135,13 +135,13 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, inLongBlock = (ackBlockLength == 0) } - // if the last range was not complete, FirstPacketNumber and LastPacketNumber make no sense + // if the last range was not complete, First and Last make no sense // remove the range from frame.AckRanges if !lastRangeComplete { frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1] } - frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber + frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].First } else { if frame.LargestAcked == 0 { frame.LowestAcked = 0 @@ -167,7 +167,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, return nil, err } // First Timestamp - _, err = utils.ReadUint32(r) + _, err = utils.GetByteOrder(version).ReadUint32(r) if err != nil { return nil, err } @@ -180,13 +180,12 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, } // Time Since Previous Timestamp - _, err = utils.ReadUint16(r) + _, err = utils.GetByteOrder(version).ReadUint16(r) if err != nil { return nil, err } } } - return frame, nil } @@ -215,15 +214,15 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(f.LargestAcked)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(f.LargestAcked)) + utils.GetByteOrder(version).WriteUint16(b, uint16(f.LargestAcked)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(f.LargestAcked)) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.LargestAcked)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(f.LargestAcked)) + utils.GetByteOrder(version).WriteUint48(b, uint64(f.LargestAcked)&(1<<48-1)) } f.DelayTime = time.Since(f.PacketReceivedTime) - utils.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) + utils.GetByteOrder(version).WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond)) var numRanges uint64 var numRangesWritten uint64 @@ -239,13 +238,13 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error if !f.HasMissingRanges() { firstAckBlockLength = f.LargestAcked - f.LowestAcked + 1 } else { - if f.LargestAcked != f.AckRanges[0].LastPacketNumber { + if f.LargestAcked != f.AckRanges[0].Last { return errInconsistentAckLargestAcked } - if f.LowestAcked != f.AckRanges[len(f.AckRanges)-1].FirstPacketNumber { + if f.LowestAcked != f.AckRanges[len(f.AckRanges)-1].First { return errInconsistentAckLowestAcked } - firstAckBlockLength = f.LargestAcked - f.AckRanges[0].FirstPacketNumber + 1 + firstAckBlockLength = f.LargestAcked - f.AckRanges[0].First + 1 numRangesWritten++ } @@ -253,11 +252,11 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(firstAckBlockLength)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(firstAckBlockLength)) + utils.GetByteOrder(version).WriteUint16(b, uint16(firstAckBlockLength)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(firstAckBlockLength)) + utils.GetByteOrder(version).WriteUint32(b, uint32(firstAckBlockLength)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(firstAckBlockLength)) + utils.GetByteOrder(version).WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1)) } for i, ackRange := range f.AckRanges { @@ -265,8 +264,8 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error continue } - length := ackRange.LastPacketNumber - ackRange.FirstPacketNumber + 1 - gap := f.AckRanges[i-1].FirstPacketNumber - ackRange.LastPacketNumber - 1 + length := ackRange.Last - ackRange.First + 1 + gap := f.AckRanges[i-1].First - ackRange.Last - 1 num := gap/0xFF + 1 if gap%0xFF == 0 { @@ -279,11 +278,11 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(length)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(length)) + utils.GetByteOrder(version).WriteUint16(b, uint16(length)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(length)) + utils.GetByteOrder(version).WriteUint32(b, uint32(length)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(length)) + utils.GetByteOrder(version).WriteUint48(b, uint64(length)&(1<<48-1)) } numRangesWritten++ } else { @@ -304,11 +303,11 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen1: b.WriteByte(uint8(lengthWritten)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(lengthWritten)) + utils.GetByteOrder(version).WriteUint16(b, uint16(lengthWritten)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(lengthWritten)) + utils.GetByteOrder(version).WriteUint32(b, uint32(lengthWritten)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, lengthWritten) + utils.GetByteOrder(version).WriteUint48(b, lengthWritten&(1<<48-1)) } numRangesWritten++ @@ -326,7 +325,6 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error } b.WriteByte(0) // no timestamps - return nil } @@ -363,13 +361,13 @@ func (f *AckFrame) validateAckRanges() bool { return false } - if f.AckRanges[0].LastPacketNumber != f.LargestAcked { + if f.AckRanges[0].Last != f.LargestAcked { return false } // check the validity of every single ACK range for _, ackRange := range f.AckRanges { - if ackRange.FirstPacketNumber > ackRange.LastPacketNumber { + if ackRange.First > ackRange.Last { return false } } @@ -380,10 +378,10 @@ func (f *AckFrame) validateAckRanges() bool { continue } lastAckRange := f.AckRanges[i-1] - if lastAckRange.FirstPacketNumber <= ackRange.FirstPacketNumber { + if lastAckRange.First <= ackRange.First { return false } - if lastAckRange.FirstPacketNumber <= ackRange.LastPacketNumber+1 { + if lastAckRange.First <= ackRange.Last+1 { return false } } @@ -405,7 +403,7 @@ func (f *AckFrame) numWritableNackRanges() uint64 { } lastAckRange := f.AckRanges[i-1] - gap := lastAckRange.FirstPacketNumber - ackRange.LastPacketNumber - 1 + gap := lastAckRange.First - ackRange.Last - 1 rangeLength := 1 + uint64(gap)/0xFF if uint64(gap)%0xFF == 0 { rangeLength-- @@ -426,7 +424,7 @@ func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen { if f.HasMissingRanges() { for _, ackRange := range f.AckRanges { - rangeLength := ackRange.LastPacketNumber - ackRange.FirstPacketNumber + 1 + rangeLength := ackRange.Last - ackRange.First + 1 if rangeLength > maxRangeLength { maxRangeLength = rangeLength } @@ -457,7 +455,7 @@ func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { if f.HasMissingRanges() { // TODO: this could be implemented as a binary search for _, ackRange := range f.AckRanges { - if p >= ackRange.FirstPacketNumber && p <= ackRange.LastPacketNumber { + if p >= ackRange.First && p <= ackRange.Last { return true } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go new file mode 100644 index 0000000..c561762 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go @@ -0,0 +1,9 @@ +package wire + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// AckRange is an ACK range +type AckRange struct { + First protocol.PacketNumber + Last protocol.PacketNumber +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go new file mode 100644 index 0000000..08dc051 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame.go @@ -0,0 +1,35 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A BlockedFrame is a BLOCKED frame +type BlockedFrame struct{} + +// ParseBlockedFrame parses a BLOCKED frame +func ParseBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*BlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &BlockedFrame{}, nil +} + +func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesMaxDataFrame() { + return (&blockedFrameLegacy{}).Write(b, version) + } + typeByte := uint8(0x08) + b.WriteByte(typeByte) + return nil +} + +// MinLength of a written frame +func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + if !version.UsesMaxDataFrame() { // writing this frame would result in a legacy BLOCKED being written, which is longer + return 1 + 4, nil + } + return 1, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go new file mode 100644 index 0000000..d60ca4c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type blockedFrameLegacy struct { + StreamID protocol.StreamID +} + +// ParseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format) +// The frame returned is +// * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream +// * a BLOCKED frame, if the BLOCKED applies to the connection +func ParseBlockedFrameLegacy(r *bytes.Reader, version protocol.VersionNumber) (Frame, error) { + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + streamID, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + if streamID == 0 { + return &BlockedFrame{}, nil + } + return &StreamBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil +} + +//Write writes a BLOCKED frame +func (f *blockedFrameLegacy) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x05) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go similarity index 62% rename from vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go index 5a7ed04..432c6a8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go @@ -1,4 +1,4 @@ -package frames +package wire import ( "bytes" @@ -6,8 +6,8 @@ import ( "io" "math" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -18,7 +18,7 @@ type ConnectionCloseFrame struct { } // ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame -func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) { +func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) { frame := &ConnectionCloseFrame{} // read the TypeByte @@ -27,23 +27,27 @@ func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) { return nil, err } - errorCode, err := utils.ReadUint32(r) + errorCode, err := utils.GetByteOrder(version).ReadUint32(r) if err != nil { return nil, err } frame.ErrorCode = qerr.ErrorCode(errorCode) - reasonPhraseLen, err := utils.ReadUint16(r) + reasonPhraseLen, err := utils.GetByteOrder(version).ReadUint16(r) if err != nil { return nil, err } - if reasonPhraseLen > uint16(protocol.MaxPacketSize) { - return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long") + // shortcut to prevent the unneccessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the whole reason phrase would result in EOF when attempting to READ + if int(reasonPhraseLen) > r.Len() { + return nil, io.EOF } reasonPhrase := make([]byte, reasonPhraseLen) if _, err := io.ReadFull(r, reasonPhrase); err != nil { + // this should never happen, since we already checked the reasonPhraseLen earlier return nil, err } frame.ReasonPhrase = string(reasonPhrase) @@ -59,14 +63,14 @@ func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protoc // Write writes an CONNECTION_CLOSE frame. func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { b.WriteByte(0x02) - utils.WriteUint32(b, uint32(f.ErrorCode)) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.ErrorCode)) if len(f.ReasonPhrase) > math.MaxUint16 { return errors.New("ConnectionFrame: ReasonPhrase too long") } reasonPhraseLen := uint16(len(f.ReasonPhrase)) - utils.WriteUint16(b, reasonPhraseLen) + utils.GetByteOrder(version).WriteUint16(b, reasonPhraseLen) b.WriteString(f.ReasonPhrase) return nil diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go similarity index 74% rename from vendor/github.com/lucas-clemente/quic-go/frames/frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go index 464e669..f31f5bf 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame.go @@ -1,9 +1,9 @@ -package frames +package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // A Frame in QUIC diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/goaway_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go similarity index 66% rename from vendor/github.com/lucas-clemente/quic-go/frames/goaway_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go index e00a6cf..5332210 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/goaway_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/goaway_frame.go @@ -1,11 +1,11 @@ -package frames +package wire import ( "bytes" "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -17,27 +17,26 @@ type GoawayFrame struct { } // ParseGoawayFrame parses a GOAWAY frame -func ParseGoawayFrame(r *bytes.Reader) (*GoawayFrame, error) { +func ParseGoawayFrame(r *bytes.Reader, version protocol.VersionNumber) (*GoawayFrame, error) { frame := &GoawayFrame{} - _, err := r.ReadByte() - if err != nil { + if _, err := r.ReadByte(); err != nil { return nil, err } - errorCode, err := utils.ReadUint32(r) + errorCode, err := utils.GetByteOrder(version).ReadUint32(r) if err != nil { return nil, err } frame.ErrorCode = qerr.ErrorCode(errorCode) - lastGoodStream, err := utils.ReadUint32(r) + lastGoodStream, err := utils.GetByteOrder(version).ReadUint32(r) if err != nil { return nil, err } frame.LastGoodStream = protocol.StreamID(lastGoodStream) - reasonPhraseLen, err := utils.ReadUint16(r) + reasonPhraseLen, err := utils.GetByteOrder(version).ReadUint16(r) if err != nil { return nil, err } @@ -51,19 +50,15 @@ func ParseGoawayFrame(r *bytes.Reader) (*GoawayFrame, error) { return nil, err } frame.ReasonPhrase = string(reasonPhrase) - return frame, nil } func (f *GoawayFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x03) - b.WriteByte(typeByte) - - utils.WriteUint32(b, uint32(f.ErrorCode)) - utils.WriteUint32(b, uint32(f.LastGoodStream)) - utils.WriteUint16(b, uint16(len(f.ReasonPhrase))) + b.WriteByte(0x03) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.ErrorCode)) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.LastGoodStream)) + utils.GetByteOrder(version).WriteUint16(b, uint16(len(f.ReasonPhrase))) b.WriteString(f.ReasonPhrase) - return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go new file mode 100644 index 0000000..96066cc --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go @@ -0,0 +1,111 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// Header is the header of a QUIC packet. +// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header. +type Header struct { + Raw []byte + ConnectionID protocol.ConnectionID + OmitConnectionID bool + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + Version protocol.VersionNumber // VersionNumber sent by the client + SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server + + // only needed for the gQUIC Public Header + VersionFlag bool + ResetFlag bool + DiversificationNonce []byte + + // only needed for the IETF Header + Type protocol.PacketType + IsLongHeader bool + KeyPhase int + + // only needed for logging + isPublicHeader bool +} + +// ParseHeaderSentByServer parses the header for a packet that was sent by the server. +func ParseHeaderSentByServer(b *bytes.Reader, version protocol.VersionNumber) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + _ = b.UnreadByte() // unread the type byte + + var isPublicHeader bool + // As a client, we know the version of the packet that the server sent, except for Version Negotiation Packets. + if typeByte == 0x81 { // IETF draft Version Negotiation Packet + isPublicHeader = false + } else if typeByte&0xcf == 0x9 { // gQUIC Version Negotiation Packet + // IETF QUIC Version Negotiation Packets are sent with the Long Header (indicated by the 0x80 bit) + // gQUIC always has 0x80 unset + isPublicHeader = true + } else { // not a Version Negotiation Packet + // the client knows the version that this packet was sent with + isPublicHeader = !version.UsesTLS() + } + return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader) +} + +// ParseHeaderSentByClient parses the header for a packet that was sent by the client. +func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + _ = b.UnreadByte() // unread the type byte + + // If this is a gQUIC header 0x80 and 0x40 will be set to 0. + // If this is an IETF QUIC header there are two options: + // * either 0x80 will be 1 (for the Long Header) + // * or 0x40 (the Connection ID Flag) will be 0 (for the Short Header), since we don't the client to omit it + isPublicHeader := typeByte&0xc0 == 0 + + return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader) +} + +func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHeader bool) (*Header, error) { + // This is a gQUIC Public Header. + if isPublicHeader { + hdr, err := parsePublicHeader(b, sentBy) + if err != nil { + return nil, err + } + hdr.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later + return hdr, nil + } + return parseHeader(b, sentBy) +} + +// Write writes the Header. +func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { + if !version.UsesTLS() { + h.isPublicHeader = true // save that this is a Public Header, so we can log it correctly later + return h.writePublicHeader(b, pers, version) + } + return h.writeHeader(b) +} + +// GetLength determines the length of the Header. +func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNumber) (protocol.ByteCount, error) { + if !version.UsesTLS() { + return h.getPublicHeaderLength(pers) + } + return h.getHeaderLength() +} + +// Log logs the Header +func (h *Header) Log() { + if h.isPublicHeader { + h.logPublicHeader() + } else { + h.logHeader() + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go new file mode 100644 index 0000000..3db67cc --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go @@ -0,0 +1,170 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// parseHeader parses the header. +func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + if typeByte&0x80 > 0 { + return parseLongHeader(b, packetSentBy, typeByte) + } + return parseShortHeader(b, typeByte) +} + +func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte) (*Header, error) { + connID, err := utils.BigEndian.ReadUint64(b) + if err != nil { + return nil, err + } + pn, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + packetType := protocol.PacketType(typeByte & 0x7f) + if sentBy == protocol.PerspectiveClient && (packetType != protocol.PacketTypeInitial && packetType != protocol.PacketTypeHandshake && packetType != protocol.PacketType0RTT) { + if packetType == protocol.PacketTypeVersionNegotiation { + return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "sent by the client") + } + return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", packetType)) + } + if sentBy == protocol.PerspectiveServer && (packetType != protocol.PacketTypeVersionNegotiation && packetType != protocol.PacketTypeRetry && packetType != protocol.PacketTypeHandshake) { + return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", packetType)) + } + h := &Header{ + Type: packetType, + IsLongHeader: true, + ConnectionID: protocol.ConnectionID(connID), + PacketNumber: protocol.PacketNumber(pn), + PacketNumberLen: protocol.PacketNumberLen4, + Version: protocol.VersionNumber(v), + } + if h.Type == protocol.PacketTypeVersionNegotiation { + if b.Len() == 0 { + return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") + } + h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, qerr.InvalidVersionNegotiationPacket + } + h.SupportedVersions[i] = protocol.VersionNumber(v) + } + } + return h, nil +} + +func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { + hasConnID := typeByte&0x40 > 0 + var connID uint64 + if hasConnID { + var err error + connID, err = utils.BigEndian.ReadUint64(b) + if err != nil { + return nil, err + } + } + pnLen := 1 << ((typeByte & 0x3) - 1) + pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen)) + if err != nil { + return nil, err + } + return &Header{ + KeyPhase: int(typeByte&0x20) >> 5, + OmitConnectionID: !hasConnID, + ConnectionID: protocol.ConnectionID(connID), + PacketNumber: protocol.PacketNumber(pn), + PacketNumberLen: protocol.PacketNumberLen(pnLen), + }, nil +} + +// writeHeader writes the Header. +func (h *Header) writeHeader(b *bytes.Buffer) error { + if h.IsLongHeader { + return h.writeLongHeader(b) + } + return h.writeShortHeader(b) +} + +// TODO: add support for the key phase +func (h *Header) writeLongHeader(b *bytes.Buffer) error { + b.WriteByte(byte(0x80 ^ h.Type)) + utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + return nil +} + +func (h *Header) writeShortHeader(b *bytes.Buffer) error { + typeByte := byte(h.KeyPhase << 5) + if !h.OmitConnectionID { + typeByte ^= 0x40 + } + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + typeByte ^= 0x1 + case protocol.PacketNumberLen2: + typeByte ^= 0x2 + case protocol.PacketNumberLen4: + typeByte ^= 0x3 + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + b.WriteByte(typeByte) + + if !h.OmitConnectionID { + utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) + } + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + } + return nil +} + +// getHeaderLength gets the length of the Header in bytes. +func (h *Header) getHeaderLength() (protocol.ByteCount, error) { + if h.IsLongHeader { + return 1 + 8 + 4 + 4, nil + } + + length := protocol.ByteCount(1) // type byte + if !h.OmitConnectionID { + length += 8 + } + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { + return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + length += protocol.ByteCount(h.PacketNumberLen) + return length, nil +} + +func (h *Header) logHeader() { + if h.IsLongHeader { + utils.Debugf(" Long Header{Type: %#x, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) + } else { + connID := "(omitted)" + if !h.OmitConnectionID { + connID = fmt.Sprintf("%#x", h.ConnectionID) + } + utils.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go new file mode 100644 index 0000000..0e72ea9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go @@ -0,0 +1,28 @@ +package wire + +import "github.com/lucas-clemente/quic-go/internal/utils" + +// LogFrame logs a frame, either sent or received +func LogFrame(frame Frame, sent bool) { + if !utils.Debug() { + return + } + dir := "<-" + if sent { + dir = "->" + } + switch f := frame.(type) { + case *StreamFrame: + utils.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + case *StopWaitingFrame: + if sent { + utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) + } else { + utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) + } + case *AckFrame: + utils.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) + default: + utils.Debugf("\t%s %#v", dir, frame) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go new file mode 100644 index 0000000..cd3ff65 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go @@ -0,0 +1,51 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxDataFrame carries flow control information for the connection +type MaxDataFrame struct { + ByteOffset protocol.ByteCount +} + +// ParseMaxDataFrame parses a MAX_DATA frame +func ParseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDataFrame, error) { + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &MaxDataFrame{} + byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) + if err != nil { + return nil, err + } + frame.ByteOffset = protocol.ByteCount(byteOffset) + return frame, nil +} + +//Write writes a MAX_STREAM_DATA frame +func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesMaxDataFrame() { + // write a gQUIC WINDOW_UPDATE frame (with stream ID 0, which means connection-level there) + return (&windowUpdateFrame{ + StreamID: 0, + ByteOffset: f.ByteOffset, + }).Write(b, version) + } + b.WriteByte(0x4) + utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + return nil +} + +// MinLength of a written frame +func (f *MaxDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + if !version.UsesMaxDataFrame() { // writing this frame would result in a gQUIC WINDOW_UPDATE being written, which is longer + return 1 + 4 + 8, nil + } + return 1 + 8, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go new file mode 100644 index 0000000..56c44c9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go @@ -0,0 +1,56 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxStreamDataFrame carries flow control information for a stream +type MaxStreamDataFrame struct { + StreamID protocol.StreamID + ByteOffset protocol.ByteCount +} + +// ParseMaxStreamDataFrame parses a MAX_STREAM_DATA frame +func ParseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxStreamDataFrame, error) { + frame := &MaxStreamDataFrame{} + + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(sid) + + byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) + if err != nil { + return nil, err + } + frame.ByteOffset = protocol.ByteCount(byteOffset) + return frame, nil +} + +// Write writes a MAX_STREAM_DATA frame +func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesMaxDataFrame() { + return (&windowUpdateFrame{ + StreamID: f.StreamID, + ByteOffset: f.ByteOffset, + }).Write(b, version) + } + b.WriteByte(0x5) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + return nil +} + +// MinLength of a written frame +func (f *MaxStreamDataFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + return 1 + 4 + 8, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/ping_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go similarity index 76% rename from vendor/github.com/lucas-clemente/quic-go/frames/ping_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go index 8486af5..2a09c33 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/ping_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go @@ -1,16 +1,16 @@ -package frames +package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // A PingFrame is a ping frame type PingFrame struct{} // ParsePingFrame parses a Ping frame -func ParsePingFrame(r *bytes.Reader) (*PingFrame, error) { +func ParsePingFrame(r *bytes.Reader, version protocol.VersionNumber) (*PingFrame, error) { frame := &PingFrame{} _, err := r.ReadByte() diff --git a/vendor/github.com/lucas-clemente/quic-go/public_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go similarity index 57% rename from vendor/github.com/lucas-clemente/quic-go/public_header.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go index 59ddc6c..ba5c8e6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/public_header.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go @@ -1,61 +1,45 @@ -package quic +package wire import ( "bytes" "errors" + "fmt" + "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) var ( - errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") + errReceivedOmittedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported") errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") ) -// The PublicHeader of a QUIC packet. Warning: This struct should not be considered stable and will change soon. -type PublicHeader struct { - Raw []byte - ConnectionID protocol.ConnectionID - VersionFlag bool - ResetFlag bool - TruncateConnectionID bool - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - VersionNumber protocol.VersionNumber // VersionNumber sent by the client - SupportedVersions []protocol.VersionNumber // VersionNumbers sent by the server - DiversificationNonce []byte -} - -// Write writes a public header. Warning: This API should not be considered stable and will change soon. -func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error { - publicFlagByte := uint8(0x00) - +// writePublicHeader writes a Public Header. +func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet } + publicFlagByte := uint8(0x00) if h.VersionFlag { publicFlagByte |= 0x01 } if h.ResetFlag { publicFlagByte |= 0x02 } - if !h.TruncateConnectionID { + if !h.OmitConnectionID { publicFlagByte |= 0x08 } - if len(h.DiversificationNonce) > 0 { if len(h.DiversificationNonce) != 32 { return errors.New("invalid diversification nonce length") } publicFlagByte |= 0x04 } - // only set PacketNumberLen bits if a packet number will be written if h.hasPacketNumber(pers) { switch h.PacketNumberLen { @@ -69,59 +53,50 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe publicFlagByte |= 0x30 } } - b.WriteByte(publicFlagByte) - if !h.TruncateConnectionID { - utils.WriteUint64(b, uint64(h.ConnectionID)) + if !h.OmitConnectionID { + utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) } - if h.VersionFlag && pers == protocol.PerspectiveClient { - utils.WriteUint32(b, protocol.VersionNumberToTag(h.VersionNumber)) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) } - if len(h.DiversificationNonce) > 0 { b.Write(h.DiversificationNonce) } - // if we're a server, and the VersionFlag is set, we must not include anything else in the packet if !h.hasPacketNumber(pers) { return nil } - if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { - return errPacketNumberLenNotSet - } - switch h.PacketNumberLen { case protocol.PacketNumberLen1: b.WriteByte(uint8(h.PacketNumber)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(h.PacketNumber)) + utils.GetByteOrder(version).WriteUint16(b, uint16(h.PacketNumber)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(h.PacketNumber)) + utils.GetByteOrder(version).WriteUint32(b, uint32(h.PacketNumber)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(h.PacketNumber)) + utils.GetByteOrder(version).WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) default: - return errPacketNumberLenNotSet + return errors.New("PublicHeader: PacketNumberLen not set") } return nil } -// ParsePublicHeader parses a QUIC packet's public header. +// parsePublicHeader parses a QUIC packet's Public Header. // The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient. -// Warning: This API should not be considered stable and will change soon. -func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) { - header := &PublicHeader{} +func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { + header := &Header{} // First byte publicFlagByte, err := b.ReadByte() if err != nil { return nil, err } - header.VersionFlag = publicFlagByte&0x01 > 0 header.ResetFlag = publicFlagByte&0x02 > 0 + header.VersionFlag = publicFlagByte&0x01 > 0 // TODO: activate this check once Chrome sends the correct value // see https://github.com/lucas-clemente/quic-go/issues/232 @@ -129,11 +104,10 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub // return nil, errors.New("diversification nonces should only be sent by servers") // } - header.TruncateConnectionID = publicFlagByte&0x08 == 0 - if header.TruncateConnectionID && packetSentBy == protocol.PerspectiveClient { - return nil, errReceivedTruncatedConnectionID + header.OmitConnectionID = publicFlagByte&0x08 == 0 + if header.OmitConnectionID && packetSentBy == protocol.PerspectiveClient { + return nil, errReceivedOmittedConnectionID } - if header.hasPacketNumber(packetSentBy) { switch publicFlagByte & 0x30 { case 0x30: @@ -148,9 +122,9 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub } // Connection ID - if !header.TruncateConnectionID { + if !header.OmitConnectionID { var connID uint64 - connID, err = utils.ReadUint64(b) + connID, err = utils.BigEndian.ReadUint64(b) if err != nil { return nil, err } @@ -166,91 +140,85 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub // see https://github.com/lucas-clemente/quic-go/issues/232 if !header.VersionFlag && !header.ResetFlag { header.DiversificationNonce = make([]byte, 32) - // this Read can never return an EOF for a valid packet, since the diversification nonce is followed by the packet number - _, err = b.Read(header.DiversificationNonce) - if err != nil { + if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil { return nil, err } } } // Version (optional) - if !header.ResetFlag { - if header.VersionFlag { - if packetSentBy == protocol.PerspectiveClient { - var versionTag uint32 - versionTag, err = utils.ReadUint32(b) - if err != nil { - return nil, err - } - header.VersionNumber = protocol.VersionTagToNumber(versionTag) - } else { // parse the version negotiaton packet - if b.Len()%4 != 0 { - return nil, qerr.InvalidVersionNegotiationPacket - } - header.SupportedVersions = make([]protocol.VersionNumber, 0) - for { - var versionTag uint32 - versionTag, err = utils.ReadUint32(b) - if err != nil { - break - } - v := protocol.VersionTagToNumber(versionTag) - header.SupportedVersions = append(header.SupportedVersions, v) - } + if !header.ResetFlag && header.VersionFlag { + if packetSentBy == protocol.PerspectiveServer { // parse the version negotiaton packet + if b.Len() == 0 { + return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") } + if b.Len()%4 != 0 { + return nil, qerr.InvalidVersionNegotiationPacket + } + header.SupportedVersions = make([]protocol.VersionNumber, 0) + for { + var versionTag uint32 + versionTag, err = utils.BigEndian.ReadUint32(b) + if err != nil { + break + } + v := protocol.VersionNumber(versionTag) + header.SupportedVersions = append(header.SupportedVersions, v) + } + // a version negotiation packet doesn't have a packet number + return header, nil } + // packet was sent by the client. Read the version number + var versionTag uint32 + versionTag, err = utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + header.Version = protocol.VersionNumber(versionTag) } // Packet number if header.hasPacketNumber(packetSentBy) { - packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen)) + packetNumber, err := utils.BigEndian.ReadUintN(b, uint8(header.PacketNumberLen)) if err != nil { return nil, err } header.PacketNumber = protocol.PacketNumber(packetNumber) } - return header, nil } -// GetLength gets the length of the publicHeader in bytes. +// getPublicHeaderLength gets the length of the publicHeader in bytes. // It can only be called for regular packets. -func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) { +func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.ByteCount, error) { if h.VersionFlag && h.ResetFlag { return 0, errResetAndVersionFlagSet } - if h.VersionFlag && pers == protocol.PerspectiveServer { return 0, errGetLengthNotForVersionNegotiation } length := protocol.ByteCount(1) // 1 byte for public flags - if h.hasPacketNumber(pers) { if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { return 0, errPacketNumberLenNotSet } length += protocol.ByteCount(h.PacketNumberLen) } - - if !h.TruncateConnectionID { + if !h.OmitConnectionID { length += 8 // 8 bytes for the connection ID } - // Version Number in packets sent by the client if h.VersionFlag { length += 4 } - length += protocol.ByteCount(len(h.DiversificationNonce)) - return length, nil } -// hasPacketNumber determines if this PublicHeader will contain a packet number +// hasPacketNumber determines if this Public Header will contain a packet number // this depends on the ResetFlag, the VersionFlag and who sent the packet -func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool { +func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { if h.ResetFlag { return false } @@ -259,3 +227,15 @@ func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool { } return true } + +func (h *Header) logPublicHeader() { + connID := "(omitted)" + if !h.OmitConnectionID { + connID = fmt.Sprintf("%#x", h.ConnectionID) + } + ver := "(unset)" + if h.Version != 0 { + ver = fmt.Sprintf("%s", h.Version) + } + utils.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go new file mode 100644 index 0000000..6adc9f6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/public_reset.go @@ -0,0 +1,65 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "errors" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A PublicReset is a PUBLIC_RESET +type PublicReset struct { + RejectedPacketNumber protocol.PacketNumber + Nonce uint64 +} + +// WritePublicReset writes a Public Reset +func WritePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { + b := &bytes.Buffer{} + b.WriteByte(0x0a) + utils.BigEndian.WriteUint64(b, uint64(connectionID)) + utils.LittleEndian.WriteUint32(b, uint32(handshake.TagPRST)) + utils.LittleEndian.WriteUint32(b, 2) + utils.LittleEndian.WriteUint32(b, uint32(handshake.TagRNON)) + utils.LittleEndian.WriteUint32(b, 8) + utils.LittleEndian.WriteUint32(b, uint32(handshake.TagRSEQ)) + utils.LittleEndian.WriteUint32(b, 16) + utils.LittleEndian.WriteUint64(b, nonceProof) + utils.LittleEndian.WriteUint64(b, uint64(rejectedPacketNumber)) + return b.Bytes() +} + +// ParsePublicReset parses a Public Reset +func ParsePublicReset(r *bytes.Reader) (*PublicReset, error) { + pr := PublicReset{} + msg, err := handshake.ParseHandshakeMessage(r) + if err != nil { + return nil, err + } + if msg.Tag != handshake.TagPRST { + return nil, errors.New("wrong public reset tag") + } + + // The RSEQ tag is mandatory according to the gQUIC wire spec. + // However, Google doesn't send RSEQ in their Public Resets. + // Therefore, we'll treat RSEQ as an optional field. + if rseq, ok := msg.Data[handshake.TagRSEQ]; ok { + if len(rseq) != 8 { + return nil, errors.New("invalid RSEQ tag") + } + pr.RejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) + } + + rnon, ok := msg.Data[handshake.TagRNON] + if !ok { + return nil, errors.New("RNON missing") + } + if len(rnon) != 8 { + return nil, errors.New("invalid RNON tag") + } + pr.Nonce = binary.LittleEndian.Uint64(rnon) + return &pr, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go similarity index 60% rename from vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go index ea2531c..04086f8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/rst_stream_frame.go @@ -1,10 +1,10 @@ -package frames +package wire import ( "bytes" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" ) // A RstStreamFrame in QUIC @@ -17,9 +17,9 @@ type RstStreamFrame struct { //Write writes a RST_STREAM frame func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { b.WriteByte(0x01) - utils.WriteUint32(b, uint32(f.StreamID)) - utils.WriteUint64(b, uint64(f.ByteOffset)) - utils.WriteUint32(b, f.ErrorCode) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + utils.GetByteOrder(version).WriteUint32(b, f.ErrorCode) return nil } @@ -29,31 +29,29 @@ func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.Byt } // ParseRstStreamFrame parses a RST_STREAM frame -func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) { +func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) { frame := &RstStreamFrame{} // read the TypeByte - _, err := r.ReadByte() - if err != nil { + if _, err := r.ReadByte(); err != nil { return nil, err } - sid, err := utils.ReadUint32(r) + sid, err := utils.GetByteOrder(version).ReadUint32(r) if err != nil { return nil, err } frame.StreamID = protocol.StreamID(sid) - byteOffset, err := utils.ReadUint64(r) + byteOffset, err := utils.GetByteOrder(version).ReadUint64(r) if err != nil { return nil, err } frame.ByteOffset = protocol.ByteCount(byteOffset) - frame.ErrorCode, err = utils.ReadUint32(r) + frame.ErrorCode, err = utils.GetByteOrder(version).ReadUint32(r) if err != nil { return nil, err } - return frame, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/stop_waiting_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go similarity index 78% rename from vendor/github.com/lucas-clemente/quic-go/frames/stop_waiting_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go index 91f937a..9eb068d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/stop_waiting_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_waiting_frame.go @@ -1,11 +1,11 @@ -package frames +package wire import ( "bytes" "errors" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -13,7 +13,8 @@ import ( type StopWaitingFrame struct { LeastUnacked protocol.PacketNumber PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber + // PacketNumber is the packet number of the packet that this StopWaitingFrame will be sent with + PacketNumber protocol.PacketNumber } var ( @@ -23,34 +24,28 @@ var ( ) func (f *StopWaitingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - // packetNumber is the packet number of the packet that this StopWaitingFrame will be sent with - typeByte := uint8(0x06) - b.WriteByte(typeByte) - // make sure the PacketNumber was set if f.PacketNumber == protocol.PacketNumber(0) { return errPacketNumberNotSet } - if f.LeastUnacked > f.PacketNumber { return errLeastUnackedHigherThanPacketNumber } + b.WriteByte(0x06) leastUnackedDelta := uint64(f.PacketNumber - f.LeastUnacked) - switch f.PacketNumberLen { case protocol.PacketNumberLen1: b.WriteByte(uint8(leastUnackedDelta)) case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(leastUnackedDelta)) + utils.GetByteOrder(version).WriteUint16(b, uint16(leastUnackedDelta)) case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(leastUnackedDelta)) + utils.GetByteOrder(version).WriteUint32(b, uint32(leastUnackedDelta)) case protocol.PacketNumberLen6: - utils.WriteUint48(b, leastUnackedDelta) + utils.GetByteOrder(version).WriteUint48(b, leastUnackedDelta&(1<<48-1)) default: return errPacketNumberLenNotSet } - return nil } @@ -62,7 +57,6 @@ func (f *StopWaitingFrame) MinLength(version protocol.VersionNumber) (protocol.B return 0, errPacketNumberLenNotSet } minLength += protocol.ByteCount(f.PacketNumberLen) - return minLength, nil } @@ -71,21 +65,17 @@ func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, frame := &StopWaitingFrame{} // read the TypeByte - _, err := r.ReadByte() - if err != nil { + if _, err := r.ReadByte(); err != nil { return nil, err } - leastUnackedDelta, err := utils.ReadUintN(r, uint8(packetNumberLen)) + leastUnackedDelta, err := utils.GetByteOrder(version).ReadUintN(r, uint8(packetNumberLen)) if err != nil { return nil, err } - - if leastUnackedDelta > uint64(packetNumber) { + if leastUnackedDelta >= uint64(packetNumber) { return nil, qerr.Error(qerr.InvalidStopWaitingData, "invalid LeastUnackedDelta") } - frame.LeastUnacked = protocol.PacketNumber(uint64(packetNumber) - leastUnackedDelta) - return frame, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go new file mode 100644 index 0000000..981c0ec --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_blocked_frame.go @@ -0,0 +1,44 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A StreamBlockedFrame in QUIC +type StreamBlockedFrame struct { + StreamID protocol.StreamID +} + +// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame +func ParseStreamBlockedFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamBlockedFrame, error) { + frame := &StreamBlockedFrame{} + + // read the TypeByte + if _, err := r.ReadByte(); err != nil { + return nil, err + } + sid, err := utils.GetByteOrder(version).ReadUint32(r) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(sid) + return frame, nil +} + +// Write writes a STREAM_BLOCKED frame +func (f *StreamBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesMaxDataFrame() { + return (&blockedFrameLegacy{StreamID: f.StreamID}).Write(b, version) + } + b.WriteByte(0x09) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + return nil +} + +// MinLength of a written frame +func (f *StreamBlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + return 1 + 4, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frames/stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go similarity index 69% rename from vendor/github.com/lucas-clemente/quic-go/frames/stream_frame.go rename to vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go index 7dd6223..75be888 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frames/stream_frame.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go @@ -1,11 +1,12 @@ -package frames +package wire import ( "bytes" "errors" + "io" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" ) @@ -24,7 +25,7 @@ var ( ) // ParseStreamFrame reads a stream frame. The type byte must not have been read yet. -func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { +func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { frame := &StreamFrame{} typeByte, err := r.ReadByte() @@ -34,19 +35,19 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { frame.FinBit = typeByte&0x40 > 0 frame.DataLenPresent = typeByte&0x20 > 0 - offsetLen := typeByte & 0x1C >> 2 + offsetLen := typeByte & 0x1c >> 2 if offsetLen != 0 { offsetLen++ } - streamIDLen := typeByte&0x03 + 1 + streamIDLen := typeByte&0x3 + 1 - sid, err := utils.ReadUintN(r, streamIDLen) + sid, err := utils.GetByteOrder(version).ReadUintN(r, streamIDLen) if err != nil { return nil, err } frame.StreamID = protocol.StreamID(sid) - offset, err := utils.ReadUintN(r, offsetLen) + offset, err := utils.GetByteOrder(version).ReadUintN(r, offsetLen) if err != nil { return nil, err } @@ -54,14 +55,17 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { var dataLen uint16 if frame.DataLenPresent { - dataLen, err = utils.ReadUint16(r) + dataLen, err = utils.GetByteOrder(version).ReadUint16(r) if err != nil { return nil, err } } - if dataLen > uint16(protocol.MaxPacketSize) { - return nil, qerr.Error(qerr.InvalidStreamData, "data len too large") + // shortcut to prevent the unneccessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the packet contents would result in EOF when attempting to READ + if int(dataLen) > r.Len() { + return nil, io.EOF } if !frame.DataLenPresent { @@ -70,11 +74,8 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { } if dataLen != 0 { frame.Data = make([]byte, dataLen) - n, err := r.Read(frame.Data) - if n != int(dataLen) { - return nil, errors.New("BUG: StreamFrame could not read dataLen bytes") - } - if err != nil { + if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier return nil, err } } @@ -82,11 +83,9 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { if frame.Offset+frame.DataLen() < frame.Offset { return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") } - if !frame.FinBit && frame.DataLen() == 0 { return nil, qerr.EmptyStreamFrameNoFin } - return frame, nil } @@ -97,17 +96,14 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err } typeByte := uint8(0x80) // sets the leftmost bit to 1 - if f.FinBit { typeByte ^= 0x40 } - if f.DataLenPresent { typeByte ^= 0x20 } offsetLength := f.getOffsetLength() - if offsetLength > 0 { typeByte ^= (uint8(offsetLength) - 1) << 2 } @@ -121,11 +117,11 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err case 1: b.WriteByte(uint8(f.StreamID)) case 2: - utils.WriteUint16(b, uint16(f.StreamID)) + utils.GetByteOrder(version).WriteUint16(b, uint16(f.StreamID)) case 3: - utils.WriteUint24(b, uint32(f.StreamID)) + utils.GetByteOrder(version).WriteUint24(b, uint32(f.StreamID)) case 4: - utils.WriteUint32(b, uint32(f.StreamID)) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) default: return errInvalidStreamIDLen } @@ -133,29 +129,28 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err switch offsetLength { case 0: case 2: - utils.WriteUint16(b, uint16(f.Offset)) + utils.GetByteOrder(version).WriteUint16(b, uint16(f.Offset)) case 3: - utils.WriteUint24(b, uint32(f.Offset)) + utils.GetByteOrder(version).WriteUint24(b, uint32(f.Offset)) case 4: - utils.WriteUint32(b, uint32(f.Offset)) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.Offset)) case 5: - utils.WriteUint40(b, uint64(f.Offset)) + utils.GetByteOrder(version).WriteUint40(b, uint64(f.Offset)) case 6: - utils.WriteUint48(b, uint64(f.Offset)) + utils.GetByteOrder(version).WriteUint48(b, uint64(f.Offset)) case 7: - utils.WriteUint56(b, uint64(f.Offset)) + utils.GetByteOrder(version).WriteUint56(b, uint64(f.Offset)) case 8: - utils.WriteUint64(b, uint64(f.Offset)) + utils.GetByteOrder(version).WriteUint64(b, uint64(f.Offset)) default: return errInvalidOffsetLen } if f.DataLenPresent { - utils.WriteUint16(b, uint16(len(f.Data))) + utils.GetByteOrder(version).WriteUint16(b, uint16(len(f.Data))) } b.Write(f.Data) - return nil } @@ -202,7 +197,6 @@ func (f *StreamFrame) MinLength(protocol.VersionNumber) (protocol.ByteCount, err if f.DataLenPresent { length += 2 } - return length, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go new file mode 100644 index 0000000..92afb3b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go @@ -0,0 +1,51 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC +func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { + fullReply := &bytes.Buffer{} + ph := Header{ + ConnectionID: connID, + PacketNumber: 1, + VersionFlag: true, + } + if err := ph.writePublicHeader(fullReply, protocol.PerspectiveServer, protocol.VersionWhatever); err != nil { + utils.Errorf("error composing version negotiation packet: %s", err.Error()) + return nil + } + for _, v := range versions { + utils.BigEndian.WriteUint32(fullReply, uint32(v)) + } + return fullReply.Bytes() +} + +// ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft +func ComposeVersionNegotiation( + connID protocol.ConnectionID, + pn protocol.PacketNumber, + versionOffered protocol.VersionNumber, + versions []protocol.VersionNumber, +) []byte { + fullReply := &bytes.Buffer{} + ph := Header{ + IsLongHeader: true, + Type: protocol.PacketTypeVersionNegotiation, + ConnectionID: connID, + PacketNumber: pn, + Version: versionOffered, + } + if err := ph.writeHeader(fullReply); err != nil { + utils.Errorf("error composing version negotiation packet: %s", err.Error()) + return nil + } + for _, v := range versions { + utils.BigEndian.WriteUint32(fullReply, uint32(v)) + } + return fullReply.Bytes() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go new file mode 100644 index 0000000..20d7b66 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/window_update_frame.go @@ -0,0 +1,35 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type windowUpdateFrame struct { + StreamID protocol.StreamID + ByteOffset protocol.ByteCount +} + +// ParseWindowUpdateFrame parses a WINDOW_UPDATE frame +// The frame returned is +// * a MAX_STREAM_DATA frame, if the WINDOW_UPDATE applies to a stream +// * a MAX_DATA frame, if the WINDOW_UPDATE applies to the connection +func ParseWindowUpdateFrame(r *bytes.Reader, version protocol.VersionNumber) (Frame, error) { + f, err := ParseMaxStreamDataFrame(r, version) + if err != nil { + return nil, err + } + if f.StreamID == 0 { + return &MaxDataFrame{ByteOffset: f.ByteOffset}, nil + } + return f, nil +} + +func (f *windowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x4) + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + utils.GetByteOrder(version).WriteUint64(b, uint64(f.ByteOffset)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go b/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go index 71ca9a3..8ece95a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_number_generator.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "math" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" ) // The packetNumberGenerator generates the packet number for the next packet diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go index 28c29ac..1a63715 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -6,15 +6,15 @@ import ( "fmt" "github.com/lucas-clemente/quic-go/ackhandler" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type packedPacket struct { - number protocol.PacketNumber + header *wire.Header raw []byte - frames []frames.Frame + frames []wire.Frame encryptionLevel protocol.EncryptionLevel } @@ -25,18 +25,17 @@ type packetPacker struct { cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator - connectionParameters handshake.ConnectionParametersManager streamFramer *streamFramer - controlFrames []frames.Frame - stopWaiting *frames.StopWaitingFrame - ackFrame *frames.AckFrame - leastUnacked protocol.PacketNumber + controlFrames []wire.Frame + stopWaiting *wire.StopWaitingFrame + ackFrame *wire.AckFrame + leastUnacked protocol.PacketNumber + omitConnectionID bool } func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, - connectionParameters handshake.ConnectionParametersManager, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber, @@ -44,7 +43,6 @@ func newPacketPacker(connectionID protocol.ConnectionID, return &packetPacker{ cryptoSetup: cryptoSetup, connectionID: connectionID, - connectionParameters: connectionParameters, perspective: perspective, version: version, streamFramer: streamFramer, @@ -53,13 +51,13 @@ func newPacketPacker(connectionID protocol.ConnectionID, } // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame -func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame) (*packedPacket, error) { - frames := []frames.Frame{ccf} +func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { + frames := []wire.Frame{ccf} encLevel, sealer := p.cryptoSetup.GetSealer() - ph := p.getPublicHeader(encLevel) - raw, err := p.writeAndSealPacket(ph, frames, sealer) + header := p.getHeader(encLevel) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -71,18 +69,18 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) { return nil, errors.New("packet packer BUG: no ack frame queued") } encLevel, sealer := p.cryptoSetup.GetSealer() - ph := p.getPublicHeader(encLevel) - frames := []frames.Frame{p.ackFrame} + header := p.getHeader(encLevel) + frames := []wire.Frame{p.ackFrame} if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = ph.PacketNumber - p.stopWaiting.PacketNumberLen = ph.PacketNumberLen + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen frames = append(frames, p.stopWaiting) p.stopWaiting = nil } p.ackFrame = nil - raw, err := p.writeAndSealPacket(ph, frames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -101,14 +99,14 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* if p.stopWaiting == nil { return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") } - ph := p.getPublicHeader(packet.EncryptionLevel) - p.stopWaiting.PacketNumber = ph.PacketNumber - p.stopWaiting.PacketNumberLen = ph.PacketNumberLen - frames := append([]frames.Frame{p.stopWaiting}, packet.Frames...) + header := p.getHeader(packet.EncryptionLevel) + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen + frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...) p.stopWaiting = nil - raw, err := p.writeAndSealPacket(ph, frames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: packet.EncryptionLevel, @@ -124,17 +122,17 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() - publicHeader := p.getPublicHeader(encLevel) - publicHeaderLength, err := publicHeader.GetLength(p.perspective) + header := p.getHeader(encLevel) + headerLength, err := header.GetLength(p.perspective, p.version) if err != nil { return nil, err } if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = publicHeader.PacketNumber - p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen } - maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength + maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) if err != nil { return nil, err @@ -151,12 +149,12 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.stopWaiting = nil p.ackFrame = nil - raw, err := p.writeAndSealPacket(publicHeader, payloadFrames, sealer) + raw, err := p.writeAndSealPacket(header, payloadFrames, sealer) if err != nil { return nil, err } return &packedPacket{ - number: publicHeader.PacketNumber, + header: header, raw: raw, frames: payloadFrames, encryptionLevel: encLevel, @@ -165,19 +163,19 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream() - publicHeader := p.getPublicHeader(encLevel) - publicHeaderLength, err := publicHeader.GetLength(p.perspective) + header := p.getHeader(encLevel) + headerLength, err := header.GetLength(p.perspective, p.version) if err != nil { return nil, err } - maxLen := protocol.MaxFrameAndPublicHeaderSize - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength - frames := []frames.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} - raw, err := p.writeAndSealPacket(publicHeader, frames, sealer) + maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength + frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} + raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err } return &packedPacket{ - number: publicHeader.PacketNumber, + header: header, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -187,9 +185,9 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { func (p *packetPacker) composeNextPacket( maxFrameSize protocol.ByteCount, canSendStreamFrames bool, -) ([]frames.Frame, error) { +) ([]wire.Frame, error) { var payloadLength protocol.ByteCount - var payloadFrames []frames.Frame + var payloadFrames []wire.Frame // STOP_WAITING and ACK will always fit if p.stopWaiting != nil { @@ -253,47 +251,64 @@ func (p *packetPacker) composeNextPacket( return payloadFrames, nil } -func (p *packetPacker) QueueControlFrame(frame frames.Frame) { +func (p *packetPacker) QueueControlFrame(frame wire.Frame) { switch f := frame.(type) { - case *frames.StopWaitingFrame: + case *wire.StopWaitingFrame: p.stopWaiting = f - case *frames.AckFrame: + case *wire.AckFrame: p.ackFrame = f default: p.controlFrames = append(p.controlFrames, f) } } -func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *PublicHeader { +func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { pnum := p.packetNumberGenerator.Peek() - packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked) - publicHeader := &PublicHeader{ - ConnectionID: p.connectionID, - PacketNumber: pnum, - PacketNumberLen: packetNumberLen, - TruncateConnectionID: p.connectionParameters.TruncateConnectionID(), + packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) + + var isLongHeader bool + if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { + // TODO: set the Long Header type + packetNumberLen = protocol.PacketNumberLen4 + isLongHeader = true } - if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { - publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce() - } - if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { - publicHeader.VersionFlag = true - publicHeader.VersionNumber = p.version + header := &wire.Header{ + ConnectionID: p.connectionID, + PacketNumber: pnum, + PacketNumberLen: packetNumberLen, + IsLongHeader: isLongHeader, } - return publicHeader + if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { + header.OmitConnectionID = true + } + if !p.version.UsesTLS() { + if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { + header.DiversificationNonce = p.cryptoSetup.DiversificationNonce() + } + if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { + header.VersionFlag = true + header.Version = p.version + } + } else { + header.Type = p.cryptoSetup.GetNextPacketType() + if encLevel != protocol.EncryptionForwardSecure { + header.Version = p.version + } + } + return header } func (p *packetPacker) writeAndSealPacket( - publicHeader *PublicHeader, - payloadFrames []frames.Frame, + header *wire.Header, + payloadFrames []wire.Frame, sealer handshake.Sealer, ) ([]byte, error) { raw := getPacketBuffer() buffer := bytes.NewBuffer(raw) - if err := publicHeader.Write(buffer, p.version, p.perspective); err != nil { + if err := header.Write(buffer, p.perspective, p.version); err != nil { return nil, err } payloadStartIndex := buffer.Len() @@ -303,16 +318,16 @@ func (p *packetPacker) writeAndSealPacket( return nil, err } } - if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize { + if protocol.ByteCount(buffer.Len()+sealer.Overhead()) > protocol.MaxPacketSize { return nil, errors.New("PacketPacker BUG: packet too large") } raw = raw[0:buffer.Len()] - _ = sealer(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex]) - raw = raw[0 : buffer.Len()+12] + _ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex]) + raw = raw[0 : buffer.Len()+sealer.Overhead()] num := p.packetNumberGenerator.Pop() - if num != publicHeader.PacketNumber { + if num != header.PacketNumber { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } @@ -329,3 +344,7 @@ func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool { func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { p.leastUnacked = leastUnacked } + +func (p *packetPacker) SetOmitConnectionID() { + p.omitConnectionID = true +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go index c92e6a5..f891e37 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go @@ -2,17 +2,16 @@ package quic import ( "bytes" - "errors" "fmt" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) type unpackedPacket struct { encryptionLevel protocol.EncryptionLevel - frames []frames.Frame + frames []wire.Frame } type quicAEAD interface { @@ -24,10 +23,10 @@ type packetUnpacker struct { aead quicAEAD } -func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, data []byte) (*unpackedPacket, error) { +func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { buf := getPacketBuffer() defer putPacketBuffer(buf) - decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, publicHeaderBinary) + decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, headerBinary) if err != nil { // Wrap err in quicError so that public reset is sent by session return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) @@ -38,7 +37,7 @@ func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, da return nil, qerr.MissingPayload } - fs := make([]frames.Frame, 0, 2) + fs := make([]wire.Frame, 0, 2) // Read all frames in the packet for r.Len() > 0 { @@ -48,61 +47,73 @@ func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *PublicHeader, da } r.UnreadByte() - var frame frames.Frame + var frame wire.Frame if typeByte&0x80 == 0x80 { - frame, err = frames.ParseStreamFrame(r) + frame, err = wire.ParseStreamFrame(r, u.version) if err != nil { err = qerr.Error(qerr.InvalidStreamData, err.Error()) } else { - streamID := frame.(*frames.StreamFrame).StreamID - if streamID != 1 && encryptionLevel <= protocol.EncryptionUnencrypted { + streamID := frame.(*wire.StreamFrame).StreamID + if streamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted { err = qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", streamID)) } } } else if typeByte&0xc0 == 0x40 { - frame, err = frames.ParseAckFrame(r, u.version) + frame, err = wire.ParseAckFrame(r, u.version) if err != nil { err = qerr.Error(qerr.InvalidAckData, err.Error()) } - } else if typeByte&0xe0 == 0x20 { - err = errors.New("unimplemented: CONGESTION_FEEDBACK") - } else { - switch typeByte { - case 0x01: - frame, err = frames.ParseRstStreamFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) - } - case 0x02: - frame, err = frames.ParseConnectionCloseFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) - } - case 0x03: - frame, err = frames.ParseGoawayFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidGoawayData, err.Error()) - } - case 0x04: - frame, err = frames.ParseWindowUpdateFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) - } - case 0x05: - frame, err = frames.ParseBlockedFrame(r) - if err != nil { - err = qerr.Error(qerr.InvalidBlockedData, err.Error()) - } - case 0x06: - frame, err = frames.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) - if err != nil { - err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) - } - case 0x07: - frame, err = frames.ParsePingFrame(r) - default: - err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) + } else if typeByte == 0x01 { + frame, err = wire.ParseRstStreamFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidRstStreamData, err.Error()) } + } else if typeByte == 0x02 { + frame, err = wire.ParseConnectionCloseFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error()) + } + } else if typeByte == 0x3 { + frame, err = wire.ParseGoawayFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidGoawayData, err.Error()) + } + } else if u.version.UsesMaxDataFrame() && typeByte == 0x4 { // in IETF QUIC, 0x4 is a MAX_DATA frame + frame, err = wire.ParseMaxDataFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + } else if typeByte == 0x4 { // in gQUIC, 0x4 is a WINDOW_UPDATE frame + frame, err = wire.ParseWindowUpdateFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + } else if u.version.UsesMaxDataFrame() && typeByte == 0x5 { // in IETF QUIC, 0x5 is a MAX_STREAM_DATA frame + frame, err = wire.ParseMaxStreamDataFrame(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) + } + } else if typeByte == 0x5 { // in gQUIC, 0x5 is a BLOCKED frame + frame, err = wire.ParseBlockedFrameLegacy(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + } else if typeByte == 0x6 { + frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) + } + } else if typeByte == 0x7 { + frame, err = wire.ParsePingFrame(r, u.version) + } else if u.version.UsesMaxDataFrame() && typeByte == 0x8 { // in IETF QUIC, 0x4 is a BLOCKED frame + frame, err = wire.ParseBlockedFrame(r, u.version) + } else if u.version.UsesMaxDataFrame() && typeByte == 0x9 { // in IETF QUIC, 0x4 is a STREAM_BLOCKED frame + frame, err = wire.ParseBlockedFrameLegacy(r, u.version) + if err != nil { + err = qerr.Error(qerr.InvalidBlockedData, err.Error()) + } + } else { + err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) } if err != nil { return nil, err diff --git a/vendor/github.com/lucas-clemente/quic-go/protocol/version.go b/vendor/github.com/lucas-clemente/quic-go/protocol/version.go deleted file mode 100644 index 388162e..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/protocol/version.go +++ /dev/null @@ -1,55 +0,0 @@ -package protocol - -// VersionNumber is a version number as int -type VersionNumber int - -// The version numbers, making grepping easier -const ( - Version35 VersionNumber = 35 + iota - Version36 - Version37 - VersionWhatever VersionNumber = 0 // for when the version doesn't matter - VersionUnsupported VersionNumber = -1 -) - -// SupportedVersions lists the versions that the server supports -// must be in sorted descending order -var SupportedVersions = []VersionNumber{ - Version37, Version36, Version35, -} - -// VersionNumberToTag maps version numbers ('32') to tags ('Q032') -func VersionNumberToTag(vn VersionNumber) uint32 { - v := uint32(vn) - return 'Q' + ((v/100%10)+'0')<<8 + ((v/10%10)+'0')<<16 + ((v%10)+'0')<<24 -} - -// VersionTagToNumber is built from VersionNumberToTag in init() -func VersionTagToNumber(v uint32) VersionNumber { - return VersionNumber(((v>>8)&0xff-'0')*100 + ((v>>16)&0xff-'0')*10 + ((v>>24)&0xff - '0')) -} - -// IsSupportedVersion returns true if the server supports this version -func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { - for _, t := range supported { - if t == v { - return true - } - } - return false -} - -// ChooseSupportedVersion finds the best version in the overlap of ours and theirs -// ours is a slice of versions that we support, sorted by our preference (descending) -// theirs is a slice of versions offered by the peer. The order does not matter -// if no suitable version is found, it returns VersionUnsupported -func ChooseSupportedVersion(ours, theirs []VersionNumber) VersionNumber { - for _, ourVer := range ours { - for _, theirVer := range theirs { - if ourVer == theirVer { - return ourVer - } - } - } - return VersionUnsupported -} diff --git a/vendor/github.com/lucas-clemente/quic-go/public_reset.go b/vendor/github.com/lucas-clemente/quic-go/public_reset.go deleted file mode 100644 index 958db9c..0000000 --- a/vendor/github.com/lucas-clemente/quic-go/public_reset.go +++ /dev/null @@ -1,62 +0,0 @@ -package quic - -import ( - "bytes" - "encoding/binary" - "errors" - - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" -) - -type publicReset struct { - rejectedPacketNumber protocol.PacketNumber - nonce uint64 -} - -func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { - b := &bytes.Buffer{} - b.WriteByte(0x0a) - utils.WriteUint64(b, uint64(connectionID)) - utils.WriteUint32(b, uint32(handshake.TagPRST)) - utils.WriteUint32(b, 2) - utils.WriteUint32(b, uint32(handshake.TagRNON)) - utils.WriteUint32(b, 8) - utils.WriteUint32(b, uint32(handshake.TagRSEQ)) - utils.WriteUint32(b, 16) - utils.WriteUint64(b, nonceProof) - utils.WriteUint64(b, uint64(rejectedPacketNumber)) - return b.Bytes() -} - -func parsePublicReset(r *bytes.Reader) (*publicReset, error) { - pr := publicReset{} - msg, err := handshake.ParseHandshakeMessage(r) - if err != nil { - return nil, err - } - if msg.Tag != handshake.TagPRST { - return nil, errors.New("wrong public reset tag") - } - - rseq, ok := msg.Data[handshake.TagRSEQ] - if !ok { - return nil, errors.New("RSEQ missing") - } - if len(rseq) != 8 { - return nil, errors.New("invalid RSEQ tag") - } - pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) - - rnon, ok := msg.Data[handshake.TagRNON] - if !ok { - return nil, errors.New("RNON missing") - } - if len(rnon) != 8 { - return nil, errors.New("invalid RNON tag") - } - pr.nonce = binary.LittleEndian.Uint64(rnon) - - return &pr, nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go index 76f07ba..fb73ccb 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/server.go @@ -8,10 +8,11 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/crypto" - "github.com/lucas-clemente/quic-go/handshake" + "github.com/lucas-clemente/quic-go/internal/crypto" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -19,6 +20,7 @@ import ( type packetHandler interface { Session handlePacket(*receivedPacket) + GetVersion() protocol.VersionNumber run() error closeRemote(error) } @@ -88,14 +90,15 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, errorChan: make(chan struct{}), } go s.serve() + utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } -var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool { - if stk == nil { +var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool { + if cookie == nil { return false } - if time.Now().After(stk.sentTime.Add(protocol.STKExpiryTime)) { + if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) { return false } var sourceAddr string @@ -104,7 +107,7 @@ var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool { } else { sourceAddr = clientAddr.String() } - return sourceAddr == stk.remoteAddr + return sourceAddr == cookie.RemoteAddr } // populateServerConfig populates fields in the quic.Config with their default values, if none are set @@ -118,15 +121,19 @@ func populateServerConfig(config *Config) *Config { versions = protocol.SupportedVersions } - vsa := defaultAcceptSTK - if config.AcceptSTK != nil { - vsa = config.AcceptSTK + vsa := defaultAcceptCookie + if config.AcceptCookie != nil { + vsa = config.AcceptCookie } handshakeTimeout := protocol.DefaultHandshakeTimeout if config.HandshakeTimeout != 0 { handshakeTimeout = config.HandshakeTimeout } + idleTimeout := protocol.DefaultIdleTimeout + if config.IdleTimeout != 0 { + idleTimeout = config.IdleTimeout + } maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow if maxReceiveStreamFlowControlWindow == 0 { @@ -140,7 +147,9 @@ func populateServerConfig(config *Config) *Config { return &Config{ Versions: versions, HandshakeTimeout: handshakeTimeout, - AcceptSTK: vsa, + IdleTimeout: idleTimeout, + AcceptCookie: vsa, + KeepAlive: config.KeepAlive, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, } @@ -181,14 +190,19 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { s.sessionsMutex.Lock() + var wg sync.WaitGroup for _, session := range s.sessions { if session != nil { - s.sessionsMutex.Unlock() - _ = session.Close(nil) - s.sessionsMutex.Lock() + wg.Add(1) + go func(sess packetHandler) { + // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + _ = sess.Close(nil) + wg.Done() + }(session) } } s.sessionsMutex.Unlock() + wg.Wait() if s.conn == nil { return nil @@ -205,25 +219,31 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient) + hdr, err := wire.ParseHeaderSentByClient(r) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } hdr.Raw = packet[:len(packet)-r.Len()] + connID := hdr.ConnectionID s.sessionsMutex.RLock() - session, ok := s.sessions[hdr.ConnectionID] + session, sessionKnown := s.sessions[connID] s.sessionsMutex.RUnlock() + if sessionKnown && session == nil { + // Late packet for closed session + return nil + } + // ignore all Public Reset packets if hdr.ResetFlag { - if ok { - var pr *publicReset - pr, err = parsePublicReset(r) + if sessionKnown { + var pr *wire.PublicReset + pr, err = wire.ParsePublicReset(r) if err != nil { utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") } else { - utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.rejectedPacketNumber) + utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) } } else { utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) @@ -231,35 +251,46 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return nil } + // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset + // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. + // TODO(#943): implement sending of IETF draft style stateless resets + if !sessionKnown && (!hdr.VersionFlag && hdr.Type != protocol.PacketTypeInitial) { + _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) + return err + } + // a session is only created once the client sent a supported version // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated // it is safe to drop it - if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) { + if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { return nil } - // Send Version Negotiation Packet if the client is speaking a different protocol version - if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) { + // send a Version Negotiation Packet if the client is speaking a different protocol version + // since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet + if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { // drop packets that are too small to be valid first packets if len(packet) < protocol.ClientHelloMinimumSize+len(hdr.Raw) { return errors.New("dropping small packet with unknown version") } - utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber) - _, err = pconn.WriteTo(composeVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) + utils.Infof("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) + if _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr); err != nil { + return err + } + } + // send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header + if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + _, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, hdr.Version, s.config.Versions), remoteAddr) return err } - if !ok { - if !hdr.VersionFlag { - _, err = pconn.WriteTo(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) - return err - } - version := hdr.VersionNumber + if !sessionKnown { + version := hdr.Version if !protocol.IsSupportedVersion(s.config.Versions, version) { return errors.New("Server BUG: negotiated version not supported") } - utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr) + utils.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) var handshakeChan <-chan handshakeEvent session, handshakeChan, err = s.newSession( &conn{pconn: pconn, currentAddr: remoteAddr}, @@ -273,13 +304,13 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return err } s.sessionsMutex.Lock() - s.sessions[hdr.ConnectionID] = session + s.sessions[connID] = session s.sessionsMutex.Unlock() go func() { // session.run() returns as soon as the session is closed _ = session.run() - s.removeConnection(hdr.ConnectionID) + s.removeConnection(connID) }() go func() { @@ -295,15 +326,11 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet s.sessionQueue <- session }() } - if session == nil { - // Late packet for closed session - return nil - } session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - publicHeader: hdr, - data: packet[len(packet)-r.Len():], - rcvTime: rcvTime, + remoteAddr: remoteAddr, + header: hdr, + data: packet[len(packet)-r.Len():], + rcvTime: rcvTime, }) return nil } @@ -319,20 +346,3 @@ func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Unlock() }) } - -func composeVersionNegotiation(connectionID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { - fullReply := &bytes.Buffer{} - responsePublicHeader := PublicHeader{ - ConnectionID: connectionID, - PacketNumber: 1, - VersionFlag: true, - } - err := responsePublicHeader.Write(fullReply, protocol.VersionWhatever, protocol.PerspectiveServer) - if err != nil { - utils.Errorf("error composing version negotiation packet: %s", err.Error()) - } - for _, v := range versions { - utils.WriteUint32(fullReply, protocol.VersionNumberToTag(v)) - } - return fullReply.Bytes() -} diff --git a/vendor/github.com/lucas-clemente/quic-go/session.go b/vendor/github.com/lucas-clemente/quic-go/session.go index 376aa7f..06d6916 100644 --- a/vendor/github.com/lucas-clemente/quic-go/session.go +++ b/vendor/github.com/lucas-clemente/quic-go/session.go @@ -1,6 +1,7 @@ package quic import ( + "context" "crypto/tls" "errors" "fmt" @@ -10,23 +11,23 @@ import ( "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/flowcontrol" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/handshake" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) type unpacker interface { - Unpack(publicHeaderBinary []byte, hdr *PublicHeader, data []byte) (*unpackedPacket, error) + Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) } type receivedPacket struct { - remoteAddr net.Addr - publicHeader *PublicHeader - data []byte - rcvTime time.Time + remoteAddr net.Addr + header *wire.Header + data []byte + rcvTime time.Time } var ( @@ -54,12 +55,12 @@ type session struct { connectionID protocol.ConnectionID perspective protocol.Perspective version protocol.VersionNumber - tlsConf *tls.Config config *Config conn connection - streamsMap *streamsMap + streamsMap *streamsMap + cryptoStream streamI rttStats *congestion.RTTStats @@ -67,7 +68,7 @@ type session struct { receivedPacketHandler ackhandler.ReceivedPacketHandler streamFramer *streamFramer - flowControlManager flowcontrol.FlowControlManager + connFlowController flowcontrol.ConnectionFlowController unpacker unpacker packer *packetPacker @@ -78,16 +79,18 @@ type session struct { sendingScheduled chan struct{} // closeChan is used to notify the run loop that it should terminate. closeChan chan closeError - // runClosed is closed once the run loop exits - // it is used to block Close() and WaitUntilClosed() - runClosed chan struct{} closeOnce sync.Once + ctx context.Context + ctxCancel context.CancelFunc + // when we receive too many undecryptable packets during the handshake, we send a Public reset // but only after a time of protocol.PublicResetTimeout has passed undecryptablePackets []*receivedPacket receivedTooManyUndecrytablePacketsTime time.Time + // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them + paramsChan <-chan handshake.TransportParameters // this channel is passed to the CryptoSetup and receives the current encryption level // it is closed as soon as the handshake is complete aeadChanged <-chan protocol.EncryptionLevel @@ -100,8 +103,6 @@ type session struct { // it receives at most 3 handshake events: 2 when the encryption level changes, and one error handshakeChan chan<- handshakeEvent - connectionParameters handshake.ConnectionParametersManager - lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire // representation, and sent back in public reset packets @@ -110,6 +111,8 @@ type session struct { sessionCreationTime time.Time lastNetworkActivityTime time.Time + peerParams *handshake.TransportParameters + timer *utils.Timer // keepAlivePingSent stores whether a Ping frame was sent to the peer or not // it is reset as soon as we receive a packet from the peer @@ -134,7 +137,7 @@ func newSession( version: v, config: config, } - return s.setup(sCfg, "", nil) + return s.setup(sCfg, "", tlsConf, v, nil) } // declare this as a variable, such that we can it mock it in the tests @@ -145,34 +148,38 @@ var newClientSession = func( connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, - negotiatedVersions []protocol.VersionNumber, + initialVersion protocol.VersionNumber, + negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton ) (packetHandler, <-chan handshakeEvent, error) { s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveClient, version: v, - tlsConf: tlsConf, config: config, } - return s.setup(nil, hostname, negotiatedVersions) + return s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions) } func (s *session) setup( scfg *handshake.ServerConfig, hostname string, + tlsConf *tls.Config, + initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { aeadChanged := make(chan protocol.EncryptionLevel, 2) + paramsChan := make(chan handshake.TransportParameters) s.aeadChanged = aeadChanged + s.paramsChan = paramsChan handshakeChan := make(chan handshakeEvent, 3) s.handshakeChan = handshakeChan - s.runClosed = make(chan struct{}) s.handshakeCompleteChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) s.timer = utils.NewTimer() now := time.Now() @@ -180,49 +187,84 @@ func (s *session) setup( s.sessionCreationTime = now s.rttStats = &congestion.RTTStats{} - s.connectionParameters = handshake.NewConnectionParamatersManager(s.perspective, s.version, - s.config.MaxReceiveStreamFlowControlWindow, s.config.MaxReceiveConnectionFlowControlWindow) + transportParams := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: s.config.IdleTimeout, + } s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) - s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler() - s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) - s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) + s.connFlowController = flowcontrol.NewConnectionFlowController( + protocol.ReceiveConnectionFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), + s.rttStats, + ) + s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) + s.cryptoStream = s.newStream(s.version.CryptoStreamID()) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController) var err error if s.perspective == protocol.PerspectiveServer { - cryptoStream, _ := s.GetOrOpenStream(1) - _, _ = s.AcceptStream() // don't expose the crypto stream - verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool { - var stk *STK - if hstk != nil { - stk = &STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime} - } - return s.config.AcceptSTK(clientAddr, stk) + verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool { + return s.config.AcceptCookie(clientAddr, cookie) + } + if s.version.UsesTLS() { + s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer( + s.cryptoStream, + s.connectionID, + tlsConf, + s.conn.RemoteAddr(), + transportParams, + paramsChan, + aeadChanged, + verifySourceAddr, + s.config.Versions, + s.version, + ) + } else { + s.cryptoSetup, err = newCryptoSetup( + s.cryptoStream, + s.connectionID, + s.conn.RemoteAddr(), + s.version, + scfg, + transportParams, + s.config.Versions, + verifySourceAddr, + paramsChan, + aeadChanged, + ) } - s.cryptoSetup, err = newCryptoSetup( - s.connectionID, - s.conn.RemoteAddr(), - s.version, - scfg, - cryptoStream, - s.connectionParameters, - s.config.Versions, - verifySourceAddr, - aeadChanged, - ) } else { - cryptoStream, _ := s.OpenStream() - s.cryptoSetup, err = newCryptoSetupClient( - hostname, - s.connectionID, - s.version, - cryptoStream, - s.tlsConf, - s.connectionParameters, - aeadChanged, - &handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation}, - negotiatedVersions, - ) + transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission + if s.version.UsesTLS() { + s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient( + s.cryptoStream, + s.connectionID, + hostname, + tlsConf, + transportParams, + paramsChan, + aeadChanged, + initialVersion, + s.config.Versions, + s.version, + ) + } else { + s.cryptoSetup, err = newCryptoSetupClient( + s.cryptoStream, + hostname, + s.connectionID, + s.version, + tlsConf, + transportParams, + paramsChan, + aeadChanged, + initialVersion, + negotiatedVersions, + ) + } } if err != nil { return nil, nil, err @@ -230,7 +272,6 @@ func (s *session) setup( s.packer = newPacketPacker(s.connectionID, s.cryptoSetup, - s.connectionParameters, s.streamFramer, s.perspective, s.version, @@ -242,7 +283,8 @@ func (s *session) setup( // run the session main loop func (s *session) run() error { - // Start the crypto stream handler + defer s.ctxCancel() + go func() { if err := s.cryptoSetup.HandleCryptoStream(); err != nil { s.Close(err) @@ -285,11 +327,14 @@ runLoop: } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. - putPacketBuffer(p.publicHeader.Raw) + putPacketBuffer(p.header.Raw) + case p := <-s.paramsChan: + s.processTransportParameters(&p) case l, ok := <-aeadChanged: if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. s.handshakeComplete = true aeadChanged = nil // prevent this case from ever being selected again + s.sentPacketHandler.SetHandshakeComplete() close(s.handshakeChan) close(s.handshakeCompleteChan) } else { @@ -305,9 +350,9 @@ runLoop: s.sentPacketHandler.OnAlarm() } - if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.idleTimeout()/2 { + if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 { // send the PING frame since there is no activity in the session - s.packer.QueueControlFrame(&frames.PingFrame{}) + s.packer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true } @@ -317,13 +362,16 @@ runLoop: if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } - if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { - s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) - } if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout { s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time.")) } - s.garbageCollectStreams() + if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout { + s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) + } + + if err := s.streamsMap.DeleteClosedStreams(); err != nil { + s.closeLocal(err) + } } // only send the error the handshakeChan when the handshake is not completed yet @@ -333,20 +381,19 @@ runLoop: s.handshakeChan <- handshakeEvent{err: closeErr.err} } s.handleCloseError(closeErr) - close(s.runClosed) return closeErr.err } -func (s *session) WaitUntilClosed() { - <-s.runClosed +func (s *session) Context() context.Context { + return s.ctx } func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { - deadline = s.lastNetworkActivityTime.Add(s.idleTimeout() / 2) + deadline = s.lastNetworkActivityTime.Add(s.peerParams.IdleTimeout / 2) } else { - deadline = s.lastNetworkActivityTime.Add(s.idleTimeout()) + deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout) } if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { @@ -366,16 +413,9 @@ func (s *session) maybeResetTimer() { s.timer.Reset(deadline) } -func (s *session) idleTimeout() time.Duration { - if s.handshakeComplete { - return s.connectionParameters.GetIdleConnectionStateLifetime() - } - return protocol.InitialIdleTimeout -} - func (s *session) handlePacketImpl(p *receivedPacket) error { if s.perspective == protocol.PerspectiveClient { - diversificationNonce := p.publicHeader.DiversificationNonce + diversificationNonce := p.header.DiversificationNonce if len(diversificationNonce) > 0 { s.cryptoSetup.SetDiversificationNonce(diversificationNonce) } @@ -388,7 +428,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { s.lastNetworkActivityTime = p.rcvTime s.keepAlivePingSent = false - hdr := p.publicHeader + hdr := p.header data := p.data // Calculate packet number @@ -405,6 +445,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } else { utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, packet.encryptionLevel) } + hdr.Log() } // if the decryption failed, this might be a packet sent by an attacker // don't update the remote address @@ -428,30 +469,35 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { return err } - return s.handleFrames(packet.frames) + return s.handleFrames(packet.frames, packet.encryptionLevel) } -func (s *session) handleFrames(fs []frames.Frame) error { +func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error { for _, ff := range fs { var err error - frames.LogFrame(ff, false) + wire.LogFrame(ff, false) switch frame := ff.(type) { - case *frames.StreamFrame: + case *wire.StreamFrame: err = s.handleStreamFrame(frame) - case *frames.AckFrame: - err = s.handleAckFrame(frame) - case *frames.ConnectionCloseFrame: + case *wire.AckFrame: + err = s.handleAckFrame(frame, encLevel) + case *wire.ConnectionCloseFrame: s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) - case *frames.GoawayFrame: + case *wire.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") - case *frames.StopWaitingFrame: - err = s.receivedPacketHandler.ReceivedStopWaiting(frame) - case *frames.RstStreamFrame: + case *wire.StopWaitingFrame: + // LeastUnacked is guaranteed to have LeastUnacked > 0 + // therefore this will never underflow + s.receivedPacketHandler.SetLowerLimit(frame.LeastUnacked - 1) + case *wire.RstStreamFrame: err = s.handleRstStreamFrame(frame) - case *frames.WindowUpdateFrame: - err = s.handleWindowUpdateFrame(frame) - case *frames.BlockedFrame: - case *frames.PingFrame: + case *wire.MaxDataFrame: + s.handleMaxDataFrame(frame) + case *wire.MaxStreamDataFrame: + err = s.handleMaxStreamDataFrame(frame) + case *wire.BlockedFrame: + case *wire.StreamBlockedFrame: + case *wire.PingFrame: default: return errors.New("Session BUG: unexpected frame type") } @@ -483,7 +529,10 @@ func (s *session) handlePacket(p *receivedPacket) { } } -func (s *session) handleStreamFrame(frame *frames.StreamFrame) error { +func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { + if frame.StreamID == s.version.CryptoStreamID() { + return s.cryptoStream.AddStreamFrame(frame) + } str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { return err @@ -496,21 +545,23 @@ func (s *session) handleStreamFrame(frame *frames.StreamFrame) error { return str.AddStreamFrame(frame) } -func (s *session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { - if frame.StreamID != 0 { - str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) - if err != nil { - return err - } - if str == nil { - return errWindowUpdateOnClosedStream - } - } - _, err := s.flowControlManager.UpdateWindow(frame.StreamID, frame.ByteOffset) - return err +func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { + s.connFlowController.UpdateSendWindow(frame.ByteOffset) } -func (s *session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { +func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { + str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + return errWindowUpdateOnClosedStream + } + str.UpdateSendWindow(frame.ByteOffset) + return nil +} + +func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { return err @@ -518,13 +569,11 @@ func (s *session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { if str == nil { return errRstStreamOnInvalidStream } - - str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) - return s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) + return str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode), frame.ByteOffset) } -func (s *session) handleAckFrame(frame *frames.AckFrame) error { - return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) +func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { + return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime) } func (s *session) closeLocal(e error) { @@ -543,7 +592,7 @@ func (s *session) closeRemote(e error) { // It waits until the run loop has stopped before returning func (s *session) Close(e error) error { s.closeLocal(e) - <-s.runClosed + <-s.ctx.Done() return nil } @@ -564,6 +613,7 @@ func (s *session) handleCloseError(closeErr closeError) error { utils.Errorf("Closing session with error: %s", closeErr.err.Error()) } + s.cryptoStream.Cancel(quicErr) s.streamsMap.CloseWithError(quicErr) if closeErr.err == errCloseSessionForNewVersion { @@ -575,20 +625,34 @@ func (s *session) handleCloseError(closeErr closeError) error { return nil } - if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment { + if quicErr.ErrorCode == qerr.DecryptionFailure || + quicErr == handshake.ErrHOLExperiment || + quicErr == handshake.ErrNSTPExperiment { return s.sendPublicReset(s.lastRcvdPacketNumber) } return s.sendConnectionClose(quicErr) } +func (s *session) processTransportParameters(params *handshake.TransportParameters) { + s.peerParams = params + s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams) + if params.OmitConnectionID { + s.packer.SetOmitConnectionID() + } + s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) + s.streamsMap.Range(func(str streamI) { + str.UpdateSendWindow(params.StreamFlowControlWindow) + }) +} + func (s *session) sendPacket() error { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) - // Get WindowUpdate frames + // Get MAX_DATA and MAX_STREAM_DATA frames // this call triggers the flow controller to increase the flow control windows, if necessary - windowUpdateFrames := s.getWindowUpdateFrames() - for _, wuf := range windowUpdateFrames { - s.packer.QueueControlFrame(wuf) + windowUpdates := s.getWindowUpdates() + for _, f := range windowUpdates { + s.packer.QueueControlFrame(f) } ack := s.receivedPacketHandler.GetAckFrame() @@ -639,15 +703,10 @@ func (s *session) sendPacket() error { utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) // resend the frames that were in the packet for _, frame := range retransmitPacket.GetFramesForRetransmission() { + // TODO: only retransmit WINDOW_UPDATEs if they actually enlarge the window switch f := frame.(type) { - case *frames.StreamFrame: + case *wire.StreamFrame: s.streamFramer.AddFrameForRetransmission(f) - case *frames.WindowUpdateFrame: - // only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream - currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID) - if err == nil && f.ByteOffset >= currentOffset { - s.packer.QueueControlFrame(f) - } default: s.packer.QueueControlFrame(frame) } @@ -662,6 +721,10 @@ func (s *session) sendPacket() error { s.packer.QueueControlFrame(swf) } } + // add a retransmittable frame + if s.sentPacketHandler.ShouldSendRetransmittablePacket() { + s.packer.QueueControlFrame(&wire.PingFrame{}) + } packet, err := s.packer.PackPacket() if err != nil || packet == nil { return err @@ -671,10 +734,10 @@ func (s *session) sendPacket() error { } // send every window update twice - for _, f := range windowUpdateFrames { + for _, f := range windowUpdates { s.packer.QueueControlFrame(f) } - windowUpdateFrames = nil + windowUpdates = nil ack = nil } } @@ -682,7 +745,7 @@ func (s *session) sendPacket() error { func (s *session) sendPackedPacket(packet *packedPacket) error { defer putPacketBuffer(packet.raw) err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{ - PacketNumber: packet.number, + PacketNumber: packet.header.PacketNumber, Frames: packet.frames, Length: protocol.ByteCount(len(packet.raw)), EncryptionLevel: packet.encryptionLevel, @@ -696,7 +759,7 @@ func (s *session) sendPackedPacket(packet *packedPacket) error { func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) - packet, err := s.packer.PackConnectionClose(&frames.ConnectionCloseFrame{ + packet, err := s.packer.PackConnectionClose(&wire.ConnectionCloseFrame{ ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage, }) @@ -712,9 +775,10 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.number, len(packet.raw), s.connectionID, packet.encryptionLevel) + utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) + packet.header.Log() for _, frame := range packet.frames { - frames.LogFrame(frame, true) + wire.LogFrame(frame, true) } } @@ -748,42 +812,33 @@ func (s *session) WaitUntilHandshakeComplete() error { } func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { - s.packer.QueueControlFrame(&frames.RstStreamFrame{ + s.packer.QueueControlFrame(&wire.RstStreamFrame{ StreamID: id, ByteOffset: offset, }) s.scheduleSending() } -func (s *session) newStream(id protocol.StreamID) *stream { - // TODO: find a better solution for determining which streams contribute to connection level flow control - if id == 1 || id == 3 { - s.flowControlManager.NewStream(id, false) - } else { - s.flowControlManager.NewStream(id, true) +func (s *session) newStream(id protocol.StreamID) streamI { + var initialSendWindow protocol.ByteCount + if s.peerParams != nil { + initialSendWindow = s.peerParams.StreamFlowControlWindow } - return newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager) -} - -// garbageCollectStreams goes through all streams and removes EOF'ed streams -// from the streams map. -func (s *session) garbageCollectStreams() { - s.streamsMap.Iterate(func(str *stream) (bool, error) { - id := str.StreamID() - if str.finished() { - err := s.streamsMap.RemoveStream(id) - if err != nil { - return false, err - } - s.flowControlManager.RemoveStream(id) - } - return true, nil - }) + flowController := flowcontrol.NewStreamFlowController( + id, + s.version.StreamContributesToConnectionFlowControl(id), + s.connFlowController, + protocol.ReceiveStreamFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + initialSendWindow, + s.rttStats, + ) + return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) - return s.conn.Write(writePublicReset(s.connectionID, rejectedPacketNumber, 0)) + return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) } // scheduleSending signals that we have data for sending @@ -796,7 +851,7 @@ func (s *session) scheduleSending() { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.handshakeComplete { - utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.publicHeader, len(p.data)) + utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) return } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { @@ -805,10 +860,10 @@ func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { s.receivedTooManyUndecrytablePacketsTime = time.Now() s.maybeResetTimer() } - utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.publicHeader.PacketNumber) + utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) return } - utils.Infof("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber) + utils.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) s.undecryptablePackets = append(s.undecryptablePackets, p) } @@ -819,11 +874,20 @@ func (s *session) tryDecryptingQueuedPackets() { s.undecryptablePackets = s.undecryptablePackets[:0] } -func (s *session) getWindowUpdateFrames() []*frames.WindowUpdateFrame { - updates := s.flowControlManager.GetWindowUpdates() - res := make([]*frames.WindowUpdateFrame, len(updates)) - for i, u := range updates { - res[i] = &frames.WindowUpdateFrame{StreamID: u.StreamID, ByteOffset: u.Offset} +func (s *session) getWindowUpdates() []wire.Frame { + var res []wire.Frame + s.streamsMap.Range(func(str streamI) { + if offset := str.GetWindowUpdate(); offset != 0 { + res = append(res, &wire.MaxStreamDataFrame{ + StreamID: str.StreamID(), + ByteOffset: offset, + }) + } + }) + if offset := s.connFlowController.GetWindowUpdate(); offset != 0 { + res = append(res, &wire.MaxDataFrame{ + ByteOffset: offset, + }) } return res } @@ -836,3 +900,7 @@ func (s *session) LocalAddr() net.Addr { func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } + +func (s *session) GetVersion() protocol.VersionNumber { + return s.version +} diff --git a/vendor/github.com/lucas-clemente/quic-go/stream.go b/vendor/github.com/lucas-clemente/quic-go/stream.go index 45cf01a..806e7fc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream.go @@ -1,24 +1,46 @@ package quic import ( + "context" "fmt" "io" "net" "sync" "time" - "github.com/lucas-clemente/quic-go/flowcontrol" - "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) +type streamI interface { + Stream + + AddStreamFrame(*wire.StreamFrame) error + RegisterRemoteError(error, protocol.ByteCount) error + LenOfDataForWriting() protocol.ByteCount + GetDataForWriting(maxBytes protocol.ByteCount) []byte + GetWriteOffset() protocol.ByteCount + Finished() bool + Cancel(error) + ShouldSendFin() bool + SentFin() + // methods needed for flow control + GetWindowUpdate() protocol.ByteCount + UpdateSendWindow(protocol.ByteCount) + IsFlowControlBlocked() bool +} + // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { mutex sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + streamID protocol.StreamID onData func() // onReset is a callback that should send a RST_STREAM @@ -52,9 +74,13 @@ type stream struct { writeChan chan struct{} writeDeadline time.Time - flowControlManager flowcontrol.FlowControlManager + flowController flowcontrol.StreamFlowController + version protocol.VersionNumber } +var _ Stream = &stream{} +var _ streamI = &stream{} + type deadlineError struct{} func (deadlineError) Error() string { return "deadline exceeded" } @@ -67,16 +93,21 @@ var errDeadline net.Error = &deadlineError{} func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), - flowControlManager flowcontrol.FlowControlManager) *stream { - return &stream{ - onData: onData, - onReset: onReset, - streamID: StreamID, - flowControlManager: flowControlManager, - frameQueue: newStreamFrameSorter(), - readChan: make(chan struct{}, 1), - writeChan: make(chan struct{}, 1), + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *stream { + s := &stream{ + onData: onData, + onReset: onReset, + streamID: StreamID, + flowController: flowController, + frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), + writeChan: make(chan struct{}, 1), + version: version, } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s } // Read implements io.Reader. It is not thread safe! @@ -154,7 +185,7 @@ func (s *stream) Read(p []byte) (int, error) { // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream if !s.resetRemotely.Get() { - s.flowControlManager.AddBytesRead(s.streamID, protocol.ByteCount(m)) + s.flowController.AddBytesRead(protocol.ByteCount(m)) } s.onData() // so that a possible WINDOW_UPDATE is sent @@ -223,7 +254,11 @@ func (s *stream) Write(p []byte) (int, error) { return len(p), nil } -func (s *stream) lenOfDataForWriting() protocol.ByteCount { +func (s *stream) GetWriteOffset() protocol.ByteCount { + return s.writeOffset +} + +func (s *stream) LenOfDataForWriting() protocol.ByteCount { s.mutex.Lock() var l protocol.ByteCount if s.err == nil { @@ -233,7 +268,7 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount { return l } -func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { +func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { s.mutex.Lock() defer s.mutex.Unlock() @@ -241,6 +276,14 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { return nil } + // TODO(#657): Flow control for the crypto stream + if s.streamID != s.version.CryptoStreamID() { + maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) + } + if maxBytes == 0 { + return nil + } + var ret []byte if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { ret = s.dataForWriting[:maxBytes] @@ -251,12 +294,14 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { s.signalWrite() } s.writeOffset += protocol.ByteCount(len(ret)) + s.flowController.AddBytesSent(protocol.ByteCount(len(ret))) return ret } // Close implements io.Closer func (s *stream) Close() error { s.finishedWriting.Set(true) + s.ctxCancel() s.onData() return nil } @@ -268,29 +313,27 @@ func (s *stream) shouldSendReset() bool { return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() } -func (s *stream) shouldSendFin() bool { +func (s *stream) ShouldSendFin() bool { s.mutex.Lock() res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil s.mutex.Unlock() return res } -func (s *stream) sentFin() { +func (s *stream) SentFin() { s.finSent.Set(true) } // AddStreamFrame adds a new stream frame -func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { +func (s *stream) AddStreamFrame(frame *wire.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() - err := s.flowControlManager.UpdateHighestReceived(s.streamID, maxOffset) - if err != nil { + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.FinBit); err != nil { return err } s.mutex.Lock() defer s.mutex.Unlock() - err = s.frameQueue.Push(frame) - if err != nil && err != errDuplicateStreamData { + if err := s.frameQueue.Push(frame); err != nil && err != errDuplicateStreamData { return err } s.signalRead() @@ -344,7 +387,7 @@ func (s *stream) SetDeadline(t time.Time) error { // CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset func (s *stream) CloseRemote(offset protocol.ByteCount) { - s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) + s.AddStreamFrame(&wire.StreamFrame{FinBit: true, Offset: offset}) } // Cancel is called by session to indicate that an error occurred @@ -352,6 +395,7 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) { func (s *stream) Cancel(err error) { s.mutex.Lock() s.cancelled.Set(true) + s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err @@ -368,6 +412,7 @@ func (s *stream) Reset(err error) { } s.mutex.Lock() s.resetLocally.Set(true) + s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err @@ -382,29 +427,34 @@ func (s *stream) Reset(err error) { } // resets the stream remotely -func (s *stream) RegisterRemoteError(err error) { +func (s *stream) RegisterRemoteError(err error, offset protocol.ByteCount) error { if s.resetRemotely.Get() { - return + return nil } s.mutex.Lock() s.resetRemotely.Set(true) + s.ctxCancel() // errors must not be changed! if s.err == nil { s.err = err s.signalWrite() } + if err := s.flowController.UpdateHighestReceived(offset, true); err != nil { + return err + } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset) s.rstSent.Set(true) } s.mutex.Unlock() + return nil } func (s *stream) finishedWriteAndSentFin() bool { return s.finishedWriting.Get() && s.finSent.Get() } -func (s *stream) finished() bool { +func (s *stream) Finished() bool { return s.cancelled.Get() || (s.finishedReading.Get() && s.finishedWriteAndSentFin()) || (s.resetRemotely.Get() && s.rstSent.Get()) || @@ -412,6 +462,22 @@ func (s *stream) finished() bool { (s.finishedWriteAndSentFin() && s.resetRemotely.Get()) } +func (s *stream) Context() context.Context { + return s.ctx +} + func (s *stream) StreamID() protocol.StreamID { return s.streamID } + +func (s *stream) UpdateSendWindow(n protocol.ByteCount) { + s.flowController.UpdateSendWindow(n) +} + +func (s *stream) IsFlowControlBlocked() bool { + return s.flowController.IsBlocked() +} + +func (s *stream) GetWindowUpdate() protocol.ByteCount { + return s.flowController.GetWindowUpdate() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go b/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go index 4a50150..e3a3a80 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_frame_sorter.go @@ -3,13 +3,13 @@ package quic import ( "errors" - "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFrameSorter struct { - queuedFrames map[protocol.ByteCount]*frames.StreamFrame + queuedFrames map[protocol.ByteCount]*wire.StreamFrame readPosition protocol.ByteCount gaps *utils.ByteIntervalList } @@ -23,13 +23,13 @@ var ( func newStreamFrameSorter() *streamFrameSorter { s := streamFrameSorter{ gaps: utils.NewByteIntervalList(), - queuedFrames: make(map[protocol.ByteCount]*frames.StreamFrame), + queuedFrames: make(map[protocol.ByteCount]*wire.StreamFrame), } s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) return &s } -func (s *streamFrameSorter) Push(frame *frames.StreamFrame) error { +func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error { if frame.DataLen() == 0 { if frame.FinBit { s.queuedFrames[frame.Offset] = frame @@ -143,7 +143,7 @@ func (s *streamFrameSorter) Push(frame *frames.StreamFrame) error { return nil } -func (s *streamFrameSorter) Pop() *frames.StreamFrame { +func (s *streamFrameSorter) Pop() *wire.StreamFrame { frame := s.Head() if frame != nil { s.readPosition += frame.DataLen() @@ -152,7 +152,7 @@ func (s *streamFrameSorter) Pop() *frames.StreamFrame { return frame } -func (s *streamFrameSorter) Head() *frames.StreamFrame { +func (s *streamFrameSorter) Head() *wire.StreamFrame { frame, ok := s.queuedFrames[s.readPosition] if ok { return frame diff --git a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go index 20f82e3..8928e49 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream_framer.go +++ b/vendor/github.com/lucas-clemente/quic-go/stream_framer.go @@ -1,38 +1,43 @@ package quic import ( - "github.com/lucas-clemente/quic-go/flowcontrol" - "github.com/lucas-clemente/quic-go/frames" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFramer struct { - streamsMap *streamsMap + streamsMap *streamsMap + cryptoStream streamI - flowControlManager flowcontrol.FlowControlManager + connFlowController flowcontrol.ConnectionFlowController - retransmissionQueue []*frames.StreamFrame - blockedFrameQueue []*frames.BlockedFrame + retransmissionQueue []*wire.StreamFrame + blockedFrameQueue []wire.Frame } -func newStreamFramer(streamsMap *streamsMap, flowControlManager flowcontrol.FlowControlManager) *streamFramer { +func newStreamFramer( + cryptoStream streamI, + streamsMap *streamsMap, + cfc flowcontrol.ConnectionFlowController, +) *streamFramer { return &streamFramer{ streamsMap: streamsMap, - flowControlManager: flowControlManager, + cryptoStream: cryptoStream, + connFlowController: cfc, } } -func (f *streamFramer) AddFrameForRetransmission(frame *frames.StreamFrame) { +func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) { f.retransmissionQueue = append(f.retransmissionQueue, frame) } -func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*frames.StreamFrame { +func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame { fs, currentLen := f.maybePopFramesForRetransmission(maxLen) return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) } -func (f *streamFramer) PopBlockedFrame() *frames.BlockedFrame { +func (f *streamFramer) PopBlockedFrame() wire.Frame { if len(f.blockedFrameQueue) == 0 { return nil } @@ -46,28 +51,24 @@ func (f *streamFramer) HasFramesForRetransmission() bool { } func (f *streamFramer) HasCryptoStreamFrame() bool { - // TODO(#657): Flow control - cs, _ := f.streamsMap.GetOrOpenStream(1) - return cs.lenOfDataForWriting() > 0 + return f.cryptoStream.LenOfDataForWriting() > 0 } // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. -// TODO(#657): Flow control -func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *frames.StreamFrame { +func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { if !f.HasCryptoStreamFrame() { return nil } - cs, _ := f.streamsMap.GetOrOpenStream(1) - frame := &frames.StreamFrame{ - StreamID: 1, - Offset: cs.writeOffset, + frame := &wire.StreamFrame{ + StreamID: f.cryptoStream.StreamID(), + Offset: f.cryptoStream.GetWriteOffset(), } frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error - frame.Data = cs.getDataForWriting(maxLen - frameHeaderBytes) + frame.Data = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) return frame } -func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*frames.StreamFrame, currentLen protocol.ByteCount) { +func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { for len(f.retransmissionQueue) > 0 { frame := f.retransmissionQueue[0] frame.DataLenPresent = true @@ -93,63 +94,48 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount return } -func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*frames.StreamFrame) { - frame := &frames.StreamFrame{DataLenPresent: true} +func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []*wire.StreamFrame) { + frame := &wire.StreamFrame{DataLenPresent: true} var currentLen protocol.ByteCount - fn := func(s *stream) (bool, error) { - if s == nil || s.streamID == 1 /* crypto stream is handled separately */ { + fn := func(s streamI) (bool, error) { + if s == nil { return true, nil } - frame.StreamID = s.streamID + frame.StreamID = s.StreamID() + frame.Offset = s.GetWriteOffset() // not perfect, but thread-safe since writeOffset is only written when getting data - frame.Offset = s.writeOffset frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error if currentLen+frameHeaderBytes > maxBytes { return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here } maxLen := maxBytes - currentLen - frameHeaderBytes - var sendWindowSize protocol.ByteCount - lenStreamData := s.lenOfDataForWriting() - if lenStreamData != 0 { - sendWindowSize, _ = f.flowControlManager.SendWindowSize(s.streamID) - maxLen = utils.MinByteCount(maxLen, sendWindowSize) - } - - if maxLen == 0 { - return true, nil - } - var data []byte - if lenStreamData != 0 { - // Only getDataForWriting() if we didn't have data earlier, so that we - // don't send without FC approval (if a Write() raced). - data = s.getDataForWriting(maxLen) + if s.LenOfDataForWriting() > 0 { + data = s.GetDataForWriting(maxLen) } // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. - shouldSendFin := s.shouldSendFin() + shouldSendFin := s.ShouldSendFin() if data == nil && !shouldSendFin { return true, nil } if shouldSendFin { frame.FinBit = true - s.sentFin() + s.SentFin() } frame.Data = data - f.flowControlManager.AddBytesSent(s.streamID, protocol.ByteCount(len(data))) // Finally, check if we are now FC blocked and should queue a BLOCKED frame - if f.flowControlManager.RemainingConnectionWindowSize() == 0 { - // We are now connection-level FC blocked - f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: 0}) - } else if !frame.FinBit && sendWindowSize-frame.DataLen() == 0 { - // We are now stream-level FC blocked - f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: s.StreamID()}) + if !frame.FinBit && s.IsFlowControlBlocked() { + f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.StreamBlockedFrame{StreamID: s.StreamID()}) + } + if f.connFlowController.IsBlocked() { + f.blockedFrameQueue = append(f.blockedFrameQueue, &wire.BlockedFrame{}) } res = append(res, frame) @@ -159,17 +145,16 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] return false, nil } - frame = &frames.StreamFrame{DataLenPresent: true} + frame = &wire.StreamFrame{DataLenPresent: true} return true, nil } f.streamsMap.RoundRobinIterate(fn) - return } // maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified. -func maybeSplitOffFrame(frame *frames.StreamFrame, n protocol.ByteCount) *frames.StreamFrame { +func maybeSplitOffFrame(frame *wire.StreamFrame, n protocol.ByteCount) *wire.StreamFrame { if n >= frame.DataLen() { return nil } @@ -179,7 +164,7 @@ func maybeSplitOffFrame(frame *frames.StreamFrame, n protocol.ByteCount) *frames frame.Offset += n }() - return &frames.StreamFrame{ + return &wire.StreamFrame{ FinBit: false, StreamID: frame.StreamID, Offset: frame.Offset, diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map.go b/vendor/github.com/lucas-clemente/quic-go/streams_map.go index 74be17e..df5b4c9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map.go +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map.go @@ -5,21 +5,20 @@ import ( "fmt" "sync" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" ) type streamsMap struct { mutex sync.RWMutex - perspective protocol.Perspective - connectionParameters handshake.ConnectionParametersManager + perspective protocol.Perspective - streams map[protocol.StreamID]*stream + streams map[protocol.StreamID]streamI // needed for round-robin scheduling openStreams []protocol.StreamID - roundRobinIndex uint32 + roundRobinIndex int nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID @@ -33,32 +32,42 @@ type streamsMap struct { numOutgoingStreams uint32 numIncomingStreams uint32 + maxIncomingStreams uint32 + maxOutgoingStreams uint32 } -type streamLambda func(*stream) (bool, error) -type newStreamLambda func(protocol.StreamID) *stream +type streamLambda func(streamI) (bool, error) +type newStreamLambda func(protocol.StreamID) streamI -var ( - errMapAccess = errors.New("streamsMap: Error accessing the streams map") -) +var errMapAccess = errors.New("streamsMap: Error accessing the streams map") -func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap { +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap { + // add some tolerance to the maximum incoming streams value + maxStreams := uint32(protocol.MaxIncomingStreams) + maxIncomingStreams := utils.MaxUint32( + maxStreams+protocol.MaxStreamsMinimumIncrement, + uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), + ) sm := streamsMap{ - perspective: pers, - streams: map[protocol.StreamID]*stream{}, - openStreams: make([]protocol.StreamID, 0), - newStream: newStream, - connectionParameters: connectionParameters, + perspective: pers, + streams: make(map[protocol.StreamID]streamI), + openStreams: make([]protocol.StreamID, 0), + newStream: newStream, + maxIncomingStreams: maxIncomingStreams, } sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex + nextOddStream := protocol.StreamID(1) + if ver.CryptoStreamID() == protocol.StreamID(1) { + nextOddStream = 3 + } if pers == protocol.PerspectiveClient { - sm.nextStream = 1 + sm.nextStream = nextOddStream sm.nextStreamToAccept = 2 } else { sm.nextStream = 2 - sm.nextStreamToAccept = 1 + sm.nextStreamToAccept = nextOddStream } return &sm @@ -66,7 +75,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. -func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { +func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { m.mutex.RLock() s, ok := m.streams[id] m.mutex.RUnlock() @@ -124,8 +133,8 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { return m.streams[id], nil } -func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { - if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { +func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) { + if m.numIncomingStreams >= m.maxIncomingStreams { return nil, qerr.TooManyOpenStreams } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { @@ -147,9 +156,9 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { return s, nil } -func (m *streamsMap) openStreamImpl() (*stream, error) { +func (m *streamsMap) openStreamImpl() (streamI, error) { id := m.nextStream - if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { + if m.numOutgoingStreams >= m.maxOutgoingStreams { return nil, qerr.TooManyOpenStreams } @@ -166,7 +175,7 @@ func (m *streamsMap) openStreamImpl() (*stream, error) { } // OpenStream opens the next available stream -func (m *streamsMap) OpenStream() (*stream, error) { +func (m *streamsMap) OpenStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -176,7 +185,7 @@ func (m *streamsMap) OpenStream() (*stream, error) { return m.openStreamImpl() } -func (m *streamsMap) OpenStreamSync() (*stream, error) { +func (m *streamsMap) OpenStreamSync() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -197,10 +206,10 @@ func (m *streamsMap) OpenStreamSync() (*stream, error) { // AcceptStream returns the next stream opened by the peer // it blocks until a new stream is opened -func (m *streamsMap) AcceptStream() (*stream, error) { +func (m *streamsMap) AcceptStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() - var str *stream + var str streamI for { var ok bool if m.closeErr != nil { @@ -216,50 +225,64 @@ func (m *streamsMap) AcceptStream() (*stream, error) { return str, nil } -func (m *streamsMap) Iterate(fn streamLambda) error { +func (m *streamsMap) DeleteClosedStreams() error { m.mutex.Lock() defer m.mutex.Unlock() - openStreams := append([]protocol.StreamID{}, m.openStreams...) - - for _, streamID := range openStreams { - cont, err := m.iterateFunc(streamID, fn) - if err != nil { - return err + var numDeletedStreams int + // for every closed stream, the streamID is replaced by 0 in the openStreams slice + for i, streamID := range m.openStreams { + str, ok := m.streams[streamID] + if !ok { + return errMapAccess } - if !cont { - break + if !str.Finished() { + continue + } + numDeletedStreams++ + m.openStreams[i] = 0 + if streamID%2 == 0 { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- + } + delete(m.streams, streamID) + } + + if numDeletedStreams == 0 { + return nil + } + + // remove all 0s (representing closed streams) from the openStreams slice + // and adjust the roundRobinIndex + var j int + for i, id := range m.openStreams { + if i != j { + m.openStreams[j] = m.openStreams[i] + } + if id != 0 { + j++ + } else if j < m.roundRobinIndex { + m.roundRobinIndex-- } } + m.openStreams = m.openStreams[:len(m.openStreams)-numDeletedStreams] + m.openStreamOrErrCond.Signal() return nil } // RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false // It uses a round-robin-like scheduling to ensure that every stream is considered fairly -// It prioritizes the crypto- and the header-stream (StreamIDs 1 and 3) +// It prioritizes the the header-stream (StreamID 3) func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() - numStreams := uint32(len(m.streams)) + numStreams := len(m.streams) startIndex := m.roundRobinIndex - for _, i := range []protocol.StreamID{1, 3} { - cont, err := m.iterateFunc(i, fn) - if err != nil && err != errMapAccess { - return err - } - if !cont { - return nil - } - } - - for i := uint32(0); i < numStreams; i++ { + for i := 0; i < numStreams; i++ { streamID := m.openStreams[(i+startIndex)%numStreams] - if streamID == 1 || streamID == 3 { - continue - } - cont, err := m.iterateFunc(streamID, fn) if err != nil { return err @@ -272,6 +295,18 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { return nil } +// Range executes a callback for all streams, in pseudo-random order +func (m *streamsMap) Range(cb func(s streamI)) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + for _, s := range m.streams { + if s != nil { + cb(s) + } + } +} + func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (bool, error) { str, ok := m.streams[streamID] if !ok { @@ -280,7 +315,7 @@ func (m *streamsMap) iterateFunc(streamID protocol.StreamID, fn streamLambda) (b return fn(str) } -func (m *streamsMap) putStream(s *stream) error { +func (m *streamsMap) putStream(s streamI) error { id := s.StreamID() if _, ok := m.streams[id]; ok { return fmt.Errorf("a stream with ID %d already exists", id) @@ -291,36 +326,6 @@ func (m *streamsMap) putStream(s *stream) error { return nil } -// Attention: this function must only be called if a mutex has been acquired previously -func (m *streamsMap) RemoveStream(id protocol.StreamID) error { - s, ok := m.streams[id] - if !ok || s == nil { - return fmt.Errorf("attempted to remove non-existing stream: %d", id) - } - - if id%2 == 0 { - m.numOutgoingStreams-- - } else { - m.numIncomingStreams-- - } - - for i, s := range m.openStreams { - if s == id { - // delete the streamID from the openStreams slice - m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])] - // adjust round-robin index, if necessary - if uint32(i) < m.roundRobinIndex { - m.roundRobinIndex-- - } - break - } - } - - delete(m.streams, id) - m.openStreamOrErrCond.Signal() - return nil -} - func (m *streamsMap) CloseWithError(err error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -331,3 +336,9 @@ func (m *streamsMap) CloseWithError(err error) { m.streams[s].Cancel(err) } } + +func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.maxOutgoingStreams = limit +} diff --git a/vendor/golang.org/x/crypto/curve25519/curve25519.go b/vendor/golang.org/x/crypto/curve25519/curve25519.go index 2d14c2a..cb8fbc5 100644 --- a/vendor/golang.org/x/crypto/curve25519/curve25519.go +++ b/vendor/golang.org/x/crypto/curve25519/curve25519.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// We have a implementation in amd64 assembly so this code is only run on +// We have an implementation in amd64 assembly so this code is only run on // non-amd64 platforms. The amd64 assembly does not support gccgo. // +build !amd64 gccgo appengine diff --git a/vendor/vendor.json b/vendor/vendor.json index 354683e..cb3ce1f 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -62,6 +62,18 @@ "revision": "c91e78db502ff629614837aacb7aa4efa61c651a", "revisionTime": "2016-04-30T09:49:23Z" }, + { + "checksumSHA1": "vN3qcK0Nkk4AhOfhTEA8ZUMu2RY=", + "path": "github.com/aead/chacha20", + "revision": "8d6ce0550041f9d97e7f15ec27ed489f8bbbb0fb", + "revisionTime": "2017-06-14T05:10:14Z" + }, + { + "checksumSHA1": "zmmUB5W1Oz/czit5mixwbVQq31A=", + "path": "github.com/aead/chacha20/chacha", + "revision": "8d6ce0550041f9d97e7f15ec27ed489f8bbbb0fb", + "revisionTime": "2017-06-14T05:10:14Z" + }, { "checksumSHA1": "30PBqj9BW03KCVqASvLg3bR+xYc=", "path": "github.com/agl/ed25519/edwards25519", @@ -74,6 +86,18 @@ "revision": "5312a61534124124185d41f09206b9fef1d88403", "revisionTime": "2017-01-16T20:05:12Z" }, + { + "checksumSHA1": "p0SvpHmpHEGr2eYb1tohV4EZhD0=", + "path": "github.com/bifurcation/mint", + "revision": "64af8ab8ccb81bd5d4eab356f79ba0939117d9f6", + "revisionTime": "2017-10-31T22:03:52Z" + }, + { + "checksumSHA1": "usbuF7R80ixs5RS8ZM99C6OTDlc=", + "path": "github.com/bifurcation/mint/syntax", + "revision": "64af8ab8ccb81bd5d4eab356f79ba0939117d9f6", + "revisionTime": "2017-10-31T22:03:52Z" + }, { "checksumSHA1": "xqVDKHGnakGlcRhmWd1j9JYmfLc=", "path": "github.com/dchest/siphash", @@ -147,88 +171,85 @@ "revisionTime": "2016-09-12T19:31:07Z" }, { - "checksumSHA1": "hUI9uYDnlXeOY+SEAPViyVpgq6I=", + "checksumSHA1": "2RFzGcdTeQrFkkhT70WhQcMWF6c=", + "origin": "github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/aes12", "path": "github.com/lucas-clemente/aes12", - "revision": "25700e67be5c860bcc999137275b9ef8b65932bd", - "revisionTime": "2016-12-15T15:22:28Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { "checksumSHA1": "ne1X+frkx5fJcpz9FaZPuUZ7amM=", + "origin": "github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/fnv128a", "path": "github.com/lucas-clemente/fnv128a", - "revision": "393af48d391698c6ae4219566bfbdfef67269997", - "revisionTime": "2016-05-04T15:23:51Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "Z/PMIFIZ3o0/HabtyLwTHKKdvb4=", + "checksumSHA1": "8wGsJmhHz2l6XAcpoJ4NiF2zXjY=", "path": "github.com/lucas-clemente/quic-go", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { "checksumSHA1": "OA9E+y7g05x/mWJJHmA7oPxWKQo=", + "origin": "github.com/lucas-clemente/quic-go/vendor/github.com/lucas-clemente/quic-go-certificates", "path": "github.com/lucas-clemente/quic-go-certificates", - "revision": "d2f86524cced5186554df90d92529757d22c1cb6", - "revisionTime": "2016-08-23T09:51:56Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "quVHUmzDMdxO14W2xY1PizK1GME=", + "checksumSHA1": "w1JPHEDB2mQW4qpKZbWj5zsavuM=", "path": "github.com/lucas-clemente/quic-go/ackhandler", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "bHgFCrGrrdQLBrDN/iwhWnTY5c8=", + "checksumSHA1": "X4/8oG5h9ffPOXFCYl39QUNuN9U=", "path": "github.com/lucas-clemente/quic-go/congestion", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "mOdxhLhYusZpEUM4XYweq6/e7wc=", - "path": "github.com/lucas-clemente/quic-go/crypto", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "checksumSHA1": "lWOh0Q0bY/dd3G/MZDCkzk1dVTo=", + "path": "github.com/lucas-clemente/quic-go/internal/crypto", + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "IjJn3XeLM+MZ0vWkl2RxK+8y7ac=", - "path": "github.com/lucas-clemente/quic-go/flowcontrol", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "checksumSHA1": "cab7WtoBeOlbQGMEoTaKAjEbqZg=", + "path": "github.com/lucas-clemente/quic-go/internal/flowcontrol", + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "BvjNmqw285B/O0u9CapEiOG6nNU=", - "path": "github.com/lucas-clemente/quic-go/frames", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "checksumSHA1": "gWAXju/s95yWXKYLgCnxx+Ed22M=", + "path": "github.com/lucas-clemente/quic-go/internal/handshake", + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "T6xXdLavEsHmF+yhvMOjKn/1RJU=", - "path": "github.com/lucas-clemente/quic-go/h2quic", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "checksumSHA1": "VjW23wuTXH3REwjcwhfdQrJTUDI=", + "path": "github.com/lucas-clemente/quic-go/internal/protocol", + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "dnDC7JMEhC/8UtQnkExChY+zIeY=", - "path": "github.com/lucas-clemente/quic-go/handshake", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" - }, - { - "checksumSHA1": "I106jIrBkNesQcoPrgUJ8e2JUNY=", + "checksumSHA1": "V9xXEL18b0TrZoe+dqHAmK0beCY=", "path": "github.com/lucas-clemente/quic-go/internal/utils", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { - "checksumSHA1": "UoLAi6qeJrVxnEWxEyRs7VKLp78=", - "path": "github.com/lucas-clemente/quic-go/protocol", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "checksumSHA1": "Zs24W5nNpq5Adhl7069+094yW3A=", + "path": "github.com/lucas-clemente/quic-go/internal/wire", + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { "checksumSHA1": "RaG0jfP+lFzgedW98Bfp0Uri7EY=", "path": "github.com/lucas-clemente/quic-go/qerr", - "revision": "811315e31a0c190e7a9e86c84102e86c9ed2a072", - "revisionTime": "2017-07-29T00:10:52Z" + "revision": "214e95c655a1832cc64a182544708a0c4f70eea3", + "revisionTime": "2017-11-13T03:10:14Z" }, { "checksumSHA1": "ynJSWoF6v+3zMnh9R0QmmG6iGV8=", @@ -261,10 +282,10 @@ "revisionTime": "2017-07-28T12:36:07Z" }, { - "checksumSHA1": "MlEHIE/60sB86Lmf0MPTIXHzKzE=", + "checksumSHA1": "IQkUIOnvlf0tYloFx9mLaXSvXWQ=", "path": "golang.org/x/crypto/curve25519", - "revision": "558b6879de74bc843225cde5686419267ff707ca", - "revisionTime": "2017-07-28T12:36:07Z" + "revision": "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94", + "revisionTime": "2017-09-21T17:41:56Z" }, { "checksumSHA1": "i3dNaI+oCYeDGIsNj7LwecTsIAs=", @@ -427,6 +448,14 @@ "path": "gopkg.in/xtaci/smux.v1", "revision": "427dd804ce9fb0a9e7b27a628f68a124fb0d67a6", "revisionTime": "2016-11-29T15:03:00Z" + }, + { + "path": "hel", + "revision": "" + }, + { + "path": "help", + "revision": "" } ], "rootPath": "github.com/ginuerzh/gost" From 4cdf5d5b8b69e0712c0b8cefb926c92602779ac6 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 21 Nov 2017 13:45:26 +0800 Subject: [PATCH 2/3] add AES encryption support for QUIC --- cmd/gost/main.go | 22 +++++++++ quic.go | 117 ++++++++++++++++++++++++++++++++++++++++++++--- tls.go | 5 ++ ws.go | 11 +++++ 4 files changed, 149 insertions(+), 6 deletions(-) diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 7fd8de3..6cc244d 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/json" @@ -223,6 +224,18 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { TLSConfig: tlsCfg, KeepAlive: toBool(node.Values.Get("keepalive")), } + + timeout, _ := strconv.Atoi(node.Values.Get("timeout")) + config.Timeout = time.Duration(timeout) * time.Second + + idle, _ := strconv.Atoi(node.Values.Get("idle")) + config.IdleTimeout = time.Duration(idle) * time.Second + + if key := node.Values.Get("key"); key != "" { + sum := sha256.Sum256([]byte(key)) + config.Key = sum[:] + } + tr = gost.QUICTransporter(config) case "http2": tr = gost.HTTP2Transporter(tlsCfg) @@ -371,6 +384,15 @@ func (r *route) serve() error { } timeout, _ := strconv.Atoi(node.Values.Get("timeout")) config.Timeout = time.Duration(timeout) * time.Second + + idle, _ := strconv.Atoi(node.Values.Get("idle")) + config.IdleTimeout = time.Duration(idle) * time.Second + + if key := node.Values.Get("key"); key != "" { + sum := sha256.Sum256([]byte(key)) + config.Key = sum[:] + } + ln, err = gost.QUICListener(node.Addr, config) case "http2": ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) diff --git a/quic.go b/quic.go index 56a61c8..5353c77 100644 --- a/quic.go +++ b/quic.go @@ -1,8 +1,12 @@ package gost import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" "crypto/tls" "errors" + "io" "net" "sync" "time" @@ -55,10 +59,17 @@ func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Co session, ok := tr.sessions[addr] if !ok { - conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + var cc *net.UDPConn + cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return } + conn = cc + + if tr.config != nil && tr.config.Key != nil { + conn = &quicCipherConn{UDPConn: cc, key: tr.config.Key} + } + session = &quicSession{conn: conn} tr.sessions[addr] = session } @@ -107,7 +118,7 @@ func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) } func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) { - udpConn, ok := conn.(*net.UDPConn) + udpConn, ok := conn.(net.PacketConn) if !ok { return nil, errors.New("quic: wrong connection type") } @@ -118,6 +129,7 @@ func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICC quicConfig := &quic.Config{ HandshakeTimeout: config.Timeout, KeepAlive: config.KeepAlive, + IdleTimeout: config.IdleTimeout, } session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig) if err != nil { @@ -133,9 +145,11 @@ func (tr *quicTransporter) Multiplex() bool { // QUICConfig is the config for QUIC client and server type QUICConfig struct { - TLSConfig *tls.Config - Timeout time.Duration - KeepAlive bool + TLSConfig *tls.Config + Timeout time.Duration + KeepAlive bool + IdleTimeout time.Duration + Key []byte } type quicListener struct { @@ -152,13 +166,31 @@ func QUICListener(addr string, config *QUICConfig) (Listener, error) { quicConfig := &quic.Config{ HandshakeTimeout: config.Timeout, KeepAlive: config.KeepAlive, + IdleTimeout: config.IdleTimeout, } tlsConfig := config.TLSConfig if tlsConfig == nil { tlsConfig = DefaultTLSConfig } - ln, err := quic.ListenAddr(addr, tlsConfig, quicConfig) + + var conn net.PacketConn + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + lconn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + conn = lconn + + if config.Key != nil { + conn = &quicCipherConn{UDPConn: lconn, key: config.Key} + } + + ln, err := quic.Listen(conn, tlsConfig, quicConfig) if err != nil { return nil, err } @@ -241,3 +273,76 @@ func (c *quicConn) LocalAddr() net.Addr { func (c *quicConn) RemoteAddr() net.Addr { return c.raddr } + +type quicCipherConn struct { + *net.UDPConn + key []byte +} + +func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { + n, addr, err = conn.UDPConn.ReadFrom(data) + if err != nil { + return + } + b, err := conn.decrypt(data[:n]) + if err != nil { + return + } + + copy(data, b) + + return len(b), addr, nil +} + +func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { + b, err := conn.encrypt(data) + if err != nil { + return + } + + _, err = conn.UDPConn.WriteTo(b, addr) + if err != nil { + return + } + + return len(b), nil +} + +func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) { + c, err := aes.NewCipher(conn.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + return gcm.Seal(nonce, nonce, data, nil), nil +} + +func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) { + c, err := aes.NewCipher(conn.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return nil, errors.New("ciphertext too short") + } + + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + return gcm.Open(nil, nonce, ciphertext, nil) +} diff --git a/tls.go b/tls.go index f444330..35791d6 100644 --- a/tls.go +++ b/tls.go @@ -56,6 +56,11 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] + if session != nil && session.session != nil && session.session.IsClosed() { + session.Close() + delete(tr.sessions, addr) + ok = false + } if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) diff --git a/ws.go b/ws.go index fa9c043..0b9c61f 100644 --- a/ws.go +++ b/ws.go @@ -158,6 +158,11 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] + if session != nil && session.session != nil && session.session.IsClosed() { + session.Close() + delete(tr.sessions, addr) + ok = false + } if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) @@ -193,6 +198,7 @@ func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( session = s tr.sessions[opts.Addr] = session } + cc, err := session.GetConn() if err != nil { session.Close() @@ -281,6 +287,11 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] + if session != nil && session.session != nil && session.session.IsClosed() { + session.Close() + delete(tr.sessions, addr) + ok = false + } if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) From e0066c9c0fe671e0b50f208c41305cf3c80d00b1 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 21 Nov 2017 14:31:18 +0800 Subject: [PATCH 3/3] Node filter only applies to node list of more than one node --- chain.go | 2 +- selector.go | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/chain.go b/chain.go index ca7d15b..9dc672b 100644 --- a/chain.go +++ b/chain.go @@ -53,7 +53,7 @@ func (c *Chain) NodeGroups() []*NodeGroup { } // LastNode returns the last node of the node list. -// If the chain is empty, an empty node is returns. +// If the chain is empty, an empty node will be returned. // If the last node is a node group, the first node in the group will be returned. func (c *Chain) LastNode() Node { if c.IsEmpty() { diff --git a/selector.go b/selector.go index bea2c65..15bbe33 100644 --- a/selector.go +++ b/selector.go @@ -16,7 +16,6 @@ var ( // NodeSelector as a mechanism to pick nodes and mark their status. type NodeSelector interface { Select(nodes []Node, opts ...SelectOption) (Node, error) - // Mark(node Node) } type defaultSelector struct { @@ -130,7 +129,7 @@ type FailFilter struct { // Filter filters nodes. func (f *FailFilter) Filter(nodes []Node) []Node { - if f.MaxFails <= 0 { + if len(nodes) <= 1 || f.MaxFails <= 0 { return nodes } nl := []Node{}