This commit is contained in:
MacType 2018-12-07 15:39:33 +08:00
commit 2c290a1a0d
855 changed files with 234894 additions and 29541 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "vendor/github.com/shadowsocks/go-shadowsocks2"]
path = vendor/github.com/shadowsocks/go-shadowsocks2
url = https://github.com/shadowsocks/go-shadowsocks2.git

View File

@ -34,6 +34,8 @@ Wiki站点: <https://docs.ginuerzh.xyz/gost/>
Google讨论组: <https://groups.google.com/d/forum/go-gost>
Telegram讨论群: <https://t.me/gogost>
安装
------

View File

@ -31,6 +31,7 @@ Wiki: <https://docs.ginuerzh.xyz/gost/en/>
Google group: <https://groups.google.com/d/forum/go-gost>
Telegram group: <https://t.me/gogost>
Installation
------

View File

@ -124,7 +124,7 @@ type Bypass struct {
matchers []Matcher
reversed bool
period time.Duration // the period for live reloading
mux sync.Mutex
mux sync.RWMutex
}
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
@ -160,8 +160,8 @@ func (bp *Bypass) Contains(addr string) bool {
}
}
bp.mux.Lock()
defer bp.mux.Unlock()
bp.mux.RLock()
defer bp.mux.RUnlock()
var matched bool
for _, matcher := range bp.matchers {
@ -179,22 +179,33 @@ func (bp *Bypass) Contains(addr string) bool {
// AddMatchers appends matchers to the bypass matcher list.
func (bp *Bypass) AddMatchers(matchers ...Matcher) {
bp.mux.Lock()
defer bp.mux.Unlock()
bp.matchers = append(bp.matchers, matchers...)
}
// Matchers return the bypass matcher list.
func (bp *Bypass) Matchers() []Matcher {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.matchers
}
// Reversed reports whether the rules of the bypass are reversed.
func (bp *Bypass) Reversed() bool {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.reversed
}
// Reload parses config from r, then live reloads the bypass.
func (bp *Bypass) Reload(r io.Reader) error {
var matchers []Matcher
var period time.Duration
var reversed bool
scanner := bufio.NewScanner(r)
for scanner.Scan() {
@ -217,7 +228,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
}
}
if len(ss) == 2 {
bp.period, _ = time.ParseDuration(ss[1])
period, _ = time.ParseDuration(ss[1])
continue
}
}
@ -231,7 +242,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
}
}
if len(ss) == 2 {
bp.reversed, _ = strconv.ParseBool(ss[1])
reversed, _ = strconv.ParseBool(ss[1])
continue
}
}
@ -247,19 +258,28 @@ func (bp *Bypass) Reload(r io.Reader) error {
defer bp.mux.Unlock()
bp.matchers = matchers
bp.period = period
bp.reversed = reversed
return nil
}
// Period returns the reload period
func (bp *Bypass) Period() time.Duration {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.period
}
func (bp *Bypass) String() string {
bp.mux.RLock()
defer bp.mux.RUnlock()
b := &bytes.Buffer{}
fmt.Fprintf(b, "reversed: %v\n", bp.Reversed())
for _, m := range bp.Matchers() {
fmt.Fprintf(b, "reversed: %v\n", bp.reversed)
fmt.Fprintf(b, "reload: %v\n", bp.period)
for _, m := range bp.matchers {
b.WriteString(m.String())
b.WriteByte('\n')
}

View File

@ -38,7 +38,7 @@ func newRoute(nodes ...Node) *Chain {
}
// Nodes returns the proxy nodes that the chain holds.
// If a node is a node group, the first node in the group will be returned.
// The first node in each group will be returned.
func (c *Chain) Nodes() (nodes []Node) {
for _, group := range c.nodeGroups {
if ns := group.Nodes(); len(ns) > 0 {
@ -61,7 +61,7 @@ func (c *Chain) LastNode() Node {
return Node{}
}
group := c.nodeGroups[len(c.nodeGroups)-1]
return group.nodes[0].Clone()
return group.GetNode(0)
}
// LastNodeGroup returns the last group of the group list.
@ -136,13 +136,14 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e
return nil, err
}
addr = c.resolve(addr, options.Resolver, options.Hosts)
ipAddr := c.resolve(addr, options.Resolver, options.Hosts)
if route.IsEmpty() {
return net.DialTimeout("tcp", addr, options.Timeout)
return net.DialTimeout("tcp", ipAddr, options.Timeout)
}
conn, err := route.getConn(addr)
conn, err := route.getConn(ipAddr)
if err != nil {
return nil, err
}
@ -172,7 +173,6 @@ func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
}
// Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
@ -215,18 +215,22 @@ func (c *Chain) getConn(addr string) (conn net.Conn, err error) {
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
node.group.ResetDeadNode(node.ID)
if len(nodes) > 1 {
node.ResetDead() // don't reset the last node as we are going to check if it will connect successfully.
node.group.ResetDeadNode(node.ID) // don't reset the last node as we are going to check if it will connect successfully.
}
preNode := node
@ -235,17 +239,19 @@ func (c *Chain) getConn(addr string) (conn net.Conn, err error) {
cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil {
cn.Close()
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil {
cn.Close()
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
if len(nodes) > 1 {
node.ResetDead()
node.group.ResetDeadNode(node.ID)
}
cn = cc
preNode = node
@ -257,14 +263,14 @@ func (c *Chain) getConn(addr string) (conn net.Conn, err error) {
cc, err = node.Client.Connect(conn, addr)
if err != nil {
if _, ok := err.(*net.OpError); ok {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
}
conn.Close()
return
}
conn = cc
}
node.ResetDead()
node.group.ResetDeadNode(node.ID)
return
}

View File

@ -27,8 +27,8 @@ func (c *Client) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn,
}
// Connect connects to the address addr via the proxy over connection conn.
func (c *Client) Connect(conn net.Conn, addr string) (net.Conn, error) {
return c.Connector.Connect(conn, addr)
func (c *Client) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return c.Connector.Connect(conn, addr, options...)
}
// DefaultClient is a standard HTTP proxy client.
@ -51,7 +51,7 @@ func Connect(conn net.Conn, addr string) (net.Conn, error) {
// Connector is responsible for connecting to the destination address.
type Connector interface {
Connect(conn net.Conn, addr string) (net.Conn, error)
Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error)
}
// Transporter is responsible for handshaking with the proxy server.
@ -96,7 +96,7 @@ type DialOptions struct {
Chain *Chain
}
// DialOption allows a common way to set dial options.
// DialOption allows a common way to set DialOptions.
type DialOption func(opts *DialOptions)
// TimeoutDialOption specifies the timeout used by Transporter.Dial
@ -127,7 +127,7 @@ type HandshakeOptions struct {
QUICConfig *QUICConfig
}
// HandshakeOption allows a common way to set handshake options.
// HandshakeOption allows a common way to set HandshakeOptions.
type HandshakeOption func(opts *HandshakeOptions)
// AddrHandshakeOption specifies the server address
@ -199,3 +199,18 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
opts.QUICConfig = config
}
}
// ConnectOptions describes the options for Connector.Connect.
type ConnectOptions struct {
Addr string
}
// ConnectOption allows a common way to set ConnectOptions.
type ConnectOption func(opts *ConnectOptions)
// AddrConnectOption specifies the corresponding address of the target.
func AddrConnectOption(addr string) ConnectOption {
return func(opts *ConnectOptions) {
opts.Addr = addr
}
}

View File

@ -173,55 +173,6 @@ func parseIP(s string, port string) (ips []string) {
return
}
type peerConfig struct {
Strategy string `json:"strategy"`
Filters []string `json:"filters"`
MaxFails int `json:"max_fails"`
FailTimeout int `json:"fail_timeout"`
Nodes []string `json:"nodes"`
Bypass *bypass `json:"bypass"` // global bypass
}
type bypass struct {
Reverse bool `json:"reverse"`
Patterns []string `json:"patterns"`
}
func loadPeerConfig(peer string) (config peerConfig, err error) {
if peer == "" {
return
}
content, err := ioutil.ReadFile(peer)
if err != nil {
return
}
err = json.Unmarshal(content, &config)
return
}
func (cfg *peerConfig) Validate() {
if cfg.MaxFails <= 0 {
cfg.MaxFails = 1
}
if cfg.FailTimeout <= 0 {
cfg.FailTimeout = 30 // seconds
}
}
func parseStrategy(s string) gost.Strategy {
switch s {
case "random":
return &gost.RandomStrategy{}
case "fifo":
return &gost.FIFOStrategy{}
case "round":
fallthrough
default:
return &gost.RoundStrategy{}
}
}
func parseBypass(s string) *gost.Bypass {
if s == "" {
return nil

View File

@ -6,6 +6,7 @@ import (
"flag"
"fmt"
"net"
// _ "net/http/pprof"
"os"
"runtime"
@ -102,65 +103,23 @@ func (r *route) initChain() (*gost.Chain, error) {
ngroup.ID = gid
gid++
// parse the base node
// parse the base nodes
nodes, err := parseChainNode(ns)
if err != nil {
return nil, err
}
nid := 1 // node ID
for i := range nodes {
nodes[i].ID = nid
nid++
}
ngroup.AddNode(nodes...)
// parse peer nodes if exists
peerCfg, err := loadPeerConfig(nodes[0].Get("peer"))
if err != nil {
log.Log(err)
}
peerCfg.Validate()
strategy := peerCfg.Strategy
// overwrite the strategry in the peer config if `strategy` param exists.
if s := nodes[0].Get("strategy"); s != "" {
strategy = s
}
ngroup.Options = append(ngroup.Options,
gost.WithFilter(&gost.FailFilter{
MaxFails: peerCfg.MaxFails,
FailTimeout: time.Duration(peerCfg.FailTimeout) * time.Second,
}),
gost.WithStrategy(parseStrategy(strategy)),
)
for _, s := range peerCfg.Nodes {
nodes, err = parseChainNode(s)
if err != nil {
return nil, err
}
for i := range nodes {
nodes[i].ID = nid
nid++
}
ngroup.AddNode(nodes...)
}
var bypass *gost.Bypass
// global bypass
if peerCfg.Bypass != nil {
bypass = gost.NewBypassPatterns(peerCfg.Bypass.Reverse, peerCfg.Bypass.Patterns...)
}
nodes = ngroup.Nodes()
for i := range nodes {
if nodes[i].Bypass == nil {
nodes[i].Bypass = bypass // use global bypass if local bypass does not exist.
}
}
go gost.PeriodReload(&peerConfig{
group: ngroup,
baseNodes: nodes,
}, nodes[0].Get("peer"))
chain.AddNodeGroup(ngroup)
}
@ -510,6 +469,7 @@ func (r *route) serve() error {
gost.HostsHandlerOption(hosts),
gost.RetryHandlerOption(node.GetInt("retry")),
gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second),
gost.ProbeResistHandlerOption(node.Get("probe_resist")),
)
srv := &gost.Server{Listener: ln}

164
cmd/gost/peer.go Normal file
View File

@ -0,0 +1,164 @@
package main
import (
"bufio"
"bytes"
"encoding/json"
"io"
"io/ioutil"
"strconv"
"strings"
"time"
"github.com/ginuerzh/gost"
)
type peerConfig struct {
Strategy string `json:"strategy"`
MaxFails int `json:"max_fails"`
FailTimeout time.Duration `json:"fail_timeout"`
period time.Duration // the period for live reloading
Nodes []string `json:"nodes"`
group *gost.NodeGroup
baseNodes []gost.Node
}
type bypass struct {
Reverse bool `json:"reverse"`
Patterns []string `json:"patterns"`
}
func parsePeerConfig(cfg string, group *gost.NodeGroup, baseNodes []gost.Node) *peerConfig {
pc := &peerConfig{
group: group,
baseNodes: baseNodes,
}
go gost.PeriodReload(pc, cfg)
return pc
}
func (cfg *peerConfig) Validate() {
if cfg.MaxFails <= 0 {
cfg.MaxFails = 1
}
if cfg.FailTimeout <= 0 {
cfg.FailTimeout = 30 // seconds
}
}
func (cfg *peerConfig) Reload(r io.Reader) error {
if err := cfg.parse(r); err != nil {
return err
}
cfg.Validate()
group := cfg.group
strategy := cfg.Strategy
if len(cfg.baseNodes) > 0 {
// overwrite the strategry in the peer config if `strategy` param exists.
if s := cfg.baseNodes[0].Get("strategy"); s != "" {
strategy = s
}
}
group.SetSelector(
nil,
gost.WithFilter(&gost.FailFilter{
MaxFails: cfg.MaxFails,
FailTimeout: time.Duration(cfg.FailTimeout) * time.Second,
}),
gost.WithStrategy(parseStrategy(strategy)),
)
gNodes := cfg.baseNodes
nid := len(gNodes) + 1
for _, s := range cfg.Nodes {
nodes, err := parseChainNode(s)
if err != nil {
return err
}
for i := range nodes {
nodes[i].ID = nid
nid++
}
gNodes = append(gNodes, nodes...)
}
group.SetNodes(gNodes...)
return nil
}
func (cfg *peerConfig) parse(r io.Reader) error {
data, err := ioutil.ReadAll(r)
if err != nil {
return err
}
// compatible with JSON format
if err := json.NewDecoder(bytes.NewReader(data)).Decode(cfg); err == nil {
return nil
}
split := func(line string) []string {
if line == "" {
return nil
}
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
return ss
}
cfg.Nodes = nil
scanner := bufio.NewScanner(bytes.NewReader(data))
for scanner.Scan() {
line := scanner.Text()
ss := split(line)
if len(ss) < 2 {
continue
}
switch ss[0] {
case "strategy":
cfg.Strategy = ss[1]
case "max_fails":
cfg.MaxFails, _ = strconv.Atoi(ss[1])
case "fail_timeout":
cfg.FailTimeout, _ = time.ParseDuration(ss[1])
case "reload":
cfg.period, _ = time.ParseDuration(ss[1])
case "peer":
cfg.Nodes = append(cfg.Nodes, ss[1])
}
}
return scanner.Err()
}
func (cfg *peerConfig) Period() time.Duration {
return cfg.period
}
func parseStrategy(s string) gost.Strategy {
switch s {
case "random":
return &gost.RandomStrategy{}
case "fifo":
return &gost.FIFOStrategy{}
case "round":
fallthrough
default:
return &gost.RoundStrategy{}
}
}

View File

@ -1,18 +0,0 @@
{
"strategy": "round",
"max_fails": 3,
"fail_timeout": 30,
"nodes":[
"socks5://:1081",
"socks://:1082",
"socks4a://:1083"
],
"bypass":{
"reverse": false,
"patterns": [
"10.0.0.1",
"192.168.0.0/24",
"*.example.com"
]
}
}

14
cmd/gost/peer.txt Normal file
View File

@ -0,0 +1,14 @@
# strategy for node selecting
strategy random
max_fails 1
fail_timeout 30s
# period for live reloading
reload 10s
# peers
peer http://:18080
peer socks://:11080
peer ss://chacha20:123456@:18338

View File

@ -22,7 +22,7 @@ func ForwardConnector() Connector {
return &forwardConnector{}
}
func (c *forwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return conn, nil
}
@ -41,6 +41,9 @@ func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
group: NewNodeGroup(),
}
if raddr == "" {
raddr = ":0" // dummy address
}
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
@ -104,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
)
if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.MarkDead()
node.group.MarkDeadNode(node.ID)
} else {
break
}
@ -113,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
return
}
node.ResetDead()
node.group.ResetDeadNode(node.ID)
defer cc.Close()
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
@ -136,6 +139,9 @@ func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
group: NewNodeGroup(),
}
if raddr == "" {
raddr = ":0" // dummy address
}
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
@ -185,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
if h.options.Chain.IsEmpty() {
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
@ -206,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
}
defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)
@ -285,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
if err != nil {
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
node.MarkDead()
node.group.MarkDeadNode(node.ID)
} else {
break
}
@ -295,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
}
defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
transport(cc, conn)
@ -363,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
cc, err := net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)

16
gost.go
View File

@ -7,6 +7,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"math/big"
"time"
@ -14,7 +15,7 @@ import (
)
// Version is the gost version.
const Version = "2.6"
const Version = "2.6.1"
// Debug is a flag that enables the debug log.
var Debug bool
@ -100,3 +101,16 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) {
return
}
type readWriter struct {
r io.Reader
w io.Writer
}
func (rw *readWriter) Read(p []byte) (n int, err error) {
return rw.r.Read(p)
}
func (rw *readWriter) Write(p []byte) (n int, err error) {
return rw.w.Write(p)
}

View File

@ -20,18 +20,19 @@ type Handler interface {
// HandlerOptions describes the options for Handler.
type HandlerOptions struct {
Addr string
Chain *Chain
Users []*url.Userinfo
TLSConfig *tls.Config
Whitelist *Permissions
Blacklist *Permissions
Strategy Strategy
Bypass *Bypass
Retries int
Timeout time.Duration
Resolver Resolver
Hosts *Hosts
Addr string
Chain *Chain
Users []*url.Userinfo
TLSConfig *tls.Config
Whitelist *Permissions
Blacklist *Permissions
Strategy Strategy
Bypass *Bypass
Retries int
Timeout time.Duration
Resolver Resolver
Hosts *Hosts
ProbeResist string
}
// HandlerOption allows a common way to set handler options.
@ -121,6 +122,13 @@ func HostsHandlerOption(hosts *Hosts) HandlerOption {
}
}
// ProbeResistHandlerOption adds the probe resistance for HTTP proxy.
func ProbeResistHandlerOption(pr string) HandlerOption {
return func(opts *HandlerOptions) {
opts.ProbeResist = pr
}
}
type autoHandler struct {
options *HandlerOptions
}
@ -145,7 +153,7 @@ func (h *autoHandler) Handle(conn net.Conn) {
br := bufio.NewReader(conn)
b, err := br.Peek(1)
if err != nil {
log.Log(err)
log.Logf("[auto] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
conn.Close()
return
}

View File

@ -5,6 +5,7 @@ import (
"io"
"net"
"strings"
"sync"
"time"
"github.com/go-log/log"
@ -25,6 +26,7 @@ type Host struct {
type Hosts struct {
hosts []Host
period time.Duration
mux sync.RWMutex
}
// NewHosts creates a Hosts with optional list of host
@ -36,6 +38,9 @@ func NewHosts(hosts ...Host) *Hosts {
// AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) {
h.mux.Lock()
defer h.mux.Unlock()
h.hosts = append(h.hosts, host...)
}
@ -44,6 +49,10 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
if h == nil {
return
}
h.mux.RLock()
defer h.mux.RUnlock()
for _, h := range h.hosts {
if h.Hostname == host {
ip = h.IP
@ -64,6 +73,7 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
// Reload parses config from r, then live reloads the hosts.
func (h *Hosts) Reload(r io.Reader) error {
var period time.Duration
var hosts []Host
scanner := bufio.NewScanner(r)
@ -89,7 +99,7 @@ func (h *Hosts) Reload(r io.Reader) error {
// reload option
if strings.ToLower(ss[0]) == "reload" {
h.period, _ = time.ParseDuration(ss[1])
period, _ = time.ParseDuration(ss[1])
continue
}
@ -110,11 +120,18 @@ func (h *Hosts) Reload(r io.Reader) error {
return err
}
h.mux.Lock()
h.period = period
h.hosts = hosts
h.mux.Unlock()
return nil
}
// Period returns the reload period
func (h *Hosts) Period() time.Duration {
h.mux.RLock()
defer h.mux.RUnlock()
return h.period
}

129
http.go
View File

@ -8,6 +8,8 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"time"
@ -24,7 +26,7 @@ func HTTPConnector(user *url.Userinfo) Connector {
return &httpConnector{User: user}
}
func (c *httpConnector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
@ -111,16 +113,6 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), req.Host, string(dump))
}
if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
resp := "HTTP/1.1 400 Bad Request\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n"
conn.Write([]byte(resp))
if Debug {
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, resp)
}
return
}
// try to get the actual host.
if v := req.Header.Get("Gost-Target"); v != "" {
if host, err := decodeServerName(v); err == nil {
@ -128,25 +120,37 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
}
}
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
}
resp.Header.Add("Proxy-Agent", "gost/"+Version)
if !Can("tcp", req.Host, h.options.Whitelist, h.options.Blacklist) {
log.Logf("[http] Unauthorized to tcp connect to %s", req.Host)
b := []byte("HTTP/1.1 403 Forbidden\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n")
conn.Write(b)
log.Logf("[http] %s - %s : Unauthorized to tcp connect to %s",
conn.RemoteAddr(), req.Host, req.Host)
resp.StatusCode = http.StatusForbidden
if Debug {
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b))
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(dump))
}
resp.Write(conn)
return
}
if h.options.Bypass.Contains(req.Host) {
log.Logf("[http] [bypass] %s", req.Host)
b := []byte("HTTP/1.1 403 Forbidden\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n")
conn.Write(b)
resp.StatusCode = http.StatusForbidden
if Debug {
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b))
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(dump))
}
resp.Write(conn)
return
}
@ -155,16 +159,80 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
log.Logf("[http] %s - %s : Authorization: '%s' '%s'", conn.RemoteAddr(), req.Host, u, p)
}
if !authenticate(u, p, h.options.Users...) {
log.Logf("[http] %s <- %s : proxy authentication required", conn.RemoteAddr(), req.Host)
resp := "HTTP/1.1 407 Proxy Authentication Required\r\n" +
"Proxy-Authenticate: Basic realm=\"gost\"\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n"
conn.Write([]byte(resp))
// probing resistance is enabled
if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 {
switch ss[0] {
case "code":
resp.StatusCode, _ = strconv.Atoi(ss[1])
case "web":
url := ss[1]
if !strings.HasPrefix(url, "http") {
url = "http://" + url
}
if r, err := http.Get(url); err == nil {
resp = r
}
case "host":
cc, err := net.Dial("tcp", ss[1])
if err == nil {
defer cc.Close()
req.Write(cc)
log.Logf("[http] %s <-> %s : forward to %s", conn.LocalAddr(), req.Host, ss[1])
transport(conn, cc)
log.Logf("[http] %s >-< %s : forward to %s", conn.LocalAddr(), req.Host, ss[1])
return
}
case "file":
f, _ := os.Open(ss[1])
if f != nil {
resp.StatusCode = http.StatusOK
if finfo, _ := f.Stat(); finfo != nil {
resp.ContentLength = finfo.Size()
}
resp.Body = f
}
}
}
if resp.StatusCode == 0 {
log.Logf("[http] %s <- %s : proxy authentication required", conn.RemoteAddr(), req.Host)
resp.StatusCode = http.StatusProxyAuthRequired
resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"")
} else {
resp.Header = http.Header{}
resp.Header.Set("Server", "nginx/1.14.1")
resp.Header.Set("Date", time.Now().Format(http.TimeFormat))
if resp.ContentLength > 0 {
resp.Header.Set("Content-Type", "text/html")
}
if resp.StatusCode == http.StatusOK {
resp.Header.Set("Connection", "keep-alive")
}
}
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(dump))
}
resp.Write(conn)
return
}
if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
resp.StatusCode = http.StatusBadRequest
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(dump))
}
resp.Write(conn)
return
}
req.Header.Del("Proxy-Authorization")
// req.Header.Del("Proxy-Connection")
host := req.Host
if _, port, _ := net.SplitHostPort(host); port == "" {
@ -212,13 +280,14 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
if err != nil {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), host, err)
resp.StatusCode = http.StatusServiceUnavailable
b := []byte("HTTP/1.1 503 Service unavailable\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n")
if Debug {
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), host, string(b))
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), host, string(dump))
}
conn.Write(b)
resp.Write(conn)
return
}
defer cc.Close()

120
http2.go
View File

