add tcp port forwarding

This commit is contained in:
rui.zheng 2016-08-29 18:11:09 +08:00
parent c63a0a0548
commit 1814806ecf
5 changed files with 93 additions and 30 deletions

63
conn.go
View File

@ -46,28 +46,22 @@ func listenAndServe(arg Args) error {
switch arg.Transport { switch arg.Transport {
case "ws": // websocket connection case "ws": // websocket connection
err = NewWs(arg).ListenAndServe() return NewWs(arg).ListenAndServe()
if err != nil {
glog.Infoln(err)
}
return err
case "wss": // websocket security connection case "wss": // websocket security connection
err = NewWs(arg).listenAndServeTLS() return NewWs(arg).listenAndServeTLS()
if err != nil {
glog.Infoln(err)
}
return err
case "tls": // tls connection case "tls": // tls connection
ln, err = tls.Listen("tcp", arg.Addr, ln, err = tls.Listen("tcp", arg.Addr,
&tls.Config{Certificates: []tls.Certificate{arg.Cert}}) &tls.Config{Certificates: []tls.Certificate{arg.Cert}})
case "tcp": case "tcp": // TCP port forwarding
fallthrough return listenAndServeTcpForward(arg)
case "udp": // UDP port forwarding
//return listenAndServeUdpForward(arg)
return nil
default: default:
ln, err = net.Listen("tcp", arg.Addr) ln, err = net.Listen("tcp", arg.Addr)
} }
if err != nil { if err != nil {
glog.Infoln(err)
return err return err
} }
@ -81,10 +75,47 @@ func listenAndServe(arg Args) error {
} }
go handleConn(conn, arg) go handleConn(conn, arg)
} }
}
func listenAndServeTcpForward(arg Args) error {
ln, err := net.Listen("tcp", arg.Addr)
if err != nil {
return err
}
for {
conn, err := ln.Accept()
if err != nil {
glog.V(LWARNING).Infoln(err)
continue
}
go handleTcpForward(conn, arg)
}
return nil return nil
} }
/*
func listenAndServeUdpForward(arg Args) error {
addr, err := net.ResolveUDPAddr("udp", arg.Addr)
if err != nil {
return err
}
ln, err := net.ListenUDP("udp", addr)
if err != nil {
return err
}
for {
b := udpPool.Get().([]byte)
defer udpPool.Put(b)
_, c, err := ln.ReadFromUDP(b)
if err != nil {
glog.V(LWARNING).Infoln(err)
continue
}
handleUdpForward(c, arg)
}
}
*/
func handleConn(conn net.Conn, arg Args) { func handleConn(conn net.Conn, arg Args) {
atomic.AddInt32(&connCounter, 1) atomic.AddInt32(&connCounter, 1)
glog.V(LINFO).Infof("%s connected, connections: %d", glog.V(LINFO).Infof("%s connected, connections: %d",
@ -99,7 +130,7 @@ func handleConn(conn net.Conn, arg Args) {
defer atomic.AddInt32(&connCounter, -1) defer atomic.AddInt32(&connCounter, -1)
defer conn.Close() defer conn.Close()
// server supported methods // socks5 server supported methods
selector := &serverSelector{ selector := &serverSelector{
methods: []uint8{ methods: []uint8{
gosocks5.MethodNoAuth, gosocks5.MethodNoAuth,
@ -134,7 +165,7 @@ func handleConn(conn net.Conn, arg Args) {
return return
} }
// http + socks5 // http or socks5
//b := make([]byte, 16*1024) //b := make([]byte, 16*1024)
b := tcpPool.Get().([]byte) b := tcpPool.Get().([]byte)
@ -212,7 +243,7 @@ func Connect(addr string) (conn net.Conn, err error) {
addr += ":80" addr += ":80"
} }
if len(forwardArgs) == 0 { if len(forwardArgs) == 0 {
return net.DialTimeout("tcp", addr, time.Second*30) return net.DialTimeout("tcp", addr, time.Second*60)
} }
var end Args var end Args

23
forward.go Normal file
View File

@ -0,0 +1,23 @@
package main
import (
"github.com/golang/glog"
"net"
)
func handleTcpForward(conn net.Conn, arg Args) {
glog.V(LINFO).Infoln("[tcp-forward] CONNECT", arg.Forward)
c, err := Connect(arg.Forward)
if err != nil {
glog.V(LWARNING).Infoln("[tcp-forward] CONNECT", arg.Forward, err)
return
}
defer c.Close()
glog.V(LINFO).Infoln("[tcp-forward] CONNECT", arg.Forward, "OK")
Transport(conn, c)
}
func handleUdpForward(conn *net.UDPConn, arg Args) {
}

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/golang/glog" "github.com/golang/glog"
"os" "os"
"runtime"
"sync" "sync"
) )
@ -18,7 +19,7 @@ const (
) )
const ( const (
Version = "2.0" Version = "2.1-dev"
) )
var ( var (
@ -45,6 +46,7 @@ func main() {
} }
if pv { if pv {
fmt.Fprintln(os.Stderr, "gost", Version) fmt.Fprintln(os.Stderr, "gost", Version)
fmt.Fprintln(os.Stderr, runtime.Version())
return return
} }
@ -60,7 +62,7 @@ func main() {
wg.Add(1) wg.Add(1)
go func(arg Args) { go func(arg Args) {
defer wg.Done() defer wg.Done()
listenAndServe(arg) glog.V(LERROR).Infoln(listenAndServe(arg))
}(args) }(args)
} }
wg.Wait() wg.Wait()

2
tls.go
View File

@ -6,6 +6,8 @@ import (
) )
const ( const (
// This is the default cert file for convenience, providing your own cert is recommended.
rawCert = `-----BEGIN CERTIFICATE----- rawCert = `-----BEGIN CERTIFICATE-----
MIIC5jCCAdCgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD MIIC5jCCAdCgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
bzAeFw0xNDAzMTcwNjIwNTFaFw0xNTAzMTcwNjIwNTFaMBIxEDAOBgNVBAoTB0Fj bzAeFw0xNDAzMTcwNjIwNTFaFw0xNTAzMTcwNjIwNTFaMBIxEDAOBgNVBAoTB0Fj

23
util.go
View File

@ -23,8 +23,9 @@ func (ss *strSlice) Set(value string) error {
// admin:123456@localhost:8080 // admin:123456@localhost:8080
type Args struct { type Args struct {
Addr string // host:port Addr string // host:port
Protocol string // protocol: http&socks5/http/socks/socks5/ss, default is http&socks5 Protocol string // protocol: http/socks(5)/ss
Transport string // transport: tcp/ws/tls, default is tcp(raw tcp) Transport string // transport: ws(s)/tls/tcp/udp
Forward string // forward address, used by tcp/udp port forwarding
User *url.Userinfo User *url.Userinfo
Cert tls.Certificate // tls certificate Cert tls.Certificate // tls certificate
} }
@ -35,14 +36,14 @@ func (args Args) String() string {
authUser = args.User.Username() authUser = args.User.Username()
authPass, _ = args.User.Password() authPass, _ = args.User.Password()
} }
return fmt.Sprintf("host: %s, protocol: %s, transport: %s, auth: %s:%s", return fmt.Sprintf("host: %s, protocol: %s, transport: %s, forward: %s, auth: %s/%s",
args.Addr, args.Protocol, args.Transport, authUser, authPass) args.Addr, args.Protocol, args.Transport, args.Forward, authUser, authPass)
} }
func parseArgs(ss []string) (args []Args) { func parseArgs(ss []string) (args []Args) {
for _, s := range ss { for _, s := range ss {
if !strings.Contains(s, "://") { if !strings.Contains(s, "://") {
s = "tcp://" + s s = "auto://" + s
} }
u, err := url.Parse(s) u, err := url.Parse(s)
if err != nil { if err != nil {
@ -54,6 +55,7 @@ func parseArgs(ss []string) (args []Args) {
Addr: u.Host, Addr: u.Host,
User: u.User, User: u.User,
Cert: tlsCert, Cert: tlsCert,
Forward: strings.Trim(u.EscapedPath(), "/"),
} }
schemes := strings.Split(u.Scheme, "+") schemes := strings.Split(u.Scheme, "+")
@ -69,12 +71,15 @@ func parseArgs(ss []string) (args []Args) {
switch arg.Protocol { switch arg.Protocol {
case "http", "socks", "socks5", "ss": case "http", "socks", "socks5", "ss":
default: default:
arg.Protocol = "default" arg.Protocol = ""
} }
switch arg.Transport { switch arg.Transport {
case "ws", "wss", "tls", "tcp": case "ws", "wss", "tls":
case "tcp", "udp": // started from v2.1, tcp and udp are for port forwarding
arg.Protocol = ""
default: default:
arg.Transport = "tcp" arg.Transport = ""
} }
args = append(args, arg) args = append(args, arg)
@ -83,7 +88,7 @@ func parseArgs(ss []string) (args []Args) {
return return
} }
// based on io.Copy // Based on io.Copy, but the io.ErrShortWrite is ignored (mainly for websocket)
func Copy(dst io.Writer, src io.Reader) (written int64, err error) { func Copy(dst io.Writer, src io.Reader) (written int64, err error) {
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)