add node ID

This commit is contained in:
rui.zheng 2017-11-07 17:55:18 +08:00
parent e42d27c368
commit 8cb2269159
6 changed files with 87 additions and 58 deletions

View File

@ -1,9 +1,10 @@
package gost package gost
import ( import (
"bytes"
"errors" "errors"
"fmt"
"net" "net"
"strings"
"github.com/go-log/log" "github.com/go-log/log"
) )
@ -124,29 +125,6 @@ func (c *Chain) Conn() (conn net.Conn, err error) {
return return
} }
func (c *Chain) selectRoute() (route *Chain, err error) {
route = NewChain()
for _, group := range c.nodeGroups {
selector := group.Selector
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err := selector.Select(group.Nodes(), group.Options...)
if err != nil {
return nil, err
}
if node.Client.Transporter.Multiplex() {
node.DialOptions = append(node.DialOptions,
ChainDialOption(route),
)
route = NewChain() // cutoff the chain for multiplex
}
route.AddNode(node)
}
return
}
func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) {
if route.IsEmpty() { if route.IsEmpty() {
err = ErrEmptyChain err = ErrEmptyChain
@ -155,11 +133,7 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) {
nodes := route.Nodes() nodes := route.Nodes()
node := nodes[0] node := nodes[0]
addr, err := selectIP(&node) cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil {
return
}
cn, err := node.Client.Dial(addr, node.DialOptions...)
if err != nil { if err != nil {
return return
} }
@ -171,13 +145,8 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) {
preNode := node preNode := node
for _, node := range nodes[1:] { for _, node := range nodes[1:] {
addr, err = selectIP(&node)
if err != nil {
return
}
var cc net.Conn var cc net.Conn
cc, err = preNode.Client.Connect(cn, addr) cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil { if err != nil {
cn.Close() cn.Close()
return return
@ -195,8 +164,37 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) {
return return
} }
func (c *Chain) selectRoute() (route *Chain, err error) {
buf := bytes.Buffer{}
route = NewChain()
for _, group := range c.nodeGroups {
selector := group.Selector
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err := selector.Select(group.Nodes(), group.Options...)
if err != nil {
return nil, err
}
if _, err := selectIP(&node); err != nil {
return nil, err
}
buf.WriteString(fmt.Sprintf("%d@%s -> ", node.ID, node.Addr))
if node.Client.Transporter.Multiplex() {
node.DialOptions = append(node.DialOptions,
ChainDialOption(route),
)
route = NewChain() // cutoff the chain for multiplex
}
route.AddNode(node)
}
log.Log("select route:", buf.String())
return
}
func selectIP(node *Node) (string, error) { func selectIP(node *Node) (string, error) {
addr := node.Addr
s := node.IPSelector s := node.IPSelector
if s == nil { if s == nil {
s = &RandomIPSelector{} s = &RandomIPSelector{}
@ -207,17 +205,9 @@ func selectIP(node *Node) (string, error) {
return "", err return "", err
} }
if ip != "" { if ip != "" {
if !strings.Contains(ip, ":") {
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}
ip = ip + ":" + sport
}
addr = ip
// override the original address // override the original address
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr)) node.Addr = ip
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(node.Addr))
} }
log.Log("select IP:", node.Addr, node.IPs, addr) return node.Addr, nil
return addr, nil
} }

View File

@ -94,6 +94,9 @@ func initChain() (*gost.Chain, error) {
return nil, err return nil, err
} }
id := 1 // start from 1
node.ID = id
ngroup := gost.NewNodeGroup(node) ngroup := gost.NewNodeGroup(node)
// parse node peers if exists // parse node peers if exists
@ -110,6 +113,8 @@ func initChain() (*gost.Chain, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
id++
node.ID = id
ngroup.AddNode(node) ngroup.AddNode(node)
} }
@ -126,6 +131,15 @@ func parseChainNode(ns string) (node gost.Node, err error) {
} }
node.IPs = parseIP(node.Values.Get("ip")) node.IPs = parseIP(node.Values.Get("ip"))
for i, ip := range node.IPs {
if !strings.Contains(ip, ":") {
_, sport, _ := net.SplitHostPort(node.Addr)
if sport == "" {
sport = "8080" // default port
}
node.IPs[i] = ip + ":" + sport
}
}
node.IPSelector = &gost.RoundRobinIPSelector{} node.IPSelector = &gost.RoundRobinIPSelector{}
users, err := parseUsers(node.Values.Get("secrets")) users, err := parseUsers(node.Values.Get("secrets"))
@ -592,11 +606,12 @@ func loadPeerConfig(peer string) (config peerConfig, err error) {
func parseStrategy(s string) gost.Strategy { func parseStrategy(s string) gost.Strategy {
switch s { switch s {
case "round":
return &gost.RoundStrategy{}
case "random": case "random":
return &gost.RandomStrategy{}
case "round":
fallthrough fallthrough
default: default:
return &gost.RandomStrategy{} return &gost.RoundStrategy{}
} }
} }

View File

@ -56,6 +56,7 @@ func SetLogger(logger log.Logger) {
log.DefaultLogger = logger log.DefaultLogger = logger
} }
// GenCertificate generates a random TLS certificate
func GenCertificate() (cert tls.Certificate, err error) { func GenCertificate() (cert tls.Certificate, err error) {
rawCert, rawKey, err := generateKeyPair() rawCert, rawKey, err := generateKeyPair()
if err != nil { if err != nil {

View File

@ -93,6 +93,13 @@ func (h *httpHandler) Handle(conn net.Conn) {
return return
} }
h.handleRequest(conn, req)
}
func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
if req == nil {
return
}
if Debug { if Debug {
dump, _ := httputil.DumpRequest(req, false) dump, _ := httputil.DumpRequest(req, false)
log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), req.Host, string(dump)) log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), req.Host, string(dump))