@ -2,14 +2,18 @@ package gost
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@ -28,7 +32,7 @@ func HTTP2Connector(user *url.Userinfo) Connector {
return &http2Connector{User: user}
}
func (c *http2Connector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
cc, ok := conn.(*http2ClientConn)
if !ok {
return nil, errors.New("wrong connection type")
@ -75,6 +79,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string) (net.Conn, error) {
w: pw,
closed: make(chan struct{}),
}
hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr)
hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr)
@ -307,14 +312,79 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
return
}
resp := &http.Response{
ProtoMajor: 2,
ProtoMinor: 0,
Header: http.Header{},
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if Debug && (u != "" || p != "") {
log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
log.Logf("[http2] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
}
if !authenticate(u, p, h.options.Users...) {
log.Logf("[http2] %s <- %s : proxy authentication required", r.RemoteAddr, target)
w.Header().Set("Proxy-Authenticate", "Basic realm=\"gost\"")
w.WriteHeader(http.StatusProxyAuthRequired)
// probing resistance is enabled
if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 {
switch ss[0] {
case "code":
resp.StatusCode, _ = strconv.Atoi(ss[1])
case "web":
url := ss[1]
if !strings.HasPrefix(url, "http") {
url = "http://" + url
}
if r, err := http.Get(url); err == nil {
resp = r
}
case "host":
cc, err := net.Dial("tcp", ss[1])
if err == nil {
defer cc.Close()
log.Logf("[http2] %s <-> %s : forward to %s", r.RemoteAddr, target, ss[1])
if err := h.forwardRequest(w, r, cc); err != nil {
log.Logf("[http2] %s - %s : %s", r.RemoteAddr, target, err)
}
log.Logf("[http2] %s >-< %s : forward to %s", r.RemoteAddr, target, ss[1])
return
}
case "file":
f, _ := os.Open(ss[1])
if f != nil {
resp.StatusCode = http.StatusOK
if finfo, _ := f.Stat(); finfo != nil {
resp.ContentLength = finfo.Size()
}
resp.Body = f
}
}
}
if resp.StatusCode == 0 {
log.Logf("[http2] %s <- %s : proxy authentication required", r.RemoteAddr, target)
resp.StatusCode = http.StatusProxyAuthRequired
resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"")
} else {
w.Header().Del("Proxy-Agent")
resp.Header = http.Header{}
resp.Header.Set("Server", "nginx/1.14.1")
resp.Header.Set("Date", time.Now().Format(http.TimeFormat))
if resp.ContentLength > 0 {
resp.Header.Set("Content-Type", "text/html")
}
if resp.StatusCode == http.StatusOK {
resp.Header.Set("Connection", "keep-alive")
}
}
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http2] %s <- %s\n%s", r.RemoteAddr, target, string(dump))
}
h.writeResponse(w, resp)
resp.Body.Close()
return
}
@ -358,47 +428,41 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
}
log.Logf("[http2] %s <-> %s", r.RemoteAddr, target)
errc := make(chan error, 2)
go func() {
_, err := io.Copy(cc, r.Body)
errc <- err
}()
go func() {
_, err := io.Copy(flushWriter{w}, cc)
errc <- err
}()
select {
case <-errc:
// glog.V(LWARNING).Infoln("exit", err)
}
transport(&readWriter{r: r.Body, w: flushWriter{w}}, cc)
log.Logf("[http2] %s >-< %s", r.RemoteAddr, target)
return
}
log.Logf("[http2] %s <-> %s", r.RemoteAddr, target)
if err = r.Write(cc); err != nil {
log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, target, err)
if err := h.forwardRequest(w, r, cc); err != nil {
log.Logf("[http2] %s - %s : %s", r.RemoteAddr, target, err)
}
log.Logf("[http2] %s >-< %s", r.RemoteAddr, target)
}
func (h *http2Handler) forwardRequest(w http.ResponseWriter, r *http.Request, rw io.ReadWriter) (err error) {
if err = r.Write(rw); err != nil {
return
}
resp, err := http.ReadResponse(bufio.NewReader(cc), r)
resp, err := http.ReadResponse(bufio.NewReader(rw), r)
if err != nil {
log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, target, err)
return
}
defer resp.Body.Close()
return h.writeResponse(w, resp)
}
func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) error {
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(flushWriter{w}, resp.Body); err != nil {
log.Logf("[http2] %s <- %s : %s", r.RemoteAddr, target, err)
}
log.Logf("[http2] %s >-< %s", r.RemoteAddr, target)
_, err := io.Copy(flushWriter{w}, resp.Body)
return err
}
type http2Listener struct {
@ -526,7 +590,7 @@ func H2CListener(addr string) (Listener, error) {
l := &h2Listener{
Listener: tcpKeepAliveListener{ln.(*net.TCPListener)},
server: &http2.Server{
// MaxConcurrentStreams: 1000,
// MaxConcurrentStreams: 1000,
},
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),

View File

@ -89,7 +89,7 @@ var httpProxyTests = []struct {
{"", url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""},
{"http://:0", nil, nil, "503 Service unavailable"},
{"http://:0", nil, nil, "503 Service Unavailable"},
}
func TestHTTPProxy(t *testing.T) {

147
node.go
View File

@ -5,6 +5,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
@ -88,41 +89,6 @@ func ParseNode(s string) (node Node, err error) {
return
}
// MarkDead marks the node fail status.
func (node *Node) MarkDead() {
atomic.AddUint32(&node.failCount, 1)
atomic.StoreInt64(&node.failTime, time.Now().Unix())
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.AddUint32(&node.group.nodes[i].failCount, 1)
atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDead resets the node fail status.
func (node *Node) ResetDead() {
atomic.StoreUint32(&node.failCount, 0)
atomic.StoreInt64(&node.failTime, 0)
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.StoreUint32(&node.group.nodes[i].failCount, 0)
atomic.StoreInt64(&node.group.nodes[i].failTime, 0)
break
}
}
}
// Clone clones the node, it will prevent data race.
func (node *Node) Clone() Node {
return Node{
@ -167,10 +133,11 @@ func (node *Node) String() string {
// NodeGroup is a group of nodes.
type NodeGroup struct {
ID int
nodes []Node
Options []SelectOption
Selector NodeSelector
ID int
nodes []Node
selectorOptions []SelectOption
selector NodeSelector
mux sync.RWMutex
}
// NewNodeGroup creates a node group
@ -180,40 +147,128 @@ func NewNodeGroup(nodes ...Node) *NodeGroup {
}
}
// AddNode adds node or node list into group
// AddNode appends node or node list into group node.
func (group *NodeGroup) AddNode(node ...Node) {
if group == nil {
return
}
group.mux.Lock()
defer group.mux.Unlock()
group.nodes = append(group.nodes, node...)
}
// SetNodes replaces the group nodes to the specified nodes.
func (group *NodeGroup) SetNodes(nodes ...Node) {
if group == nil {
return
}
group.mux.Lock()
defer group.mux.Unlock()
group.nodes = nodes
}
// SetSelector sets node selector with options for the group.
func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption) {
if group == nil {
return
}
group.Selector = selector
group.Options = opts
group.mux.Lock()
defer group.mux.Unlock()
group.selector = selector
group.selectorOptions = opts
}
// Nodes returns node list in the group
// Nodes returns the node list in the group
func (group *NodeGroup) Nodes() []Node {
if group == nil {
return nil
}
group.mux.RLock()
defer group.mux.RUnlock()
return group.nodes
}
// Next selects the next node from group.
func (group *NodeGroup) copyNodes() []Node {
group.mux.RLock()
defer group.mux.RUnlock()
var nodes []Node
for i := range group.nodes {
nodes = append(nodes, group.nodes[i])
}
return nodes
}
// GetNode returns a copy of the node specified by index in the group.
func (group *NodeGroup) GetNode(i int) Node {
group.mux.RLock()
defer group.mux.RUnlock()
if i < 0 || group == nil || len(group.nodes) <= i {
return Node{}
}
return group.nodes[i].Clone()
}
// MarkDeadNode marks the node with ID nid status to dead.
func (group *NodeGroup) MarkDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.AddUint32(&group.nodes[i].failCount, 1)
atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDeadNode resets the node with ID nid status.
func (group *NodeGroup) ResetDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.StoreUint32(&group.nodes[i].failCount, 0)
atomic.StoreInt64(&group.nodes[i].failTime, 0)
break
}
}
}
// Next selects a node from group.
// It also selects IP if the IP list exists.
func (group *NodeGroup) Next() (node Node, err error) {
selector := group.Selector
if group == nil {
return
}
group.mux.RLock()
defer group.mux.RUnlock()
selector := group.selector
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err = selector.Select(group.Nodes(), group.Options...)
node, err = selector.Select(group.nodes, group.selectorOptions...)
if err != nil {
return
}

View File

@ -33,7 +33,7 @@ func (session *quicSession) GetConn() (*quicConn, error) {
}
func (session *quicSession) Close() error {
return session.session.Close(nil)
return session.session.Close()
}
type quicTransporter struct {
@ -226,7 +226,7 @@ func (l *quicListener) sessionLoop(session quic.Session) {
stream, err := session.AcceptStream()
if err != nil {
log.Log("[quic] accept stream:", err)
session.Close(err)
session.Close()
return
}

View File

@ -26,13 +26,12 @@ func PeriodReload(r Reloader, configFile string) error {
finfo, err := f.Stat()
if err != nil {
f.Close()
return err
}
mt := finfo.ModTime()
if !mt.Equal(lastMod) {
if Debug {
log.Log("[reload]", configFile)
}
log.Log("[reload]", configFile)
r.Reload(f)
lastMod = mt
}

View File

@ -3,7 +3,6 @@ package gost
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
@ -13,6 +12,7 @@ import (
"time"
"github.com/go-log/log"
"github.com/miekg/dns"
)
var (
@ -46,14 +46,13 @@ type NameServer struct {
func (ns NameServer) String() string {
addr := ns.Addr
prot := ns.Protocol
host := ns.Hostname
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53")
}
if prot == "" {
prot = "udp"
}
return fmt.Sprintf("%s/%s %s", addr, prot, host)
return fmt.Sprintf("%s/%s", addr, prot)
}
type resolverCacheItem struct {
@ -68,6 +67,8 @@ type resolver struct {
Timeout time.Duration
TTL time.Duration
period time.Duration
domain string
mux sync.RWMutex
}
// NewResolver create a new Resolver with the given name servers and resolution timeout.
@ -78,95 +79,116 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
TTL: ttl,
mCache: &sync.Map{},
}
r.init()
return r
}
func (r *resolver) init() {
if r.Timeout <= 0 {
r.Timeout = DefaultResolverTimeout
}
if r.TTL == 0 {
r.TTL = DefaultResolverTTL
}
r.Resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) {
for _, ns := range r.Servers {
conn, err = r.dial(ctx, ns)
if err == nil {
break
}
log.Logf("[resolver] %s : %s", ns, err)
}
return
},
}
return r
}
func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
var d net.Dialer
func (r *resolver) copyServers() []NameServer {
var servers []NameServer
for i := range r.Servers {
servers = append(servers, r.Servers[i])
}
return servers
}
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
if r == nil {
return
}
var domain string
var timeout, ttl time.Duration
var servers []NameServer
r.mux.RLock()
domain = r.domain
timeout = r.Timeout
servers = r.copyServers()
r.mux.RUnlock()
if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil
}
if !strings.Contains(host, ".") && domain != "" {
host = host + "." + domain
}
ips = r.loadCache(host, ttl)
if len(ips) > 0 {
if Debug {
log.Logf("[resolver] cache hit %s: %v", host, ips)
}
return
}
for _, ns := range servers {
ips, err = r.resolve(ns, host, timeout)
if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns, err)
continue
}
if Debug {
log.Logf("[resolver] %s via %s %v", host, ns, ips)
}
if len(ips) > 0 {
break
}
}
r.storeCache(host, ips)
return
}
func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
addr := ns.Addr
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53")
}
client := dns.Client{
Timeout: timeout,
}
switch strings.ToLower(ns.Protocol) {
case "tcp":
return d.DialContext(ctx, "tcp", addr)
client.Net = "tcp"
case "tls":
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
cfg := &tls.Config{
ServerName: ns.Hostname,
}
if cfg.ServerName == "" {
cfg.InsecureSkipVerify = true
}
return tls.Client(conn, cfg), nil
client.Net = "tcp-tls"
client.TLSConfig = cfg
case "udp":
fallthrough
default:
return d.DialContext(ctx, "udp", addr)
client.Net = "udp"
}
}
func (r *resolver) Resolve(name string) (ips []net.IP, err error) {
if r == nil {
m := dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
mr, _, err := client.Exchange(&m, addr)
if err != nil {
return
}
timeout := r.Timeout
if ip := net.ParseIP(name); ip != nil {
return []net.IP{ip}, nil
}
ips = r.loadCache(name)
if len(ips) > 0 {
if Debug {
log.Logf("[resolver] cache hit: %s %v", name, ips)
for _, ans := range mr.Answer {
if ar, _ := ans.(*dns.A); ar != nil {
ips = append(ips, ar.A)
}
return
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
addrs, err := r.Resolver.LookupIPAddr(ctx, name)
for _, addr := range addrs {
ips = append(ips, addr.IP)
}
r.storeCache(name, ips)
if len(ips) > 0 && Debug {
log.Logf("[resolver] %s %v", name, ips)
}
return
}
func (r *resolver) loadCache(name string) []net.IP {
ttl := r.TTL
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
if ttl < 0 {
return nil
}
@ -183,8 +205,7 @@ func (r *resolver) loadCache(name string) []net.IP {
}
func (r *resolver) storeCache(name string, ips []net.IP) {
ttl := r.TTL
if ttl < 0 || name == "" || len(ips) == 0 {
if name == "" || len(ips) == 0 {
return
}
r.mCache.Store(name, &resolverCacheItem{
@ -194,73 +215,99 @@ func (r *resolver) storeCache(name string, ips []net.IP) {
}
func (r *resolver) Reload(rd io.Reader) error {
var ttl, timeout, period time.Duration
var domain string
var nss []NameServer
scanner := bufio.NewScanner(rd)
for scanner.Scan() {
line := scanner.Text()
split := func(line string) []string {
if line == "" {
return nil
}
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
return ss
}
scanner := bufio.NewScanner(rd)
for scanner.Scan() {
line := scanner.Text()
ss := split(line)
if len(ss) == 0 {
continue
}
if len(ss) >= 2 {
// timeout option
if strings.ToLower(ss[0]) == "timeout" {
r.Timeout, _ = time.ParseDuration(ss[1])
continue
switch ss[0] {
case "timeout": // timeout option
if len(ss) > 1 {
timeout, _ = time.ParseDuration(ss[1])
}
// ttl option
if strings.ToLower(ss[0]) == "ttl" {
r.TTL, _ = time.ParseDuration(ss[1])
continue
case "ttl": // ttl option
if len(ss) > 1 {
ttl, _ = time.ParseDuration(ss[1])
}
// reload option
if strings.ToLower(ss[0]) == "reload" {
r.period, _ = time.ParseDuration(ss[1])
continue
case "reload": // reload option
if len(ss) > 1 {
period, _ = time.ParseDuration(ss[1])
}
}
var ns NameServer
switch len(ss) {
case 1:
ns.Addr = ss[0]
case 2:
ns.Addr = ss[0]
ns.Protocol = ss[1]
case "domain":
if len(ss) > 1 {
domain = ss[1]
}
case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
if len(ss) <= 1 {
break
}
ss = ss[1:]
fallthrough
default:
ns.Addr = ss[0]
ns.Protocol = ss[1]
ns.Hostname = ss[2]
var ns NameServer
switch len(ss) {
case 0:
break
case 1:
ns.Addr = ss[0]
case 2:
ns.Addr = ss[0]
ns.Protocol = ss[1]
default:
ns.Addr = ss[0]
ns.Protocol = ss[1]
ns.Hostname = ss[2]
}
nss = append(nss, ns)
}
nss = append(nss, ns)
}
if err := scanner.Err(); err != nil {
return err
}
r.mux.Lock()
r.Timeout = timeout
r.TTL = ttl
r.domain = domain
r.period = period
r.Servers = nss
r.mux.Unlock()
return nil
}
func (r *resolver) Period() time.Duration {
r.mux.RLock()
defer r.mux.RUnlock()
return r.period
}
@ -269,6 +316,9 @@ func (r *resolver) String() string {
return ""
}
r.mux.RLock()
defer r.mux.RUnlock()
b := &bytes.Buffer{}
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
fmt.Fprintf(b, "TTL %v\n", r.TTL)

View File

@ -94,6 +94,7 @@ type RandomStrategy struct {
Seed int64
rand *rand.Rand
once sync.Once
mux sync.Mutex
}
// Apply applies the random strategy for the nodes.
@ -109,7 +110,11 @@ func (s *RandomStrategy) Apply(nodes []Node) Node {
return Node{}
}
return nodes[s.rand.Int()%len(nodes)]
s.mux.Lock()
r := s.rand.Int()
s.mux.Unlock()
return nodes[r%len(nodes)]
}
func (s *RandomStrategy) String() string {

View File

@ -72,11 +72,13 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error {
}
tempDelay = 0
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
log.Log("[bypass]", conn.RemoteAddr())
conn.Close()
continue
}
/*
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
log.Log("[bypass]", conn.RemoteAddr())
conn.Close()
continue
}
*/
go h.Handle(conn)
}
@ -90,12 +92,14 @@ type ServerOptions struct {
// ServerOption allows a common way to set server options.
type ServerOption func(opts *ServerOptions)
/*
// BypassServerOption sets the bypass option of ServerOptions.
func BypassServerOption(bypass *Bypass) ServerOption {
return func(opts *ServerOptions) {
opts.Bypass = bypass
}
}
*/
// Listener is a proxy server listener, just like a net.Listener.
type Listener interface {

8
sni.go
View File

@ -28,7 +28,7 @@ func SNIConnector(host string) Connector {
return &sniConnector{host: host}
}
func (c *sniConnector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *sniConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return &sniClientConn{addr: addr, host: c.host, Conn: conn}, nil
}
@ -112,12 +112,12 @@ func (h *sniHandler) Handle(conn net.Conn) {
defer cc.Close()
if _, err := cc.Write(b); err != nil {
log.Logf("[sni] %s -> %s : %s", conn.RemoteAddr(), host, err)
log.Logf("[sni] %s -> %s : %s", conn.RemoteAddr(), addr, err)
}
log.Logf("[sni] %s <-> %s", cc.LocalAddr(), host)
log.Logf("[sni] %s <-> %s", cc.LocalAddr(), addr)
transport(conn, cc)
log.Logf("[sni] %s >-< %s", cc.LocalAddr(), host)
log.Logf("[sni] %s >-< %s", cc.LocalAddr(), addr)
}
// sniSniffConn is a net.Conn that reads from r, fails on Writes,

View File

@ -148,11 +148,11 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
req, err := gosocks5.ReadUserPassRequest(conn)
if err != nil {
log.Log("[socks5]", err)
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return nil, err
}
if Debug {
log.Log("[socks5]", req.String())
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
}
valid := false
for _, user := range selector.Users {
@ -168,23 +168,23 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
if len(selector.Users) > 0 && !valid {
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
if err := resp.Write(conn); err != nil {
log.Log("[socks5]", err)
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return nil, err
}
if Debug {
log.Log("[socks5]", resp)
log.Log("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp)
}
log.Log("[socks5] proxy authentication required")
log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr())
return nil, gosocks5.ErrAuthFailure
}
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded)
if err := resp.Write(conn); err != nil {
log.Log("[socks5]", err)
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return nil, err
}
if Debug {
log.Log("[socks5]", resp)
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp)
}
case gosocks5.MethodNoAcceptable:
return nil, gosocks5.ErrBadMethod
@ -203,7 +203,7 @@ func SOCKS5Connector(user *url.Userinfo) Connector {
return &socks5Connector{User: user}
}
func (c *socks5Connector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
selector := &clientSelector{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
User: c.User,
@ -265,7 +265,7 @@ func SOCKS4Connector() Connector {
return &socks4Connector{}
}
func (c *socks4Connector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *socks4Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
taddr, err := net.ResolveTCPAddr("tcp4", addr)
if err != nil {
return nil, err
@ -312,7 +312,7 @@ func SOCKS4AConnector() Connector {
return &socks4aConnector{}
}
func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *socks4aConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@ -326,7 +326,7 @@ func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error)
}
if Debug {
log.Logf("[socks4] %s", req)
log.Logf("[socks4a] %s", req)
}
reply, err := gosocks4.ReadReply(conn)
@ -335,11 +335,11 @@ func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error)
}
if Debug {
log.Logf("[socks4] %s", reply)
log.Logf("[socks4a] %s", reply)
}
if reply.Code != gosocks4.Granted {
return nil, fmt.Errorf("[socks4] %d", reply.Code)
return nil, fmt.Errorf("[socks4a] %d", reply.Code)
}
return conn, nil

148
ss.go
View File

@ -9,13 +9,35 @@ import (
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/ginuerzh/gosocks5"
"github.com/go-log/log"
core "github.com/shadowsocks/go-shadowsocks2/core"
socks "github.com/shadowsocks/go-shadowsocks2/socks"
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
)
// Check if shadowsocks2 should be used for AEAD encryption.
func isModernCipher(c string) bool {
c = strings.ToUpper(c)
modern := strings.Contains(c, "AEAD_")
switch c {
case "DUMMY":
fallthrough
case "CHACHA20-IETF-POLY1305":
fallthrough
case "AES-128-GCM":
fallthrough
case "AES-192-GCM":
fallthrough
case "AES-256-GCM":
modern = true
}
return modern
}
// Due to in/out byte length is inconsistent of the shadowsocks.Conn.Write,
// we wrap around it to make io.Copy happy.
type shadowConn struct {
@ -67,7 +89,7 @@ func ShadowConnector(cipher *url.Userinfo) Connector {
return &shadowConnector{Cipher: cipher}
}
func (c *shadowConnector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
rawaddr, err := ss.RawAddr(addr)
if err != nil {
return nil, err
@ -79,16 +101,28 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string) (net.Conn, error)
password, _ = c.Cipher.Password()
}
cipher, err := ss.NewCipher(method, password)
if err != nil {
return nil, err
}
if isModernCipher(method) {
cipher, err := core.PickCipher(method, []byte{}, password)
if err != nil {
return nil, err
}
conn = cipher.StreamConn(conn)
if _, err = conn.Write(rawaddr); err != nil {
return nil, err
}
return conn, nil
} else {
cipher, err := ss.NewCipher(method, password)
if err != nil {
return nil, err
}
sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher)
if err != nil {
return nil, err
sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher)
if err != nil {
return nil, err
}
return &shadowConn{conn: sc}, nil
}
return &shadowConn{conn: sc}, nil
}
type shadowHandler struct {
@ -116,29 +150,53 @@ func (h *shadowHandler) Init(options ...HandlerOption) {
func (h *shadowHandler) Handle(conn net.Conn) {
defer conn.Close()
var method, password string
var method, password, addr string
users := h.options.Users
if len(users) > 0 {
method = users[0].Username()
password, _ = users[0].Password()
}
cipher, err := ss.NewCipher(method, password)
if err != nil {
log.Log("[ss]", err)
return
}
conn = &shadowConn{conn: ss.NewConn(conn, cipher)}
log.Logf("[ss] %s - %s", conn.RemoteAddr(), conn.LocalAddr())
if isModernCipher(method) {
//ss with aead
ciph, err := core.PickCipher(method, []byte{}, password)
if err != nil {
log.Log("[ss]", err)
return
}
conn.SetReadDeadline(time.Now().Add(ReadTimeout))
addr, err := h.getRequest(conn)
if err != nil {
log.Logf("[ss] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
log.Logf("[ss] %s - %s", conn.RemoteAddr(), conn.LocalAddr())
conn.SetReadDeadline(time.Now().Add(ReadTimeout))
conn = ciph.StreamConn(conn)
tgt, err := socks.ReadAddr(conn)
if err != nil {
log.Logf("[ss] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
// clear timer
conn.SetReadDeadline(time.Time{})
addr = tgt.String()
} else {
// outdated ss
cipher, err := ss.NewCipher(method, password)
if err != nil {
log.Log("[ss]", err)
return
}
conn = &shadowConn{conn: ss.NewConn(conn, cipher)}
log.Logf("[ss] %s - %s", conn.RemoteAddr(), conn.LocalAddr())
conn.SetReadDeadline(time.Now().Add(ReadTimeout))
addr, err = h.getRequest(conn)
if err != nil {
log.Logf("[ss] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
// clear timer
conn.SetReadDeadline(time.Time{})
}
// clear timer
conn.SetReadDeadline(time.Time{})
log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr)
@ -259,18 +317,38 @@ func ShadowUDPListener(addr string, cipher *url.Userinfo, ttl time.Duration) (Li
method = cipher.Username()
password, _ = cipher.Password()
}
cp, err := ss.NewCipher(method, password)
if err != nil {
ln.Close()
return nil, err
}
l := &shadowUDPListener{
ln: ss.NewSecurePacketConn(ln, cp, false),
conns: make(map[string]*udpServerConn),
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),
ttl: ttl,
var l *shadowUDPListener = nil
if isModernCipher(method) {
//modern ss
cp, err := core.PickCipher(method, []byte{}, password)
if err != nil {
ln.Close()
return nil, err
}
l = &shadowUDPListener{
ln: cp.PacketConn(ln),
conns: make(map[string]*udpServerConn),
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),
ttl: ttl,
}
} else {
//ancient ss
cp, err := ss.NewCipher(method, password)
if err != nil {
ln.Close()
return nil, err
}
l = &shadowUDPListener{
ln: ss.NewSecurePacketConn(ln, cp, false),
conns: make(map[string]*udpServerConn),
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),
ttl: ttl,
}
}
go l.listenLoop()
return l, nil
}

4
ssh.go
View File

@ -39,7 +39,7 @@ func SSHDirectForwardConnector() Connector {
return &sshDirectForwardConnector{}
}
func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string) (net.Conn, error) {
func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string, options ...ConnectOption) (net.Conn, error) {
cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution.
if !ok {
return nil, errors.New("ssh: wrong connection type")
@ -60,7 +60,7 @@ func SSHRemoteForwardConnector() Connector {
return &sshRemoteForwardConnector{}
}
func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) {
func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution.
if !ok {
return nil, errors.New("ssh: wrong connection type")

View File

@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2016 Lucas Clemente
Copyright (c) 2014 cheekybits
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
@ -19,3 +19,4 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

2
vendor/github.com/cheekybits/genny/generic/doc.go generated vendored Normal file
View File

@ -0,0 +1,2 @@
// Package generic contains the generic marker types.
package generic

13
vendor/github.com/cheekybits/genny/generic/generic.go generated vendored Normal file
View File

@ -0,0 +1,13 @@
package generic
// Type is the placeholder type that indicates a generic value.
// When genny is executed, variables of this type will be replaced with
// references to the specific types.
// var GenericType generic.Type
type Type interface{}
// Number is the placehoder type that indiccates a generic numerical value.
// When genny is executed, variables of this type will be replaced with
// references to the specific types.
// var GenericType generic.Number
type Number float64

View File

@ -2,7 +2,9 @@ package dissector
import (
"bytes"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
)
@ -62,6 +64,11 @@ func (h *ClientHelloHandshake) ReadFrom(r io.Reader) (n int64, err error) {
}
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if length < 34 { // length of version + random
err = fmt.Errorf("bad length, need at least 34 bytes, got %d", length)
return
}
b = make([]byte, length)
nn, err = io.ReadFull(r, b)
n += int64(nn)
@ -69,6 +76,10 @@ func (h *ClientHelloHandshake) ReadFrom(r io.Reader) (n int64, err error) {
return
}
h.Version = Version(binary.BigEndian.Uint16(b[:2]))
if h.Version < tls.VersionTLS12 {
err = fmt.Errorf("bad version: only TLSv1.2 is supported")
return
}
pos := 2
h.Random.Time = binary.BigEndian.Uint32(b[pos : pos+4])
@ -76,41 +87,113 @@ func (h *ClientHelloHandshake) ReadFrom(r io.Reader) (n int64, err error) {
copy(h.Random.Opaque[:], b[pos:pos+28])
pos += 28
sessionLen := int(b[pos])
pos++
h.SessionID = make([]byte, sessionLen)
copy(h.SessionID, b[pos:pos+sessionLen])
pos += sessionLen
cipherLen := int(binary.BigEndian.Uint16(b[pos : pos+2]))
pos += 2
for i := 0; i < cipherLen/2; i++ {
h.CipherSuites = append(h.CipherSuites, CipherSuite(binary.BigEndian.Uint16(b[pos:pos+2])))
pos += 2
nn, err = h.readSession(b[pos:])
if err != nil {
return
}
pos += nn
compLen := int(b[pos])
pos++
for i := 0; i < compLen; i++ {
h.CompressionMethods = append(h.CompressionMethods, CompressionMethod(b[pos]))
pos++
nn, err = h.readCipherSuites(b[pos:])
if err != nil {
return
}
pos += nn
// extLen := int(binary.BigEndian.Uint16(b[pos : pos+2]))
pos += 2
if pos >= len(b) {
nn, err = h.readCompressionMethods(b[pos:])
if err != nil {
return
}
pos += nn
nn, err = h.readExtensions(b[pos:])
if err != nil {
return
}
// pos += nn
return
}
func (h *ClientHelloHandshake) readSession(b []byte) (n int, err error) {
if len(b) == 0 {
err = fmt.Errorf("bad length: data too short for session")
return
}
br := bytes.NewReader(b[pos:])
nlen := int(b[0])
n++
if len(b) < n+nlen {
err = fmt.Errorf("bad length: malformed data for session")
}
if nlen > 0 && n+nlen <= len(b) {
h.SessionID = make([]byte, nlen)
copy(h.SessionID, b[n:n+nlen])
n += nlen
}
return
}
func (h *ClientHelloHandshake) readCipherSuites(b []byte) (n int, err error) {
if len(b) < 2 {
err = fmt.Errorf("bad length: data too short for cipher suites")
return
}
nlen := int(binary.BigEndian.Uint16(b[:2]))
n += 2
if len(b) < n+nlen {
err = fmt.Errorf("bad length: malformed data for cipher suites")
}
for i := 0; i < nlen/2; i++ {
h.CipherSuites = append(h.CipherSuites, CipherSuite(binary.BigEndian.Uint16(b[n:n+2])))
n += 2
}
return
}
func (h *ClientHelloHandshake) readCompressionMethods(b []byte) (n int, err error) {
if len(b) == 0 {
err = fmt.Errorf("bad length: data too short for compression methods")
return
}
nlen := int(b[0])
n++
if len(b) < n+nlen {
err = fmt.Errorf("bad length: malformed data for compression methods")
}
for i := 0; i < nlen; i++ {
h.CompressionMethods = append(h.CompressionMethods, CompressionMethod(b[n]))
n++
}
return
}
func (h *ClientHelloHandshake) readExtensions(b []byte) (n int, err error) {
if len(b) < 2 {
err = fmt.Errorf("bad length: data too short for extensions")
return
}
nlen := int(binary.BigEndian.Uint16(b[:2]))
n += 2
if len(b) < n+nlen {
err = fmt.Errorf("bad length: malformed data for extensions")
return
}
br := bytes.NewReader(b[n:])
for br.Len() > 0 {
cn := br.Len()
var ext Extension
ext, err = ReadExtension(br)
if err != nil {
return
}
h.Extensions = append(h.Extensions, ext)
n += (cn - br.Len())
}
return
}

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2016 Lucas Clemente
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,28 +0,0 @@
# aes12
This package modifies the AES-GCM implementation from Go's standard library to use 12 byte tag sizes. It is not intended for a general audience, and used in [quic-go](https://github.com/lucas-clemente/quic-go).
To make use of the in-place encryption / decryption feature, the `dst` parameter to `Seal` and `Open` should be 16 bytes longer than plaintext, not 12.
Command for testing:
```
go test . --bench=. && GOARCH=386 go test . --bench=.
```
The output (on my machine):
```
BenchmarkAESGCMSeal1K-8 3000000 467 ns/op 2192.37 MB/s
BenchmarkAESGCMOpen1K-8 3000000 416 ns/op 2456.72 MB/s
BenchmarkAESGCMSeal8K-8 500000 2742 ns/op 2986.53 MB/s
BenchmarkAESGCMOpen8K-8 500000 2791 ns/op 2934.65 MB/s
PASS
ok github.com/lucas-clemente/aes12 6.383s
BenchmarkAESGCMSeal1K-8 50000 35233 ns/op 29.06 MB/s
BenchmarkAESGCMOpen1K-8 50000 34529 ns/op 29.66 MB/s
BenchmarkAESGCMSeal8K-8 5000 262678 ns/op 31.19 MB/s
BenchmarkAESGCMOpen8K-8 5000 267296 ns/op 30.65 MB/s
PASS
ok github.com/lucas-clemente/aes12 6.972s
```

View File

@ -1,148 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64
package aes12
import "crypto/subtle"
// The following functions are defined in gcm_amd64.s.
func hasGCMAsm() bool
//go:noescape
func aesEncBlock(dst, src *[16]byte, ks []uint32)
//go:noescape
func gcmAesInit(productTable *[256]byte, ks []uint32)
//go:noescape
func gcmAesData(productTable *[256]byte, data []byte, T *[16]byte)
//go:noescape
func gcmAesEnc(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, ks []uint32)
//go:noescape
func gcmAesDec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, ks []uint32)
//go:noescape
func gcmAesFinish(productTable *[256]byte, tagMask, T *[16]byte, pLen, dLen uint64)
// aesCipherGCM implements crypto/cipher.gcmAble so that crypto/cipher.NewGCM
// will use the optimised implementation in this file when possible. Instances
// of this type only exist when hasGCMAsm returns true.
type aesCipherGCM struct {
aesCipherAsm
}
// Assert that aesCipherGCM implements the gcmAble interface.
var _ gcmAble = (*aesCipherGCM)(nil)
// NewGCM returns the AES cipher wrapped in Galois Counter Mode. This is only
// called by crypto/cipher.NewGCM via the gcmAble interface.
func (c *aesCipherGCM) NewGCM(nonceSize int) (AEAD, error) {
g := &gcmAsm{ks: c.enc, nonceSize: nonceSize}
gcmAesInit(&g.productTable, g.ks)
return g, nil
}
type gcmAsm struct {
// ks is the key schedule, the length of which depends on the size of
// the AES key.
ks []uint32
// productTable contains pre-computed multiples of the binary-field
// element used in GHASH.
productTable [256]byte
// nonceSize contains the expected size of the nonce, in bytes.
nonceSize int
}
func (g *gcmAsm) NonceSize() int {
return g.nonceSize
}
func (*gcmAsm) Overhead() int {
return gcmTagSize
}
// Seal encrypts and authenticates plaintext. See the AEAD interface for
// details.
func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte {
if len(nonce) != g.nonceSize {
panic("cipher: incorrect nonce length given to GCM")
}
var counter, tagMask [gcmBlockSize]byte
if len(nonce) == gcmStandardNonceSize {
// Init counter to nonce||1
copy(counter[:], nonce)
counter[gcmBlockSize-1] = 1
} else {
// Otherwise counter = GHASH(nonce)
gcmAesData(&g.productTable, nonce, &counter)
gcmAesFinish(&g.productTable, &tagMask, &counter, uint64(len(nonce)), uint64(0))
}
aesEncBlock(&tagMask, &counter, g.ks)
var tagOut [16]byte
gcmAesData(&g.productTable, data, &tagOut)
ret, out := sliceForAppend(dst, len(plaintext)+gcmTagSize)
if len(plaintext) > 0 {
gcmAesEnc(&g.productTable, out, plaintext, &counter, &tagOut, g.ks)
}
gcmAesFinish(&g.productTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data)))
copy(out[len(plaintext):], tagOut[:gcmTagSize])
return ret
}
// Open authenticates and decrypts ciphertext. See the AEAD interface
// for details.
func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
if len(nonce) != g.nonceSize {
panic("cipher: incorrect nonce length given to GCM")
}
if len(ciphertext) < gcmTagSize {
return nil, errOpen
}
tag := ciphertext[len(ciphertext)-gcmTagSize:]
ciphertext = ciphertext[:len(ciphertext)-gcmTagSize]
// See GCM spec, section 7.1.
var counter, tagMask [gcmBlockSize]byte
if len(nonce) == gcmStandardNonceSize {
// Init counter to nonce||1
copy(counter[:], nonce)
counter[gcmBlockSize-1] = 1
} else {
// Otherwise counter = GHASH(nonce)
gcmAesData(&g.productTable, nonce, &counter)
gcmAesFinish(&g.productTable, &tagMask, &counter, uint64(len(nonce)), uint64(0))
}
aesEncBlock(&tagMask, &counter, g.ks)
var expectedTag [16]byte
gcmAesData(&g.productTable, data, &expectedTag)
ret, out := sliceForAppend(dst, len(ciphertext))
if len(ciphertext) > 0 {
gcmAesDec(&g.productTable, out, ciphertext, &counter, &expectedTag, g.ks)
}
gcmAesFinish(&g.productTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data)))
if subtle.ConstantTimeCompare(expectedTag[:12], tag) != 1 {
for i := range out {
out[i] = 0
}
return nil, errOpen
}
return ret, nil
}

View File

@ -1,285 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
// func hasAsm() bool
// returns whether AES-NI is supported
TEXT ·hasAsm(SB),NOSPLIT,$0
XORQ AX, AX
INCL AX
CPUID
SHRQ $25, CX
ANDQ $1, CX
MOVB CX, ret+0(FP)
RET
// func encryptBlockAsm(nr int, xk *uint32, dst, src *byte)
TEXT ·encryptBlockAsm(SB),NOSPLIT,$0
MOVQ nr+0(FP), CX
MOVQ xk+8(FP), AX
MOVQ dst+16(FP), DX
MOVQ src+24(FP), BX
MOVUPS 0(AX), X1
MOVUPS 0(BX), X0
ADDQ $16, AX
PXOR X1, X0
SUBQ $12, CX
JE Lenc196
JB Lenc128
Lenc256:
MOVUPS 0(AX), X1
AESENC X1, X0
MOVUPS 16(AX), X1
AESENC X1, X0
ADDQ $32, AX
Lenc196:
MOVUPS 0(AX), X1
AESENC X1, X0
MOVUPS 16(AX), X1
AESENC X1, X0
ADDQ $32, AX
Lenc128:
MOVUPS 0(AX), X1
AESENC X1, X0
MOVUPS 16(AX), X1
AESENC X1, X0
MOVUPS 32(AX), X1
AESENC X1, X0
MOVUPS 48(AX), X1
AESENC X1, X0
MOVUPS 64(AX), X1
AESENC X1, X0
MOVUPS 80(AX), X1
AESENC X1, X0
MOVUPS 96(AX), X1
AESENC X1, X0
MOVUPS 112(AX), X1
AESENC X1, X0
MOVUPS 128(AX), X1
AESENC X1, X0
MOVUPS 144(AX), X1
AESENCLAST X1, X0
MOVUPS X0, 0(DX)
RET
// func decryptBlockAsm(nr int, xk *uint32, dst, src *byte)
TEXT ·decryptBlockAsm(SB),NOSPLIT,$0
MOVQ nr+0(FP), CX
MOVQ xk+8(FP), AX
MOVQ dst+16(FP), DX
MOVQ src+24(FP), BX
MOVUPS 0(AX), X1
MOVUPS 0(BX), X0
ADDQ $16, AX
PXOR X1, X0
SUBQ $12, CX
JE Ldec196
JB Ldec128
Ldec256:
MOVUPS 0(AX), X1
AESDEC X1, X0
MOVUPS 16(AX), X1
AESDEC X1, X0
ADDQ $32, AX
Ldec196:
MOVUPS 0(AX), X1
AESDEC X1, X0
MOVUPS 16(AX), X1
AESDEC X1, X0
ADDQ $32, AX
Ldec128:
MOVUPS 0(AX), X1
AESDEC X1, X0
MOVUPS 16(AX), X1
AESDEC X1, X0
MOVUPS 32(AX), X1
AESDEC X1, X0
MOVUPS 48(AX), X1
AESDEC X1, X0
MOVUPS 64(AX), X1
AESDEC X1, X0
MOVUPS 80(AX), X1
AESDEC X1, X0
MOVUPS 96(AX), X1
AESDEC X1, X0
MOVUPS 112(AX), X1
AESDEC X1, X0
MOVUPS 128(AX), X1
AESDEC X1, X0
MOVUPS 144(AX), X1
AESDECLAST X1, X0
MOVUPS X0, 0(DX)
RET
// func expandKeyAsm(nr int, key *byte, enc, dec *uint32) {
// Note that round keys are stored in uint128 format, not uint32
TEXT ·expandKeyAsm(SB),NOSPLIT,$0
MOVQ nr+0(FP), CX
MOVQ key+8(FP), AX
MOVQ enc+16(FP), BX
MOVQ dec+24(FP), DX
MOVUPS (AX), X0
// enc
MOVUPS X0, (BX)
ADDQ $16, BX
PXOR X4, X4 // _expand_key_* expect X4 to be zero
CMPL CX, $12
JE Lexp_enc196
JB Lexp_enc128
Lexp_enc256:
MOVUPS 16(AX), X2
MOVUPS X2, (BX)
ADDQ $16, BX
AESKEYGENASSIST $0x01, X2, X1
CALL _expand_key_256a<>(SB)
AESKEYGENASSIST $0x01, X0, X1
CALL _expand_key_256b<>(SB)
AESKEYGENASSIST $0x02, X2, X1
CALL _expand_key_256a<>(SB)
AESKEYGENASSIST $0x02, X0, X1
CALL _expand_key_256b<>(SB)
AESKEYGENASSIST $0x04, X2, X1
CALL _expand_key_256a<>(SB)
AESKEYGENASSIST $0x04, X0, X1
CALL _expand_key_256b<>(SB)
AESKEYGENASSIST $0x08, X2, X1
CALL _expand_key_256a<>(SB)
AESKEYGENASSIST $0x08, X0, X1
CALL _expand_key_256b<>(SB)
AESKEYGENASSIST $0x10, X2, X1
CALL _expand_key_256a<>(SB)
AESKEYGENASSIST $0x10, X0, X1
CALL _expand_key_256b<>(SB)
AESKEYGENASSIST $0x20, X2, X1
CALL _expand_key_256a<>(SB)
AESKEYGENASSIST $0x20, X0, X1
CALL _expand_key_256b<>(SB)
AESKEYGENASSIST $0x40, X2, X1
CALL _expand_key_256a<>(SB)
JMP Lexp_dec
Lexp_enc196:
MOVQ 16(AX), X2
AESKEYGENASSIST $0x01, X2, X1
CALL _expand_key_192a<>(SB)
AESKEYGENASSIST $0x02, X2, X1
CALL _expand_key_192b<>(SB)
AESKEYGENASSIST $0x04, X2, X1
CALL _expand_key_192a<>(SB)
AESKEYGENASSIST $0x08, X2, X1
CALL _expand_key_192b<>(SB)
AESKEYGENASSIST $0x10, X2, X1
CALL _expand_key_192a<>(SB)
AESKEYGENASSIST $0x20, X2, X1
CALL _expand_key_192b<>(SB)
AESKEYGENASSIST $0x40, X2, X1
CALL _expand_key_192a<>(SB)
AESKEYGENASSIST $0x80, X2, X1
CALL _expand_key_192b<>(SB)
JMP Lexp_dec
Lexp_enc128:
AESKEYGENASSIST $0x01, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x02, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x04, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x08, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x10, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x20, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x40, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x80, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x1b, X0, X1
CALL _expand_key_128<>(SB)
AESKEYGENASSIST $0x36, X0, X1
CALL _expand_key_128<>(SB)
Lexp_dec:
// dec
SUBQ $16, BX
MOVUPS (BX), X1
MOVUPS X1, (DX)
DECQ CX
Lexp_dec_loop:
MOVUPS -16(BX), X1
AESIMC X1, X0
MOVUPS X0, 16(DX)
SUBQ $16, BX
ADDQ $16, DX
DECQ CX
JNZ Lexp_dec_loop
MOVUPS -16(BX), X0
MOVUPS X0, 16(DX)
RET
TEXT _expand_key_128<>(SB),NOSPLIT,$0
PSHUFD $0xff, X1, X1
SHUFPS $0x10, X0, X4
PXOR X4, X0
SHUFPS $0x8c, X0, X4
PXOR X4, X0
PXOR X1, X0
MOVUPS X0, (BX)
ADDQ $16, BX
RET
TEXT _expand_key_192a<>(SB),NOSPLIT,$0
PSHUFD $0x55, X1, X1
SHUFPS $0x10, X0, X4
PXOR X4, X0
SHUFPS $0x8c, X0, X4
PXOR X4, X0
PXOR X1, X0
MOVAPS X2, X5
MOVAPS X2, X6
PSLLDQ $0x4, X5
PSHUFD $0xff, X0, X3
PXOR X3, X2
PXOR X5, X2
MOVAPS X0, X1
SHUFPS $0x44, X0, X6
MOVUPS X6, (BX)
SHUFPS $0x4e, X2, X1
MOVUPS X1, 16(BX)
ADDQ $32, BX
RET
TEXT _expand_key_192b<>(SB),NOSPLIT,$0
PSHUFD $0x55, X1, X1
SHUFPS $0x10, X0, X4
PXOR X4, X0
SHUFPS $0x8c, X0, X4
PXOR X4, X0
PXOR X1, X0
MOVAPS X2, X5
PSLLDQ $0x4, X5
PSHUFD $0xff, X0, X3
PXOR X3, X2
PXOR X5, X2
MOVUPS X0, (BX)
ADDQ $16, BX
RET
TEXT _expand_key_256a<>(SB),NOSPLIT,$0
JMP _expand_key_128<>(SB)
TEXT _expand_key_256b<>(SB),NOSPLIT,$0
PSHUFD $0xaa, X1, X1
SHUFPS $0x10, X2, X4
PXOR X4, X2
SHUFPS $0x8c, X2, X4
PXOR X4, X2
PXOR X1, X2
MOVUPS X2, (BX)
ADDQ $16, BX
RET

View File

@ -1,176 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This Go implementation is derived in part from the reference
// ANSI C implementation, which carries the following notice:
//
// rijndael-alg-fst.c
//
// @version 3.0 (December 2000)
//
// Optimised ANSI C code for the Rijndael cipher (now AES)
//
// @author Vincent Rijmen <vincent.rijmen@esat.kuleuven.ac.be>
// @author Antoon Bosselaers <antoon.bosselaers@esat.kuleuven.ac.be>
// @author Paulo Barreto <paulo.barreto@terra.com.br>
//
// This code is hereby placed in the public domain.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHORS ''AS IS'' AND ANY EXPRESS
// OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
// OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// See FIPS 197 for specification, and see Daemen and Rijmen's Rijndael submission
// for implementation details.
// http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf
// http://csrc.nist.gov/archive/aes/rijndael/Rijndael-ammended.pdf
package aes12
// Encrypt one block from src into dst, using the expanded key xk.
func encryptBlockGo(xk []uint32, dst, src []byte) {
var s0, s1, s2, s3, t0, t1, t2, t3 uint32
s0 = uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
s1 = uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])
s2 = uint32(src[8])<<24 | uint32(src[9])<<16 | uint32(src[10])<<8 | uint32(src[11])
s3 = uint32(src[12])<<24 | uint32(src[13])<<16 | uint32(src[14])<<8 | uint32(src[15])
// First round just XORs input with key.
s0 ^= xk[0]
s1 ^= xk[1]
s2 ^= xk[2]
s3 ^= xk[3]
// Middle rounds shuffle using tables.
// Number of rounds is set by length of expanded key.
nr := len(xk)/4 - 2 // - 2: one above, one more below
k := 4
for r := 0; r < nr; r++ {
t0 = xk[k+0] ^ te0[uint8(s0>>24)] ^ te1[uint8(s1>>16)] ^ te2[uint8(s2>>8)] ^ te3[uint8(s3)]
t1 = xk[k+1] ^ te0[uint8(s1>>24)] ^ te1[uint8(s2>>16)] ^ te2[uint8(s3>>8)] ^ te3[uint8(s0)]
t2 = xk[k+2] ^ te0[uint8(s2>>24)] ^ te1[uint8(s3>>16)] ^ te2[uint8(s0>>8)] ^ te3[uint8(s1)]
t3 = xk[k+3] ^ te0[uint8(s3>>24)] ^ te1[uint8(s0>>16)] ^ te2[uint8(s1>>8)] ^ te3[uint8(s2)]
k += 4
s0, s1, s2, s3 = t0, t1, t2, t3
}
// Last round uses s-box directly and XORs to produce output.
s0 = uint32(sbox0[t0>>24])<<24 | uint32(sbox0[t1>>16&0xff])<<16 | uint32(sbox0[t2>>8&0xff])<<8 | uint32(sbox0[t3&0xff])
s1 = uint32(sbox0[t1>>24])<<24 | uint32(sbox0[t2>>16&0xff])<<16 | uint32(sbox0[t3>>8&0xff])<<8 | uint32(sbox0[t0&0xff])
s2 = uint32(sbox0[t2>>24])<<24 | uint32(sbox0[t3>>16&0xff])<<16 | uint32(sbox0[t0>>8&0xff])<<8 | uint32(sbox0[t1&0xff])
s3 = uint32(sbox0[t3>>24])<<24 | uint32(sbox0[t0>>16&0xff])<<16 | uint32(sbox0[t1>>8&0xff])<<8 | uint32(sbox0[t2&0xff])
s0 ^= xk[k+0]
s1 ^= xk[k+1]
s2 ^= xk[k+2]
s3 ^= xk[k+3]
dst[0], dst[1], dst[2], dst[3] = byte(s0>>24), byte(s0>>16), byte(s0>>8), byte(s0)
dst[4], dst[5], dst[6], dst[7] = byte(s1>>24), byte(s1>>16), byte(s1>>8), byte(s1)
dst[8], dst[9], dst[10], dst[11] = byte(s2>>24), byte(s2>>16), byte(s2>>8), byte(s2)
dst[12], dst[13], dst[14], dst[15] = byte(s3>>24), byte(s3>>16), byte(s3>>8), byte(s3)
}
// Decrypt one block from src into dst, using the expanded key xk.
func decryptBlockGo(xk []uint32, dst, src []byte) {
var s0, s1, s2, s3, t0, t1, t2, t3 uint32
s0 = uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
s1 = uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])
s2 = uint32(src[8])<<24 | uint32(src[9])<<16 | uint32(src[10])<<8 | uint32(src[11])
s3 = uint32(src[12])<<24 | uint32(src[13])<<16 | uint32(src[14])<<8 | uint32(src[15])
// First round just XORs input with key.
s0 ^= xk[0]
s1 ^= xk[1]
s2 ^= xk[2]
s3 ^= xk[3]
// Middle rounds shuffle using tables.
// Number of rounds is set by length of expanded key.
nr := len(xk)/4 - 2 // - 2: one above, one more below
k := 4
for r := 0; r < nr; r++ {
t0 = xk[k+0] ^ td0[uint8(s0>>24)] ^ td1[uint8(s3>>16)] ^ td2[uint8(s2>>8)] ^ td3[uint8(s1)]
t1 = xk[k+1] ^ td0[uint8(s1>>24)] ^ td1[uint8(s0>>16)] ^ td2[uint8(s3>>8)] ^ td3[uint8(s2)]
t2 = xk[k+2] ^ td0[uint8(s2>>24)] ^ td1[uint8(s1>>16)] ^ td2[uint8(s0>>8)] ^ td3[uint8(s3)]
t3 = xk[k+3] ^ td0[uint8(s3>>24)] ^ td1[uint8(s2>>16)] ^ td2[uint8(s1>>8)] ^ td3[uint8(s0)]
k += 4
s0, s1, s2, s3 = t0, t1, t2, t3
}
// Last round uses s-box directly and XORs to produce output.
s0 = uint32(sbox1[t0>>24])<<24 | uint32(sbox1[t3>>16&0xff])<<16 | uint32(sbox1[t2>>8&0xff])<<8 | uint32(sbox1[t1&0xff])
s1 = uint32(sbox1[t1>>24])<<24 | uint32(sbox1[t0>>16&0xff])<<16 | uint32(sbox1[t3>>8&0xff])<<8 | uint32(sbox1[t2&0xff])
s2 = uint32(sbox1[t2>>24])<<24 | uint32(sbox1[t1>>16&0xff])<<16 | uint32(sbox1[t0>>8&0xff])<<8 | uint32(sbox1[t3&0xff])
s3 = uint32(sbox1[t3>>24])<<24 | uint32(sbox1[t2>>16&0xff])<<16 | uint32(sbox1[t1>>8&0xff])<<8 | uint32(sbox1[t0&0xff])
s0 ^= xk[k+0]
s1 ^= xk[k+1]
s2 ^= xk[k+2]
s3 ^= xk[k+3]
dst[0], dst[1], dst[2], dst[3] = byte(s0>>24), byte(s0>>16), byte(s0>>8), byte(s0)
dst[4], dst[5], dst[6], dst[7] = byte(s1>>24), byte(s1>>16), byte(s1>>8), byte(s1)
dst[8], dst[9], dst[10], dst[11] = byte(s2>>24), byte(s2>>16), byte(s2>>8), byte(s2)
dst[12], dst[13], dst[14], dst[15] = byte(s3>>24), byte(s3>>16), byte(s3>>8), byte(s3)
}
// Apply sbox0 to each byte in w.
func subw(w uint32) uint32 {
return uint32(sbox0[w>>24])<<24 |
uint32(sbox0[w>>16&0xff])<<16 |
uint32(sbox0[w>>8&0xff])<<8 |
uint32(sbox0[w&0xff])
}
// Rotate
func rotw(w uint32) uint32 { return w<<8 | w>>24 }
// Key expansion algorithm. See FIPS-197, Figure 11.
// Their rcon[i] is our powx[i-1] << 24.
func expandKeyGo(key []byte, enc, dec []uint32) {
// Encryption key setup.
var i int
nk := len(key) / 4
for i = 0; i < nk; i++ {
enc[i] = uint32(key[4*i])<<24 | uint32(key[4*i+1])<<16 | uint32(key[4*i+2])<<8 | uint32(key[4*i+3])
}
for ; i < len(enc); i++ {
t := enc[i-1]
if i%nk == 0 {
t = subw(rotw(t)) ^ (uint32(powx[i/nk-1]) << 24)
} else if nk > 6 && i%nk == 4 {
t = subw(t)
}
enc[i] = enc[i-nk] ^ t
}
// Derive decryption key from encryption key.
// Reverse the 4-word round key sets from enc to produce dec.
// All sets but the first and last get the MixColumn transform applied.
if dec == nil {
return
}
n := len(enc)
for i := 0; i < n; i += 4 {
ei := n - i - 4
for j := 0; j < 4; j++ {
x := enc[ei+j]
if i > 0 && i+4 < n {
x = td0[sbox0[x>>24]] ^ td1[sbox0[x>>16&0xff]] ^ td2[sbox0[x>>8&0xff]] ^ td3[sbox0[x&0xff]]
}
dec[i+j] = x
}
}
}

View File

@ -1,68 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package aes12
import "strconv"
// The AES block size in bytes.
const BlockSize = 16
// A cipher is an instance of AES encryption using a particular key.
type aesCipher struct {
enc []uint32
dec []uint32
}
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/aes: invalid key size " + strconv.Itoa(int(k))
}
// NewCipher creates and returns a new Block.
// The key argument should be the AES key,
// either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256.
func NewCipher(key []byte) (Block, error) {
k := len(key)
switch k {
default:
return nil, KeySizeError(k)
case 16, 24, 32:
break
}
return newCipher(key)
}
// newCipherGeneric creates and returns a new Block
// implemented in pure Go.
func newCipherGeneric(key []byte) (Block, error) {
n := len(key) + 28
c := aesCipher{make([]uint32, n), make([]uint32, n)}
expandKeyGo(key, c.enc, c.dec)
return &c, nil
}
func (c *aesCipher) BlockSize() int { return BlockSize }
func (c *aesCipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
encryptBlockGo(c.enc, dst, src)
}
func (c *aesCipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
decryptBlockGo(c.dec, dst, src)
}

View File

@ -1,56 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// package aes12 implements standard block cipher modes that can be wrapped
// around low-level block cipher implementations.
// See http://csrc.nist.gov/groups/ST/toolkit/BCM/current_modes.html
// and NIST Special Publication 800-38A.
package aes12
// A Block represents an implementation of block cipher
// using a given key. It provides the capability to encrypt
// or decrypt individual blocks. The mode implementations
// extend that capability to streams of blocks.
type Block interface {
// BlockSize returns the cipher's block size.
BlockSize() int
// Encrypt encrypts the first block in src into dst.
// Dst and src may point at the same memory.
Encrypt(dst, src []byte)
// Decrypt decrypts the first block in src into dst.
// Dst and src may point at the same memory.
Decrypt(dst, src []byte)
}
// A Stream represents a stream cipher.
type Stream interface {
// XORKeyStream XORs each byte in the given slice with a byte from the
// cipher's key stream. Dst and src may point to the same memory.
// If len(dst) < len(src), XORKeyStream should panic. It is acceptable
// to pass a dst bigger than src, and in that case, XORKeyStream will
// only update dst[:len(src)] and will not touch the rest of dst.
XORKeyStream(dst, src []byte)
}
// A BlockMode represents a block cipher running in a block-based mode (CBC,
// ECB etc).
type BlockMode interface {
// BlockSize returns the mode's block size.
BlockSize() int
// CryptBlocks encrypts or decrypts a number of blocks. The length of
// src must be a multiple of the block size. Dst and src may point to
// the same memory.
CryptBlocks(dst, src []byte)
}
// Utility routines
func dup(p []byte) []byte {
q := make([]byte, len(p))
copy(q, p)
return q
}

View File

@ -1,79 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package aes12
// defined in asm_amd64.s
func hasAsm() bool
func encryptBlockAsm(nr int, xk *uint32, dst, src *byte)
func decryptBlockAsm(nr int, xk *uint32, dst, src *byte)
func expandKeyAsm(nr int, key *byte, enc *uint32, dec *uint32)
type aesCipherAsm struct {
aesCipher
}
var useAsm = hasAsm()
func newCipher(key []byte) (Block, error) {
if !useAsm {
return newCipherGeneric(key)
}
n := len(key) + 28
c := aesCipherAsm{aesCipher{make([]uint32, n), make([]uint32, n)}}
rounds := 10
switch len(key) {
case 128 / 8:
rounds = 10
case 192 / 8:
rounds = 12
case 256 / 8:
rounds = 14
}
expandKeyAsm(rounds, &key[0], &c.enc[0], &c.dec[0])
if hasGCMAsm() {
return &aesCipherGCM{c}, nil
}
return &c, nil
}
func (c *aesCipherAsm) BlockSize() int { return BlockSize }
func (c *aesCipherAsm) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
encryptBlockAsm(len(c.enc)/4-1, &c.enc[0], &dst[0], &src[0])
}
func (c *aesCipherAsm) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
decryptBlockAsm(len(c.dec)/4-1, &c.dec[0], &dst[0], &src[0])
}
// expandKey is used by BenchmarkExpand to ensure that the asm implementation
// of key expansion is used for the benchmark when it is available.
func expandKey(key []byte, enc, dec []uint32) {
if useAsm {
rounds := 10 // rounds needed for AES128
switch len(key) {
case 192 / 8:
rounds = 12
case 256 / 8:
rounds = 14
}
expandKeyAsm(rounds, &key[0], &enc[0], &dec[0])
} else {
expandKeyGo(key, enc, dec)
}
}

View File

@ -1,22 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64
package aes12
// newCipher calls the newCipherGeneric function
// directly. Platforms with hardware accelerated
// implementations of AES should implement their
// own version of newCipher (which may then call
// newCipherGeneric if needed).
func newCipher(key []byte) (Block, error) {
return newCipherGeneric(key)
}
// expandKey is used by BenchmarkExpand and should
// call an assembly implementation if one is available.
func expandKey(key []byte, enc, dec []uint32) {
expandKeyGo(key, enc, dec)
}

View File

@ -1,358 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// package aes12 implements AES encryption (formerly Rijndael), as defined in
// U.S. Federal Information Processing Standards Publication 197.
package aes12
// This file contains AES constants - 8720 bytes of initialized data.
// http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf
// AES is based on the mathematical behavior of binary polynomials
// (polynomials over GF(2)) modulo the irreducible polynomial x⁸ + x⁴ + x³ + x + 1.
// Addition of these binary polynomials corresponds to binary xor.
// Reducing mod poly corresponds to binary xor with poly every
// time a 0x100 bit appears.
const poly = 1<<8 | 1<<4 | 1<<3 | 1<<1 | 1<<0 // x⁸ + x⁴ + x³ + x + 1
// Powers of x mod poly in GF(2).
var powx = [16]byte{
0x01,
0x02,
0x04,
0x08,
0x10,
0x20,
0x40,
0x80,
0x1b,
0x36,
0x6c,
0xd8,
0xab,
0x4d,
0x9a,
0x2f,
}
// FIPS-197 Figure 7. S-box substitution values in hexadecimal format.
var sbox0 = [256]byte{
0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
}
// FIPS-197 Figure 14. Inverse S-box substitution values in hexadecimal format.
var sbox1 = [256]byte{
0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
}
// Lookup tables for encryption.
// These can be recomputed by adapting the tests in aes_test.go.
var te0 = [256]uint32{
0xc66363a5, 0xf87c7c84, 0xee777799, 0xf67b7b8d, 0xfff2f20d, 0xd66b6bbd, 0xde6f6fb1, 0x91c5c554,
0x60303050, 0x02010103, 0xce6767a9, 0x562b2b7d, 0xe7fefe19, 0xb5d7d762, 0x4dababe6, 0xec76769a,
0x8fcaca45, 0x1f82829d, 0x89c9c940, 0xfa7d7d87, 0xeffafa15, 0xb25959eb, 0x8e4747c9, 0xfbf0f00b,
0x41adadec, 0xb3d4d467, 0x5fa2a2fd, 0x45afafea, 0x239c9cbf, 0x53a4a4f7, 0xe4727296, 0x9bc0c05b,
0x75b7b7c2, 0xe1fdfd1c, 0x3d9393ae, 0x4c26266a, 0x6c36365a, 0x7e3f3f41, 0xf5f7f702, 0x83cccc4f,
0x6834345c, 0x51a5a5f4, 0xd1e5e534, 0xf9f1f108, 0xe2717193, 0xabd8d873, 0x62313153, 0x2a15153f,
0x0804040c, 0x95c7c752, 0x46232365, 0x9dc3c35e, 0x30181828, 0x379696a1, 0x0a05050f, 0x2f9a9ab5,
0x0e070709, 0x24121236, 0x1b80809b, 0xdfe2e23d, 0xcdebeb26, 0x4e272769, 0x7fb2b2cd, 0xea75759f,
0x1209091b, 0x1d83839e, 0x582c2c74, 0x341a1a2e, 0x361b1b2d, 0xdc6e6eb2, 0xb45a5aee, 0x5ba0a0fb,
0xa45252f6, 0x763b3b4d, 0xb7d6d661, 0x7db3b3ce, 0x5229297b, 0xdde3e33e, 0x5e2f2f71, 0x13848497,
0xa65353f5, 0xb9d1d168, 0x00000000, 0xc1eded2c, 0x40202060, 0xe3fcfc1f, 0x79b1b1c8, 0xb65b5bed,
0xd46a6abe, 0x8dcbcb46, 0x67bebed9, 0x7239394b, 0x944a4ade, 0x984c4cd4, 0xb05858e8, 0x85cfcf4a,
0xbbd0d06b, 0xc5efef2a, 0x4faaaae5, 0xedfbfb16, 0x864343c5, 0x9a4d4dd7, 0x66333355, 0x11858594,
0x8a4545cf, 0xe9f9f910, 0x04020206, 0xfe7f7f81, 0xa05050f0, 0x783c3c44, 0x259f9fba, 0x4ba8a8e3,
0xa25151f3, 0x5da3a3fe, 0x804040c0, 0x058f8f8a, 0x3f9292ad, 0x219d9dbc, 0x70383848, 0xf1f5f504,
0x63bcbcdf, 0x77b6b6c1, 0xafdada75, 0x42212163, 0x20101030, 0xe5ffff1a, 0xfdf3f30e, 0xbfd2d26d,
0x81cdcd4c, 0x180c0c14, 0x26131335, 0xc3ecec2f, 0xbe5f5fe1, 0x359797a2, 0x884444cc, 0x2e171739,
0x93c4c457, 0x55a7a7f2, 0xfc7e7e82, 0x7a3d3d47, 0xc86464ac, 0xba5d5de7, 0x3219192b, 0xe6737395,
0xc06060a0, 0x19818198, 0x9e4f4fd1, 0xa3dcdc7f, 0x44222266, 0x542a2a7e, 0x3b9090ab, 0x0b888883,
0x8c4646ca, 0xc7eeee29, 0x6bb8b8d3, 0x2814143c, 0xa7dede79, 0xbc5e5ee2, 0x160b0b1d, 0xaddbdb76,
0xdbe0e03b, 0x64323256, 0x743a3a4e, 0x140a0a1e, 0x924949db, 0x0c06060a, 0x4824246c, 0xb85c5ce4,
0x9fc2c25d, 0xbdd3d36e, 0x43acacef, 0xc46262a6, 0x399191a8, 0x319595a4, 0xd3e4e437, 0xf279798b,
0xd5e7e732, 0x8bc8c843, 0x6e373759, 0xda6d6db7, 0x018d8d8c, 0xb1d5d564, 0x9c4e4ed2, 0x49a9a9e0,
0xd86c6cb4, 0xac5656fa, 0xf3f4f407, 0xcfeaea25, 0xca6565af, 0xf47a7a8e, 0x47aeaee9, 0x10080818,
0x6fbabad5, 0xf0787888, 0x4a25256f, 0x5c2e2e72, 0x381c1c24, 0x57a6a6f1, 0x73b4b4c7, 0x97c6c651,
0xcbe8e823, 0xa1dddd7c, 0xe874749c, 0x3e1f1f21, 0x964b4bdd, 0x61bdbddc, 0x0d8b8b86, 0x0f8a8a85,
0xe0707090, 0x7c3e3e42, 0x71b5b5c4, 0xcc6666aa, 0x904848d8, 0x06030305, 0xf7f6f601, 0x1c0e0e12,
0xc26161a3, 0x6a35355f, 0xae5757f9, 0x69b9b9d0, 0x17868691, 0x99c1c158, 0x3a1d1d27, 0x279e9eb9,
0xd9e1e138, 0xebf8f813, 0x2b9898b3, 0x22111133, 0xd26969bb, 0xa9d9d970, 0x078e8e89, 0x339494a7,
0x2d9b9bb6, 0x3c1e1e22, 0x15878792, 0xc9e9e920, 0x87cece49, 0xaa5555ff, 0x50282878, 0xa5dfdf7a,
0x038c8c8f, 0x59a1a1f8, 0x09898980, 0x1a0d0d17, 0x65bfbfda, 0xd7e6e631, 0x844242c6, 0xd06868b8,
0x824141c3, 0x299999b0, 0x5a2d2d77, 0x1e0f0f11, 0x7bb0b0cb, 0xa85454fc, 0x6dbbbbd6, 0x2c16163a,
}
var te1 = [256]uint32{
0xa5c66363, 0x84f87c7c, 0x99ee7777, 0x8df67b7b, 0x0dfff2f2, 0xbdd66b6b, 0xb1de6f6f, 0x5491c5c5,
0x50603030, 0x03020101, 0xa9ce6767, 0x7d562b2b, 0x19e7fefe, 0x62b5d7d7, 0xe64dabab, 0x9aec7676,
0x458fcaca, 0x9d1f8282, 0x4089c9c9, 0x87fa7d7d, 0x15effafa, 0xebb25959, 0xc98e4747, 0x0bfbf0f0,
0xec41adad, 0x67b3d4d4, 0xfd5fa2a2, 0xea45afaf, 0xbf239c9c, 0xf753a4a4, 0x96e47272, 0x5b9bc0c0,
0xc275b7b7, 0x1ce1fdfd, 0xae3d9393, 0x6a4c2626, 0x5a6c3636, 0x417e3f3f, 0x02f5f7f7, 0x4f83cccc,
0x5c683434, 0xf451a5a5, 0x34d1e5e5, 0x08f9f1f1, 0x93e27171, 0x73abd8d8, 0x53623131, 0x3f2a1515,
0x0c080404, 0x5295c7c7, 0x65462323, 0x5e9dc3c3, 0x28301818, 0xa1379696, 0x0f0a0505, 0xb52f9a9a,
0x090e0707, 0x36241212, 0x9b1b8080, 0x3ddfe2e2, 0x26cdebeb, 0x694e2727, 0xcd7fb2b2, 0x9fea7575,
0x1b120909, 0x9e1d8383, 0x74582c2c, 0x2e341a1a, 0x2d361b1b, 0xb2dc6e6e, 0xeeb45a5a, 0xfb5ba0a0,
0xf6a45252, 0x4d763b3b, 0x61b7d6d6, 0xce7db3b3, 0x7b522929, 0x3edde3e3, 0x715e2f2f, 0x97138484,
0xf5a65353, 0x68b9d1d1, 0x00000000, 0x2cc1eded, 0x60402020, 0x1fe3fcfc, 0xc879b1b1, 0xedb65b5b,
0xbed46a6a, 0x468dcbcb, 0xd967bebe, 0x4b723939, 0xde944a4a, 0xd4984c4c, 0xe8b05858, 0x4a85cfcf,
0x6bbbd0d0, 0x2ac5efef, 0xe54faaaa, 0x16edfbfb, 0xc5864343, 0xd79a4d4d, 0x55663333, 0x94118585,
0xcf8a4545, 0x10e9f9f9, 0x06040202, 0x81fe7f7f, 0xf0a05050, 0x44783c3c, 0xba259f9f, 0xe34ba8a8,
0xf3a25151, 0xfe5da3a3, 0xc0804040, 0x8a058f8f, 0xad3f9292, 0xbc219d9d, 0x48703838, 0x04f1f5f5,
0xdf63bcbc, 0xc177b6b6, 0x75afdada, 0x63422121, 0x30201010, 0x1ae5ffff, 0x0efdf3f3, 0x6dbfd2d2,
0x4c81cdcd, 0x14180c0c, 0x35261313, 0x2fc3ecec, 0xe1be5f5f, 0xa2359797, 0xcc884444, 0x392e1717,
0x5793c4c4, 0xf255a7a7, 0x82fc7e7e, 0x477a3d3d, 0xacc86464, 0xe7ba5d5d, 0x2b321919, 0x95e67373,
0xa0c06060, 0x98198181, 0xd19e4f4f, 0x7fa3dcdc, 0x66442222, 0x7e542a2a, 0xab3b9090, 0x830b8888,
0xca8c4646, 0x29c7eeee, 0xd36bb8b8, 0x3c281414, 0x79a7dede, 0xe2bc5e5e, 0x1d160b0b, 0x76addbdb,
0x3bdbe0e0, 0x56643232, 0x4e743a3a, 0x1e140a0a, 0xdb924949, 0x0a0c0606, 0x6c482424, 0xe4b85c5c,
0x5d9fc2c2, 0x6ebdd3d3, 0xef43acac, 0xa6c46262, 0xa8399191, 0xa4319595, 0x37d3e4e4, 0x8bf27979,
0x32d5e7e7, 0x438bc8c8, 0x596e3737, 0xb7da6d6d, 0x8c018d8d, 0x64b1d5d5, 0xd29c4e4e, 0xe049a9a9,
0xb4d86c6c, 0xfaac5656, 0x07f3f4f4, 0x25cfeaea, 0xafca6565, 0x8ef47a7a, 0xe947aeae, 0x18100808,
0xd56fbaba, 0x88f07878, 0x6f4a2525, 0x725c2e2e, 0x24381c1c, 0xf157a6a6, 0xc773b4b4, 0x5197c6c6,
0x23cbe8e8, 0x7ca1dddd, 0x9ce87474, 0x213e1f1f, 0xdd964b4b, 0xdc61bdbd, 0x860d8b8b, 0x850f8a8a,
0x90e07070, 0x427c3e3e, 0xc471b5b5, 0xaacc6666, 0xd8904848, 0x05060303, 0x01f7f6f6, 0x121c0e0e,
0xa3c26161, 0x5f6a3535, 0xf9ae5757, 0xd069b9b9, 0x91178686, 0x5899c1c1, 0x273a1d1d, 0xb9279e9e,
0x38d9e1e1, 0x13ebf8f8, 0xb32b9898, 0x33221111, 0xbbd26969, 0x70a9d9d9, 0x89078e8e, 0xa7339494,
0xb62d9b9b, 0x223c1e1e, 0x92158787, 0x20c9e9e9, 0x4987cece, 0xffaa5555, 0x78502828, 0x7aa5dfdf,
0x8f038c8c, 0xf859a1a1, 0x80098989, 0x171a0d0d, 0xda65bfbf, 0x31d7e6e6, 0xc6844242, 0xb8d06868,
0xc3824141, 0xb0299999, 0x775a2d2d, 0x111e0f0f, 0xcb7bb0b0, 0xfca85454, 0xd66dbbbb, 0x3a2c1616,
}
var te2 = [256]uint32{
0x63a5c663, 0x7c84f87c, 0x7799ee77, 0x7b8df67b, 0xf20dfff2, 0x6bbdd66b, 0x6fb1de6f, 0xc55491c5,
0x30506030, 0x01030201, 0x67a9ce67, 0x2b7d562b, 0xfe19e7fe, 0xd762b5d7, 0xabe64dab, 0x769aec76,
0xca458fca, 0x829d1f82, 0xc94089c9, 0x7d87fa7d, 0xfa15effa, 0x59ebb259, 0x47c98e47, 0xf00bfbf0,
0xadec41ad, 0xd467b3d4, 0xa2fd5fa2, 0xafea45af, 0x9cbf239c, 0xa4f753a4, 0x7296e472, 0xc05b9bc0,
0xb7c275b7, 0xfd1ce1fd, 0x93ae3d93, 0x266a4c26, 0x365a6c36, 0x3f417e3f, 0xf702f5f7, 0xcc4f83cc,
0x345c6834, 0xa5f451a5, 0xe534d1e5, 0xf108f9f1, 0x7193e271, 0xd873abd8, 0x31536231, 0x153f2a15,
0x040c0804, 0xc75295c7, 0x23654623, 0xc35e9dc3, 0x18283018, 0x96a13796, 0x050f0a05, 0x9ab52f9a,
0x07090e07, 0x12362412, 0x809b1b80, 0xe23ddfe2, 0xeb26cdeb, 0x27694e27, 0xb2cd7fb2, 0x759fea75,
0x091b1209, 0x839e1d83, 0x2c74582c, 0x1a2e341a, 0x1b2d361b, 0x6eb2dc6e, 0x5aeeb45a, 0xa0fb5ba0,
0x52f6a452, 0x3b4d763b, 0xd661b7d6, 0xb3ce7db3, 0x297b5229, 0xe33edde3, 0x2f715e2f, 0x84971384,
0x53f5a653, 0xd168b9d1, 0x00000000, 0xed2cc1ed, 0x20604020, 0xfc1fe3fc, 0xb1c879b1, 0x5bedb65b,
0x6abed46a, 0xcb468dcb, 0xbed967be, 0x394b7239, 0x4ade944a, 0x4cd4984c, 0x58e8b058, 0xcf4a85cf,
0xd06bbbd0, 0xef2ac5ef, 0xaae54faa, 0xfb16edfb, 0x43c58643, 0x4dd79a4d, 0x33556633, 0x85941185,
0x45cf8a45, 0xf910e9f9, 0x02060402, 0x7f81fe7f, 0x50f0a050, 0x3c44783c, 0x9fba259f, 0xa8e34ba8,
0x51f3a251, 0xa3fe5da3, 0x40c08040, 0x8f8a058f, 0x92ad3f92, 0x9dbc219d, 0x38487038, 0xf504f1f5,
0xbcdf63bc, 0xb6c177b6, 0xda75afda, 0x21634221, 0x10302010, 0xff1ae5ff, 0xf30efdf3, 0xd26dbfd2,
0xcd4c81cd, 0x0c14180c, 0x13352613, 0xec2fc3ec, 0x5fe1be5f, 0x97a23597, 0x44cc8844, 0x17392e17,
0xc45793c4, 0xa7f255a7, 0x7e82fc7e, 0x3d477a3d, 0x64acc864, 0x5de7ba5d, 0x192b3219, 0x7395e673,
0x60a0c060, 0x81981981, 0x4fd19e4f, 0xdc7fa3dc, 0x22664422, 0x2a7e542a, 0x90ab3b90, 0x88830b88,
0x46ca8c46, 0xee29c7ee, 0xb8d36bb8, 0x143c2814, 0xde79a7de, 0x5ee2bc5e, 0x0b1d160b, 0xdb76addb,
0xe03bdbe0, 0x32566432, 0x3a4e743a, 0x0a1e140a, 0x49db9249, 0x060a0c06, 0x246c4824, 0x5ce4b85c,
0xc25d9fc2, 0xd36ebdd3, 0xacef43ac, 0x62a6c462, 0x91a83991, 0x95a43195, 0xe437d3e4, 0x798bf279,
0xe732d5e7, 0xc8438bc8, 0x37596e37, 0x6db7da6d, 0x8d8c018d, 0xd564b1d5, 0x4ed29c4e, 0xa9e049a9,
0x6cb4d86c, 0x56faac56, 0xf407f3f4, 0xea25cfea, 0x65afca65, 0x7a8ef47a, 0xaee947ae, 0x08181008,
0xbad56fba, 0x7888f078, 0x256f4a25, 0x2e725c2e, 0x1c24381c, 0xa6f157a6, 0xb4c773b4, 0xc65197c6,
0xe823cbe8, 0xdd7ca1dd, 0x749ce874, 0x1f213e1f, 0x4bdd964b, 0xbddc61bd, 0x8b860d8b, 0x8a850f8a,
0x7090e070, 0x3e427c3e, 0xb5c471b5, 0x66aacc66, 0x48d89048, 0x03050603, 0xf601f7f6, 0x0e121c0e,
0x61a3c261, 0x355f6a35, 0x57f9ae57, 0xb9d069b9, 0x86911786, 0xc15899c1, 0x1d273a1d, 0x9eb9279e,
0xe138d9e1, 0xf813ebf8, 0x98b32b98, 0x11332211, 0x69bbd269, 0xd970a9d9, 0x8e89078e, 0x94a73394,
0x9bb62d9b, 0x1e223c1e, 0x87921587, 0xe920c9e9, 0xce4987ce, 0x55ffaa55, 0x28785028, 0xdf7aa5df,
0x8c8f038c, 0xa1f859a1, 0x89800989, 0x0d171a0d, 0xbfda65bf, 0xe631d7e6, 0x42c68442, 0x68b8d068,
0x41c38241, 0x99b02999, 0x2d775a2d, 0x0f111e0f, 0xb0cb7bb0, 0x54fca854, 0xbbd66dbb, 0x163a2c16,
}
var te3 = [256]uint32{
0x6363a5c6, 0x7c7c84f8, 0x777799ee, 0x7b7b8df6, 0xf2f20dff, 0x6b6bbdd6, 0x6f6fb1de, 0xc5c55491,
0x30305060, 0x01010302, 0x6767a9ce, 0x2b2b7d56, 0xfefe19e7, 0xd7d762b5, 0xababe64d, 0x76769aec,
0xcaca458f, 0x82829d1f, 0xc9c94089, 0x7d7d87fa, 0xfafa15ef, 0x5959ebb2, 0x4747c98e, 0xf0f00bfb,
0xadadec41, 0xd4d467b3, 0xa2a2fd5f, 0xafafea45, 0x9c9cbf23, 0xa4a4f753, 0x727296e4, 0xc0c05b9b,
0xb7b7c275, 0xfdfd1ce1, 0x9393ae3d, 0x26266a4c, 0x36365a6c, 0x3f3f417e, 0xf7f702f5, 0xcccc4f83,
0x34345c68, 0xa5a5f451, 0xe5e534d1, 0xf1f108f9, 0x717193e2, 0xd8d873ab, 0x31315362, 0x15153f2a,
0x04040c08, 0xc7c75295, 0x23236546, 0xc3c35e9d, 0x18182830, 0x9696a137, 0x05050f0a, 0x9a9ab52f,
0x0707090e, 0x12123624, 0x80809b1b, 0xe2e23ddf, 0xebeb26cd, 0x2727694e, 0xb2b2cd7f, 0x75759fea,
0x09091b12, 0x83839e1d, 0x2c2c7458, 0x1a1a2e34, 0x1b1b2d36, 0x6e6eb2dc, 0x5a5aeeb4, 0xa0a0fb5b,
0x5252f6a4, 0x3b3b4d76, 0xd6d661b7, 0xb3b3ce7d, 0x29297b52, 0xe3e33edd, 0x2f2f715e, 0x84849713,
0x5353f5a6, 0xd1d168b9, 0x00000000, 0xeded2cc1, 0x20206040, 0xfcfc1fe3, 0xb1b1c879, 0x5b5bedb6,
0x6a6abed4, 0xcbcb468d, 0xbebed967, 0x39394b72, 0x4a4ade94, 0x4c4cd498, 0x5858e8b0, 0xcfcf4a85,
0xd0d06bbb, 0xefef2ac5, 0xaaaae54f, 0xfbfb16ed, 0x4343c586, 0x4d4dd79a, 0x33335566, 0x85859411,
0x4545cf8a, 0xf9f910e9, 0x02020604, 0x7f7f81fe, 0x5050f0a0, 0x3c3c4478, 0x9f9fba25, 0xa8a8e34b,
0x5151f3a2, 0xa3a3fe5d, 0x4040c080, 0x8f8f8a05, 0x9292ad3f, 0x9d9dbc21, 0x38384870, 0xf5f504f1,
0xbcbcdf63, 0xb6b6c177, 0xdada75af, 0x21216342, 0x10103020, 0xffff1ae5, 0xf3f30efd, 0xd2d26dbf,
0xcdcd4c81, 0x0c0c1418, 0x13133526, 0xecec2fc3, 0x5f5fe1be, 0x9797a235, 0x4444cc88, 0x1717392e,
0xc4c45793, 0xa7a7f255, 0x7e7e82fc, 0x3d3d477a, 0x6464acc8, 0x5d5de7ba, 0x19192b32, 0x737395e6,
0x6060a0c0, 0x81819819, 0x4f4fd19e, 0xdcdc7fa3, 0x22226644, 0x2a2a7e54, 0x9090ab3b, 0x8888830b,
0x4646ca8c, 0xeeee29c7, 0xb8b8d36b, 0x14143c28, 0xdede79a7, 0x5e5ee2bc, 0x0b0b1d16, 0xdbdb76ad,
0xe0e03bdb, 0x32325664, 0x3a3a4e74, 0x0a0a1e14, 0x4949db92, 0x06060a0c, 0x24246c48, 0x5c5ce4b8,
0xc2c25d9f, 0xd3d36ebd, 0xacacef43, 0x6262a6c4, 0x9191a839, 0x9595a431, 0xe4e437d3, 0x79798bf2,
0xe7e732d5, 0xc8c8438b, 0x3737596e, 0x6d6db7da, 0x8d8d8c01, 0xd5d564b1, 0x4e4ed29c, 0xa9a9e049,
0x6c6cb4d8, 0x5656faac, 0xf4f407f3, 0xeaea25cf, 0x6565afca, 0x7a7a8ef4, 0xaeaee947, 0x08081810,
0xbabad56f, 0x787888f0, 0x25256f4a, 0x2e2e725c, 0x1c1c2438, 0xa6a6f157, 0xb4b4c773, 0xc6c65197,
0xe8e823cb, 0xdddd7ca1, 0x74749ce8, 0x1f1f213e, 0x4b4bdd96, 0xbdbddc61, 0x8b8b860d, 0x8a8a850f,
0x707090e0, 0x3e3e427c, 0xb5b5c471, 0x6666aacc, 0x4848d890, 0x03030506, 0xf6f601f7, 0x0e0e121c,
0x6161a3c2, 0x35355f6a, 0x5757f9ae, 0xb9b9d069, 0x86869117, 0xc1c15899, 0x1d1d273a, 0x9e9eb927,
0xe1e138d9, 0xf8f813eb, 0x9898b32b, 0x11113322, 0x6969bbd2, 0xd9d970a9, 0x8e8e8907, 0x9494a733,
0x9b9bb62d, 0x1e1e223c, 0x87879215, 0xe9e920c9, 0xcece4987, 0x5555ffaa, 0x28287850, 0xdfdf7aa5,
0x8c8c8f03, 0xa1a1f859, 0x89898009, 0x0d0d171a, 0xbfbfda65, 0xe6e631d7, 0x4242c684, 0x6868b8d0,
0x4141c382, 0x9999b029, 0x2d2d775a, 0x0f0f111e, 0xb0b0cb7b, 0x5454fca8, 0xbbbbd66d, 0x16163a2c,
}
// Lookup tables for decryption.
// These can be recomputed by adapting the tests in aes_test.go.
var td0 = [256]uint32{
0x51f4a750, 0x7e416553, 0x1a17a4c3, 0x3a275e96, 0x3bab6bcb, 0x1f9d45f1, 0xacfa58ab, 0x4be30393,
0x2030fa55, 0xad766df6, 0x88cc7691, 0xf5024c25, 0x4fe5d7fc, 0xc52acbd7, 0x26354480, 0xb562a38f,
0xdeb15a49, 0x25ba1b67, 0x45ea0e98, 0x5dfec0e1, 0xc32f7502, 0x814cf012, 0x8d4697a3, 0x6bd3f9c6,
0x038f5fe7, 0x15929c95, 0xbf6d7aeb, 0x955259da, 0xd4be832d, 0x587421d3, 0x49e06929, 0x8ec9c844,
0x75c2896a, 0xf48e7978, 0x99583e6b, 0x27b971dd, 0xbee14fb6, 0xf088ad17, 0xc920ac66, 0x7dce3ab4,
0x63df4a18, 0xe51a3182, 0x97513360, 0x62537f45, 0xb16477e0, 0xbb6bae84, 0xfe81a01c, 0xf9082b94,
0x70486858, 0x8f45fd19, 0x94de6c87, 0x527bf8b7, 0xab73d323, 0x724b02e2, 0xe31f8f57, 0x6655ab2a,
0xb2eb2807, 0x2fb5c203, 0x86c57b9a, 0xd33708a5, 0x302887f2, 0x23bfa5b2, 0x02036aba, 0xed16825c,
0x8acf1c2b, 0xa779b492, 0xf307f2f0, 0x4e69e2a1, 0x65daf4cd, 0x0605bed5, 0xd134621f, 0xc4a6fe8a,
0x342e539d, 0xa2f355a0, 0x058ae132, 0xa4f6eb75, 0x0b83ec39, 0x4060efaa, 0x5e719f06, 0xbd6e1051,
0x3e218af9, 0x96dd063d, 0xdd3e05ae, 0x4de6bd46, 0x91548db5, 0x71c45d05, 0x0406d46f, 0x605015ff,
0x1998fb24, 0xd6bde997, 0x894043cc, 0x67d99e77, 0xb0e842bd, 0x07898b88, 0xe7195b38, 0x79c8eedb,
0xa17c0a47, 0x7c420fe9, 0xf8841ec9, 0x00000000, 0x09808683, 0x322bed48, 0x1e1170ac, 0x6c5a724e,
0xfd0efffb, 0x0f853856, 0x3daed51e, 0x362d3927, 0x0a0fd964, 0x685ca621, 0x9b5b54d1, 0x24362e3a,
0x0c0a67b1, 0x9357e70f, 0xb4ee96d2, 0x1b9b919e, 0x80c0c54f, 0x61dc20a2, 0x5a774b69, 0x1c121a16,
0xe293ba0a, 0xc0a02ae5, 0x3c22e043, 0x121b171d, 0x0e090d0b, 0xf28bc7ad, 0x2db6a8b9, 0x141ea9c8,
0x57f11985, 0xaf75074c, 0xee99ddbb, 0xa37f60fd, 0xf701269f, 0x5c72f5bc, 0x44663bc5, 0x5bfb7e34,
0x8b432976, 0xcb23c6dc, 0xb6edfc68, 0xb8e4f163, 0xd731dcca, 0x42638510, 0x13972240, 0x84c61120,
0x854a247d, 0xd2bb3df8, 0xaef93211, 0xc729a16d, 0x1d9e2f4b, 0xdcb230f3, 0x0d8652ec, 0x77c1e3d0,
0x2bb3166c, 0xa970b999, 0x119448fa, 0x47e96422, 0xa8fc8cc4, 0xa0f03f1a, 0x567d2cd8, 0x223390ef,
0x87494ec7, 0xd938d1c1, 0x8ccaa2fe, 0x98d40b36, 0xa6f581cf, 0xa57ade28, 0xdab78e26, 0x3fadbfa4,
0x2c3a9de4, 0x5078920d, 0x6a5fcc9b, 0x547e4662, 0xf68d13c2, 0x90d8b8e8, 0x2e39f75e, 0x82c3aff5,
0x9f5d80be, 0x69d0937c, 0x6fd52da9, 0xcf2512b3, 0xc8ac993b, 0x10187da7, 0xe89c636e, 0xdb3bbb7b,
0xcd267809, 0x6e5918f4, 0xec9ab701, 0x834f9aa8, 0xe6956e65, 0xaaffe67e, 0x21bccf08, 0xef15e8e6,
0xbae79bd9, 0x4a6f36ce, 0xea9f09d4, 0x29b07cd6, 0x31a4b2af, 0x2a3f2331, 0xc6a59430, 0x35a266c0,
0x744ebc37, 0xfc82caa6, 0xe090d0b0, 0x33a7d815, 0xf104984a, 0x41ecdaf7, 0x7fcd500e, 0x1791f62f,
0x764dd68d, 0x43efb04d, 0xccaa4d54, 0xe49604df, 0x9ed1b5e3, 0x4c6a881b, 0xc12c1fb8, 0x4665517f,
0x9d5eea04, 0x018c355d, 0xfa877473, 0xfb0b412e, 0xb3671d5a, 0x92dbd252, 0xe9105633, 0x6dd64713,
0x9ad7618c, 0x37a10c7a, 0x59f8148e, 0xeb133c89, 0xcea927ee, 0xb761c935, 0xe11ce5ed, 0x7a47b13c,
0x9cd2df59, 0x55f2733f, 0x1814ce79, 0x73c737bf, 0x53f7cdea, 0x5ffdaa5b, 0xdf3d6f14, 0x7844db86,
0xcaaff381, 0xb968c43e, 0x3824342c, 0xc2a3405f, 0x161dc372, 0xbce2250c, 0x283c498b, 0xff0d9541,
0x39a80171, 0x080cb3de, 0xd8b4e49c, 0x6456c190, 0x7bcb8461, 0xd532b670, 0x486c5c74, 0xd0b85742,
}
var td1 = [256]uint32{
0x5051f4a7, 0x537e4165, 0xc31a17a4, 0x963a275e, 0xcb3bab6b, 0xf11f9d45, 0xabacfa58, 0x934be303,
0x552030fa, 0xf6ad766d, 0x9188cc76, 0x25f5024c, 0xfc4fe5d7, 0xd7c52acb, 0x80263544, 0x8fb562a3,
0x49deb15a, 0x6725ba1b, 0x9845ea0e, 0xe15dfec0, 0x02c32f75, 0x12814cf0, 0xa38d4697, 0xc66bd3f9,
0xe7038f5f, 0x9515929c, 0xebbf6d7a, 0xda955259, 0x2dd4be83, 0xd3587421, 0x2949e069, 0x448ec9c8,
0x6a75c289, 0x78f48e79, 0x6b99583e, 0xdd27b971, 0xb6bee14f, 0x17f088ad, 0x66c920ac, 0xb47dce3a,
0x1863df4a, 0x82e51a31, 0x60975133, 0x4562537f, 0xe0b16477, 0x84bb6bae, 0x1cfe81a0, 0x94f9082b,
0x58704868, 0x198f45fd, 0x8794de6c, 0xb7527bf8, 0x23ab73d3, 0xe2724b02, 0x57e31f8f, 0x2a6655ab,
0x07b2eb28, 0x032fb5c2, 0x9a86c57b, 0xa5d33708, 0xf2302887, 0xb223bfa5, 0xba02036a, 0x5ced1682,
0x2b8acf1c, 0x92a779b4, 0xf0f307f2, 0xa14e69e2, 0xcd65daf4, 0xd50605be, 0x1fd13462, 0x8ac4a6fe,
0x9d342e53, 0xa0a2f355, 0x32058ae1, 0x75a4f6eb, 0x390b83ec, 0xaa4060ef, 0x065e719f, 0x51bd6e10,
0xf93e218a, 0x3d96dd06, 0xaedd3e05, 0x464de6bd, 0xb591548d, 0x0571c45d, 0x6f0406d4, 0xff605015,
0x241998fb, 0x97d6bde9, 0xcc894043, 0x7767d99e, 0xbdb0e842, 0x8807898b, 0x38e7195b, 0xdb79c8ee,
0x47a17c0a, 0xe97c420f, 0xc9f8841e, 0x00000000, 0x83098086, 0x48322bed, 0xac1e1170, 0x4e6c5a72,
0xfbfd0eff, 0x560f8538, 0x1e3daed5, 0x27362d39, 0x640a0fd9, 0x21685ca6, 0xd19b5b54, 0x3a24362e,
0xb10c0a67, 0x0f9357e7, 0xd2b4ee96, 0x9e1b9b91, 0x4f80c0c5, 0xa261dc20, 0x695a774b, 0x161c121a,
0x0ae293ba, 0xe5c0a02a, 0x433c22e0, 0x1d121b17, 0x0b0e090d, 0xadf28bc7, 0xb92db6a8, 0xc8141ea9,
0x8557f119, 0x4caf7507, 0xbbee99dd, 0xfda37f60, 0x9ff70126, 0xbc5c72f5, 0xc544663b, 0x345bfb7e,
0x768b4329, 0xdccb23c6, 0x68b6edfc, 0x63b8e4f1, 0xcad731dc, 0x10426385, 0x40139722, 0x2084c611,
0x7d854a24, 0xf8d2bb3d, 0x11aef932, 0x6dc729a1, 0x4b1d9e2f, 0xf3dcb230, 0xec0d8652, 0xd077c1e3,
0x6c2bb316, 0x99a970b9, 0xfa119448, 0x2247e964, 0xc4a8fc8c, 0x1aa0f03f, 0xd8567d2c, 0xef223390,
0xc787494e, 0xc1d938d1, 0xfe8ccaa2, 0x3698d40b, 0xcfa6f581, 0x28a57ade, 0x26dab78e, 0xa43fadbf,
0xe42c3a9d, 0x0d507892, 0x9b6a5fcc, 0x62547e46, 0xc2f68d13, 0xe890d8b8, 0x5e2e39f7, 0xf582c3af,
0xbe9f5d80, 0x7c69d093, 0xa96fd52d, 0xb3cf2512, 0x3bc8ac99, 0xa710187d, 0x6ee89c63, 0x7bdb3bbb,
0x09cd2678, 0xf46e5918, 0x01ec9ab7, 0xa8834f9a, 0x65e6956e, 0x7eaaffe6, 0x0821bccf, 0xe6ef15e8,
0xd9bae79b, 0xce4a6f36, 0xd4ea9f09, 0xd629b07c, 0xaf31a4b2, 0x312a3f23, 0x30c6a594, 0xc035a266,
0x37744ebc, 0xa6fc82ca, 0xb0e090d0, 0x1533a7d8, 0x4af10498, 0xf741ecda, 0x0e7fcd50, 0x2f1791f6,
0x8d764dd6, 0x4d43efb0, 0x54ccaa4d, 0xdfe49604, 0xe39ed1b5, 0x1b4c6a88, 0xb8c12c1f, 0x7f466551,
0x049d5eea, 0x5d018c35, 0x73fa8774, 0x2efb0b41, 0x5ab3671d, 0x5292dbd2, 0x33e91056, 0x136dd647,
0x8c9ad761, 0x7a37a10c, 0x8e59f814, 0x89eb133c, 0xeecea927, 0x35b761c9, 0xede11ce5, 0x3c7a47b1,
0x599cd2df, 0x3f55f273, 0x791814ce, 0xbf73c737, 0xea53f7cd, 0x5b5ffdaa, 0x14df3d6f, 0x867844db,
0x81caaff3, 0x3eb968c4, 0x2c382434, 0x5fc2a340, 0x72161dc3, 0x0cbce225, 0x8b283c49, 0x41ff0d95,
0x7139a801, 0xde080cb3, 0x9cd8b4e4, 0x906456c1, 0x617bcb84, 0x70d532b6, 0x74486c5c, 0x42d0b857,
}
var td2 = [256]uint32{
0xa75051f4, 0x65537e41, 0xa4c31a17, 0x5e963a27, 0x6bcb3bab, 0x45f11f9d, 0x58abacfa, 0x03934be3,
0xfa552030, 0x6df6ad76, 0x769188cc, 0x4c25f502, 0xd7fc4fe5, 0xcbd7c52a, 0x44802635, 0xa38fb562,
0x5a49deb1, 0x1b6725ba, 0x0e9845ea, 0xc0e15dfe, 0x7502c32f, 0xf012814c, 0x97a38d46, 0xf9c66bd3,
0x5fe7038f, 0x9c951592, 0x7aebbf6d, 0x59da9552, 0x832dd4be, 0x21d35874, 0x692949e0, 0xc8448ec9,
0x896a75c2, 0x7978f48e, 0x3e6b9958, 0x71dd27b9, 0x4fb6bee1, 0xad17f088, 0xac66c920, 0x3ab47dce,
0x4a1863df, 0x3182e51a, 0x33609751, 0x7f456253, 0x77e0b164, 0xae84bb6b, 0xa01cfe81, 0x2b94f908,
0x68587048, 0xfd198f45, 0x6c8794de, 0xf8b7527b, 0xd323ab73, 0x02e2724b, 0x8f57e31f, 0xab2a6655,
0x2807b2eb, 0xc2032fb5, 0x7b9a86c5, 0x08a5d337, 0x87f23028, 0xa5b223bf, 0x6aba0203, 0x825ced16,
0x1c2b8acf, 0xb492a779, 0xf2f0f307, 0xe2a14e69, 0xf4cd65da, 0xbed50605, 0x621fd134, 0xfe8ac4a6,
0x539d342e, 0x55a0a2f3, 0xe132058a, 0xeb75a4f6, 0xec390b83, 0xefaa4060, 0x9f065e71, 0x1051bd6e,
0x8af93e21, 0x063d96dd, 0x05aedd3e, 0xbd464de6, 0x8db59154, 0x5d0571c4, 0xd46f0406, 0x15ff6050,
0xfb241998, 0xe997d6bd, 0x43cc8940, 0x9e7767d9, 0x42bdb0e8, 0x8b880789, 0x5b38e719, 0xeedb79c8,
0x0a47a17c, 0x0fe97c42, 0x1ec9f884, 0x00000000, 0x86830980, 0xed48322b, 0x70ac1e11, 0x724e6c5a,
0xfffbfd0e, 0x38560f85, 0xd51e3dae, 0x3927362d, 0xd9640a0f, 0xa621685c, 0x54d19b5b, 0x2e3a2436,
0x67b10c0a, 0xe70f9357, 0x96d2b4ee, 0x919e1b9b, 0xc54f80c0, 0x20a261dc, 0x4b695a77, 0x1a161c12,
0xba0ae293, 0x2ae5c0a0, 0xe0433c22, 0x171d121b, 0x0d0b0e09, 0xc7adf28b, 0xa8b92db6, 0xa9c8141e,
0x198557f1, 0x074caf75, 0xddbbee99, 0x60fda37f, 0x269ff701, 0xf5bc5c72, 0x3bc54466, 0x7e345bfb,
0x29768b43, 0xc6dccb23, 0xfc68b6ed, 0xf163b8e4, 0xdccad731, 0x85104263, 0x22401397, 0x112084c6,
0x247d854a, 0x3df8d2bb, 0x3211aef9, 0xa16dc729, 0x2f4b1d9e, 0x30f3dcb2, 0x52ec0d86, 0xe3d077c1,
0x166c2bb3, 0xb999a970, 0x48fa1194, 0x642247e9, 0x8cc4a8fc, 0x3f1aa0f0, 0x2cd8567d, 0x90ef2233,
0x4ec78749, 0xd1c1d938, 0xa2fe8cca, 0x0b3698d4, 0x81cfa6f5, 0xde28a57a, 0x8e26dab7, 0xbfa43fad,
0x9de42c3a, 0x920d5078, 0xcc9b6a5f, 0x4662547e, 0x13c2f68d, 0xb8e890d8, 0xf75e2e39, 0xaff582c3,
0x80be9f5d, 0x937c69d0, 0x2da96fd5, 0x12b3cf25, 0x993bc8ac, 0x7da71018, 0x636ee89c, 0xbb7bdb3b,
0x7809cd26, 0x18f46e59, 0xb701ec9a, 0x9aa8834f, 0x6e65e695, 0xe67eaaff, 0xcf0821bc, 0xe8e6ef15,
0x9bd9bae7, 0x36ce4a6f, 0x09d4ea9f, 0x7cd629b0, 0xb2af31a4, 0x23312a3f, 0x9430c6a5, 0x66c035a2,
0xbc37744e, 0xcaa6fc82, 0xd0b0e090, 0xd81533a7, 0x984af104, 0xdaf741ec, 0x500e7fcd, 0xf62f1791,
0xd68d764d, 0xb04d43ef, 0x4d54ccaa, 0x04dfe496, 0xb5e39ed1, 0x881b4c6a, 0x1fb8c12c, 0x517f4665,
0xea049d5e, 0x355d018c, 0x7473fa87, 0x412efb0b, 0x1d5ab367, 0xd25292db, 0x5633e910, 0x47136dd6,
0x618c9ad7, 0x0c7a37a1, 0x148e59f8, 0x3c89eb13, 0x27eecea9, 0xc935b761, 0xe5ede11c, 0xb13c7a47,
0xdf599cd2, 0x733f55f2, 0xce791814, 0x37bf73c7, 0xcdea53f7, 0xaa5b5ffd, 0x6f14df3d, 0xdb867844,
0xf381caaf, 0xc43eb968, 0x342c3824, 0x405fc2a3, 0xc372161d, 0x250cbce2, 0x498b283c, 0x9541ff0d,
0x017139a8, 0xb3de080c, 0xe49cd8b4, 0xc1906456, 0x84617bcb, 0xb670d532, 0x5c74486c, 0x5742d0b8,
}
var td3 = [256]uint32{
0xf4a75051, 0x4165537e, 0x17a4c31a, 0x275e963a, 0xab6bcb3b, 0x9d45f11f, 0xfa58abac, 0xe303934b,
0x30fa5520, 0x766df6ad, 0xcc769188, 0x024c25f5, 0xe5d7fc4f, 0x2acbd7c5, 0x35448026, 0x62a38fb5,
0xb15a49de, 0xba1b6725, 0xea0e9845, 0xfec0e15d, 0x2f7502c3, 0x4cf01281, 0x4697a38d, 0xd3f9c66b,
0x8f5fe703, 0x929c9515, 0x6d7aebbf, 0x5259da95, 0xbe832dd4, 0x7421d358, 0xe0692949, 0xc9c8448e,
0xc2896a75, 0x8e7978f4, 0x583e6b99, 0xb971dd27, 0xe14fb6be, 0x88ad17f0, 0x20ac66c9, 0xce3ab47d,
0xdf4a1863, 0x1a3182e5, 0x51336097, 0x537f4562, 0x6477e0b1, 0x6bae84bb, 0x81a01cfe, 0x082b94f9,
0x48685870, 0x45fd198f, 0xde6c8794, 0x7bf8b752, 0x73d323ab, 0x4b02e272, 0x1f8f57e3, 0x55ab2a66,
0xeb2807b2, 0xb5c2032f, 0xc57b9a86, 0x3708a5d3, 0x2887f230, 0xbfa5b223, 0x036aba02, 0x16825ced,
0xcf1c2b8a, 0x79b492a7, 0x07f2f0f3, 0x69e2a14e, 0xdaf4cd65, 0x05bed506, 0x34621fd1, 0xa6fe8ac4,
0x2e539d34, 0xf355a0a2, 0x8ae13205, 0xf6eb75a4, 0x83ec390b, 0x60efaa40, 0x719f065e, 0x6e1051bd,
0x218af93e, 0xdd063d96, 0x3e05aedd, 0xe6bd464d, 0x548db591, 0xc45d0571, 0x06d46f04, 0x5015ff60,
0x98fb2419, 0xbde997d6, 0x4043cc89, 0xd99e7767, 0xe842bdb0, 0x898b8807, 0x195b38e7, 0xc8eedb79,
0x7c0a47a1, 0x420fe97c, 0x841ec9f8, 0x00000000, 0x80868309, 0x2bed4832, 0x1170ac1e, 0x5a724e6c,
0x0efffbfd, 0x8538560f, 0xaed51e3d, 0x2d392736, 0x0fd9640a, 0x5ca62168, 0x5b54d19b, 0x362e3a24,
0x0a67b10c, 0x57e70f93, 0xee96d2b4, 0x9b919e1b, 0xc0c54f80, 0xdc20a261, 0x774b695a, 0x121a161c,
0x93ba0ae2, 0xa02ae5c0, 0x22e0433c, 0x1b171d12, 0x090d0b0e, 0x8bc7adf2, 0xb6a8b92d, 0x1ea9c814,
0xf1198557, 0x75074caf, 0x99ddbbee, 0x7f60fda3, 0x01269ff7, 0x72f5bc5c, 0x663bc544, 0xfb7e345b,
0x4329768b, 0x23c6dccb, 0xedfc68b6, 0xe4f163b8, 0x31dccad7, 0x63851042, 0x97224013, 0xc6112084,
0x4a247d85, 0xbb3df8d2, 0xf93211ae, 0x29a16dc7, 0x9e2f4b1d, 0xb230f3dc, 0x8652ec0d, 0xc1e3d077,
0xb3166c2b, 0x70b999a9, 0x9448fa11, 0xe9642247, 0xfc8cc4a8, 0xf03f1aa0, 0x7d2cd856, 0x3390ef22,
0x494ec787, 0x38d1c1d9, 0xcaa2fe8c, 0xd40b3698, 0xf581cfa6, 0x7ade28a5, 0xb78e26da, 0xadbfa43f,
0x3a9de42c, 0x78920d50, 0x5fcc9b6a, 0x7e466254, 0x8d13c2f6, 0xd8b8e890, 0x39f75e2e, 0xc3aff582,
0x5d80be9f, 0xd0937c69, 0xd52da96f, 0x2512b3cf, 0xac993bc8, 0x187da710, 0x9c636ee8, 0x3bbb7bdb,
0x267809cd, 0x5918f46e, 0x9ab701ec, 0x4f9aa883, 0x956e65e6, 0xffe67eaa, 0xbccf0821, 0x15e8e6ef,
0xe79bd9ba, 0x6f36ce4a, 0x9f09d4ea, 0xb07cd629, 0xa4b2af31, 0x3f23312a, 0xa59430c6, 0xa266c035,
0x4ebc3774, 0x82caa6fc, 0x90d0b0e0, 0xa7d81533, 0x04984af1, 0xecdaf741, 0xcd500e7f, 0x91f62f17,
0x4dd68d76, 0xefb04d43, 0xaa4d54cc, 0x9604dfe4, 0xd1b5e39e, 0x6a881b4c, 0x2c1fb8c1, 0x65517f46,
0x5eea049d, 0x8c355d01, 0x877473fa, 0x0b412efb, 0x671d5ab3, 0xdbd25292, 0x105633e9, 0xd647136d,
0xd7618c9a, 0xa10c7a37, 0xf8148e59, 0x133c89eb, 0xa927eece, 0x61c935b7, 0x1ce5ede1, 0x47b13c7a,
0xd2df599c, 0xf2733f55, 0x14ce7918, 0xc737bf73, 0xf7cdea53, 0xfdaa5b5f, 0x3d6f14df, 0x44db8678,
0xaff381ca, 0x68c43eb9, 0x24342c38, 0xa3405fc2, 0x1dc37216, 0xe2250cbc, 0x3c498b28, 0x0d9541ff,
0xa8017139, 0x0cb3de08, 0xb4e49cd8, 0x56c19064, 0xcb84617b, 0x32b670d5, 0x6c5c7448, 0xb85742d0,
}

View File

@ -1,401 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package aes12
import (
"crypto/subtle"
"errors"
)
// AEAD is a cipher mode providing authenticated encryption with associated
// data. For a description of the methodology, see
// https://en.wikipedia.org/wiki/Authenticated_encryption
type AEAD interface {
// NonceSize returns the size of the nonce that must be passed to Seal
// and Open.
NonceSize() int
// Overhead returns the maximum difference between the lengths of a
// plaintext and its ciphertext.
Overhead() int
// Seal encrypts and authenticates plaintext, authenticates the
// additional data and appends the result to dst, returning the updated
// slice. The nonce must be NonceSize() bytes long and unique for all
// time, for a given key.
//
// The plaintext and dst may alias exactly or not at all. To reuse
// plaintext's storage for the encrypted output, use plaintext[:0] as dst.
Seal(dst, nonce, plaintext, additionalData []byte) []byte
// Open decrypts and authenticates ciphertext, authenticates the
// additional data and, if successful, appends the resulting plaintext
// to dst, returning the updated slice. The nonce must be NonceSize()
// bytes long and both it and the additional data must match the
// value passed to Seal.
//
// The ciphertext and dst may alias exactly or not at all. To reuse
// ciphertext's storage for the decrypted output, use ciphertext[:0] as dst.
//
// Even if the function fails, the contents of dst, up to its capacity,
// may be overwritten.
Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error)
}
// gcmAble is an interface implemented by ciphers that have a specific optimized
// implementation of GCM, like crypto/aes. NewGCM will check for this interface
// and return the specific AEAD if found.
type gcmAble interface {
NewGCM(int) (AEAD, error)
}
// gcmFieldElement represents a value in GF(2¹²⁸). In order to reflect the GCM
// standard and make getUint64 suitable for marshaling these values, the bits
// are stored backwards. For example:
// the coefficient of x⁰ can be obtained by v.low >> 63.
// the coefficient of x⁶³ can be obtained by v.low & 1.
// the coefficient of x⁶⁴ can be obtained by v.high >> 63.
// the coefficient of x¹²⁷ can be obtained by v.high & 1.
type gcmFieldElement struct {
low, high uint64
}
// gcm represents a Galois Counter Mode with a specific key. See
// http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf
type gcm struct {
cipher Block
nonceSize int
// productTable contains the first sixteen powers of the key, H.
// However, they are in bit reversed order. See NewGCMWithNonceSize.
productTable [16]gcmFieldElement
}
// NewGCM returns the given 128-bit, block cipher wrapped in Galois Counter Mode
// with the standard nonce length.
func NewGCM(cipher Block) (AEAD, error) {
return NewGCMWithNonceSize(cipher, gcmStandardNonceSize)
}
// NewGCMWithNonceSize returns the given 128-bit, block cipher wrapped in Galois
// Counter Mode, which accepts nonces of the given length.
//
// Only use this function if you require compatibility with an existing
// cryptosystem that uses non-standard nonce lengths. All other users should use
// NewGCM, which is faster and more resistant to misuse.
func NewGCMWithNonceSize(cipher Block, size int) (AEAD, error) {
if cipher, ok := cipher.(gcmAble); ok {
return cipher.NewGCM(size)
}
if cipher.BlockSize() != gcmBlockSize {
return nil, errors.New("cipher: NewGCM requires 128-bit block cipher")
}
var key [gcmBlockSize]byte
cipher.Encrypt(key[:], key[:])
g := &gcm{cipher: cipher, nonceSize: size}
// We precompute 16 multiples of |key|. However, when we do lookups
// into this table we'll be using bits from a field element and
// therefore the bits will be in the reverse order. So normally one
// would expect, say, 4*key to be in index 4 of the table but due to
// this bit ordering it will actually be in index 0010 (base 2) = 2.
x := gcmFieldElement{
getUint64(key[:8]),
getUint64(key[8:]),
}
g.productTable[reverseBits(1)] = x
for i := 2; i < 16; i += 2 {
g.productTable[reverseBits(i)] = gcmDouble(&g.productTable[reverseBits(i/2)])
g.productTable[reverseBits(i+1)] = gcmAdd(&g.productTable[reverseBits(i)], &x)
}
return g, nil
}
const (
gcmBlockSize = 16
gcmTagSize = 12
gcmStandardNonceSize = 12
)
func (g *gcm) NonceSize() int {
return g.nonceSize
}
func (*gcm) Overhead() int {
return gcmTagSize
}
func (g *gcm) Seal(dst, nonce, plaintext, data []byte) []byte {
if len(nonce) != g.nonceSize {
panic("cipher: incorrect nonce length given to GCM")
}
ret, out := sliceForAppend(dst, len(plaintext)+gcmTagSize)
var counter, tagMask [gcmBlockSize]byte
g.deriveCounter(&counter, nonce)
g.cipher.Encrypt(tagMask[:], counter[:])
gcmInc32(&counter)
g.counterCrypt(out, plaintext, &counter)
tag := make([]byte, 16)
g.auth(tag, out[:len(plaintext)], data, &tagMask)
copy(ret[len(ret)-12:], tag)
return ret
}
var errOpen = errors.New("cipher: message authentication failed")
func (g *gcm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
if len(nonce) != g.nonceSize {
panic("cipher: incorrect nonce length given to GCM")
}
if len(ciphertext) < gcmTagSize {
return nil, errOpen
}
tag := ciphertext[len(ciphertext)-gcmTagSize:]
ciphertext = ciphertext[:len(ciphertext)-gcmTagSize]
var counter, tagMask [gcmBlockSize]byte
g.deriveCounter(&counter, nonce)
g.cipher.Encrypt(tagMask[:], counter[:])
gcmInc32(&counter)
var expectedTag [gcmBlockSize]byte
g.auth(expectedTag[:], ciphertext, data, &tagMask)
ret, out := sliceForAppend(dst, len(ciphertext))
if subtle.ConstantTimeCompare(expectedTag[:gcmTagSize], tag) != 1 {
// The AESNI code decrypts and authenticates concurrently, and
// so overwrites dst in the event of a tag mismatch. That
// behaviour is mimicked here in order to be consistent across
// platforms.
for i := range out {
out[i] = 0
}
return nil, errOpen
}
g.counterCrypt(out, ciphertext, &counter)
return ret, nil
}
// reverseBits reverses the order of the bits of 4-bit number in i.
func reverseBits(i int) int {
i = ((i << 2) & 0xc) | ((i >> 2) & 0x3)
i = ((i << 1) & 0xa) | ((i >> 1) & 0x5)
return i
}
// gcmAdd adds two elements of GF(2¹²⁸) and returns the sum.
func gcmAdd(x, y *gcmFieldElement) gcmFieldElement {
// Addition in a characteristic 2 field is just XOR.
return gcmFieldElement{x.low ^ y.low, x.high ^ y.high}
}
// gcmDouble returns the result of doubling an element of GF(2¹²⁸).
func gcmDouble(x *gcmFieldElement) (double gcmFieldElement) {
msbSet := x.high&1 == 1
// Because of the bit-ordering, doubling is actually a right shift.
double.high = x.high >> 1
double.high |= x.low << 63
double.low = x.low >> 1
// If the most-significant bit was set before shifting then it,
// conceptually, becomes a term of x^128. This is greater than the
// irreducible polynomial so the result has to be reduced. The
// irreducible polynomial is 1+x+x^2+x^7+x^128. We can subtract that to
// eliminate the term at x^128 which also means subtracting the other
// four terms. In characteristic 2 fields, subtraction == addition ==
// XOR.
if msbSet {
double.low ^= 0xe100000000000000
}
return
}
var gcmReductionTable = []uint16{
0x0000, 0x1c20, 0x3840, 0x2460, 0x7080, 0x6ca0, 0x48c0, 0x54e0,
0xe100, 0xfd20, 0xd940, 0xc560, 0x9180, 0x8da0, 0xa9c0, 0xb5e0,
}
// mul sets y to y*H, where H is the GCM key, fixed during NewGCMWithNonceSize.
func (g *gcm) mul(y *gcmFieldElement) {
var z gcmFieldElement
for i := 0; i < 2; i++ {
word := y.high
if i == 1 {
word = y.low
}
// Multiplication works by multiplying z by 16 and adding in
// one of the precomputed multiples of H.
for j := 0; j < 64; j += 4 {
msw := z.high & 0xf
z.high >>= 4
z.high |= z.low << 60
z.low >>= 4
z.low ^= uint64(gcmReductionTable[msw]) << 48
// the values in |table| are ordered for
// little-endian bit positions. See the comment
// in NewGCMWithNonceSize.
t := &g.productTable[word&0xf]
z.low ^= t.low
z.high ^= t.high
word >>= 4
}
}
*y = z
}
// updateBlocks extends y with more polynomial terms from blocks, based on
// Horner's rule. There must be a multiple of gcmBlockSize bytes in blocks.
func (g *gcm) updateBlocks(y *gcmFieldElement, blocks []byte) {
for len(blocks) > 0 {
y.low ^= getUint64(blocks)
y.high ^= getUint64(blocks[8:])
g.mul(y)
blocks = blocks[gcmBlockSize:]
}
}
// update extends y with more polynomial terms from data. If data is not a
// multiple of gcmBlockSize bytes long then the remainder is zero padded.
func (g *gcm) update(y *gcmFieldElement, data []byte) {
fullBlocks := (len(data) >> 4) << 4
g.updateBlocks(y, data[:fullBlocks])
if len(data) != fullBlocks {
var partialBlock [gcmBlockSize]byte
copy(partialBlock[:], data[fullBlocks:])
g.updateBlocks(y, partialBlock[:])
}
}
// gcmInc32 treats the final four bytes of counterBlock as a big-endian value
// and increments it.
func gcmInc32(counterBlock *[16]byte) {
for i := gcmBlockSize - 1; i >= gcmBlockSize-4; i-- {
counterBlock[i]++
if counterBlock[i] != 0 {
break
}
}
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// counterCrypt crypts in to out using g.cipher in counter mode.
func (g *gcm) counterCrypt(out, in []byte, counter *[gcmBlockSize]byte) {
var mask [gcmBlockSize]byte
for len(in) >= gcmBlockSize {
g.cipher.Encrypt(mask[:], counter[:])
gcmInc32(counter)
xorWords(out, in, mask[:])
out = out[gcmBlockSize:]
in = in[gcmBlockSize:]
}
if len(in) > 0 {
g.cipher.Encrypt(mask[:], counter[:])
gcmInc32(counter)
xorBytes(out, in, mask[:])
}
}
// deriveCounter computes the initial GCM counter state from the given nonce.
// See NIST SP 800-38D, section 7.1. This assumes that counter is filled with
// zeros on entry.
func (g *gcm) deriveCounter(counter *[gcmBlockSize]byte, nonce []byte) {
// GCM has two modes of operation with respect to the initial counter
// state: a "fast path" for 96-bit (12-byte) nonces, and a "slow path"
// for nonces of other lengths. For a 96-bit nonce, the nonce, along
// with a four-byte big-endian counter starting at one, is used
// directly as the starting counter. For other nonce sizes, the counter
// is computed by passing it through the GHASH function.
if len(nonce) == gcmStandardNonceSize {
copy(counter[:], nonce)
counter[gcmBlockSize-1] = 1
} else {
var y gcmFieldElement
g.update(&y, nonce)
y.high ^= uint64(len(nonce)) * 8
g.mul(&y)
putUint64(counter[:8], y.low)
putUint64(counter[8:], y.high)
}
}
// auth calculates GHASH(ciphertext, additionalData), masks the result with
// tagMask and writes the result to out.
func (g *gcm) auth(out, ciphertext, additionalData []byte, tagMask *[gcmBlockSize]byte) {
var y gcmFieldElement
g.update(&y, additionalData)
g.update(&y, ciphertext)
y.low ^= uint64(len(additionalData)) * 8
y.high ^= uint64(len(ciphertext)) * 8
g.mul(&y)
putUint64(out, y.low)
putUint64(out[8:], y.high)
xorWords(out, out, tagMask[:])
}
func getUint64(data []byte) uint64 {
r := uint64(data[0])<<56 |
uint64(data[1])<<48 |
uint64(data[2])<<40 |
uint64(data[3])<<32 |
uint64(data[4])<<24 |
uint64(data[5])<<16 |
uint64(data[6])<<8 |
uint64(data[7])
return r
}
func putUint64(out []byte, v uint64) {
out[0] = byte(v >> 56)
out[1] = byte(v >> 48)
out[2] = byte(v >> 40)
out[3] = byte(v >> 32)
out[4] = byte(v >> 24)
out[5] = byte(v >> 16)
out[6] = byte(v >> 8)
out[7] = byte(v)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,84 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package aes12
import (
"runtime"
"unsafe"
)
const wordSize = int(unsafe.Sizeof(uintptr(0)))
const supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "amd64" || runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le" || runtime.GOARCH == "s390x"
// fastXORBytes xors in bulk. It only works on architectures that
// support unaligned read/writes.
func fastXORBytes(dst, a, b []byte) int {
n := len(a)
if len(b) < n {
n = len(b)
}
w := n / wordSize
if w > 0 {
dw := *(*[]uintptr)(unsafe.Pointer(&dst))
aw := *(*[]uintptr)(unsafe.Pointer(&a))
bw := *(*[]uintptr)(unsafe.Pointer(&b))
for i := 0; i < w; i++ {
dw[i] = aw[i] ^ bw[i]
}
}
for i := (n - n%wordSize); i < n; i++ {
dst[i] = a[i] ^ b[i]
}
return n
}
func safeXORBytes(dst, a, b []byte) int {
n := len(a)
if len(b) < n {
n = len(b)
}
for i := 0; i < n; i++ {
dst[i] = a[i] ^ b[i]
}
return n
}
// xorBytes xors the bytes in a and b. The destination is assumed to have enough
// space. Returns the number of bytes xor'd.
func xorBytes(dst, a, b []byte) int {
if supportsUnaligned {
return fastXORBytes(dst, a, b)
} else {
// TODO(hanwen): if (dst, a, b) have common alignment
// we could still try fastXORBytes. It is not clear
// how often this happens, and it's only worth it if
// the block encryption itself is hardware
// accelerated.
return safeXORBytes(dst, a, b)
}
}
// fastXORWords XORs multiples of 4 or 8 bytes (depending on architecture.)
// The arguments are assumed to be of equal length.
func fastXORWords(dst, a, b []byte) {
dw := *(*[]uintptr)(unsafe.Pointer(&dst))
aw := *(*[]uintptr)(unsafe.Pointer(&a))
bw := *(*[]uintptr)(unsafe.Pointer(&b))
n := len(b) / wordSize
for i := 0; i < n; i++ {
dw[i] = aw[i] ^ bw[i]
}
}
func xorWords(dst, a, b []byte) {
if supportsUnaligned {
fastXORWords(dst, a, b)
} else {
safeXORBytes(dst, a, b)
}
}

View File

@ -1,3 +0,0 @@
# fnv128a
Implementation of the FNV-1a 128bit hash in go

View File

@ -1,87 +0,0 @@
// Package fnv128a implements FNV-1 and FNV-1a, non-cryptographic hash functions
// created by Glenn Fowler, Landon Curt Noll, and Phong Vo.
// See https://en.wikipedia.org/wiki/Fowler-Noll-Vo_hash_function.
//
// Write() algorithm taken and modified from github.com/romain-jacotin/quic
package fnv128a
import "hash"
// Hash128 is the common interface implemented by all 128-bit hash functions.
type Hash128 interface {
hash.Hash
Sum128() (uint64, uint64)
}
type sum128a struct {
v0, v1, v2, v3 uint64
}
var _ Hash128 = &sum128a{}
// New1 returns a new 128-bit FNV-1a hash.Hash.
func New() Hash128 {
s := &sum128a{}
s.Reset()
return s
}
func (s *sum128a) Reset() {
s.v0 = 0x6295C58D
s.v1 = 0x62B82175
s.v2 = 0x07BB0142
s.v3 = 0x6C62272E
}
func (s *sum128a) Sum128() (uint64, uint64) {
return s.v3<<32 | s.v2, s.v1<<32 | s.v0
}
func (s *sum128a) Write(data []byte) (int, error) {
var t0, t1, t2, t3 uint64
const fnv128PrimeLow = 0x0000013B
const fnv128PrimeShift = 24
for _, v := range data {
// xor the bottom with the current octet
s.v0 ^= uint64(v)
// multiply by the 128 bit FNV magic prime mod 2^128
// fnv_prime = 309485009821345068724781371 (decimal)
// = 0x0000000001000000000000000000013B (hexadecimal)
// = 0x00000000 0x01000000 0x00000000 0x0000013B (in 4*32 words)
// = 0x0 1<<fnv128PrimeShift 0x0 fnv128PrimeLow
//
// fnv128PrimeLow = 0x0000013B
// fnv128PrimeShift = 24
// multiply by the lowest order digit base 2^32 and by the other non-zero digit
t0 = s.v0 * fnv128PrimeLow
t1 = s.v1 * fnv128PrimeLow
t2 = s.v2*fnv128PrimeLow + s.v0<<fnv128PrimeShift
t3 = s.v3*fnv128PrimeLow + s.v1<<fnv128PrimeShift
// propagate carries
t1 += (t0 >> 32)
t2 += (t1 >> 32)
t3 += (t2 >> 32)
s.v0 = t0 & 0xffffffff
s.v1 = t1 & 0xffffffff
s.v2 = t2 & 0xffffffff
s.v3 = t3 // & 0xffffffff
// Doing a s.v3 &= 0xffffffff is not really needed since it simply
// removes multiples of 2^128. We can discard these excess bits
// outside of the loop when writing the hash in Little Endian.
}
return len(data), nil
}
func (s *sum128a) Size() int { return 16 }
func (s *sum128a) BlockSize() int { return 1 }
func (s *sum128a) Sum(in []byte) []byte {
panic("FNV: not supported")
}

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2016 Lucas Clemente
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,3 +0,0 @@
# certsets
Common certificate sets for quic-go

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,34 +0,0 @@
#!/usr/bin/env ruby
#
# Extract the common certificate sets from the chromium source to go
#
# Usage:
# createCertSets.rb 1 ~/src/chromium/src/net/quic/crypto/common_cert_set_1*
# createCertSets.rb 2 ~/src/chromium/src/net/quic/crypto/common_cert_set_2*
n = ARGV.shift
mainFile = ARGV.shift
dataFiles = ARGV
data = "package certsets\n"
data += File.read(mainFile)
data += (dataFiles.map{|p| File.read(p)}).join
# Good enough
data.gsub!(/\/\*(.*?)\*\//m, '')
data.gsub!(/^#include.+/, '')
data.gsub!(/^#if 0(.*?)\n#endif/m, '')
data.gsub!(/^static const size_t kNumCerts.+/, '')
data.gsub!(/static const size_t kLens[^}]+};/m, '')
data.gsub!('static const unsigned char* const kCerts[] = {', "var CertSet#{n} = [][]byte{")
data.gsub!('static const uint64_t kHash = UINT64_C', "const CertSet#{n}Hash uint64 = ")
data.gsub!(/static const unsigned char kDERCert(\d+)\[\] = /, "var kDERCert\\1 = []byte")
data.gsub!(/kDERCert(\d+)/, "certSet#{n}Cert\\1")
File.write("cert_set_#{n}.go", data)
system("gofmt -w -s cert_set_#{n}.go")

View File

@ -1,6 +1,30 @@
# Changelog
## v0.6.0 (unreleased)
## v0.10.0 (2018-08-28)
- Add support for QUIC 44, drop support for QUIC 42.
## v0.9.0 (2018-08-15)
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
- Split Session.Close into one method for regular closing and one for closing with an error.
## v0.8.0 (2018-06-26)
- Add support for unidirectional streams (for IETF QUIC).
- Add a `quic.Config` option for the maximum number of incoming streams.
- Add support for QUIC 42 and 43.
- Add dial functions that use a context.
- Multiplex clients on a net.PacketConn, when using Dial(conn).
## v0.7.0 (2018-02-03)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
- Implement packet pacing.
## v0.6.0 (2017-12-12)
- Add support for QUIC 39, drop support for QUIC 35 - 37
- Added `quic.Config` options for maximal flow control windows

View File

@ -3,16 +3,24 @@
<img src="docs/quic.png" width=303 height=124>
[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/lucas-clemente/quic-go)
[![Linux Build Status](https://img.shields.io/travis/lucas-clemente/quic-go/master.svg?style=flat-square&label=linux+build)](https://travis-ci.org/lucas-clemente/quic-go)
[![Travis Build Status](https://img.shields.io/travis/lucas-clemente/quic-go/master.svg?style=flat-square&label=Travis+build)](https://travis-ci.org/lucas-clemente/quic-go)
[![CircleCI Build Status](https://img.shields.io/circleci/project/github/lucas-clemente/quic-go.svg?style=flat-square&label=CircleCI+build)](https://circleci.com/gh/lucas-clemente/quic-go)
[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/)
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go. It roughly implements the [IETF QUIC draft](https://github.com/quicwg/base-drafts), although we don't fully support any of the draft versions at the moment.
## Roadmap
## Version compatibility
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
Since quic-go is under active development, there's no guarantee that two builds of different commits are interoperable. The QUIC version used in the *master* branch is just a placeholder, and should not be considered stable.
If you want to use quic-go as a library in other projects, please consider using a [tagged release](https://github.com/lucas-clemente/quic-go/releases). These releases expose [experimental QUIC versions](https://github.com/quicwg/base-drafts/wiki/QUIC-Versions), which are guaranteed to be stable.
## Google QUIC
quic-go used to support both the QUIC versions supported by Google Chrome and QUIC as deployed on Google's servers, as well as IETF QUIC. Due to the divergence of the two protocols, we decided to not support both versions any more.
The *master* branch **only** supports IETF QUIC. For Google QUIC support, please refer to the [gquic branch](https://github.com/lucas-clemente/quic-go/tree/gquic).
## Guides
@ -26,31 +34,19 @@ Running tests:
go test ./...
### Running the example server
### HTTP mapping
go run example/main.go -www /var/www/
Using the `quic_client` from chromium:
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
### QUIC without HTTP/2
Take a look at [this echo example](example/echo/echo.go).
### Using the example client
go run example/client/main.go https://clemente.io
## Usage
### As a server
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))

View File

@ -1,7 +0,0 @@
package main
import (
_ "github.com/clipperhouse/linkedlist"
_ "github.com/clipperhouse/slice"
_ "github.com/clipperhouse/stringer"
)

View File

@ -1,34 +0,0 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
SendingAllowed() bool
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
ShouldSendRetransmittablePacket() bool
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm()
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
SetLowerLimit(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

View File

@ -1,34 +0,0 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
// +gen linkedlist
type Packet struct {
PacketNumber protocol.PacketNumber
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
}
// GetFramesForRetransmission gets all the frames for retransmission
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
var fs []wire.Frame
for _, frame := range p.Frames {
switch frame.(type) {
case *wire.AckFrame:
continue
case *wire.StopWaitingFrame:
continue
}
fs = append(fs, frame)
}
return fs
}

View File

@ -1,141 +0,0 @@
package ackhandler
import (
"errors"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
lowerLimit protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
version protocol.VersionNumber
}
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: protocol.AckSendDelay,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
if packetNumber == 0 {
return errInvalidPacketNumber
}
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
}
if packetNumber <= h.lowerLimit {
return nil
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil
}
// SetLowerLimit sets a lower limit for acking packets.
// Packets with packet numbers smaller or equal than p will not be acked.
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
h.lowerLimit = p
h.packetHistory.DeleteUpTo(p)
}
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
}
// always ack the first packet
if h.lastAck == nil {
h.ackQueued = true
}
if h.version < protocol.Version39 {
// Always send an ack every 20 packets in order to allow the peer to discard
// information from the SentPacketManager and provide an RTT measurement.
// From QUIC 39, this is not needed anymore, since the peer will regularly send a retransmittable packet.
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
h.ackQueued = true
}
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
h.ackQueued = true
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
return nil
}
ackRanges := h.packetHistory.GetAckRanges()
ack := &wire.AckFrame{
LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].First,
PacketReceivedTime: h.largestObservedReceivedTime,
}
if len(ackRanges) > 1 {
ack.AckRanges = ackRanges
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -1,455 +0,0 @@
package ackhandler
import (
"errors"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// The default RTT used before an RTT sample is taken.
// Note: This constant is also defined in the congestion package.
defaultInitialRTT = 100 * time.Millisecond
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
)
var (
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
// ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets
ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets")
// ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped
ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
)
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
skippedPackets []protocol.PacketNumber
numNonRetransmittablePackets int // number of non-retransmittable packets since the last retransmittable packet
LargestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
packetHistory *PacketList
stopWaitingManager stopWaitingManager
retransmissionQueue []*Packet
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
false, /* don't use reno since chromium doesn't (why?) */
protocol.InitialCongestionWindow,
protocol.DefaultMaxCongestionWindow,
)
return &sentPacketHandler{
packetHistory: NewPacketList(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
}
}
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
if f := h.packetHistory.Front(); f != nil {
return f.Value.PacketNumber - 1
}
return h.LargestAcked
}
func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool {
return h.numNonRetransmittablePackets >= protocol.MaxNonRetransmittablePackets
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
if packet.PacketNumber <= h.lastSentPacketNumber {
return errPacketNumberNotIncreasing
}
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
return ErrTooManyTrackedSentPackets
}
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p)
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
h.skippedPackets = h.skippedPackets[1:]
}
}
h.lastSentPacketNumber = packet.PacketNumber
now := time.Now()
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet)
h.numNonRetransmittablePackets = 0
} else {
h.numNonRetransmittablePackets++
}
h.congestion.OnPacketSent(
now,
h.bytesInFlight,
packet.PacketNumber,
packet.Length,
isRetransmittable,
)
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
if ackFrame.LargestAcked > h.lastSentPacketNumber {
return errAckForUnsentPacket
}
// duplicate or out-of-order ACK
if withPacketNumber <= h.largestReceivedPacketWithAck {
return ErrDuplicateOrOutOfOrderAck
}
h.largestReceivedPacketWithAck = withPacketNumber
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
if ackFrame.LargestAcked <= h.largestInOrderAcked() {
return nil
}
h.LargestAcked = ackFrame.LargestAcked
if h.skippedPacketsAcked(ackFrame) {
return ErrAckForSkippedPacket
}
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
if rttUpdated {
h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
if len(ackedPackets) > 0 {
for _, p := range ackedPackets {
if encLevel < p.Value.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
}
h.onPacketAcked(p)
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
h.detectLostPackets()
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
var ackedPackets []*PacketElement
ackRangeIndex := 0
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
packetNumber := packet.PacketNumber
// Ignore packets below the LowestAcked
if packetNumber < ackFrame.LowestAcked {
continue
}
// Break after LargestAcked is reached
if packetNumber > ackFrame.LargestAcked {
break
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if packetNumber >= ackRange.First { // packet i contained in ACK range
if packetNumber > ackRange.Last {
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
}
ackedPackets = append(ackedPackets, el)
}
} else {
ackedPackets = append(ackedPackets, el)
}
}
return ackedPackets, nil
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber == largestAcked {
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
return true
}
// Packets are sorted by number, so we can stop searching
if packet.PacketNumber > largestAcked {
break
}
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
h.alarm = time.Time{}
return
}
// TODO(#497): TLP
if !h.handshakeComplete {
h.alarm = time.Now().Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO
h.alarm = time.Now().Add(h.computeRTOTimeout())
}
}
func (h *sentPacketHandler) detectLostPackets() {
h.lossTime = time.Time{}
now := time.Now()
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
var lostPackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber > h.LargestAcked {
break
}
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, el)
} else if h.lossTime.IsZero() {
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
}
if len(lostPackets) > 0 {
for _, p := range lostPackets {
h.queuePacketForRetransmission(p)
h.congestion.OnPacketLost(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
}
func (h *sentPacketHandler) OnAlarm() {
// TODO(#497): TLP
if !h.handshakeComplete {
h.queueHandshakePacketsForRetransmission()
h.handshakeCount++
} else if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
h.detectLostPackets()
} else {
// RTO
h.retransmitOldestTwoPackets()
h.rtoCount++
}
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
h.bytesInFlight -= packetElement.Value.Length
h.rtoCount = 0
h.handshakeCount = 0
// TODO(#497): h.tlpCount = 0
h.packetHistory.Remove(packetElement)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 {
return nil
}
packet := h.retransmissionQueue[0]
// Shift the slice and don't retain anything that isn't needed.
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet
}
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
return h.largestInOrderAcked() + 1
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendingAllowed() bool {
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
if congestionLimited {
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
h.bytesInFlight,
h.congestion.GetCongestionWindow())
}
// Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
// to RTOs, but we currently don't have a nice way of distinguishing them.
haveRetransmissions := len(h.retransmissionQueue) > 0
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
}
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)
}
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)
}
}
func (h *sentPacketHandler) queueRTO(el *PacketElement) {
packet := &el.Value
utils.Debugf(
"\tQueueing packet 0x%x for retransmission (RTO), %d outstanding",
packet.PacketNumber,
h.packetHistory.Len(),
)
h.queuePacketForRetransmission(el)
h.congestion.OnPacketLost(packet.PacketNumber, packet.Length, h.bytesInFlight)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
var handshakePackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, el)
}
}
for _, el := range handshakePackets {
h.queuePacketForRetransmission(el)
}
}
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
packet := &packetElement.Value
h.bytesInFlight -= packet.Length
h.retransmissionQueue = append(h.retransmissionQueue, packet)
h.packetHistory.Remove(packetElement)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := 2 * h.rttStats.SmoothedRTT()
if duration == 0 {
duration = 2 * defaultInitialRTT
}
duration = utils.MaxDuration(duration, minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
rto := h.congestion.RetransmissionDelay()
if rto == 0 {
rto = defaultRTOTimeout
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff
rto = rto << h.rtoCount
return utils.MinDuration(rto, maxRTOTimeout)
}
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return true
}
}
return false
}
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
lioa := h.largestInOrderAcked()
deleteIndex := 0
for i, p := range h.skippedPackets {
if p <= lioa {
deleteIndex = i + 1
}
}
h.skippedPackets = h.skippedPackets[deleteIndex:]
}

