From 8cb22691599984909d92021f01fd92407680d067 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 7 Nov 2017 17:55:18 +0800 Subject: [PATCH] add node ID --- chain.go | 84 +++++++++++++++++++++--------------------------- cmd/gost/main.go | 21 ++++++++++-- gost.go | 1 + http.go | 7 ++++ node.go | 16 ++++++--- sni.go | 16 ++++++--- 6 files changed, 87 insertions(+), 58 deletions(-) diff --git a/chain.go b/chain.go index 9a409d9..d328f32 100644 --- a/chain.go +++ b/chain.go @@ -1,9 +1,10 @@ package gost import ( + "bytes" "errors" + "fmt" "net" - "strings" "github.com/go-log/log" ) @@ -124,29 +125,6 @@ func (c *Chain) Conn() (conn net.Conn, err error) { return } -func (c *Chain) selectRoute() (route *Chain, err error) { - route = NewChain() - for _, group := range c.nodeGroups { - selector := group.Selector - if selector == nil { - selector = &defaultSelector{} - } - // select node from node group - node, err := selector.Select(group.Nodes(), group.Options...) - if err != nil { - return nil, err - } - if node.Client.Transporter.Multiplex() { - node.DialOptions = append(node.DialOptions, - ChainDialOption(route), - ) - route = NewChain() // cutoff the chain for multiplex - } - route.AddNode(node) - } - return -} - func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { if route.IsEmpty() { err = ErrEmptyChain @@ -155,11 +133,7 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { nodes := route.Nodes() node := nodes[0] - addr, err := selectIP(&node) - if err != nil { - return - } - cn, err := node.Client.Dial(addr, node.DialOptions...) + cn, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { return } @@ -171,13 +145,8 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { preNode := node for _, node := range nodes[1:] { - addr, err = selectIP(&node) - if err != nil { - return - } - var cc net.Conn - cc, err = preNode.Client.Connect(cn, addr) + cc, err = preNode.Client.Connect(cn, node.Addr) if err != nil { cn.Close() return @@ -195,8 +164,37 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { return } +func (c *Chain) selectRoute() (route *Chain, err error) { + buf := bytes.Buffer{} + route = NewChain() + for _, group := range c.nodeGroups { + selector := group.Selector + if selector == nil { + selector = &defaultSelector{} + } + // select node from node group + node, err := selector.Select(group.Nodes(), group.Options...) + if err != nil { + return nil, err + } + if _, err := selectIP(&node); err != nil { + return nil, err + } + buf.WriteString(fmt.Sprintf("%d@%s -> ", node.ID, node.Addr)) + + if node.Client.Transporter.Multiplex() { + node.DialOptions = append(node.DialOptions, + ChainDialOption(route), + ) + route = NewChain() // cutoff the chain for multiplex + } + route.AddNode(node) + } + log.Log("select route:", buf.String()) + return +} + func selectIP(node *Node) (string, error) { - addr := node.Addr s := node.IPSelector if s == nil { s = &RandomIPSelector{} @@ -207,17 +205,9 @@ func selectIP(node *Node) (string, error) { return "", err } if ip != "" { - if !strings.Contains(ip, ":") { - _, sport, err := net.SplitHostPort(addr) - if err != nil { - return "", err - } - ip = ip + ":" + sport - } - addr = ip // override the original address - node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr)) + node.Addr = ip + node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(node.Addr)) } - log.Log("select IP:", node.Addr, node.IPs, addr) - return addr, nil + return node.Addr, nil } diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 5ebbe66..0ecc3d0 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -94,6 +94,9 @@ func initChain() (*gost.Chain, error) { return nil, err } + id := 1 // start from 1 + + node.ID = id ngroup := gost.NewNodeGroup(node) // parse node peers if exists @@ -110,6 +113,8 @@ func initChain() (*gost.Chain, error) { if err != nil { return nil, err } + id++ + node.ID = id ngroup.AddNode(node) } @@ -126,6 +131,15 @@ func parseChainNode(ns string) (node gost.Node, err error) { } node.IPs = parseIP(node.Values.Get("ip")) + for i, ip := range node.IPs { + if !strings.Contains(ip, ":") { + _, sport, _ := net.SplitHostPort(node.Addr) + if sport == "" { + sport = "8080" // default port + } + node.IPs[i] = ip + ":" + sport + } + } node.IPSelector = &gost.RoundRobinIPSelector{} users, err := parseUsers(node.Values.Get("secrets")) @@ -592,11 +606,12 @@ func loadPeerConfig(peer string) (config peerConfig, err error) { func parseStrategy(s string) gost.Strategy { switch s { - case "round": - return &gost.RoundStrategy{} case "random": + return &gost.RandomStrategy{} + case "round": fallthrough default: - return &gost.RandomStrategy{} + return &gost.RoundStrategy{} + } } diff --git a/gost.go b/gost.go index 39b818f..2a8baeb 100644 --- a/gost.go +++ b/gost.go @@ -56,6 +56,7 @@ func SetLogger(logger log.Logger) { log.DefaultLogger = logger } +// GenCertificate generates a random TLS certificate func GenCertificate() (cert tls.Certificate, err error) { rawCert, rawKey, err := generateKeyPair() if err != nil { diff --git a/http.go b/http.go index 0612f50..b88942d 100644 --- a/http.go +++ b/http.go @@ -93,6 +93,13 @@ func (h *httpHandler) Handle(conn net.Conn) { return } + h.handleRequest(conn, req) +} + +func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { + if req == nil { + return + } if Debug { dump, _ := httputil.DumpRequest(req, false) log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), req.Host, string(dump)) diff --git a/node.go b/node.go index 4a2ef1a..a0fa051 100644 --- a/node.go +++ b/node.go @@ -3,10 +3,12 @@ package gost import ( "net/url" "strings" + "sync" ) // Node is a proxy node, mainly used to construct a proxy chain. type Node struct { + ID int Addr string IPs []string Protocol string @@ -81,15 +83,21 @@ func ParseNode(s string) (node Node, err error) { // NodeGroup is a group of nodes. type NodeGroup struct { - nodes []Node - Options []SelectOption - Selector NodeSelector + nodes []Node + Options []SelectOption + Selector NodeSelector + mutex sync.Mutex + mFails map[string]int // node -> fail count + MaxFails int + FailTimeout int + Retries int } // NewNodeGroup creates a node group func NewNodeGroup(nodes ...Node) *NodeGroup { return &NodeGroup{ - nodes: nodes, + nodes: nodes, + mFails: make(map[string]int), } } diff --git a/sni.go b/sni.go index d4ec608..2cb9800 100644 --- a/sni.go +++ b/sni.go @@ -11,6 +11,7 @@ import ( "hash/crc32" "io" "net" + "net/http" "strings" "sync" @@ -53,15 +54,22 @@ func (h *sniHandler) Handle(conn net.Conn) { return } conn = &bufferdConn{br: br, Conn: conn} + defer conn.Close() if hdr[0] != dissector.Handshake { - // We assume that it is HTTP request - HTTPHandler(h.options...).Handle(conn) + // We assume it is an HTTP request + req, err := http.ReadRequest(bufio.NewReader(conn)) + if !req.URL.IsAbs() { + req.URL.Scheme = "http" // make sure that the URL is absolute + } + if err != nil { + log.Logf("[sni] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + HTTPHandler(h.options...).(*httpHandler).handleRequest(conn, req) return } - defer conn.Close() - b, host, err := readClientHelloRecord(conn, "", false) if err != nil { log.Log("[sni]", err)