gost_software/util.go

176 lines
3.5 KiB
Go

package main
import (
"crypto/tls"
"errors"
"fmt"
"github.com/golang/glog"
"io"
"net"
"net/url"
"strings"
"time"
)
const (
keepAliveTime = 180 * time.Second
)
type strSlice []string
func (ss *strSlice) String() string {
return fmt.Sprintf("%s", *ss)
}
func (ss *strSlice) Set(value string) error {
*ss = append(*ss, value)
return nil
}
// admin:123456@localhost:8080
type Args struct {
Addr string // host:port
Protocol string // protocol: http/socks(5)/ss
Transport string // transport: ws(s)/tls/tcp/udp/rtcp/rudp
Remote string // remote address, used by tcp/udp port forwarding
User *url.Userinfo // authentication for proxy
Cert tls.Certificate // tls certificate
}
func (args Args) String() string {
var authUser, authPass string
if args.User != nil {
authUser = args.User.Username()
authPass, _ = args.User.Password()
}
return fmt.Sprintf("host: %s, protocol: %s, transport: %s, remote: %s, auth: %s/%s",
args.Addr, args.Protocol, args.Transport, args.Remote, authUser, authPass)
}
func parseArgs(ss []string) (args []Args) {
for _, s := range ss {
if !strings.Contains(s, "://") {
s = "auto://" + s
}
u, err := url.Parse(s)
if err != nil {
glog.V(LWARNING).Infoln(err)
continue
}
arg := Args{
Addr: u.Host,
User: u.User,
Cert: tlsCert,
}
schemes := strings.Split(u.Scheme, "+")
if len(schemes) == 1 {
arg.Protocol = schemes[0]
arg.Transport = schemes[0]
}
if len(schemes) == 2 {
arg.Protocol = schemes[0]
arg.Transport = schemes[1]
}
switch arg.Protocol {
case "http", "socks", "socks5", "ss":
case "https":
arg.Protocol = "http"
arg.Transport = "tls"
default:
arg.Protocol = ""
}
switch arg.Transport {
case "ws", "wss", "tls":
case "https":
arg.Protocol = "http"
arg.Transport = "tls"
case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding
arg.Remote = strings.Trim(u.EscapedPath(), "/")
case "rtcp", "rudp": // started from v2.1, rtcp and rudp are for remote port forwarding
arg.Remote = strings.Trim(u.EscapedPath(), "/")
default:
arg.Transport = ""
}
args = append(args, arg)
}
return
}
// Based on io.Copy, but the io.ErrShortWrite is ignored (mainly for websocket)
func Copy(dst io.Writer, src io.Reader) (written int64, err error) {
// b := make([]byte, 32*1024)
b := tcpPool.Get().([]byte)
defer tcpPool.Put(b)
for {
nr, er := src.Read(b)
//log.Println("cp r", nr, er)
if nr > 0 {
nw, ew := dst.Write(b[:nr])
//log.Println("cp w", nw, ew)
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
/*
if nr != nw {
err = io.ErrShortWrite
break
}
*/
}
if er == io.EOF {
break
}
if er != nil {
err = er
break
}
}
return
}
func Pipe(src io.Reader, dst io.Writer, ch chan<- error) {
_, err := Copy(dst, src)
ch <- err
}
func Transport(conn, conn2 net.Conn) (err error) {
rChan := make(chan error, 1)
wChan := make(chan error, 1)
go Pipe(conn, conn2, wChan)
go Pipe(conn2, conn, rChan)
select {
case err = <-wChan:
//log.Println("w exit", err)
case err = <-rChan:
//log.Println("r exit", err)
}
return
}
func setKeepAlive(conn net.Conn, d time.Duration) error {
c, ok := conn.(*net.TCPConn)
if !ok {
return errors.New("Not a TCP connection")
}
if err := c.SetKeepAlive(true); err != nil {
return err
}
if err := c.SetKeepAlivePeriod(d); err != nil {
return err
}
return nil
}