View File

@ -1,42 +0,0 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
type stopWaitingManager struct {
largestLeastUnackedSent protocol.PacketNumber
nextLeastUnacked protocol.PacketNumber
lastStopWaitingFrame *wire.StopWaitingFrame
}
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
if force {
return s.lastStopWaitingFrame
}
return nil
}
s.largestLeastUnackedSent = s.nextLeastUnacked
swf := &wire.StopWaitingFrame{
LeastUnacked: s.nextLeastUnacked,
}
s.lastStopWaitingFrame = swf
return swf
}
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
if ack.LargestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = ack.LargestAcked + 1
}
}
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
if p >= s.nextLeastUnacked {
s.nextLeastUnacked = p + 1
}
}

View File

@ -10,16 +10,18 @@ environment:
- GOARCH: 386
- GOARCH: amd64
hosts:
quic.clemente.io: 127.0.0.1
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install:
- rmdir c:\go /s /q
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.9.2.windows-amd64.zip
- 7z x go1.9.2.windows-amd64.zip -y -oC:\ > NUL
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.11.windows-amd64.zip
- 7z x go1.11.windows-amd64.zip -y -oC:\ > NUL
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH%
- echo %GOPATH%
- git submodule update --init --recursive
- go get github.com/onsi/ginkgo/ginkgo
- go get github.com/onsi/gomega
- go version

