From ab87ca05fa5553dab59aa1623624ddefa66b408b Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Sat, 11 Feb 2017 18:34:35 +0800 Subject: [PATCH] update quic-go --- .../aes12/{cipher 2.go => cipher_2.go} | 0 .../github.com/lucas-clemente/quic-go/LICENSE | 2 + .../lucas-clemente/quic-go/README.md | 24 +- .../quic-go/ackhandler/interfaces.go | 4 +- .../quic-go/ackhandler/packet.go | 32 +- .../ackhandler/received_packet_handler.go | 149 +++--- .../ackhandler/received_packet_history.go | 79 ++- .../quic-go/ackhandler/sent_packet_handler.go | 4 +- .../lucas-clemente/quic-go/appveyor.yml | 4 +- .../lucas-clemente/quic-go/client.go | 219 ++++++++ .../lucas-clemente/quic-go/codecov.yml | 2 + .../quic-go/crypto/cert_chain.go | 84 +++ .../quic-go/crypto/cert_compression.go | 126 ++++- .../quic-go/crypto/cert_manager.go | 131 +++++ .../crypto/chacha20poly1305_aead_test.go | 71 --- .../quic-go/crypto/key_derivation.go | 38 +- .../quic-go/crypto/proof_source.go | 92 ---- .../quic-go/crypto/server_proof.go | 66 +++ .../lucas-clemente/quic-go/crypto/signer.go | 8 - .../flowcontrol/flow_control_manager.go | 83 ++- .../quic-go/flowcontrol/flow_controller.go | 93 +++- .../quic-go/flowcontrol/interface.go | 2 + .../quic-go/frames/ack_frame.go | 12 +- .../lucas-clemente/quic-go/frames/log.go | 17 +- .../lucas-clemente/quic-go/h2quic/client.go | 293 +++++++++++ .../quic-go/h2quic/gzipreader.go | 35 ++ .../lucas-clemente/quic-go/h2quic/request.go | 51 +- .../quic-go/h2quic/request_body.go | 29 ++ .../quic-go/h2quic/request_writer.go | 200 ++++++++ .../lucas-clemente/quic-go/h2quic/response.go | 111 ++++ .../h2quic/response_setuncompressed.go | 9 + .../h2quic/response_setuncompressed_go16.go | 9 + .../quic-go/h2quic/response_writer.go | 20 + .../quic-go/h2quic/roundtrip.go | 135 +++++ .../lucas-clemente/quic-go/h2quic/server.go | 16 +- .../connection_parameters_manager.go | 265 ++++++---- .../quic-go/handshake/crypto_setup_client.go | 485 ++++++++++++++++++ .../handshake/crypto_setup_interface.go | 16 + ...crypto_setup.go => crypto_setup_server.go} | 149 ++++-- .../quic-go/handshake/handshake_message.go | 11 +- .../quic-go/handshake/server_config.go | 20 +- .../quic-go/handshake/server_config_client.go | 148 ++++++ .../lucas-clemente/quic-go/handshake/tags.go | 2 + .../lucas-clemente/quic-go/packet_packer.go | 86 ++-- .../lucas-clemente/quic-go/packet_unpacker.go | 4 - .../quic-go/protocol/encryption_level.go | 14 + .../quic-go/protocol/perspective.go | 10 + .../quic-go/protocol/protocol.go | 6 + .../quic-go/protocol/server_parameters.go | 61 ++- .../quic-go/protocol/version.go | 26 +- .../lucas-clemente/quic-go/public_header.go | 182 +++++-- .../lucas-clemente/quic-go/public_reset.go | 38 ++ .../lucas-clemente/quic-go/server.go | 85 ++- .../lucas-clemente/quic-go/session.go | 298 +++++++---- .../lucas-clemente/quic-go/stream.go | 160 ++++-- .../lucas-clemente/quic-go/stream_framer.go | 2 +- .../lucas-clemente/quic-go/streams_map.go | 97 +++- .../lucas-clemente/quic-go/udp_conn.go | 14 +- .../lucas-clemente/quic-go/unpacked_packet.go | 27 + .../quic-go/utils/atomic_bool.go | 22 + .../quic-go/utils/connection_id.go | 18 + .../lucas-clemente/quic-go/utils/host.go | 27 + .../lucas-clemente/quic-go/utils/minmax.go | 8 + .../lucas-clemente/quic-go/utils/stream.go | 17 + .../lucas-clemente/quic-go/utils/utils.go | 11 - cmd/gost/vendor/vendor.json | 88 ++-- 66 files changed, 3786 insertions(+), 861 deletions(-) rename cmd/gost/vendor/github.com/lucas-clemente/aes12/{cipher 2.go => cipher_2.go} (100%) create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/client.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go delete mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead_test.go delete mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/proof_source.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go delete mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/signer.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed_go16.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_interface.go rename cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/{crypto_setup.go => crypto_setup_server.go} (60%) create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/unpacked_packet.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/atomic_bool.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/connection_id.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/host.go create mode 100644 cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/stream.go diff --git a/cmd/gost/vendor/github.com/lucas-clemente/aes12/cipher 2.go b/cmd/gost/vendor/github.com/lucas-clemente/aes12/cipher_2.go similarity index 100% rename from cmd/gost/vendor/github.com/lucas-clemente/aes12/cipher 2.go rename to cmd/gost/vendor/github.com/lucas-clemente/aes12/cipher_2.go diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/LICENSE b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/LICENSE index 9cf1062..51378be 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/LICENSE +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/LICENSE @@ -1,5 +1,7 @@ MIT License +Copyright (c) 2016 the quic-go authors & Google, Inc. + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/README.md b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/README.md index 93fe544..9d871f5 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/README.md +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/README.md @@ -14,6 +14,7 @@ quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) p Done: - Basic protocol with support for QUIC version 34-36 +- QUIC client - HTTP/2 support - Crypto (RSA / ECDSA certificates, Curve25519 for key exchange, AES-GCM or Chacha20-Poly1305 as stream cipher) - Loss detection and retransmission (currently fast retransmission & RTO) @@ -22,11 +23,10 @@ Done: Major TODOs: -- Security, especially DOS protections +- Security, especially DoS protections - Performance - Better packet loss detection - Connection migration -- QUIC client ## Guides @@ -38,20 +38,26 @@ Running tests: go test ./... -Running the example server: +### Running the example server go run example/main.go -www /var/www/ Using the `quic_client` from chromium: - quic_client --quic-version=32 --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io + quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io Using Chrome: /Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io +### Using the example client + + go run example/client/main.go https://quic.clemente.io + ## Usage +### As a server + See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go: ```go @@ -59,6 +65,16 @@ http.Handle("/", http.FileServer(http.Dir(wwwDir))) h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil) ``` +### As a client + +See the [example client](example/client/main.go). Use a `QuicRoundTripper` as a `Transport` in a `http.Client`. + +```go +http.Client{ + Transport: &h2quic.QuicRoundTripper{}, +} +``` + ## Building on Windows Due to the low Windows timer resolution (see [StackOverflow question](http://stackoverflow.com/questions/37706834/high-resolution-timers-millisecond-precision-in-go-on-windows)) available with Go 1.6.x, some optimizations might not work when compiled with this version of the compiler. Please use Go 1.7 on Windows. diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go index ee7ac59..1d1b4be 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go @@ -28,8 +28,8 @@ type SentPacketHandler interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { - ReceivedPacket(packetNumber protocol.PacketNumber) error + ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedStopWaiting(*frames.StopWaitingFrame) error - GetAckFrame(dequeue bool) (*frames.AckFrame, error) + GetAckFrame() *frames.AckFrame } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go index 6fe567c..d748547 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/packet.go @@ -19,31 +19,17 @@ type Packet struct { SendTime time.Time } -// GetStreamFramesForRetransmission gets all the streamframes for retransmission -func (p *Packet) GetStreamFramesForRetransmission() []*frames.StreamFrame { - var streamFrames []*frames.StreamFrame +// GetFramesForRetransmission gets all the frames for retransmission +func (p *Packet) GetFramesForRetransmission() []frames.Frame { + var fs []frames.Frame for _, frame := range p.Frames { - if streamFrame, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame { - streamFrames = append(streamFrames, streamFrame) - } - } - return streamFrames -} - -// GetControlFramesForRetransmission gets all the control frames for retransmission -func (p *Packet) GetControlFramesForRetransmission() []frames.Frame { - var controlFrames []frames.Frame - for _, frame := range p.Frames { - // omit ACKs - if _, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame { + switch frame.(type) { + case *frames.AckFrame: + continue + case *frames.StopWaitingFrame: continue } - - _, isAck := frame.(*frames.AckFrame) - _, isStopWaiting := frame.(*frames.StopWaitingFrame) - if !isAck && !isStopWaiting { - controlFrames = append(controlFrames, frame) - } + fs = append(fs, frame) } - return controlFrames + return fs } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go index dcc1b30..daebbfb 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_handler.go @@ -6,45 +6,48 @@ import ( "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" ) var ( // ErrDuplicatePacket occurres when a duplicate packet is received ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet") - // ErrMapAccess occurs when a NACK contains invalid NACK ranges - ErrMapAccess = qerr.Error(qerr.InvalidAckData, "Packet does not exist in PacketHistory") // ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting") ) -var ( - errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") - errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "") -) +var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") type receivedPacketHandler struct { - largestInOrderObserved protocol.PacketNumber - largestObserved protocol.PacketNumber - ignorePacketsBelow protocol.PacketNumber - currentAckFrame *frames.AckFrame - stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent + largestObserved protocol.PacketNumber + ignorePacketsBelow protocol.PacketNumber + largestObservedReceivedTime time.Time packetHistory *receivedPacketHistory - receivedTimes map[protocol.PacketNumber]time.Time - lowestInReceivedTimes protocol.PacketNumber + ackSendDelay time.Duration + + packetsReceivedSinceLastAck int + retransmittablePacketsReceivedSinceLastAck int + ackQueued bool + ackAlarm time.Time + ackAlarmResetCallback func(time.Time) + lastAck *frames.AckFrame } // NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler() ReceivedPacketHandler { +func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler { + // create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182 + timer := time.NewTimer(0) + <-timer.C + return &receivedPacketHandler{ - receivedTimes: make(map[protocol.PacketNumber]time.Time), - packetHistory: newReceivedPacketHistory(), + packetHistory: newReceivedPacketHistory(), + ackAlarmResetCallback: ackAlarmResetCallback, + ackSendDelay: protocol.AckSendDelay, } } -func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber) error { +func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error { if packetNumber == 0 { return errInvalidPacketNumber } @@ -55,30 +58,21 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe return ErrPacketSmallerThanLastStopWaiting } - _, ok := h.receivedTimes[packetNumber] - if packetNumber <= h.largestInOrderObserved || ok { + if h.packetHistory.IsDuplicate(packetNumber) { return ErrDuplicatePacket } - h.packetHistory.ReceivedPacket(packetNumber) - - h.stateChanged = true - h.currentAckFrame = nil + err := h.packetHistory.ReceivedPacket(packetNumber) + if err != nil { + return err + } if packetNumber > h.largestObserved { h.largestObserved = packetNumber + h.largestObservedReceivedTime = time.Now() } - if packetNumber == h.largestInOrderObserved+1 { - h.largestInOrderObserved = packetNumber - } - - h.receivedTimes[packetNumber] = time.Now() - - if protocol.PacketNumber(len(h.receivedTimes)) > protocol.MaxTrackedReceivedPackets { - return errTooManyOutstandingReceivedPackets - } - + h.maybeQueueAck(packetNumber, shouldInstigateAck) return nil } @@ -89,55 +83,84 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) } h.ignorePacketsBelow = f.LeastUnacked - 1 - h.garbageCollectReceivedTimes() - - // the LeastUnacked is the smallest packet number of any packet for which the sender is still awaiting an ack. So the largestInOrderObserved is one less than that - if f.LeastUnacked > h.largestInOrderObserved { - h.largestInOrderObserved = f.LeastUnacked - 1 - } h.packetHistory.DeleteBelow(f.LeastUnacked) - return nil } -func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, error) { - if !h.stateChanged { - return nil, nil +func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { + var ackAlarmSet bool + h.packetsReceivedSinceLastAck++ + + if shouldInstigateAck { + h.retransmittablePacketsReceivedSinceLastAck++ } - if dequeue { - h.stateChanged = false + // always ack the first packet + if h.lastAck == nil { + h.ackQueued = true } - if h.currentAckFrame != nil { - return h.currentAckFrame, nil + // Always send an ack every 20 packets in order to allow the peer to discard + // information from the SentPacketManager and provide an RTT measurement. + if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend { + h.ackQueued = true } - packetReceivedTime, ok := h.receivedTimes[h.largestObserved] - if !ok { - return nil, ErrMapAccess + // if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK + // note that it cannot be a duplicate because they're already filtered out by ReceivedPacket() + if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked { + h.ackQueued = true + } + + // check if a new missing range above the previously was created + if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked { + h.ackQueued = true + } + + if !h.ackQueued && shouldInstigateAck { + if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck { + h.ackQueued = true + } else { + if h.ackAlarm.IsZero() { + h.ackAlarm = time.Now().Add(h.ackSendDelay) + ackAlarmSet = true + } + } + } + + if h.ackQueued { + // cancel the ack alarm + h.ackAlarm = time.Time{} + ackAlarmSet = false + } + + if ackAlarmSet { + h.ackAlarmResetCallback(h.ackAlarm) + } +} + +func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame { + if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) { + return nil } ackRanges := h.packetHistory.GetAckRanges() - h.currentAckFrame = &frames.AckFrame{ + ack := &frames.AckFrame{ LargestAcked: h.largestObserved, LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber, - PacketReceivedTime: packetReceivedTime, + PacketReceivedTime: h.largestObservedReceivedTime, } if len(ackRanges) > 1 { - h.currentAckFrame.AckRanges = ackRanges + ack.AckRanges = ackRanges } - return h.currentAckFrame, nil -} + h.lastAck = ack + h.ackAlarm = time.Time{} + h.ackQueued = false + h.packetsReceivedSinceLastAck = 0 + h.retransmittablePacketsReceivedSinceLastAck = 0 -func (h *receivedPacketHandler) garbageCollectReceivedTimes() { - for i := h.lowestInReceivedTimes; i <= h.ignorePacketsBelow; i++ { - delete(h.receivedTimes, i) - } - if h.ignorePacketsBelow > h.lowestInReceivedTimes { - h.lowestInReceivedTimes = h.ignorePacketsBelow + 1 - } + return ack } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go index 14bb08f..d45fe6f 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/received_packet_history.go @@ -1,40 +1,54 @@ package ackhandler import ( - "sync" - "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/utils" ) type receivedPacketHistory struct { ranges *utils.PacketIntervalList - mutex sync.RWMutex + // the map is used as a replacement for a set here. The bool is always supposed to be set to true + receivedPacketNumbers map[protocol.PacketNumber]bool + lowestInReceivedPacketNumbers protocol.PacketNumber } +var ( + errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges") + errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets") +) + // newReceivedPacketHistory creates a new received packet history func newReceivedPacketHistory() *receivedPacketHistory { return &receivedPacketHistory{ - ranges: utils.NewPacketIntervalList(), + ranges: utils.NewPacketIntervalList(), + receivedPacketNumbers: make(map[protocol.PacketNumber]bool), } } // ReceivedPacket registers a packet with PacketNumber p and updates the ranges -func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) { - h.mutex.Lock() - defer h.mutex.Unlock() +func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error { + if h.ranges.Len() >= protocol.MaxTrackedReceivedAckRanges { + return errTooManyOutstandingReceivedAckRanges + } + + if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets { + return errTooManyOutstandingReceivedPackets + } + + h.receivedPacketNumbers[p] = true if h.ranges.Len() == 0 { h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) - return + return nil } for el := h.ranges.Back(); el != nil; el = el.Prev() { // p already included in an existing range. Nothing to do here if p >= el.Value.Start && p <= el.Value.End { - return + return nil } var rangeExtended bool @@ -52,46 +66,61 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) { if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges prev.Value.End = el.Value.End h.ranges.Remove(el) - return + return nil } - return // if the two ranges were not merge, we're done here + return nil // if the two ranges were not merge, we're done here } // create a new range at the end if p > el.Value.End { h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el) - return + return nil } } // create a new range at the beginning h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front()) + + return nil } +// DeleteBelow deletes all entries below the leastUnacked packet number func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) { - h.mutex.Lock() - defer h.mutex.Unlock() + h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, leastUnacked) nextEl := h.ranges.Front() for el := h.ranges.Front(); nextEl != nil; el = nextEl { nextEl = el.Next() if leastUnacked > el.Value.Start && leastUnacked <= el.Value.End { + for i := el.Value.Start; i < leastUnacked; i++ { // adjust start value of a range + delete(h.receivedPacketNumbers, i) + } el.Value.Start = leastUnacked - } - if el.Value.End < leastUnacked { // delete a whole range + } else if el.Value.End < leastUnacked { // delete a whole range + for i := el.Value.Start; i <= el.Value.End; i++ { + delete(h.receivedPacketNumbers, i) + } h.ranges.Remove(el) - } else { + } else { // no ranges affected. Nothing to do return } } } +// IsDuplicate determines if a packet should be regarded as a duplicate packet +// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed +func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool { + if p < h.lowestInReceivedPacketNumbers { + return true + } + + _, ok := h.receivedPacketNumbers[p] + return ok +} + // GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { - h.mutex.RLock() - defer h.mutex.RUnlock() - if h.ranges.Len() == 0 { return nil } @@ -104,3 +133,13 @@ func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { return ackRanges } + +func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange { + ackRange := frames.AckRange{} + if h.ranges.Len() > 0 { + r := h.ranges.Back().Value + ackRange.FirstPacketNumber = r.Start + ackRange.LastPacketNumber = r.End + } + return ackRange +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go index 531a99a..686fd9d 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/ackhandler/sent_packet_handler.go @@ -47,9 +47,7 @@ type sentPacketHandler struct { } // NewSentPacketHandler creates a new sentPacketHandler -func NewSentPacketHandler() SentPacketHandler { - rttStats := &congestion.RTTStats{} - +func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { congestion := congestion.NewCubicSender( congestion.DefaultClock{}, rttStats, diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/appveyor.yml b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/appveyor.yml index 837facc..4ed6cfa 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/appveyor.yml +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/appveyor.yml @@ -13,8 +13,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go install: - rmdir c:\go /s /q - - appveyor DownloadFile https://storage.googleapis.com/golang/go1.7.1.windows-amd64.zip - - 7z x go1.7.1.windows-amd64.zip -y -oC:\ > NUL + - appveyor DownloadFile https://storage.googleapis.com/golang/go1.7.5.windows-amd64.zip + - 7z x go1.7.5.windows-amd64.zip -y -oC:\ > NUL - set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin - echo %PATH% - echo %GOPATH% diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/client.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/client.go new file mode 100644 index 0000000..e5e6abe --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/client.go @@ -0,0 +1,219 @@ +package quic + +import ( + "bytes" + "crypto/tls" + "errors" + "net" + "strings" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" +) + +// A Client of QUIC +type Client struct { + addr *net.UDPAddr + conn *net.UDPConn + hostname string + + connectionID protocol.ConnectionID + version protocol.VersionNumber + versionNegotiated bool + closed uint32 // atomic bool + + tlsConfig *tls.Config + cryptoChangeCallback CryptoChangeCallback + versionNegotiateCallback VersionNegotiateCallback + + session packetHandler +} + +// VersionNegotiateCallback is called once the client has a negotiated version +type VersionNegotiateCallback func() error + +var errHostname = errors.New("Invalid hostname") + +var ( + errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") +) + +// NewClient makes a new client +func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { + udpAddr, err := net.ResolveUDPAddr("udp", host) + if err != nil { + return nil, err + } + + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + + connectionID, err := utils.GenerateConnectionID() + if err != nil { + return nil, err + } + + hostname, _, err := net.SplitHostPort(host) + if err != nil { + return nil, err + } + + client := &Client{ + addr: udpAddr, + conn: conn, + hostname: hostname, + version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default + connectionID: connectionID, + tlsConfig: tlsConfig, + cryptoChangeCallback: cryptoChangeCallback, + versionNegotiateCallback: versionNegotiateCallback, + } + + utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) + + err = client.createNewSession(nil) + if err != nil { + return nil, err + } + + return client, nil +} + +// Listen listens +func (c *Client) Listen() error { + for { + data := getPacketBuffer() + data = data[:protocol.MaxPacketSize] + + n, _, err := c.conn.ReadFromUDP(data) + if err != nil { + if strings.HasSuffix(err.Error(), "use of closed network connection") { + return nil + } + return err + } + data = data[:n] + + err = c.handlePacket(data) + if err != nil { + utils.Errorf("error handling packet: %s", err.Error()) + c.session.Close(err) + return err + } + } +} + +// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs) +func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) { + return c.session.OpenStream(id) +} + +// Close closes the connection +func (c *Client) Close(e error) error { + // Only close once + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return nil + } + + _ = c.session.Close(e) + return c.conn.Close() +} + +func (c *Client) handlePacket(packet []byte) error { + if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { + return qerr.PacketTooLarge + } + + rcvTime := time.Now() + + r := bytes.NewReader(packet) + + hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) + if err != nil { + return qerr.Error(qerr.InvalidPacketHeader, err.Error()) + } + hdr.Raw = packet[:len(packet)-r.Len()] + + // ignore delayed / duplicated version negotiation packets + if c.versionNegotiated && hdr.VersionFlag { + return nil + } + + // this is the first packet after the client sent a packet with the VersionFlag set + // if the server doesn't send a version negotiation packet, it supports the suggested version + if !hdr.VersionFlag && !c.versionNegotiated { + c.versionNegotiated = true + err = c.versionNegotiateCallback() + if err != nil { + return err + } + } + + if hdr.VersionFlag { + var hasCommonVersion bool // check if we're supporting any of the offered versions + for _, v := range hdr.SupportedVersions { + // check if the server sent the offered version in supported versions + if v == c.version { + return qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.") + } + if v != protocol.VersionUnsupported { + hasCommonVersion = true + } + } + if !hasCommonVersion { + utils.Infof("No common version found.") + return qerr.InvalidVersion + } + + ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions) + if !ok { + return qerr.VersionNegotiationMismatch + } + + utils.Infof("Switching to QUIC version %d", highestSupportedVersion) + c.version = highestSupportedVersion + c.versionNegotiated = true + + c.session.Close(errCloseSessionForNewVersion) + err = c.createNewSession(hdr.SupportedVersions) + if err != nil { + return err + } + err = c.versionNegotiateCallback() + if err != nil { + return err + } + + return nil // version negotiation packets have no payload + } + + c.session.handlePacket(&receivedPacket{ + remoteAddr: c.addr, + publicHeader: hdr, + data: packet[len(packet)-r.Len():], + rcvTime: rcvTime, + }) + return nil +} + +func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { + var err error + c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) + if err != nil { + return err + } + + go c.session.run() + return nil +} + +func (c *Client) streamCallback(session *Session, stream utils.Stream) {} + +func (c *Client) closeCallback(id protocol.ConnectionID) { + utils.Infof("Connection %x closed.", id) +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/codecov.yml b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/codecov.yml index 4e4e039..8fa7519 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/codecov.yml +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/codecov.yml @@ -2,6 +2,8 @@ coverage: round: nearest ignore: - ackhandler/packet_linkedlist.go + - h2quic/gzipreader.go + - h2quic/response.go - utils/byteinterval_linkedlist.go - utils/packetinterval_linkedlist.go status: diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go new file mode 100644 index 0000000..0fd905a --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_chain.go @@ -0,0 +1,84 @@ +package crypto + +import ( + "crypto/tls" + "errors" + "strings" +) + +// A CertChain holds a certificate and a private key +type CertChain interface { + SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) + GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error) + GetLeafCert(sni string) ([]byte, error) +} + +// proofSource stores a key and a certificate for the server proof +type certChain struct { + config *tls.Config +} + +var _ CertChain = &certChain{} + +var errNoMatchingCertificate = errors.New("no matching certificate found") + +// NewCertChain loads the key and cert from files +func NewCertChain(tlsConfig *tls.Config) CertChain { + return &certChain{config: tlsConfig} +} + +// SignServerProof signs CHLO and server config for use in the server proof +func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) { + cert, err := c.getCertForSNI(sni) + if err != nil { + return nil, err + } + + return signServerProof(cert, chlo, serverConfigData) +} + +// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc +func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { + cert, err := c.getCertForSNI(sni) + if err != nil { + return nil, err + } + return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes) +} + +// GetLeafCert gets the leaf certificate +func (c *certChain) GetLeafCert(sni string) ([]byte, error) { + cert, err := c.getCertForSNI(sni) + if err != nil { + return nil, err + } + return cert.Certificate[0], nil +} + +func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { + if c.config.GetCertificate != nil { + cert, err := c.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) + if err != nil { + return nil, err + } + if cert != nil { + return cert, nil + } + } + + if len(c.config.NameToCertificate) != 0 { + if cert, ok := c.config.NameToCertificate[sni]; ok { + return cert, nil + } + wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) + if cert, ok := c.config.NameToCertificate[wildcardSNI]; ok { + return cert, nil + } + } + + if len(c.config.Certificates) != 0 { + return &c.config.Certificates[0], nil + } + + return nil, errNoMatchingCertificate +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go index 8fdb257..f7676d5 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_compression.go @@ -22,8 +22,8 @@ const ( type entry struct { t entryType - h uint64 - i uint32 + h uint64 // set hash + i uint32 // index } func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { @@ -41,7 +41,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by chainHashes := make([]uint64, len(chain)) for i := range chain { - chainHashes[i] = hashCert(chain[i]) + chainHashes[i] = HashCert(chain[i]) } entries := buildEntries(chain, chainHashes, cachedHashes, setHashes) @@ -89,6 +89,111 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by return res.Bytes(), nil } +func decompressChain(data []byte) ([][]byte, error) { + var chain [][]byte + var entries []entry + r := bytes.NewReader(data) + + var numCerts int + var hasCompressedCerts bool + for { + entryTypeByte, err := r.ReadByte() + if entryTypeByte == 0 { + break + } + + et := entryType(entryTypeByte) + if err != nil { + return nil, err + } + + numCerts++ + + switch et { + case entryCached: + // we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain + return nil, errors.New("unexpected cached certificate") + case entryCommon: + e := entry{t: entryCommon} + e.h, err = utils.ReadUint64(r) + if err != nil { + return nil, err + } + e.i, err = utils.ReadUint32(r) + if err != nil { + return nil, err + } + certSet, ok := certSets[e.h] + if !ok { + return nil, errors.New("unknown certSet") + } + if e.i >= uint32(len(certSet)) { + return nil, errors.New("certificate not found in certSet") + } + entries = append(entries, e) + chain = append(chain, certSet[e.i]) + case entryCompressed: + hasCompressedCerts = true + entries = append(entries, entry{t: entryCompressed}) + chain = append(chain, nil) + default: + return nil, errors.New("unknown entryType") + } + } + + if numCerts == 0 { + return make([][]byte, 0, 0), nil + } + + if hasCompressedCerts { + uncompressedLength, err := utils.ReadUint32(r) + if err != nil { + fmt.Println(4) + return nil, err + } + + zlibDict := buildZlibDictForEntries(entries, chain) + gz, err := zlib.NewReaderDict(r, zlibDict) + if err != nil { + return nil, err + } + defer gz.Close() + + var totalLength uint32 + var certIndex int + for totalLength < uncompressedLength { + lenBytes := make([]byte, 4) + _, err := gz.Read(lenBytes) + if err != nil { + return nil, err + } + certLen := binary.LittleEndian.Uint32(lenBytes) + + cert := make([]byte, certLen) + n, err := gz.Read(cert) + if uint32(n) != certLen && err != nil { + return nil, err + } + + for { + if certIndex >= len(entries) { + return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate") + } + if entries[certIndex].t == entryCompressed { + chain[certIndex] = cert + certIndex++ + break + } + certIndex++ + } + + totalLength += 4 + certLen + } + } + + return chain, nil +} + func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry { res := make([]entry, len(chain)) chainLoop: @@ -149,8 +254,19 @@ func splitHashes(hashes []byte) ([]uint64, error) { return res, nil } -func hashCert(cert []byte) uint64 { - h := fnv.New64() +func getCommonCertificateHashes() []byte { + ccs := make([]byte, 8*len(certSets), 8*len(certSets)) + i := 0 + for certSetHash := range certSets { + binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash) + i++ + } + return ccs +} + +// HashCert calculates the FNV1a hash of a certificate +func HashCert(cert []byte) uint64 { + h := fnv.New64a() h.Write(cert) return h.Sum64() } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go new file mode 100644 index 0000000..5622784 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/cert_manager.go @@ -0,0 +1,131 @@ +package crypto + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "hash/fnv" + "time" + + "github.com/lucas-clemente/quic-go/qerr" +) + +// CertManager manages the certificates sent by the server +type CertManager interface { + SetData([]byte) error + GetCommonCertificateHashes() []byte + GetLeafCert() []byte + GetLeafCertHash() (uint64, error) + VerifyServerProof(proof, chlo, serverConfigData []byte) bool + Verify(hostname string) error +} + +type certManager struct { + chain []*x509.Certificate + config *tls.Config +} + +var _ CertManager = &certManager{} + +var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded") + +// NewCertManager creates a new CertManager +func NewCertManager(tlsConfig *tls.Config) CertManager { + return &certManager{config: tlsConfig} +} + +// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain +func (c *certManager) SetData(data []byte) error { + byteChain, err := decompressChain(data) + if err != nil { + return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid") + } + + chain := make([]*x509.Certificate, len(byteChain), len(byteChain)) + for i, data := range byteChain { + cert, err := x509.ParseCertificate(data) + if err != nil { + return err + } + chain[i] = cert + } + + c.chain = chain + return nil +} + +func (c *certManager) GetCommonCertificateHashes() []byte { + return getCommonCertificateHashes() +} + +// GetLeafCert returns the leaf certificate of the certificate chain +// it returns nil if the certificate chain has not yet been set +func (c *certManager) GetLeafCert() []byte { + if len(c.chain) == 0 { + return nil + } + return c.chain[0].Raw +} + +// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate +func (c *certManager) GetLeafCertHash() (uint64, error) { + leafCert := c.GetLeafCert() + if leafCert == nil { + return 0, errNoCertificateChain + } + + h := fnv.New64a() + _, err := h.Write(leafCert) + if err != nil { + return 0, err + } + return h.Sum64(), nil +} + +// VerifyServerProof verifies the signature of the server config +// it should only be called after the certificate chain has been set, otherwise it returns false +func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool { + if len(c.chain) == 0 { + return false + } + + return verifyServerProof(proof, c.chain[0], chlo, serverConfigData) +} + +// Verify verifies the certificate chain +func (c *certManager) Verify(hostname string) error { + if len(c.chain) == 0 { + return errNoCertificateChain + } + + if c.config != nil && c.config.InsecureSkipVerify { + return nil + } + + leafCert := c.chain[0] + + var opts x509.VerifyOptions + if c.config != nil { + opts.Roots = c.config.RootCAs + opts.DNSName = c.config.ServerName + if c.config.Time == nil { + opts.CurrentTime = time.Now() + } else { + opts.CurrentTime = c.config.Time() + } + } else { + opts.DNSName = hostname + } + + // the first certificate is the leaf certificate, all others are intermediates + if len(c.chain) > 1 { + intermediates := x509.NewCertPool() + for i := 1; i < len(c.chain); i++ { + intermediates.AddCert(c.chain[i]) + } + opts.Intermediates = intermediates + } + + _, err := leafCert.Verify(opts) + return err +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead_test.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead_test.go deleted file mode 100644 index 9d5197b..0000000 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/chacha20poly1305_aead_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// +build ignore - -package crypto - -import ( - "crypto/rand" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Chacha20poly1305", func() { - var ( - alice, bob AEAD - keyAlice, keyBob, ivAlice, ivBob []byte - ) - - BeforeEach(func() { - keyAlice = make([]byte, 32) - keyBob = make([]byte, 32) - ivAlice = make([]byte, 4) - ivBob = make([]byte, 4) - rand.Reader.Read(keyAlice) - rand.Reader.Read(keyBob) - rand.Reader.Read(ivAlice) - rand.Reader.Read(ivBob) - var err error - alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice) - Expect(err).ToNot(HaveOccurred()) - bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob) - Expect(err).ToNot(HaveOccurred()) - }) - - It("seals and opens", func() { - b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) - text, err := bob.Open(nil, b, 42, []byte("aad")) - Expect(err).ToNot(HaveOccurred()) - Expect(text).To(Equal([]byte("foobar"))) - }) - - It("seals and opens reverse", func() { - b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) - text, err := alice.Open(nil, b, 42, []byte("aad")) - Expect(err).ToNot(HaveOccurred()) - Expect(text).To(Equal([]byte("foobar"))) - }) - - It("has the proper length", func() { - b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad")) - Expect(b).To(HaveLen(6 + 12)) - }) - - It("fails with wrong aad", func() { - b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad")) - _, err := bob.Open(nil, b, 42, []byte("aad2")) - Expect(err).To(HaveOccurred()) - }) - - It("rejects wrong key and iv sizes", func() { - var err error - e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs" - _, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice) - Expect(err).To(MatchError(e)) - _, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice) - Expect(err).To(MatchError(e)) - _, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice) - Expect(err).To(MatchError(e)) - _, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:]) - Expect(err).To(MatchError(e)) - }) -}) diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go index 60648d8..470137f 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/key_derivation.go @@ -21,15 +21,21 @@ import ( // } // DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance -func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) { - otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16) +func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) { + var swap bool + if pers == protocol.PerspectiveClient { + swap = true + } + otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap) if err != nil { return nil, err } return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) } -func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int) ([]byte, []byte, []byte, []byte, error) { +// deriveKeys derives the keys and the IVs +// swap should be set true if generating the values for the client, and false for the server +func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) { var info bytes.Buffer if forwardSecure { info.Write([]byte("QUIC forward secure key expansion\x00")) @@ -47,17 +53,33 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol if _, err := io.ReadFull(r, s); err != nil { return nil, nil, nil, nil, err } - otherKey := s[:keyLen] - myKey := s[keyLen : 2*keyLen] - otherIV := s[2*keyLen : 2*keyLen+4] - myIV := s[2*keyLen+4:] + + key1 := s[:keyLen] + key2 := s[keyLen : 2*keyLen] + iv1 := s[2*keyLen : 2*keyLen+4] + iv2 := s[2*keyLen+4:] + + var otherKey, myKey []byte + var otherIV, myIV []byte if !forwardSecure { - if err := diversify(myKey, myIV, divNonce); err != nil { + if err := diversify(key2, iv2, divNonce); err != nil { return nil, nil, nil, nil, err } } + if swap { + otherKey = key2 + myKey = key1 + otherIV = iv2 + myIV = iv1 + } else { + otherKey = key1 + myKey = key2 + otherIV = iv1 + myIV = iv2 + } + return otherKey, myKey, otherIV, myIV, nil } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/proof_source.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/proof_source.go deleted file mode 100644 index 6af8072..0000000 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/proof_source.go +++ /dev/null @@ -1,92 +0,0 @@ -package crypto - -import ( - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/tls" - "errors" - "strings" -) - -// proofSource stores a key and a certificate for the server proof -type proofSource struct { - config *tls.Config -} - -// NewProofSource loads the key and cert from files -func NewProofSource(tlsConfig *tls.Config) (Signer, error) { - return &proofSource{config: tlsConfig}, nil -} - -// SignServerProof signs CHLO and server config for use in the server proof -func (ps *proofSource) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) { - cert, err := ps.getCertForSNI(sni) - if err != nil { - return nil, err - } - - hash := sha256.New() - hash.Write([]byte("QUIC CHLO and server config signature\x00")) - chloHash := sha256.Sum256(chlo) - hash.Write([]byte{32, 0, 0, 0}) - hash.Write(chloHash[:]) - hash.Write(serverConfigData) - - key, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil, errors.New("expected PrivateKey to implement crypto.Signer") - } - - opts := crypto.SignerOpts(crypto.SHA256) - - if _, ok = key.(*rsa.PrivateKey); ok { - opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} - } - - return key.Sign(rand.Reader, hash.Sum(nil), opts) -} - -// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc -func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { - cert, err := ps.getCertForSNI(sni) - if err != nil { - return nil, err - } - return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes) -} - -// GetLeafCert gets the leaf certificate -func (ps *proofSource) GetLeafCert(sni string) ([]byte, error) { - cert, err := ps.getCertForSNI(sni) - if err != nil { - return nil, err - } - return cert.Certificate[0], nil -} - -func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) { - if ps.config.GetCertificate != nil { - cert, err := ps.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) - if err != nil { - return nil, err - } - if cert != nil { - return cert, nil - } - } - if len(ps.config.NameToCertificate) != 0 { - if cert, ok := ps.config.NameToCertificate[sni]; ok { - return cert, nil - } - wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) - if cert, ok := ps.config.NameToCertificate[wildcardSNI]; ok { - return cert, nil - } - } - if len(ps.config.Certificates) != 0 { - return &ps.config.Certificates[0], nil - } - return nil, errors.New("no matching certificate found") -} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go new file mode 100644 index 0000000..456ad32 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/server_proof.go @@ -0,0 +1,66 @@ +package crypto + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/asn1" + "errors" + "math/big" +) + +type ecdsaSignature struct { + R, S *big.Int +} + +// signServerProof signs CHLO and server config for use in the server proof +func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) { + hash := sha256.New() + hash.Write([]byte("QUIC CHLO and server config signature\x00")) + chloHash := sha256.Sum256(chlo) + hash.Write([]byte{32, 0, 0, 0}) + hash.Write(chloHash[:]) + hash.Write(serverConfigData) + + key, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, errors.New("expected PrivateKey to implement crypto.Signer") + } + + opts := crypto.SignerOpts(crypto.SHA256) + + if _, ok = key.(*rsa.PrivateKey); ok { + opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} + } + + return key.Sign(rand.Reader, hash.Sum(nil), opts) +} + +// verifyServerProof verifies the server proof signature +func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool { + hash := sha256.New() + hash.Write([]byte("QUIC CHLO and server config signature\x00")) + chloHash := sha256.Sum256(chlo) + hash.Write([]byte{32, 0, 0, 0}) + hash.Write(chloHash[:]) + hash.Write(serverConfigData) + + // RSA + if cert.PublicKeyAlgorithm == x509.RSA { + opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256} + err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts) + return err == nil + } + + // ECDSA + signature := &ecdsaSignature{} + rest, err := asn1.Unmarshal(proof, signature) + if err != nil || len(rest) != 0 { + return false + } + return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S) +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/signer.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/signer.go deleted file mode 100644 index 0d9ba4e..0000000 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/crypto/signer.go +++ /dev/null @@ -1,8 +0,0 @@ -package crypto - -// A Signer holds a certificate and a private key -type Signer interface { - SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) - GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error) - GetLeafCert(sni string) ([]byte, error) -} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go index 0ab6363..a6cc577 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go @@ -2,38 +2,39 @@ package flowcontrol import ( "errors" + "fmt" "sync" + "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/utils" ) type flowControlManager struct { - connectionParametersManager *handshake.ConnectionParametersManager + connectionParameters handshake.ConnectionParametersManager + rttStats *congestion.RTTStats + streamFlowController map[protocol.StreamID]*flowController contributesToConnectionFlowControl map[protocol.StreamID]bool mutex sync.RWMutex } -var ( - // ErrStreamFlowControlViolation is a stream flow control violation - ErrStreamFlowControlViolation = errors.New("Stream level flow control violation") - // ErrConnectionFlowControlViolation is a connection level flow control violation - ErrConnectionFlowControlViolation = errors.New("Connection level flow control violation") -) +var _ FlowControlManager = &flowControlManager{} var errMapAccess = errors.New("Error accessing the flowController map.") // NewFlowControlManager creates a new flow control manager -func NewFlowControlManager(connectionParametersManager *handshake.ConnectionParametersManager) FlowControlManager { +func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager { fcm := flowControlManager{ - connectionParametersManager: connectionParametersManager, + connectionParameters: connectionParameters, + rttStats: rttStats, streamFlowController: make(map[protocol.StreamID]*flowController), contributesToConnectionFlowControl: make(map[protocol.StreamID]bool), } // initialize connection level flow controller - fcm.streamFlowController[0] = newFlowController(0, connectionParametersManager) + fcm.streamFlowController[0] = newFlowController(0, connectionParameters, rttStats) fcm.contributesToConnectionFlowControl[0] = false return &fcm } @@ -47,7 +48,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo return } - f.streamFlowController[streamID] = newFlowController(streamID, f.connectionParametersManager) + f.streamFlowController[streamID] = newFlowController(streamID, f.connectionParameters, f.rttStats) f.contributesToConnectionFlowControl[streamID] = contributesToConnectionFlow } @@ -59,6 +60,48 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { f.mutex.Unlock() } +// ResetStream should be called when receiving a RstStreamFrame +// it updates the byte offset to the value in the RstStreamFrame +// streamID must not be 0 here +func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + streamFlowController, err := f.getFlowController(streamID) + if err != nil { + return err + } + increment, err := streamFlowController.UpdateHighestReceived(byteOffset) + if err != nil { + return qerr.StreamDataAfterTermination + } + + if streamFlowController.CheckFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) + } + + if f.contributesToConnectionFlowControl[streamID] { + connectionFlowController := f.streamFlowController[0] + connectionFlowController.IncrementHighestReceived(increment) + if connectionFlowController.CheckFlowControlViolation() { + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow)) + } + } + + return nil +} + +func (f *flowControlManager) GetBytesSent(streamID protocol.StreamID) (protocol.ByteCount, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + fc, err := f.getFlowController(streamID) + if err != nil { + return 0, err + } + return fc.GetBytesSent(), nil +} + // UpdateHighestReceived updates the highest received byte offset for a stream // it adds the number of additional bytes to connection level flow control // streamID must not be 0 here @@ -70,17 +113,19 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b if err != nil { return err } - increment := streamFlowController.UpdateHighestReceived(byteOffset) + // UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered + // this error can be ignored here + increment, _ := streamFlowController.UpdateHighestReceived(byteOffset) if streamFlowController.CheckFlowControlViolation() { - return ErrStreamFlowControlViolation + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow)) } if f.contributesToConnectionFlowControl[streamID] { connectionFlowController := f.streamFlowController[0] connectionFlowController.IncrementHighestReceived(increment) if connectionFlowController.CheckFlowControlViolation() { - return ErrConnectionFlowControlViolation + return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow)) } } @@ -117,6 +162,16 @@ func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { return res } +func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + flowController, err := f.getFlowController(streamID) + if err != nil { + return 0, err + } + return flowController.receiveFlowControlWindow, nil +} + // streamID must not be 0 here func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { // Only lock the part reading from the map, since send-windows are only accessed from the session goroutine. diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go index 21020ad..4dadf7c 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go @@ -1,37 +1,52 @@ package flowcontrol import ( + "errors" + "time" + + "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/utils" ) type flowController struct { streamID protocol.StreamID - connectionParametersManager *handshake.ConnectionParametersManager + connectionParameters handshake.ConnectionParametersManager + rttStats *congestion.RTTStats bytesSent protocol.ByteCount sendFlowControlWindow protocol.ByteCount - bytesRead protocol.ByteCount - highestReceived protocol.ByteCount - receiveFlowControlWindow protocol.ByteCount - receiveFlowControlWindowIncrement protocol.ByteCount + lastWindowUpdateTime time.Time + + bytesRead protocol.ByteCount + highestReceived protocol.ByteCount + receiveFlowControlWindow protocol.ByteCount + receiveFlowControlWindowIncrement protocol.ByteCount + maxReceiveFlowControlWindowIncrement protocol.ByteCount } +// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously +var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") + // newFlowController gets a new flow controller -func newFlowController(streamID protocol.StreamID, connectionParametersManager *handshake.ConnectionParametersManager) *flowController { +func newFlowController(streamID protocol.StreamID, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController { fc := flowController{ - streamID: streamID, - connectionParametersManager: connectionParametersManager, + streamID: streamID, + connectionParameters: connectionParameters, + rttStats: rttStats, } if streamID == 0 { - fc.receiveFlowControlWindow = connectionParametersManager.GetReceiveConnectionFlowControlWindow() + fc.receiveFlowControlWindow = connectionParameters.GetReceiveConnectionFlowControlWindow() fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow + fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow() } else { - fc.receiveFlowControlWindow = connectionParametersManager.GetReceiveStreamFlowControlWindow() + fc.receiveFlowControlWindow = connectionParameters.GetReceiveStreamFlowControlWindow() fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow + fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow() } return &fc @@ -40,9 +55,9 @@ func newFlowController(streamID protocol.StreamID, connectionParametersManager * func (c *flowController) getSendFlowControlWindow() protocol.ByteCount { if c.sendFlowControlWindow == 0 { if c.streamID == 0 { - return c.connectionParametersManager.GetSendConnectionFlowControlWindow() + return c.connectionParameters.GetSendConnectionFlowControlWindow() } - return c.connectionParametersManager.GetSendStreamFlowControlWindow() + return c.connectionParameters.GetSendStreamFlowControlWindow() } return c.sendFlowControlWindow } @@ -51,6 +66,10 @@ func (c *flowController) AddBytesSent(n protocol.ByteCount) { c.bytesSent += n } +func (c *flowController) GetBytesSent() protocol.ByteCount { + return c.bytesSent +} + // UpdateSendWindow should be called after receiving a WindowUpdateFrame // it returns true if the window was actually updated func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { @@ -76,13 +95,19 @@ func (c *flowController) SendWindowOffset() protocol.ByteCount { // UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher // Should **only** be used for the stream-level FlowController -func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) protocol.ByteCount { +// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before +// This error occurs every time StreamFrames get reordered and has to be ignored in that case +// It should only be treated as an error when resetting a stream +func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) { + if byteOffset == c.highestReceived { + return 0, nil + } if byteOffset > c.highestReceived { increment := byteOffset - c.highestReceived c.highestReceived = byteOffset - return increment + return increment, nil } - return 0 + return 0, ErrReceivedSmallerByteOffset } // IncrementHighestReceived adds an increment to the highestReceived value @@ -99,14 +124,52 @@ func (c *flowController) AddBytesRead(n protocol.ByteCount) { // if so, it returns true and the offset of the window func (c *flowController) MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) { diff := c.receiveFlowControlWindow - c.bytesRead + // Chromium implements the same threshold if diff < (c.receiveFlowControlWindowIncrement / 2) { + c.maybeAdjustWindowIncrement() + c.lastWindowUpdateTime = time.Now() + c.receiveFlowControlWindow = c.bytesRead + c.receiveFlowControlWindowIncrement + return true, c.receiveFlowControlWindow } + return false, 0 } +// maybeAdjustWindowIncrement increases the receiveFlowControlWindowIncrement if we're sending WindowUpdates too often +func (c *flowController) maybeAdjustWindowIncrement() { + if c.lastWindowUpdateTime.IsZero() { + return + } + + rtt := c.rttStats.SmoothedRTT() + if rtt == 0 { + return + } + + timeSinceLastWindowUpdate := time.Now().Sub(c.lastWindowUpdateTime) + + // interval between the window updates is sufficiently large, no need to increase the increment + if timeSinceLastWindowUpdate >= 2*rtt { + return + } + + oldWindowSize := c.receiveFlowControlWindowIncrement + c.receiveFlowControlWindowIncrement = utils.MinByteCount(2*c.receiveFlowControlWindowIncrement, c.maxReceiveFlowControlWindowIncrement) + + // debug log, if the window size was actually increased + if oldWindowSize < c.receiveFlowControlWindowIncrement { + newWindowSize := c.receiveFlowControlWindowIncrement / (1 << 10) + if c.streamID == 0 { + utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize) + } else { + utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize) + } + } +} + func (c *flowController) CheckFlowControlViolation() bool { if c.highestReceived > c.receiveFlowControlWindow { return true diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go index 3f4b089..e1ea3fa 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go @@ -13,9 +13,11 @@ type FlowControlManager interface { NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) RemoveStream(streamID protocol.StreamID) // methods needed for receiving data + ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error GetWindowUpdates() []WindowUpdate + GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) // methods needed for sending data AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go index b68b448..0380541 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/ack_frame.go @@ -27,8 +27,10 @@ type AckFrame struct { LowestAcked protocol.PacketNumber AckRanges []AckRange // has to be ordered. The ACK range with the highest FirstPacketNumber goes first, the ACK range with the lowest FirstPacketNumber goes last + // time when the LargestAcked was receiveid + // this field Will not be set for received ACKs frames + PacketReceivedTime time.Time DelayTime time.Duration - PacketReceivedTime time.Time // only for received packets. Will not be modified for received ACKs frames } // ParseAckFrame reads an ACK frame @@ -83,7 +85,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, if err != nil { return nil, err } - if ackBlockLength < 1 { + if frame.LargestAcked > 0 && ackBlockLength < 1 { return nil, ErrInvalidFirstAckRange } @@ -141,7 +143,11 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber } else { - frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength) + if frame.LargestAcked == 0 { + frame.LowestAcked = 0 + } else { + frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength) + } } if !frame.validateAckRanges() { diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/log.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/log.go index ac548c6..1918db1 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/log.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/frames/log.go @@ -11,9 +11,18 @@ func LogFrame(frame Frame, sent bool) { if sent { dir = "->" } - if sf, ok := frame.(*StreamFrame); ok { - utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, sf.StreamID, sf.FinBit, sf.Offset, sf.DataLen(), sf.Offset+sf.DataLen()) - return + switch f := frame.(type) { + case *StreamFrame: + utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + case *StopWaitingFrame: + if sent { + utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) + } else { + utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) + } + case *AckFrame: + utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) + default: + utils.Debugf("\t%s %#v", dir, frame) } - utils.Debugf("\t%s %#v", dir, frame) } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go new file mode 100644 index 0000000..ca18021 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/client.go @@ -0,0 +1,293 @@ +package h2quic + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + "golang.org/x/net/idna" + + quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" +) + +type quicClient interface { + OpenStream(protocol.StreamID) (utils.Stream, error) + Close(error) error + Listen() error +} + +// Client is a HTTP2 client doing QUIC requests +type Client struct { + mutex sync.RWMutex + cryptoChangedCond sync.Cond + + t *QuicRoundTripper + + hostname string + encryptionLevel protocol.EncryptionLevel + + client quicClient + headerStream utils.Stream + headerErr *qerr.QuicError + highestOpenedStream protocol.StreamID + requestWriter *requestWriter + + responses map[protocol.StreamID]chan *http.Response +} + +var _ h2quicClient = &Client{} + +// NewClient creates a new client +func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) { + c := &Client{ + t: t, + hostname: authorityAddr("https", hostname), + highestOpenedStream: 3, + responses: make(map[protocol.StreamID]chan *http.Response), + } + c.cryptoChangedCond = sync.Cond{L: &c.mutex} + + var err error + c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback) + if err != nil { + return nil, err + } + + go c.client.Listen() + return c, nil +} + +func (c *Client) handleStreamCb(session *quic.Session, stream utils.Stream) { + utils.Debugf("Handling stream %d", stream.StreamID()) +} + +func (c *Client) cryptoChangeCallback(isForwardSecure bool) { + c.cryptoChangedCond.L.Lock() + defer c.cryptoChangedCond.L.Unlock() + + if isForwardSecure { + c.encryptionLevel = protocol.EncryptionForwardSecure + utils.Debugf("is forward secure") + } else { + c.encryptionLevel = protocol.EncryptionSecure + utils.Debugf("is secure") + } + c.cryptoChangedCond.Broadcast() +} + +func (c *Client) versionNegotiateCallback() error { + var err error + // once the version has been negotiated, open the header stream + c.headerStream, err = c.client.OpenStream(3) + if err != nil { + return err + } + c.requestWriter = newRequestWriter(c.headerStream) + go c.handleHeaderStream() + return nil +} + +func (c *Client) handleHeaderStream() { + decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) + h2framer := http2.NewFramer(nil, c.headerStream) + + var lastStream protocol.StreamID + + for { + frame, err := h2framer.ReadFrame() + if err != nil { + c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame") + break + } + lastStream = protocol.StreamID(frame.Header().StreamID) + hframe, ok := frame.(*http2.HeadersFrame) + if !ok { + c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") + break + } + mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} + mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) + if err != nil { + c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") + break + } + + c.mutex.RLock() + headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] + c.mutex.RUnlock() + if !ok { + c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) + break + } + + rsp, err := responseFromHeaders(mhframe) + if err != nil { + c.headerErr = qerr.Error(qerr.InternalError, err.Error()) + } + headerChan <- rsp + } + + // stop all running request + utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) + c.mutex.Lock() + for _, responseChan := range c.responses { + responseChan <- nil + } + c.mutex.Unlock() +} + +// Do executes a request and returns a response +func (c *Client) Do(req *http.Request) (*http.Response, error) { + // TODO: add port to address, if it doesn't have one + if req.URL.Scheme != "https" { + return nil, errors.New("quic http2: unsupported scheme") + } + if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { + utils.Debugf("%s vs %s", req.Host, c.hostname) + return nil, errors.New("h2quic Client BUG: Do called for the wrong client") + } + + hasBody := (req.Body != nil) + + c.mutex.Lock() + c.highestOpenedStream += 2 + dataStreamID := c.highestOpenedStream + for c.encryptionLevel != protocol.EncryptionForwardSecure { + c.cryptoChangedCond.Wait() + } + hdrChan := make(chan *http.Response) + c.responses[dataStreamID] = hdrChan + c.mutex.Unlock() + + // TODO: think about what to do with a TooManyOpenStreams error. Wait and retry? + dataStream, err := c.client.OpenStream(dataStreamID) + if err != nil { + c.Close(err) + return nil, err + } + + var requestedGzip bool + if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { + requestedGzip = true + } + // TODO: add support for trailers + endStream := !hasBody + err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip) + if err != nil { + c.Close(err) + return nil, err + } + + resc := make(chan error, 1) + if hasBody { + go func() { + resc <- c.writeRequestBody(dataStream, req.Body) + }() + } + + var res *http.Response + + var receivedResponse bool + var bodySent bool + + if !hasBody { + bodySent = true + } + + for !(bodySent && receivedResponse) { + select { + case res = <-hdrChan: + receivedResponse = true + c.mutex.Lock() + delete(c.responses, dataStreamID) + c.mutex.Unlock() + if res == nil { // an error occured on the header stream + c.Close(c.headerErr) + return nil, c.headerErr + } + case err := <-resc: + bodySent = true + if err != nil { + return nil, err + } + } + } + + // TODO: correctly set this variable + var streamEnded bool + isHead := (req.Method == "HEAD") + + res = setLength(res, isHead, streamEnded) + + if streamEnded || isHead { + res.Body = noBody + } else { + res.Body = dataStream + if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &gzipReader{body: res.Body} + setUncompressed(res) + } + } + + res.Request = req + + return res, nil +} + +func (c *Client) writeRequestBody(dataStream utils.Stream, body io.ReadCloser) (err error) { + defer func() { + cerr := body.Close() + if err == nil { + // TODO: what to do with dataStream here? Maybe reset it? + err = cerr + } + }() + + _, err = io.Copy(dataStream, body) + if err != nil { + // TODO: what to do with dataStream here? Maybe reset it? + return err + } + return dataStream.Close() +} + +// Close closes the client +func (c *Client) Close(e error) { + _ = c.client.Close(e) +} + +// copied from net/transport.go + +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + return net.JoinHostPort(host, port) +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go new file mode 100644 index 0000000..91c226b --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/gzipreader.go @@ -0,0 +1,35 @@ +package h2quic + +// copied from net/transport.go + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +import ( + "compress/gzip" + "io" +) + +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr + } + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go index f2e0fa5..911485e 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request.go @@ -1,15 +1,18 @@ package h2quic import ( + "crypto/tls" "errors" "net/http" "net/url" + "strconv" + "strings" "golang.org/x/net/http2/hpack" ) func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { - var path, authority, method string + var path, authority, method, contentLengthStr string httpHeaders := http.Header{} for _, h := range headers { @@ -20,6 +23,8 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { method = h.Value case ":authority": authority = h.Value + case "content-length": + contentLengthStr = h.Value default: if !h.IsPseudo() { httpHeaders.Add(h.Name, h.Value) @@ -27,6 +32,11 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { } } + // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 + if len(httpHeaders["Cookie"]) > 0 { + httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) + } + if len(path) == 0 || len(authority) == 0 || len(method) == 0 { return nil, errors.New(":path, :authority and :method must not be empty") } @@ -36,16 +46,35 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { return nil, err } + var contentLength int64 + if len(contentLengthStr) > 0 { + contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + return nil, err + } + } + return &http.Request{ - Method: method, - URL: u, - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - Header: httpHeaders, - Body: nil, - // ContentLength: -1, - Host: authority, - RequestURI: path, + Method: method, + URL: u, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + Header: httpHeaders, + Body: nil, + ContentLength: contentLength, + Host: authority, + RequestURI: path, + TLS: &tls.ConnectionState{}, }, nil } + +func hostnameFromRequest(req *http.Request) string { + if len(req.Host) > 0 { + return req.Host + } + if req.URL != nil { + return req.URL.Host + } + return "" +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go new file mode 100644 index 0000000..41ff5c6 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_body.go @@ -0,0 +1,29 @@ +package h2quic + +import ( + "io" + + "github.com/lucas-clemente/quic-go/utils" +) + +type requestBody struct { + requestRead bool + dataStream utils.Stream +} + +// make sure the requestBody can be used as a http.Request.Body +var _ io.ReadCloser = &requestBody{} + +func newRequestBody(stream utils.Stream) *requestBody { + return &requestBody{dataStream: stream} +} + +func (b *requestBody) Read(p []byte) (int, error) { + b.requestRead = true + return b.dataStream.Read(p) +} + +func (b *requestBody) Close() error { + // stream's Close() closes the write side, not the read side + return nil +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go new file mode 100644 index 0000000..e837b0f --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go @@ -0,0 +1,200 @@ +package h2quic + +import ( + "bytes" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + "golang.org/x/net/lex/httplex" + + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/utils" +) + +type requestWriter struct { + mutex sync.Mutex + headerStream utils.Stream + + henc *hpack.Encoder + hbuf bytes.Buffer // HPACK encoder writes into this +} + +const defaultUserAgent = "quic-go" + +func newRequestWriter(headerStream utils.Stream) *requestWriter { + rw := &requestWriter{ + headerStream: headerStream, + } + rw.henc = hpack.NewEncoder(&rw.hbuf) + return rw +} + +func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error { + // TODO: add support for trailers + // TODO: add support for gzip compression + // TODO: write continuation frames, if the header frame is too long + + w.mutex.Lock() + defer w.mutex.Unlock() + + w.encodeHeaders(req, requestGzip, "", actualContentLength(req)) + h2framer := http2.NewFramer(w.headerStream, nil) + return h2framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: uint32(dataStreamID), + EndHeaders: true, + EndStream: endStream, + BlockFragment: w.hbuf.Bytes(), + Priority: http2.PriorityParam{Weight: 0xff}, + }) +} + +// the rest of this files is copied from http2.Transport +func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { + w.hbuf.Reset() + + host := req.Host + if host == "" { + host = req.URL.Host + } + host, err := httplex.PunycodeHostPort(host) + if err != nil { + return nil, err + } + + var path string + if req.Method != "CONNECT" { + path = req.URL.RequestURI() + if !validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !validPseudoPath(path) { + if req.URL.Opaque != "" { + return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return nil, fmt.Errorf("invalid request :path %q", orig) + } + } + } + } + + // Check for any invalid headers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) + for k, vv := range req.Header { + if !httplex.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httplex.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } + } + } + + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). + w.writeHeader(":authority", host) + w.writeHeader(":method", req.Method) + if req.Method != "CONNECT" { + w.writeHeader(":path", path) + w.writeHeader(":scheme", req.URL.Scheme) + } + if trailers != "" { + w.writeHeader("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + lowKey := strings.ToLower(k) + switch lowKey { + case "host", "content-length": + // Host is :authority, already sent. + // Content-Length is automatic, set below. + continue + case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + continue + case "user-agent": + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + } + for _, v := range vv { + w.writeHeader(lowKey, v) + } + } + if shouldSendReqContentLength(req.Method, contentLength) { + w.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + w.writeHeader("accept-encoding", "gzip") + } + if !didUA { + w.writeHeader("user-agent", defaultUserAgent) + } + return w.hbuf.Bytes(), nil +} + +func (w *requestWriter) writeHeader(name, value string) { + utils.Debugf("http2: Transport encoding header %q = %q", name, value) + w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) +} + +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} + +func validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" +} + +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response.go new file mode 100644 index 0000000..13efdf8 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response.go @@ -0,0 +1,111 @@ +package h2quic + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "net/http" + "net/textproto" + "strconv" + "strings" + + "golang.org/x/net/http2" +) + +// copied from net/http2/transport.go + +var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") +var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) + +// from the handleResponse function +func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { + if f.Truncated { + return nil, errResponseHeaderListSize + } + + status := f.PseudoValue("status") + if status == "" { + return nil, errors.New("missing status pseudo header") + } + statusCode, err := strconv.Atoi(status) + if err != nil { + return nil, errors.New("malformed non-numeric status pseudo header") + } + + if statusCode == 100 { + // TODO: handle this + + // traceGot100Continue(cs.trace) + // if cs.on100 != nil { + // cs.on100() // forces any write delay timer to fire + // } + // cs.pastHeaders = false // do it all again + // return nil, nil + } + + header := make(http.Header) + res := &http.Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: header, + StatusCode: statusCode, + Status: status + " " + http.StatusText(statusCode), + } + for _, hf := range f.RegularFields() { + key := http.CanonicalHeaderKey(hf.Name) + if key == "Trailer" { + t := res.Trailer + if t == nil { + t = make(http.Header) + res.Trailer = t + } + foreachHeaderElement(hf.Value, func(v string) { + t[http.CanonicalHeaderKey(v)] = nil + }) + } else { + header[key] = append(header[key], hf.Value) + } + } + + return res, nil +} + +// continuation of the handleResponse function +func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { + if !streamEnded || isHead { + res.ContentLength = -1 + if clens := res.Header["Content-Length"]; len(clens) == 1 { + if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { + res.ContentLength = clen64 + } else { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } + } else if len(clens) > 1 { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } + } + return res +} + +// copied from net/http/server.go + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 2616 section 2.1 and calls fn for each non-empty element. +func foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed.go new file mode 100644 index 0000000..191a248 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed.go @@ -0,0 +1,9 @@ +// +build go1.7 + +package h2quic + +import "net/http" + +func setUncompressed(res *http.Response) { + res.Uncompressed = true +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed_go16.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed_go16.go new file mode 100644 index 0000000..7359f04 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_setuncompressed_go16.go @@ -0,0 +1,9 @@ +// +build !go1.7 + +package h2quic + +import "net/http" + +func setUncompressed(res *http.Response) { + // http.Response.Uncompressed was introduced in go 1.7 +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go index 0b5e930..7bd804f 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go @@ -21,6 +21,7 @@ type responseWriter struct { headerStreamMutex *sync.Mutex header http.Header + status int // status code passed to WriteHeader headerWritten bool } @@ -43,6 +44,7 @@ func (w *responseWriter) WriteHeader(status int) { return } w.headerWritten = true + w.status = status var headers bytes.Buffer enc := hpack.NewEncoder(&headers) @@ -72,6 +74,9 @@ func (w *responseWriter) Write(p []byte) (int, error) { if !w.headerWritten { w.WriteHeader(200) } + if !bodyAllowedForStatus(w.status) { + return 0, http.ErrBodyNotAllowed + } return w.dataStream.Write(p) } @@ -79,3 +84,18 @@ func (w *responseWriter) Flush() {} // test that we implement http.Flusher var _ http.Flusher = &responseWriter{} + +// copied from http2/http2.go +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 2616, section 4.4. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go new file mode 100644 index 0000000..85faf8e --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go @@ -0,0 +1,135 @@ +package h2quic + +import ( + "crypto/tls" + "errors" + "fmt" + "net/http" + "strings" + "sync" + + "golang.org/x/net/lex/httplex" +) + +type h2quicClient interface { + Do(*http.Request) (*http.Response, error) +} + +// QuicRoundTripper implements the http.RoundTripper interface +type QuicRoundTripper struct { + mutex sync.Mutex + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + clients map[string]h2quicClient +} + +var _ http.RoundTripper = &QuicRoundTripper{} + +// RoundTrip does a round trip +func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL == nil { + closeRequestBody(req) + return nil, errors.New("quic: nil Request.URL") + } + if req.URL.Host == "" { + closeRequestBody(req) + return nil, errors.New("quic: no Host in request URL") + } + if req.Header == nil { + closeRequestBody(req) + return nil, errors.New("quic: nil Request.Header") + } + + if req.URL.Scheme == "https" { + for k, vv := range req.Header { + if !httplex.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("quic: invalid http header field name %q", k) + } + for _, v := range vv { + if !httplex.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) + } + } + } + } else { + closeRequestBody(req) + return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) + } + + if req.Method != "" && !validMethod(req.Method) { + closeRequestBody(req) + return nil, fmt.Errorf("quic: invalid method %q", req.Method) + } + + hostname := authorityAddr("https", hostnameFromRequest(req)) + client, err := r.getClient(hostname) + if err != nil { + return nil, err + } + return client.Do(req) +} + +func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.clients == nil { + r.clients = make(map[string]h2quicClient) + } + + client, ok := r.clients[hostname] + if !ok { + var err error + client, err = NewClient(r, r.TLSClientConfig, hostname) + if err != nil { + return nil, err + } + r.clients[hostname] = client + } + return client, nil +} + +func (r *QuicRoundTripper) disableCompression() bool { + return r.DisableCompression +} + +func closeRequestBody(req *http.Request) { + if req.Body != nil { + req.Body.Close() + } +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +// copied from net/http/http.go +func isNotToken(r rune) bool { + return !httplex.IsTokenRune(r) +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go index 4e301b5..8b591ec 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/h2quic/server.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "errors" "fmt" - "io/ioutil" "net" "net/http" "runtime" @@ -113,6 +112,7 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) { if _, ok := err.(*qerr.QuicError); !ok { utils.Errorf("error handling h2 request: %s", err.Error()) } + session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) return } } @@ -124,7 +124,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, if err != nil { return err } - h2headersFrame := h2frame.(*http2.HeadersFrame) + h2headersFrame, ok := h2frame.(*http2.HeadersFrame) + if !ok { + return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame") + } if !h2headersFrame.HeadersEnded() { return errors.New("http2 header continuation not implemented") } @@ -152,13 +155,15 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, return err } + var streamEnded bool if h2headersFrame.StreamEnded() { dataStream.CloseRemote(0) + streamEnded = true _, _ = dataStream.Read([]byte{0}) // read the eof } - // stream's Close() closes the write side, not the read side - req.Body = ioutil.NopCloser(dataStream) + reqBody := newRequestBody(dataStream) + req.Body = reqBody responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) @@ -187,6 +192,9 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, responseWriter.WriteHeader(200) } if responseWriter.dataStream != nil { + if !streamEnded && !reqBody.requestRead { + responseWriter.dataStream.Reset(nil) + } responseWriter.dataStream.Close() } if s.CloseAfterFirstRequest { diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go index df26119..5b8c816 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go @@ -2,7 +2,6 @@ package handshake import ( "bytes" - "encoding/binary" "errors" "sync" "time" @@ -12,23 +11,51 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) -// ConnectionParametersManager stores the connection parameters -// Warning: Writes may only be done from the crypto stream, see the comment -// in GetSHLOMap(). -type ConnectionParametersManager struct { - params map[Tag][]byte - mutex sync.RWMutex +// ConnectionParametersManager negotiates and stores the connection parameters +// A ConnectionParametersManager can be used for a server as well as a client +// For the server: +// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation +// 2. call GetHelloMap to get the values to send in the SHLO +// For the client: +// 1. call GetHelloMap to get the values to send in a CHLO +// 2. call SetFromMap with the values received in the SHLO +type ConnectionParametersManager interface { + SetFromMap(map[Tag][]byte) error + GetHelloMap() (map[Tag][]byte, error) - flowControlNegotiated bool // have the flow control parameters for sending already been negotiated - - maxStreamsPerConnection uint32 - idleConnectionStateLifetime time.Duration - sendStreamFlowControlWindow protocol.ByteCount - sendConnectionFlowControlWindow protocol.ByteCount - receiveStreamFlowControlWindow protocol.ByteCount - receiveConnectionFlowControlWindow protocol.ByteCount + GetSendStreamFlowControlWindow() protocol.ByteCount + GetSendConnectionFlowControlWindow() protocol.ByteCount + GetReceiveStreamFlowControlWindow() protocol.ByteCount + GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount + GetReceiveConnectionFlowControlWindow() protocol.ByteCount + GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount + GetMaxOutgoingStreams() uint32 + GetMaxIncomingStreams() uint32 + GetIdleConnectionStateLifetime() time.Duration + TruncateConnectionID() bool } +type connectionParametersManager struct { + mutex sync.RWMutex + + version protocol.VersionNumber + perspective protocol.Perspective + + flowControlNegotiated bool + hasReceivedMaxIncomingDynamicStreams bool + + truncateConnectionID bool + maxStreamsPerConnection uint32 + maxIncomingDynamicStreamsPerConnection uint32 + idleConnectionStateLifetime time.Duration + sendStreamFlowControlWindow protocol.ByteCount + sendConnectionFlowControlWindow protocol.ByteCount + receiveStreamFlowControlWindow protocol.ByteCount + receiveConnectionFlowControlWindow protocol.ByteCount +} + +var _ ConnectionParametersManager = &connectionParametersManager{} + var errTagNotInConnectionParameterMap = errors.New("ConnectionParametersManager: Tag not found in ConnectionsParameter map") // ErrMalformedTag is returned when the tag value cannot be read @@ -38,58 +65,82 @@ var ( ) // NewConnectionParamatersManager creates a new connection parameters manager -func NewConnectionParamatersManager() *ConnectionParametersManager { - return &ConnectionParametersManager{ - params: make(map[Tag][]byte), - idleConnectionStateLifetime: protocol.DefaultIdleTimeout, +func NewConnectionParamatersManager(pers protocol.Perspective, v protocol.VersionNumber) ConnectionParametersManager { + h := &connectionParametersManager{ + perspective: pers, + version: v, sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - maxStreamsPerConnection: protocol.MaxStreamsPerConnection, } + + if h.perspective == protocol.PerspectiveServer { + h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout + h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent + h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective + } else { + h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient + h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent + h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective + } + + return h } // SetFromMap reads all params -func (h *ConnectionParametersManager) SetFromMap(params map[Tag][]byte) error { +func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error { h.mutex.Lock() defer h.mutex.Unlock() - for key, value := range params { - switch key { - case TagTCID: - h.params[key] = value - case TagMSPC: - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) - case TagICSL: - clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second) - case TagSFCW: - if h.flowControlNegotiated { - return ErrFlowControlRenegotiationNotSupported - } - sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow) - case TagCFCW: - if h.flowControlNegotiated { - return ErrFlowControlRenegotiationNotSupported - } - sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return ErrMalformedTag - } - h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow) + if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer { + clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return ErrMalformedTag } + h.truncateConnectionID = (clientValue == 0) + } + if value, ok := params[TagMSPC]; ok { + clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return ErrMalformedTag + } + h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) + } + if value, ok := params[TagMIDS]; ok { + clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return ErrMalformedTag + } + h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue) + h.hasReceivedMaxIncomingDynamicStreams = true + } + if value, ok := params[TagICSL]; ok { + clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return ErrMalformedTag + } + h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second) + } + if value, ok := params[TagSFCW]; ok { + if h.flowControlNegotiated { + return ErrFlowControlRenegotiationNotSupported + } + sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return ErrMalformedTag + } + h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow) + } + if value, ok := params[TagCFCW]; ok { + if h.flowControlNegotiated { + return ErrFlowControlRenegotiationNotSupported + } + sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return ErrMalformedTag + } + h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow) } _, containsSFCW := params[TagSFCW] @@ -101,102 +152,132 @@ func (h *ConnectionParametersManager) SetFromMap(params map[Tag][]byte) error { return nil } -func (h *ConnectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { +func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) } -func (h *ConnectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { - return utils.MinDuration(clientValue, protocol.MaxIdleTimeout) +func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { + return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) } -// getRawValue gets the byte-slice for a tag -func (h *ConnectionParametersManager) getRawValue(tag Tag) ([]byte, error) { - h.mutex.RLock() - rawValue, ok := h.params[tag] - h.mutex.RUnlock() - - if !ok { - return nil, errTagNotInConnectionParameterMap +func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { + if h.perspective == protocol.PerspectiveServer { + return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer) } - return rawValue, nil + return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient) } -// GetSHLOMap gets all values (except crypto values) needed for the SHLO -func (h *ConnectionParametersManager) GetSHLOMap() map[Tag][]byte { +// GetHelloMap gets all parameters needed for the Hello message +func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) { sfcw := bytes.NewBuffer([]byte{}) utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) cfcw := bytes.NewBuffer([]byte{}) utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) mspc := bytes.NewBuffer([]byte{}) - utils.WriteUint32(mspc, h.GetMaxStreamsPerConnection()) - mids := bytes.NewBuffer([]byte{}) - utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreams) + utils.WriteUint32(mspc, h.maxStreamsPerConnection) icsl := bytes.NewBuffer([]byte{}) utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second)) - return map[Tag][]byte{ + tags := map[Tag][]byte{ TagICSL: icsl.Bytes(), TagMSPC: mspc.Bytes(), - TagMIDS: mids.Bytes(), TagCFCW: cfcw.Bytes(), TagSFCW: sfcw.Bytes(), } + + if h.version > protocol.Version34 { + mids := bytes.NewBuffer([]byte{}) + utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection) + tags[TagMIDS] = mids.Bytes() + } + + return tags, nil } // GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *ConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { +func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.sendStreamFlowControlWindow } // GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *ConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { +func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.sendConnectionFlowControlWindow } // GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data -func (h *ConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { +func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.receiveStreamFlowControlWindow } +// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data +func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { + if h.perspective == protocol.PerspectiveServer { + return protocol.MaxReceiveStreamFlowControlWindowServer + } + return protocol.MaxReceiveStreamFlowControlWindowClient +} + // GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data -func (h *ConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { +func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.receiveConnectionFlowControlWindow } -// GetMaxStreamsPerConnection gets the maximum number of streams per connection -func (h *ConnectionParametersManager) GetMaxStreamsPerConnection() uint32 { +// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data +func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { + if h.perspective == protocol.PerspectiveServer { + return protocol.MaxReceiveConnectionFlowControlWindowServer + } + return protocol.MaxReceiveConnectionFlowControlWindowClient +} + +// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection +func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 { h.mutex.RLock() defer h.mutex.RUnlock() + + if h.version > protocol.Version34 && h.hasReceivedMaxIncomingDynamicStreams { + return h.maxIncomingDynamicStreamsPerConnection + } return h.maxStreamsPerConnection } +// GetMaxIncomingStreams get the maximum number of incoming streams per connection +func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 { + h.mutex.RLock() + defer h.mutex.RUnlock() + + var val uint32 + if h.version <= protocol.Version34 { + val = h.maxStreamsPerConnection + } else { + val = protocol.MaxIncomingDynamicStreamsPerConnection + } + + return utils.MaxUint32(val+protocol.MaxStreamsMinimumIncrement, uint32(float64(val)*protocol.MaxStreamsMultiplier)) +} + // GetIdleConnectionStateLifetime gets the idle timeout -func (h *ConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { +func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { h.mutex.RLock() defer h.mutex.RUnlock() return h.idleConnectionStateLifetime } // TruncateConnectionID determines if the client requests truncated ConnectionIDs -func (h *ConnectionParametersManager) TruncateConnectionID() bool { - rawValue, err := h.getRawValue(TagTCID) - if err != nil { +func (h *connectionParametersManager) TruncateConnectionID() bool { + if h.perspective == protocol.PerspectiveClient { return false } - if len(rawValue) != 4 { - return false - } - value := binary.LittleEndian.Uint32(rawValue) - if value == 0 { - return true - } - return false + + h.mutex.RLock() + defer h.mutex.RUnlock() + return h.truncateConnectionID } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go new file mode 100644 index 0000000..0fc46e5 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_client.go @@ -0,0 +1,485 @@ +package handshake + +import ( + "bytes" + "crypto/rand" + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" +) + +type cryptoSetupClient struct { + mutex sync.RWMutex + + hostname string + connID protocol.ConnectionID + version protocol.VersionNumber + negotiatedVersions []protocol.VersionNumber + + cryptoStream utils.Stream + + serverConfig *serverConfigClient + + stk []byte + sno []byte + nonc []byte + proof []byte + diversificationNonce []byte + chloForSignature []byte + lastSentCHLO []byte + certManager crypto.CertManager + + clientHelloCounter int + serverVerified bool // has the certificate chain and the proof already been verified + keyDerivation KeyDerivationFunction + + receivedSecurePacket bool + secureAEAD crypto.AEAD + forwardSecureAEAD crypto.AEAD + aeadChanged chan struct{} + + connectionParameters ConnectionParametersManager +} + +var _ crypto.AEAD = &cryptoSetupClient{} +var _ CryptoSetup = &cryptoSetupClient{} + +var ( + errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available") + errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated") + errConflictingDiversificationNonces = errors.New("Received two different diversification nonces") +) + +// NewCryptoSetupClient creates a new CryptoSetup instance for a client +func NewCryptoSetupClient( + hostname string, + connID protocol.ConnectionID, + version protocol.VersionNumber, + cryptoStream utils.Stream, + tlsConfig *tls.Config, + connectionParameters ConnectionParametersManager, + aeadChanged chan struct{}, + negotiatedVersions []protocol.VersionNumber, +) (CryptoSetup, error) { + return &cryptoSetupClient{ + hostname: hostname, + connID: connID, + version: version, + cryptoStream: cryptoStream, + certManager: crypto.NewCertManager(tlsConfig), + connectionParameters: connectionParameters, + keyDerivation: crypto.DeriveKeysAESGCM, + aeadChanged: aeadChanged, + negotiatedVersions: negotiatedVersions, + }, nil +} + +func (h *cryptoSetupClient) HandleCryptoStream() error { + for { + err := h.maybeUpgradeCrypto() + if err != nil { + return err + } + + // send CHLOs until the forward secure encryption is established + if h.forwardSecureAEAD == nil { + err = h.sendCHLO() + if err != nil { + return err + } + } + + var shloData bytes.Buffer + + messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &shloData)) + if err != nil { + return qerr.HandshakeFailed + } + + if messageTag != TagSHLO && messageTag != TagREJ { + return qerr.InvalidCryptoMessageType + } + + if messageTag == TagSHLO { + utils.Debugf("Got SHLO:\n%s", printHandshakeMessage(cryptoData)) + err = h.handleSHLOMessage(cryptoData) + if err != nil { + return err + } + } + + if messageTag == TagREJ { + err = h.handleREJMessage(cryptoData) + if err != nil { + return err + } + } + } +} + +func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { + utils.Debugf("Got REJ:\n%s", printHandshakeMessage(cryptoData)) + + var err error + + if stk, ok := cryptoData[TagSTK]; ok { + h.stk = stk + } + + if sno, ok := cryptoData[TagSNO]; ok { + h.sno = sno + } + + // TODO: what happens if the server sends a different server config in two packets? + if scfg, ok := cryptoData[TagSCFG]; ok { + h.serverConfig, err = parseServerConfig(scfg) + if err != nil { + return err + } + + if h.serverConfig.IsExpired() { + return qerr.CryptoServerConfigExpired + } + + // now that we have a server config, we can use its OBIT value to generate a client nonce + if len(h.nonc) == 0 { + err = h.generateClientNonce() + if err != nil { + return err + } + } + } + + if proof, ok := cryptoData[TagPROF]; ok { + h.proof = proof + h.chloForSignature = h.lastSentCHLO + } + + if crt, ok := cryptoData[TagCERT]; ok { + err := h.certManager.SetData(crt) + if err != nil { + return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid") + } + + err = h.certManager.Verify(h.hostname) + if err != nil { + utils.Infof("Certificate validation failed: %s", err.Error()) + return qerr.ProofInvalid + } + } + + if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil { + validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get()) + if !validProof { + utils.Infof("Server proof verification failed") + return qerr.ProofInvalid + } + + h.serverVerified = true + } + + return nil +} + +func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { + h.mutex.Lock() + defer h.mutex.Unlock() + + if !h.receivedSecurePacket { + return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") + } + + if sno, ok := cryptoData[TagSNO]; ok { + h.sno = sno + } + + serverPubs, ok := cryptoData[TagPUBS] + if !ok { + return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") + } + + verTag, ok := cryptoData[TagVER] + if !ok { + return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") + } + if !h.validateVersionList(verTag) { + return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") + } + + nonce := append(h.nonc, h.sno...) + + ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs) + if err != nil { + return err + } + + leafCert := h.certManager.GetLeafCert() + + h.forwardSecureAEAD, err = h.keyDerivation( + true, + ephermalSharedSecret, + nonce, + h.connID, + h.lastSentCHLO, + h.serverConfig.Get(), + leafCert, + nil, + protocol.PerspectiveClient, + ) + if err != nil { + return err + } + + err = h.connectionParameters.SetFromMap(cryptoData) + if err != nil { + return qerr.InvalidCryptoMessageParameter + } + + h.aeadChanged <- struct{}{} + + return nil +} + +func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { + if len(h.negotiatedVersions) == 0 { + return true + } + if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) { + return false + } + + b := bytes.NewReader(verTags) + for _, negotiatedVersion := range h.negotiatedVersions { + verTag, err := utils.ReadUint32(b) + if err != nil { // should never occur, since the length was already checked + return false + } + ver := protocol.VersionTagToNumber(verTag) + if !protocol.IsSupportedVersion(ver) { + ver = protocol.VersionUnsupported + } + if ver != negotiatedVersion { + return false + } + } + return true +} + +func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { + if h.forwardSecureAEAD != nil { + data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) + if err == nil { + return data, nil + } + return nil, err + } + + if h.secureAEAD != nil { + data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData) + if err == nil { + h.receivedSecurePacket = true + return data, nil + } + if h.receivedSecurePacket { + return nil, err + } + } + + return (&crypto.NullAEAD{}).Open(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupClient) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + if h.forwardSecureAEAD != nil { + return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) + } + if h.secureAEAD != nil { + return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) + } + return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupClient) DiversificationNonce() []byte { + panic("not needed for cryptoSetupClient") +} + +func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error { + if len(h.diversificationNonce) == 0 { + h.diversificationNonce = data + return h.maybeUpgradeCrypto() + } + if !bytes.Equal(h.diversificationNonce, data) { + return errConflictingDiversificationNonces + } + return nil +} + +func (h *cryptoSetupClient) LockForSealing() { + +} + +func (h *cryptoSetupClient) UnlockForSealing() { + +} + +func (h *cryptoSetupClient) HandshakeComplete() bool { + h.mutex.RLock() + complete := h.forwardSecureAEAD != nil + h.mutex.RUnlock() + return complete +} + +func (h *cryptoSetupClient) sendCHLO() error { + h.clientHelloCounter++ + if h.clientHelloCounter > protocol.MaxClientHellos { + return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos)) + } + + b := &bytes.Buffer{} + + tags, err := h.getTags() + if err != nil { + return err + } + h.addPadding(tags) + + utils.Debugf("Sending CHLO:\n%s", printHandshakeMessage(tags)) + WriteHandshakeMessage(b, TagCHLO, tags) + + _, err = h.cryptoStream.Write(b.Bytes()) + if err != nil { + return err + } + + h.lastSentCHLO = b.Bytes() + + return nil +} + +func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { + tags, err := h.connectionParameters.GetHelloMap() + if err != nil { + return nil, err + } + tags[TagSNI] = []byte(h.hostname) + tags[TagPDMD] = []byte("X509") + + ccs := h.certManager.GetCommonCertificateHashes() + if len(ccs) > 0 { + tags[TagCCS] = ccs + } + + versionTag := make([]byte, 4, 4) + binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) + tags[TagVER] = versionTag + + if len(h.stk) > 0 { + tags[TagSTK] = h.stk + } + + if len(h.sno) > 0 { + tags[TagSNO] = h.sno + } + + if h.serverConfig != nil { + tags[TagSCID] = h.serverConfig.ID + + leafCert := h.certManager.GetLeafCert() + if leafCert != nil { + certHash, _ := h.certManager.GetLeafCertHash() + xlct := make([]byte, 8, 8) + binary.LittleEndian.PutUint64(xlct, certHash) + + tags[TagNONC] = h.nonc + tags[TagXLCT] = xlct + tags[TagKEXS] = []byte("C255") + tags[TagAEAD] = []byte("AESG") + tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended + } + } + + return tags, nil +} + +// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize +func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) { + var size int + for _, tag := range tags { + size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data + } + paddingSize := protocol.ClientHelloMinimumSize - size + if paddingSize > 0 { + tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize) + } +} + +func (h *cryptoSetupClient) maybeUpgradeCrypto() error { + if !h.serverVerified { + return nil + } + + h.mutex.Lock() + defer h.mutex.Unlock() + + leafCert := h.certManager.GetLeafCert() + + if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) { + var err error + var nonce []byte + if h.sno == nil { + nonce = h.nonc + } else { + nonce = append(h.nonc, h.sno...) + } + + h.secureAEAD, err = h.keyDerivation( + false, + h.serverConfig.sharedSecret, + nonce, + h.connID, + h.lastSentCHLO, + h.serverConfig.Get(), + leafCert, + h.diversificationNonce, + protocol.PerspectiveClient, + ) + if err != nil { + return err + } + + h.aeadChanged <- struct{}{} + } + + return nil +} + +func (h *cryptoSetupClient) generateClientNonce() error { + if len(h.nonc) > 0 { + return errClientNonceAlreadyExists + } + + nonc := make([]byte, 32) + binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix())) + + if len(h.serverConfig.obit) != 8 { + return errNoObitForClientNonce + } + + copy(nonc[4:12], h.serverConfig.obit) + + _, err := rand.Read(nonc[12:]) + if err != nil { + return err + } + + h.nonc = nonc + return nil +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_interface.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_interface.go new file mode 100644 index 0000000..4822cfb --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_interface.go @@ -0,0 +1,16 @@ +package handshake + +import "github.com/lucas-clemente/quic-go/protocol" + +// CryptoSetup is a crypto setup +type CryptoSetup interface { + HandleCryptoStream() error + Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) + Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + LockForSealing() + UnlockForSealing() + HandshakeComplete() bool + // TODO: clean up this interface + DiversificationNonce() []byte // only needed for cryptoSetupServer + SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go similarity index 60% rename from cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup.go rename to cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go index 56e9764..213d1d2 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/crypto_setup_server.go @@ -3,6 +3,7 @@ package handshake import ( "bytes" "crypto/rand" + "encoding/binary" "io" "net" "sync" @@ -14,13 +15,13 @@ import ( ) // KeyDerivationFunction is used for key derivation -type KeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (crypto.AEAD, error) +type KeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) // KeyExchangeFunction is used to make a new KEX type KeyExchangeFunction func() crypto.KeyExchange -// The CryptoSetup handles all things crypto for the Session -type CryptoSetup struct { +// The CryptoSetupServer handles all things crypto for the Session +type cryptoSetupServer struct { connID protocol.ConnectionID ip net.IP version protocol.VersionNumber @@ -38,38 +39,38 @@ type CryptoSetup struct { cryptoStream utils.Stream - connectionParametersManager *ConnectionParametersManager + connectionParameters ConnectionParametersManager mutex sync.RWMutex } -var _ crypto.AEAD = &CryptoSetup{} +var _ crypto.AEAD = &cryptoSetupServer{} -// NewCryptoSetup creates a new CryptoSetup instance +// NewCryptoSetup creates a new CryptoSetup instance for a server func NewCryptoSetup( connID protocol.ConnectionID, ip net.IP, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream, - connectionParametersManager *ConnectionParametersManager, + connectionParametersManager ConnectionParametersManager, aeadChanged chan struct{}, -) (*CryptoSetup, error) { - return &CryptoSetup{ - connID: connID, - ip: ip, - version: version, - scfg: scfg, - keyDerivation: crypto.DeriveKeysAESGCM, - keyExchange: getEphermalKEX, - cryptoStream: cryptoStream, - connectionParametersManager: connectionParametersManager, - aeadChanged: aeadChanged, +) (CryptoSetup, error) { + return &cryptoSetupServer{ + connID: connID, + ip: ip, + version: version, + scfg: scfg, + keyDerivation: crypto.DeriveKeysAESGCM, + keyExchange: getEphermalKEX, + cryptoStream: cryptoStream, + connectionParameters: connectionParametersManager, + aeadChanged: aeadChanged, }, nil } // HandleCryptoStream reads and writes messages on the crypto stream -func (h *CryptoSetup) HandleCryptoStream() error { +func (h *cryptoSetupServer) HandleCryptoStream() error { for { var chloData bytes.Buffer messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) @@ -92,7 +93,7 @@ func (h *CryptoSetup) HandleCryptoStream() error { } } -func (h *CryptoSetup) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) { +func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) { sniSlice, ok := cryptoData[TagSNI] if !ok { return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") @@ -102,9 +103,31 @@ func (h *CryptoSetup) handleMessage(chloData []byte, cryptoData map[Tag][]byte) return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") } + // prevent version downgrade attacks + // see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples + verSlice, ok := cryptoData[TagVER] + if !ok { + return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag") + } + if len(verSlice) != 4 { + return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag") + } + verTag := binary.LittleEndian.Uint32(verSlice) + ver := protocol.VersionTagToNumber(verTag) + // If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack. + if ver != h.version && protocol.IsSupportedVersion(ver) { + return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") + } + var reply []byte var err error - if !h.isInchoateCHLO(cryptoData) { + + certUncompressed, err := h.scfg.certChain.GetLeafCert(sni) + if err != nil { + return false, err + } + + if !h.isInchoateCHLO(cryptoData, certUncompressed) { // We have a CHLO with a proper server config ID, do a 0-RTT handshake reply, err = h.handleCHLO(sni, chloData, cryptoData) if err != nil { @@ -130,7 +153,7 @@ func (h *CryptoSetup) handleMessage(chloData []byte, cryptoData map[Tag][]byte) } // Open a message -func (h *CryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { +func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { h.mutex.RLock() defer h.mutex.RUnlock() @@ -158,7 +181,7 @@ func (h *CryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, } // Seal a message, call LockForSealing() before! -func (h *CryptoSetup) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { +func (h *cryptoSetupServer) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { if h.receivedForwardSecurePacket { return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) } else if h.secureAEAD != nil { @@ -168,12 +191,20 @@ func (h *CryptoSetup) Seal(dst, src []byte, packetNumber protocol.PacketNumber, } } -func (h *CryptoSetup) isInchoateCHLO(cryptoData map[Tag][]byte) bool { +func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool { + if _, ok := cryptoData[TagPUBS]; !ok { + return true + } scid, ok := cryptoData[TagSCID] if !ok || !bytes.Equal(h.scfg.ID, scid) { return true } - if _, ok := cryptoData[TagPUBS]; !ok { + xlctTag, ok := cryptoData[TagXLCT] + if !ok || len(xlctTag) != 8 { + return true + } + xlct := binary.LittleEndian.Uint64(xlctTag) + if crypto.HashCert(cert) != xlct { return true } if err := h.scfg.stkSource.VerifyToken(h.ip, cryptoData[TagSTK]); err != nil { @@ -183,7 +214,7 @@ func (h *CryptoSetup) isInchoateCHLO(cryptoData map[Tag][]byte) bool { return false } -func (h *CryptoSetup) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { +func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { if len(chlo) < protocol.ClientHelloMinimumSize { return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small") } @@ -219,10 +250,11 @@ func (h *CryptoSetup) handleInchoateCHLO(sni string, chlo []byte, cryptoData map var serverReply bytes.Buffer WriteHandshakeMessage(&serverReply, TagREJ, replyMap) + utils.Debugf("Sending REJ:\n%s", printHandshakeMessage(replyMap)) return serverReply.Bytes(), nil } -func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) { +func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) { // We have a CHLO matching our server config, we can continue with the 0-RTT handshake sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS]) if err != nil { @@ -232,13 +264,13 @@ func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]b h.mutex.Lock() defer h.mutex.Unlock() - certUncompressed, err := h.scfg.signer.GetLeafCert(sni) + certUncompressed, err := h.scfg.certChain.GetLeafCert(sni) if err != nil { return nil, err } - nonce := make([]byte, 32) - if _, err = rand.Read(nonce); err != nil { + serverNonce := make([]byte, 32) + if _, err = rand.Read(serverNonce); err != nil { return nil, err } @@ -247,15 +279,32 @@ func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]b return nil, err } + clientNonce := cryptoData[TagNONC] + err = h.validateClientNonce(clientNonce) + if err != nil { + return nil, err + } + + aead := cryptoData[TagAEAD] + if !bytes.Equal(aead, []byte("AESG")) { + return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS") + } + + kexs := cryptoData[TagKEXS] + if !bytes.Equal(kexs, []byte("C255")) { + return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS") + } + h.secureAEAD, err = h.keyDerivation( false, sharedSecret, - cryptoData[TagNONC], + clientNonce, h.connID, data, h.scfg.Get(), certUncompressed, h.diversificationNonce, + protocol.PerspectiveServer, ) if err != nil { return nil, err @@ -263,13 +312,14 @@ func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]b // Generate a new curve instance to derive the forward secure key var fsNonce bytes.Buffer - fsNonce.Write(cryptoData[TagNONC]) - fsNonce.Write(nonce) + fsNonce.Write(clientNonce) + fsNonce.Write(serverNonce) ephermalKex := h.keyExchange() ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) if err != nil { return nil, err } + h.forwardSecureAEAD, err = h.keyDerivation( true, ephermalSharedSecret, @@ -279,24 +329,29 @@ func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]b h.scfg.Get(), certUncompressed, nil, + protocol.PerspectiveServer, ) if err != nil { return nil, err } - err = h.connectionParametersManager.SetFromMap(cryptoData) + err = h.connectionParameters.SetFromMap(cryptoData) if err != nil { return nil, err } - replyMap := h.connectionParametersManager.GetSHLOMap() + replyMap, err := h.connectionParameters.GetHelloMap() + if err != nil { + return nil, err + } // add crypto parameters replyMap[TagPUBS] = ephermalKex.PublicKey() - replyMap[TagSNO] = nonce + replyMap[TagSNO] = serverNonce replyMap[TagVER] = protocol.SupportedVersionsAsTags var reply bytes.Buffer WriteHandshakeMessage(&reply, TagSHLO, replyMap) + utils.Debugf("Sending SHLO:\n%s", printHandshakeMessage(replyMap)) h.aeadChanged <- struct{}{} @@ -304,24 +359,38 @@ func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]b } // DiversificationNonce returns a diversification nonce if required in the next packet to be Seal'ed. See LockForSealing()! -func (h *CryptoSetup) DiversificationNonce() []byte { +func (h *cryptoSetupServer) DiversificationNonce() []byte { if h.receivedForwardSecurePacket || h.secureAEAD == nil { return nil } return h.diversificationNonce } +func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) error { + panic("not needed for cryptoSetupServer") +} + // LockForSealing should be called before Seal(). It is needed so that diversification nonces can be obtained before packets are sealed, and the AEADs are not changed in the meantime. -func (h *CryptoSetup) LockForSealing() { +func (h *cryptoSetupServer) LockForSealing() { h.mutex.RLock() } // UnlockForSealing should be called after Seal() is complete, see LockForSealing(). -func (h *CryptoSetup) UnlockForSealing() { +func (h *cryptoSetupServer) UnlockForSealing() { h.mutex.RUnlock() } // HandshakeComplete returns true after the first forward secure packet was received form the client. -func (h *CryptoSetup) HandshakeComplete() bool { +func (h *cryptoSetupServer) HandshakeComplete() bool { return h.receivedForwardSecurePacket } + +func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { + if len(nonce) != 32 { + return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") + } + if !bytes.Equal(nonce[4:12], h.scfg.obit) { + return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching") + } + return nil +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go index 013c44b..32f0265 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/handshake_message.go @@ -95,12 +95,19 @@ func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte) func printHandshakeMessage(data map[Tag][]byte) string { var res string + var pad string for k, v := range data { if k == TagPAD { - continue + pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(k), len(v)) + } else { + res += fmt.Sprintf("\t%s: %#v\n", tagToString(k), string(v)) } - res += fmt.Sprintf("\t%s: %#v\n", tagToString(k), string(v)) } + + if len(pad) > 0 { + res += pad + } + return res } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go index dc33e97..cd15b20 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config.go @@ -10,13 +10,14 @@ import ( // ServerConfig is a server config type ServerConfig struct { kex crypto.KeyExchange - signer crypto.Signer + certChain crypto.CertChain ID []byte + obit []byte stkSource crypto.StkSource } // NewServerConfig creates a new server config -func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfig, error) { +func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) { id := make([]byte, 16) _, err := rand.Read(id) if err != nil { @@ -27,6 +28,12 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi if _, err = rand.Read(stkSecret); err != nil { return nil, err } + + obit := make([]byte, 8) + if _, err = rand.Read(obit); err != nil { + return nil, err + } + stkSource, err := crypto.NewStkSource(stkSecret) if err != nil { return nil, err @@ -34,8 +41,9 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi return &ServerConfig{ kex: kex, - signer: signer, + certChain: certChain, ID: id, + obit: obit, stkSource: stkSource, }, nil } @@ -48,7 +56,7 @@ func (s *ServerConfig) Get() []byte { TagKEXS: []byte("C255"), TagAEAD: []byte("AESG"), TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...), - TagOBIT: {0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}, + TagOBIT: s.obit, TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, }) return serverConfig.Bytes() @@ -56,10 +64,10 @@ func (s *ServerConfig) Get() []byte { // Sign the server config and CHLO with the server's keyData func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) { - return s.signer.SignServerProof(sni, chlo, s.Get()) + return s.certChain.SignServerProof(sni, chlo, s.Get()) } // GetCertsCompressed returns the certificate data func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) { - return s.signer.GetCertsCompressed(sni, commonSetHashes, compressedHashes) + return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes) } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go new file mode 100644 index 0000000..1da6551 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/server_config_client.go @@ -0,0 +1,148 @@ +package handshake + +import ( + "bytes" + "encoding/binary" + "errors" + "math" + "time" + + "github.com/lucas-clemente/quic-go/crypto" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" +) + +type serverConfigClient struct { + raw []byte + ID []byte + obit []byte + expiry time.Time + + kex crypto.KeyExchange + sharedSecret []byte +} + +var ( + errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG") +) + +// parseServerConfig parses a server config +func parseServerConfig(data []byte) (*serverConfigClient, error) { + tag, tagMap, err := ParseHandshakeMessage(bytes.NewReader(data)) + if err != nil { + return nil, err + } + if tag != TagSCFG { + return nil, errMessageNotServerConfig + } + + scfg := &serverConfigClient{raw: data} + err = scfg.parseValues(tagMap) + if err != nil { + return nil, err + } + + return scfg, nil +} + +func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error { + // SCID + scfgID, ok := tagMap[TagSCID] + if !ok { + return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID") + } + if len(scfgID) != 16 { + return qerr.Error(qerr.CryptoInvalidValueLength, "SCID") + } + s.ID = scfgID + + // KEXS + // TODO: allow for P256 in the list + // TODO: setup Key Exchange + kexs, ok := tagMap[TagKEXS] + if !ok { + return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS") + } + if len(kexs)%4 != 0 { + return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS") + } + if !bytes.Equal(kexs, []byte("C255")) { + return qerr.Error(qerr.CryptoNoSupport, "KEXS") + } + + // AEAD + aead, ok := tagMap[TagAEAD] + if !ok { + return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD") + } + if len(aead)%4 != 0 { + return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD") + } + var aesgFound bool + for i := 0; i < len(aead)/4; i++ { + if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) { + aesgFound = true + break + } + } + if !aesgFound { + return qerr.Error(qerr.CryptoNoSupport, "AEAD") + } + + // PUBS + // TODO: save this value + pubs, ok := tagMap[TagPUBS] + if !ok { + return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") + } + if len(pubs) != 35 { + return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") + } + + var err error + s.kex, err = crypto.NewCurve25519KEX() + if err != nil { + return err + } + + // the PUBS value is always prepended by []byte{0x20, 0x00, 0x00} + s.sharedSecret, err = s.kex.CalculateSharedKey(pubs[3:]) + if err != nil { + return err + } + + // OBIT + obit, ok := tagMap[TagOBIT] + if !ok { + return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT") + } + if len(obit) != 8 { + return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT") + } + s.obit = obit + + // EXPY + expy, ok := tagMap[TagEXPY] + if !ok { + return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY") + } + if len(expy) != 8 { + return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY") + } + // make sure that the value doesn't overflow an int64 + // furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here + expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2) + s.expiry = time.Unix(int64(expyTimestamp), 0) + + // TODO: implement VER + + return nil +} + +func (s *serverConfigClient) IsExpired() bool { + return s.expiry.Before(time.Now()) +} + +func (s *serverConfigClient) Get() []byte { + return s.raw +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/tags.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/tags.go index 92d4d84..56a07f6 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/tags.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/handshake/tags.go @@ -59,6 +59,8 @@ const ( // TagNONC is the client nonce TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24 + // TagXLCT is the expected leaf certificate + TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24 // TagSCID is the server config ID TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24 diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_packer.go index 359fa73..fdf307a 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -18,58 +18,72 @@ type packedPacket struct { type packetPacker struct { connectionID protocol.ConnectionID + perspective protocol.Perspective version protocol.VersionNumber - cryptoSetup *handshake.CryptoSetup + cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator - connectionParametersManager *handshake.ConnectionParametersManager + connectionParameters handshake.ConnectionParametersManager streamFramer *streamFramer controlFrames []frames.Frame } -func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup *handshake.CryptoSetup, connectionParametersHandler *handshake.ConnectionParametersManager, streamFramer *streamFramer, version protocol.VersionNumber) *packetPacker { +func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, connectionParameters handshake.ConnectionParametersManager, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber) *packetPacker { return &packetPacker{ - cryptoSetup: cryptoSetup, - connectionID: connectionID, - connectionParametersManager: connectionParametersHandler, - version: version, - streamFramer: streamFramer, - packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), + cryptoSetup: cryptoSetup, + connectionID: connectionID, + connectionParameters: connectionParameters, + perspective: perspective, + version: version, + streamFramer: streamFramer, + packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), } } -func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { - return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true, false) +// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame +func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { + // in case the connection is closed, all queued control frames aren't of any use anymore + // discard them and queue the ConnectionCloseFrame + p.controlFrames = []frames.Frame{ccf} + return p.packPacket(nil, leastUnacked) } -func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, maySendOnlyAck bool) (*packedPacket, error) { - return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false, maySendOnlyAck) +// PackPacket packs a new packet +// the stopWaitingFrame is *guaranteed* to be included in the next packet +// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise +func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { + p.controlFrames = append(p.controlFrames, controlFrames...) + return p.packPacket(stopWaitingFrame, leastUnacked) } -func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame, maySendOnlyAck bool) (*packedPacket, error) { - if len(controlFrames) > 0 { - p.controlFrames = append(p.controlFrames, controlFrames...) - } - - currentPacketNumber := p.packetNumberGenerator.Peek() - +func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { // cryptoSetup needs to be locked here, so that the AEADs are not changed between // calling DiversificationNonce() and Seal(). p.cryptoSetup.LockForSealing() defer p.cryptoSetup.UnlockForSealing() + currentPacketNumber := p.packetNumberGenerator.Peek() packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked) responsePublicHeader := &PublicHeader{ ConnectionID: p.connectionID, PacketNumber: currentPacketNumber, PacketNumberLen: packetNumberLen, - TruncateConnectionID: p.connectionParametersManager.TruncateConnectionID(), - DiversificationNonce: p.cryptoSetup.DiversificationNonce(), + TruncateConnectionID: p.connectionParameters.TruncateConnectionID(), } - publicHeaderLength, err := responsePublicHeader.GetLength() + if p.perspective == protocol.PerspectiveServer { + responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce() + } + + // TODO: stop sending version numbers once a version has been negotiated + if p.perspective == protocol.PerspectiveClient { + responsePublicHeader.VersionFlag = true + responsePublicHeader.VersionNumber = p.version + } + + publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective) if err != nil { return nil, err } @@ -79,9 +93,15 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con stopWaitingFrame.PacketNumberLen = packetNumberLen } + // we're packing a ConnectionClose, don't add any StreamFrames + var isConnectionClose bool + if len(p.controlFrames) == 1 { + _, isConnectionClose = p.controlFrames[0].(*frames.ConnectionCloseFrame) + } + var payloadFrames []frames.Frame - if onlySendOneControlFrame { - payloadFrames = []frames.Frame{controlFrames[0]} + if isConnectionClose { + payloadFrames = []frames.Frame{p.controlFrames[0]} } else { payloadFrames, err = p.composeNextPacket(stopWaitingFrame, publicHeaderLength) if err != nil { @@ -94,26 +114,14 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con return nil, nil } // Don't send out packets that only contain a StopWaitingFrame - if !onlySendOneControlFrame && len(payloadFrames) == 1 && stopWaitingFrame != nil { + if len(payloadFrames) == 1 && stopWaitingFrame != nil { return nil, nil } - // Don't send out packets that only contain an ACK (plus optional STOP_WAITING), if requested - if !maySendOnlyAck { - if len(payloadFrames) == 1 { - if _, ok := payloadFrames[0].(*frames.AckFrame); ok { - return nil, nil - } - } else if len(payloadFrames) == 2 && stopWaitingFrame != nil { - if _, ok := payloadFrames[1].(*frames.AckFrame); ok { - return nil, nil - } - } - } raw := getPacketBuffer() buffer := bytes.NewBuffer(raw) - if err = responsePublicHeader.WritePublicHeader(buffer, p.version); err != nil { + if err = responsePublicHeader.Write(buffer, p.version, p.perspective); err != nil { return nil, err } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go index 3434751..cec85ed 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go @@ -11,10 +11,6 @@ import ( "github.com/lucas-clemente/quic-go/qerr" ) -type unpackedPacket struct { - frames []frames.Frame -} - type packetUnpacker struct { version protocol.VersionNumber aead crypto.AEAD diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go new file mode 100644 index 0000000..3622c9f --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/encryption_level.go @@ -0,0 +1,14 @@ +package protocol + +// EncryptionLevel is the encryption level +// Default value is Unencrypted +type EncryptionLevel int + +const ( + // Unencrypted is not encrypted + Unencrypted EncryptionLevel = iota + // EncryptionSecure is encrypted, but not forward secure + EncryptionSecure + // EncryptionForwardSecure is forward secure + EncryptionForwardSecure +) diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go new file mode 100644 index 0000000..6aa3b70 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/perspective.go @@ -0,0 +1,10 @@ +package protocol + +// Perspective determines if we're acting as a server or a client +type Perspective int + +// the perspectives +const ( + PerspectiveServer Perspective = 1 + PerspectiveClient Perspective = 2 +) diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go index aae2e0a..bb36e3b 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/protocol.go @@ -64,3 +64,9 @@ const MaxRetransmissionTime = 60 * time.Second // ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. const ClientHelloMinimumSize = 1024 + +// MaxClientHellos is the maximum number of times we'll send a client hello +// The value 3 accounts for: +// * one failure due to an incorrect or missing source-address token +// * one failure due the server's certificate chain being unavailible and the server being unwilling to send it without a valid source-address token +const MaxClientHellos = 3 diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go index 99894a0..1198182 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/server_parameters.go @@ -3,31 +3,48 @@ package protocol import "time" // DefaultMaxCongestionWindow is the default for the max congestion window -const DefaultMaxCongestionWindow PacketNumber = 1000 +const DefaultMaxCongestionWindow = 1000 // InitialCongestionWindow is the initial congestion window in QUIC packets -const InitialCongestionWindow PacketNumber = 32 +const InitialCongestionWindow = 32 // MaxUndecryptablePackets limits the number of undecryptable packets that a // session queues for later until it sends a public reset. const MaxUndecryptablePackets = 10 -// AckSendDelay is the maximal time delay applied to packets containing only ACKs -const AckSendDelay = 5 * time.Millisecond +// AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet +// This is the value Chromium is using +const AckSendDelay = 25 * time.Millisecond // ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data // This is the value that Google servers are using -const ReceiveStreamFlowControlWindow ByteCount = (1 << 20) // 1 MB +const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB -// ReceiveConnectionFlowControlWindow is the stream-level flow control window for receiving data +// ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data // This is the value that Google servers are using -const ReceiveConnectionFlowControlWindow ByteCount = (1 << 20) * 1.5 // 1.5 MB +const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB + +// MaxReceiveStreamFlowControlWindowServer is the maximum stream-level flow control window for receiving data +// This is the value that Google servers are using +const MaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB + +// MaxReceiveConnectionFlowControlWindowServer is the connection-level flow control window for receiving data +// This is the value that Google servers are using +const MaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB + +// MaxReceiveStreamFlowControlWindowClient is the maximum stream-level flow control window for receiving data, for the client +// This is the value that Chromium is using +const MaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB + +// MaxReceiveConnectionFlowControlWindowClient is the connection-level flow control window for receiving data, for the server +// This is the value that Google servers are using +const MaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB // MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection const MaxStreamsPerConnection = 100 -// MaxIncomingDynamicStreams is the maximum value accepted for the incoming number of dynamic streams per connection -const MaxIncomingDynamicStreams = 100 +// MaxIncomingDynamicStreamsPerConnection is the maximum value accepted for the incoming number of dynamic streams per connection +const MaxIncomingDynamicStreamsPerConnection = 100 // MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used. const MaxStreamsMultiplier = 1.1 @@ -60,8 +77,17 @@ const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow // MaxTrackedReceivedPackets is the maximum number of received packets saved for doing the entropy calculations const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow +// MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked +const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow + +// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent +const MaxPacketsReceivedBeforeAckSend = 20 + +// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for +const RetransmittablePacketsBeforeAck = 2 + // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames -// prevents DOS attacks against the streamFrameSorter +// prevents DoS attacks against the streamFrameSorter const MaxStreamFrameSorterGaps = 1000 // CryptoMaxParams is the upper limit for the number of parameters in a crypto message. @@ -69,7 +95,7 @@ const MaxStreamFrameSorterGaps = 1000 const CryptoMaxParams = 128 // CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message. -const CryptoParameterMaxLength = 2000 +const CryptoParameterMaxLength = 4000 // EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX. const EphermalKeyLifetime = time.Minute @@ -77,14 +103,21 @@ const EphermalKeyLifetime = time.Minute // InitialIdleTimeout is the timeout before the handshake succeeds. const InitialIdleTimeout = 5 * time.Second -// DefaultIdleTimeout is the default idle timeout. +// DefaultIdleTimeout is the default idle timeout, for the server const DefaultIdleTimeout = 30 * time.Second -// MaxIdleTimeout is the maximum idle timeout that can be negotiated. -const MaxIdleTimeout = 1 * time.Minute +// MaxIdleTimeoutServer is the maximum idle timeout that can be negotiated, for the server +const MaxIdleTimeoutServer = 1 * time.Minute + +// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server +const MaxIdleTimeoutClient = 2 * time.Minute // MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds. const MaxTimeForCryptoHandshake = 10 * time.Second +// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed +// after this time all information about the old connection will be deleted +const ClosedSessionDeleteTimeout = time.Minute + // NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space const NumCachedCertificates = 128 diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/version.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/version.go index cd0cd47..3c1a70b 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/version.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/protocol/version.go @@ -14,10 +14,12 @@ const ( Version34 VersionNumber = 34 + iota Version35 Version36 - VersionWhatever = 0 // for when the version doesn't matter + VersionWhatever = 0 // for when the version doesn't matter + VersionUnsupported = -1 ) // SupportedVersions lists the versions that the server supports +// must be in sorted order var SupportedVersions = []VersionNumber{ Version34, Version35, Version36, } @@ -49,6 +51,28 @@ func IsSupportedVersion(v VersionNumber) bool { return false } +// HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions +// the versions in other do not need to be ordered +// it returns true and the version number, if there is one, otherwise false +func HighestSupportedVersion(other []VersionNumber) (bool, VersionNumber) { + var otherSupported []VersionNumber + for _, ver := range other { + if ver != VersionUnsupported { + otherSupported = append(otherSupported, ver) + } + } + + for i := len(SupportedVersions) - 1; i >= 0; i-- { + for _, ver := range otherSupported { + if ver == SupportedVersions[i] { + return true, ver + } + } + } + + return false, 0 +} + func init() { var b bytes.Buffer for _, v := range SupportedVersions { diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_header.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_header.go index 475c1b6..9613ea2 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_header.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_header.go @@ -3,7 +3,6 @@ package quic import ( "bytes" "errors" - "io" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -11,11 +10,11 @@ import ( ) var ( - errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") - errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") - errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") - errGetLengthOnlyForRegularPackets = errors.New("PublicHeader: GetLength can only be called for regular packets") + errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") + errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") + errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") + errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") + errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") ) // The PublicHeader of a QUIC packet @@ -27,16 +26,19 @@ type PublicHeader struct { TruncateConnectionID bool PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber - VersionNumber protocol.VersionNumber + VersionNumber protocol.VersionNumber // VersionNumber sent by the client + SupportedVersions []protocol.VersionNumber // VersionNumbers sent by the server DiversificationNonce []byte } -// WritePublicHeader writes a public header -func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.VersionNumber) error { +// Write writes a public header +func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error { publicFlagByte := uint8(0x00) + if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet } + if h.VersionFlag { publicFlagByte |= 0x01 } @@ -54,7 +56,8 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi publicFlagByte |= 0x04 } - if !h.ResetFlag && !h.VersionFlag { + // only set PacketNumberLen bits if a packet number will be written + if h.hasPacketNumber(pers) { switch h.PacketNumberLen { case protocol.PacketNumberLen1: publicFlagByte |= 0x00 @@ -73,30 +76,42 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi utils.WriteUint64(b, uint64(h.ConnectionID)) } + if h.VersionFlag && pers == protocol.PerspectiveClient { + utils.WriteUint32(b, protocol.VersionNumberToTag(h.VersionNumber)) + } + if len(h.DiversificationNonce) > 0 { b.Write(h.DiversificationNonce) } - if !h.ResetFlag && !h.VersionFlag { - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(h.PacketNumber)) - case protocol.PacketNumberLen2: - utils.WriteUint16(b, uint16(h.PacketNumber)) - case protocol.PacketNumberLen4: - utils.WriteUint32(b, uint32(h.PacketNumber)) - case protocol.PacketNumberLen6: - utils.WriteUint48(b, uint64(h.PacketNumber)) - default: - return errPacketNumberLenNotSet - } + // if we're a server, and the VersionFlag is set, we must not include anything else in the packet + if !h.hasPacketNumber(pers) { + return nil + } + + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { + return errPacketNumberLenNotSet + } + + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.WriteUint32(b, uint32(h.PacketNumber)) + case protocol.PacketNumberLen6: + utils.WriteUint48(b, uint64(h.PacketNumber)) + default: + return errPacketNumberLenNotSet } return nil } // ParsePublicHeader parses a QUIC packet's public header -func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { +// the packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient +func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) { header := &PublicHeader{} // First byte @@ -117,15 +132,17 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { return nil, errReceivedTruncatedConnectionID } - switch publicFlagByte & 0x30 { - case 0x30: - header.PacketNumberLen = protocol.PacketNumberLen6 - case 0x20: - header.PacketNumberLen = protocol.PacketNumberLen4 - case 0x10: - header.PacketNumberLen = protocol.PacketNumberLen2 - case 0x00: - header.PacketNumberLen = protocol.PacketNumberLen1 + if header.hasPacketNumber(packetSentBy) { + switch publicFlagByte & 0x30 { + case 0x30: + header.PacketNumberLen = protocol.PacketNumberLen6 + case 0x20: + header.PacketNumberLen = protocol.PacketNumberLen4 + case 0x10: + header.PacketNumberLen = protocol.PacketNumberLen2 + case 0x00: + header.PacketNumberLen = protocol.PacketNumberLen1 + } } // Connection ID @@ -133,46 +150,111 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { if err != nil { return nil, err } + header.ConnectionID = protocol.ConnectionID(connID) if header.ConnectionID == 0 { return nil, errInvalidConnectionID } - // Version (optional) - if header.VersionFlag { - var versionTag uint32 - versionTag, err = utils.ReadUint32(b) - if err != nil { - return nil, err + if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 { + // TODO: remove the if once the Google servers send the correct value + // assume that a packet doesn't contain a diversification nonce if the version flag or the reset flag is set, no matter what the public flag says + // see https://github.com/lucas-clemente/quic-go/issues/232 + if !header.VersionFlag && !header.ResetFlag { + header.DiversificationNonce = make([]byte, 32) + // this Read can never return an EOF for a valid packet, since the diversification nonce is followed by the packet number + _, err = b.Read(header.DiversificationNonce) + if err != nil { + return nil, err + } + } + } + + // Version (optional) + if !header.ResetFlag { + if header.VersionFlag { + if packetSentBy == protocol.PerspectiveClient { + var versionTag uint32 + versionTag, err = utils.ReadUint32(b) + if err != nil { + return nil, err + } + header.VersionNumber = protocol.VersionTagToNumber(versionTag) + } else { // parse the version negotiaton packet + if b.Len()%4 != 0 { + return nil, qerr.InvalidVersionNegotiationPacket + } + header.SupportedVersions = make([]protocol.VersionNumber, 0) + for { + var versionTag uint32 + versionTag, err = utils.ReadUint32(b) + if err != nil { + break + } + v := protocol.VersionTagToNumber(versionTag) + if !protocol.IsSupportedVersion(v) { + v = protocol.VersionUnsupported + } + header.SupportedVersions = append(header.SupportedVersions, v) + } + } } - header.VersionNumber = protocol.VersionTagToNumber(versionTag) } // Packet number - packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen)) - if err != nil { - return nil, err + if header.hasPacketNumber(packetSentBy) { + packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen)) + if err != nil { + return nil, err + } + header.PacketNumber = protocol.PacketNumber(packetNumber) } - header.PacketNumber = protocol.PacketNumber(packetNumber) return header, nil } // GetLength gets the length of the publicHeader in bytes // can only be called for regular packets -func (h *PublicHeader) GetLength() (protocol.ByteCount, error) { - if h.VersionFlag || h.ResetFlag { - return 0, errGetLengthOnlyForRegularPackets +func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) { + if h.VersionFlag && h.ResetFlag { + return 0, errResetAndVersionFlagSet + } + + if h.VersionFlag && pers == protocol.PerspectiveServer { + return 0, errGetLengthNotForVersionNegotiation } length := protocol.ByteCount(1) // 1 byte for public flags - if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { - return 0, errPacketNumberLenNotSet + + if h.hasPacketNumber(pers) { + if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { + return 0, errPacketNumberLenNotSet + } + length += protocol.ByteCount(h.PacketNumberLen) } + if !h.TruncateConnectionID { length += 8 // 8 bytes for the connection ID } + + // Version Number in packets sent by the client + if h.VersionFlag { + length += 4 + } + length += protocol.ByteCount(len(h.DiversificationNonce)) - length += protocol.ByteCount(h.PacketNumberLen) + return length, nil } + +// hasPacketNumber determines if this PublicHeader will contain a packet number +// this depends on the ResetFlag, the VersionFlag and who sent the packet +func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool { + if h.ResetFlag { + return false + } + if h.VersionFlag && packetSentBy == protocol.PerspectiveServer { + return false + } + return true +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_reset.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_reset.go index 7cceb2e..b1f60d4 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_reset.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/public_reset.go @@ -2,12 +2,19 @@ package quic import ( "bytes" + "encoding/binary" + "errors" "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/utils" ) +type publicReset struct { + rejectedPacketNumber protocol.PacketNumber + nonce uint64 +} + func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { b := &bytes.Buffer{} b.WriteByte(0x0a) @@ -22,3 +29,34 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p utils.WriteUint64(b, uint64(rejectedPacketNumber)) return b.Bytes() } + +func parsePublicReset(r *bytes.Reader) (*publicReset, error) { + pr := publicReset{} + tag, tagMap, err := handshake.ParseHandshakeMessage(r) + if err != nil { + return nil, err + } + if tag != handshake.TagPRST { + return nil, errors.New("wrong public reset tag") + } + + rseq, ok := tagMap[handshake.TagRSEQ] + if !ok { + return nil, errors.New("RSEQ missing") + } + if len(rseq) != 8 { + return nil, errors.New("invalid RSEQ tag") + } + pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) + + rnon, ok := tagMap[handshake.TagRNON] + if !ok { + return nil, errors.New("RNON missing") + } + if len(rnon) != 8 { + return nil, errors.New("invalid RNON tag") + } + pr.nonce = binary.LittleEndian.Uint64(rnon) + + return &pr, nil +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/server.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/server.go index 54ccc6d..40b3e44 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/server.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/server.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "crypto/tls" + "errors" "net" "strings" "sync" @@ -18,6 +19,7 @@ import ( // packetHandler handles packets type packetHandler interface { handlePacket(*receivedPacket) + OpenStream(protocol.StreamID) (utils.Stream, error) run() Close(error) error } @@ -29,11 +31,12 @@ type Server struct { conn *net.UDPConn connMutex sync.Mutex - signer crypto.Signer - scfg *handshake.ServerConfig + certChain crypto.CertChain + scfg *handshake.ServerConfig - sessions map[protocol.ConnectionID]packetHandler - sessionsMutex sync.RWMutex + sessions map[protocol.ConnectionID]packetHandler + sessionsMutex sync.RWMutex + deleteClosedSessionsAfter time.Duration streamCallback StreamCallback @@ -42,16 +45,13 @@ type Server struct { // NewServer makes a new server func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, error) { - signer, err := crypto.NewProofSource(tlsConfig) - if err != nil { - return nil, err - } + certChain := crypto.NewCertChain(tlsConfig) kex, err := crypto.NewCurve25519KEX() if err != nil { return nil, err } - scfg, err := handshake.NewServerConfig(kex, signer) + scfg, err := handshake.NewServerConfig(kex, certChain) if err != nil { return nil, err } @@ -62,12 +62,13 @@ func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, } return &Server{ - addr: udpAddr, - signer: signer, - scfg: scfg, - streamCallback: cb, - sessions: map[protocol.ConnectionID]packetHandler{}, - newSession: newSession, + addr: udpAddr, + certChain: certChain, + scfg: scfg, + streamCallback: cb, + sessions: map[protocol.ConnectionID]packetHandler{}, + newSession: newSession, + deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, }, nil } @@ -135,12 +136,39 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r) + hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } hdr.Raw = packet[:len(packet)-r.Len()] + s.sessionsMutex.RLock() + session, ok := s.sessions[hdr.ConnectionID] + s.sessionsMutex.RUnlock() + + // ignore all Public Reset packets + if hdr.ResetFlag { + if ok { + var pr *publicReset + pr, err = parsePublicReset(r) + if err != nil { + utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") + } else { + utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.rejectedPacketNumber) + } + } else { + utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) + } + return nil + } + + // a session is only created once the client sent a supported version + // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated + // it is safe to drop it + if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) { + return nil + } + // Send Version Negotiation Packet if the client is speaking a different protocol version if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) { utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber) @@ -148,15 +176,20 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet return err } - s.sessionsMutex.RLock() - session, ok := s.sessions[hdr.ConnectionID] - s.sessionsMutex.RUnlock() - if !ok { - utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, hdr.VersionNumber, remoteAddr) + if !hdr.VersionFlag { + _, err = conn.WriteToUDP(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) + return err + } + version := hdr.VersionNumber + if !protocol.IsSupportedVersion(version) { + return errors.New("Server BUG: negotiated version not supported") + } + + utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr) session, err = s.newSession( &udpConn{conn: conn, currentAddr: remoteAddr}, - hdr.VersionNumber, + version, hdr.ConnectionID, s.scfg, s.streamCallback, @@ -187,6 +220,12 @@ func (s *Server) closeCallback(id protocol.ConnectionID) { s.sessionsMutex.Lock() s.sessions[id] = nil s.sessionsMutex.Unlock() + + time.AfterFunc(s.deleteClosedSessionsAfter, func() { + s.sessionsMutex.Lock() + delete(s.sessions, id) + s.sessionsMutex.Unlock() + }) } func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte { @@ -196,7 +235,7 @@ func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte { PacketNumber: 1, VersionFlag: true, } - err := responsePublicHeader.WritePublicHeader(fullReply, protocol.Version35) + err := responsePublicHeader.Write(fullReply, protocol.Version35, protocol.PerspectiveServer) if err != nil { utils.Errorf("error composing version negotiation packet: %s", err.Error()) } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/session.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/session.go index 2f75299..ee67fe8 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/session.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/session.go @@ -1,14 +1,15 @@ package quic import ( + "crypto/tls" "errors" "fmt" "net" - "runtime" "sync/atomic" "time" "github.com/lucas-clemente/quic-go/ackhandler" + "github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/handshake" @@ -31,26 +32,35 @@ type receivedPacket struct { var ( errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") + errSessionAlreadyClosed = errors.New("Cannot close Session. It was already closed before.") ) // StreamCallback gets a stream frame and returns a reply frame type StreamCallback func(*Session, utils.Stream) +// CryptoChangeCallback is called every time the encryption level changes +// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that +type CryptoChangeCallback func(isForwardSecure bool) + // closeCallback is called when a session is closed type closeCallback func(id protocol.ConnectionID) // A Session is a QUIC session type Session struct { connectionID protocol.ConnectionID + perspective protocol.Perspective version protocol.VersionNumber - streamCallback StreamCallback - closeCallback closeCallback + streamCallback StreamCallback + closeCallback closeCallback + cryptoChangeCallback CryptoChangeCallback conn connection streamsMap *streamsMap + rttStats *congestion.RTTStats + sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler streamFramer *streamFramer @@ -60,21 +70,22 @@ type Session struct { unpacker unpacker packer *packetPacker - cryptoSetup *handshake.CryptoSetup + cryptoSetup handshake.CryptoSetup receivedPackets chan *receivedPacket sendingScheduled chan struct{} // closeChan is used to notify the run loop that it should terminate. // If the value is not nil, the error is sent as a CONNECTION_CLOSE. closeChan chan *qerr.QuicError + runClosed chan struct{} closed uint32 // atomic bool undecryptablePackets []*receivedPacket aeadChanged chan struct{} - delayedAckOriginTime time.Time + nextAckScheduledTime time.Time - connectionParametersManager *handshake.ConnectionParametersManager + connectionParameters handshake.ConnectionParametersManager lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire @@ -91,56 +102,90 @@ type Session struct { // newSession makes a new session func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) { - connectionParametersManager := handshake.NewConnectionParamatersManager() - flowControlManager := flowcontrol.NewFlowControlManager(connectionParametersManager) - - var sentPacketHandler ackhandler.SentPacketHandler - var receivedPacketHandler ackhandler.ReceivedPacketHandler - - sentPacketHandler = ackhandler.NewSentPacketHandler() - receivedPacketHandler = ackhandler.NewReceivedPacketHandler() - - now := time.Now() session := &Session{ conn: conn, connectionID: connectionID, + perspective: protocol.PerspectiveServer, version: v, - streamCallback: streamCallback, - closeCallback: closeCallback, - - connectionParametersManager: connectionParametersManager, - sentPacketHandler: sentPacketHandler, - receivedPacketHandler: receivedPacketHandler, - flowControlManager: flowControlManager, - - receivedPackets: make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets), - closeChan: make(chan *qerr.QuicError, 1), - sendingScheduled: make(chan struct{}, 1), - undecryptablePackets: make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets), - aeadChanged: make(chan struct{}, 1), - - timer: time.NewTimer(0), - lastNetworkActivityTime: now, - sessionCreationTime: now, + streamCallback: streamCallback, + closeCallback: closeCallback, + cryptoChangeCallback: func(bool) {}, + connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), } - session.streamsMap = newStreamsMap(session.newStream) - + session.setup() cryptoStream, _ := session.GetOrOpenStream(1) var err error - session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.RemoteAddr().IP, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) + session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.RemoteAddr().IP, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged) if err != nil { return nil, err } - session.streamFramer = newStreamFramer(session.streamsMap, flowControlManager) - session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParametersManager, session.streamFramer, v) - session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v} + session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParameters, session.streamFramer, session.perspective, session.version) + session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: session.version} return session, err } +func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { + session := &Session{ + conn: &udpConn{conn: conn, currentAddr: addr}, + connectionID: connectionID, + perspective: protocol.PerspectiveClient, + version: v, + + streamCallback: streamCallback, + closeCallback: closeCallback, + cryptoChangeCallback: cryptoChangeCallback, + connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), + } + + session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged) + session.setup() + + cryptoStream, _ := session.OpenStream(1) + var err error + session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions) + if err != nil { + return nil, err + } + + session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParameters, session.streamFramer, session.perspective, session.version) + session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: session.version} + + return session, err +} + +// setup is called from newSession and newClientSession and initializes values that are independent of the perspective +func (s *Session) setup() { + s.rttStats = &congestion.RTTStats{} + flowControlManager := flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) + + var sentPacketHandler ackhandler.SentPacketHandler + sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) + + now := time.Now() + + s.sentPacketHandler = sentPacketHandler + s.flowControlManager = flowControlManager + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) + + s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) + s.closeChan = make(chan *qerr.QuicError, 1) + s.sendingScheduled = make(chan struct{}, 1) + s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) + s.aeadChanged = make(chan struct{}, 1) + s.runClosed = make(chan struct{}, 1) + + s.timer = time.NewTimer(0) + s.lastNetworkActivityTime = now + s.sessionCreationTime = now + + s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) + s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) +} + // run the session main loop func (s *Session) run() { // Start the crypto stream handler @@ -150,6 +195,7 @@ func (s *Session) run() { } }() +runLoop: for { // Close immediately if requested select { @@ -157,7 +203,7 @@ func (s *Session) run() { if errForConnClose != nil { s.sendConnectionClose(errForConnClose) } - return + break runLoop default: } @@ -169,7 +215,7 @@ func (s *Session) run() { if errForConnClose != nil { s.sendConnectionClose(errForConnClose) } - return + break runLoop case <-s.timer.C: s.timerRead = true // We do all the interesting stuff after the switch statement, so @@ -186,35 +232,36 @@ func (s *Session) run() { // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. putPacketBuffer(p.publicHeader.Raw) - if s.delayedAckOriginTime.IsZero() { - s.delayedAckOriginTime = p.rcvTime - } case <-s.aeadChanged: s.tryDecryptingQueuedPackets() + s.cryptoChangeCallback(s.cryptoSetup.HandshakeComplete()) } if err != nil { - s.Close(err) + s.close(err) } if err := s.sendPacket(); err != nil { - s.Close(err) + s.close(err) } if time.Now().Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { - s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) + s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } if !s.cryptoSetup.HandshakeComplete() && time.Now().Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake { - s.Close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) + s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) } s.garbageCollectStreams() } + + s.closeCallback(s.connectionID) + s.runClosed <- struct{}{} } func (s *Session) maybeResetTimer() { nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) - if !s.delayedAckOriginTime.IsZero() { - nextDeadline = utils.MinTime(nextDeadline, s.delayedAckOriginTime.Add(protocol.AckSendDelay)) + if !s.nextAckScheduledTime.IsZero() { + nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime) } if rtoTime := s.sentPacketHandler.TimeOfFirstRTO(); !rtoTime.IsZero() { nextDeadline = utils.MinTime(nextDeadline, rtoTime) @@ -242,12 +289,19 @@ func (s *Session) maybeResetTimer() { func (s *Session) idleTimeout() time.Duration { if s.cryptoSetup.HandshakeComplete() { - return s.connectionParametersManager.GetIdleConnectionStateLifetime() + return s.connectionParameters.GetIdleConnectionStateLifetime() } return protocol.InitialIdleTimeout } func (s *Session) handlePacketImpl(p *receivedPacket) error { + if s.perspective == protocol.PerspectiveClient { + diversificationNonce := p.publicHeader.DiversificationNonce + if len(diversificationNonce) > 0 { + s.cryptoSetup.SetDiversificationNonce(diversificationNonce) + } + } + if p.rcvTime.IsZero() { // To simplify testing p.rcvTime = time.Now() @@ -264,13 +318,19 @@ func (s *Session) handlePacketImpl(p *receivedPacket) error { hdr.PacketNumber, ) if utils.Debug() { - utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x @ %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, time.Now().Format("15:04:05.000")) } - // TODO: Only do this after authenticating - s.conn.setCurrentRemoteAddr(p.remoteAddr) - packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) + // if the decryption failed, this might be a packet sent by an attacker + // don't update the remote address + if quicErr, ok := err.(*qerr.QuicError); ok && quicErr.ErrorCode == qerr.DecryptionFailure { + return err + } + if s.perspective == protocol.PerspectiveServer { + // update the remote address, even if unpacking failed for any other reason than a decryption error + s.conn.setCurrentRemoteAddr(p.remoteAddr) + } if err != nil { return err } @@ -279,7 +339,7 @@ func (s *Session) handlePacketImpl(p *receivedPacket) error { // Only do this after decrypting, so we are sure the packet is not attacker-controlled s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) - err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber) + err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, packet.IsRetransmittable()) // ignore duplicate packets if err == ackhandler.ErrDuplicatePacket { utils.Infof("Ignoring packet 0x%x due to ErrDuplicatePacket", hdr.PacketNumber) @@ -305,7 +365,6 @@ func (s *Session) handleFrames(fs []frames.Frame) error { switch frame := ff.(type) { case *frames.StreamFrame: err = s.handleStreamFrame(frame) - // TODO: send RstStreamFrame case *frames.AckFrame: err = s.handleAckFrame(frame) case *frames.ConnectionCloseFrame: @@ -357,7 +416,8 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { return err } if str == nil { - // Stream is closed, ignore + // Stream is closed and already garbage collected + // ignore this StreamFrame return nil } err = str.AddStreamFrame(frame) @@ -381,7 +441,6 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error return err } -// TODO: Handle frame.byteOffset func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { @@ -390,8 +449,9 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { if str == nil { return errRstStreamOnInvalidStream } - s.closeStreamWithError(str, fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) - return nil + + str.RegisterRemoteError(fmt.Errorf("RST_STREAM received with code %d", frame.ErrorCode)) + return s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) } func (s *Session) handleAckFrame(frame *frames.AckFrame) error { @@ -402,13 +462,39 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error { } // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. +// It waits until the run loop has stopped before returning func (s *Session) Close(e error) error { - return s.closeImpl(e, false) + err := s.closeImpl(e, false) + if err == errSessionAlreadyClosed { + return nil + } + + // wait for the run loop to finish + <-s.runClosed + return err +} + +// close the connection. Use this when called from the run loop +func (s *Session) close(e error) error { + err := s.closeImpl(e, false) + if err == errSessionAlreadyClosed { + return nil + } + return err } func (s *Session) closeImpl(e error, remoteClose bool) error { // Only close once if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + return errSessionAlreadyClosed + } + + if e == errCloseSessionForNewVersion { + s.closeStreamsWithError(e) + // when the run loop exits, it will call the closeCallback + // replace it with an noop function to make sure this doesn't have any effect + s.closeCallback = func(protocol.ConnectionID) {} + s.closeChan <- nil return nil } @@ -426,7 +512,6 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { } s.closeStreamsWithError(quicErr) - s.closeCallback(s.connectionID) if remoteClose { // If this is a remote close we don't need to send a CONNECTION_CLOSE @@ -445,15 +530,11 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { func (s *Session) closeStreamsWithError(err error) { s.streamsMap.Iterate(func(str *stream) (bool, error) { - s.closeStreamWithError(str, err) + str.Cancel(err) return true, nil }) } -func (s *Session) closeStreamWithError(str *stream, err error) { - str.RegisterError(err) -} - func (s *Session) sendPacket() error { // Repeatedly try sending until we don't have any more data, or run out of the congestion window for { @@ -471,6 +552,16 @@ func (s *Session) sendPacket() error { var controlFrames []frames.Frame + // get WindowUpdate frames + // this call triggers the flow controller to increase the flow control windows, if necessary + windowUpdateFrames, err := s.getWindowUpdateFrames() + if err != nil { + return err + } + for _, wuf := range windowUpdateFrames { + controlFrames = append(controlFrames, wuf) + } + // check for retransmissions first for { retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() @@ -480,55 +571,41 @@ func (s *Session) sendPacket() error { utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) // resend the frames that were in the packet - controlFrames = append(controlFrames, retransmitPacket.GetControlFramesForRetransmission()...) - for _, streamFrame := range retransmitPacket.GetStreamFramesForRetransmission() { - s.streamFramer.AddFrameForRetransmission(streamFrame) + for _, frame := range retransmitPacket.GetFramesForRetransmission() { + switch frame.(type) { + case *frames.StreamFrame: + s.streamFramer.AddFrameForRetransmission(frame.(*frames.StreamFrame)) + case *frames.WindowUpdateFrame: + // only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream + var currentOffset protocol.ByteCount + f := frame.(*frames.WindowUpdateFrame) + currentOffset, err = s.flowControlManager.GetReceiveWindow(f.StreamID) + if err == nil && f.ByteOffset >= currentOffset { + controlFrames = append(controlFrames, frame) + } + default: + controlFrames = append(controlFrames, frame) + } } } - windowUpdateFrames, err := s.getWindowUpdateFrames() - if err != nil { - return err - } - - for _, wuf := range windowUpdateFrames { - controlFrames = append(controlFrames, wuf) - } - - ack, err := s.receivedPacketHandler.GetAckFrame(false) - if err != nil { - return err - } + ack := s.receivedPacketHandler.GetAckFrame() if ack != nil { controlFrames = append(controlFrames, ack) } - - // Check whether we are allowed to send a packet containing only an ACK - maySendOnlyAck := time.Now().Sub(s.delayedAckOriginTime) > protocol.AckSendDelay - if runtime.GOOS == "windows" { - maySendOnlyAck = true - } - hasRetransmission := s.streamFramer.HasFramesForRetransmission() - var stopWaitingFrame *frames.StopWaitingFrame if ack != nil || hasRetransmission { stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) } - packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked(), maySendOnlyAck) + packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked()) if err != nil { return err } if packet == nil { return nil } - - // Pop the ACK frame now that we are sure we're gonna send it - _, err = s.receivedPacketHandler.GetAckFrame(true) - if err != nil { - return err - } - + // send every window update twice for _, f := range windowUpdateFrames { s.packer.QueueControlFrameForNextPacket(f) } @@ -543,13 +620,13 @@ func (s *Session) sendPacket() error { } s.logPacket(packet) - s.delayedAckOriginTime = time.Time{} err = s.conn.write(packet.raw) putPacketBuffer(packet.raw) if err != nil { return err } + s.nextAckScheduledTime = time.Time{} } } @@ -571,7 +648,7 @@ func (s *Session) logPacket(packet *packedPacket) { return } if utils.Debug() { - utils.Debugf("-> Sending packet 0x%x (%d bytes)", packet.number, len(packet.raw)) + utils.Debugf("-> Sending packet 0x%x (%d bytes) @ %s", packet.number, len(packet.raw), time.Now().Format("15:04:05.000")) for _, frame := range packet.frames { frames.LogFrame(frame, true) } @@ -589,12 +666,16 @@ func (s *Session) OpenStream(id protocol.StreamID) (utils.Stream, error) { return s.streamsMap.OpenStream(id) } -func (s *Session) newStreamImpl(id protocol.StreamID) (*stream, error) { - return s.streamsMap.GetOrOpenStream(id) +func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { + s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ + StreamID: id, + ByteOffset: offset, + }) + s.scheduleSending() } func (s *Session) newStream(id protocol.StreamID) (*stream, error) { - stream, err := newStream(id, s.scheduleSending, s.flowControlManager) + stream, err := newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager) if err != nil { return nil, err } @@ -646,7 +727,7 @@ func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) { } utils.Infof("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber) if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets { - s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) + s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } s.undecryptablePackets = append(s.undecryptablePackets, p) } @@ -667,6 +748,11 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { return res, nil } +func (s *Session) ackAlarmChanged(t time.Time) { + s.nextAckScheduledTime = t + s.maybeResetTimer() +} + // RemoteAddr returns the net.UDPAddr of the client func (s *Session) RemoteAddr() *net.UDPAddr { return s.conn.RemoteAddr() diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream.go index fc73b1e..72d3ec8 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream.go @@ -4,12 +4,10 @@ import ( "fmt" "io" "sync" - "sync/atomic" "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/utils" ) @@ -17,36 +15,47 @@ import ( // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. type stream struct { + mutex sync.Mutex + streamID protocol.StreamID onData func() + // onReset is a callback that should send a RST_STREAM + onReset func(protocol.StreamID, protocol.ByteCount) readPosInFrame int writeOffset protocol.ByteCount readOffset protocol.ByteCount - // Once set, err must not be changed! - err error - mutex sync.Mutex + // Once set, the errors must not be changed! + err error - // eof is set if we are finished reading - eof int32 // really a bool - // closed is set when we are finished writing - closed int32 // really a bool + // cancelled is set when Cancel() is called + cancelled utils.AtomicBool + // finishedReading is set once we read a frame with a FinBit + finishedReading utils.AtomicBool + // finisedWriting is set once Close() is called + finishedWriting utils.AtomicBool + // resetLocally is set if Reset() is called + resetLocally utils.AtomicBool + // resetRemotely is set if RegisterRemoteError() is called + resetRemotely utils.AtomicBool frameQueue *streamFrameSorter newFrameOrErrCond sync.Cond dataForWriting []byte - finSent bool + finSent utils.AtomicBool + rstSent utils.AtomicBool doneWritingOrErrCond sync.Cond flowControlManager flowcontrol.FlowControlManager } // newStream creates a new Stream -func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flowcontrol.FlowControlManager) (*stream, error) { +func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowControlManager flowcontrol.FlowControlManager) (*stream, error) { s := &stream{ onData: onData, + onReset: onReset, streamID: StreamID, flowControlManager: flowControlManager, frameQueue: newStreamFrameSorter(), @@ -60,7 +69,10 @@ func newStream(StreamID protocol.StreamID, onData func(), flowControlManager flo // Read implements io.Reader. It is not thread safe! func (s *stream) Read(p []byte) (int, error) { - if atomic.LoadInt32(&s.eof) != 0 { + if s.cancelled.Get() || s.resetLocally.Get() { + return 0, s.err + } + if s.finishedReading.Get() { return 0, io.EOF } @@ -77,7 +89,7 @@ func (s *stream) Read(p []byte) (int, error) { var err error for { // Stop waiting on errors - if s.err != nil { + if s.resetLocally.Get() || s.cancelled.Get() { err = s.err break } @@ -89,11 +101,8 @@ func (s *stream) Read(p []byte) (int, error) { frame = s.frameQueue.Head() } s.mutex.Unlock() - // Here, either frame != nil xor err != nil - if frame == nil { - atomic.StoreInt32(&s.eof, 1) - // We have an err and no data, return the error + if err != nil { return bytesRead, err } @@ -111,7 +120,10 @@ func (s *stream) Read(p []byte) (int, error) { bytesRead += m s.readOffset += protocol.ByteCount(m) - s.flowControlManager.AddBytesRead(s.streamID, protocol.ByteCount(m)) + // when a RST_STREAM was received, the was already informed about the final byteOffset for this stream + if !s.resetRemotely.Get() { + s.flowControlManager.AddBytesRead(s.streamID, protocol.ByteCount(m)) + } s.onData() // so that a possible WINDOW_UPDATE is sent if s.readPosInFrame >= int(frame.DataLen()) { @@ -120,7 +132,7 @@ func (s *stream) Read(p []byte) (int, error) { s.frameQueue.Pop() s.mutex.Unlock() if fin { - atomic.StoreInt32(&s.eof, 1) + s.finishedReading.Set(true) return bytesRead, io.EOF } } @@ -130,6 +142,10 @@ func (s *stream) Read(p []byte) (int, error) { } func (s *stream) Write(p []byte) (int, error) { + if s.resetLocally.Get() { + return 0, s.err + } + s.mutex.Lock() defer s.mutex.Unlock() @@ -159,13 +175,20 @@ func (s *stream) Write(p []byte) (int, error) { func (s *stream) lenOfDataForWriting() protocol.ByteCount { s.mutex.Lock() - l := protocol.ByteCount(len(s.dataForWriting)) + var l protocol.ByteCount + if s.err == nil { + l = protocol.ByteCount(len(s.dataForWriting)) + } s.mutex.Unlock() return l } func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { s.mutex.Lock() + if s.err != nil { + s.mutex.Unlock() + return nil + } if s.dataForWriting == nil { s.mutex.Unlock() return nil @@ -186,35 +209,33 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { // Close implements io.Closer func (s *stream) Close() error { - atomic.StoreInt32(&s.closed, 1) + s.finishedWriting.Set(true) s.onData() return nil } +func (s *stream) shouldSendReset() bool { + if s.rstSent.Get() { + return false + } + return (s.resetLocally.Get() || s.resetRemotely.Get()) && !s.finishedWriteAndSentFin() +} + func (s *stream) shouldSendFin() bool { s.mutex.Lock() - res := atomic.LoadInt32(&s.closed) != 0 && !s.finSent && s.err == nil && s.dataForWriting == nil + res := s.finishedWriting.Get() && !s.finSent.Get() && s.err == nil && s.dataForWriting == nil s.mutex.Unlock() return res } func (s *stream) sentFin() { - s.mutex.Lock() - s.finSent = true - s.mutex.Unlock() + s.finSent.Set(true) } // AddStreamFrame adds a new stream frame func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { maxOffset := frame.Offset + frame.DataLen() err := s.flowControlManager.UpdateHighestReceived(s.streamID, maxOffset) - - if err == flowcontrol.ErrStreamFlowControlViolation { - return qerr.FlowControlReceivedTooMuchData - } - if err == flowcontrol.ErrConnectionFlowControlViolation { - return qerr.FlowControlReceivedTooMuchData - } if err != nil { return err } @@ -234,32 +255,69 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) { s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) } -// RegisterError is called by session to indicate that an error occurred and the -// stream should be closed. -func (s *stream) RegisterError(err error) { - atomic.StoreInt32(&s.closed, 1) +// Cancel is called by session to indicate that an error occurred +// The stream should will be closed immediately +func (s *stream) Cancel(err error) { s.mutex.Lock() - defer s.mutex.Unlock() - if s.err != nil { // s.err must not be changed! + s.cancelled.Set(true) + // errors must not be changed! + if s.err == nil { + s.err = err + s.newFrameOrErrCond.Signal() + s.doneWritingOrErrCond.Signal() + } + s.mutex.Unlock() +} + +// resets the stream locally +func (s *stream) Reset(err error) { + if s.resetLocally.Get() { return } - s.err = err - s.doneWritingOrErrCond.Signal() - s.newFrameOrErrCond.Signal() -} - -func (s *stream) finishedReading() bool { - return atomic.LoadInt32(&s.eof) != 0 -} - -func (s *stream) finishedWriting() bool { s.mutex.Lock() - defer s.mutex.Unlock() - return s.err != nil || (atomic.LoadInt32(&s.closed) != 0 && s.finSent) + s.resetLocally.Set(true) + // errors must not be changed! + if s.err == nil { + s.err = err + s.newFrameOrErrCond.Signal() + s.doneWritingOrErrCond.Signal() + } + if s.shouldSendReset() { + s.onReset(s.streamID, s.writeOffset) + s.rstSent.Set(true) + } + s.mutex.Unlock() +} + +// resets the stream remotely +func (s *stream) RegisterRemoteError(err error) { + if s.resetRemotely.Get() { + return + } + s.mutex.Lock() + s.resetRemotely.Set(true) + // errors must not be changed! + if s.err == nil { + s.err = err + s.doneWritingOrErrCond.Signal() + } + if s.shouldSendReset() { + s.onReset(s.streamID, s.writeOffset) + s.rstSent.Set(true) + } + s.mutex.Unlock() +} + +func (s *stream) finishedWriteAndSentFin() bool { + return s.finishedWriting.Get() && s.finSent.Get() } func (s *stream) finished() bool { - return s.finishedReading() && s.finishedWriting() + return s.cancelled.Get() || + (s.finishedReading.Get() && s.finishedWriteAndSentFin()) || + (s.resetRemotely.Get() && s.rstSent.Get()) || + (s.finishedReading.Get() && s.rstSent.Get()) || + (s.finishedWriteAndSentFin() && s.resetRemotely.Get()) } func (s *stream) StreamID() protocol.StreamID { diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream_framer.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream_framer.go index b702c28..45c0722 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream_framer.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/stream_framer.go @@ -119,7 +119,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] if f.flowControlManager.RemainingConnectionWindowSize() == 0 { // We are now connection-level FC blocked f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: 0}) - } else if sendWindowSize-frame.DataLen() == 0 { + } else if !frame.FinBit && sendWindowSize-frame.DataLen() == 0 { // We are now stream-level FC blocked f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: s.StreamID()}) } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/streams_map.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/streams_map.go index 8b41b86..7666497 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/streams_map.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/streams_map.go @@ -5,24 +5,31 @@ import ( "fmt" "sync" + "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" - "github.com/lucas-clemente/quic-go/utils" ) type streamsMap struct { mutex sync.RWMutex + perspective protocol.Perspective + connectionParameters handshake.ConnectionParametersManager + streams map[protocol.StreamID]*stream openStreams []protocol.StreamID highestStreamOpenedByClient protocol.StreamID streamsOpenedAfterLastGarbageCollect int - newStream newStreamLambda - maxNumStreams int + newStream newStreamLambda - roundRobinIndex int + maxOutgoingStreams uint32 + numOutgoingStreams uint32 + maxIncomingStreams uint32 + numIncomingStreams uint32 + + roundRobinIndex uint32 } type streamLambda func(*stream) (bool, error) @@ -32,14 +39,13 @@ var ( errMapAccess = errors.New("streamsMap: Error accessing the streams map") ) -func newStreamsMap(newStream newStreamLambda) *streamsMap { - maxNumStreams := utils.Max(int(float32(protocol.MaxIncomingDynamicStreams)*protocol.MaxStreamsMultiplier), int(protocol.MaxIncomingDynamicStreams)) - +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap { return &streamsMap{ - streams: map[protocol.StreamID]*stream{}, - openStreams: make([]protocol.StreamID, 0, maxNumStreams), - newStream: newStream, - maxNumStreams: maxNumStreams, + perspective: pers, + streams: map[protocol.StreamID]*stream{}, + openStreams: make([]protocol.StreamID, 0), + newStream: newStream, + connectionParameters: connectionParameters, } } @@ -61,12 +67,15 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if ok { return s, nil } - if len(m.openStreams) == m.maxNumStreams { + if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { return nil, qerr.TooManyOpenStreams } - if id%2 == 0 { + if m.perspective == protocol.PerspectiveServer && id%2 == 0 { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) } + if m.perspective == protocol.PerspectiveClient && id%2 == 1 { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) + } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByClient { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByClient)) } @@ -76,10 +85,17 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { return nil, err } + if m.perspective == protocol.PerspectiveServer { + m.numIncomingStreams++ + } else { + m.numOutgoingStreams++ + } + if id > m.highestStreamOpenedByClient { m.highestStreamOpenedByClient = id } + // maybe trigger garbage collection of streams map m.streamsOpenedAfterLastGarbageCollect++ if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 { m.garbageCollectClosedStreams() @@ -91,7 +107,36 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { // OpenStream opens a stream from the server's side func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) { - panic("OpenStream: not implemented") + if m.perspective == protocol.PerspectiveServer && id%2 == 1 { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) + } + if m.perspective == protocol.PerspectiveClient && id%2 == 0 { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + _, ok := m.streams[id] + if ok { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is already open", id)) + } + if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { + return nil, qerr.TooManyOpenStreams + } + + s, err := m.newStream(id) + if err != nil { + return nil, err + } + + if m.perspective == protocol.PerspectiveServer { + m.numOutgoingStreams++ + } else { + m.numIncomingStreams++ + } + + m.putStream(s) + return s, nil } func (m *streamsMap) Iterate(fn streamLambda) error { @@ -117,7 +162,7 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() - numStreams := len(m.openStreams) + numStreams := uint32(len(m.openStreams)) startIndex := m.roundRobinIndex for _, i := range []protocol.StreamID{1, 3} { @@ -130,7 +175,7 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { } } - for i := 0; i < numStreams; i++ { + for i := uint32(0); i < numStreams; i++ { streamID := m.openStreams[(i+startIndex)%numStreams] if streamID == 1 || streamID == 3 { @@ -180,13 +225,18 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { } m.streams[id] = nil + if id%2 == 0 { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- + } for i, s := range m.openStreams { if s == id { // delete the streamID from the openStreams slice m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])] // adjust round-robin index, if necessary - if i < m.roundRobinIndex { + if uint32(i) < m.roundRobinIndex { m.roundRobinIndex-- } break @@ -196,14 +246,6 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { return nil } -// NumberOfStreams gets the number of open streams -func (m *streamsMap) NumberOfStreams() int { - m.mutex.RLock() - n := len(m.openStreams) - m.mutex.RUnlock() - return n -} - // garbageCollectClosedStreams deletes nil values in the streams if they are smaller than protocol.MaxNewStreamIDDelta than the highest stream opened by the client // note that this garbage collection is relatively expensive, since it iterates over the whole streams map. It should not be called every time a stream is openend or closed func (m *streamsMap) garbageCollectClosedStreams() { @@ -211,7 +253,10 @@ func (m *streamsMap) garbageCollectClosedStreams() { if str != nil { continue } - if id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient { + + // server-side streams can be gargage collected immediately + // client-side streams need to be kept as nils in the streams map for a bit longer, in order to prevent a client from reopening closed streams + if id%2 == 0 || id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient { delete(m.streams, id) } } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/udp_conn.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/udp_conn.go index 2c1bafe..efc646d 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/udp_conn.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/udp_conn.go @@ -1,6 +1,9 @@ package quic -import "net" +import ( + "net" + "sync" +) type connection interface { write([]byte) error @@ -9,6 +12,8 @@ type connection interface { } type udpConn struct { + mutex sync.RWMutex + conn *net.UDPConn currentAddr *net.UDPAddr } @@ -21,9 +26,14 @@ func (c *udpConn) write(p []byte) error { } func (c *udpConn) setCurrentRemoteAddr(addr interface{}) { + c.mutex.Lock() c.currentAddr = addr.(*net.UDPAddr) + c.mutex.Unlock() } func (c *udpConn) RemoteAddr() *net.UDPAddr { - return c.currentAddr + c.mutex.RLock() + addr := c.currentAddr + c.mutex.RUnlock() + return addr } diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/unpacked_packet.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/unpacked_packet.go new file mode 100644 index 0000000..8079204 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/unpacked_packet.go @@ -0,0 +1,27 @@ +package quic + +import "github.com/lucas-clemente/quic-go/frames" + +type unpackedPacket struct { + frames []frames.Frame +} + +func (u *unpackedPacket) IsRetransmittable() bool { + for _, f := range u.frames { + switch f.(type) { + case *frames.StreamFrame: + return true + case *frames.RstStreamFrame: + return true + case *frames.WindowUpdateFrame: + return true + case *frames.BlockedFrame: + return true + case *frames.PingFrame: + return true + case *frames.GoawayFrame: + return true + } + } + return false +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/atomic_bool.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/atomic_bool.go new file mode 100644 index 0000000..cf46425 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/atomic_bool.go @@ -0,0 +1,22 @@ +package utils + +import "sync/atomic" + +// An AtomicBool is an atomic bool +type AtomicBool struct { + v int32 +} + +// Set sets the value +func (a *AtomicBool) Set(value bool) { + var n int32 + if value { + n = 1 + } + atomic.StoreInt32(&a.v, n) +} + +// Get gets the value +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.v) != 0 +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/connection_id.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/connection_id.go new file mode 100644 index 0000000..e0227d0 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/connection_id.go @@ -0,0 +1,18 @@ +package utils + +import ( + "crypto/rand" + "encoding/binary" + + "github.com/lucas-clemente/quic-go/protocol" +) + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID() (protocol.ConnectionID, error) { + b := make([]byte, 8, 8) + _, err := rand.Read(b) + if err != nil { + return 0, err + } + return protocol.ConnectionID(binary.LittleEndian.Uint64(b)), nil +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/host.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/host.go new file mode 100644 index 0000000..a1d6453 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/host.go @@ -0,0 +1,27 @@ +package utils + +import ( + "net/url" + "strings" +) + +// HostnameFromAddr determines the hostname in an address string +func HostnameFromAddr(addr string) (string, error) { + p, err := url.Parse(addr) + if err != nil { + return "", err + } + h := p.Host + + // copied from https://golang.org/src/net/http/transport.go + if hasPort(h) { + h = h[:strings.LastIndex(h, ":")] + } + + return h, nil +} + +// copied from https://golang.org/src/net/http/http.go +func hasPort(s string) bool { + return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/minmax.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/minmax.go index ec22139..6e23df5 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/minmax.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/minmax.go @@ -34,6 +34,14 @@ func MaxUint64(a, b uint64) uint64 { return a } +// MinUint64 returns the maximum of two uint64 +func MinUint64(a, b uint64) uint64 { + if a < b { + return a + } + return b +} + // Min returns the minimum of two Ints func Min(a, b int) int { if a < b { diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/stream.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/stream.go new file mode 100644 index 0000000..b776a89 --- /dev/null +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/stream.go @@ -0,0 +1,17 @@ +package utils + +import ( + "io" + + "github.com/lucas-clemente/quic-go/protocol" +) + +// Stream is the interface for QUIC streams +type Stream interface { + io.Reader + io.Writer + io.Closer + StreamID() protocol.StreamID + CloseRemote(offset protocol.ByteCount) + Reset(error) +} diff --git a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/utils.go b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/utils.go index cf25679..f6c4e03 100644 --- a/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/utils.go +++ b/cmd/gost/vendor/github.com/lucas-clemente/quic-go/utils/utils.go @@ -3,19 +3,8 @@ package utils import ( "bytes" "io" - - "github.com/lucas-clemente/quic-go/protocol" ) -// Stream is the interface for QUIC streams -type Stream interface { - io.Reader - io.Writer - io.Closer - StreamID() protocol.StreamID - CloseRemote(offset protocol.ByteCount) -} - // ReadUintN reads N bytes func ReadUintN(b io.ByteReader, length uint8) (uint64, error) { var res uint64 diff --git a/cmd/gost/vendor/vendor.json b/cmd/gost/vendor/vendor.json index aaf0b34..def1c54 100644 --- a/cmd/gost/vendor/vendor.json +++ b/cmd/gost/vendor/vendor.json @@ -14,6 +14,12 @@ "revision": "ec07b4f69a3f70b1dd2a8ad77230deb1ba5d6953", "revisionTime": "2015-11-07T02:50:05Z" }, + { + "checksumSHA1": "aIhLeVAIrsjs63CwqmU3+GU8yT4=", + "path": "github.com/ginuerzh/gosocks4", + "revision": "fc196f9d34e35f19a813bff2f092a275131c23bc", + "revisionTime": "2017-02-09T14:09:51Z" + }, { "checksumSHA1": "5TwW96Afcvo+zm0tAn+DSNIQreQ=", "path": "github.com/ginuerzh/gosocks5", @@ -21,10 +27,10 @@ "revisionTime": "2017-01-19T05:34:58Z" }, { - "checksumSHA1": "ero0DQrGYph2eFDyEjDnOQV8NHo=", + "checksumSHA1": "1iyn4OEHEJknBi+IiZuUaJi6Ifw=", "path": "github.com/ginuerzh/gost", - "revision": "72d8a598d50f8517d922f233e8f8d37011fcb18f", - "revisionTime": "2017-01-25T04:23:26Z" + "revision": "333291e9bc766a76d6df5243a7022c5f028be17c", + "revisionTime": "2017-02-05T06:35:38Z" }, { "checksumSHA1": "+XIOnTW0rv8Kr/amkXgMraNeUr4=", @@ -75,10 +81,10 @@ "revisionTime": "2016-09-12T19:31:07Z" }, { - "checksumSHA1": "fbae4URna3lp8RtTOutiXIO1JS0=", + "checksumSHA1": "hUI9uYDnlXeOY+SEAPViyVpgq6I=", "path": "github.com/lucas-clemente/aes12", - "revision": "8ee5b5610baca43b60ecfad586b3c40d92a96e0c", - "revisionTime": "2016-08-23T09:51:02Z" + "revision": "25700e67be5c860bcc999137275b9ef8b65932bd", + "revisionTime": "2016-12-15T15:22:28Z" }, { "checksumSHA1": "ne1X+frkx5fJcpz9FaZPuUZ7amM=", @@ -87,10 +93,10 @@ "revisionTime": "2016-05-04T15:23:51Z" }, { - "checksumSHA1": "xbX/mARowOKpW3S1G8hmaDlWdp8=", + "checksumSHA1": "KF0chHPN90e/Ct/WJJYsuHKdiEM=", "path": "github.com/lucas-clemente/quic-go", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { "checksumSHA1": "OA9E+y7g05x/mWJJHmA7oPxWKQo=", @@ -99,64 +105,64 @@ "revisionTime": "2016-08-23T09:51:56Z" }, { - "checksumSHA1": "1+iOTf/w8VXGqQao0FMNEE2RFFg=", + "checksumSHA1": "NtA/2oLQBR72S8l43PNUug2TCog=", "path": "github.com/lucas-clemente/quic-go/ackhandler", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { "checksumSHA1": "8zoU6uLKP2Czs96VgmNMubNcWKk=", "path": "github.com/lucas-clemente/quic-go/congestion", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { - "checksumSHA1": "vzBxE7JViWonmrSndgmRuye8ntA=", + "checksumSHA1": "XUdlanAUfQt+UjszRgtMeziyiG8=", "path": "github.com/lucas-clemente/quic-go/crypto", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { - "checksumSHA1": "CaOt7EZEuyWQ073FITB8qQfFswA=", + "checksumSHA1": "qeA/SEEEJISmxMiT7wVH87x8bYs=", "path": "github.com/lucas-clemente/quic-go/flowcontrol", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { - "checksumSHA1": "VvhIOTKtMkPZ7pdrCPHlDQI2wIw=", + "checksumSHA1": "Uj2jJdb2wytCw/DJqmVD7wvq4BU=", "path": "github.com/lucas-clemente/quic-go/frames", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { - "checksumSHA1": "IJDUjsrJP5dtHvxVNT32x4SQQVk=", + "checksumSHA1": "Q+A6dMo3fn7olbOo64BocN5Iun8=", "path": "github.com/lucas-clemente/quic-go/h2quic", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { - "checksumSHA1": "1DQTvvwcUUrmMKkN4ASjX5+iGqs=", + "checksumSHA1": "ItcGJhHoCWJVDtJagkdbtFkqIMo=", "path": "github.com/lucas-clemente/quic-go/handshake", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { - "checksumSHA1": "qp5LXpuvIAgW3BffRzHVZQk1WfE=", + "checksumSHA1": "5qERDU1QlAgmMVbyY/zkWTzG7Po=", "path": "github.com/lucas-clemente/quic-go/protocol", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { "checksumSHA1": "ss57bkTclCnmt9fVosYie/ehkoo=", "path": "github.com/lucas-clemente/quic-go/qerr", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { - "checksumSHA1": "j/A6Rfz4BWpt5rsR2j5kR0H2ZTI=", + "checksumSHA1": "LUpOEK8o7qCBWHt5DdPd+3QG6Y4=", "path": "github.com/lucas-clemente/quic-go/utils", - "revision": "ef977ee0591f72543f8323cd12a585f5406ff971", - "revisionTime": "2016-10-14T09:35:10Z" + "revision": "268841f0cc2962070a8bd662551b150acbee369b", + "revisionTime": "2017-02-04T02:12:35Z" }, { "checksumSHA1": "ynJSWoF6v+3zMnh9R0QmmG6iGV8=", @@ -185,14 +191,14 @@ { "checksumSHA1": "dwOedwBJ1EIK9+S3t108Bx054Y8=", "path": "golang.org/x/crypto/curve25519", - "revision": "1150b8bd09e53aea1d415621adae9bad665061a1", - "revisionTime": "2016-10-21T22:59:10Z" + "revision": "bed12803fa9663d7aa2c2346b0c634ad2dcd43b7", + "revisionTime": "2017-02-01T20:15:17Z" }, { "checksumSHA1": "4D8hxMIaSDEW5pCQk22Xj4DcDh4=", "path": "golang.org/x/crypto/hkdf", - "revision": "1150b8bd09e53aea1d415621adae9bad665061a1", - "revisionTime": "2016-10-21T22:59:10Z" + "revision": "bed12803fa9663d7aa2c2346b0c634ad2dcd43b7", + "revisionTime": "2017-02-01T20:15:17Z" }, { "checksumSHA1": "1MGpGDQqnUoRpv7VEcQrXOBydXE=",