View File

@ -3,10 +3,12 @@ package gost
import ( import (
"net/url" "net/url"
"strings" "strings"
"sync"
) )
// Node is a proxy node, mainly used to construct a proxy chain. // Node is a proxy node, mainly used to construct a proxy chain.
type Node struct { type Node struct {
ID int
Addr string Addr string
IPs []string IPs []string
Protocol string Protocol string
@ -84,12 +86,18 @@ type NodeGroup struct {
nodes []Node nodes []Node
Options []SelectOption Options []SelectOption
Selector NodeSelector Selector NodeSelector
mutex sync.Mutex
mFails map[string]int // node -> fail count
MaxFails int
FailTimeout int
Retries int
} }
// NewNodeGroup creates a node group // NewNodeGroup creates a node group
func NewNodeGroup(nodes ...Node) *NodeGroup { func NewNodeGroup(nodes ...Node) *NodeGroup {
return &NodeGroup{ return &NodeGroup{
nodes: nodes, nodes: nodes,
mFails: make(map[string]int),
} }
} }

16
sni.go
View File

@ -11,6 +11,7 @@ import (
"hash/crc32" "hash/crc32"
"io" "io"
"net" "net"
"net/http"
"strings" "strings"
"sync" "sync"
@ -53,14 +54,21 @@ func (h *sniHandler) Handle(conn net.Conn) {
return return
} }
conn = &bufferdConn{br: br, Conn: conn} conn = &bufferdConn{br: br, Conn: conn}
defer conn.Close()
if hdr[0] != dissector.Handshake { if hdr[0] != dissector.Handshake {
// We assume that it is HTTP request // We assume it is an HTTP request
HTTPHandler(h.options...).Handle(conn) req, err := http.ReadRequest(bufio.NewReader(conn))
if !req.URL.IsAbs() {
req.URL.Scheme = "http" // make sure that the URL is absolute
}
if err != nil {
log.Logf("[sni] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
HTTPHandler(h.options...).(*httpHandler).handleRequest(conn, req)
return return
} }
defer conn.Close()
b, host, err := readClientHelloRecord(conn, "", false) b, host, err := readClientHelloRecord(conn, "", false)
if err != nil { if err != nil {