View File

@ -8,19 +8,20 @@ import (
var bufferPool sync.Pool
func getPacketBuffer() []byte {
return bufferPool.Get().([]byte)
func getPacketBuffer() *[]byte {
return bufferPool.Get().(*[]byte)
}
func putPacketBuffer(buf []byte) {
if cap(buf) != int(protocol.MaxReceivePacketSize) {
func putPacketBuffer(buf *[]byte) {
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!")
}
bufferPool.Put(buf[:0])
bufferPool.Put(buf)
}
func init() {
bufferPool.New = func() interface{} {
return make([]byte, 0, protocol.MaxReceivePacketSize)
b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
}
}

View File

@ -1,69 +1,81 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type client struct {
mutex sync.Mutex
conn connection
hostname string
conn connection
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
handshakeChan <-chan handshakeEvent
packetHandlers packetHandlerManager
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has version negotiation completed yet
token []byte
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
config *Config
connectionID protocol.ConnectionID
version protocol.VersionNumber
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
session packetHandler
initialVersion protocol.VersionNumber
version protocol.VersionNumber
handshakeChan chan struct{}
session quicSession
logger utils.Logger
}
var _ packetHandler = &client{}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = utils.GenerateConnectionID
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
)
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, addr, tlsConf, config)
}
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
func DialAddr(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
) (Session, error) {
return DialAddrContext(context.Background(), addr, tlsConf, config)
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address.
func DialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -72,52 +84,7 @@ func DialAddrNonFWSecure(
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := generateConnectionID()
if err != nil {
return nil, err
}
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}),
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.establishSecureConnection(); err != nil {
return nil, err
}
return c.session.(NonFWSession), nil
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
@ -129,19 +96,89 @@ func Dial(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
}
func dialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
if err := sess.WaitUntilHandshakeComplete(); err != nil {
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn)
if err != nil {
return nil, err
}
return sess, nil
c.packetHandlers = packetHandlers
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.session, nil
}
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
config *Config,
tlsConf *tls.Config,
host string,
createdPacketConn bool,
) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
}
if tlsConf.ServerName == "" {
var err error
tlsConf.ServerName, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config {
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if config == nil {
config = &Config{}
}
@ -161,163 +198,146 @@ func populateClientConfig(config *Config) *Config {
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
ConnectionIDLength: connIDLen,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
KeepAlive: config.KeepAlive,
}
}
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
if err := c.createNewSession(c.version, nil); err != nil {
func (c *client) generateConnectionIDs() error {
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
if err != nil {
return err
}
go c.listen()
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return err
}
c.srcConnID = srcConnID
c.destConnID = destConnID
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
if err := c.createNewTLSSession(c.version); err != nil {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
// - any other error that might occur
// - when the connection is forward-secure
func (c *client) establishSecureConnection(ctx context.Context) error {
errorChan := make(chan error, 1)
var runErr error
errorChan := make(chan struct{})
go func() {
// session.run() returns as soon as the session is closed
runErr = c.session.run()
if runErr == errCloseSessionForNewVersion {
// run the new session
runErr = c.session.run()
err := c.session.run() // returns as soon as the session is closed
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
c.conn.Close()
}
close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
errorChan <- err
}()
// wait until the server accepts the QUIC version (or an error occurs)
select {
case <-errorChan:
return runErr
case <-c.versionNegotiationChan:
}
select {
case <-errorChan:
return runErr
case ev := <-c.handshakeChan:
if ev.err != nil {
return ev.err
}
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
case <-ctx.Done():
// The session will send a PeerGoingAway error to the server.
c.session.Close()
return ctx.Err()
case err := <-errorChan:
return err
case <-c.handshakeChan:
// handshake successfully completed
return nil
}
}
// Listen listens
func (c *client) listen() {
var err error
for {
var n int
var addr net.Addr
data := getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err = c.conn.Read(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
c.session.Close(err)
}
break
}
data = data[:n]
c.handlePacket(addr, data)
func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil {
c.logger.Errorf("error handling packet: %s", err)
}
}
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now()
r := bytes.NewReader(packet)
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
if err != nil {
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the header
return
}
// reject packets with truncated connection id if we didn't request truncation
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return
}
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
hdr.Raw = packet[:len(packet)-r.Len()]
func (c *client) handlePacketImpl(p *receivedPacket) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := wire.ParsePublicReset(r)
// handle Version Negotiation Packets
if p.header.IsVersionNegotiation {
err := c.handleVersionNegotiationPacket(p.header)
if err != nil {
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return
c.session.destroy(err)
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return
// version negotiation packets have no payload
return err
}
isVersionNegotiationPacket := hdr.VersionFlag /* gQUIC Version Negotiation Packet */ || hdr.Type == protocol.PacketTypeVersionNegotiation /* IETF draft style Version Negotiation Packet */
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
// handle Version Negotiation Packets
if isVersionNegotiationPacket {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
return
}
// version negotiation packets have no payload
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
c.session.Close(err)
}
return
if p.header.Type == protocol.PacketTypeRetry {
c.handleRetryPacket(p.header)
return nil
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
close(c.versionNegotiationChan)
}
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
})
c.session.handlePacket(p)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
return nil
}
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
@ -327,42 +347,115 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
}
}
c.receivedVersionNegotiationPacket = true
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version
initialVersion := c.version
c.initialVersion = c.version
c.version = newVersion
var err error
c.connectionID, err = utils.GenerateConnectionID()
if err := c.generateConnectionIDs(); err != nil {
return err
}
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
return nil
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger)
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return
}
if hdr.SrcConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
return
}
// If a token is already set, this means that we already received a Retry from the server.
// Ignore this Retry packet.
if len(c.token) > 0 {
c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
return
}
c.origDestConnID = c.destConnID
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
}
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
params := &handshake.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
InitialMaxData: protocol.InitialMaxData,
IdleTimeout: c.config.IdleTimeout,
MaxBidiStreams: uint64(c.config.MaxIncomingStreams),
MaxUniStreams: uint64(c.config.MaxIncomingUniStreams),
DisableMigration: true,
}
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
retireConnectionIDImpl: c.packetHandlers.Retire,
removeConnectionIDImpl: c.packetHandlers.Remove,
}
sess, err := newClientSession(
c.conn,
runner,
c.token,
c.origDestConnID,
c.destConnID,
c.srcConnID,
c.config,
c.tlsConf,
params,
c.initialVersion,
c.logger,
c.version,
)
if err != nil {
return err
}
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
// create a new session and close the old one
// the new session must be created first to update client member variables
oldSession := c.session
defer oldSession.Close(errCloseSessionForNewVersion)
return c.createNewSession(initialVersion, hdr.SupportedVersions)
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
return nil
}
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error {
var err error
utils.Debugf("createNewSession with initial version %s", initialVersion)
c.session, c.handshakeChan, err = newClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.tlsConf,
c.config,
initialVersion,
negotiatedVersions,
)
return err
func (c *client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return nil
}
return c.session.Close()
}
func (c *client) destroy(e error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return
}
c.session.destroy(e)
}
func (c *client) GetVersion() protocol.VersionNumber {
c.mutex.Lock()
v := c.version
c.mutex.Unlock()
return v
}
func (c *client) GetPerspective() protocol.Perspective {
return protocol.PerspectiveClient
}

