#141: Add load balancing support for proxy chain
This commit is contained in:
parent
18bb8ab2ca
commit
dedd08530a
72
chain.go
72
chain.go
@ -95,12 +95,17 @@ func (c *Chain) Dial(addr string) (net.Conn, error) {
|
|||||||
return net.Dial("tcp", addr)
|
return net.Dial("tcp", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, nodes, err := c.getConn()
|
route, err := c.selectRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err := nodes[len(nodes)-1].Client.Connect(conn, addr)
|
conn, err := c.getConn(route)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cc, err := route.LastNode().Client.Connect(conn, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -111,26 +116,44 @@ func (c *Chain) Dial(addr string) (net.Conn, error) {
|
|||||||
// Conn obtains a handshaked connection to the last node of the chain.
|
// Conn obtains a handshaked connection to the last node of the chain.
|
||||||
// If the chain is empty, it returns an ErrEmptyChain error.
|
// If the chain is empty, it returns an ErrEmptyChain error.
|
||||||
func (c *Chain) Conn() (conn net.Conn, err error) {
|
func (c *Chain) Conn() (conn net.Conn, err error) {
|
||||||
conn, _, err = c.getConn()
|
route, err := c.selectRoute()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err = c.getConn(route)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
func (c *Chain) selectRoute() (route *Chain, err error) {
|
||||||
if c.IsEmpty() {
|
route = NewChain()
|
||||||
|
for _, group := range c.nodeGroups {
|
||||||
|
selector := group.Selector
|
||||||
|
if selector == nil {
|
||||||
|
selector = &defaultSelector{}
|
||||||
|
}
|
||||||
|
// select node from node group
|
||||||
|
node, err := selector.Select(group.Nodes(), group.Options...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if node.Client.Transporter.Multiplex() {
|
||||||
|
node.DialOptions = append(node.DialOptions,
|
||||||
|
ChainDialOption(route),
|
||||||
|
)
|
||||||
|
route = NewChain() // cutoff the chain for multiplex
|
||||||
|
}
|
||||||
|
route.AddNode(node)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) {
|
||||||
|
if route.IsEmpty() {
|
||||||
err = ErrEmptyChain
|
err = ErrEmptyChain
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
groups := c.nodeGroups
|
nodes := route.Nodes()
|
||||||
selector := groups[0].Selector
|
node := nodes[0]
|
||||||
if selector == nil {
|
|
||||||
selector = &defaultSelector{}
|
|
||||||
}
|
|
||||||
// select node from node group
|
|
||||||
node, err := selector.Select(groups[0].Nodes(), groups[0].Options...)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nodes = append(nodes, node)
|
|
||||||
|
|
||||||
addr, err := selectIP(&node)
|
addr, err := selectIP(&node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -147,21 +170,7 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
preNode := node
|
preNode := node
|
||||||
for i := range groups {
|
for _, node := range nodes[1:] {
|
||||||
if i == len(groups)-1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
selector = groups[i+1].Selector
|
|
||||||
if selector == nil {
|
|
||||||
selector = &defaultSelector{}
|
|
||||||
}
|
|
||||||
node, err = selector.Select(groups[i+1].Nodes(), groups[i+1].Options...)
|
|
||||||
if err != nil {
|
|
||||||
cn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nodes = append(nodes, node)
|
|
||||||
|
|
||||||
addr, err = selectIP(&node)
|
addr, err = selectIP(&node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -206,6 +215,7 @@ func selectIP(node *Node) (string, error) {
|
|||||||
ip = ip + ":" + sport
|
ip = ip + ":" + sport
|
||||||
}
|
}
|
||||||
addr = ip
|
addr = ip
|
||||||
|
// override the original address
|
||||||
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr))
|
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr))
|
||||||
}
|
}
|
||||||
log.Log("select IP:", node.Addr, node.IPs, addr)
|
log.Log("select IP:", node.Addr, node.IPs, addr)
|
||||||
|
@ -94,7 +94,6 @@ func (tr *tcpTransporter) Multiplex() bool {
|
|||||||
type DialOptions struct {
|
type DialOptions struct {
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
Chain *Chain
|
Chain *Chain
|
||||||
// IPs []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOption allows a common way to set dial options.
|
// DialOption allows a common way to set dial options.
|
||||||
|
333
cmd/gost/main.go
333
cmd/gost/main.go
@ -62,6 +62,16 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
// generate random self-signed certificate.
|
||||||
|
cert, err := gost.GenCertificate()
|
||||||
|
if err != nil {
|
||||||
|
log.Log(err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
gost.DefaultTLSConfig = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
|
||||||
chain, err := initChain()
|
chain, err := initChain()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Log(err)
|
log.Log(err)
|
||||||
@ -71,162 +81,190 @@ func main() {
|
|||||||
log.Log(err)
|
log.Log(err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {}
|
select {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func initChain() (*gost.Chain, error) {
|
func initChain() (*gost.Chain, error) {
|
||||||
chain := gost.NewChain()
|
chain := gost.NewChain()
|
||||||
for _, ns := range options.ChainNodes {
|
for _, ns := range options.ChainNodes {
|
||||||
node, err := gost.ParseNode(ns)
|
// parse the base node
|
||||||
|
node, err := parseChainNode(ns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
node.IPs = parseIP(node.Values.Get("ip"))
|
ngroup := gost.NewNodeGroup(node)
|
||||||
node.IPSelector = &gost.RoundRobinIPSelector{}
|
|
||||||
|
|
||||||
users, err := parseUsers(node.Values.Get("secrets"))
|
// parse node peers if exists
|
||||||
|
peerCfg, err := loadPeerConfig(node.Values.Get("peer"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
log.Log(err)
|
||||||
}
|
}
|
||||||
if node.User == nil && len(users) > 0 {
|
ngroup.Options = append(ngroup.Options,
|
||||||
node.User = users[0]
|
// gost.WithFilter(),
|
||||||
}
|
gost.WithStrategy(parseStrategy(peerCfg.Strategy)),
|
||||||
serverName, _, _ := net.SplitHostPort(node.Addr)
|
)
|
||||||
if serverName == "" {
|
for _, s := range peerCfg.Nodes {
|
||||||
serverName = "localhost" // default server name
|
node, err = parseChainNode(s)
|
||||||
}
|
|
||||||
|
|
||||||
rootCAs, err := loadCA(node.Values.Get("ca"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tlsCfg := &tls.Config{
|
|
||||||
ServerName: serverName,
|
|
||||||
InsecureSkipVerify: !toBool(node.Values.Get("secure")),
|
|
||||||
RootCAs: rootCAs,
|
|
||||||
}
|
|
||||||
wsOpts := &gost.WSOptions{}
|
|
||||||
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
|
|
||||||
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
|
|
||||||
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
|
|
||||||
wsOpts.UserAgent = node.Values.Get("agent")
|
|
||||||
|
|
||||||
var tr gost.Transporter
|
|
||||||
switch node.Transport {
|
|
||||||
case "tls":
|
|
||||||
tr = gost.TLSTransporter()
|
|
||||||
case "mtls":
|
|
||||||
tr = gost.MTLSTransporter()
|
|
||||||
case "ws":
|
|
||||||
tr = gost.WSTransporter(wsOpts)
|
|
||||||
case "mws":
|
|
||||||
tr = gost.MWSTransporter(wsOpts)
|
|
||||||
case "wss":
|
|
||||||
tr = gost.WSSTransporter(wsOpts)
|
|
||||||
case "mwss":
|
|
||||||
tr = gost.MWSSTransporter(wsOpts)
|
|
||||||
case "kcp":
|
|
||||||
if !chain.IsEmpty() {
|
|
||||||
return nil, errors.New("KCP must be the first node in the proxy chain")
|
|
||||||
}
|
|
||||||
config, err := parseKCPConfig(node.Values.Get("c"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tr = gost.KCPTransporter(config)
|
ngroup.AddNode(node)
|
||||||
case "ssh":
|
|
||||||
if node.Protocol == "direct" || node.Protocol == "remote" {
|
|
||||||
tr = gost.SSHForwardTransporter()
|
|
||||||
} else {
|
|
||||||
tr = gost.SSHTunnelTransporter()
|
|
||||||
}
|
|
||||||
case "quic":
|
|
||||||
if !chain.IsEmpty() {
|
|
||||||
return nil, errors.New("QUIC must be the first node in the proxy chain")
|
|
||||||
}
|
|
||||||
config := &gost.QUICConfig{
|
|
||||||
TLSConfig: tlsCfg,
|
|
||||||
KeepAlive: toBool(node.Values.Get("keepalive")),
|
|
||||||
}
|
|
||||||
tr = gost.QUICTransporter(config)
|
|
||||||
case "http2":
|
|
||||||
tr = gost.HTTP2Transporter(tlsCfg)
|
|
||||||
case "h2":
|
|
||||||
tr = gost.H2Transporter(tlsCfg)
|
|
||||||
case "h2c":
|
|
||||||
tr = gost.H2CTransporter()
|
|
||||||
|
|
||||||
case "obfs4":
|
|
||||||
if err := gost.Obfs4Init(node, false); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tr = gost.Obfs4Transporter()
|
|
||||||
case "ohttp":
|
|
||||||
tr = gost.ObfsHTTPTransporter()
|
|
||||||
default:
|
|
||||||
tr = gost.TCPTransporter()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tr.Multiplex() {
|
chain.AddNodeGroup(ngroup)
|
||||||
node.DialOptions = append(node.DialOptions,
|
|
||||||
gost.ChainDialOption(chain),
|
|
||||||
)
|
|
||||||
chain = gost.NewChain() // cutoff the chain for multiplex
|
|
||||||
}
|
|
||||||
|
|
||||||
var connector gost.Connector
|
|
||||||
switch node.Protocol {
|
|
||||||
case "http2":
|
|
||||||
connector = gost.HTTP2Connector(node.User)
|
|
||||||
case "socks", "socks5":
|
|
||||||
connector = gost.SOCKS5Connector(node.User)
|
|
||||||
case "socks4":
|
|
||||||
connector = gost.SOCKS4Connector()
|
|
||||||
case "socks4a":
|
|
||||||
connector = gost.SOCKS4AConnector()
|
|
||||||
case "ss":
|
|
||||||
connector = gost.ShadowConnector(node.User)
|
|
||||||
case "direct":
|
|
||||||
connector = gost.SSHDirectForwardConnector()
|
|
||||||
case "remote":
|
|
||||||
connector = gost.SSHRemoteForwardConnector()
|
|
||||||
case "forward":
|
|
||||||
connector = gost.ForwardConnector()
|
|
||||||
case "sni":
|
|
||||||
connector = gost.SNIConnector(node.Values.Get("host"))
|
|
||||||
case "http":
|
|
||||||
fallthrough
|
|
||||||
default:
|
|
||||||
node.Protocol = "http" // default protocol is HTTP
|
|
||||||
connector = gost.HTTPConnector(node.User)
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
|
||||||
node.DialOptions = append(node.DialOptions,
|
|
||||||
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
|
|
||||||
)
|
|
||||||
|
|
||||||
interval, _ := strconv.Atoi(node.Values.Get("ping"))
|
|
||||||
retry, _ := strconv.Atoi(node.Values.Get("retry"))
|
|
||||||
node.HandshakeOptions = append(node.HandshakeOptions,
|
|
||||||
gost.AddrHandshakeOption(node.Addr),
|
|
||||||
gost.UserHandshakeOption(node.User),
|
|
||||||
gost.TLSConfigHandshakeOption(tlsCfg),
|
|
||||||
gost.IntervalHandshakeOption(time.Duration(interval)*time.Second),
|
|
||||||
gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second),
|
|
||||||
gost.RetryHandshakeOption(retry),
|
|
||||||
)
|
|
||||||
node.Client = &gost.Client{
|
|
||||||
Connector: connector,
|
|
||||||
Transporter: tr,
|
|
||||||
}
|
|
||||||
chain.AddNode(node)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return chain, nil
|
return chain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseChainNode(ns string) (node gost.Node, err error) {
|
||||||
|
node, err = gost.ParseNode(ns)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
node.IPs = parseIP(node.Values.Get("ip"))
|
||||||
|
node.IPSelector = &gost.RoundRobinIPSelector{}
|
||||||
|
|
||||||
|
users, err := parseUsers(node.Values.Get("secrets"))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if node.User == nil && len(users) > 0 {
|
||||||
|
node.User = users[0]
|
||||||
|
}
|
||||||
|
serverName, _, _ := net.SplitHostPort(node.Addr)
|
||||||
|
if serverName == "" {
|
||||||
|
serverName = "localhost" // default server name
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCAs, err := loadCA(node.Values.Get("ca"))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlsCfg := &tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
InsecureSkipVerify: !toBool(node.Values.Get("secure")),
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
}
|
||||||
|
wsOpts := &gost.WSOptions{}
|
||||||
|
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
|
||||||
|
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
|
||||||
|
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
|
||||||
|
wsOpts.UserAgent = node.Values.Get("agent")
|
||||||
|
|
||||||
|
var tr gost.Transporter
|
||||||
|
switch node.Transport {
|
||||||
|
case "tls":
|
||||||
|
tr = gost.TLSTransporter()
|
||||||
|
case "mtls":
|
||||||
|
tr = gost.MTLSTransporter()
|
||||||
|
case "ws":
|
||||||
|
tr = gost.WSTransporter(wsOpts)
|
||||||
|
case "mws":
|
||||||
|
tr = gost.MWSTransporter(wsOpts)
|
||||||
|
case "wss":
|
||||||
|
tr = gost.WSSTransporter(wsOpts)
|
||||||
|
case "mwss":
|
||||||
|
tr = gost.MWSSTransporter(wsOpts)
|
||||||
|
case "kcp":
|
||||||
|
/*
|
||||||
|
if !chain.IsEmpty() {
|
||||||
|
return nil, errors.New("KCP must be the first node in the proxy chain")
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
config, err := parseKCPConfig(node.Values.Get("c"))
|
||||||
|
if err != nil {
|
||||||
|
return node, err
|
||||||
|
}
|
||||||
|
tr = gost.KCPTransporter(config)
|
||||||
|
case "ssh":
|
||||||
|
if node.Protocol == "direct" || node.Protocol == "remote" {
|
||||||
|
tr = gost.SSHForwardTransporter()
|
||||||
|
} else {
|
||||||
|
tr = gost.SSHTunnelTransporter()
|
||||||
|
}
|
||||||
|
case "quic":
|
||||||
|
/*
|
||||||
|
if !chain.IsEmpty() {
|
||||||
|
return nil, errors.New("QUIC must be the first node in the proxy chain")
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
config := &gost.QUICConfig{
|
||||||
|
TLSConfig: tlsCfg,
|
||||||
|
KeepAlive: toBool(node.Values.Get("keepalive")),
|
||||||
|
}
|
||||||
|
tr = gost.QUICTransporter(config)
|
||||||
|
case "http2":
|
||||||
|
tr = gost.HTTP2Transporter(tlsCfg)
|
||||||
|
case "h2":
|
||||||
|
tr = gost.H2Transporter(tlsCfg)
|
||||||
|
case "h2c":
|
||||||
|
tr = gost.H2CTransporter()
|
||||||
|
|
||||||
|
case "obfs4":
|
||||||
|
if err := gost.Obfs4Init(node, false); err != nil {
|
||||||
|
return node, err
|
||||||
|
}
|
||||||
|
tr = gost.Obfs4Transporter()
|
||||||
|
case "ohttp":
|
||||||
|
tr = gost.ObfsHTTPTransporter()
|
||||||
|
default:
|
||||||
|
tr = gost.TCPTransporter()
|
||||||
|
}
|
||||||
|
|
||||||
|
var connector gost.Connector
|
||||||
|
switch node.Protocol {
|
||||||
|
case "http2":
|
||||||
|
connector = gost.HTTP2Connector(node.User)
|
||||||
|
case "socks", "socks5":
|
||||||
|
connector = gost.SOCKS5Connector(node.User)
|
||||||
|
case "socks4":
|
||||||
|
connector = gost.SOCKS4Connector()
|
||||||
|
case "socks4a":
|
||||||
|
connector = gost.SOCKS4AConnector()
|
||||||
|
case "ss":
|
||||||
|
connector = gost.ShadowConnector(node.User)
|
||||||
|
case "direct":
|
||||||
|
connector = gost.SSHDirectForwardConnector()
|
||||||
|
case "remote":
|
||||||
|
connector = gost.SSHRemoteForwardConnector()
|
||||||
|
case "forward":
|
||||||
|
connector = gost.ForwardConnector()
|
||||||
|
case "sni":
|
||||||
|
connector = gost.SNIConnector(node.Values.Get("host"))
|
||||||
|
case "http":
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
node.Protocol = "http" // default protocol is HTTP
|
||||||
|
connector = gost.HTTPConnector(node.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
||||||
|
node.DialOptions = append(node.DialOptions,
|
||||||
|
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
|
||||||
|
)
|
||||||
|
|
||||||
|
interval, _ := strconv.Atoi(node.Values.Get("ping"))
|
||||||
|
retry, _ := strconv.Atoi(node.Values.Get("retry"))
|
||||||
|
node.HandshakeOptions = append(node.HandshakeOptions,
|
||||||
|
gost.AddrHandshakeOption(node.Addr),
|
||||||
|
gost.UserHandshakeOption(node.User),
|
||||||
|
gost.TLSConfigHandshakeOption(tlsCfg),
|
||||||
|
gost.IntervalHandshakeOption(time.Duration(interval)*time.Second),
|
||||||
|
gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second),
|
||||||
|
gost.RetryHandshakeOption(retry),
|
||||||
|
)
|
||||||
|
node.Client = &gost.Client{
|
||||||
|
Connector: connector,
|
||||||
|
Transporter: tr,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func serve(chain *gost.Chain) error {
|
func serve(chain *gost.Chain) error {
|
||||||
for _, ns := range options.ServeNodes {
|
for _, ns := range options.ServeNodes {
|
||||||
node, err := gost.ParseNode(ns)
|
node, err := gost.ParseNode(ns)
|
||||||
@ -533,3 +571,32 @@ func parseIP(s string) (ips []string) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type peerConfig struct {
|
||||||
|
Strategy string `json:"strategy"`
|
||||||
|
Filters []string `json:"filters"`
|
||||||
|
Nodes []string `json:"nodes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadPeerConfig(peer string) (config peerConfig, err error) {
|
||||||
|
if peer == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content, err := ioutil.ReadFile(peer)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(content, &config)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStrategy(s string) gost.Strategy {
|
||||||
|
switch s {
|
||||||
|
case "round":
|
||||||
|
return &gost.RoundStrategy{}
|
||||||
|
case "random":
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return &gost.RandomStrategy{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
26
gost.go
26
gost.go
@ -38,7 +38,7 @@ var (
|
|||||||
// PingTimeout is the timeout for pinging.
|
// PingTimeout is the timeout for pinging.
|
||||||
PingTimeout = 30 * time.Second
|
PingTimeout = 30 * time.Second
|
||||||
// PingRetries is the reties of ping.
|
// PingRetries is the reties of ping.
|
||||||
PingRetries = 3
|
PingRetries = 1
|
||||||
// default udp node TTL in second for udp port forwarding.
|
// default udp node TTL in second for udp port forwarding.
|
||||||
defaultTTL = 60 * time.Second
|
defaultTTL = 60 * time.Second
|
||||||
)
|
)
|
||||||
@ -51,27 +51,19 @@ var (
|
|||||||
DefaultUserAgent = "Chrome/60.0.3112.90"
|
DefaultUserAgent = "Chrome/60.0.3112.90"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
rawCert, rawKey, err := generateKeyPair()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
cert, err := tls.X509KeyPair(rawCert, rawKey)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
DefaultTLSConfig = &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
}
|
|
||||||
|
|
||||||
// log.DefaultLogger = &LogLogger{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLogger sets a new logger for internal log system
|
// SetLogger sets a new logger for internal log system
|
||||||
func SetLogger(logger log.Logger) {
|
func SetLogger(logger log.Logger) {
|
||||||
log.DefaultLogger = logger
|
log.DefaultLogger = logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenCertificate() (cert tls.Certificate, err error) {
|
||||||
|
rawCert, rawKey, err := generateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return tls.X509KeyPair(rawCert, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
func generateKeyPair() (rawCert, rawKey []byte, err error) {
|
func generateKeyPair() (rawCert, rawKey []byte, err error) {
|
||||||
// Create private key and self-signed certificate
|
// Create private key and self-signed certificate
|
||||||
// Adapted from https://golang.org/src/crypto/tls/generate_cert.go
|
// Adapted from https://golang.org/src/crypto/tls/generate_cert.go
|
||||||
|
1
quic.go
1
quic.go
@ -194,6 +194,7 @@ func (l *quicListener) sessionLoop(session quic.Session) {
|
|||||||
stream, err := session.AcceptStream()
|
stream, err := session.AcceptStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Log("[quic] accept stream:", err)
|
log.Log("[quic] accept stream:", err)
|
||||||
|
session.Close(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
69
selector.go
69
selector.go
@ -11,14 +11,10 @@ var (
|
|||||||
ErrNoneAvailable = errors.New("none available")
|
ErrNoneAvailable = errors.New("none available")
|
||||||
)
|
)
|
||||||
|
|
||||||
// SelectOption used when making a select call
|
|
||||||
type SelectOption func(*SelectOptions)
|
|
||||||
|
|
||||||
// NodeSelector as a mechanism to pick nodes and mark their status.
|
// NodeSelector as a mechanism to pick nodes and mark their status.
|
||||||
type NodeSelector interface {
|
type NodeSelector interface {
|
||||||
Select(nodes []Node, opts ...SelectOption) (Node, error)
|
Select(nodes []Node, opts ...SelectOption) (Node, error)
|
||||||
// Mark(node Node)
|
// Mark(node Node)
|
||||||
String() string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type defaultSelector struct {
|
type defaultSelector struct {
|
||||||
@ -26,35 +22,70 @@ type defaultSelector struct {
|
|||||||
|
|
||||||
func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) {
|
func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) {
|
||||||
sopts := SelectOptions{
|
sopts := SelectOptions{
|
||||||
Strategy: defaultStrategy,
|
Strategy: &RoundStrategy{},
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&sopts)
|
opt(&sopts)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, filter := range sopts.Filters {
|
for _, filter := range sopts.Filters {
|
||||||
nodes = filter(nodes)
|
nodes = filter.Filter(nodes)
|
||||||
}
|
}
|
||||||
if len(nodes) == 0 {
|
if len(nodes) == 0 {
|
||||||
return Node{}, ErrNoneAvailable
|
return Node{}, ErrNoneAvailable
|
||||||
}
|
}
|
||||||
return sopts.Strategy(nodes), nil
|
return sopts.Strategy.Apply(nodes), nil
|
||||||
}
|
|
||||||
|
|
||||||
func (s *defaultSelector) String() string {
|
|
||||||
return "default"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter is used to filter a node during the selection process
|
// Filter is used to filter a node during the selection process
|
||||||
type Filter func([]Node) []Node
|
type Filter interface {
|
||||||
|
Filter([]Node) []Node
|
||||||
|
}
|
||||||
|
|
||||||
// Strategy is a selection strategy e.g random, round robin
|
// Strategy is a selection strategy e.g random, round robin
|
||||||
type Strategy func([]Node) Node
|
type Strategy interface {
|
||||||
|
Apply([]Node) Node
|
||||||
func defaultStrategy(nodes []Node) Node {
|
String() string
|
||||||
return nodes[0]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoundStrategy is a strategy for node selector
|
||||||
|
type RoundStrategy struct {
|
||||||
|
count uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies the round robin strategy for the nodes
|
||||||
|
func (s *RoundStrategy) Apply(nodes []Node) Node {
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return Node{}
|
||||||
|
}
|
||||||
|
old := s.count
|
||||||
|
atomic.AddUint64(&s.count, 1)
|
||||||
|
return nodes[int(old%uint64(len(nodes)))]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RoundStrategy) String() string {
|
||||||
|
return "round"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomStrategy is a strategy for node selector
|
||||||
|
type RandomStrategy struct{}
|
||||||
|
|
||||||
|
// Apply applies the random strategy for the nodes
|
||||||
|
func (s *RandomStrategy) Apply(nodes []Node) Node {
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return Node{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes[time.Now().Nanosecond()%len(nodes)]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RandomStrategy) String() string {
|
||||||
|
return "random"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectOption used when making a select call
|
||||||
|
type SelectOption func(*SelectOptions)
|
||||||
|
|
||||||
// SelectOptions is the options for node selection
|
// SelectOptions is the options for node selection
|
||||||
type SelectOptions struct {
|
type SelectOptions struct {
|
||||||
Filters []Filter
|
Filters []Filter
|
||||||
@ -108,9 +139,9 @@ func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
|
|||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
old := s.count
|
||||||
count := atomic.AddUint64(&s.count, 1)
|
atomic.AddUint64(&s.count, 1)
|
||||||
return ips[int(count%uint64(len(ips)))], nil
|
return ips[int(old%uint64(len(ips)))], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RoundRobinIPSelector) String() string {
|
func (s *RoundRobinIPSelector) String() string {
|
||||||
|
Loading…
Reference in New Issue
Block a user