View File

@ -1,11 +1,16 @@
coverage:
round: nearest
ignore:
- ackhandler/packet_linkedlist.go
- streams_map_incoming_bidi.go
- streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go
- h2quic/gzipreader.go
- h2quic/response.go
- internal/ackhandler/packet_linkedlist.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/packetinterval_linkedlist.go
- internal/utils/linkedlist/linkedlist.go
status:
project:
default:

View File

@ -1,183 +0,0 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const (
// Note: This constant is also defined in the ackhandler package.
initialRTTus = 100 * 1000
rttAlpha float32 = 0.125
oneMinusAlpha float32 = (1 - rttAlpha)
rttBeta float32 = 0.25
oneMinusBeta float32 = (1 - rttBeta)
halfWindow float32 = 0.5
quarterWindow float32 = 0.25
)
type rttSample struct {
rtt time.Duration
time time.Time
}
// RTTStats provides round-trip statistics
type RTTStats struct {
initialRTTus int64
recentMinRTTwindow time.Duration
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
numMinRTTsamplesRemaining uint32
newMinRTT rttSample
recentMinRTT rttSample
halfWindowRTT rttSample
quarterWindowRTT rttSample
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{
initialRTTus: initialRTTus,
recentMinRTTwindow: utils.InfDuration,
}
}
// InitialRTTus is the initial RTT in us
func (r *RTTStats) InitialRTTus() int64 { return r.initialRTTus }
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// LatestRTT returns the most recent rtt measurement.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// RecentMinRTT the minRTT since SampleNewRecentMinRtt has been called, or the
// minRTT for the entire connection if SampleNewMinRtt was never called.
func (r *RTTStats) RecentMinRTT() time.Duration { return r.recentMinRTT.rtt }
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// GetQuarterWindowRTT gets the quarter window RTT
func (r *RTTStats) GetQuarterWindowRTT() time.Duration { return r.quarterWindowRTT.rtt }
// GetHalfWindowRTT gets the half window RTT
func (r *RTTStats) GetHalfWindowRTT() time.Duration { return r.halfWindowRTT.rtt }
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// SetRecentMinRTTwindow sets how old a recent min rtt sample can be.
func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
r.recentMinRTTwindow = recentMinRTTwindow
}
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 {
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
return
}
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
// ackDelay but the raw observed sendDelta, since poor clock granularity at
// the client may cause a high ackDelay to result in underestimation of the
// r.minRTT.
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
r.updateRecentMinRTT(sendDelta, now)
// Correct for ackDelay if information received from the peer results in a
// positive RTT sample. Otherwise, we use the sendDelta as a reasonable
// measure for smoothedRTT.
sample := sendDelta
if sample > ackDelay {
sample -= ackDelay
}
r.latestRTT = sample
// First time call.
if r.smoothedRTT == 0 {
r.smoothedRTT = sample
r.meanDeviation = sample / 2
} else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
}
}
func (r *RTTStats) updateRecentMinRTT(sample time.Duration, now time.Time) { // Recent minRTT update.
if r.numMinRTTsamplesRemaining > 0 {
r.numMinRTTsamplesRemaining--
if r.newMinRTT.rtt == 0 || sample <= r.newMinRTT.rtt {
r.newMinRTT = rttSample{rtt: sample, time: now}
}
if r.numMinRTTsamplesRemaining == 0 {
r.recentMinRTT = r.newMinRTT
r.halfWindowRTT = r.newMinRTT
r.quarterWindowRTT = r.newMinRTT
}
}
// Update the three recent rtt samples.
if r.recentMinRTT.rtt == 0 || sample <= r.recentMinRTT.rtt {
r.recentMinRTT = rttSample{rtt: sample, time: now}
r.halfWindowRTT = r.recentMinRTT
r.quarterWindowRTT = r.recentMinRTT
} else if sample <= r.halfWindowRTT.rtt {
r.halfWindowRTT = rttSample{rtt: sample, time: now}
r.quarterWindowRTT = r.halfWindowRTT
} else if sample <= r.quarterWindowRTT.rtt {
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
}
// Expire old min rtt samples.
if r.recentMinRTT.time.Before(now.Add(-r.recentMinRTTwindow)) {
r.recentMinRTT = r.halfWindowRTT
r.halfWindowRTT = r.quarterWindowRTT
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
} else if r.halfWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*halfWindow) * time.Microsecond)) {
r.halfWindowRTT = r.quarterWindowRTT
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
} else if r.quarterWindowRTT.time.Before(now.Add(-time.Duration(float32(r.recentMinRTTwindow/time.Microsecond)*quarterWindow) * time.Microsecond)) {
r.quarterWindowRTT = rttSample{rtt: sample, time: now}
}
}
// SampleNewRecentMinRTT forces RttStats to sample a new recent min rtt within the next
// |numSamples| UpdateRTT calls.
func (r *RTTStats) SampleNewRecentMinRTT(numSamples uint32) {
r.numMinRTTsamplesRemaining = numSamples
r.newMinRTT = rttSample{}
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
r.initialRTTus = initialRTTus
r.numMinRTTsamplesRemaining = 0
r.recentMinRTTwindow = utils.InfDuration
r.recentMinRTT = rttSample{}
r.halfWindowRTT = rttSample{}
r.quarterWindowRTT = rttSample{}
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
}

View File

@ -0,0 +1,108 @@
package quic
import (
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStream interface {
// for receiving data
HandleCryptoFrame(*wire.CryptoFrame) error
GetCryptoData() []byte
Finish() error
// for sending data
io.Writer
HasData() bool
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
}
type cryptoStreamImpl struct {
queue *frameSorter
msgBuf []byte
highestOffset protocol.ByteCount
finished bool
writeOffset protocol.ByteCount
writeBuf []byte
}
func newCryptoStream() cryptoStream {
return &cryptoStreamImpl{
queue: newFrameSorter(),
}
}
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
}
if s.finished {
if highestOffset > s.highestOffset {
// reject crypto data received after this stream was already finished
return errors.New("received crypto data after change of encryption level")
}
// ignore data with a smaller offset than the highest received
// could e.g. be a retransmission
return nil
}
s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
return err
}
for {
data, _ := s.queue.Pop()
if data == nil {
return nil
}
s.msgBuf = append(s.msgBuf, data...)
}
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
}
func (s *cryptoStreamImpl) Finish() error {
if s.queue.HasMoreData() {
return errors.New("encryption level changed, but crypto stream has more data to read")
}
s.finished = true
return nil
}
// Writes writes data that should be sent out in CRYPTO frames
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
return len(p), nil
}
func (s *cryptoStreamImpl) HasData() bool {
return len(s.writeBuf) > 0
}
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: s.writeOffset}
n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
f.Data = s.writeBuf[:n]
s.writeBuf = s.writeBuf[n:]
s.writeOffset += n
return f
}

View File

@ -0,0 +1,55 @@
package quic
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) bool
}
type cryptoStreamManager struct {
cryptoHandler cryptoDataHandler
initialStream cryptoStream
handshakeStream cryptoStream
}
func newCryptoStreamManager(
cryptoHandler cryptoDataHandler,
initialStream cryptoStream,
handshakeStream cryptoStream,
) *cryptoStreamManager {
return &cryptoStreamManager{
cryptoHandler: cryptoHandler,
initialStream: initialStream,
handshakeStream: handshakeStream,
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
var str cryptoStream
switch encLevel {
case protocol.EncryptionInitial:
str = m.initialStream
case protocol.EncryptionHandshake:
str = m.handshakeStream
default:
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return false, err
}
for {
data := str.GetCryptoData()
if data == nil {
return false, nil
}
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return true, str.Finish()
}
}
}

View File

@ -5,51 +5,55 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type streamFrameSorter struct {
queuedFrames map[protocol.ByteCount]*wire.StreamFrame
readPosition protocol.ByteCount
gaps *utils.ByteIntervalList
type frameSorter struct {
queue map[protocol.ByteCount][]byte
readPos protocol.ByteCount
finalOffset protocol.ByteCount
gaps *utils.ByteIntervalList
}
var (
errTooManyGapsInReceivedStreamData = errors.New("Too many gaps in received StreamFrame data")
errDuplicateStreamData = errors.New("Duplicate Stream Data")
errEmptyStreamData = errors.New("Stream Data empty")
)
var errDuplicateStreamData = errors.New("Duplicate Stream Data")
func newStreamFrameSorter() *streamFrameSorter {
s := streamFrameSorter{
gaps: utils.NewByteIntervalList(),
queuedFrames: make(map[protocol.ByteCount]*wire.StreamFrame),
func newFrameSorter() *frameSorter {
s := frameSorter{
gaps: utils.NewByteIntervalList(),
queue: make(map[protocol.ByteCount][]byte),
finalOffset: protocol.MaxByteCount,
}
s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount})
return &s
}
func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
if frame.DataLen() == 0 {
if frame.FinBit {
s.queuedFrames[frame.Offset] = frame
return nil
}
return errEmptyStreamData
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, fin bool) error {
err := s.push(data, offset, fin)
if err == errDuplicateStreamData {
return nil
}
return err
}
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, fin bool) error {
if fin {
s.finalOffset = offset + protocol.ByteCount(len(data))
}
if len(data) == 0 {
return nil
}
var wasCut bool
if oldFrame, ok := s.queuedFrames[frame.Offset]; ok {
if frame.DataLen() <= oldFrame.DataLen() {
if oldData, ok := s.queue[offset]; ok {
if len(data) <= len(oldData) {
return errDuplicateStreamData
}
frame.Data = frame.Data[oldFrame.DataLen():]
frame.Offset += oldFrame.DataLen()
data = data[len(oldData):]
offset += protocol.ByteCount(len(oldData))
wasCut = true
}
start := frame.Offset
end := frame.Offset + frame.DataLen()
start := offset
end := offset + protocol.ByteCount(len(data))
// skip all gaps that are before this stream frame
var gap *utils.ByteIntervalElement
@ -69,9 +73,9 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
if start < gap.Value.Start {
add := gap.Value.Start - start
frame.Offset += add
offset += add
start += add
frame.Data = frame.Data[add:]
data = data[add:]
wasCut = true
}
@ -89,15 +93,15 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
break
}
// delete queued frames completely covered by the current frame
delete(s.queuedFrames, endGap.Value.End)
delete(s.queue, endGap.Value.End)
endGap = nextEndGap
}
if end > endGap.Value.End {
cutLen := end - endGap.Value.End
len := frame.DataLen() - cutLen
len := protocol.ByteCount(len(data)) - cutLen
end -= cutLen
frame.Data = frame.Data[:len]
data = data[:len]
wasCut = true
}
@ -130,32 +134,30 @@ func (s *streamFrameSorter) Push(frame *wire.StreamFrame) error {
}
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
return errTooManyGapsInReceivedStreamData
return errors.New("Too many gaps in received data")
}
if wasCut {
data := make([]byte, frame.DataLen())
copy(data, frame.Data)
frame.Data = data
newData := make([]byte, len(data))
copy(newData, data)
data = newData
}
s.queuedFrames[frame.Offset] = frame
s.queue[offset] = data
return nil
}
func (s *streamFrameSorter) Pop() *wire.StreamFrame {
frame := s.Head()
if frame != nil {
s.readPosition += frame.DataLen()
delete(s.queuedFrames, frame.Offset)
func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
data, ok := s.queue[s.readPos]
if !ok {
return nil, s.readPos >= s.finalOffset
}
return frame
delete(s.queue, s.readPos)
s.readPos += protocol.ByteCount(len(data))
return data, s.readPos >= s.finalOffset
}
func (s *streamFrameSorter) Head() *wire.StreamFrame {
frame, ok := s.queuedFrames[s.readPosition]
if ok {
return frame
}
return nil
// HasMoreData says if there is any more data queued at *any* offset.
func (s *frameSorter) HasMoreData() bool {
return len(s.queue) > 0
}

109
vendor/github.com/lucas-clemente/quic-go/framer.go generated vendored Normal file
View File

@ -0,0 +1,109 @@
package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type framer interface {
QueueControlFrame(wire.Frame)
AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID)
AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
}
type framerI struct {
mutex sync.Mutex
streamGetter streamGetter
version protocol.VersionNumber
activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
}
var _ framer = &framerI{}
func newFramer(
streamGetter streamGetter,
v protocol.VersionNumber,
) framer {
return &framerI{
streamGetter: streamGetter,
activeStreams: make(map[protocol.StreamID]struct{}),
version: v,
}
}
func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
f.controlFrames = append(f.controlFrames, frame)
f.controlFrameMutex.Unlock()
}
func (f *framerI) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) {
var length protocol.ByteCount
f.controlFrameMutex.Lock()
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(f.version)
if length+frameLen > maxLen {
break
}
frames = append(frames, frame)
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
f.controlFrameMutex.Unlock()
return frames, length
}
func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id)
f.activeStreams[id] = struct{}{}
}
f.mutex.Unlock()
}
func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame {
var length protocol.ByteCount
f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := len(f.streamQueue)
for i := 0; i < numActiveStreams; i++ {
if maxLen-length < protocol.MinStreamFrameSize {
break
}
id := f.streamQueue[0]
f.streamQueue = f.streamQueue[1:]
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
continue
}
frame, hasMoreData := str.popStreamFrame(maxLen - length)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue = append(f.streamQueue, id)
} else { // no more data to send. Stream is not active any more
delete(f.activeStreams, id)
}
if frame == nil { // can happen if the receiveStream was canceled after it said it had data
continue
}
frames = append(frames, frame)
length += frame.Length(f.version)
}
f.mutex.Unlock()
return frames
}

View File

@ -17,22 +17,47 @@ type StreamID = protocol.StreamID
type VersionNumber = protocol.VersionNumber
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
type Cookie struct {
RemoteAddr string
SentTime time.Time
}
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams
type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer
StreamID() StreamID
// Reset closes the stream with an error.
Reset(error)
// CancelWrite aborts sending on this stream.
// It must not be called after Close.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
@ -53,34 +78,78 @@ type Stream interface {
SetDeadline(t time.Time) error
}
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached.
// New streams always have the smallest possible stream ID.
// TODO: Enable testing for the special error
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
AcceptUniStream() (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error)
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened.
// It always picks the smallest possible stream ID.
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenStreamSync() (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenUniStreamSync() (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error
// Close the connection.
io.Closer
// Close the connection with an error.
// The error must not be nil.
CloseWithError(ErrorCode, error) error
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
}
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
}
// Config contains all configuration data needed for a QUIC server or client.
@ -89,10 +158,13 @@ type Config struct {
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []VersionNumber
// Ask the server to omit the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDOmission bool
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
@ -113,6 +185,14 @@ type Config struct {
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow uint64
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
}

View File

@ -0,0 +1,3 @@
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet

View File

@ -0,0 +1,48 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet)
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
// The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
// It always returns a number greater or equal than 1.
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
// Note that the number of packets is only calculated based on the pacing algorithm.
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)
PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm() error
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

View File

@ -0,0 +1,29 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
type Packet struct {
PacketNumber protocol.PacketNumber
PacketType protocol.PacketType
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
// There are two reasons why a packet cannot be retransmitted:
// * it was already retransmitted
// * this packet is a retransmission, and we already received an ACK for the original packet
canBeRetransmitted bool
includedInBytesInFlight bool
retransmittedAs []protocol.PacketNumber
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
retransmissionOf protocol.PacketNumber
}

View File

@ -1,13 +1,10 @@
// Generated by: main
// TypeWriter: linkedlist
// Directive: +gen on Packet
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package ackhandler
// List is a modification of http://golang.org/pkg/container/list/
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Linked list implementation from the Go standard library.
// PacketElement is an element of a linked list.
type PacketElement struct {
@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
return nil
}
// PacketList represents a doubly linked list.
// The zero value for PacketList is an empty list ready to use.
// PacketList is a linked list of Packets.
type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
// The complexity is O(1).
func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil.
// Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement {
if l.len == 0 {
return nil
@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
return l.root.next
}
// Back returns the last element of list l or nil.
// Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement {
if l.len == 0 {
return nil
@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
return l.root.prev
}
// lazyInit lazily initializes a zero PacketList value.
// lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() {
if l.root.next == nil {
l.Init()
@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
return e
}
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at).
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at)
}
@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero PacketElement) and l.remove will crash
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in PacketList.Remove about initialization of l
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
}
// MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {

View File

@ -1,10 +1,11 @@
package quic
package ackhandler
import (
"crypto/rand"
"math"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// The packetNumberGenerator generates the packet number for the next packet
@ -15,13 +16,17 @@ type packetNumberGenerator struct {
next protocol.PacketNumber
nextToSkip protocol.PacketNumber
history []protocol.PacketNumber
}
func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator {
return &packetNumberGenerator{
next: 1,
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
g := &packetNumberGenerator{
next: initial,
averagePeriod: averagePeriod,
}
g.generateNewSkip()
return g
}
func (p *packetNumberGenerator) Peek() protocol.PacketNumber {
@ -35,6 +40,10 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
p.next++
if p.next == p.nextToSkip {
if len(p.history)+1 > protocol.MaxTrackedSkippedPackets {
p.history = p.history[1:]
}
p.history = append(p.history, p.next)
p.next++
p.generateNewSkip()
}
@ -42,28 +51,28 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
return next
}
func (p *packetNumberGenerator) generateNewSkip() error {
num, err := p.getRandomNumber()
if err != nil {
return err
}
func (p *packetNumberGenerator) generateNewSkip() {
num := p.getRandomNumber()
skip := protocol.PacketNumber(num) * (p.averagePeriod - 1) / (math.MaxUint16 / 2)
// make sure that there are never two consecutive packet numbers that are skipped
p.nextToSkip = p.next + 2 + skip
return nil
}
// getRandomNumber() generates a cryptographically secure random number between 0 and MaxUint16 (= 65535)
// The expectation value is 65535/2
func (p *packetNumberGenerator) getRandomNumber() (uint16, error) {
func (p *packetNumberGenerator) getRandomNumber() uint16 {
b := make([]byte, 2)
_, err := rand.Read(b)
if err != nil {
return 0, err
}
rand.Read(b) // ignore the error here
num := uint16(b[0])<<8 + uint16(b[1])
return num, nil
return num
}
func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool {
for _, pn := range p.history {
if ack.AcksPacket(pn) {
return false
}
}
return true
}

View File

@ -0,0 +1,215 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
// initial maximum number of retransmittable packets received before sending an ack.
initialRetransmittablePacketsBeforeAck = 2
// number of retransmittable that an ACK is sent for
retransmittablePacketsBeforeAck = 10
// 1/5 RTT delay when doing ack decimation
ackDecimationDelay = 1.0 / 4
// 1/8 RTT delay when doing ack decimation
shortAckDecimationDelay = 1.0 / 8
// Minimum number of packets received before ack decimation is enabled.
// This intends to avoid the beginning of slow start, when CWNDs may be
// rapidly increasing.
minReceivedBeforeAckDecimation = 100
// Maximum number of packets to ack immediately after a missing packet for
// fast retransmission to kick in at the sender. This limit is created to
// reduce the number of acks sent that have no benefit for fast retransmission.
// Set to the number of nacks needed for fast retransmit plus one for protection
// against an ack loss
maxPacketsAfterNewMissing = 4
)
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(
rttStats *congestion.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: ackSendDelay,
rttStats: rttStats,
logger: logger,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil
}
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
// in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.ackQueued = true
return
}
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
}
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
}
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
} else {
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
}
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
ackTime := rcvTime.Add(ackDelay)
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
now := time.Now()
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
DelayTime: now.Sub(h.largestObservedReceivedTime),
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -2,9 +2,9 @@ package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
// The receivedPacketHistory stores if a packet number has already been received.
@ -74,17 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
return nil
}
// DeleteUpTo deletes all entries up to (and including) p
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p <= h.lowestInReceivedPacketNumbers {
return
}
h.lowestInReceivedPacketNumbers = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if p >= el.Value.Start && p < el.Value.End {
el.Value.Start = p + 1
} else if el.Value.End <= p { // delete a whole range
if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
} else if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else { // no ranges affected. Nothing to do
return
@ -101,7 +104,7 @@ func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
ackRanges := make([]wire.AckRange, h.ranges.Len())
i := 0
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
i++
}
return ackRanges
@ -111,8 +114,8 @@ func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.First = r.Start
ackRange.Last = r.End
ackRange.Smallest = r.Start
ackRange.Largest = r.End
}
return ackRange
}

View File

@ -16,8 +16,6 @@ func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f wire.Frame) bool {
switch f.(type) {
case *wire.StopWaitingFrame:
return false
case *wire.AckFrame:
return false
default:

View File

@ -0,0 +1,40 @@
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendTLP means that a TLP probe packet should be sent
SendTLP
// SendAny means that any packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendRTO:
return "rto"
case SendTLP:
return "tlp"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}

View File

@ -0,0 +1,633 @@
package ackhandler
import (
"errors"
"fmt"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Maximum number of tail loss probes before an RTO fires.
maxTLPs = 2
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
)
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
packetNumberGenerator *packetNumberGenerator
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
nextPacketSendTime time.Time
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
largestSentBeforeRTO protocol.PacketNumber
packetHistory *sentPacketHistory
retransmissionQueue []*Packet
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times a TLP has been sent without receiving an ack.
tlpCount uint32
allowTLP bool
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The number of RTO probe packets that should be sent.
numRTOs int
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
logger utils.Logger
version protocol.VersionNumber
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
false, /* don't use reno since chromium doesn't (why?) */
protocol.InitialCongestionWindow,
protocol.DefaultMaxCongestionWindow,
)
return &sentPacketHandler{
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
packetHistory: newSentPacketHistory(),
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
if p := h.packetHistory.FirstOutstanding(); p != nil {
return p.PacketNumber
}
return h.largestAcked + 1
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.Encryption1RTT {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.packetHistory.Remove(p.PacketNumber)
}
h.retransmissionQueue = queue
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
h.packetHistory.SentPacket(packet)
h.updateLossDetectionAlarm()
}
}
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
var p []*Packet
for _, packet := range packets {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
p = append(p, packet)
}
}
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %#x", p)
}
h.lastSentPacketNumber = packet.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
packet.largestAcked = ackFrame.LargestAcked()
}
}
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
if packet.EncryptionLevel != protocol.Encryption1RTT {
h.lastSentHandshakePacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
packet.canBeRetransmitted = true
if h.numRTOs > 0 {
h.numRTOs--
}
h.allowTLP = false
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
return isRetransmittable
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
largestAcked := ackFrame.LargestAcked()
if largestAcked > h.lastSentPacketNumber {
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
}
// duplicate or out of order ACK
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
return nil
}
h.largestReceivedPacketWithAck = withPacketNumber
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
if !h.packetNumberGenerator.Validate(ackFrame) {
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
priorInFlight := h.bytesInFlight
for _, p := range ackedPackets {
// TODO(#1534): check the encryption level
// if encLevel < p.EncryptionLevel {
// return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
// }
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
}
if err := h.onPacketAcked(p, rcvTime); err != nil {
return err
}
if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
}
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
return err
}
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
var ackedPackets []*Packet
ackRangeIndex := 0
lowestAcked := ackFrame.LowestAcked()
largestAcked := ackFrame.LargestAcked()
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
ackedPackets = append(ackedPackets, p)
}
} else {
ackedPackets = append(ackedPackets, p)
}
return true, nil
})
if h.logger.Debug() && len(ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(ackedPackets))
for i, p := range ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
}
return ackedPackets, err
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
return true
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if !h.packetHistory.HasOutstandingPackets() {
h.alarm = time.Time{}
return
}
if h.packetHistory.HasOutstandingHandshakePackets() {
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO or TLP alarm
alarmDuration := h.computeRTOTimeout()
if h.tlpCount < maxTLPs {
tlpAlarm := h.computeTLPTimeout()
// if the RTO duration is shorter than the TLP duration, use the RTO duration
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
}
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
h.lossTime = time.Time{}
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
var lostPackets []*Packet
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
if packet.PacketNumber > h.largestAcked {
return false, nil
}
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, packet)
} else if h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
}
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
return true, nil
})
if h.logger.Debug() && len(lostPackets) > 0 {
pns := make([]protocol.PacketNumber, len(lostPackets))
for i, p := range lostPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
}
for _, p := range lostPackets {
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
if p.canBeRetransmitted {
// queue the packet for retransmission, and report the loss to the congestion controller
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
h.packetHistory.Remove(p.PacketNumber)
}
return nil
}
func (h *sentPacketHandler) OnAlarm() error {
// When all outstanding are acknowledged, the alarm is canceled in
// updateLossDetectionAlarm. This doesn't reset the timer in the session though.
// When OnAlarm is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.packetHistory.HasOutstandingPackets() {
if err := h.onVerifiedAlarm(); err != nil {
return err
}
}
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) onVerifiedAlarm() error {
var err error
if h.packetHistory.HasOutstandingHandshakePackets() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
}
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
} else if !h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
}
// Early retransmit or time loss detection
err = h.detectLostPackets(time.Now(), h.bytesInFlight)
} else if h.tlpCount < maxTLPs { // TLP
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount)
}
h.allowTLP = true
h.tlpCount++
} else { // RTO
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount)
}
if h.rtoCount == 0 {
h.largestSentBeforeRTO = h.lastSentPacketNumber
}
h.rtoCount++
h.numRTOs += 2
}
return err
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
// This happens if a packet and its retransmissions is acked in the same ACK.
// As soon as we process the first one, this will remove all the retransmissions,
// so we won't find the retransmitted packet number later.
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
return nil
}
// only report the acking of this packet to the congestion controller if:
// * it is a retransmittable packet
// * this packet wasn't retransmitted yet
if p.isRetransmission {
// that the parent doesn't exist is expected to happen every time the original packet was already acked
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
if len(parent.retransmittedAs) == 1 {
parent.retransmittedAs = nil
} else {
// remove this packet from the slice of retransmission
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
for _, pn := range parent.retransmittedAs {
if pn != p.PacketNumber {
retransmittedAs = append(retransmittedAs, pn)
}
}
parent.retransmittedAs = retransmittedAs
}
}
}
// this also applies to packets that have been retransmitted as probe packets
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if h.rtoCount > 0 {
h.verifyRTO(p.PacketNumber)
}
if err := h.stopRetransmissionsFor(p); err != nil {
return err
}
h.rtoCount = 0
h.tlpCount = 0
h.handshakeCount = 0
return h.packetHistory.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
for _, r := range p.retransmittedAs {
packet := h.packetHistory.GetPacket(r)
if packet == nil {
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
}
h.stopRetransmissionsFor(packet)
}
return nil
}
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
if pn <= h.largestSentBeforeRTO {
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
// Replace SRTT with latest_rtt and increase the variance to prevent
// a spurious RTO from happening again.
h.rttStats.ExpireSmoothedMetrics()
return
}
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 {
return nil
}
packet := h.retransmissionQueue[0]
// Shift the slice and don't retain anything that isn't needed.
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet
}
func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
if len(h.retransmissionQueue) == 0 {
p := h.packetHistory.FirstOutstanding()
if p == nil {
return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
}
if err := h.queuePacketForRetransmission(p); err != nil {
return nil, err
}
}
return h.DequeuePacketForRetransmission(), nil
}
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
pn := h.packetNumberGenerator.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
}
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
return h.packetNumberGenerator.Pop()
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.allowTLP {
return SendTLP
}
if h.numRTOs > 0 {
return SendRTO
}
// Only send ACKs if we're congestion limited.
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
}
return SendAck
}
// Send retransmissions first, if there are any.
if len(h.retransmissionQueue) > 0 {
return SendRetransmission
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.nextPacketSendTime
}
func (h *sentPacketHandler) ShouldSendNumPackets() int {
if h.numRTOs > 0 {
// RTO probes should not be paced, but must be sent immediately.
return h.numRTOs
}
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
if delay == 0 || delay > protocol.MinPacingDelay {
return 1
}
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
return nil
}
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
if !p.canBeRetransmitted {
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
}
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
return nil
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
// TODO(#1236): include the max_ack_delay
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
var rto time.Duration
rtt := h.rttStats.SmoothedRTT()
if rtt == 0 {
rto = defaultRTOTimeout
} else {
rto = rtt + 4*h.rttStats.MeanDeviation()
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff
rto <<= h.rtoCount
return utils.MinDuration(rto, maxRTOTimeout)
}

View File

@ -0,0 +1,168 @@
package ackhandler
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
numOutstandingPackets int
numOutstandingHandshakePackets int
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
}
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets++
}
}
return el
}
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
retransmission, ok := h.packetMap[retransmissionOf]
// The retransmitted packet is not present anymore.
// This can happen if it was acked in between dequeueing of the retransmission and sending.
// Just treat the retransmissions as normal packets.
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
if !ok {
for _, packet := range packets {
h.sentPacketImpl(packet)
}
return
}
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
for i, packet := range packets {
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
el := h.sentPacketImpl(packet)
el.Value.isRetransmission = true
el.Value.retransmissionOf = retransmissionOf
}
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
if el, ok := h.packetMap[p]; ok {
return &el.Value
}
return nil
}
// Iterate iterates through all packets.
// The callback must not modify the history.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
var err error
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn]
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && !el.Value.canBeRetransmitted {
el = el.Next()
}
h.firstOutstanding = el
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstandingPackets > 0
}
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
return h.numOutstandingHandshakePackets > 0
}

View File

@ -16,11 +16,10 @@ import (
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling
// round trip time.
// where 0.100 is 100 ms which is the scaling round trip time.
const cubeScale = 40
const cubeCongestionWindowScale = 410
const cubeFactor protocol.PacketNumber = 1 << cubeScale / cubeCongestionWindowScale
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
const defaultNumConnections = 2
@ -32,39 +31,35 @@ const beta float32 = 0.7
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// If true, Cubic's epoch is shifted when the sender is application-limited.
const shiftQuicCubicEpochWhenAppLimited = true
const maxCubicTimeInterval = 30 * time.Millisecond
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Time when sender went into application-limited period. Zero if not in
// application-limited period.
appLimitedStartTime time.Time
// Time when we updated last_congestion_window.
lastUpdateTime time.Time
// Last congestion window (in packets) used.
lastCongestionWindow protocol.PacketNumber
// Max congestion window (in packets) used just before last loss event.
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.PacketNumber
// Number of acked packets since the cycle started (epoch).
ackedPacketsCount protocol.PacketNumber
lastMaxCongestionWindow protocol.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.PacketNumber
estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function.
originPointCongestionWindow protocol.PacketNumber
originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.PacketNumber
lastTargetCongestionWindow protocol.ByteCount
}
// NewCubic returns a new Cubic instance
@ -80,11 +75,8 @@ func NewCubic(clock Clock) *Cubic {
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.appLimitedStartTime = time.Time{}
c.lastUpdateTime = time.Time{}
c.lastCongestionWindow = 0
c.lastMaxCongestionWindow = 0
c.ackedPacketsCount = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
@ -107,57 +99,59 @@ func (c *Cubic) beta() float32 {
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
if shiftQuicCubicEpochWhenAppLimited {
// When sender is not using the available congestion window, Cubic's epoch
// should not continue growing. Record the time when sender goes into an
// app-limited period here, to compensate later when cwnd growth happens.
if c.appLimitedStartTime.IsZero() {
c.appLimitedStartTime = c.clock.Now()
}
} else {
// When sender is not using the available congestion window, Cubic's epoch
// should not continue growing. Reset the epoch when in such a period.
c.epoch = time.Time{}
}
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.PacketNumber) protocol.PacketNumber {
if currentCongestionWindow < c.lastMaxCongestionWindow {
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.PacketNumber(betaLastMax * float32(currentCongestionWindow))
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return protocol.PacketNumber(float32(currentCongestionWindow) * c.beta())
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, delayMin time.Duration) protocol.PacketNumber {
c.ackedPacketsCount++ // Packets acked.
currentTime := c.clock.Now()
// Cubic is "independent" of RTT, the update is limited by the time elapsed.
if c.lastCongestionWindow == currentCongestionWindow && (currentTime.Sub(c.lastUpdateTime) <= maxCubicTimeInterval) {
return utils.MaxPacketNumber(c.lastTargetCongestionWindow, c.estimatedTCPcongestionWindow)
}
c.lastCongestionWindow = currentCongestionWindow
c.lastUpdateTime = currentTime
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes protocol.ByteCount,
currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) protocol.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = currentTime // Start of epoch.
c.ackedPacketsCount = 1 // Reset count.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
@ -167,48 +161,37 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
} else {
// If sender was app-limited, then freeze congestion window growth during
// app-limited period. Continue growth now by shifting the epoch-start
// through the app-limited period.
if shiftQuicCubicEpochWhenAppLimited && !c.appLimitedStartTime.IsZero() {
shift := currentTime.Sub(c.appLimitedStartTime)
c.epoch = c.epoch.Add(shift)
c.appLimitedStartTime = time.Time{}
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64((currentTime.Add(delayMin).Sub(c.epoch)/time.Microsecond)<<10) / 1000000
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
// Right-shifts of negative, signed numbers have
// implementation-dependent behavior. Force the offset to be
// positive, similar to the kernel implementation.
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.PacketNumber((cubeCongestionWindowScale * offset * offset * offset) >> cubeScale)
var targetCongestionWindow protocol.PacketNumber
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// With dynamic beta/alpha based on number of active streams, it is possible
// for the required_ack_count to become much lower than acked_packets_count_
// suddenly, leading to more than one iteration through the following loop.
for {
// Update estimated TCP congestion_window.
requiredAckCount := protocol.PacketNumber(float32(c.estimatedTCPcongestionWindow) / c.alpha())
if c.ackedPacketsCount < requiredAckCount {
break
}
c.ackedPacketsCount -= requiredAckCount
c.estimatedTCPcongestionWindow++
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
@ -218,7 +201,6 @@ func (c *Cubic) CongestionWindowAfterAck(currentCongestionWindow protocol.Packet
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}

View File

@ -8,9 +8,9 @@ import (
)
const (
maxBurstBytes = 3 * protocol.DefaultTCPMSS
defaultMinimumCongestionWindow protocol.PacketNumber = 2
renoBeta float32 = 0.7 // Reno backoff factor.
maxBurstBytes = 3 * protocol.DefaultTCPMSS
renoBeta float32 = 0.7 // Reno backoff factor.
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
)
type cubicSender struct {
@ -31,12 +31,6 @@ type cubicSender struct {
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber
// Congestion window in packets.
congestionWindow protocol.PacketNumber
// Slow start congestion window in packets, aka ssthresh.
slowstartThreshold protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
@ -44,24 +38,35 @@ type cubicSender struct {
// When true, exit slow start with large cutback of congestion window.
slowStartLargeReduction bool
// Minimum congestion window in packets.
minCongestionWindow protocol.PacketNumber
// Congestion window in packets.
congestionWindow protocol.ByteCount
// Maximum number of outstanding packets for tcp.
maxTCPCongestionWindow protocol.PacketNumber
// Minimum congestion window in packets.
minCongestionWindow protocol.ByteCount
// Maximum congestion window.
maxCongestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowstartThreshold protocol.ByteCount
// Number of connections to simulate.
numConnections int
// ACK counter for the Reno implementation.
congestionWindowCount protocol.ByteCount
numAckedPackets uint64
initialCongestionWindow protocol.PacketNumber
initialMaxCongestionWindow protocol.PacketNumber
initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.ByteCount
minSlowStartExitWindow protocol.ByteCount
}
var _ SendAlgorithm = &cubicSender{}
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
// NewCubicSender makes a new cubic sender
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.PacketNumber) SendAlgorithmWithDebugInfo {
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
return &cubicSender{
rttStats: rttStats,
initialCongestionWindow: initialCongestionWindow,
@ -69,28 +74,37 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
congestionWindow: initialCongestionWindow,
minCongestionWindow: defaultMinimumCongestionWindow,
slowstartThreshold: initialMaxCongestionWindow,
maxTCPCongestionWindow: initialMaxCongestionWindow,
maxCongestionWindow: initialMaxCongestionWindow,
numConnections: defaultNumConnections,
cubic: NewCubic(clock),
reno: reno,
}
}
func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
if c.InRecovery() {
// PRR is used when in recovery.
return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold())
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
return 0
}
}
if c.GetCongestionWindow() > bytesInFlight {
return 0
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
delay = delay * 8 / 5
}
return utils.InfDuration
return delay
}
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
// Only update bytesInFlight for data packets.
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
bytesInFlight protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
if !isRetransmittable {
return false
return
}
if c.InRecovery() {
// PRR is used when in recovery.
@ -98,7 +112,6 @@ func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.By
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
return true
}
func (c *cubicSender) InRecovery() bool {
@ -110,18 +123,18 @@ func (c *cubicSender) InSlowStart() bool {
}
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return protocol.ByteCount(c.congestionWindow) * protocol.DefaultTCPMSS
return c.congestionWindow
}
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
return protocol.ByteCount(c.slowstartThreshold) * protocol.DefaultTCPMSS
return c.slowstartThreshold
}
func (c *cubicSender) ExitSlowstart() {
c.slowstartThreshold = c.congestionWindow
}
func (c *cubicSender) SlowstartThreshold() protocol.PacketNumber {
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
return c.slowstartThreshold
}
@ -131,20 +144,29 @@ func (c *cubicSender) MaybeExitSlowStart() {
}
}
func (c *cubicSender) OnPacketAcked(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
// PRR is used when in recovery.
c.prr.OnPacketAcked(ackedBytes)
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, bytesInFlight)
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) OnPacketLost(
packetNumber protocol.PacketNumber,
lostBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
@ -152,10 +174,8 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.stats.slowstartPacketsLost++
c.stats.slowstartBytesLost += lostBytes
if c.slowStartLargeReduction {
if c.stats.slowstartPacketsLost == 1 || (c.stats.slowstartBytesLost/protocol.DefaultTCPMSS) > (c.stats.slowstartBytesLost-lostBytes)/protocol.DefaultTCPMSS {
// Reduce congestion window by 1 for every mss of bytes lost.
c.congestionWindow = utils.MaxPacketNumber(c.congestionWindow-1, c.minCongestionWindow)
}
// Reduce congestion window by lost_bytes for every loss.
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
c.slowstartThreshold = c.congestionWindow
}
}
@ -166,17 +186,19 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.stats.slowstartPacketsLost++
}
c.prr.OnPacketLost(bytesInFlight)
c.prr.OnPacketLost(priorInFlight)
// TODO(chromium): Separate out all of slow start into a separate class.
if c.slowStartLargeReduction && c.InSlowStart() {
c.congestionWindow = c.congestionWindow - 1
if c.congestionWindow >= 2*c.initialCongestionWindow {
c.minSlowStartExitWindow = c.congestionWindow / 2
}
c.congestionWindow -= protocol.DefaultTCPMSS
} else if c.reno {
c.congestionWindow = protocol.PacketNumber(float32(c.congestionWindow) * c.RenoBeta())
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
// Enforce a minimum congestion window.
if c.congestionWindow < c.minCongestionWindow {
c.congestionWindow = c.minCongestionWindow
}
@ -184,7 +206,7 @@ func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.congestionWindowCount = 0
c.numAckedPackets = 0
}
func (c *cubicSender) RenoBeta() float32 {
@ -197,32 +219,38 @@ func (c *cubicSender) RenoBeta() float32 {
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(ackedPacketNumber protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount) {
func (c *cubicSender) maybeIncreaseCwnd(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(bytesInFlight) {
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
return
}
if c.congestionWindow >= c.maxTCPCongestionWindow {
if c.congestionWindow >= c.maxCongestionWindow {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow++
c.congestionWindow += protocol.DefaultTCPMSS
return
}
// Congestion avoidance
if c.reno {
// Classic Reno congestion avoidance.
c.congestionWindowCount++
c.numAckedPackets++
// Divide by num_connections to smoothly increase the CWND at a faster
// rate than conventional Reno.
if protocol.PacketNumber(c.congestionWindowCount*protocol.ByteCount(c.numConnections)) >= c.congestionWindow {
c.congestionWindow++
c.congestionWindowCount = 0
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
c.congestionWindow += protocol.DefaultTCPMSS
c.numAckedPackets = 0
}
} else {
c.congestionWindow = utils.MinPacketNumber(c.maxTCPCongestionWindow, c.cubic.CongestionWindowAfterAck(c.congestionWindow, c.rttStats.MinRTT()))
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
@ -278,21 +306,13 @@ func (c *cubicSender) OnConnectionMigration() {
c.largestSentAtLastCutback = 0
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.congestionWindowCount = 0
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowstartThreshold = c.initialMaxCongestionWindow
c.maxTCPCongestionWindow = c.initialMaxCongestionWindow
c.maxCongestionWindow = c.initialMaxCongestionWindow
}
// SetSlowStartLargeReduction allows enabling the SSLR experiment
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
c.slowStartLargeReduction = enabled
}
// RetransmissionDelay gives the time to retransmission
func (c *cubicSender) RetransmissionDelay() time.Duration {
if c.rttStats.SmoothedRTT() == 0 {
return 0
}
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
}

View File

@ -8,16 +8,15 @@ import (
// A SendAlgorithm performs congestion control and calculates the congestion window
type SendAlgorithm interface {
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
GetCongestionWindow() protocol.ByteCount
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, bytesInFlight protocol.ByteCount)
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration()
RetransmissionDelay() time.Duration
// Experiments
SetSlowStartLargeReduction(enabled bool)
@ -31,7 +30,7 @@ type SendAlgorithmWithDebugInfo interface {
// Stuff only used in testing
HybridSlowStart() *HybridSlowStart
SlowstartThreshold() protocol.PacketNumber
SlowstartThreshold() protocol.ByteCount
RenoBeta() float32
InRecovery() bool
}

View File

@ -1,10 +1,7 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
@ -23,9 +20,9 @@ func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
// OnPacketLost should be called on the first loss that triggers a recovery
// period and all other methods in this class should only be called when in
// recovery.
func (p *PrrSender) OnPacketLost(bytesInFlight protocol.ByteCount) {
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
p.bytesSentSinceLoss = 0
p.bytesInFlightBeforeLoss = bytesInFlight
p.bytesInFlightBeforeLoss = priorInFlight
p.bytesDeliveredSinceLoss = 0
p.ackCountSinceLoss = 0
}
@ -36,28 +33,22 @@ func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
p.ackCountSinceLoss++
}
// TimeUntilSend calculates the time until a packet can be sent
func (p *PrrSender) TimeUntilSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) time.Duration {
// CanSend returns if packets can be sent
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
// Return QuicTime::Zero In order to ensure limited transmit always works.
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
return 0
return true
}
if congestionWindow > bytesInFlight {
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
// of sending the entire available window. This prevents burst retransmits
// when more packets are lost than the CWND reduction.
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
if p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS <= p.bytesSentSinceLoss {
return utils.InfDuration
}
return 0
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
}
// Implement Proportional Rate Reduction (RFC6937).
// Checks a simplified version of the PRR formula that doesn't use division:
// AvailableSendWindow =
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
if p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss {
return 0
}
return utils.InfDuration
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
}

View File

@ -0,0 +1,101 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const (
rttAlpha float32 = 0.125
oneMinusAlpha float32 = (1 - rttAlpha)
rttBeta float32 = 0.25
oneMinusBeta float32 = (1 - rttBeta)
// The default RTT used before an RTT sample is taken.
defaultInitialRTT = 100 * time.Millisecond
)
// RTTStats provides round-trip statistics
type RTTStats struct {
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{}
}
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// LatestRTT returns the most recent rtt measurement.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection.
// If no valid updates have occurred, it returns the initial RTT.
func (r *RTTStats) SmoothedOrInitialRTT() time.Duration {
if r.smoothedRTT != 0 {
return r.smoothedRTT
}
return defaultInitialRTT
}
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 {
return
}
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
// ackDelay but the raw observed sendDelta, since poor clock granularity at
// the client may cause a high ackDelay to result in underestimation of the
// r.minRTT.
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
// Correct for ackDelay if information received from the peer results in a
// an RTT sample at least as large as minRTT. Otherwise, only use the
// sendDelta.
sample := sendDelta
if sample-r.minRTT >= ackDelay {
sample -= ackDelay
}
r.latestRTT = sample
// First time call.
if r.smoothedRTT == 0 {
r.smoothedRTT = sample
r.meanDeviation = sample / 2
} else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
}
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
}

View File

@ -1,72 +0,0 @@
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/lucas-clemente/aes12"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadAESGCM12 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
var _ AEAD = &aeadAESGCM12{}
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
//
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
// tag size, and couples the cipher and aes packages closely.
// See https://github.com/lucas-clemente/aes12.
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
}
encrypterCipher, err := aes12.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := aes12.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes12.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := aes12.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM12{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}
func (aead *aeadAESGCM12) Overhead() int {
return aead.encrypter.Overhead()
}

View File

@ -1,48 +0,0 @@
package crypto
import (
"fmt"
"hash/fnv"
"github.com/hashicorp/golang-lru"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
compressedCertsCache *lru.Cache
)
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
// Hash all inputs
hasher := fnv.New64a()
for _, v := range chain {
hasher.Write(v)
}
hasher.Write(pCommonSetHashes)
hasher.Write(pCachedHashes)
hash := hasher.Sum64()
var result []byte
resultI, isCached := compressedCertsCache.Get(hash)
if isCached {
result = resultI.([]byte)
} else {
var err error
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
if err != nil {
return nil, err
}
compressedCertsCache.Add(hash, result)
}
return result, nil
}
func init() {
var err error
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
if err != nil {
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
}
}

View File

@ -1,113 +0,0 @@
package crypto
import (
"crypto/tls"
"errors"
"strings"
)
// A CertChain holds a certificate and a private key
type CertChain interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
// proofSource stores a key and a certificate for the server proof
type certChain struct {
config *tls.Config
}
var _ CertChain = &certChain{}
var errNoMatchingCertificate = errors.New("no matching certificate found")
// NewCertChain loads the key and cert from files
func NewCertChain(tlsConfig *tls.Config) CertChain {
return &certChain{config: tlsConfig}
}
// SignServerProof signs CHLO and server config for use in the server proof
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return signServerProof(cert, chlo, serverConfigData)
}
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes)
}
// GetLeafCert gets the leaf certificate
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return cert.Certificate[0], nil
}
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
c := cc.config
c, err := maybeGetConfigForClient(c, sni)
if err != nil {
return nil, err
}
// The rest of this function is mostly copied from crypto/tls.getCertificate
if c.GetCertificate != nil {
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil {
return cert, err
}
}
if len(c.Certificates) == 0 {
return nil, errNoMatchingCertificate
}
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
return &c.Certificates[0], nil
}
name := strings.ToLower(sni)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
if cert, ok := c.NameToCertificate[name]; ok {
return cert, nil
}
// try replacing labels in the name with wildcards until we get a
// match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &c.Certificates[0], nil
}
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}

View File

@ -1,272 +0,0 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"encoding/binary"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type entryType uint8
const (
entryCompressed entryType = 1
entryCached entryType = 2
entryCommon entryType = 3
)
type entry struct {
t entryType
h uint64 // set hash
i uint32 // index
}
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
res := &bytes.Buffer{}
cachedHashes, err := splitHashes(pCachedHashes)
if err != nil {
return nil, err
}
setHashes, err := splitHashes(pCommonSetHashes)
if err != nil {
return nil, err
}
chainHashes := make([]uint64, len(chain))
for i := range chain {
chainHashes[i] = HashCert(chain[i])
}
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
totalUncompressedLen := 0
for i, e := range entries {
res.WriteByte(uint8(e.t))
switch e.t {
case entryCached:
utils.LittleEndian.WriteUint64(res, e.h)
case entryCommon:
utils.LittleEndian.WriteUint64(res, e.h)
utils.LittleEndian.WriteUint32(res, e.i)
case entryCompressed:
totalUncompressedLen += 4 + len(chain[i])
}
}
res.WriteByte(0) // end of list
if totalUncompressedLen > 0 {
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
if err != nil {
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
}
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
for i, e := range entries {
if e.t != entryCompressed {
continue
}
lenCert := len(chain[i])
gz.Write([]byte{
byte(lenCert & 0xff),
byte((lenCert >> 8) & 0xff),
byte((lenCert >> 16) & 0xff),
byte((lenCert >> 24) & 0xff),
})
gz.Write(chain[i])
}
gz.Close()
}
return res.Bytes(), nil
}
func decompressChain(data []byte) ([][]byte, error) {
var chain [][]byte
var entries []entry
r := bytes.NewReader(data)
var numCerts int
var hasCompressedCerts bool
for {
entryTypeByte, err := r.ReadByte()
if entryTypeByte == 0 {
break
}
et := entryType(entryTypeByte)
if err != nil {
return nil, err
}
numCerts++
switch et {
case entryCached:
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
return nil, errors.New("unexpected cached certificate")
case entryCommon:
e := entry{t: entryCommon}
e.h, err = utils.LittleEndian.ReadUint64(r)
if err != nil {
return nil, err
}
e.i, err = utils.LittleEndian.ReadUint32(r)
if err != nil {
return nil, err
}
certSet, ok := certSets[e.h]
if !ok {
return nil, errors.New("unknown certSet")
}
if e.i >= uint32(len(certSet)) {
return nil, errors.New("certificate not found in certSet")
}
entries = append(entries, e)
chain = append(chain, certSet[e.i])
case entryCompressed:
hasCompressedCerts = true
entries = append(entries, entry{t: entryCompressed})
chain = append(chain, nil)
default:
return nil, errors.New("unknown entryType")
}
}
if numCerts == 0 {
return make([][]byte, 0), nil
}
if hasCompressedCerts {
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
if err != nil {
fmt.Println(4)
return nil, err
}
zlibDict := buildZlibDictForEntries(entries, chain)
gz, err := zlib.NewReaderDict(r, zlibDict)
if err != nil {
return nil, err
}
defer gz.Close()
var totalLength uint32
var certIndex int
for totalLength < uncompressedLength {
lenBytes := make([]byte, 4)
_, err := gz.Read(lenBytes)
if err != nil {
return nil, err
}
certLen := binary.LittleEndian.Uint32(lenBytes)
cert := make([]byte, certLen)
n, err := gz.Read(cert)
if uint32(n) != certLen && err != nil {
return nil, err
}
for {
if certIndex >= len(entries) {
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
}
if entries[certIndex].t == entryCompressed {
chain[certIndex] = cert
certIndex++
break
}
certIndex++
}
totalLength += 4 + certLen
}
}
return chain, nil
}
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
res := make([]entry, len(chain))
chainLoop:
for i := range chain {
// Check if hash is in cachedHashes
for j := range cachedHashes {
if chainHashes[i] == cachedHashes[j] {
res[i] = entry{t: entryCached, h: chainHashes[i]}
continue chainLoop
}
}
// Go through common sets and check if it's in there
for _, setHash := range setHashes {
set, ok := certSets[setHash]
if !ok {
// We don't have this set
continue
}
// We have this set, check if chain[i] is in the set
pos := set.findCertInSet(chain[i])
if pos >= 0 {
// Found
res[i] = entry{t: entryCommon, h: setHash, i: uint32(pos)}
continue chainLoop
}
}
res[i] = entry{t: entryCompressed}
}
return res
}
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
var dict bytes.Buffer
// First the cached and common in reverse order
for i := len(entries) - 1; i >= 0; i-- {
if entries[i].t == entryCompressed {
continue
}
dict.Write(chain[i])
}
dict.Write(certDictZlib)
return dict.Bytes()
}
func splitHashes(hashes []byte) ([]uint64, error) {
if len(hashes)%8 != 0 {
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
}
n := len(hashes) / 8
res := make([]uint64, n)
for i := 0; i < n; i++ {
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
}
return res, nil
}
func getCommonCertificateHashes() []byte {
ccs := make([]byte, 8*len(certSets))
i := 0
for certSetHash := range certSets {
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
i++
}
return ccs
}
// HashCert calculates the FNV1a hash of a certificate
func HashCert(cert []byte) uint64 {
h := fnv.New64a()
h.Write(cert)
return h.Sum64()
}

View File

@ -1,128 +0,0 @@
package crypto
var certDictZlib = []byte{
0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04,
0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03,
0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30,
0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01,
0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07,
0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65,
0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34,
0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31,
0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73,
0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65,
0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73,
0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06,
0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86,
0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01,
0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2,
0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e,
0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69,
0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03,
0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09,
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30,
0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08,
0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30,
0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74,
0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03,
0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33,
0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63,
0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74,
0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03,
0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53,
0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68,
0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55,
0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37,
0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d,
0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c,
0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00,
0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff,
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05,
0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07,
0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff,
0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d,
0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86,
0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e,
0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08,
0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74,
0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65,
0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17,
0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04,
0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86,
0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17,
0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72,
0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31,
0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65,
0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74,
0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73,
0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68,
0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76,
0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e,
0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11,
0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01,
0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04,
0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13,
0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50,
0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e,
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30,
0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61,
0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56,
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31,
0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65,
0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63,
0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53,
0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41,
0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e,
0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72,
0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a,
0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72,
0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72,
0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63,
0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72,
0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05,
0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b,
0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64,
0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06,
0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68,
0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64,
0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73,
0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74,
0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72,
0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16,
0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee,
0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27,
0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30,
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73,
}

View File

@ -1,130 +0,0 @@
package crypto
import (
"crypto/tls"
"crypto/x509"
"errors"
"hash/fnv"
"time"
"github.com/lucas-clemente/quic-go/qerr"
)
// CertManager manages the certificates sent by the server
type CertManager interface {
SetData([]byte) error
GetCommonCertificateHashes() []byte
GetLeafCert() []byte
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
}
type certManager struct {
chain []*x509.Certificate
config *tls.Config
}
var _ CertManager = &certManager{}
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
// NewCertManager creates a new CertManager
func NewCertManager(tlsConfig *tls.Config) CertManager {
return &certManager{config: tlsConfig}
}
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
func (c *certManager) SetData(data []byte) error {
byteChain, err := decompressChain(data)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
chain := make([]*x509.Certificate, len(byteChain))
for i, data := range byteChain {
cert, err := x509.ParseCertificate(data)
if err != nil {
return err
}
chain[i] = cert
}
c.chain = chain
return nil
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}
// GetLeafCert returns the leaf certificate of the certificate chain
// it returns nil if the certificate chain has not yet been set
func (c *certManager) GetLeafCert() []byte {
if len(c.chain) == 0 {
return nil
}
return c.chain[0].Raw
}
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
func (c *certManager) GetLeafCertHash() (uint64, error) {
leafCert := c.GetLeafCert()
if leafCert == nil {
return 0, errNoCertificateChain
}
h := fnv.New64a()
_, err := h.Write(leafCert)
if err != nil {
return 0, err
}
return h.Sum64(), nil
}
// VerifyServerProof verifies the signature of the server config
// it should only be called after the certificate chain has been set, otherwise it returns false
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
if len(c.chain) == 0 {
return false
}
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
}
// Verify verifies the certificate chain
func (c *certManager) Verify(hostname string) error {
if len(c.chain) == 0 {
return errNoCertificateChain
}
if c.config != nil && c.config.InsecureSkipVerify {
return nil
}
leafCert := c.chain[0]
var opts x509.VerifyOptions
if c.config != nil {
opts.Roots = c.config.RootCAs
if c.config.Time == nil {
opts.CurrentTime = time.Now()
} else {
opts.CurrentTime = c.config.Time()
}
}
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 {
intermediates := x509.NewCertPool()
for i := 1; i < len(c.chain); i++ {
intermediates.AddCert(c.chain[i])
}
opts.Intermediates = intermediates
}
_, err := leafCert.Verify(opts)
return err
}

View File

@ -1,24 +0,0 @@
package crypto
import (
"bytes"
"github.com/lucas-clemente/quic-go-certificates"
)
type certSet [][]byte
var certSets = map[uint64]certSet{
certsets.CertSet2Hash: certsets.CertSet2,
certsets.CertSet3Hash: certsets.CertSet3,
}
// findCertInSet searches for the cert in the set. Negative return value means not found.
func (s *certSet) findCertInSet(cert []byte) int {
for i, c := range *s {
if bytes.Equal(c, cert) {
return i
}
}
return -1
}

View File

@ -1,61 +0,0 @@
// +build ignore
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/aead/chacha20"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadChacha20Poly1305 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
}
// copy because ChaCha20Poly1305 expects array pointers
var MyKey, OtherKey [32]byte
copy(MyKey[:], myKey)
copy(OtherKey[:], otherKey)
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
if err != nil {
return nil, err
}
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
if err != nil {
return nil, err
}
return &aeadChacha20Poly1305{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View File

@ -1,71 +0,0 @@
// +build ignore
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Chacha20poly1305", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 32)
keyBob = make([]byte, 32)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + 12))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})

View File

@ -1,45 +0,0 @@
package crypto
import (
"crypto/rand"
"errors"
"golang.org/x/crypto/curve25519"
)
// KeyExchange manages the exchange of keys
type curve25519KEX struct {
secret [32]byte
public [32]byte
}
var _ KeyExchange = &curve25519KEX{}
// NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html
func NewCurve25519KEX() (KeyExchange, error) {
c := &curve25519KEX{}
if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key")
}
// See https://cr.yp.to/ecdh.html
c.secret[0] &= 248
c.secret[31] &= 127
c.secret[31] |= 64
curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil
}
func (c *curve25519KEX) PublicKey() []byte {
return c.public[:]
}
func (c *curve25519KEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
if len(otherPublic) != 32 {
return nil, errors.New("Curve25519: expected public key of 32 byte")
}
var res [32]byte
var otherPublicArray [32]byte
copy(otherPublicArray[:], otherPublic)
curve25519.ScalarMult(&res, &c.secret, &otherPublicArray)
return res[:], nil
}

Some files were not shown because too many files have changed in this diff Show More