diff --git a/.gitignore b/.gitignore index 7840d7e..57726ad 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ _testmain.go *.test *.bak + +cmd/gost \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 6c8bef2..0000000 --- a/.travis.yml +++ /dev/null @@ -1,5 +0,0 @@ -language: go - -go: - 1.6 - 1.7 \ No newline at end of file diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 2033b3a..0000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2016 ginuerzh - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md deleted file mode 100644 index 3b8eb53..0000000 --- a/README.md +++ /dev/null @@ -1,358 +0,0 @@ -gost - GO Simple Tunnel -====== - -### GO语言实现的安全隧道 - -[English README](README_en.md) - -特性 ------- -* 可同时监听多端口 -* 可设置转发代理,支持多级转发(代理链) -* 支持标准HTTP/HTTPS/SOCKS4(A)/SOCKS5代理协议 -* SOCKS5代理支持TLS协商加密 -* Tunnel UDP over TCP -* 支持Shadowsocks协议 (OTA: 2.2+,UDP: 2.4+) -* 支持本地/远程端口转发 (2.1+) -* 支持HTTP 2.0 (2.2+) -* 实验性支持QUIC (2.3+) -* 支持KCP协议 (2.3+) -* 透明代理 (2.3+) -* SSH隧道 (2.4+) - -二进制文件下载:https://github.com/ginuerzh/gost/releases - -Google讨论组: https://groups.google.com/d/forum/go-gost - -在gost中,gost与其他代理服务都被看作是代理节点,gost可以自己处理请求,或者将请求转发给任意一个或多个代理节点。 - -参数说明 ------- -#### 代理及代理链 - -适用于-L和-F参数 - -```bash -[scheme://][user:pass@host]:port -``` -scheme分为两部分: protocol+transport - -protocol: 代理协议类型(http, socks4(a), socks5, shadowsocks), transport: 数据传输方式(ws, wss, tls, http2, quic, kcp, pht), 二者可以任意组合,或单独使用: - -> http - HTTP代理: http://:8080 - -> http+tls - HTTPS代理(可能需要提供受信任的证书): http+tls://:443或https://:443 - -> http2 - HTTP2代理并向下兼容HTTPS代理: http2://:443 - -> socks4(a) - 标准SOCKS4(A)代理: socks4://:1080或socks4a://:1080 - -> socks - 标准SOCKS5代理(支持TLS协商加密): socks://:1080 - -> socks+wss - SOCKS5代理,使用websocket传输数据: socks+wss://:1080 - -> tls - HTTPS/SOCKS5代理,使用TLS传输数据: tls://:443 - -> ss - Shadowsocks代理,ss://chacha20:123456@:8338 - -> ssu - Shadowsocks UDP relay,ssu://chacha20:123456@:8338 - -> quic - QUIC代理,quic://:6121 - -> kcp - KCP通道,kcp://:8388或kcp://aes:123456@:8388 - -> pht - 普通HTTP通道,pht://:8080 - -> redirect - 透明代理,redirect://:12345 - -> ssh - SSH转发隧道,ssh://admin:123456@:2222 - -#### 端口转发 - -适用于-L参数 - -```bash -scheme://[bind_address]:port/[host]:hostport -``` -> scheme - 端口转发模式, 本地端口转发: tcp, udp; 远程端口转发: rtcp, rudp - -> bind_address:port - 本地/远程绑定地址 - -> host:hostport - 目标访问地址 - -#### 配置文件 - -> -C : 指定配置文件路径 - -配置文件为标准json格式: -```json -{ - "ServeNodes": [ - ":8080", - "ss://chacha20:12345678@:8338" - ], - "ChainNodes": [ - "http://192.168.1.1:8080", - "https://10.0.2.1:443" - ] -} -``` - -ServeNodes等同于-L参数,ChainNodes等同于-F参数 - -#### 开启日志 - -> -logtostderr : 输出到控制台 - -> -v=3 : 日志级别(1-5),级别越高,日志越详细(级别5将开启http2 debug) - -> -log_dir=/log/dir/path : 输出到目录/log/dir/path - - -使用方法 ------- -#### 不设置转发代理 - - - -* 作为标准HTTP/SOCKS5代理 -```bash -gost -L=:8080 -``` - -* 设置代理认证信息 -```bash -gost -L=admin:123456@localhost:8080 -``` - -* 多组认证信息 -```bash -gost -L=localhost:8080?secrets=secrets.txt -``` - -通过secrets参数可以为HTTP/SOCKS5代理设置多组认证信息,格式为: -```plain -# username password - -test001 123456 -test002 12345678 -``` - -* 多端口监听 -```bash -gost -L=http2://:443 -L=socks://:1080 -L=ss://aes-128-cfb:123456@:8338 -``` - -#### 设置转发代理 - - -```bash -gost -L=:8080 -F=192.168.1.1:8081 -``` - -* 转发代理认证 -```bash -gost -L=:8080 -F=http://admin:123456@192.168.1.1:8081 -``` - -#### 设置多级转发代理(代理链) - - -```bash -gost -L=:8080 -F=http+tls://192.168.1.1:443 -F=socks+ws://192.168.1.2:1080 -F=ss://aes-128-cfb:123456@192.168.1.3:8338 -F=a.b.c.d:NNNN -``` -gost按照-F设置的顺序通过代理链将请求最终转发给a.b.c.d:NNNN处理,每一个转发代理可以是任意HTTP/HTTPS/HTTP2/SOCKS5/Shadowsocks类型代理。 - -#### 本地端口转发(TCP) - -```bash -gost -L=tcp://:2222/192.168.1.1:22 -F=... -``` -将本地TCP端口2222上的数据(通过代理链)转发到192.168.1.1:22上。当代理链末端(最后一个-F参数)为SSH类型时,gost会直接使用SSH的本地端口转发功能。 -#### 本地端口转发(UDP) - -```bash -gost -L=udp://:5353/192.168.1.1:53?ttl=60 -F=... -``` -将本地UDP端口5353上的数据(通过代理链)转发到192.168.1.1:53上。 -每条转发通道都有超时时间,当超过此时间,且在此时间段内无任何数据交互,则此通道将关闭。可以通过`ttl`参数来设置超时时间,默认值为60秒。 - -**注:** 转发UDP数据时,如果有代理链,则代理链的末端(最后一个-F参数)必须是gost SOCKS5类型代理。 - -#### 远程端口转发(TCP) - -```bash -gost -L=rtcp://:2222/192.168.1.1:22 -F=... -F=socks://172.24.10.1:1080 -``` -将172.24.10.1:2222上的数据(通过代理链)转发到192.168.1.1:22上。当代理链末端(最后一个-F参数)为SSH类型时,gost会直接使用SSH的远程端口转发功能。 - -#### 远程端口转发(UDP) - -```bash -gost -L=rudp://:5353/192.168.1.1:53 -F=... -F=socks://172.24.10.1:1080 -``` -将172.24.10.1:5353上的数据(通过代理链)转发到192.168.1.1:53上。 - -**注:** 若要使用远程端口转发功能,代理链不能为空(至少要设置一个-F参数),且代理链的末端(最后一个-F参数)必须是gost SOCKS5类型代理。 - -#### HTTP2 -gost的HTTP2支持两种模式并自适应: -* 作为标准的HTTP2代理,并向下兼容HTTPS代理。 -* 作为transport(类似于wss),传输其他协议。 - -服务端: -```bash -gost -L=http2://:443 -``` -客户端: -```bash -gost -L=:8080 -F=http2://server_ip:443?ping=30 -``` - -客户端支持`ping`参数开启心跳检测(默认不开启),参数值代表心跳间隔秒数。 - -**注:** gost的代理链仅支持一个HTTP2代理节点,采用就近原则,会将第一个遇到的HTTP2代理节点视为HTTP2代理,其他HTTP2代理节点则被视为HTTPS代理。 - -#### QUIC -gost对QUIC的支持是基于[quic-go](https://github.com/lucas-clemente/quic-go)库。 - -服务端: -```bash -gost -L=quic://:6121 -``` - -客户端(Chrome): -```bash -chrome --enable-quic --proxy-server=quic://server_ip:6121 -``` - -**注:** 由于Chrome自身的限制,目前只能通过QUIC访问HTTP网站,无法访问HTTPS网站。 - -#### KCP -gost对KCP的支持是基于[kcp-go](https://github.com/xtaci/kcp-go)和[kcptun](https://github.com/xtaci/kcptun)库。 - -服务端: -```bash -gost -L=kcp://:8388 -``` - -客户端: -```bash -gost -L=:8080 -F=kcp://server_ip:8388 -``` - -或者手动指定加密方法和密码(手动指定的加密方法和密码会覆盖配置文件中的相应值) - -服务端: -```bash -gost -L=kcp://aes:123456@:8388 -``` - -客户端: -```bash -gost -L=:8080 -F=kcp://aes:123456@server_ip:8388 -``` - -gost会自动加载当前工作目录中的kcp.json(如果存在)配置文件,或者可以手动通过参数指定配置文件路径: -```bash -gost -L=kcp://:8388?c=/path/to/conf/file -``` - -**注:** 客户端若要开启KCP转发,当且仅当代理链不为空且首个代理节点(第一个-F参数)为kcp类型。 - -#### 透明代理 -基于iptables的透明代理。 - -```bash -gost -L=redirect://:12345 -F=http2://server_ip:443 -``` - -加密机制 ------- -#### HTTP -对于HTTP可以使用TLS加密整个通讯过程,即HTTPS代理: - -服务端: -```bash -gost -L=http+tls://:443 -``` -客户端: -```bash -gost -L=:8080 -F=http+tls://server_ip:443 -``` - -#### HTTP2 -gost仅支持使用TLS加密的HTTP2协议,不支持明文HTTP2传输。 - - -#### SOCKS5 -gost支持标准SOCKS5协议的no-auth(0x00)和user/pass(0x02)方法,并在此基础上扩展了两个:tls(0x80)和tls-auth(0x82),用于数据加密。 - -服务端: -```bash -gost -L=socks://:1080 -``` -客户端: -```bash -gost -L=:8080 -F=socks://server_ip:1080 -``` - -如果两端都是gost(如上)则数据传输会被加密(协商使用tls或tls-auth方法),否则使用标准SOCKS5进行通讯(no-auth或user/pass方法)。 - -**注:** 如果transport已经支持加密(wss, tls, http2, kcp),则SOCKS5不会再使用加密方法,防止不必要的双重加密。 - -#### Shadowsocks -gost对shadowsocks的支持是基于[shadowsocks-go](https://github.com/shadowsocks/shadowsocks-go)库。 - -服务端(可以通过ota参数开启OTA强制模式,开启后客户端必须使用OTA模式): -```bash -gost -L=ss://aes-128-cfb:123456@:8338?ota=1 -``` -客户端(可以通过ota参数开启OTA模式): -```bash -gost -L=:8080 -F=ss://aes-128-cfb:123456@server_ip:8338?ota=1 -``` - -##### Shadowsocks UDP relay - -目前仅服务端支持UDP,且仅支持OTA模式。 - -服务端: -```bash -gost -L=ssu://aes-128-cfb:123456@:8338 -``` - -#### TLS -gost内置了TLS证书,如果需要使用其他TLS证书,有两种方法: -* 在gost运行目录放置cert.pem(公钥)和key.pem(私钥)两个文件即可,gost会自动加载运行目录下的cert.pem和key.pem文件。 -* 使用参数指定证书文件路径: -```bash -gost -L="http2://:443?cert=/path/to/my/cert/file&key=/path/to/my/key/file" -``` - -SOCKS5 UDP数据处理 ------- -#### 不设置转发代理 - - - -gost作为标准SOCKS5代理处理UDP数据 - -#### 设置转发代理 - - - -#### 设置多个转发代理(代理链) - - - -当设置转发代理时,gost会使用UDP-over-TCP方式转发UDP数据。proxy1 - proxyN可以为任意HTTP/HTTPS/HTTP2/SOCKS5/Shadowsocks类型代理。 - -限制条件 ------- -代理链中的HTTP代理节点必须支持CONNECT方法。 - -如果要转发SOCKS5的BIND和UDP请求,代理链的末端(最后一个-F参数)必须支持gost SOCKS5类型代理。 - - - diff --git a/README_en.md b/README_en.md deleted file mode 100644 index 6c8e7e3..0000000 --- a/README_en.md +++ /dev/null @@ -1,362 +0,0 @@ -gost - GO Simple Tunnel -====== - -### A simple security tunnel written in Golang - -Features ------- -* Listening on multiple ports -* Multi-level forward proxy - proxy chain -* Standard HTTP/HTTPS/SOCKS4(A)/SOCKS5 proxy protocols support -* TLS encryption via negotiation support for SOCKS5 proxy -* Tunnel UDP over TCP -* Shadowsocks protocol support (OTA: 2.2+, UDP: 2.4+) -* Local/remote port forwarding (2.1+) -* HTTP 2.0 support (2.2+) -* Experimental QUIC support (2.3+) -* KCP protocol support (2.3+) -* Transparent proxy (2.3+) -* SSH tunnel (2.4+) - -Binary file download:https://github.com/ginuerzh/gost/releases - -Google group: https://groups.google.com/d/forum/go-gost - -Gost and other proxy services are considered to be proxy nodes, -gost can handle the request itself, or forward the request to any one or more proxy nodes. - -Parameter Description ------- -#### Proxy and proxy chain - -Effective for the -L and -F parameters - -```bash -[scheme://][user:pass@host]:port -``` -scheme can be divided into two parts: protocol+transport - -protocol: proxy protocol types (http, socks4(a), socks5, shadowsocks), -transport: data transmission mode (ws, wss, tls, http2, quic, kcp, pht), may be used in any combination or individually: - -> http - standard HTTP proxy: http://:8080 - -> http+tls - standard HTTPS proxy (may need to provide a trusted certificate): http+tls://:443 or https://:443 - -> http2 - HTTP2 proxy and backwards-compatible with HTTPS proxy: http2://:443 - -> socks4(a) - standard SOCKS4(A) proxy: socks4://:1080 or socks4a://:1080 - -> socks - standard SOCKS5 proxy: socks://:1080 - -> socks+wss - SOCKS5 over websocket: socks+wss://:1080 - -> tls - HTTPS/SOCKS5 over TLS: tls://:443 - -> ss - standard shadowsocks proxy, ss://chacha20:123456@:8338 - -> ssu - shadowsocks UDP relay,ssu://chacha20:123456@:8338 - -> quic - standard QUIC proxy, quic://:6121 - -> kcp - standard KCP tunnel,kcp://:8388 or kcp://aes:123456@:8388 - -> pht - plain HTTP tunnel, pht://:8080 - -> redirect - transparent proxy,redirect://:12345 - -> ssh - SSH tunnel, ssh://admin:123456@:2222 - -#### Port forwarding - -Effective for the -L parameter - -```bash -scheme://[bind_address]:port/[host]:hostport -``` -> scheme - forward mode, local: tcp, udp; remote: rtcp, rudp - -> bind_address:port - local/remote binding address - -> host:hostport - target address - -#### Configuration file - -> -C : specifies the configuration file path - -The configuration file is in standard JSON format: -```json -{ - "ServeNodes": [ - ":8080", - "ss://chacha20:12345678@:8338" - ], - "ChainNodes": [ - "http://192.168.1.1:8080", - "https://10.0.2.1:443" - ] -} -``` - -ServeNodes is equivalent to the -L parameter, ChainNodes is equivalent to the -F parameter. - -#### Logging - -> -logtostderr : log to console - -> -v=3 : log level (1-5),The higher the level, the more detailed the log (level 5 will enable HTTP2 debug) - -> -log_dir=/log/dir/path : log to directory /log/dir/path - -Usage ------- -#### No forward proxy - - - -* Standard HTTP/SOCKS5 proxy -```bash -gost -L=:8080 -``` - -* Proxy authentication -```bash -gost -L=admin:123456@localhost:8080 -``` - -* Multiple sets of authentication information -```bash -gost -L=localhost:8080?secrets=secrets.txt -``` - -The secrets parameter allows you to set multiple authentication information for HTTP/SOCKS5 proxies, the format is: -```plain -# username password - -test001 123456 -test002 12345678 -``` - -* Listen on multiple ports -```bash -gost -L=http2://:443 -L=socks://:1080 -L=ss://aes-128-cfb:123456@:8338 -``` - -#### Forward proxy - - -```bash -gost -L=:8080 -F=192.168.1.1:8081 -``` - -* Forward proxy authentication -```bash -gost -L=:8080 -F=http://admin:123456@192.168.1.1:8081 -``` - -#### Multi-level forward proxy - - -```bash -gost -L=:8080 -F=http+tls://192.168.1.1:443 -F=socks+ws://192.168.1.2:1080 -F=ss://aes-128-cfb:123456@192.168.1.3:8338 -F=a.b.c.d:NNNN -``` -Gost forwards the request to a.b.c.d:NNNN through the proxy chain in the order set by -F, -each forward proxy can be any HTTP/HTTPS/HTTP2/SOCKS5/Shadowsocks type. - -#### Local TCP port forwarding - -```bash -gost -L=tcp://:2222/192.168.1.1:22 -F=... -``` -The data on the local TCP port 2222 is forwarded to 192.168.1.1:22 (through the proxy chain). If the last node of the chain (the last -F parameter) is a SSH tunnel, then gost will use the local port forwarding function of SSH directly. - -#### Local UDP port forwarding - -```bash -gost -L=udp://:5353/192.168.1.1:53?ttl=60 -F=... -``` -The data on the local UDP port 5353 is forwarded to 192.168.1.1:53 (through the proxy chain). -Each forwarding channel has a timeout period. When this time is exceeded and there is no data interaction during this time period, the channel will be closed. The timeout value can be set by the `ttl` parameter. The default value is 60 seconds. - -**NOTE:** When forwarding UDP data, if there is a proxy chain, the end of the chain (the last -F parameter) must be gost SOCKS5 proxy. - -#### Remote TCP port forwarding - -```bash -gost -L=rtcp://:2222/192.168.1.1:22 -F=... -F=socks://172.24.10.1:1080 -``` -The data on 172.24.10.1:2222 is forwarded to 192.168.1.1:22 (through the proxy chain). If the last node of the chain (the last -F parameter) is a SSH tunnel, then gost will use the remote port forwarding function of SSH directly. - -#### Remote UDP port forwarding - -```bash -gost -L=rudp://:5353/192.168.1.1:53 -F=... -F=socks://172.24.10.1:1080 -``` -The data on 172.24.10.1:5353 is forwarded to 192.168.1.1:53 (through the proxy chain). - -**NOTE:** To use the remote port forwarding feature, the proxy chain can not be empty (at least one -F parameter is set) -and the end of the chain (last -F parameter) must be gost SOCKS5 proxy. - -#### HTTP2 -Gost HTTP2 supports two modes and self-adapting: -* As a standard HTTP2 proxy, and backwards-compatible with the HTTPS proxy. -* As transport (similar to wss), tunnel other protocol. - -Server: -```bash -gost -L=http2://:443 -``` -Client: -```bash -gost -L=:8080 -F=http2://server_ip:443?ping=30 -``` - -The client supports the `ping` parameter to enable heartbeat detection (which is disabled by default). -Parameter value represents heartbeat interval seconds. - -**NOTE:** The proxy chain of gost supports only one HTTP2 proxy node and the nearest rule applies, -the first HTTP2 proxy node is treated as an HTTP2 proxy, and the other HTTP2 proxy nodes are treated as HTTPS proxies. - -#### QUIC -Support for QUIC is based on library [quic-go](https://github.com/lucas-clemente/quic-go). - -Server: -```bash -gost -L=quic://:6121 -``` -Client(Chrome): -```bash -chrome --enable-quic --proxy-server=quic://server_ip:6121 -``` - -**NOTE:** Due to Chrome's limitations, it is currently only possible to access the HTTP (but not HTTPS) site through QUIC. - -#### KCP -Support for KCP is based on libraries [kcp-go](https://github.com/xtaci/kcp-go) and [kcptun](https://github.com/xtaci/kcptun). - -Server: -```bash -gost -L=kcp://:8388 -``` -Client: -```bash -gost -L=:8080 -F=kcp://server_ip:8388 -``` - -Or manually specify the encryption method and password (Manually specifying the encryption method and password overwrites the corresponding value in the configuration file) - -Server: -```bash -gost -L=kcp://aes:123456@:8388 -``` - -Client: -```bash -gost -L=:8080 -F=kcp://aes:123456@server_ip:8388 -``` - -Gost will automatically load kcp.json configuration file from current working directory if exists, -or you can use the parameter to specify the path to the file. -```bash -gost -L=kcp://:8388?c=/path/to/conf/file -``` - -**NOTE:** KCP will be enabled if and only if the proxy chain is not empty and the first proxy node (the first -F parameter) is of type KCP. - -#### Transparent proxy -Iptables-based transparent proxy - -```bash -gost -L=redirect://:12345 -F=http2://server_ip:443 -``` - -Encryption Mechanism ------- -#### HTTP -For HTTP, you can use TLS to encrypt the entire communication process, the HTTPS proxy: - -Server: -```bash -gost -L=http+tls://:443 -``` -Client: -```bash -gost -L=:8080 -F=http+tls://server_ip:443 -``` - -#### HTTP2 -Gost supports only the HTTP2 protocol that uses TLS encryption (h2) and does not support plaintext HTTP2 (h2c) transport. - - -#### SOCKS5 -Gost supports the standard SOCKS5 protocol methods: no-auth (0x00) and user/pass (0x02), -and extends two methods for data encryption: tls(0x80) and tls-auth(0x82). - -Server: -```bash -gost -L=socks://:1080 -``` -Client: -```bash -gost -L=:8080 -F=socks://server_ip:1080 -``` - -If both ends are gosts (as example above), the data transfer will be encrypted (using tls or tls-auth). -Otherwise, use standard SOCKS5 for communication (no-auth or user/pass). - -**NOTE:** If transport already supports encryption (wss, tls, http2, kcp), SOCKS5 will no longer use the encryption method to prevent unnecessary double encryption. - -#### Shadowsocks -Support for shadowsocks is based on library [shadowsocks-go](https://github.com/shadowsocks/shadowsocks-go). - -Server (The OTA mode can be enabled by the ota parameter. When enabled, the client must use OTA mode): -```bash -gost -L=ss://aes-128-cfb:123456@:8338?ota=1 -``` -Client (The OTA mode can be enabled by the ota parameter): -```bash -gost -L=:8080 -F=ss://aes-128-cfb:123456@server_ip:8338?ota=1 -``` - -##### Shadowsocks UDP relay -Currently, only the server supports UDP, and only OTA mode is supported. - -Server: -```bash -gost -L=ssu://aes-128-cfb:123456@:8338 -``` - -#### TLS -There is built-in TLS certificate in gost, if you need to use other TLS certificate, there are two ways: -* Place two files cert.pem (public key) and key.pem (private key) in the current working directory, gost will automatically load them. -* Use the parameter to specify the path to the certificate file: -```bash -gost -L="http2://:443?cert=/path/to/my/cert/file&key=/path/to/my/key/file" -``` - -SOCKS5 UDP Data Processing ------- -#### No forward proxy - - - -Gost acts as the standard SOCKS5 proxy for UDP relay. - -#### Forward proxy - - - -#### Multi-level forward proxy - - - -When forward proxies are set, gost uses UDP-over-TCP to forward UDP data, proxy1 to proxyN can be any HTTP/HTTPS/HTTP2/SOCKS5/Shadowsocks type. - -Limitation ------- -The HTTP proxy node in the proxy chain must support the CONNECT method. - -If the BIND and UDP requests for SOCKS5 are to be forwarded, the end of the chain (the last -F parameter) must be the gost SOCKS5 proxy. - - - diff --git a/chain.go b/chain.go index f4dc6dd..3cd0172 100644 --- a/chain.go +++ b/chain.go @@ -1,558 +1,110 @@ package gost import ( - "crypto/rand" - "crypto/tls" - "crypto/x509" - "encoding/base64" "errors" - "io" - "io/ioutil" "net" - "net/http" - "net/http/httputil" - "net/url" - "strconv" - "strings" - "sync" - "time" - - "github.com/ginuerzh/pht" - "github.com/golang/glog" - "github.com/lucas-clemente/quic-go/h2quic" - "golang.org/x/net/http2" ) -// Proxy chain holds a list of proxy nodes -type ProxyChain struct { - nodes []ProxyNode - lastNode *ProxyNode - http2NodeIndex int - http2Enabled bool - http2Client *http.Client - kcpEnabled bool - kcpConfig *KCPConfig - kcpSession *KCPSession - kcpMutex sync.Mutex - phtClient *pht.Client - quicClient *http.Client +var ( + // ErrEmptyChain is an error that implies the chain is empty. + ErrEmptyChain = errors.New("empty chain") +) + +// Chain is a proxy chain that holds a list of proxy nodes. +type Chain struct { + nodes []Node } -func NewProxyChain(nodes ...ProxyNode) *ProxyChain { - chain := &ProxyChain{nodes: nodes, http2NodeIndex: -1} - return chain -} - -func (c *ProxyChain) AddProxyNode(node ...ProxyNode) { - c.nodes = append(c.nodes, node...) -} - -func (c *ProxyChain) AddProxyNodeString(snode ...string) error { - for _, sn := range snode { - node, err := ParseProxyNode(sn) - if err != nil { - return err - } - c.AddProxyNode(node) +// NewChain creates a proxy chain with proxy nodes nodes. +func NewChain(nodes ...Node) *Chain { + return &Chain{ + nodes: nodes, } - return nil } -func (c *ProxyChain) Nodes() []ProxyNode { +// Nodes returns the proxy nodes that the chain holds. +func (c *Chain) Nodes() []Node { return c.nodes } -func (c *ProxyChain) GetNode(index int) *ProxyNode { - if index < len(c.nodes) { - return &c.nodes[index] +// LastNode returns the last node of the node list. +// If the chain is empty, an empty node is returns. +func (c *Chain) LastNode() Node { + if c.IsEmpty() { + return Node{} } - return nil + return c.nodes[len(c.nodes)-1] } -func (c *ProxyChain) SetNode(index int, node ProxyNode) { - if index < len(c.nodes) { - c.nodes[index] = node - } -} - -// Init initialize the proxy chain. -// KCP will be enabled if the first proxy node is KCP proxy (transport == kcp). -// HTTP2 will be enabled when at least one HTTP2 proxy node (scheme == http2) is present. -// -// NOTE: Should be called immediately when proxy nodes are ready. -func (c *ProxyChain) Init() { - length := len(c.nodes) - if length == 0 { +// AddNode appends the node(s) to the chain. +func (c *Chain) AddNode(nodes ...Node) { + if c == nil { return } - - c.lastNode = &c.nodes[length-1] - - // HTTP2 restrict: HTTP2 will be enabled when at least one HTTP2 proxy node is present. - for i, node := range c.nodes { - if node.Transport == "http2" { - glog.V(LINFO).Infoln("HTTP2 is enabled") - cfg := &tls.Config{ - InsecureSkipVerify: node.insecureSkipVerify(), - ServerName: node.serverName, - } - - caFile := node.caFile() - - if caFile != "" { - cfg.RootCAs = x509.NewCertPool() - - data, err := ioutil.ReadFile(caFile) - if err != nil { - glog.Fatal(err) - } - - if !cfg.RootCAs.AppendCertsFromPEM(data) { - glog.Fatal(err) - } - } - - c.http2NodeIndex = i - c.initHttp2Client(cfg, c.nodes[:i]...) - break // shortest chain for HTTP2 - } - } - - for i, node := range c.nodes { - if (node.Transport == "kcp" || node.Transport == "pht" || node.Transport == "quic") && i > 0 { - glog.Fatal("KCP/PHT/QUIC must be the first node in the proxy chain") - } - } - - if c.nodes[0].Transport == "kcp" { - glog.V(LINFO).Infoln("KCP is enabled") - c.kcpEnabled = true - config, err := ParseKCPConfig(c.nodes[0].Get("c")) - if err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) - } - if config == nil { - config = DefaultKCPConfig - } - if c.nodes[0].Users != nil { - config.Crypt = c.nodes[0].Users[0].Username() - config.Key, _ = c.nodes[0].Users[0].Password() - } - c.kcpConfig = config - go snmpLogger(config.SnmpLog, config.SnmpPeriod) - go kcpSigHandler() - - return - } - - if c.nodes[0].Transport == "quic" { - glog.V(LINFO).Infoln("QUIC is enabled") - c.quicClient = &http.Client{ - Transport: &h2quic.QuicRoundTripper{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: c.nodes[0].insecureSkipVerify(), - ServerName: c.nodes[0].serverName, - }, - }, - } - } - - if c.nodes[0].Transport == "pht" { - glog.V(LINFO).Infoln("Pure HTTP mode is enabled") - c.phtClient = pht.NewClient(c.nodes[0].Addr, c.nodes[0].Get("key")) - } + c.nodes = append(c.nodes, nodes...) } -func (c *ProxyChain) KCPEnabled() bool { - return c.kcpEnabled +// IsEmpty checks if the chain is empty. +// An empty chain means that there is no proxy node in the chain. +func (c *Chain) IsEmpty() bool { + return c == nil || len(c.nodes) == 0 } -func (c *ProxyChain) Http2Enabled() bool { - return c.http2Enabled -} - -// Wrap a net.Conn into a client tls connection, performing any -// additional verification as needed. -// -// As of go 1.3, crypto/tls only supports either doing no certificate -// verification, or doing full verification including of the peer's -// DNS name. For consul, we want to validate that the certificate is -// signed by a known CA, but because consul doesn't use DNS names for -// node names, we don't verify the certificate DNS names. Since go 1.3 -// no longer supports this mode of operation, we have to do it -// manually. -// -// This code is taken from consul: -// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go -func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { - var err error - var tlsConn *tls.Conn - - tlsConn = tls.Client(conn, tlsConfig) - - // If crypto/tls is doing verification, there's no need to do our own. - if tlsConfig.InsecureSkipVerify == false { - return tlsConn, nil +// Dial connects to the target address addr through the chain. +// If the chain is empty, it will use the net.Dial directly. +func (c *Chain) Dial(addr string) (net.Conn, error) { + if c.IsEmpty() { + return net.Dial("tcp", addr) } - // Similarly if we use host's CA, we can do full handshake - if tlsConfig.RootCAs == nil { - return tlsConn, nil - } - - // Otherwise perform handshake, but don't verify the domain - // - // The following is lightly-modified from the doFullHandshake - // method in https://golang.org/src/crypto/tls/handshake_client.go - if err = tlsConn.Handshake(); err != nil { - tlsConn.Close() - return nil, err - } - - opts := x509.VerifyOptions{ - Roots: tlsConfig.RootCAs, - CurrentTime: time.Now(), - DNSName: "", - Intermediates: x509.NewCertPool(), - } - - certs := tlsConn.ConnectionState().PeerCertificates - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - - _, err = certs[0].Verify(opts) + conn, err := c.Conn() if err != nil { - tlsConn.Close() return nil, err } - return tlsConn, err + cc, err := c.LastNode().Client.Connect(conn, addr) + if err != nil { + conn.Close() + return nil, err + } + return cc, nil } -func (c *ProxyChain) initHttp2Client(config *tls.Config, nodes ...ProxyNode) { - if c.http2NodeIndex < 0 || c.http2NodeIndex >= len(c.nodes) { - return - } - http2Node := c.nodes[c.http2NodeIndex] - - tr := http2.Transport{ - TLSClientConfig: config, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - // replace the default dialer with our proxy chain. - conn, err := c.dialWithNodes(false, http2Node.Addr, nodes...) - if err != nil { - return conn, err - } - - conn, err = wrapTLSClient(conn, cfg) - if err != nil { - return conn, err - } - - // enable HTTP2 ping-pong - pingIntvl, _ := strconv.Atoi(http2Node.Get("ping")) - if pingIntvl > 0 { - enablePing(conn, time.Duration(pingIntvl)*time.Second) - } - - return conn, nil - }, - } - c.http2Client = &http.Client{Transport: &tr} - c.http2Enabled = true - -} - -func enablePing(conn net.Conn, interval time.Duration) { - if conn == nil || interval == 0 { - return - } - - glog.V(LINFO).Infoln("[http2] ping enabled, interval:", interval) - go func() { - t := time.NewTicker(interval) - var framer *http2.Framer - for { - select { - case <-t.C: - if framer == nil { - framer = http2.NewFramer(conn, conn) - } - - var p [8]byte - rand.Read(p[:]) - err := framer.WritePing(false, p) - if err != nil { - t.Stop() - framer = nil - glog.V(LWARNING).Infoln("[http2] ping:", err) - return - } - } - } - }() -} - -// Connect to addr through proxy chain -func (c *ProxyChain) Dial(addr string) (net.Conn, error) { - if !strings.Contains(addr, ":") { - addr += ":80" - } - return c.dialWithNodes(true, addr, c.nodes...) -} - -// GetConn initializes a proxy chain connection, -// if no proxy nodes on this chain, it will return error -func (c *ProxyChain) GetConn() (net.Conn, error) { - nodes := c.nodes - if len(nodes) == 0 { +// Conn obtains a handshaked connection to the last node of the chain. +// If the chain is empty, it returns an ErrEmptyChain error. +func (c *Chain) Conn() (net.Conn, error) { + if c.IsEmpty() { return nil, ErrEmptyChain } - if c.Http2Enabled() { - nodes = nodes[c.http2NodeIndex+1:] - if len(nodes) == 0 { - header := make(http.Header) - header.Set("Proxy-Switch", "gost") // Flag header to indicate server to switch to HTTP2 transport mode - conn, err := c.getHttp2Conn(header) - if err != nil { - return nil, err - } - http2Node := c.nodes[c.http2NodeIndex] - if http2Node.Transport == "http2" { - http2Node.Transport = "h2" - } - if http2Node.Protocol == "http2" { - http2Node.Protocol = "socks5" // assume it as socks5 protocol, so we can do much more things. - } - pc := NewProxyConn(conn, http2Node) - if err := pc.Handshake(); err != nil { - conn.Close() - return nil, err - } - return pc, nil - } - } - return c.travelNodes(true, nodes...) -} - -func (c *ProxyChain) dialWithNodes(withHttp2 bool, addr string, nodes ...ProxyNode) (conn net.Conn, err error) { - if len(nodes) == 0 { - return net.DialTimeout("tcp", addr, DialTimeout) - } - - if withHttp2 && c.Http2Enabled() { - nodes = nodes[c.http2NodeIndex+1:] - if len(nodes) == 0 { - return c.http2Connect(addr) - } - } - - if nodes[0].Transport == "quic" { - glog.V(LINFO).Infoln("Dial with QUIC") - return c.quicConnect(addr) - } - - pc, err := c.travelNodes(withHttp2, nodes...) + nodes := c.nodes + conn, err := nodes[0].Client.Dial(nodes[0].Addr, nodes[0].DialOptions...) if err != nil { - return + return nil, err } - if err = pc.Connect(addr); err != nil { - pc.Close() - return - } - conn = pc - return -} -func (c *ProxyChain) travelNodes(withHttp2 bool, nodes ...ProxyNode) (conn *ProxyConn, err error) { - defer func() { - if err != nil && conn != nil { + conn, err = nodes[0].Client.Handshake(conn, nodes[0].HandshakeOptions...) + if err != nil { + return nil, err + } + + for i, node := range nodes { + if i == len(nodes)-1 { + break + } + + next := nodes[i+1] + cc, err := node.Client.Connect(conn, next.Addr) + if err != nil { conn.Close() - conn = nil + return nil, err } - }() - - var cc net.Conn - node := nodes[0] - - if withHttp2 && c.Http2Enabled() { - cc, err = c.http2Connect(node.Addr) - } else if node.Transport == "kcp" { - cc, err = c.getKCPConn() - } else if node.Transport == "pht" { - cc, err = c.phtClient.Dial() - } else { - cc, err = net.DialTimeout("tcp", node.Addr, DialTimeout) - } - if err != nil { - return - } - setKeepAlive(cc, KeepAliveTime) - - pc := NewProxyConn(cc, node) - conn = pc - if err = pc.Handshake(); err != nil { - return - } - - for _, node := range nodes[1:] { - if err = conn.Connect(node.Addr); err != nil { - return - } - pc := NewProxyConn(conn, node) - conn = pc - if err = pc.Handshake(); err != nil { - return + cc, err = next.Client.Handshake(cc, next.HandshakeOptions...) + if err != nil { + conn.Close() + return nil, err } + conn = cc } - return -} - -func (c *ProxyChain) initKCPSession() (err error) { - c.kcpMutex.Lock() - defer c.kcpMutex.Unlock() - - if c.kcpSession == nil || c.kcpSession.IsClosed() { - glog.V(LINFO).Infoln("[kcp] new kcp session") - c.kcpSession, err = DialKCP(c.nodes[0].Addr, c.kcpConfig) - } - return -} - -func (c *ProxyChain) getKCPConn() (conn net.Conn, err error) { - if !c.KCPEnabled() { - return nil, errors.New("KCP is not enabled") - } - - if err = c.initKCPSession(); err != nil { - return nil, err - } - return c.kcpSession.GetConn() -} - -// Initialize an HTTP2 transport if HTTP2 is enabled. -func (c *ProxyChain) getHttp2Conn(header http.Header) (net.Conn, error) { - if !c.Http2Enabled() { - return nil, errors.New("HTTP2 is not enabled") - } - http2Node := c.nodes[c.http2NodeIndex] - pr, pw := io.Pipe() - - if header == nil { - header = make(http.Header) - } - - req := http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Scheme: "https", Host: http2Node.Addr}, - Header: header, - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - Body: pr, - Host: http2Node.Addr, - ContentLength: -1, - } - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(&req, false) - glog.Infoln(string(dump)) - } - resp, err := c.http2Client.Do(&req) - if err != nil { - return nil, err - } - if glog.V(LDEBUG) { - dump, _ := httputil.DumpResponse(resp, false) - glog.Infoln(string(dump)) - } - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, errors.New(resp.Status) - } - conn := &http2Conn{r: resp.Body, w: pw} - conn.remoteAddr, _ = net.ResolveTCPAddr("tcp", http2Node.Addr) - return conn, nil -} - -// Use HTTP2 as transport to connect target addr. -// -// BUG: SOCKS5 is ignored, only HTTP supported -func (c *ProxyChain) http2Connect(addr string) (net.Conn, error) { - if !c.Http2Enabled() { - return nil, errors.New("HTTP2 is not enabled") - } - http2Node := c.nodes[c.http2NodeIndex] - - header := make(http.Header) - header.Set("Gost-Target", addr) // Flag header to indicate the address that server connected to - if http2Node.Users != nil { - header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(http2Node.Users[0].String()))) - } - return c.getHttp2Conn(header) -} - -func (c *ProxyChain) quicConnect(addr string) (net.Conn, error) { - quicNode := c.nodes[0] - header := make(http.Header) - header.Set("Gost-Target", addr) // Flag header to indicate the address that server connected to - if quicNode.Users != nil { - header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(quicNode.Users[0].String()))) - } - return c.getQuicConn(header) -} - -func (c *ProxyChain) getQuicConn(header http.Header) (net.Conn, error) { - quicNode := c.nodes[0] - pr, pw := io.Pipe() - - if header == nil { - header = make(http.Header) - } - - /* - req := http.Request{ - Method: http.MethodGet, - URL: &url.URL{Scheme: "https", Host: quicNode.Addr}, - Header: header, - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - Body: pr, - Host: quicNode.Addr, - ContentLength: -1, - } - */ - req, err := http.NewRequest(http.MethodPost, "https://"+quicNode.Addr, pr) - if err != nil { - return nil, err - } - req.ContentLength = -1 - req.Header = header - - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(req, false) - glog.Infoln(string(dump)) - } - resp, err := c.quicClient.Do(req) - if err != nil { - return nil, err - } - if glog.V(LDEBUG) { - dump, _ := httputil.DumpResponse(resp, false) - glog.Infoln(string(dump)) - } - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, errors.New(resp.Status) - } - conn := &http2Conn{r: resp.Body, w: pw} - conn.remoteAddr, _ = net.ResolveUDPAddr("udp", quicNode.Addr) return conn, nil } diff --git a/gost/client.go b/client.go similarity index 100% rename from gost/client.go rename to client.go diff --git a/cmd/gost/.gitignore b/cmd/gost/.gitignore deleted file mode 100644 index c4b36ef..0000000 --- a/cmd/gost/.gitignore +++ /dev/null @@ -1 +0,0 @@ -gost diff --git a/cmd/gost/main.go b/cmd/gost/main.go index ef39362..35f5395 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -1,22 +1,29 @@ package main import ( + "bufio" + "crypto/tls" "encoding/json" + "errors" "flag" "fmt" "io/ioutil" + "net" + "net/url" "os" "runtime" - "sync" + "strconv" + "strings" + "time" "github.com/ginuerzh/gost" - "github.com/golang/glog" - "golang.org/x/net/http2" + "github.com/go-log/log" ) var ( options struct { - ChainNodes, ServeNodes flagStringList + chainNodes, serveNodes stringList + debugMode bool } ) @@ -26,55 +33,340 @@ func init() { printVersion bool ) + flag.Var(&options.chainNodes, "F", "forward address, can make a forward chain") + flag.Var(&options.serveNodes, "L", "listen address, can listen on multiple ports") flag.StringVar(&configureFile, "C", "", "configure file") - flag.Var(&options.ChainNodes, "F", "forward address, can make a forward chain") - flag.Var(&options.ServeNodes, "L", "listen address, can listen on multiple ports") + flag.BoolVar(&options.debugMode, "D", false, "enable debug log") flag.BoolVar(&printVersion, "V", false, "print version") flag.Parse() if err := loadConfigureFile(configureFile); err != nil { - glog.Fatal(err) - } - - if glog.V(5) { - http2.VerboseLogs = true + log.Log(err) + os.Exit(1) } if flag.NFlag() == 0 { flag.PrintDefaults() - return + os.Exit(0) } if printVersion { fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version()) - return + os.Exit(0) } + + gost.Debug = options.debugMode } func main() { - chain := gost.NewProxyChain() - if err := chain.AddProxyNodeString(options.ChainNodes...); err != nil { - glog.Fatal(err) + chain, err := initChain() + if err != nil { + log.Log(err) + os.Exit(1) } - chain.Init() + if err := serve(chain); err != nil { + log.Log(err) + os.Exit(1) + } + select {} +} - var wg sync.WaitGroup - for _, ns := range options.ServeNodes { - serverNode, err := gost.ParseProxyNode(ns) +func initChain() (*gost.Chain, error) { + chain := gost.NewChain() + for _, ns := range options.chainNodes { + node, err := gost.ParseNode(ns) if err != nil { - glog.Fatal(err) + return nil, err } - glog.Info(serverNode) + serverName, _, _ := net.SplitHostPort(node.Addr) + if serverName == "" { + serverName = "localhost" // default server name + } - wg.Add(1) - go func(node gost.ProxyNode) { - defer wg.Done() - server := gost.NewProxyServer(node, chain) - glog.Fatal(server.Serve()) - }(serverNode) + tlsCfg := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: !toBool(node.Values.Get("scure")), + } + var tr gost.Transporter + switch node.Transport { + case "tls": + tr = gost.TLSTransporter() + case "ws": + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = toBool(node.Values.Get("compression")) + wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) + wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + node.HandshakeOptions = append(node.HandshakeOptions, + gost.WSOptionsHandshakeOption(wsOpts), + ) + tr = gost.WSTransporter(nil) + case "wss": + tr = gost.WSSTransporter(nil) + case "kcp": + if !chain.IsEmpty() { + return nil, errors.New("KCP must be the first node in the proxy chain") + } + config, err := parseKCPConfig(node.Values.Get("c")) + if err != nil { + log.Log("[kcp]", err) + } + node.HandshakeOptions = append(node.HandshakeOptions, + gost.KCPConfigHandshakeOption(config), + ) + tr = gost.KCPTransporter(nil) + case "ssh": + if node.Protocol == "direct" || node.Protocol == "remote" || node.Protocol == "forward" { + tr = gost.SSHForwardTransporter() + } else { + tr = gost.SSHTunnelTransporter() + } + + node.DialOptions = append(node.DialOptions, + gost.ChainDialOption(chain), + ) + chain = gost.NewChain() // cutoff the chain for multiplex + case "quic": + if !chain.IsEmpty() { + return nil, errors.New("QUIC must be the first node in the proxy chain") + } + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: toBool(node.Values.Get("keepalive")), + } + node.HandshakeOptions = append(node.HandshakeOptions, + gost.QUICConfigHandshakeOption(config), + ) + tr = gost.QUICTransporter(nil) + case "http2": + tr = gost.HTTP2Transporter(nil) + node.DialOptions = append(node.DialOptions, + gost.ChainDialOption(chain), + ) + chain = gost.NewChain() // cutoff the chain for multiplex + case "h2": + tr = gost.H2Transporter(nil) + node.DialOptions = append(node.DialOptions, + gost.ChainDialOption(chain), + ) + chain = gost.NewChain() // cutoff the chain for multiplex + case "h2c": + tr = gost.H2CTransporter() + node.DialOptions = append(node.DialOptions, + gost.ChainDialOption(chain), + ) + chain = gost.NewChain() // cutoff the chain for multiplex + default: + tr = gost.TCPTransporter() + } + + var connector gost.Connector + switch node.Protocol { + case "http2": + connector = gost.HTTP2Connector(node.User) + case "socks", "socks5": + connector = gost.SOCKS5Connector(node.User) + case "socks4": + connector = gost.SOCKS4Connector() + case "socks4a": + connector = gost.SOCKS4AConnector() + case "ss": + connector = gost.ShadowConnector(node.User) + case "direct", "forward": + connector = gost.SSHDirectForwardConnector() + case "remote": + connector = gost.SSHRemoteForwardConnector() + case "http": + fallthrough + default: + node.Protocol = "http" // default protocol is HTTP + connector = gost.HTTPConnector(node.User) + } + + node.DialOptions = append(node.DialOptions, + gost.TimeoutDialOption(gost.DialTimeout), + ) + + interval, _ := strconv.Atoi(node.Values.Get("ping")) + node.HandshakeOptions = append(node.HandshakeOptions, + gost.AddrHandshakeOption(node.Addr), + gost.UserHandshakeOption(node.User), + gost.TLSConfigHandshakeOption(tlsCfg), + gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), + ) + node.Client = &gost.Client{ + Connector: connector, + Transporter: tr, + } + chain.AddNode(node) } - wg.Wait() + + return chain, nil +} + +func serve(chain *gost.Chain) error { + for _, ns := range options.serveNodes { + node, err := gost.ParseNode(ns) + if err != nil { + return err + } + users, err := parseUsers(node.Values.Get("secrets")) + if err != nil { + return err + } + if node.User != nil { + users = append(users, node.User) + } + tlsCfg, err := tlsConfig(node.Values.Get("cert"), node.Values.Get("key")) + if err != nil { + return err + } + + var ln gost.Listener + switch node.Transport { + case "tls": + ln, err = gost.TLSListener(node.Addr, tlsCfg) + case "ws": + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = toBool(node.Values.Get("compression")) + wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) + wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + ln, err = gost.WSListener(node.Addr, wsOpts) + case "wss": + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = toBool(node.Values.Get("compression")) + wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) + wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + ln, err = gost.WSSListener(node.Addr, tlsCfg, wsOpts) + case "kcp": + config, err := parseKCPConfig(node.Values.Get("c")) + if err != nil { + log.Log("[kcp]", err) + } + ln, err = gost.KCPListener(node.Addr, config) + case "ssh": + config := &gost.SSHConfig{ + Users: users, + TLSConfig: tlsCfg, + } + if node.Protocol == "forward" { + ln, err = gost.TCPListener(node.Addr) + } else { + ln, err = gost.SSHTunnelListener(node.Addr, config) + } + case "quic": + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: toBool(node.Values.Get("keepalive")), + } + timeout, _ := strconv.Atoi(node.Values.Get("timeout")) + config.Timeout = time.Duration(timeout) * time.Second + ln, err = gost.QUICListener(node.Addr, config) + case "http2": + ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) + case "h2": + ln, err = gost.H2Listener(node.Addr, tlsCfg) + case "h2c": + ln, err = gost.H2CListener(node.Addr) + case "tcp": + ln, err = gost.TCPListener(node.Addr) + case "rtcp": + if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { + chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHRemoteForwardConnector() + } + ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) + case "udp": + ttl, _ := strconv.Atoi(node.Values.Get("ttl")) + ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(ttl)*time.Second) + case "rudp": + ttl, _ := strconv.Atoi(node.Values.Get("ttl")) + ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(ttl)*time.Second) + case "redirect": + ln, err = gost.TCPListener(node.Addr) + case "ssu": + ttl, _ := strconv.Atoi(node.Values.Get("ttl")) + ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(ttl)*time.Second) + default: + ln, err = gost.TCPListener(node.Addr) + } + if err != nil { + return err + } + + var whitelist, blacklist *gost.Permissions + if node.Values.Get("whitelist") != "" { + if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil { + return err + } + } else { + // By default allow for everyting + whitelist, _ = gost.ParsePermissions("*:*:*") + } + + if node.Values.Get("blacklist") != "" { + if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil { + return err + } + } else { + // By default block nothing + blacklist, _ = gost.ParsePermissions("") + } + + var handlerOptions []gost.HandlerOption + + handlerOptions = append(handlerOptions, + gost.AddrHandlerOption(node.Addr), + gost.ChainHandlerOption(chain), + gost.UsersHandlerOption(users...), + gost.TLSConfigHandlerOption(tlsCfg), + gost.WhitelistHandlerOption(whitelist), + gost.BlacklistHandlerOption(blacklist), + ) + var handler gost.Handler + switch node.Protocol { + case "http2": + handler = gost.HTTP2Handler(handlerOptions...) + case "socks", "socks5": + handler = gost.SOCKS5Handler(handlerOptions...) + case "socks4", "socks4a": + handler = gost.SOCKS4Handler(handlerOptions...) + case "ss": + handler = gost.ShadowHandler(handlerOptions...) + case "http": + handler = gost.HTTPHandler(handlerOptions...) + case "tcp": + handler = gost.TCPDirectForwardHandler(node.Remote, handlerOptions...) + case "rtcp": + handler = gost.TCPRemoteForwardHandler(node.Remote, handlerOptions...) + case "udp": + handler = gost.UDPDirectForwardHandler(node.Remote, handlerOptions...) + case "rudp": + handler = gost.UDPRemoteForwardHandler(node.Remote, handlerOptions...) + case "forward": + handler = gost.SSHForwardHandler(handlerOptions...) + case "redirect": + handler = gost.TCPRedirectHandler(handlerOptions...) + case "ssu": + handler = gost.ShadowUDPdHandler(handlerOptions...) + default: + handler = gost.AutoHandler(handlerOptions...) + } + go new(gost.Server).Serve(ln, handler) + } + + return nil +} + +// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. +func tlsConfig(certFile, keyFile string) (*tls.Config, error) { + if certFile == "" || keyFile == "" { + return nil, nil + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return &tls.Config{Certificates: []tls.Certificate{cert}}, nil } func loadConfigureFile(configureFile string) error { @@ -91,12 +383,65 @@ func loadConfigureFile(configureFile string) error { return nil } -type flagStringList []string +type stringList []string -func (this *flagStringList) String() string { - return fmt.Sprintf("%s", *this) +func (l *stringList) String() string { + return fmt.Sprintf("%s", *l) } -func (this *flagStringList) Set(value string) error { - *this = append(*this, value) +func (l *stringList) Set(value string) error { + *l = append(*l, value) return nil } + +func toBool(s string) bool { + if b, _ := strconv.ParseBool(s); b { + return b + } + n, _ := strconv.Atoi(s) + return n > 0 +} + +func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { + if configFile == "" { + return nil, nil + } + file, err := os.Open(configFile) + if err != nil { + return nil, err + } + defer file.Close() + + config := &gost.KCPConfig{} + if err = json.NewDecoder(file).Decode(config); err != nil { + return nil, err + } + return config, nil +} + +func parseUsers(authFile string) (users []*url.Userinfo, err error) { + if authFile == "" { + return + } + + file, err := os.Open(authFile) + if err != nil { + return + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + s := strings.SplitN(line, " ", 2) + if len(s) == 1 { + users = append(users, url.User(strings.TrimSpace(s[0]))) + } else if len(s) == 2 { + users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) + } + } + + err = scanner.Err() + return +} diff --git a/conn.go b/conn.go deleted file mode 100644 index b446425..0000000 --- a/conn.go +++ /dev/null @@ -1,313 +0,0 @@ -package gost - -import ( - "bufio" - "crypto/tls" - "encoding/base64" - "errors" - "fmt" - "net" - "net/http" - "net/http/httputil" - "net/url" - "strconv" - "strings" - "sync" - "time" - - "github.com/ginuerzh/gosocks4" - "github.com/ginuerzh/gosocks5" - "github.com/golang/glog" - ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" -) - -type ProxyConn struct { - conn net.Conn - Node ProxyNode - handshaked bool - handshakeMutex sync.Mutex - handshakeErr error -} - -func NewProxyConn(conn net.Conn, node ProxyNode) *ProxyConn { - return &ProxyConn{ - conn: conn, - Node: node, - } -} - -// Handshake handshake with this proxy node based on the proxy node info: transport, protocol, authentication, etc. -// -// NOTE: any HTTP2 scheme will be treated as http (for protocol) or tls (for transport). -func (c *ProxyConn) Handshake() error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if err := c.handshakeErr; err != nil { - return err - } - if c.handshaked { - return nil - } - c.handshakeErr = c.handshake() - return c.handshakeErr -} - -func (c *ProxyConn) handshake() error { - var tlsUsed bool - - switch c.Node.Transport { - case "ws": // websocket connection - rbuf, _ := strconv.Atoi(c.Node.Get("rbuf")) - wbuf, _ := strconv.Atoi(c.Node.Get("wbuf")) - comp := c.Node.getBool("compression") - opt := WSOptions{ - ReadBufferSize: rbuf, - WriteBufferSize: wbuf, - HandshakeTimeout: DialTimeout, - EnableCompression: comp, - } - u := url.URL{Scheme: "ws", Host: c.Node.Addr, Path: "/ws"} - conn, err := WebsocketClientConn(u.String(), c.conn, &opt) - if err != nil { - return err - } - c.conn = conn - case "wss": // websocket security - tlsUsed = true - - rbuf, _ := strconv.Atoi(c.Node.Get("rbuf")) - wbuf, _ := strconv.Atoi(c.Node.Get("wbuf")) - comp := c.Node.getBool("compression") - opt := WSOptions{ - ReadBufferSize: rbuf, - WriteBufferSize: wbuf, - HandshakeTimeout: DialTimeout, - EnableCompression: comp, - TLSConfig: &tls.Config{ - InsecureSkipVerify: c.Node.insecureSkipVerify(), - ServerName: c.Node.serverName, - }, - } - - u := url.URL{Scheme: "wss", Host: c.Node.Addr, Path: "/ws"} - conn, err := WebsocketClientConn(u.String(), c.conn, &opt) - if err != nil { - return err - } - c.conn = conn - case "tls", "http2": // tls connection - tlsUsed = true - cfg := &tls.Config{ - InsecureSkipVerify: c.Node.insecureSkipVerify(), - ServerName: c.Node.serverName, - } - c.conn = tls.Client(c.conn, cfg) - case "h2": // same as http2, but just set a flag for later using. - tlsUsed = true - case "kcp": // kcp connection - tlsUsed = true - default: - } - - switch c.Node.Protocol { - case "socks", "socks5": // socks5 handshake with auth and tls supported - selector := &ClientSelector{ - methods: []uint8{ - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - //MethodTLS, - }, - } - - if len(c.Node.Users) > 0 { - selector.User = c.Node.Users[0] - } - - if !tlsUsed { // if transport is not security, enable security socks5 - selector.methods = append(selector.methods, MethodTLS) - selector.TLSConfig = &tls.Config{ - InsecureSkipVerify: c.Node.insecureSkipVerify(), - ServerName: c.Node.serverName, - } - } - - conn := gosocks5.ClientConn(c.conn, selector) - if err := conn.Handleshake(); err != nil { - return err - } - c.conn = conn - case "ss": // shadowsocks - // nothing to do - case "http", "http2": - fallthrough - default: - } - - c.handshaked = true - - return nil -} - -// Connect connect to addr through this proxy node -func (c *ProxyConn) Connect(addr string) error { - switch c.Node.Protocol { - case "ss": // shadowsocks - rawaddr, err := ss.RawAddr(addr) - if err != nil { - return err - } - - var method, password string - if len(c.Node.Users) > 0 { - method = c.Node.Users[0].Username() - password, _ = c.Node.Users[0].Password() - } - if c.Node.getBool("ota") && !strings.HasSuffix(method, "-auth") { - method += "-auth" - } - - cipher, err := ss.NewCipher(method, password) - if err != nil { - return err - } - - ssc, err := ss.DialWithRawAddrConn(rawaddr, c.conn, cipher) - if err != nil { - return err - } - c.conn = &shadowConn{conn: ssc} - return nil - case "socks", "socks5": - host, port, err := net.SplitHostPort(addr) - if err != nil { - return err - } - p, _ := strconv.Atoi(port) - req := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{ - Type: gosocks5.AddrDomain, - Host: host, - Port: uint16(p), - }) - if err := req.Write(c); err != nil { - return err - } - glog.V(LDEBUG).Infoln("[socks5]", req) - - reply, err := gosocks5.ReadReply(c) - if err != nil { - return err - } - glog.V(LDEBUG).Infoln("[socks5]", reply) - if reply.Rep != gosocks5.Succeeded { - return errors.New("Service unavailable") - } - case "socks4", "socks4a": - atype := gosocks4.AddrDomain - host, port, err := net.SplitHostPort(addr) - if err != nil { - return err - } - p, _ := strconv.Atoi(port) - - if c.Node.Protocol == "socks4" { - taddr, err := net.ResolveTCPAddr("tcp4", addr) - if err != nil { - return err - } - host = taddr.IP.String() - p = taddr.Port - atype = gosocks4.AddrIPv4 - } - req := gosocks4.NewRequest(gosocks4.CmdConnect, - &gosocks4.Addr{Type: atype, Host: host, Port: uint16(p)}, nil) - if err := req.Write(c); err != nil { - return err - } - glog.V(LDEBUG).Infof("[%s] %s", c.Node.Protocol, req) - - reply, err := gosocks4.ReadReply(c) - if err != nil { - return err - } - glog.V(LDEBUG).Infof("[%s] %s", c.Node.Protocol, reply) - - if reply.Code != gosocks4.Granted { - return errors.New(fmt.Sprintf("%s: code=%d", c.Node.Protocol, reply.Code)) - } - case "http": - fallthrough - default: - req := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Host: addr}, - Host: addr, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - req.Header.Set("Proxy-Connection", "keep-alive") - if len(c.Node.Users) > 0 { - user := c.Node.Users[0] - s := user.String() - if _, set := user.Password(); !set { - s += ":" - } - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) - } - if err := req.Write(c); err != nil { - return err - } - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(req, false) - glog.Infoln(string(dump)) - } - - resp, err := http.ReadResponse(bufio.NewReader(c), req) - if err != nil { - return err - } - if glog.V(LDEBUG) { - dump, _ := httputil.DumpResponse(resp, false) - glog.Infoln(string(dump)) - } - if resp.StatusCode != http.StatusOK { - return errors.New(resp.Status) - } - } - - return nil -} - -func (c *ProxyConn) Read(b []byte) (n int, err error) { - return c.conn.Read(b) -} - -func (c *ProxyConn) Write(b []byte) (n int, err error) { - return c.conn.Write(b) -} - -func (c *ProxyConn) Close() error { - return c.conn.Close() -} - -func (c *ProxyConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *ProxyConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *ProxyConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *ProxyConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *ProxyConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} diff --git a/gost/examples/bench/cli.go b/examples/bench/cli.go similarity index 99% rename from gost/examples/bench/cli.go rename to examples/bench/cli.go index bcd9644..57c189c 100644 --- a/gost/examples/bench/cli.go +++ b/examples/bench/cli.go @@ -10,7 +10,7 @@ import ( "sync" "time" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" "golang.org/x/net/http2" ) diff --git a/gost/examples/bench/srv.go b/examples/bench/srv.go similarity index 99% rename from gost/examples/bench/srv.go rename to examples/bench/srv.go index 86379b7..e2bb953 100644 --- a/gost/examples/bench/srv.go +++ b/examples/bench/srv.go @@ -9,7 +9,7 @@ import ( "net/url" "time" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" "golang.org/x/net/http2" ) diff --git a/gost/examples/forward/direct/client.go b/examples/forward/direct/client.go similarity index 93% rename from gost/examples/forward/direct/client.go rename to examples/forward/direct/client.go index f8310a4..11cdf96 100644 --- a/gost/examples/forward/direct/client.go +++ b/examples/forward/direct/client.go @@ -3,7 +3,7 @@ package main import ( "log" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) func main() { diff --git a/gost/examples/forward/direct/server.go b/examples/forward/direct/server.go similarity index 99% rename from gost/examples/forward/direct/server.go rename to examples/forward/direct/server.go index 2e68828..1427bc2 100644 --- a/gost/examples/forward/direct/server.go +++ b/examples/forward/direct/server.go @@ -4,7 +4,7 @@ import ( "crypto/tls" "log" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) func main() { diff --git a/gost/examples/forward/remote/client.go b/examples/forward/remote/client.go similarity index 94% rename from gost/examples/forward/remote/client.go rename to examples/forward/remote/client.go index 716d9b0..623a17b 100644 --- a/gost/examples/forward/remote/client.go +++ b/examples/forward/remote/client.go @@ -3,7 +3,7 @@ package main import ( "log" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) func main() { diff --git a/gost/examples/forward/remote/server.go b/examples/forward/remote/server.go similarity index 99% rename from gost/examples/forward/remote/server.go rename to examples/forward/remote/server.go index 971bf98..2268f49 100644 --- a/gost/examples/forward/remote/server.go +++ b/examples/forward/remote/server.go @@ -4,7 +4,7 @@ import ( "crypto/tls" "log" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) func main() { diff --git a/gost/examples/forward/udp/cli.go b/examples/forward/udp/cli.go similarity index 100% rename from gost/examples/forward/udp/cli.go rename to examples/forward/udp/cli.go diff --git a/gost/examples/forward/udp/direct.go b/examples/forward/udp/direct.go similarity index 96% rename from gost/examples/forward/udp/direct.go rename to examples/forward/udp/direct.go index 7156d7c..c9da045 100644 --- a/gost/examples/forward/udp/direct.go +++ b/examples/forward/udp/direct.go @@ -5,7 +5,7 @@ import ( "log" "time" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) var ( diff --git a/gost/examples/forward/udp/remote.go b/examples/forward/udp/remote.go similarity index 96% rename from gost/examples/forward/udp/remote.go rename to examples/forward/udp/remote.go index 3d5cce8..6ef3e36 100644 --- a/gost/examples/forward/udp/remote.go +++ b/examples/forward/udp/remote.go @@ -5,7 +5,7 @@ import ( "log" "time" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) var ( diff --git a/gost/examples/forward/udp/srv.go b/examples/forward/udp/srv.go similarity index 100% rename from gost/examples/forward/udp/srv.go rename to examples/forward/udp/srv.go diff --git a/gost/examples/http2/http2.go b/examples/http2/http2.go similarity index 99% rename from gost/examples/http2/http2.go rename to examples/http2/http2.go index 5e185e6..1d3f6c5 100644 --- a/gost/examples/http2/http2.go +++ b/examples/http2/http2.go @@ -8,7 +8,7 @@ import ( "golang.org/x/net/http2" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) var ( diff --git a/gost/examples/quic/quicc.go b/examples/quic/quicc.go similarity index 99% rename from gost/examples/quic/quicc.go rename to examples/quic/quicc.go index d8c03d3..5b25106 100644 --- a/gost/examples/quic/quicc.go +++ b/examples/quic/quicc.go @@ -6,7 +6,7 @@ import ( "log" "time" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) var ( diff --git a/gost/examples/quic/quics.go b/examples/quic/quics.go similarity index 99% rename from gost/examples/quic/quics.go rename to examples/quic/quics.go index 1fa5bb1..b40d455 100644 --- a/gost/examples/quic/quics.go +++ b/examples/quic/quics.go @@ -5,7 +5,7 @@ import ( "flag" "log" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) var ( diff --git a/gost/examples/ssh/sshc.go b/examples/ssh/sshc.go similarity index 96% rename from gost/examples/ssh/sshc.go rename to examples/ssh/sshc.go index 098c8f2..137febf 100644 --- a/gost/examples/ssh/sshc.go +++ b/examples/ssh/sshc.go @@ -4,8 +4,9 @@ import ( "crypto/tls" "flag" "log" + "time" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) var ( @@ -33,6 +34,9 @@ func main() { Protocol: "socks5", Transport: "ssh", Addr: faddr, + HandshakeOptions: []gost.HandshakeOption{ + gost.IntervalHandshakeOption(30 * time.Second), + }, Client: &gost.Client{ Connector: gost.SOCKS5Connector(nil), Transporter: gost.SSHTunnelTransporter(), diff --git a/gost/examples/ssh/sshd.go b/examples/ssh/sshd.go similarity index 99% rename from gost/examples/ssh/sshd.go rename to examples/ssh/sshd.go index 5fa3eb8..68f1f81 100644 --- a/gost/examples/ssh/sshd.go +++ b/examples/ssh/sshd.go @@ -5,7 +5,7 @@ import ( "flag" "log" - "github.com/ginuerzh/gost/gost" + "github.com/ginuerzh/gost" ) var ( diff --git a/gost/examples/ssu/ssu.go b/examples/ssu/ssu.go similarity index 100% rename from gost/examples/ssu/ssu.go rename to examples/ssu/ssu.go diff --git a/forward.go b/forward.go index f1ff558..5c5e9a6 100644 --- a/forward.go +++ b/forward.go @@ -2,725 +2,662 @@ package gost import ( "errors" - "fmt" "net" - "strconv" - "strings" + "sync" "time" + "fmt" + "github.com/ginuerzh/gosocks5" - "github.com/golang/glog" - "golang.org/x/crypto/ssh" + "github.com/go-log/log" ) -type TcpForwardServer struct { - Base *ProxyServer - sshClient *ssh.Client - Handler func(conn net.Conn, raddr *net.TCPAddr) +type tcpDirectForwardHandler struct { + raddr string + options *HandlerOptions } -func NewTcpForwardServer(base *ProxyServer) *TcpForwardServer { - return &TcpForwardServer{Base: base} +// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. +// The raddr is the remote address that the server will forward to. +func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &tcpDirectForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + return h } -func (s *TcpForwardServer) ListenAndServe() error { - raddr, err := net.ResolveTCPAddr("tcp", s.Base.Node.Remote) - if err != nil { - return err - } - - ln, err := net.Listen("tcp", s.Base.Node.Addr) - if err != nil { - return err - } - defer ln.Close() - - if s.Handler == nil { - s.Handler = s.handleTcpForward - } - - quit := make(chan interface{}) - close(quit) // first init ssh client - - for { - start: - conn, err := ln.Accept() - if err != nil { - glog.V(LWARNING).Infoln("[ssh]", err) - continue - } - setKeepAlive(conn, KeepAliveTime) - - select { - case <-quit: - if s.Base.Chain.lastNode == nil || s.Base.Chain.lastNode.Transport != "ssh" { - break - } - if err := s.initSSHClient(); err != nil { - glog.V(LWARNING).Infoln("[ssh]", err) - conn.Close() - goto start - } - quit = make(chan interface{}) - exit := make(chan error, 1) - go func() { - exit <- s.sshClient.Wait() - }() - - go func() { - var c <-chan time.Time - ping, _ := strconv.Atoi(s.Base.Chain.lastNode.Get("ping")) - if ping > 0 { - d := time.Second * time.Duration(ping) - glog.V(LINFO).Infoln("[tcp-ssh] ping is enabled:", d) - t := time.NewTicker(d) - defer t.Stop() - c = t.C - } - - for { - select { - case <-c: - _, _, err := s.sshClient.SendRequest("ping", true, nil) - if err != nil { - glog.V(LWARNING).Infoln("[tcp-ssh] ping", err) - close(quit) - return - } - glog.V(LDEBUG).Infoln("[tcp-ssh] heartbeat OK") - - case er := <-exit: - if er != nil { - glog.V(LWARNING).Infoln("[tcp-ssh] ssh connection closed:", er) - } - close(quit) - return - } - } - }() - - default: - } - - go s.Handler(conn, raddr) - } -} - -func (s *TcpForwardServer) initSSHClient() error { - if s.sshClient != nil { - s.sshClient.Close() - s.sshClient = nil - } - - sshNode := s.Base.Chain.lastNode - c, err := s.Base.Chain.GetConn() - if err != nil { - return err - } - var user, password string - if len(sshNode.Users) > 0 { - user = sshNode.Users[0].Username() - password, _ = sshNode.Users[0].Password() - } - config := ssh.ClientConfig{ - User: user, - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, - } - sshConn, chans, reqs, err := ssh.NewClientConn(c, sshNode.Addr, &config) - if err != nil { - return err - } - s.sshClient = ssh.NewClient(sshConn, chans, reqs) - s.Handler = s.handleTcpForwardSSH - - return nil -} - -func (s *TcpForwardServer) handleTcpForward(conn net.Conn, raddr *net.TCPAddr) { +func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() - glog.V(LINFO).Infof("[tcp] %s - %s", conn.RemoteAddr(), raddr) - cc, err := s.Base.Chain.Dial(raddr.String()) + log.Logf("[tcp] %s - %s", conn.RemoteAddr(), h.raddr) + cc, err := h.options.Chain.Dial(h.raddr) if err != nil { - glog.V(LWARNING).Infof("[tcp] %s -> %s : %s", conn.RemoteAddr(), raddr, err) + log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), h.raddr, err) return } defer cc.Close() - glog.V(LINFO).Infof("[tcp] %s <-> %s", conn.RemoteAddr(), raddr) - s.Base.transport(conn, cc) - glog.V(LINFO).Infof("[tcp] %s >-< %s", conn.RemoteAddr(), raddr) + log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), h.raddr) + transport(conn, cc) + log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), h.raddr) } -func (s *TcpForwardServer) handleTcpForwardSSH(conn net.Conn, raddr *net.TCPAddr) { +type udpDirectForwardHandler struct { + raddr string + options *HandlerOptions +} + +// UDPDirectForwardHandler creates a server Handler for UDP port forwarding server. +// The raddr is the remote address that the server will forward to. +func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &udpDirectForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *udpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() - if s.sshClient == nil { - return - } - - rc, err := s.sshClient.DialTCP("tcp", nil, raddr) - if err != nil { - glog.V(LWARNING).Infof("[tcp] %s -> %s : %s", conn.RemoteAddr(), raddr, err) - return - } - defer rc.Close() - - glog.V(LINFO).Infof("[tcp] %s <-> %s", conn.RemoteAddr(), raddr) - Transport(conn, rc) - glog.V(LINFO).Infof("[tcp] %s >-< %s", conn.RemoteAddr(), raddr) -} - -type packet struct { - srcAddr string // src address - dstAddr string // dest address - data []byte -} - -type cnode struct { - chain *ProxyChain - conn net.Conn - srcAddr, dstAddr string - rChan, wChan chan *packet - err error - ttl time.Duration -} - -func (node *cnode) getUDPTunnel() (net.Conn, error) { - conn, err := node.chain.GetConn() - if err != nil { - return nil, err - } - - conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err = gosocks5.NewRequest(CmdUdpTun, nil).Write(conn); err != nil { - conn.Close() - return nil, err - } - conn.SetWriteDeadline(time.Time{}) - - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - reply, err := gosocks5.ReadReply(conn) - if err != nil { - conn.Close() - return nil, err - } - conn.SetReadDeadline(time.Time{}) - - if reply.Rep != gosocks5.Succeeded { - conn.Close() - return nil, errors.New("UDP tunnel failure") - } - - return conn, nil -} - -func (node *cnode) run() { - if len(node.chain.Nodes()) == 0 { - lconn, err := net.ListenUDP("udp", nil) + var cc net.Conn + if h.options.Chain.IsEmpty() { + raddr, err := net.ResolveUDPAddr("udp", h.raddr) if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", node.srcAddr, node.dstAddr, err) - node.err = err + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + cc, err = net.DialUDP("udp", nil, raddr) + if err != nil { + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) return } - node.conn = lconn } else { - tc, err := node.getUDPTunnel() + var err error + cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) if err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s -> %s : %s", node.srcAddr, node.dstAddr, err) - node.err = err + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) return } - node.conn = tc + cc = &udpTunnelConn{Conn: cc, raddr: h.raddr} } - defer node.conn.Close() + defer cc.Close() - timer := time.NewTimer(node.ttl) - errChan := make(chan error, 2) - - go func() { - for { - switch c := node.conn.(type) { - case *net.UDPConn: - b := make([]byte, MediumBufferSize) - n, addr, err := c.ReadFromUDP(b) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", node.srcAddr, node.dstAddr, err) - node.err = err - errChan <- err - return - } - - timer.Reset(node.ttl) - glog.V(LDEBUG).Infof("[udp] %s <<< %s : length %d", node.srcAddr, addr, n) - - select { - // swap srcAddr with dstAddr - case node.rChan <- &packet{srcAddr: addr.String(), dstAddr: node.srcAddr, data: b[:n]}: - case <-time.After(time.Second * 3): - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", node.srcAddr, node.dstAddr, "recv queue is full, discard") - } - - default: - dgram, err := gosocks5.ReadUDPDatagram(c) - if err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s <- %s : %s", node.srcAddr, node.dstAddr, err) - node.err = err - errChan <- err - return - } - - timer.Reset(node.ttl) - glog.V(LDEBUG).Infof("[udp-tun] %s <<< %s : length %d", node.srcAddr, dgram.Header.Addr.String(), len(dgram.Data)) - - select { - // swap srcAddr with dstAddr - case node.rChan <- &packet{srcAddr: dgram.Header.Addr.String(), dstAddr: node.srcAddr, data: dgram.Data}: - case <-time.After(time.Second * 3): - glog.V(LWARNING).Infof("[udp-tun] %s <- %s : %s", node.srcAddr, node.dstAddr, "recv queue is full, discard") - } - } - } - }() - - go func() { - for pkt := range node.wChan { - timer.Reset(node.ttl) - - dstAddr, err := net.ResolveUDPAddr("udp", pkt.dstAddr) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, err) - continue - } - - switch c := node.conn.(type) { - case *net.UDPConn: - if _, err := c.WriteToUDP(pkt.data, dstAddr); err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, err) - node.err = err - errChan <- err - return - } - glog.V(LDEBUG).Infof("[udp] %s >>> %s : length %d", pkt.srcAddr, pkt.dstAddr, len(pkt.data)) - - default: - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(pkt.data)), 0, ToSocksAddr(dstAddr)), pkt.data) - if err := dgram.Write(c); err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, err) - node.err = err - errChan <- err - return - } - glog.V(LDEBUG).Infof("[udp-tun] %s >>> %s : length %d", pkt.srcAddr, pkt.dstAddr, len(pkt.data)) - } - } - }() - - select { - case <-errChan: - case <-timer.C: - } + log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), h.raddr) + transport(conn, cc) + log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), h.raddr) } -type UdpForwardServer struct { - Base *ProxyServer - TTL int +type tcpRemoteForwardHandler struct { + raddr string + options *HandlerOptions } -func NewUdpForwardServer(base *ProxyServer, ttl int) *UdpForwardServer { - return &UdpForwardServer{Base: base, TTL: ttl} +// TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server. +// The raddr is the remote address that the server will forward to. +func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &tcpRemoteForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + return h } -func (s *UdpForwardServer) ListenAndServe() error { - laddr, err := net.ResolveUDPAddr("udp", s.Base.Node.Addr) - if err != nil { - return err - } - - raddr, err := net.ResolveUDPAddr("udp", s.Base.Node.Remote) - if err != nil { - return err - } - - conn, err := net.ListenUDP("udp", laddr) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) - return err - } +func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() - rChan, wChan := make(chan *packet, 128), make(chan *packet, 128) - // start send queue - go func(ch chan<- *packet) { - for { - b := make([]byte, MediumBufferSize) - n, addr, err := conn.ReadFromUDP(b) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) - continue - } + cc, err := net.DialTimeout("tcp", h.raddr, DialTimeout) + if err != nil { + log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + defer cc.Close() + + log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), h.raddr) + transport(cc, conn) + log.Logf("[rtcp] %s >-< %s", conn.LocalAddr(), h.raddr) +} + +type udpRemoteForwardHandler struct { + raddr string + options *HandlerOptions +} + +// UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server. +// The raddr is the remote address that the server will forward to. +func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &udpRemoteForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { + defer conn.Close() + + raddr, err := net.ResolveUDPAddr("udp", h.raddr) + if err != nil { + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) + return + } + cc, err := net.DialUDP("udp", nil, raddr) + if err != nil { + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) + return + } + + log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), h.raddr) + transport(conn, cc) + log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), h.raddr) +} + +type udpDirectForwardListener struct { + ln net.PacketConn + conns map[string]*udpServerConn + connChan chan net.Conn + errChan chan error + ttl time.Duration +} + +// UDPDirectForwardListener creates a Listener for UDP port forwarding server. +func UDPDirectForwardListener(addr string, ttl time.Duration) (Listener, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenUDP("udp", laddr) + if err != nil { + return nil, err + } + l := &udpDirectForwardListener{ + ln: ln, + conns: make(map[string]*udpServerConn), + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + ttl: ttl, + } + go l.listenLoop() + return l, nil +} + +func (l *udpDirectForwardListener) listenLoop() { + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := l.ln.ReadFrom(b) + if err != nil { + log.Logf("[udp] peer -> %s : %s", l.Addr(), err) + l.ln.Close() + l.errChan <- err + close(l.errChan) + return + } + if Debug { + log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) + } + conn, ok := l.conns[raddr.String()] + if !ok || conn.Closed() { + conn = newUDPServerConn(l.ln, raddr, l.ttl) + l.conns[raddr.String()] = conn select { - case ch <- &packet{srcAddr: addr.String(), dstAddr: raddr.String(), data: b[:n]}: - case <-time.After(time.Second * 3): - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", addr, raddr, "send queue is full, discard") + case l.connChan <- conn: + default: + conn.Close() + log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr()) } } - }(wChan) - // start recv queue - go func(ch <-chan *packet) { - for pkt := range ch { - dstAddr, err := net.ResolveUDPAddr("udp", pkt.dstAddr) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", pkt.dstAddr, pkt.srcAddr, err) - continue - } - if _, err := conn.WriteToUDP(pkt.data, dstAddr); err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", pkt.dstAddr, pkt.srcAddr, err) - return - } - } - }(rChan) - - // mapping client to node - m := make(map[string]*cnode) - - // start dispatcher - for pkt := range wChan { - // clear obsolete nodes - for k, node := range m { - if node != nil && node.err != nil { - close(node.wChan) - delete(m, k) - glog.V(LINFO).Infof("[udp] clear node %s", k) - } - } - - node, ok := m[pkt.srcAddr] - if !ok { - node = &cnode{ - chain: s.Base.Chain, - srcAddr: pkt.srcAddr, - dstAddr: pkt.dstAddr, - rChan: rChan, - wChan: make(chan *packet, 32), - ttl: time.Duration(s.TTL) * time.Second, - } - m[pkt.srcAddr] = node - go node.run() - glog.V(LINFO).Infof("[udp] %s -> %s : new client (%d)", pkt.srcAddr, pkt.dstAddr, len(m)) - } select { - case node.wChan <- pkt: - case <-time.After(time.Second * 3): - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, "node send queue is full, discard") + case conn.rChan <- b[:n]: + default: + log.Logf("[udp] %s -> %s : read queue is full", raddr, l.Addr()) } } +} +func (l *udpDirectForwardListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *udpDirectForwardListener) Addr() net.Addr { + return l.ln.LocalAddr() +} + +func (l *udpDirectForwardListener) Close() error { + return l.ln.Close() +} + +type udpServerConn struct { + conn net.PacketConn + raddr net.Addr + rChan, wChan chan []byte + closed chan struct{} + brokenChan chan struct{} + closeMutex sync.Mutex + ttl time.Duration + nopChan chan int +} + +func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration) *udpServerConn { + c := &udpServerConn{ + conn: conn, + raddr: raddr, + rChan: make(chan []byte, 128), + wChan: make(chan []byte, 128), + closed: make(chan struct{}), + brokenChan: make(chan struct{}), + nopChan: make(chan int), + ttl: ttl, + } + go c.writeLoop() + go c.ttlWait() + return c +} + +func (c *udpServerConn) Read(b []byte) (n int, err error) { + select { + case bb := <-c.rChan: + n = copy(b, bb) + if n != len(bb) { + err = errors.New("read partial data") + return + } + case <-c.brokenChan: + err = errors.New("Broken pipe") + case <-c.closed: + err = errors.New("read from closed connection") + return + } + + select { + case c.nopChan <- n: + default: + } + return +} + +func (c *udpServerConn) Write(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + select { + case c.wChan <- b: + n = len(b) + case <-c.brokenChan: + err = errors.New("Broken pipe") + case <-c.closed: + err = errors.New("write to closed connection") + return + } + + select { + case c.nopChan <- n: + default: + } + + return +} + +func (c *udpServerConn) Close() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + select { + case <-c.closed: + return errors.New("connection is closed") + default: + close(c.closed) + } return nil } -type RTcpForwardServer struct { - Base *ProxyServer +func (c *udpServerConn) Closed() bool { + select { + case <-c.closed: + return true + default: + return false + } } -func NewRTcpForwardServer(base *ProxyServer) *RTcpForwardServer { - return &RTcpForwardServer{Base: base} -} - -func (s *RTcpForwardServer) Serve() error { - if len(s.Base.Chain.nodes) == 0 { - return errors.New("rtcp: at least one -F must be assigned") - } - - laddr, err := net.ResolveTCPAddr("tcp", s.Base.Node.Addr) - if err != nil { - return err - } - raddr, err := net.ResolveTCPAddr("tcp", s.Base.Node.Remote) - if err != nil { - return err - } - - retry := 0 +func (c *udpServerConn) writeLoop() { for { - conn, err := s.Base.Chain.GetConn() - if err != nil { - glog.V(LWARNING).Infof("[rtcp] %s - %s : %s", laddr, raddr, err) - time.Sleep((1 << uint(retry)) * time.Second) - if retry < 5 { - retry++ + select { + case b, ok := <-c.wChan: + if !ok { + return } + n, err := c.conn.WriteTo(b, c.raddr) + if err != nil { + log.Logf("[udp] %s - %s : %s", c.RemoteAddr(), c.LocalAddr(), err) + return + } + if Debug { + log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n) + } + case <-c.brokenChan: + return + case <-c.closed: + return + } + } +} + +func (c *udpServerConn) ttlWait() { + ttl := c.ttl + if ttl == 0 { + ttl = defaultTTL + } + timer := time.NewTimer(ttl) + + for { + select { + case <-c.nopChan: + timer.Reset(ttl) + case <-timer.C: + close(c.brokenChan) + return + case <-c.closed: + return + } + } +} + +func (c *udpServerConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *udpServerConn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *udpServerConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *udpServerConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *udpServerConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type tcpRemoteForwardListener struct { + addr net.Addr + chain *Chain + ln net.Listener + closed chan struct{} +} + +// TCPRemoteForwardListener creates a Listener for TCP remote port forwarding server. +func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + + return &tcpRemoteForwardListener{ + addr: laddr, + chain: chain, + closed: make(chan struct{}), + }, nil +} + +func (l *tcpRemoteForwardListener) Accept() (net.Conn, error) { + select { + case <-l.closed: + return nil, errors.New("closed") + default: + } + + var tempDelay time.Duration + for { + conn, err := l.accept() + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("[rtcp] Accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) continue } - retry = 0 - - glog.V(LINFO).Infof("[rtcp] %s - %s", laddr, raddr) - - lastNode := s.Base.Chain.lastNode - if lastNode != nil && lastNode.Transport == "ssh" { - s.connectRTcpForwardSSH(conn, lastNode, laddr, raddr) - } else { - if err := s.connectRTcpForward(conn, laddr, raddr); err != nil { - conn.Close() - } - } - time.Sleep(3 * time.Second) + return conn, nil } } -func (s *RTcpForwardServer) connectRTcpForwardSSH(conn net.Conn, sshNode *ProxyNode, laddr, raddr net.Addr) error { - defer conn.Close() - - var user, password string - if len(sshNode.Users) > 0 { - user = sshNode.Users[0].Username() - password, _ = sshNode.Users[0].Password() - } - config := ssh.ClientConfig{ - User: user, - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, - } - c, chans, reqs, err := ssh.NewClientConn(conn, sshNode.Addr, &config) - if err != nil { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : %s", laddr, raddr, err) - return err - } - client := ssh.NewClient(c, chans, reqs) - - quit := make(chan interface{}) - defer close(quit) - - go func() { - defer client.Close() - - var c <-chan time.Time - - ping, _ := strconv.Atoi(sshNode.Get("ping")) - if ping > 0 { - d := time.Second * time.Duration(ping) - glog.V(LINFO).Infoln("[rtcp] ping is enabled:", d) - t := time.NewTicker(d) - defer t.Stop() - c = t.C +func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { + lastNode := l.chain.LastNode() + if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { + conn, err = l.chain.Dial(l.addr.String()) + } else if lastNode.Protocol == "socks5" { + cc, er := l.chain.Conn() + if er != nil { + return nil, er } - - for { - select { - case <-c: - _, _, err := client.SendRequest("ping", true, nil) - if err != nil { - glog.V(LWARNING).Infoln("[rtcp] ping", err) - return - } - glog.V(LDEBUG).Infoln("[rtcp] heartbeat OK") - - case <-quit: - glog.V(LWARNING).Infoln("[rtcp] ssh connection closed") - return - } - } - }() - - addr := laddr.String() - if strings.HasPrefix(addr, ":") { - addr = "0.0.0.0" + addr - } - ln, err := client.Listen("tcp", addr) - if err != nil { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : %s", laddr, raddr, err) - return err - } - defer ln.Close() - - for { - rc, err := ln.Accept() + conn, err = l.waitConnectSOCKS5(cc) if err != nil { - return err + cc.Close() } - - go func(c net.Conn) { - defer c.Close() - - tc, err := net.DialTimeout("tcp", raddr.String(), time.Second*30) + } else { + if l.ln == nil { + l.ln, err = net.Listen("tcp", l.addr.String()) if err != nil { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : %s", laddr, raddr, err) return } - defer tc.Close() - - glog.V(LINFO).Infof("[rtcp] %s <-> %s", c.RemoteAddr(), c.LocalAddr()) - Transport(c, tc) - glog.V(LINFO).Infof("[rtcp] %s >-< %s", c.RemoteAddr(), c.LocalAddr()) - }(rc) + } + conn, err = l.ln.Accept() } + return } -func (s *RTcpForwardServer) connectRTcpForward(conn net.Conn, laddr, raddr net.Addr) error { - req := gosocks5.NewRequest(gosocks5.CmdBind, ToSocksAddr(laddr)) +func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) { + conn, err := socks5Handshake(conn, l.chain.LastNode().User) + if err != nil { + return nil, err + } + req := gosocks5.NewRequest(gosocks5.CmdBind, toSocksAddr(l.addr)) if err := req.Write(conn); err != nil { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : %s", laddr, raddr, err) - return err + log.Log("[rtcp] SOCKS5 BIND request: ", err) + return nil, err } // first reply, bind status conn.SetReadDeadline(time.Now().Add(ReadTimeout)) rep, err := gosocks5.ReadReply(conn) if err != nil { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : %s", laddr, raddr, err) - return err + log.Log("[rtcp] SOCKS5 BIND reply: ", err) + return nil, err } conn.SetReadDeadline(time.Time{}) if rep.Rep != gosocks5.Succeeded { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : bind on %s failure", laddr, raddr, laddr) - return errors.New("Bind on " + laddr.String() + " failure") + log.Logf("[rtcp] bind on %s failure", l.addr) + return nil, fmt.Errorf("Bind on %s failure", l.addr.String()) } - glog.V(LINFO).Infof("[rtcp] %s - %s BIND ON %s OK", laddr, raddr, rep.Addr) + log.Logf("[rtcp] BIND ON %s OK", rep.Addr) - // second reply, peer connection + // second reply, peer connected rep, err = gosocks5.ReadReply(conn) if err != nil { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : %s", laddr, raddr, err) - return err + log.Log("[rtcp]", err) + return nil, err } if rep.Rep != gosocks5.Succeeded { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : peer connect failure", laddr, raddr) - return errors.New("peer connect failure") + log.Logf("[rtcp] peer connect failure: %d", rep.Rep) + return nil, errors.New("peer connect failure") } - glog.V(LINFO).Infof("[rtcp] %s -> %s PEER %s CONNECTED", laddr, raddr, rep.Addr) + log.Logf("[rtcp] PEER %s CONNECTED", rep.Addr) + return conn, nil +} - go func() { - defer conn.Close() - - lconn, err := net.DialTimeout("tcp", raddr.String(), time.Second*30) - if err != nil { - glog.V(LWARNING).Infof("[rtcp] %s -> %s : %s", rep.Addr, raddr, err) - return - } - defer lconn.Close() - - glog.V(LINFO).Infof("[rtcp] %s <-> %s", rep.Addr, lconn.RemoteAddr()) - s.Base.transport(lconn, conn) - glog.V(LINFO).Infof("[rtcp] %s >-< %s", rep.Addr, lconn.RemoteAddr()) - }() +func (l *tcpRemoteForwardListener) Addr() net.Addr { + return l.addr +} +func (l *tcpRemoteForwardListener) Close() error { + close(l.closed) return nil } -type RUdpForwardServer struct { - Base *ProxyServer +type udpRemoteForwardListener struct { + addr *net.UDPAddr + chain *Chain + conns map[string]*udpServerConn + connChan chan net.Conn + errChan chan error + ttl time.Duration + closed chan struct{} } -func NewRUdpForwardServer(base *ProxyServer) *RUdpForwardServer { - return &RUdpForwardServer{Base: base} +// UDPRemoteForwardListener creates a Listener for UDP remote port forwarding server. +func UDPRemoteForwardListener(addr string, chain *Chain, ttl time.Duration) (Listener, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + ln := &udpRemoteForwardListener{ + addr: laddr, + chain: chain, + conns: make(map[string]*udpServerConn), + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + ttl: ttl, + closed: make(chan struct{}), + } + go ln.listenLoop() + return ln, nil } -func (s *RUdpForwardServer) Serve() error { - if len(s.Base.Chain.nodes) == 0 { - return errors.New("rudp: at least one -F must be assigned") - } - - laddr, err := net.ResolveUDPAddr("udp", s.Base.Node.Addr) - if err != nil { - return err - } - raddr, err := net.ResolveUDPAddr("udp", s.Base.Node.Remote) - if err != nil { - return err - } - - retry := 0 +func (l *udpRemoteForwardListener) listenLoop() { for { - conn, err := s.Base.Chain.GetConn() + conn, err := l.connect() if err != nil { - glog.V(LWARNING).Infof("[rudp] %s - %s : %s", laddr, raddr, err) - time.Sleep((1 << uint(retry)) * time.Second) - if retry < 5 { - retry++ + log.Logf("[rudp] %s : %s", l.Addr(), err) + return + } + + defer conn.Close() + + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := conn.ReadFrom(b) + if err != nil { + log.Logf("[rudp] %s : %s", l.Addr(), err) + break } + if Debug { + log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) + } + uc, ok := l.conns[raddr.String()] + if !ok || uc.Closed() { + uc = newUDPServerConn(conn, raddr, l.ttl) + l.conns[raddr.String()] = uc + + select { + case l.connChan <- uc: + default: + uc.Close() + log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr()) + } + } + + select { + case uc.rChan <- b[:n]: + default: + log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr()) + } + } + } + +} + +func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { + var tempDelay time.Duration + + for { + select { + case <-l.closed: + return nil, errors.New("closed") + default: + } + + lastNode := l.chain.LastNode() + if lastNode.Protocol == "socks5" { + var cc net.Conn + cc, err = getSOCKS5UDPTunnel(l.chain, l.addr) + if err != nil { + log.Logf("[rudp] %s : %s", l.Addr(), err) + } else { + conn = &udpTunnelConn{Conn: cc} + } + } else { + conn, err = net.ListenUDP("udp", l.addr) + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("[rudp] Accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) continue } - retry = 0 - - if err := s.connectRUdpForward(conn, laddr, raddr); err != nil { - conn.Close() - time.Sleep(6 * time.Second) - } + return } } -func (s *RUdpForwardServer) connectRUdpForward(conn net.Conn, laddr, raddr *net.UDPAddr) error { - glog.V(LINFO).Infof("[rudp] %s - %s", laddr, raddr) - - req := gosocks5.NewRequest(CmdUdpTun, ToSocksAddr(laddr)) - conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err := req.Write(conn); err != nil { - glog.V(LWARNING).Infof("[rudp] %s -> %s : %s", laddr, raddr, err) - return err - } - conn.SetWriteDeadline(time.Time{}) - - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - rep, err := gosocks5.ReadReply(conn) - if err != nil { - glog.V(LWARNING).Infof("[rudp] %s <- %s : %s", laddr, raddr, err) - return err - } - conn.SetReadDeadline(time.Time{}) - - if rep.Rep != gosocks5.Succeeded { - glog.V(LWARNING).Infof("[rudp] %s <- %s : bind on %s failure", laddr, raddr, laddr) - return errors.New(fmt.Sprintf("bind on %s failure", laddr)) - } - - glog.V(LINFO).Infof("[rudp] %s - %s BIND ON %s OK", laddr, raddr, rep.Addr) - - for { - dgram, err := gosocks5.ReadUDPDatagram(conn) - if err != nil { - glog.V(LWARNING).Infof("[rudp] %s <- %s : %s", laddr, raddr, err) - return err +func (l *udpRemoteForwardListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") } - - go func() { - b := make([]byte, MediumBufferSize) - - relay, err := net.DialUDP("udp", nil, raddr) - if err != nil { - glog.V(LWARNING).Infof("[rudp] %s -> %s : %s", laddr, raddr, err) - return - } - defer relay.Close() - - if _, err := relay.Write(dgram.Data); err != nil { - glog.V(LWARNING).Infof("[rudp] %s -> %s : %s", laddr, raddr, err) - return - } - glog.V(LDEBUG).Infof("[rudp] %s >>> %s length: %d", laddr, raddr, len(dgram.Data)) - - relay.SetReadDeadline(time.Now().Add(ReadTimeout)) - n, err := relay.Read(b) - if err != nil { - glog.V(LWARNING).Infof("[rudp] %s <- %s : %s", laddr, raddr, err) - return - } - relay.SetReadDeadline(time.Time{}) - - glog.V(LDEBUG).Infof("[rudp] %s <<< %s length: %d", laddr, raddr, n) - - conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(n), 0, dgram.Header.Addr), b[:n]).Write(conn); err != nil { - glog.V(LWARNING).Infof("[rudp] %s <- %s : %s", laddr, raddr, err) - return - } - conn.SetWriteDeadline(time.Time{}) - }() } + return +} + +func (l *udpRemoteForwardListener) Addr() net.Addr { + return l.addr +} + +func (l *udpRemoteForwardListener) Close() error { + close(l.closed) + return nil } diff --git a/gost.go b/gost.go index a5b27bd..e8a4fc6 100644 --- a/gost.go +++ b/gost.go @@ -6,88 +6,76 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" - "encoding/base64" "encoding/pem" - "errors" - "io" "math/big" - "net" - "strings" "time" - "github.com/golang/glog" + "github.com/go-log/log" ) -const ( - Version = "2.4-dev20170711" -) - -// Log level for glog -const ( - LFATAL = iota - LERROR - LWARNING - LINFO - LDEBUG -) +// Version is the gost version. +const Version = "2.4-dev20170803" +// Debug is a flag that enables the debug log. var Debug bool var ( + tinyBufferSize = 128 + smallBufferSize = 1 * 1024 // 1KB small buffer + mediumBufferSize = 8 * 1024 // 8KB medium buffer + largeBufferSize = 32 * 1024 // 32KB large buffer +) + +var ( + // KeepAliveTime is the keep alive time period for TCP connection. KeepAliveTime = 180 * time.Second - DialTimeout = 30 * time.Second - ReadTimeout = 90 * time.Second - WriteTimeout = 90 * time.Second - - DefaultTTL = 60 // default udp node TTL in second for udp port forwarding + // DialTimeout is the timeout of dial. + DialTimeout = 30 * time.Second + // ReadTimeout is the timeout for reading. + ReadTimeout = 30 * time.Second + // WriteTimeout is the timeout for writing. + WriteTimeout = 60 * time.Second + // PingTimeout is the timeout for pinging. + PingTimeout = 30 * time.Second + // PingRetries is the reties of ping. + PingRetries = 3 + // default udp node TTL in second for udp port forwarding. + defaultTTL = 60 * time.Second ) var ( - SmallBufferSize = 1 * 1024 // 1KB small buffer - MediumBufferSize = 8 * 1024 // 8KB medium buffer - LargeBufferSize = 32 * 1024 // 32KB large buffer + DefaultTLSConfig *tls.Config ) -var ( - DefaultCertFile = "cert.pem" - DefaultKeyFile = "key.pem" - - // This is the default cert and key data for convenience, providing your own cert is recommended. - defaultRawCert []byte - defaultRawKey []byte -) - -var ( - ErrEmptyChain = errors.New("empty chain") -) - -func setKeepAlive(conn net.Conn, d time.Duration) error { - c, ok := conn.(*net.TCPConn) - if !ok { - return errors.New("Not a TCP connection") +func init() { + rawCert, rawKey, err := generateKeyPair() + if err != nil { + panic(err) } - if err := c.SetKeepAlive(true); err != nil { - return err + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) } - if err := c.SetKeepAlivePeriod(d); err != nil { - return err + DefaultTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, } - return nil + + log.DefaultLogger = &LogLogger{} } -func generateKeyPair() (rawCert, rawKey []byte) { - if defaultRawCert != nil && defaultRawKey != nil { - return defaultRawCert, defaultRawKey - } +func SetLogger(logger log.Logger) { + log.DefaultLogger = logger +} +func generateKeyPair() (rawCert, rawKey []byte, err error) { // Create private key and self-signed certificate // Adapted from https://golang.org/src/crypto/tls/generate_cert.go priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - glog.Fatal(err) + return } - validFor := time.Hour * 24 * 365 * 10 + validFor := time.Hour * 24 * 365 * 10 // ten years notBefore := time.Now() notAfter := notBefore.Add(validFor) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) @@ -106,68 +94,11 @@ func generateKeyPair() (rawCert, rawKey []byte) { } derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - glog.Fatal(err) + return } rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) - return rawCert, rawKey -} - -// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. -func LoadCertificate(certFile, keyFile string) (tls.Certificate, error) { - tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err == nil { - return tlsCert, nil - } - glog.V(LWARNING).Infoln(err) - - rawCert, rawKey := defaultRawCert, defaultRawKey - if defaultRawCert == nil || defaultRawKey == nil { - rawCert, rawKey = generateKeyPair() - } - return tls.X509KeyPair(rawCert, rawKey) -} - -// Replace the default certificate by your own -func SetDefaultCertificate(rawCert, rawKey []byte) { - defaultRawCert = rawCert - defaultRawKey = rawKey -} - -func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { - if proxyAuth == "" { - return - } - - if !strings.HasPrefix(proxyAuth, "Basic ") { - return - } - c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) - if err != nil { - return - } - cs := string(c) - s := strings.IndexByte(cs, ':') - if s < 0 { - return - } - - return cs[:s], cs[s+1:], true -} - -func Transport(rw1, rw2 io.ReadWriter) error { - errc := make(chan error, 1) - go func() { - _, err := io.Copy(rw1, rw2) - errc <- err - }() - - go func() { - _, err := io.Copy(rw2, rw1) - errc <- err - }() - - return <-errc + return } diff --git a/gost/chain.go b/gost/chain.go deleted file mode 100644 index 3cd0172..0000000 --- a/gost/chain.go +++ /dev/null @@ -1,110 +0,0 @@ -package gost - -import ( - "errors" - "net" -) - -var ( - // ErrEmptyChain is an error that implies the chain is empty. - ErrEmptyChain = errors.New("empty chain") -) - -// Chain is a proxy chain that holds a list of proxy nodes. -type Chain struct { - nodes []Node -} - -// NewChain creates a proxy chain with proxy nodes nodes. -func NewChain(nodes ...Node) *Chain { - return &Chain{ - nodes: nodes, - } -} - -// Nodes returns the proxy nodes that the chain holds. -func (c *Chain) Nodes() []Node { - return c.nodes -} - -// LastNode returns the last node of the node list. -// If the chain is empty, an empty node is returns. -func (c *Chain) LastNode() Node { - if c.IsEmpty() { - return Node{} - } - return c.nodes[len(c.nodes)-1] -} - -// AddNode appends the node(s) to the chain. -func (c *Chain) AddNode(nodes ...Node) { - if c == nil { - return - } - c.nodes = append(c.nodes, nodes...) -} - -// IsEmpty checks if the chain is empty. -// An empty chain means that there is no proxy node in the chain. -func (c *Chain) IsEmpty() bool { - return c == nil || len(c.nodes) == 0 -} - -// Dial connects to the target address addr through the chain. -// If the chain is empty, it will use the net.Dial directly. -func (c *Chain) Dial(addr string) (net.Conn, error) { - if c.IsEmpty() { - return net.Dial("tcp", addr) - } - - conn, err := c.Conn() - if err != nil { - return nil, err - } - - cc, err := c.LastNode().Client.Connect(conn, addr) - if err != nil { - conn.Close() - return nil, err - } - return cc, nil -} - -// Conn obtains a handshaked connection to the last node of the chain. -// If the chain is empty, it returns an ErrEmptyChain error. -func (c *Chain) Conn() (net.Conn, error) { - if c.IsEmpty() { - return nil, ErrEmptyChain - } - - nodes := c.nodes - conn, err := nodes[0].Client.Dial(nodes[0].Addr, nodes[0].DialOptions...) - if err != nil { - return nil, err - } - - conn, err = nodes[0].Client.Handshake(conn, nodes[0].HandshakeOptions...) - if err != nil { - return nil, err - } - - for i, node := range nodes { - if i == len(nodes)-1 { - break - } - - next := nodes[i+1] - cc, err := node.Client.Connect(conn, next.Addr) - if err != nil { - conn.Close() - return nil, err - } - cc, err = next.Client.Handshake(cc, next.HandshakeOptions...) - if err != nil { - conn.Close() - return nil, err - } - conn = cc - } - return conn, nil -} diff --git a/gost/cmd/gost/main.go b/gost/cmd/gost/main.go deleted file mode 100644 index 55a7d75..0000000 --- a/gost/cmd/gost/main.go +++ /dev/null @@ -1,430 +0,0 @@ -package main - -import ( - "bufio" - "crypto/tls" - "encoding/json" - "errors" - "flag" - "fmt" - "io/ioutil" - "net" - "net/url" - "os" - "runtime" - "strconv" - "strings" - "time" - - "github.com/ginuerzh/gost/gost" - "github.com/go-log/log" -) - -var ( - options struct { - chainNodes, serveNodes stringList - debugMode bool - } -) - -func init() { - var ( - configureFile string - printVersion bool - ) - - flag.Var(&options.chainNodes, "F", "forward address, can make a forward chain") - flag.Var(&options.serveNodes, "L", "listen address, can listen on multiple ports") - flag.StringVar(&configureFile, "C", "", "configure file") - flag.BoolVar(&options.debugMode, "D", false, "enable debug log") - flag.BoolVar(&printVersion, "V", false, "print version") - flag.Parse() - - if err := loadConfigureFile(configureFile); err != nil { - log.Log(err) - os.Exit(1) - } - - if flag.NFlag() == 0 { - flag.PrintDefaults() - os.Exit(0) - } - - if printVersion { - fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version()) - os.Exit(0) - } - - gost.Debug = options.debugMode -} - -func main() { - chain, err := initChain() - if err != nil { - log.Log(err) - os.Exit(1) - } - if err := serve(chain); err != nil { - log.Log(err) - os.Exit(1) - } - select {} -} - -func initChain() (*gost.Chain, error) { - chain := gost.NewChain() - for _, ns := range options.chainNodes { - node, err := gost.ParseNode(ns) - if err != nil { - return nil, err - } - - serverName, _, _ := net.SplitHostPort(node.Addr) - if serverName == "" { - serverName = "localhost" // default server name - } - - tlsCfg := &tls.Config{ - ServerName: serverName, - InsecureSkipVerify: !toBool(node.Values.Get("scure")), - } - var tr gost.Transporter - switch node.Transport { - case "tls": - tr = gost.TLSTransporter() - case "ws": - wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = toBool(node.Values.Get("compression")) - wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) - wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) - node.HandshakeOptions = append(node.HandshakeOptions, - gost.WSOptionsHandshakeOption(wsOpts), - ) - tr = gost.WSTransporter(nil) - case "wss": - tr = gost.WSSTransporter(nil) - case "kcp": - if !chain.IsEmpty() { - return nil, errors.New("KCP must be the first node in the proxy chain") - } - config, err := parseKCPConfig(node.Values.Get("c")) - if err != nil { - log.Log("[kcp]", err) - } - node.HandshakeOptions = append(node.HandshakeOptions, - gost.KCPConfigHandshakeOption(config), - ) - tr = gost.KCPTransporter(nil) - case "ssh": - if node.Protocol == "direct" || node.Protocol == "remote" { - tr = gost.SSHForwardTransporter() - } else { - tr = gost.SSHTunnelTransporter() - } - node.Chain = chain // cutoff the chain for multiplex - chain = gost.NewChain() - case "quic": - if !chain.IsEmpty() { - return nil, errors.New("QUIC must be the first node in the proxy chain") - } - config := &gost.QUICConfig{ - TLSConfig: tlsCfg, - KeepAlive: toBool(node.Values.Get("keepalive")), - } - node.HandshakeOptions = append(node.HandshakeOptions, - gost.QUICConfigHandshakeOption(config), - ) - tr = gost.QUICTransporter(nil) - case "http2": - tr = gost.HTTP2Transporter(nil) - node.Chain = chain // cutoff the chain for multiplex - chain = gost.NewChain() - case "h2": - tr = gost.H2Transporter(nil) - case "h2c": - tr = gost.H2CTransporter() - default: - tr = gost.TCPTransporter() - } - - var connector gost.Connector - switch node.Protocol { - case "http2": - connector = gost.HTTP2Connector(nil) - case "socks", "socks5": - connector = gost.SOCKS5Connector(nil) - case "socks4": - connector = gost.SOCKS4Connector() - case "socks4a": - connector = gost.SOCKS4AConnector() - case "ss": - connector = gost.ShadowConnector(nil) - case "direct": - connector = gost.SSHDirectForwardConnector() - case "remote": - connector = gost.SSHRemoteForwardConnector() - case "http": - fallthrough - default: - node.Protocol = "http" // default protocol is HTTP - connector = gost.HTTPConnector(nil) - } - - node.DialOptions = append(node.DialOptions, - gost.TimeoutDialOption(gost.DialTimeout), - gost.ChainDialOption(node.Chain), - ) - - interval, _ := strconv.Atoi(node.Values.Get("ping")) - node.HandshakeOptions = append(node.HandshakeOptions, - gost.AddrHandshakeOption(node.Addr), - gost.UserHandshakeOption(node.User), - gost.TLSConfigHandshakeOption(tlsCfg), - gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), - ) - node.Client = &gost.Client{ - Connector: connector, - Transporter: tr, - } - chain.AddNode(node) - } - - return chain, nil -} - -func serve(chain *gost.Chain) error { - for _, ns := range options.serveNodes { - node, err := gost.ParseNode(ns) - if err != nil { - return err - } - users, err := parseUsers(node.Values.Get("secrets")) - if err != nil { - return err - } - tlsCfg, err := tlsConfig(node.Values.Get("cert"), node.Values.Get("key")) - if err != nil { - return err - } - - var ln gost.Listener - switch node.Transport { - case "tls": - ln, err = gost.TLSListener(node.Addr, tlsCfg) - case "ws": - wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = toBool(node.Values.Get("compression")) - wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) - wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) - ln, err = gost.WSListener(node.Addr, wsOpts) - case "wss": - wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = toBool(node.Values.Get("compression")) - wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) - wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) - ln, err = gost.WSSListener(node.Addr, tlsCfg, wsOpts) - case "kcp": - config, err := parseKCPConfig(node.Values.Get("c")) - if err != nil { - log.Log("[kcp]", err) - } - ln, err = gost.KCPListener(node.Addr, config) - case "ssh": - config := &gost.SSHConfig{ - Users: users, - TLSConfig: tlsCfg, - } - if node.Protocol == "forward" { - ln, err = gost.TCPListener(node.Addr) - } else { - ln, err = gost.SSHTunnelListener(node.Addr, config) - } - case "quic": - config := &gost.QUICConfig{ - TLSConfig: tlsCfg, - KeepAlive: toBool(node.Values.Get("keepalive")), - } - timeout, _ := strconv.Atoi(node.Values.Get("timeout")) - config.Timeout = time.Duration(timeout) * time.Second - ln, err = gost.QUICListener(node.Addr, config) - case "http2": - ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) - case "h2": - ln, err = gost.H2Listener(node.Addr, tlsCfg) - case "h2c": - ln, err = gost.H2CListener(node.Addr) - case "tcp": - ln, err = gost.TCPListener(node.Addr) - case "rtcp": - ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) - case "udp": - ttl, _ := strconv.Atoi(node.Values.Get("ttl")) - ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(ttl)*time.Second) - case "rudp": - ttl, _ := strconv.Atoi(node.Values.Get("ttl")) - ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(ttl)*time.Second) - case "redirect": - ln, err = gost.TCPListener(node.Addr) - case "ssu": - ttl, _ := strconv.Atoi(node.Values.Get("ttl")) - ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(ttl)*time.Second) - default: - ln, err = gost.TCPListener(node.Addr) - } - if err != nil { - return err - } - - var whitelist, blacklist *gost.Permissions - if node.Values.Get("whitelist") != "" { - if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil { - return err - } - } else { - // By default allow for everyting - whitelist, _ = gost.ParsePermissions("*:*:*") - } - - if node.Values.Get("blacklist") != "" { - if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil { - return err - } - } else { - // By default block nothing - blacklist, _ = gost.ParsePermissions("") - } - - var handlerOptions []gost.HandlerOption - - handlerOptions = append(handlerOptions, - gost.AddrHandlerOption(node.Addr), - gost.ChainHandlerOption(chain), - gost.UsersHandlerOption(users...), - gost.TLSConfigHandlerOption(tlsCfg), - gost.WhitelistHandlerOption(whitelist), - gost.BlacklistHandlerOption(blacklist), - ) - var handler gost.Handler - switch node.Protocol { - case "http2": - handler = gost.HTTP2Handler(handlerOptions...) - case "socks", "socks5": - handler = gost.SOCKS5Handler(handlerOptions...) - case "socks4", "socks4a": - handler = gost.SOCKS4Handler(handlerOptions...) - case "ss": - handler = gost.ShadowHandler(handlerOptions...) - case "http": - handler = gost.HTTPHandler(handlerOptions...) - case "tcp": - handler = gost.TCPDirectForwardHandler(node.Remote, handlerOptions...) - case "rtcp": - handler = gost.TCPRemoteForwardHandler(node.Remote, handlerOptions...) - case "udp": - handler = gost.UDPDirectForwardHandler(node.Remote, handlerOptions...) - case "rudp": - handler = gost.UDPRemoteForwardHandler(node.Remote, handlerOptions...) - case "forward": - handler = gost.SSHForwardHandler(handlerOptions...) - case "redirect": - handler = gost.TCPRedirectHandler(handlerOptions...) - case "ssu": - handler = gost.ShadowUDPdHandler(handlerOptions...) - default: - // TODO: auto poroxy handler - handler = gost.HTTPHandler(handlerOptions...) - } - go new(gost.Server).Serve(ln, handler) - } - - return nil -} - -// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. -func tlsConfig(certFile, keyFile string) (*tls.Config, error) { - if certFile == "" || keyFile == "" { - return nil, nil - } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err - } - return &tls.Config{Certificates: []tls.Certificate{cert}}, nil -} - -func loadConfigureFile(configureFile string) error { - if configureFile == "" { - return nil - } - content, err := ioutil.ReadFile(configureFile) - if err != nil { - return err - } - if err := json.Unmarshal(content, &options); err != nil { - return err - } - return nil -} - -type stringList []string - -func (l *stringList) String() string { - return fmt.Sprintf("%s", *l) -} -func (l *stringList) Set(value string) error { - *l = append(*l, value) - return nil -} - -func toBool(s string) bool { - if b, _ := strconv.ParseBool(s); b { - return b - } - n, _ := strconv.Atoi(s) - return n > 0 -} - -func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { - if configFile == "" { - return nil, nil - } - file, err := os.Open(configFile) - if err != nil { - return nil, err - } - defer file.Close() - - config := &gost.KCPConfig{} - if err = json.NewDecoder(file).Decode(config); err != nil { - return nil, err - } - return config, nil -} - -func parseUsers(authFile string) (users []*url.Userinfo, err error) { - if authFile == "" { - return - } - - file, err := os.Open(authFile) - if err != nil { - return - } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - s := strings.SplitN(line, " ", 2) - if len(s) == 1 { - users = append(users, url.User(strings.TrimSpace(s[0]))) - } else if len(s) == 2 { - users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) - } - } - - err = scanner.Err() - return -} diff --git a/gost/forward.go b/gost/forward.go deleted file mode 100644 index bdbe361..0000000 --- a/gost/forward.go +++ /dev/null @@ -1,663 +0,0 @@ -package gost - -import ( - "errors" - "net" - "sync" - "time" - - "fmt" - - "github.com/ginuerzh/gosocks5" - "github.com/go-log/log" -) - -type tcpDirectForwardHandler struct { - raddr string - options *HandlerOptions -} - -// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. -// The raddr is the remote address that the server will forward to. -func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { - h := &tcpDirectForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { - defer conn.Close() - - log.Logf("[tcp] %s - %s", conn.RemoteAddr(), h.raddr) - cc, err := h.options.Chain.Dial(h.raddr) - if err != nil { - log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), h.raddr, err) - return - } - defer cc.Close() - - log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), h.raddr) - transport(conn, cc) - log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), h.raddr) -} - -type udpDirectForwardHandler struct { - raddr string - options *HandlerOptions -} - -// UDPDirectForwardHandler creates a server Handler for UDP port forwarding server. -// The raddr is the remote address that the server will forward to. -func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { - h := &udpDirectForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *udpDirectForwardHandler) Handle(conn net.Conn) { - defer conn.Close() - - var cc net.Conn - if h.options.Chain.IsEmpty() { - raddr, err := net.ResolveUDPAddr("udp", h.raddr) - if err != nil { - log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) - return - } - cc, err = net.DialUDP("udp", nil, raddr) - if err != nil { - log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) - return - } - } else { - var err error - cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) - if err != nil { - log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) - return - } - cc = &udpTunnelConn{Conn: cc, raddr: h.raddr} - } - - defer cc.Close() - - log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), h.raddr) - transport(conn, cc) - log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), h.raddr) -} - -type tcpRemoteForwardHandler struct { - raddr string - options *HandlerOptions -} - -// TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server. -// The raddr is the remote address that the server will forward to. -func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { - h := &tcpRemoteForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { - defer conn.Close() - - cc, err := net.DialTimeout("tcp", h.raddr, DialTimeout) - if err != nil { - log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), h.raddr, err) - return - } - defer cc.Close() - - log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), h.raddr) - transport(cc, conn) - log.Logf("[rtcp] %s >-< %s", conn.LocalAddr(), h.raddr) -} - -type udpRemoteForwardHandler struct { - raddr string - options *HandlerOptions -} - -// UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server. -// The raddr is the remote address that the server will forward to. -func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { - h := &udpRemoteForwardHandler{ - raddr: raddr, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { - defer conn.Close() - - raddr, err := net.ResolveUDPAddr("udp", h.raddr) - if err != nil { - log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) - return - } - cc, err := net.DialUDP("udp", nil, raddr) - if err != nil { - log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) - return - } - - log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), h.raddr) - transport(conn, cc) - log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), h.raddr) -} - -type udpDirectForwardListener struct { - ln net.PacketConn - conns map[string]*udpServerConn - connChan chan net.Conn - errChan chan error - ttl time.Duration -} - -// UDPDirectForwardListener creates a Listener for UDP port forwarding server. -func UDPDirectForwardListener(addr string, ttl time.Duration) (Listener, error) { - laddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenUDP("udp", laddr) - if err != nil { - return nil, err - } - l := &udpDirectForwardListener{ - ln: ln, - conns: make(map[string]*udpServerConn), - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - ttl: ttl, - } - go l.listenLoop() - return l, nil -} - -func (l *udpDirectForwardListener) listenLoop() { - for { - b := make([]byte, mediumBufferSize) - n, raddr, err := l.ln.ReadFrom(b) - if err != nil { - log.Logf("[udp] peer -> %s : %s", l.Addr(), err) - l.ln.Close() - l.errChan <- err - close(l.errChan) - return - } - if Debug { - log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) - } - conn, ok := l.conns[raddr.String()] - if !ok || conn.Closed() { - conn = newUDPServerConn(l.ln, raddr, l.ttl) - l.conns[raddr.String()] = conn - - select { - case l.connChan <- conn: - default: - conn.Close() - log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr()) - } - } - - select { - case conn.rChan <- b[:n]: - default: - log.Logf("[udp] %s -> %s : read queue is full", raddr, l.Addr()) - } - } -} - -func (l *udpDirectForwardListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *udpDirectForwardListener) Addr() net.Addr { - return l.ln.LocalAddr() -} - -func (l *udpDirectForwardListener) Close() error { - return l.ln.Close() -} - -type udpServerConn struct { - conn net.PacketConn - raddr net.Addr - rChan, wChan chan []byte - closed chan struct{} - brokenChan chan struct{} - closeMutex sync.Mutex - ttl time.Duration - nopChan chan int -} - -func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration) *udpServerConn { - c := &udpServerConn{ - conn: conn, - raddr: raddr, - rChan: make(chan []byte, 128), - wChan: make(chan []byte, 128), - closed: make(chan struct{}), - brokenChan: make(chan struct{}), - nopChan: make(chan int), - ttl: ttl, - } - go c.writeLoop() - go c.ttlWait() - return c -} - -func (c *udpServerConn) Read(b []byte) (n int, err error) { - select { - case bb := <-c.rChan: - n = copy(b, bb) - if n != len(bb) { - err = errors.New("read partial data") - return - } - case <-c.brokenChan: - err = errors.New("Broken pipe") - case <-c.closed: - err = errors.New("read from closed connection") - return - } - - select { - case c.nopChan <- n: - default: - } - return -} - -func (c *udpServerConn) Write(b []byte) (n int, err error) { - if len(b) == 0 { - return 0, nil - } - select { - case c.wChan <- b: - n = len(b) - case <-c.brokenChan: - err = errors.New("Broken pipe") - case <-c.closed: - err = errors.New("write to closed connection") - return - } - - select { - case c.nopChan <- n: - default: - } - - return -} - -func (c *udpServerConn) Close() error { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - - select { - case <-c.closed: - return errors.New("connection is closed") - default: - close(c.closed) - } - return nil -} - -func (c *udpServerConn) Closed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - -func (c *udpServerConn) writeLoop() { - for { - select { - case b, ok := <-c.wChan: - if !ok { - return - } - n, err := c.conn.WriteTo(b, c.raddr) - if err != nil { - log.Logf("[udp] %s - %s : %s", c.RemoteAddr(), c.LocalAddr(), err) - return - } - if Debug { - log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n) - } - case <-c.brokenChan: - return - case <-c.closed: - return - } - } -} - -func (c *udpServerConn) ttlWait() { - ttl := c.ttl - if ttl == 0 { - ttl = defaultTTL - } - timer := time.NewTimer(ttl) - - for { - select { - case <-c.nopChan: - timer.Reset(ttl) - case <-timer.C: - close(c.brokenChan) - return - case <-c.closed: - return - } - } -} - -func (c *udpServerConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *udpServerConn) RemoteAddr() net.Addr { - return c.raddr -} - -func (c *udpServerConn) SetDeadline(t time.Time) error { - return nil -} - -func (c *udpServerConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *udpServerConn) SetWriteDeadline(t time.Time) error { - return nil -} - -type tcpRemoteForwardListener struct { - addr net.Addr - chain *Chain - ln net.Listener - closed chan struct{} -} - -// TCPRemoteForwardListener creates a Listener for TCP remote port forwarding server. -func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { - laddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - - return &tcpRemoteForwardListener{ - addr: laddr, - chain: chain, - closed: make(chan struct{}), - }, nil -} - -func (l *tcpRemoteForwardListener) Accept() (net.Conn, error) { - select { - case <-l.closed: - return nil, errors.New("closed") - default: - } - - var tempDelay time.Duration - for { - conn, err := l.accept() - if err != nil { - if tempDelay == 0 { - tempDelay = 1000 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 6 * time.Second; tempDelay > max { - tempDelay = max - } - log.Logf("[rtcp] Accept error: %v; retrying in %v", err, tempDelay) - time.Sleep(tempDelay) - continue - } - return conn, nil - } -} - -func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { - lastNode := l.chain.LastNode() - if lastNode.Protocol == "remote" && lastNode.Transport == "ssh" { - conn, err = l.chain.Dial(l.addr.String()) - } else if lastNode.Protocol == "socks5" { - cc, er := l.chain.Conn() - if er != nil { - return nil, er - } - conn, err = l.waitConnectSOCKS5(cc) - if err != nil { - cc.Close() - } - } else { - if l.ln == nil { - l.ln, err = net.Listen("tcp", l.addr.String()) - if err != nil { - return - } - } - conn, err = l.ln.Accept() - } - return -} - -func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) { - conn, err := socks5Handshake(conn, l.chain.LastNode().User) - if err != nil { - return nil, err - } - req := gosocks5.NewRequest(gosocks5.CmdBind, toSocksAddr(l.addr)) - if err := req.Write(conn); err != nil { - log.Log("[rtcp] SOCKS5 BIND request: ", err) - return nil, err - } - - // first reply, bind status - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - rep, err := gosocks5.ReadReply(conn) - if err != nil { - log.Log("[rtcp] SOCKS5 BIND reply: ", err) - return nil, err - } - conn.SetReadDeadline(time.Time{}) - if rep.Rep != gosocks5.Succeeded { - log.Logf("[rtcp] bind on %s failure", l.addr) - return nil, fmt.Errorf("Bind on %s failure", l.addr.String()) - } - log.Logf("[rtcp] BIND ON %s OK", rep.Addr) - - // second reply, peer connected - rep, err = gosocks5.ReadReply(conn) - if err != nil { - log.Log("[rtcp]", err) - return nil, err - } - if rep.Rep != gosocks5.Succeeded { - log.Logf("[rtcp] peer connect failure: %d", rep.Rep) - return nil, errors.New("peer connect failure") - } - - log.Logf("[rtcp] PEER %s CONNECTED", rep.Addr) - return conn, nil -} - -func (l *tcpRemoteForwardListener) Addr() net.Addr { - return l.addr -} - -func (l *tcpRemoteForwardListener) Close() error { - close(l.closed) - return nil -} - -type udpRemoteForwardListener struct { - addr *net.UDPAddr - chain *Chain - conns map[string]*udpServerConn - connChan chan net.Conn - errChan chan error - ttl time.Duration - closed chan struct{} -} - -// UDPRemoteForwardListener creates a Listener for UDP remote port forwarding server. -func UDPRemoteForwardListener(addr string, chain *Chain, ttl time.Duration) (Listener, error) { - laddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - - ln := &udpRemoteForwardListener{ - addr: laddr, - chain: chain, - conns: make(map[string]*udpServerConn), - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - ttl: ttl, - closed: make(chan struct{}), - } - go ln.listenLoop() - return ln, nil -} - -func (l *udpRemoteForwardListener) listenLoop() { - for { - conn, err := l.connect() - if err != nil { - log.Logf("[rudp] %s : %s", l.Addr(), err) - return - } - - defer conn.Close() - - for { - b := make([]byte, mediumBufferSize) - n, raddr, err := conn.ReadFrom(b) - if err != nil { - log.Logf("[rudp] %s : %s", l.Addr(), err) - break - } - if Debug { - log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) - } - uc, ok := l.conns[raddr.String()] - if !ok || uc.Closed() { - uc = newUDPServerConn(conn, raddr, l.ttl) - l.conns[raddr.String()] = uc - - select { - case l.connChan <- uc: - default: - uc.Close() - log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr()) - } - } - - select { - case uc.rChan <- b[:n]: - default: - log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr()) - } - } - } - -} - -func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { - var tempDelay time.Duration - - for { - select { - case <-l.closed: - return nil, errors.New("closed") - default: - } - - lastNode := l.chain.LastNode() - if lastNode.Protocol == "socks5" { - var cc net.Conn - cc, err = getSOCKS5UDPTunnel(l.chain, l.addr) - if err != nil { - log.Logf("[rudp] %s : %s", l.Addr(), err) - } else { - conn = &udpTunnelConn{Conn: cc} - } - } else { - conn, err = net.ListenUDP("udp", l.addr) - } - - if err != nil { - if tempDelay == 0 { - tempDelay = 1000 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 6 * time.Second; tempDelay > max { - tempDelay = max - } - log.Logf("[rudp] Accept error: %v; retrying in %v", err, tempDelay) - time.Sleep(tempDelay) - continue - } - return - } -} - -func (l *udpRemoteForwardListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *udpRemoteForwardListener) Addr() net.Addr { - return l.addr -} - -func (l *udpRemoteForwardListener) Close() error { - close(l.closed) - return nil -} diff --git a/gost/gost.go b/gost/gost.go deleted file mode 100644 index 91768e2..0000000 --- a/gost/gost.go +++ /dev/null @@ -1,108 +0,0 @@ -package gost - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" - "time" - - "github.com/go-log/log" -) - -// Version is the gost version. -const Version = "2.4-dev20170803" - -// Debug is a flag that enables the debug log. -var Debug bool - -var ( - tinyBufferSize = 128 - smallBufferSize = 1 * 1024 // 1KB small buffer - mediumBufferSize = 8 * 1024 // 8KB medium buffer - largeBufferSize = 32 * 1024 // 32KB large buffer -) - -var ( - // KeepAliveTime is the keep alive time period for TCP connection. - KeepAliveTime = 180 * time.Second - // DialTimeout is the timeout of dial. - DialTimeout = 30 * time.Second - // ReadTimeout is the timeout for reading. - ReadTimeout = 30 * time.Second - // WriteTimeout is the timeout for writing. - WriteTimeout = 60 * time.Second - // PingTimeout is the timeout for pinging. - PingTimeout = 30 * time.Second - // PingRetries is the reties of ping. - PingRetries = 3 - // default udp node TTL in second for udp port forwarding. - defaultTTL = 60 * time.Second -) - -var ( - defaultRawCert []byte - defaultRawKey []byte -) - -func init() { - rawCert, rawKey, err := generateKeyPair() - if err != nil { - panic(err) - } - defaultRawCert, defaultRawKey = rawCert, rawKey - - log.DefaultLogger = &LogLogger{} -} - -func SetLogger(logger log.Logger) { - log.DefaultLogger = logger -} - -func generateKeyPair() (rawCert, rawKey []byte, err error) { - if defaultRawCert != nil && defaultRawKey != nil { - return defaultRawCert, defaultRawKey, nil - } - - // Create private key and self-signed certificate - // Adapted from https://golang.org/src/crypto/tls/generate_cert.go - - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return - } - validFor := time.Hour * 24 * 365 * 10 // ten years - notBefore := time.Now() - notAfter := notBefore.Add(validFor) - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"gost"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - return - } - - rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) - - return -} - -// SetDefaultCertificate replaces the default certificate by your own -func SetDefaultCertificate(rawCert, rawKey []byte) { - defaultRawCert = rawCert - defaultRawKey = rawKey -} diff --git a/gost/handler.go b/gost/handler.go deleted file mode 100644 index 6a6c7de..0000000 --- a/gost/handler.go +++ /dev/null @@ -1,67 +0,0 @@ -package gost - -import ( - "crypto/tls" - "net" - "net/url" -) - -// Handler is a proxy server handler -type Handler interface { - Handle(net.Conn) -} - -// HandlerOptions describes the options for Handler. -type HandlerOptions struct { - Addr string - Chain *Chain - Users []*url.Userinfo - TLSConfig *tls.Config - Whitelist *Permissions - Blacklist *Permissions -} - -// HandlerOption allows a common way to set handler options. -type HandlerOption func(opts *HandlerOptions) - -// AddrHandlerOption sets the Addr option of HandlerOptions. -func AddrHandlerOption(addr string) HandlerOption { - return func(opts *HandlerOptions) { - opts.Addr = addr - } -} - -// ChainHandlerOption sets the Chain option of HandlerOptions. -func ChainHandlerOption(chain *Chain) HandlerOption { - return func(opts *HandlerOptions) { - opts.Chain = chain - } -} - -// UsersHandlerOption sets the Users option of HandlerOptions. -func UsersHandlerOption(users ...*url.Userinfo) HandlerOption { - return func(opts *HandlerOptions) { - opts.Users = users - } -} - -// TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions. -func TLSConfigHandlerOption(config *tls.Config) HandlerOption { - return func(opts *HandlerOptions) { - opts.TLSConfig = config - } -} - -// WhitelistHandlerOption sets the Whitelist option of HandlerOptions. -func WhitelistHandlerOption(whitelist *Permissions) HandlerOption { - return func(opts *HandlerOptions) { - opts.Whitelist = whitelist - } -} - -// BlacklistHandlerOption sets the Blacklist option of HandlerOptions. -func BlacklistHandlerOption(blacklist *Permissions) HandlerOption { - return func(opts *HandlerOptions) { - opts.Blacklist = blacklist - } -} diff --git a/gost/http.go b/gost/http.go deleted file mode 100644 index dfc5dc0..0000000 --- a/gost/http.go +++ /dev/null @@ -1,259 +0,0 @@ -package gost - -import ( - "bufio" - "encoding/base64" - "fmt" - "net" - "net/http" - "net/http/httputil" - "net/url" - "strings" - "time" - - "github.com/go-log/log" -) - -type httpConnector struct { - User *url.Userinfo -} - -// HTTPConnector creates a Connector for HTTP proxy client. -// It accepts an optional auth info for HTTP Basic Authentication. -func HTTPConnector(user *url.Userinfo) Connector { - return &httpConnector{User: user} -} - -func (c *httpConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { - req := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Host: addr}, - Host: addr, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - req.Header.Set("Proxy-Connection", "keep-alive") - - if c.User != nil { - s := c.User.String() - if _, set := c.User.Password(); !set { - s += ":" - } - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) - } - - if err := req.Write(conn); err != nil { - return nil, err - } - - if Debug { - dump, _ := httputil.DumpRequest(req, false) - log.Log(string(dump)) - } - - resp, err := http.ReadResponse(bufio.NewReader(conn), req) - if err != nil { - return nil, err - } - - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Log(string(dump)) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%s", resp.Status) - } - - return conn, nil -} - -type httpHandler struct { - options *HandlerOptions -} - -// HTTPHandler creates a server Handler for HTTP proxy server. -func HTTPHandler(opts ...HandlerOption) Handler { - h := &httpHandler{ - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *httpHandler) Handle(conn net.Conn) { - defer conn.Close() - - req, err := http.ReadRequest(bufio.NewReader(conn)) - if err != nil { - log.Logf("[http] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - - if Debug { - log.Logf("[http] %s %s - %s %s", req.Method, conn.RemoteAddr(), req.Host, req.Proto) - dump, _ := httputil.DumpRequest(req, false) - log.Logf(string(dump)) - } - - if req.Method == "PRI" && req.ProtoMajor == 2 { - log.Logf("[http] %s <- %s : Not an HTTP2 server", conn.RemoteAddr(), req.Host) - resp := "HTTP/1.1 400 Bad Request\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n" - conn.Write([]byte(resp)) - return - } - - u, p, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) - if !authenticate(u, p, h.options.Users...) { - log.Logf("[http] %s <- %s : proxy authentication required", conn.RemoteAddr(), req.Host) - resp := "HTTP/1.1 407 Proxy Authentication Required\r\n" + - "Proxy-Authenticate: Basic realm=\"gost\"\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n" - conn.Write([]byte(resp)) - return - } - - req.Header.Del("Proxy-Authorization") - req.Header.Del("Proxy-Connection") - - if !Can("tcp", req.Host, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[http] Unauthorized to tcp connect to %s", req.Host) - b := []byte("HTTP/1.1 403 Forbidden\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - conn.Write(b) - if Debug { - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) - } - return - } - - // forward http request - lastNode := h.options.Chain.LastNode() - if req.Method != http.MethodConnect && lastNode.Protocol == "http" { - h.forwardRequest(conn, req) - return - } - - host := req.Host - if !strings.Contains(req.Host, ":") { - host += ":80" - } - cc, err := h.options.Chain.Dial(host) - if err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) - - b := []byte("HTTP/1.1 503 Service unavailable\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - if Debug { - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) - } - conn.Write(b) - return - } - defer cc.Close() - - if req.Method == http.MethodConnect { - b := []byte("HTTP/1.1 200 Connection established\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - if Debug { - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) - } - conn.Write(b) - } else { - req.Header.Del("Proxy-Connection") - - if err = req.Write(cc); err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) - return - } - } - - log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host) - transport(conn, cc) - log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host) -} - -func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request) { - if h.options.Chain.IsEmpty() { - return - } - lastNode := h.options.Chain.LastNode() - - cc, err := h.options.Chain.Conn() - if err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), lastNode.Addr, err) - - b := []byte("HTTP/1.1 503 Service unavailable\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - if Debug { - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), lastNode.Addr, string(b)) - } - conn.Write(b) - return - } - defer cc.Close() - - if lastNode.User != nil { - s := lastNode.User.String() - if _, set := lastNode.User.Password(); !set { - s += ":" - } - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) - } - - cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err = req.WriteProxy(cc); err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) - return - } - cc.SetWriteDeadline(time.Time{}) - - log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host) - transport(conn, cc) - log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host) - return -} - -func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { - if proxyAuth == "" { - return - } - - if !strings.HasPrefix(proxyAuth, "Basic ") { - return - } - c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) - if err != nil { - return - } - cs := string(c) - s := strings.IndexByte(cs, ':') - if s < 0 { - return - } - - return cs[:s], cs[s+1:], true -} - -func authenticate(username, password string, users ...*url.Userinfo) bool { - if len(users) == 0 { - return true - } - - for _, user := range users { - u := user.Username() - p, _ := user.Password() - if (u == username && p == password) || - (u == username && p == "") || - (u == "" && p == password) { - return true - } - } - return false -} diff --git a/gost/kcp.go b/gost/kcp.go deleted file mode 100644 index 857c697..0000000 --- a/gost/kcp.go +++ /dev/null @@ -1,517 +0,0 @@ -package gost - -import ( - "crypto/sha1" - "encoding/csv" - "errors" - "fmt" - "net" - "os" - "time" - - "golang.org/x/crypto/pbkdf2" - - "sync" - - "github.com/go-log/log" - "github.com/klauspost/compress/snappy" - "gopkg.in/xtaci/kcp-go.v2" - "gopkg.in/xtaci/smux.v1" -) - -var ( - // KCPSalt is the default salt for KCP cipher. - KCPSalt = "kcp-go" -) - -// KCPConfig describes the config for KCP. -type KCPConfig struct { - Key string `json:"key"` - Crypt string `json:"crypt"` - Mode string `json:"mode"` - MTU int `json:"mtu"` - SndWnd int `json:"sndwnd"` - RcvWnd int `json:"rcvwnd"` - DataShard int `json:"datashard"` - ParityShard int `json:"parityshard"` - DSCP int `json:"dscp"` - NoComp bool `json:"nocomp"` - AckNodelay bool `json:"acknodelay"` - NoDelay int `json:"nodelay"` - Interval int `json:"interval"` - Resend int `json:"resend"` - NoCongestion int `json:"nc"` - SockBuf int `json:"sockbuf"` - KeepAlive int `json:"keepalive"` - SnmpLog string `json:"snmplog"` - SnmpPeriod int `json:"snmpperiod"` - Signal bool `json:"signal"` // Signal enables the signal SIGUSR1 feature. -} - -// Init initializes the KCP config. -func (c *KCPConfig) Init() { - switch c.Mode { - case "normal": - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 50, 2, 1 - case "fast2": - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 30, 2, 1 - case "fast3": - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 20, 2, 1 - case "fast": - fallthrough - default: - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 40, 2, 1 - } -} - -var ( - // DefaultKCPConfig is the default KCP config. - DefaultKCPConfig = &KCPConfig{ - Key: "it's a secrect", - Crypt: "aes", - Mode: "fast", - MTU: 1350, - SndWnd: 1024, - RcvWnd: 1024, - DataShard: 10, - ParityShard: 3, - DSCP: 0, - NoComp: false, - AckNodelay: false, - NoDelay: 0, - Interval: 50, - Resend: 0, - NoCongestion: 0, - SockBuf: 4194304, - KeepAlive: 10, - SnmpLog: "", - SnmpPeriod: 60, - Signal: false, - } -) - -type kcpConn struct { - conn net.Conn - stream *smux.Stream -} - -func (c *kcpConn) Read(b []byte) (n int, err error) { - return c.stream.Read(b) -} - -func (c *kcpConn) Write(b []byte) (n int, err error) { - return c.stream.Write(b) -} - -func (c *kcpConn) Close() error { - return c.stream.Close() -} - -func (c *kcpConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *kcpConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *kcpConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *kcpConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *kcpConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -type kcpSession struct { - conn net.Conn - session *smux.Session -} - -func (session *kcpSession) GetConn() (*kcpConn, error) { - stream, err := session.session.OpenStream() - if err != nil { - return nil, err - } - return &kcpConn{conn: session.conn, stream: stream}, nil -} - -func (session *kcpSession) Close() error { - return session.session.Close() -} - -func (session *kcpSession) IsClosed() bool { - return session.session.IsClosed() -} - -func (session *kcpSession) NumStreams() int { - return session.session.NumStreams() -} - -type kcpTransporter struct { - sessions map[string]*kcpSession - sessionMutex sync.Mutex - config *KCPConfig -} - -// KCPTransporter creates a Transporter that is used by KCP proxy client. -func KCPTransporter(config *KCPConfig) Transporter { - if config == nil { - config = DefaultKCPConfig - } - config.Init() - - go snmpLogger(config.SnmpLog, config.SnmpPeriod) - if config.Signal { - go kcpSigHandler() - } - - return &kcpTransporter{ - config: config, - sessions: make(map[string]*kcpSession), - } -} - -func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - uaddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if !ok { - conn, err = net.DialUDP("udp", nil, uaddr) - if err != nil { - return - } - session = &kcpSession{conn: conn} - tr.sessions[addr] = session - } - return session.conn, nil -} - -func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - config := tr.config - if opts.KCPConfig != nil { - config = opts.KCPConfig - } - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[opts.Addr] - if session != nil && session.conn != conn { - conn.Close() - return nil, errors.New("kcp: unrecognized connection") - } - if !ok || session.session == nil { - s, err := tr.initSession(opts.Addr, conn, config) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - session = s - tr.sessions[opts.Addr] = session - } - cc, err := session.GetConn() - if err != nil { - session.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - - return cc, nil -} - -func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*kcpSession, error) { - udpConn, ok := conn.(*net.UDPConn) - if !ok { - return nil, errors.New("kcp: wrong connection type") - } - - kcpconn, err := kcp.NewConn(addr, - blockCrypt(config.Key, config.Crypt, KCPSalt), - config.DataShard, config.ParityShard, - &kcp.ConnectedUDPConn{UDPConn: udpConn, Conn: udpConn}) - if err != nil { - return nil, err - } - - kcpconn.SetStreamMode(true) - kcpconn.SetNoDelay(config.NoDelay, config.Interval, config.Resend, config.NoCongestion) - kcpconn.SetWindowSize(config.SndWnd, config.RcvWnd) - kcpconn.SetMtu(config.MTU) - kcpconn.SetACKNoDelay(config.AckNodelay) - kcpconn.SetKeepAlive(config.KeepAlive) - - if err := kcpconn.SetDSCP(config.DSCP); err != nil { - log.Log("[kcp]", err) - } - if err := kcpconn.SetReadBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - if err := kcpconn.SetWriteBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - - // stream multiplex - smuxConfig := smux.DefaultConfig() - smuxConfig.MaxReceiveBuffer = config.SockBuf - var cc net.Conn = kcpconn - if !config.NoComp { - cc = newCompStreamConn(kcpconn) - } - session, err := smux.Client(cc, smuxConfig) - if err != nil { - return nil, err - } - return &kcpSession{conn: conn, session: session}, nil -} - -func (tr *kcpTransporter) Multiplex() bool { - return true -} - -type kcpListener struct { - config *KCPConfig - ln *kcp.Listener - connChan chan net.Conn - errChan chan error -} - -// KCPListener creates a Listener for KCP proxy server. -func KCPListener(addr string, config *KCPConfig) (Listener, error) { - if config == nil { - config = DefaultKCPConfig - } - config.Init() - - ln, err := kcp.ListenWithOptions(addr, - blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard) - if err != nil { - return nil, err - } - if err = ln.SetDSCP(config.DSCP); err != nil { - log.Log("[kcp]", err) - } - if err = ln.SetReadBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - if err = ln.SetWriteBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - - go snmpLogger(config.SnmpLog, config.SnmpPeriod) - if config.Signal { - go kcpSigHandler() - } - - l := &kcpListener{ - config: config, - ln: ln, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - go l.listenLoop() - - return l, nil -} - -func (l *kcpListener) listenLoop() { - for { - conn, err := l.ln.AcceptKCP() - if err != nil { - log.Log("[kcp] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - conn.SetStreamMode(true) - conn.SetNoDelay(l.config.NoDelay, l.config.Interval, l.config.Resend, l.config.NoCongestion) - conn.SetMtu(l.config.MTU) - conn.SetWindowSize(l.config.SndWnd, l.config.RcvWnd) - conn.SetACKNoDelay(l.config.AckNodelay) - conn.SetKeepAlive(l.config.KeepAlive) - go l.mux(conn) - } -} - -func (l *kcpListener) mux(conn net.Conn) { - smuxConfig := smux.DefaultConfig() - smuxConfig.MaxReceiveBuffer = l.config.SockBuf - - log.Logf("[kcp] %s - %s", conn.RemoteAddr(), l.Addr()) - - if !l.config.NoComp { - conn = newCompStreamConn(conn) - } - - mux, err := smux.Server(conn, smuxConfig) - if err != nil { - log.Log("[kcp]", err) - return - } - defer mux.Close() - - log.Logf("[kcp] %s <-> %s", conn.RemoteAddr(), l.Addr()) - defer log.Logf("[kcp] %s >-< %s", conn.RemoteAddr(), l.Addr()) - - for { - stream, err := mux.AcceptStream() - if err != nil { - log.Log("[kcp] accept stream:", err) - return - } - - cc := &kcpConn{conn: conn, stream: stream} - select { - case l.connChan <- cc: - default: - cc.Close() - log.Logf("[kcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) - } - } -} - -func (l *kcpListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} -func (l *kcpListener) Addr() net.Addr { - return l.ln.Addr() -} - -func (l *kcpListener) Close() error { - return l.ln.Close() -} - -func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) { - pass := pbkdf2.Key([]byte(key), []byte(salt), 4096, 32, sha1.New) - - switch crypt { - case "tea": - block, _ = kcp.NewTEABlockCrypt(pass[:16]) - case "xor": - block, _ = kcp.NewSimpleXORBlockCrypt(pass) - case "none": - block, _ = kcp.NewNoneBlockCrypt(pass) - case "aes-128": - block, _ = kcp.NewAESBlockCrypt(pass[:16]) - case "aes-192": - block, _ = kcp.NewAESBlockCrypt(pass[:24]) - case "blowfish": - block, _ = kcp.NewBlowfishBlockCrypt(pass) - case "twofish": - block, _ = kcp.NewTwofishBlockCrypt(pass) - case "cast5": - block, _ = kcp.NewCast5BlockCrypt(pass[:16]) - case "3des": - block, _ = kcp.NewTripleDESBlockCrypt(pass[:24]) - case "xtea": - block, _ = kcp.NewXTEABlockCrypt(pass[:16]) - case "salsa20": - block, _ = kcp.NewSalsa20BlockCrypt(pass) - case "aes": - fallthrough - default: // aes - block, _ = kcp.NewAESBlockCrypt(pass) - } - return -} - -func snmpLogger(format string, interval int) { - if format == "" || interval == 0 { - return - } - ticker := time.NewTicker(time.Duration(interval) * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - f, err := os.OpenFile(time.Now().Format(format), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - log.Log("[kcp]", err) - return - } - w := csv.NewWriter(f) - // write header in empty file - if stat, err := f.Stat(); err == nil && stat.Size() == 0 { - if err := w.Write(append([]string{"Unix"}, kcp.DefaultSnmp.Header()...)); err != nil { - log.Log("[kcp]", err) - } - } - if err := w.Write(append([]string{fmt.Sprint(time.Now().Unix())}, kcp.DefaultSnmp.ToSlice()...)); err != nil { - log.Log("[kcp]", err) - } - kcp.DefaultSnmp.Reset() - w.Flush() - f.Close() - } - } -} - -type compStreamConn struct { - conn net.Conn - w *snappy.Writer - r *snappy.Reader -} - -func newCompStreamConn(conn net.Conn) *compStreamConn { - c := new(compStreamConn) - c.conn = conn - c.w = snappy.NewBufferedWriter(conn) - c.r = snappy.NewReader(conn) - return c -} - -func (c *compStreamConn) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} - -func (c *compStreamConn) Write(b []byte) (n int, err error) { - n, err = c.w.Write(b) - err = c.w.Flush() - return n, err -} - -func (c *compStreamConn) Close() error { - return c.conn.Close() -} - -func (c *compStreamConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *compStreamConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *compStreamConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *compStreamConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *compStreamConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} diff --git a/gost/node.go b/gost/node.go deleted file mode 100644 index b7d103e..0000000 --- a/gost/node.go +++ /dev/null @@ -1,95 +0,0 @@ -package gost - -import ( - "net" - "net/url" - "strconv" - "strings" - - "github.com/go-log/log" -) - -// Node is a proxy node, mainly used to construct a proxy chain. -type Node struct { - Addr string - Protocol string - Transport string - Remote string // remote address, used by tcp/udp port forwarding - User *url.Userinfo - Chain *Chain - Values url.Values - Client *Client - DialOptions []DialOption - HandshakeOptions []HandshakeOption -} - -func ParseNode(s string) (node Node, err error) { - if !strings.Contains(s, "://") { - s = "auto://" + s - } - u, err := url.Parse(s) - if err != nil { - return - } - - node = Node{ - Addr: u.Host, - Values: u.Query(), - User: u.User, - } - - schemes := strings.Split(u.Scheme, "+") - if len(schemes) == 1 { - node.Protocol = schemes[0] - node.Transport = schemes[0] - } - if len(schemes) == 2 { - node.Protocol = schemes[0] - node.Transport = schemes[1] - } - - switch node.Transport { - case "tls", "ws", "wss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "redirect": - case "https": - node.Protocol = "http" - node.Transport = "tls" - case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding - node.Remote = strings.Trim(u.EscapedPath(), "/") - case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding - node.Remote = strings.Trim(u.EscapedPath(), "/") - default: - node.Transport = "" - } - - switch node.Protocol { - case "http", "http2", "socks4", "socks4a", "socks", "socks5", "ss", "ssu": - case "tcp", "udp", "rtcp", "rudp": // port forwarding - case "direct", "remote", "forward": // SSH port forwarding - default: - node.Protocol = "" - } - - return -} - -func Can(action string, addr string, whitelist, blacklist *Permissions) bool { - if !strings.Contains(addr, ":") { - addr = addr + ":80" - } - host, strport, err := net.SplitHostPort(addr) - - if err != nil { - return false - } - - port, err := strconv.Atoi(strport) - - if err != nil { - return false - } - - if Debug { - log.Logf("Can action: %s, host: %s, port %d", action, host, port) - } - return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port) -} diff --git a/gost/permissions.go b/gost/permissions.go deleted file mode 100644 index 8566c80..0000000 --- a/gost/permissions.go +++ /dev/null @@ -1,185 +0,0 @@ -package gost - -import ( - "errors" - "fmt" - "strconv" - "strings" - - glob "github.com/ryanuber/go-glob" -) - -type Permission struct { - Actions StringSet - Hosts StringSet - Ports PortSet -} - -type Permissions []Permission - -func minint(x, y int) int { - if x < y { - return x - } - return y -} - -func maxint(x, y int) int { - if x > y { - return x - } - return y -} - -type PortRange struct { - Min, Max int -} - -func (ir *PortRange) Contains(value int) bool { - return value >= ir.Min && value <= ir.Max -} - -func ParsePortRange(s string) (*PortRange, error) { - if s == "*" { - return &PortRange{Min: 0, Max: 65535}, nil - } - - minmax := strings.Split(s, "-") - switch len(minmax) { - case 1: - port, err := strconv.Atoi(s) - if err != nil { - return nil, err - } - if port < 0 || port > 65535 { - return nil, fmt.Errorf("invalid port: %s", s) - } - return &PortRange{Min: port, Max: port}, nil - case 2: - min, err := strconv.Atoi(minmax[0]) - if err != nil { - return nil, err - } - max, err := strconv.Atoi(minmax[1]) - if err != nil { - return nil, err - } - - realmin := maxint(0, minint(min, max)) - realmax := minint(65535, maxint(min, max)) - - return &PortRange{Min: realmin, Max: realmax}, nil - default: - return nil, fmt.Errorf("invalid range: %s", s) - } -} - -func (ps *PortSet) Contains(value int) bool { - for _, portRange := range *ps { - if portRange.Contains(value) { - return true - } - } - - return false -} - -type PortSet []PortRange - -func ParsePortSet(s string) (*PortSet, error) { - ps := &PortSet{} - - if s == "" { - return nil, errors.New("must specify at least one port") - } - - ranges := strings.Split(s, ",") - - for _, r := range ranges { - portRange, err := ParsePortRange(r) - - if err != nil { - return nil, err - } - - *ps = append(*ps, *portRange) - } - - return ps, nil -} - -func (ss *StringSet) Contains(subj string) bool { - for _, s := range *ss { - if glob.Glob(s, subj) { - return true - } - } - - return false -} - -type StringSet []string - -func ParseStringSet(s string) (*StringSet, error) { - ss := &StringSet{} - if s == "" { - return nil, errors.New("cannot be empty") - } - - *ss = strings.Split(s, ",") - - return ss, nil -} - -func (ps *Permissions) Can(action string, host string, port int) bool { - for _, p := range *ps { - if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) { - return true - } - } - - return false -} - -func ParsePermissions(s string) (*Permissions, error) { - ps := &Permissions{} - - if s == "" { - return &Permissions{}, nil - } - - perms := strings.Split(s, " ") - - for _, perm := range perms { - parts := strings.Split(perm, ":") - - switch len(parts) { - case 3: - actions, err := ParseStringSet(parts[0]) - - if err != nil { - return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0]) - } - - hosts, err := ParseStringSet(parts[1]) - - if err != nil { - return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1]) - } - - ports, err := ParsePortSet(parts[2]) - - if err != nil { - return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2]) - } - - permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports} - - *ps = append(*ps, permission) - default: - return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm) - } - } - - return ps, nil -} diff --git a/gost/permissions_test.go b/gost/permissions_test.go deleted file mode 100644 index bc99824..0000000 --- a/gost/permissions_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package gost - -import ( - "fmt" - "testing" -) - -var portRangeTests = []struct { - in string - out *PortRange -}{ - {"1", &PortRange{Min: 1, Max: 1}}, - {"1-3", &PortRange{Min: 1, Max: 3}}, - {"3-1", &PortRange{Min: 1, Max: 3}}, - {"0-100000", &PortRange{Min: 0, Max: 65535}}, - {"*", &PortRange{Min: 0, Max: 65535}}, -} - -var stringSetTests = []struct { - in string - out *StringSet -}{ - {"*", &StringSet{"*"}}, - {"google.pl,google.com", &StringSet{"google.pl", "google.com"}}, -} - -var portSetTests = []struct { - in string - out *PortSet -}{ - {"1,3", &PortSet{PortRange{Min: 1, Max: 1}, PortRange{Min: 3, Max: 3}}}, - {"1-3,7-5", &PortSet{PortRange{Min: 1, Max: 3}, PortRange{Min: 5, Max: 7}}}, - {"0-100000", &PortSet{PortRange{Min: 0, Max: 65535}}}, - {"*", &PortSet{PortRange{Min: 0, Max: 65535}}}, -} - -var permissionsTests = []struct { - in string - out *Permissions -}{ - {"", &Permissions{}}, - {"*:*:*", &Permissions{ - Permission{ - Actions: StringSet{"*"}, - Hosts: StringSet{"*"}, - Ports: PortSet{PortRange{Min: 0, Max: 65535}}, - }, - }}, - {"bind:127.0.0.1,localhost:80,443,8000-8100 connect:*.google.pl:80,443", &Permissions{ - Permission{ - Actions: StringSet{"bind"}, - Hosts: StringSet{"127.0.0.1", "localhost"}, - Ports: PortSet{ - PortRange{Min: 80, Max: 80}, - PortRange{Min: 443, Max: 443}, - PortRange{Min: 8000, Max: 8100}, - }, - }, - Permission{ - Actions: StringSet{"connect"}, - Hosts: StringSet{"*.google.pl"}, - Ports: PortSet{ - PortRange{Min: 80, Max: 80}, - PortRange{Min: 443, Max: 443}, - }, - }, - }}, -} - -func TestPortRangeParse(t *testing.T) { - for _, test := range portRangeTests { - actual, err := ParsePortRange(test.in) - if err != nil { - t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) - } else if *actual != *test.out { - t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) - } - } -} - -func TestPortRangeContains(t *testing.T) { - actual, _ := ParsePortRange("5-10") - - if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { - t.Errorf("5-10 should contain 5, 7 and 10") - } - - if actual.Contains(4) || actual.Contains(11) { - t.Errorf("5-10 should not contain 4, 11") - } -} - -func TestStringSetParse(t *testing.T) { - for _, test := range stringSetTests { - actual, err := ParseStringSet(test.in) - if err != nil { - t.Errorf("ParseStringSet(%q) returned error: %v", test.in, err) - } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { - t.Errorf("ParseStringSet(%q): got %v, want %v", test.in, actual, test.out) - } - } -} - -func TestStringSetContains(t *testing.T) { - ss, _ := ParseStringSet("google.pl,*.google.com") - - if !ss.Contains("google.pl") || !ss.Contains("www.google.com") { - t.Errorf("google.pl,*.google.com should contain google.pl and www.google.com") - } - - if ss.Contains("www.google.pl") || ss.Contains("foobar.com") { - t.Errorf("google.pl,*.google.com shound not contain www.google.pl and foobar.com") - } -} - -func TestPortSetParse(t *testing.T) { - for _, test := range portSetTests { - actual, err := ParsePortSet(test.in) - if err != nil { - t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) - } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { - t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) - } - } -} - -func TestPortSetContains(t *testing.T) { - actual, _ := ParsePortSet("5-10,20-30") - - if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { - t.Errorf("5-10,20-30 should contain 5, 7 and 10") - } - - if !actual.Contains(20) || !actual.Contains(27) || !actual.Contains(30) { - t.Errorf("5-10,20-30 should contain 20, 27 and 30") - } - - if actual.Contains(4) || actual.Contains(11) || actual.Contains(31) { - t.Errorf("5-10,20-30 should not contain 4, 11, 31") - } -} - -func TestPermissionsParse(t *testing.T) { - for _, test := range permissionsTests { - actual, err := ParsePermissions(test.in) - if err != nil { - t.Errorf("ParsePermissions(%q) returned error: %v", test.in, err) - } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { - t.Errorf("ParsePermissions(%q): got %v, want %v", test.in, actual, test.out) - } - } -} diff --git a/gost/quic.go b/gost/quic.go deleted file mode 100644 index e4ce855..0000000 --- a/gost/quic.go +++ /dev/null @@ -1,237 +0,0 @@ -package gost - -import ( - "crypto/tls" - "errors" - "net" - "sync" - "time" - - "github.com/go-log/log" - quic "github.com/lucas-clemente/quic-go" -) - -type quicSession struct { - conn net.Conn - session quic.Session -} - -func (session *quicSession) GetConn() (*quicConn, error) { - stream, err := session.session.OpenStream() - if err != nil { - return nil, err - } - return &quicConn{ - Stream: stream, - laddr: session.session.LocalAddr(), - raddr: session.session.RemoteAddr(), - }, nil -} - -func (session *quicSession) Close() error { - return session.session.Close(nil) -} - -type quicTransporter struct { - config *QUICConfig - sessionMutex sync.Mutex - sessions map[string]*quicSession -} - -// QUICTransporter creates a Transporter that is used by QUIC proxy client. -func QUICTransporter(config *QUICConfig) Transporter { - if config == nil { - config = &QUICConfig{} - } - return &quicTransporter{ - config: config, - sessions: make(map[string]*quicSession), - } -} - -func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if !ok { - conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return - } - session = &quicSession{conn: conn} - tr.sessions[addr] = session - } - return session.conn, nil -} - -func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - config := tr.config - if opts.QUICConfig != nil { - config = opts.QUICConfig - } - if config.TLSConfig == nil { - config.TLSConfig = &tls.Config{InsecureSkipVerify: true} - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[opts.Addr] - if session != nil && session.conn != conn { - conn.Close() - return nil, errors.New("quic: unrecognized connection") - } - if !ok || session.session == nil { - s, err := tr.initSession(opts.Addr, conn, config) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - session = s - tr.sessions[opts.Addr] = session - } - cc, err := session.GetConn() - if err != nil { - session.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - - return cc, nil -} - -func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) { - udpConn, ok := conn.(*net.UDPConn) - if !ok { - return nil, errors.New("quic: wrong connection type") - } - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - quicConfig := &quic.Config{ - HandshakeTimeout: config.Timeout, - KeepAlive: config.KeepAlive, - } - session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig) - if err != nil { - log.Log("quic dial", err) - return nil, err - } - return &quicSession{conn: conn, session: session}, nil -} - -func (tr *quicTransporter) Multiplex() bool { - return true -} - -type QUICConfig struct { - TLSConfig *tls.Config - Timeout time.Duration - KeepAlive bool -} - -type quicListener struct { - ln quic.Listener - connChan chan net.Conn - errChan chan error -} - -// QUICListener creates a Listener for QUIC proxy server. -func QUICListener(addr string, config *QUICConfig) (Listener, error) { - if config == nil { - config = &QUICConfig{} - } - quicConfig := &quic.Config{ - HandshakeTimeout: config.Timeout, - KeepAlive: config.KeepAlive, - } - - ln, err := quic.ListenAddr(addr, config.TLSConfig, quicConfig) - if err != nil { - return nil, err - } - - l := &quicListener{ - ln: ln, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - go l.listenLoop() - - return l, nil -} - -func (l *quicListener) listenLoop() { - for { - session, err := l.ln.Accept() - if err != nil { - log.Log("[quic] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - go l.sessionLoop(session) - } -} - -func (l *quicListener) sessionLoop(session quic.Session) { - log.Logf("[quic] %s <-> %s", session.RemoteAddr(), session.LocalAddr()) - defer log.Logf("[quic] %s >-< %s", session.RemoteAddr(), session.LocalAddr()) - - for { - stream, err := session.AcceptStream() - if err != nil { - log.Log("[quic] accept stream:", err) - return - } - - cc := &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()} - select { - case l.connChan <- cc: - default: - cc.Close() - log.Logf("[quic] %s - %s: connection queue is full", session.RemoteAddr(), session.LocalAddr()) - } - } -} - -func (l *quicListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *quicListener) Addr() net.Addr { - return l.ln.Addr() -} - -func (l *quicListener) Close() error { - return l.ln.Close() -} - -type quicConn struct { - quic.Stream - laddr net.Addr - raddr net.Addr -} - -func (c *quicConn) LocalAddr() net.Addr { - return c.laddr -} - -func (c *quicConn) RemoteAddr() net.Addr { - return c.raddr -} diff --git a/gost/redirect.go b/gost/redirect.go deleted file mode 100644 index f7033c6..0000000 --- a/gost/redirect.go +++ /dev/null @@ -1,91 +0,0 @@ -// +build !windows - -package gost - -import ( - "errors" - "fmt" - "net" - "syscall" - - "github.com/go-log/log" -) - -type tcpRedirectHandler struct { - options *HandlerOptions -} - -// TCPRedirectHandler creates a server Handler for TCP redirect server. -func TCPRedirectHandler(opts ...HandlerOption) Handler { - h := &tcpRedirectHandler{ - options: &HandlerOptions{ - Chain: new(Chain), - }, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *tcpRedirectHandler) Handle(c net.Conn) { - conn, ok := c.(*net.TCPConn) - if !ok { - log.Log("[red-tcp] not a TCP connection") - } - - srcAddr := conn.RemoteAddr() - dstAddr, conn, err := h.getOriginalDstAddr(conn) - if err != nil { - log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) - return - } - defer conn.Close() - - log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr) - - cc, err := h.options.Chain.Dial(dstAddr.String()) - if err != nil { - log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) - return - } - defer cc.Close() - - log.Logf("[red-tcp] %s <-> %s", srcAddr, dstAddr) - transport(conn, cc) - log.Logf("[red-tcp] %s >-< %s", srcAddr, dstAddr) -} - -func (h *tcpRedirectHandler) getOriginalDstAddr(conn *net.TCPConn) (addr net.Addr, c *net.TCPConn, err error) { - defer conn.Close() - - fc, err := conn.File() - if err != nil { - return - } - defer fc.Close() - - mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80) - if err != nil { - return - } - - // only ipv4 support - ip := net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7]) - port := uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3]) - addr, err = net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ip.String(), port)) - if err != nil { - return - } - - cc, err := net.FileConn(fc) - if err != nil { - return - } - - c, ok := cc.(*net.TCPConn) - if !ok { - err = errors.New("not a TCP connection") - } - return -} diff --git a/gost/redirect_win.go b/gost/redirect_win.go deleted file mode 100644 index 848b70b..0000000 --- a/gost/redirect_win.go +++ /dev/null @@ -1,31 +0,0 @@ -// +build windows - -package gost - -import ( - "net" - - "github.com/go-log/log" -) - -type tcpRedirectHandler struct { - options *HandlerOptions -} - -// TCPRedirectHandler creates a server Handler for TCP redirect server. -func TCPRedirectHandler(opts ...HandlerOption) Handler { - h := &tcpRedirectHandler{ - options: &HandlerOptions{ - Chain: new(Chain), - }, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *tcpRedirectHandler) Handle(c net.Conn) { - log.Log("[red-tcp] TCP redirect is not available on the Windows platform") - c.Close() -} diff --git a/gost/server.go b/gost/server.go deleted file mode 100644 index ed65b35..0000000 --- a/gost/server.go +++ /dev/null @@ -1,104 +0,0 @@ -package gost - -import ( - "io" - "net" - "time" - - "github.com/go-log/log" -) - -// Server is a proxy server. -type Server struct { -} - -// Serve serves as a proxy server. -func (s *Server) Serve(l net.Listener, h Handler) error { - defer l.Close() - - if l == nil { - ln, err := TCPListener(":8080") - if err != nil { - return err - } - l = ln - } - if h == nil { - h = HTTPHandler() - } - - var tempDelay time.Duration - for { - conn, e := l.Accept() - if e != nil { - if ne, ok := e.(net.Error); ok && ne.Temporary() { - if tempDelay == 0 { - tempDelay = 5 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 1 * time.Second; tempDelay > max { - tempDelay = max - } - log.Logf("server: Accept error: %v; retrying in %v", e, tempDelay) - time.Sleep(tempDelay) - continue - } - return e - } - tempDelay = 0 - go h.Handle(conn) - } - -} - -// Listener is a proxy server listener, just like a net.Listener. -type Listener interface { - net.Listener -} - -type tcpListener struct { - net.Listener -} - -// TCPListener creates a Listener for TCP proxy server. -func TCPListener(addr string) (Listener, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - return &tcpListener{Listener: &tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil -} - -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(KeepAliveTime) - return tc, nil -} - -func transport(rw1, rw2 io.ReadWriter) error { - errc := make(chan error, 1) - go func() { - _, err := io.Copy(rw1, rw2) - errc <- err - }() - - go func() { - _, err := io.Copy(rw2, rw1) - errc <- err - }() - - err := <-errc - if err != nil && err == io.EOF { - err = nil - } - return err -} diff --git a/gost/signal.go b/gost/signal.go deleted file mode 100644 index f12e902..0000000 --- a/gost/signal.go +++ /dev/null @@ -1,5 +0,0 @@ -// +build windows - -package gost - -func kcpSigHandler() {} diff --git a/gost/signal_unix.go b/gost/signal_unix.go deleted file mode 100644 index a761318..0000000 --- a/gost/signal_unix.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build !windows - -package gost - -import ( - "os" - "os/signal" - "syscall" - - "github.com/go-log/log" - "gopkg.in/xtaci/kcp-go.v2" -) - -func kcpSigHandler() { - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGUSR1) - - for { - switch <-ch { - case syscall.SIGUSR1: - log.Logf("[kcp] SNMP: %+v", kcp.DefaultSnmp.Copy()) - } - } -} diff --git a/gost/socks.go b/gost/socks.go deleted file mode 100644 index 54ef37c..0000000 --- a/gost/socks.go +++ /dev/null @@ -1,1166 +0,0 @@ -package gost - -import ( - "bytes" - "crypto/tls" - "errors" - "fmt" - "net" - "net/url" - "strconv" - "time" - - "io" - - "github.com/ginuerzh/gosocks4" - "github.com/ginuerzh/gosocks5" - "github.com/go-log/log" -) - -const ( - // MethodTLS is an extended SOCKS5 method for TLS. - MethodTLS uint8 = 0x80 - // MethodTLSAuth is an extended SOCKS5 method for TLS+AUTH. - MethodTLSAuth uint8 = 0x82 -) - -const ( - // CmdUDPTun is an extended SOCKS5 method for UDP over TCP. - CmdUDPTun uint8 = 0xF3 -) - -type clientSelector struct { - methods []uint8 - User *url.Userinfo - TLSConfig *tls.Config -} - -func (selector *clientSelector) Methods() []uint8 { - return selector.methods -} - -func (selector *clientSelector) AddMethod(methods ...uint8) { - selector.methods = append(selector.methods, methods...) -} - -func (selector *clientSelector) Select(methods ...uint8) (method uint8) { - return -} - -func (selector *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { - switch method { - case MethodTLS: - conn = tls.Client(conn, selector.TLSConfig) - - case gosocks5.MethodUserPass, MethodTLSAuth: - if method == MethodTLSAuth { - conn = tls.Client(conn, selector.TLSConfig) - } - - var username, password string - if selector.User != nil { - username = selector.User.Username() - password, _ = selector.User.Password() - } - - req := gosocks5.NewUserPassRequest(gosocks5.UserPassVer, username, password) - if err := req.Write(conn); err != nil { - log.Log("[socks5]", err) - return nil, err - } - if Debug { - log.Log("[socks5]", req) - } - resp, err := gosocks5.ReadUserPassResponse(conn) - if err != nil { - log.Log("[socks5]", err) - return nil, err - } - if Debug { - log.Log("[socks5]", resp) - } - if resp.Status != gosocks5.Succeeded { - return nil, gosocks5.ErrAuthFailure - } - case gosocks5.MethodNoAcceptable: - return nil, gosocks5.ErrBadMethod - } - - return conn, nil -} - -type serverSelector struct { - methods []uint8 - Users []*url.Userinfo - TLSConfig *tls.Config -} - -func (selector *serverSelector) Methods() []uint8 { - return selector.methods -} - -func (selector *serverSelector) AddMethod(methods ...uint8) { - selector.methods = append(selector.methods, methods...) -} - -func (selector *serverSelector) Select(methods ...uint8) (method uint8) { - if Debug { - log.Logf("[socks5] %d %d %v", gosocks5.Ver5, len(methods), methods) - } - method = gosocks5.MethodNoAuth - for _, m := range methods { - if m == MethodTLS { - method = m - break - } - } - - // when user/pass is set, auth is mandatory - if len(selector.Users) > 0 { - if method == gosocks5.MethodNoAuth { - method = gosocks5.MethodUserPass - } - if method == MethodTLS { - method = MethodTLSAuth - } - } - - return -} - -func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { - if Debug { - log.Logf("[socks5] %d %d", gosocks5.Ver5, method) - } - switch method { - case MethodTLS: - conn = tls.Server(conn, selector.TLSConfig) - - case gosocks5.MethodUserPass, MethodTLSAuth: - if method == MethodTLSAuth { - conn = tls.Server(conn, selector.TLSConfig) - } - - req, err := gosocks5.ReadUserPassRequest(conn) - if err != nil { - log.Log("[socks5]", err) - return nil, err - } - if Debug { - log.Log("[socks5]", req.String()) - } - valid := false - for _, user := range selector.Users { - username := user.Username() - password, _ := user.Password() - if (req.Username == username && req.Password == password) || - (req.Username == username && password == "") || - (username == "" && req.Password == password) { - valid = true - break - } - } - if len(selector.Users) > 0 && !valid { - resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) - if err := resp.Write(conn); err != nil { - log.Log("[socks5]", err) - return nil, err - } - if Debug { - log.Log("[socks5]", resp) - } - log.Log("[socks5] proxy authentication required") - return nil, gosocks5.ErrAuthFailure - } - - resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) - if err := resp.Write(conn); err != nil { - log.Log("[socks5]", err) - return nil, err - } - if Debug { - log.Log("[socks5]", resp) - } - case gosocks5.MethodNoAcceptable: - return nil, gosocks5.ErrBadMethod - } - - return conn, nil -} - -type socks5Connector struct { - User *url.Userinfo -} - -// SOCKS5Connector creates a connector for SOCKS5 proxy client. -// It accepts an optional auth info for SOCKS5 Username/Password Authentication. -func SOCKS5Connector(user *url.Userinfo) Connector { - return &socks5Connector{User: user} -} - -func (c *socks5Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { - selector := &clientSelector{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: c.User, - } - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - ) - - cc := gosocks5.ClientConn(conn, selector) - if err := cc.Handleshake(); err != nil { - return nil, err - } - conn = cc - - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - p, _ := strconv.Atoi(port) - req := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{ - Type: gosocks5.AddrDomain, - Host: host, - Port: uint16(p), - }) - if err := req.Write(conn); err != nil { - return nil, err - } - - if Debug { - log.Log("[socks5]", req) - } - - reply, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - - if Debug { - log.Log("[socks5]", reply) - } - - if reply.Rep != gosocks5.Succeeded { - return nil, errors.New("Service unavailable") - } - - return conn, nil -} - -type socks4Connector struct{} - -// SOCKS4Connector creates a Connector for SOCKS4 proxy client. -func SOCKS4Connector() Connector { - return &socks4Connector{} -} - -func (c *socks4Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { - taddr, err := net.ResolveTCPAddr("tcp4", addr) - if err != nil { - return nil, err - } - - req := gosocks4.NewRequest(gosocks4.CmdConnect, - &gosocks4.Addr{ - Type: gosocks4.AddrIPv4, - Host: taddr.IP.String(), - Port: uint16(taddr.Port), - }, nil, - ) - if err := req.Write(conn); err != nil { - return nil, err - } - - if Debug { - log.Logf("[socks4] %s", req) - } - - reply, err := gosocks4.ReadReply(conn) - if err != nil { - return nil, err - } - - if Debug { - log.Logf("[socks4] %s", reply) - } - - if reply.Code != gosocks4.Granted { - return nil, fmt.Errorf("[socks4] %d", reply.Code) - } - - return conn, nil -} - -type socks4aConnector struct{} - -// SOCKS4AConnector creates a Connector for SOCKS4A proxy client. -func SOCKS4AConnector() Connector { - return &socks4aConnector{} -} - -func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - p, _ := strconv.Atoi(port) - - req := gosocks4.NewRequest(gosocks4.CmdConnect, - &gosocks4.Addr{Type: gosocks4.AddrDomain, Host: host, Port: uint16(p)}, nil) - if err := req.Write(conn); err != nil { - return nil, err - } - - if Debug { - log.Logf("[socks4] %s", req) - } - - reply, err := gosocks4.ReadReply(conn) - if err != nil { - return nil, err - } - - if Debug { - log.Logf("[socks4] %s", reply) - } - - if reply.Code != gosocks4.Granted { - return nil, fmt.Errorf("[socks4] %d", reply.Code) - } - - return conn, nil -} - -type socks5Handler struct { - selector *serverSelector - options *HandlerOptions -} - -// SOCKS5Handler creates a server Handler for SOCKS5 proxy server. -func SOCKS5Handler(opts ...HandlerOption) Handler { - options := &HandlerOptions{} - for _, opt := range opts { - opt(options) - } - - selector := &serverSelector{ // socks5 server selector - Users: options.Users, - TLSConfig: options.TLSConfig, - } - // methods that socks5 server supported - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - MethodTLSAuth, - ) - return &socks5Handler{ - options: options, - selector: selector, - } -} - -func (h *socks5Handler) Handle(conn net.Conn) { - defer conn.Close() - - conn = gosocks5.ServerConn(conn, h.selector) - req, err := gosocks5.ReadRequest(conn) - if err != nil { - log.Log("[socks5]", err) - return - } - - if Debug { - log.Logf("[socks5] %s - %s\n%s", conn.RemoteAddr(), req.Addr, req) - } - switch req.Cmd { - case gosocks5.CmdConnect: - h.handleConnect(conn, req) - - case gosocks5.CmdBind: - h.handleBind(conn, req) - - case gosocks5.CmdUdp: - h.handleUDPRelay(conn, req) - - case CmdUDPTun: - h.handleUDPTunnel(conn, req) - - default: - log.Log("[socks5] Unrecognized request:", req.Cmd) - } -} - -func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { - addr := req.Addr.String() - if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5-connect] Unauthorized to tcp connect to %s", addr) - rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - return - } - - cc, err := h.options.Chain.Dial(addr) - if err != nil { - log.Logf("[socks5-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) - rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - return - } - defer cc.Close() - - rep := gosocks5.NewReply(gosocks5.Succeeded, nil) - if err := rep.Write(conn); err != nil { - log.Logf("[socks5-connect] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) - return - } - if Debug { - log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - log.Logf("[socks5-connect] %s <-> %s", conn.RemoteAddr(), req.Addr) - transport(conn, cc) - log.Logf("[socks5-connect] %s >-< %s", conn.RemoteAddr(), req.Addr) -} - -func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) { - if h.options.Chain.IsEmpty() { - addr := req.Addr.String() - if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("Unauthorized to tcp bind to %s", addr) - return - } - h.bindOn(conn, addr) - return - } - - cc, err := h.options.Chain.Conn() - if err != nil { - log.Logf("[socks5-bind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(conn) - if Debug { - log.Logf("[socks5-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) - } - return - } - - // forward request - // note: this type of request forwarding is defined when starting server, - // so we don't need to authenticate it, as it's as explicit as whitelisting - defer cc.Close() - req.Write(cc) - log.Logf("[socks5-bind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - transport(conn, cc) - log.Logf("[socks5-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) -} - -func (h *socks5Handler) bindOn(conn net.Conn, addr string) { - bindAddr, _ := net.ResolveTCPAddr("tcp", addr) - ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error - if err != nil { - log.Logf("[socks5-bind] %s -> %s : %s", conn.RemoteAddr(), addr, err) - gosocks5.NewReply(gosocks5.Failure, nil).Write(conn) - return - } - - socksAddr := toSocksAddr(ln.Addr()) - // Issue: may not reachable when host has multi-interface - socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) - reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) - if err := reply.Write(conn); err != nil { - log.Logf("[socks5-bind] %s <- %s : %s", conn.RemoteAddr(), addr, err) - ln.Close() - return - } - if Debug { - log.Logf("[socks5-bind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply) - } - log.Logf("[socks5-bind] %s - %s BIND ON %s OK", conn.RemoteAddr(), addr, socksAddr) - - var pconn net.Conn - accept := func() <-chan error { - errc := make(chan error, 1) - - go func() { - defer close(errc) - defer ln.Close() - - c, err := ln.AcceptTCP() - if err != nil { - errc <- err - return - } - pconn = c - }() - - return errc - } - - pc1, pc2 := net.Pipe() - pipe := func() <-chan error { - errc := make(chan error, 1) - - go func() { - defer close(errc) - defer pc1.Close() - - errc <- transport(conn, pc1) - }() - - return errc - } - - defer pc2.Close() - - for { - select { - case err := <-accept(): - if err != nil || pconn == nil { - log.Logf("[socks5-bind] %s <- %s : %v", conn.RemoteAddr(), addr, err) - return - } - defer pconn.Close() - - reply := gosocks5.NewReply(gosocks5.Succeeded, toSocksAddr(pconn.RemoteAddr())) - if err := reply.Write(pc2); err != nil { - log.Logf("[socks5-bind] %s <- %s : %v", conn.RemoteAddr(), addr, err) - } - if Debug { - log.Logf("[socks5-bind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply) - } - log.Logf("[socks5-bind] %s <- %s PEER %s ACCEPTED", conn.RemoteAddr(), socksAddr, pconn.RemoteAddr()) - - log.Logf("[socks5-bind] %s <-> %s", conn.RemoteAddr(), pconn.RemoteAddr()) - if err = transport(pc2, pconn); err != nil { - log.Logf("[socks5-bind] %s - %s : %v", conn.RemoteAddr(), pconn.RemoteAddr(), err) - } - log.Logf("[socks5-bind] %s >-< %s", conn.RemoteAddr(), pconn.RemoteAddr()) - return - case err := <-pipe(): - if err != nil { - log.Logf("[socks5-bind] %s -> %s : %v", conn.RemoteAddr(), addr, err) - } - ln.Close() - return - } - } -} - -func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { - addr := req.Addr.String() - if !Can("udp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5-udp] Unauthorized to udp connect to %s", addr) - rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - return - } - - relay, err := net.ListenUDP("udp", nil) - if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), relay.LocalAddr(), err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(conn) - if Debug { - log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), relay.LocalAddr(), reply) - } - return - } - defer relay.Close() - - socksAddr := toSocksAddr(relay.LocalAddr()) - socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's - reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) - if err := reply.Write(conn); err != nil { - log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), relay.LocalAddr(), err) - return - } - if Debug { - log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), reply.Addr, reply) - } - log.Logf("[socks5-udp] %s - %s BIND ON %s OK", conn.RemoteAddr(), relay.LocalAddr(), socksAddr) - - // serve as standard socks5 udp relay local <-> remote - if h.options.Chain.IsEmpty() { - peer, er := net.ListenUDP("udp", nil) - if er != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, er) - return - } - defer peer.Close() - - go h.transportUDP(relay, peer) - log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) - if err := h.discardClientData(conn); err != nil { - log.Logf("[socks5-udp] %s - %s : %s", conn.RemoteAddr(), socksAddr, err) - } - log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) - return - } - - cc, err := h.options.Chain.Conn() - // connection error - if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) - return - } - // forward udp local <-> tunnel - defer cc.Close() - - cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User) - if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) - return - } - - cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) - r := gosocks5.NewRequest(CmdUDPTun, nil) - if err := r.Write(cc); err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), cc.RemoteAddr(), err) - return - } - cc.SetWriteDeadline(time.Time{}) - if Debug { - log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), cc.RemoteAddr(), r) - } - cc.SetReadDeadline(time.Now().Add(ReadTimeout)) - reply, err = gosocks5.ReadReply(cc) - if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), cc.RemoteAddr(), err) - return - } - if Debug { - log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), cc.RemoteAddr(), reply) - } - - if reply.Rep != gosocks5.Succeeded { - log.Logf("[socks5-udp] %s <- %s : udp associate failed", conn.RemoteAddr(), cc.RemoteAddr()) - return - } - cc.SetReadDeadline(time.Time{}) - log.Logf("[socks5-udp] %s <-> %s [tun: %s]", conn.RemoteAddr(), socksAddr, reply.Addr) - - go h.tunnelClientUDP(relay, cc) - log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) - if err := h.discardClientData(conn); err != nil { - log.Logf("[socks5-udp] %s - %s : %s", conn.RemoteAddr(), socksAddr, err) - } - log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) -} - -func (h *socks5Handler) discardClientData(conn net.Conn) (err error) { - b := make([]byte, tinyBufferSize) - n := 0 - for { - n, err = conn.Read(b) // discard any data from tcp connection - if err != nil { - if err == io.EOF { // disconnect normally - err = nil - } - break // client disconnected - } - log.Logf("[socks5-udp] read %d UNEXPECTED TCP data from client", n) - } - return -} - -func (h *socks5Handler) transportUDP(relay, peer *net.UDPConn) (err error) { - errc := make(chan error, 2) - - var clientAddr *net.UDPAddr - - go func() { - b := make([]byte, largeBufferSize) - - for { - n, laddr, err := relay.ReadFromUDP(b) - if err != nil { - errc <- err - return - } - if clientAddr == nil { - clientAddr = laddr - } - dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n])) - if err != nil { - errc <- err - return - } - - raddr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) - if err != nil { - continue // drop silently - } - if _, err := peer.WriteToUDP(dgram.Data, raddr); err != nil { - errc <- err - return - } - if Debug { - log.Logf("[socks5-udp] %s >>> %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) - } - } - }() - - go func() { - b := make([]byte, largeBufferSize) - - for { - n, raddr, err := peer.ReadFromUDP(b) - if err != nil { - errc <- err - return - } - if clientAddr == nil { - continue - } - buf := bytes.Buffer{} - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(raddr)), b[:n]) - dgram.Write(&buf) - if _, err := relay.WriteToUDP(buf.Bytes(), clientAddr); err != nil { - errc <- err - return - } - if Debug { - log.Logf("[socks5-udp] %s <<< %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) - } - } - }() - - select { - case err = <-errc: - //log.Println("w exit", err) - } - - return -} - -func (h *socks5Handler) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) { - errc := make(chan error, 2) - - var clientAddr *net.UDPAddr - - go func() { - b := make([]byte, mediumBufferSize) - - for { - n, addr, err := uc.ReadFromUDP(b) - if err != nil { - log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) - errc <- err - return - } - - // glog.V(LDEBUG).Infof("read udp %d, % #x", n, b[:n]) - // pipe from relay to tunnel - dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n])) - if err != nil { - errc <- err - return - } - if clientAddr == nil { - clientAddr = addr - } - dgram.Header.Rsv = uint16(len(dgram.Data)) - if err := dgram.Write(cc); err != nil { - errc <- err - return - } - if Debug { - log.Logf("[udp-tun] %s >>> %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) - } - } - }() - - go func() { - for { - dgram, err := gosocks5.ReadUDPDatagram(cc) - if err != nil { - log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) - errc <- err - return - } - - // pipe from tunnel to relay - if clientAddr == nil { - continue - } - dgram.Header.Rsv = 0 - - buf := bytes.Buffer{} - dgram.Write(&buf) - if _, err := uc.WriteToUDP(buf.Bytes(), clientAddr); err != nil { - errc <- err - return - } - if Debug { - log.Logf("[udp-tun] %s <<< %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) - } - } - }() - - select { - case err = <-errc: - } - - return -} - -func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { - // serve tunnel udp, tunnel <-> remote, handle tunnel udp request - if h.options.Chain.IsEmpty() { - addr := req.Addr.String() - - if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5-udp] Unauthorized to udp bind to %s", addr) - return - } - - bindAddr, _ := net.ResolveUDPAddr("udp", addr) - uc, err := net.ListenUDP("udp", bindAddr) - if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) - return - } - defer uc.Close() - - socksAddr := toSocksAddr(uc.LocalAddr()) - socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) - reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) - if err := reply.Write(conn); err != nil { - log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) - return - } - if Debug { - log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply) - } - log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) - h.tunnelServerUDP(conn, uc) - log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) - return - } - - cc, err := h.options.Chain.Conn() - // connection error - if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(conn) - log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply) - return - } - defer cc.Close() - - cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User) - if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) - return - } - // tunnel <-> tunnel, direct forwarding - // note: this type of request forwarding is defined when starting server - // so we don't need to authenticate it, as it's as explicit as whitelisting - req.Write(cc) - - log.Logf("[socks5-udp] %s <-> %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) - transport(conn, cc) - log.Logf("[socks5-udp] %s >-< %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) -} - -func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) { - errc := make(chan error, 2) - - go func() { - b := make([]byte, mediumBufferSize) - - for { - n, addr, err := uc.ReadFromUDP(b) - if err != nil { - log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) - errc <- err - return - } - - // pipe from peer to tunnel - dgram := gosocks5.NewUDPDatagram( - gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) - if err := dgram.Write(cc); err != nil { - log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err) - errc <- err - return - } - if Debug { - log.Logf("[udp-tun] %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data)) - } - } - }() - - go func() { - for { - dgram, err := gosocks5.ReadUDPDatagram(cc) - if err != nil { - log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) - errc <- err - return - } - - // pipe from tunnel to peer - addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) - if err != nil { - continue // drop silently - } - if _, err := uc.WriteToUDP(dgram.Data, addr); err != nil { - log.Logf("[udp-tun] %s -> %s : %s", cc.RemoteAddr(), addr, err) - errc <- err - return - } - if Debug { - log.Logf("[udp-tun] %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data)) - } - } - }() - - select { - case err = <-errc: - } - - return -} - -func toSocksAddr(addr net.Addr) *gosocks5.Addr { - host := "0.0.0.0" - port := 0 - if addr != nil { - h, p, _ := net.SplitHostPort(addr.String()) - host = h - port, _ = strconv.Atoi(p) - } - return &gosocks5.Addr{ - Type: gosocks5.AddrIPv4, - Host: host, - Port: uint16(port), - } -} - -type socks4Handler struct { - options *HandlerOptions -} - -// SOCKS4Handler creates a server Handler for SOCKS4(A) proxy server. -func SOCKS4Handler(opts ...HandlerOption) Handler { - options := &HandlerOptions{} - for _, opt := range opts { - opt(options) - } - return &socks4Handler{ - options: options, - } -} - -func (h *socks4Handler) Handle(conn net.Conn) { - defer conn.Close() - - req, err := gosocks4.ReadRequest(conn) - if err != nil { - log.Log("[socks4]", err) - return - } - - if Debug { - log.Logf("[socks4] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, req) - } - - switch req.Cmd { - case gosocks4.CmdConnect: - log.Logf("[socks4-connect] %s -> %s", conn.RemoteAddr(), req.Addr) - h.handleConnect(conn, req) - - case gosocks4.CmdBind: - log.Logf("[socks4-bind] %s - %s", conn.RemoteAddr(), req.Addr) - h.handleBind(conn, req) - - default: - log.Logf("[socks4] Unrecognized request: %d", req.Cmd) - } -} - -func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { - addr := req.Addr.String() - - if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks4-connect] Unauthorized to tcp connect to %s", addr) - rep := gosocks5.NewReply(gosocks4.Rejected, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - return - } - - cc, err := h.options.Chain.Dial(addr) - if err != nil { - log.Logf("[socks4-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) - rep := gosocks4.NewReply(gosocks4.Failed, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - return - } - defer cc.Close() - - rep := gosocks4.NewReply(gosocks4.Granted, nil) - if err := rep.Write(conn); err != nil { - log.Logf("[socks4-connect] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) - return - } - if Debug { - log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - - log.Logf("[socks4-connect] %s <-> %s", conn.RemoteAddr(), req.Addr) - transport(conn, cc) - log.Logf("[socks4-connect] %s >-< %s", conn.RemoteAddr(), req.Addr) -} - -func (h *socks4Handler) handleBind(conn net.Conn, req *gosocks4.Request) { - // TODO: serve socks4 bind - if h.options.Chain.IsEmpty() { - reply := gosocks4.NewReply(gosocks4.Rejected, nil) - reply.Write(conn) - if Debug { - log.Logf("[socks4-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) - } - return - } - - cc, err := h.options.Chain.Conn() - // connection error - if err != nil && err != ErrEmptyChain { - log.Logf("[socks4-bind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) - reply := gosocks4.NewReply(gosocks4.Failed, nil) - reply.Write(conn) - if Debug { - log.Logf("[socks4-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) - } - return - } - - defer cc.Close() - // forward request - req.Write(cc) - - log.Logf("[socks4-bind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - transport(conn, cc) - log.Logf("[socks4-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) -} - -func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { - conn, err := chain.Conn() - if err != nil { - return nil, err - } - cc, err := socks5Handshake(conn, chain.LastNode().User) - if err != nil { - conn.Close() - return nil, err - } - conn = cc - - conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(addr)) - if err := req.Write(conn); err != nil { - conn.Close() - return nil, err - } - if Debug { - log.Log("[socks5]", req) - } - conn.SetWriteDeadline(time.Time{}) - - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - reply, err := gosocks5.ReadReply(conn) - if err != nil { - conn.Close() - return nil, err - } - conn.SetReadDeadline(time.Time{}) - if Debug { - log.Log("[socks5]", reply) - } - - if reply.Rep != gosocks5.Succeeded { - conn.Close() - return nil, errors.New("UDP tunnel failure") - } - return conn, nil -} - -func socks5Handshake(conn net.Conn, user *url.Userinfo) (net.Conn, error) { - selector := &clientSelector{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: user, - } - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - ) - cc := gosocks5.ClientConn(conn, selector) - if err := cc.Handleshake(); err != nil { - return nil, err - } - return cc, nil -} - -type udpTunnelConn struct { - raddr string - net.Conn -} - -func (c *udpTunnelConn) Read(b []byte) (n int, err error) { - dgram, err := gosocks5.ReadUDPDatagram(c.Conn) - if err != nil { - return - } - n = copy(b, dgram.Data) - return -} - -func (c *udpTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - dgram, err := gosocks5.ReadUDPDatagram(c.Conn) - if err != nil { - return - } - n = copy(b, dgram.Data) - addr, err = net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) - return -} - -func (c *udpTunnelConn) Write(b []byte) (n int, err error) { - addr, err := net.ResolveUDPAddr("udp", c.raddr) - if err != nil { - return - } - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) - if err = dgram.Write(c.Conn); err != nil { - return - } - return len(b), nil -} - -func (c *udpTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) - if err = dgram.Write(c.Conn); err != nil { - return - } - return len(b), nil -} diff --git a/gost/ss.go b/gost/ss.go deleted file mode 100644 index 079015d..0000000 --- a/gost/ss.go +++ /dev/null @@ -1,419 +0,0 @@ -package gost - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "net/url" - "strconv" - "time" - - "github.com/ginuerzh/gosocks5" - "github.com/go-log/log" - ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" -) - -// Due to in/out byte length is inconsistent of the shadowsocks.Conn.Write, -// we wrap around it to make io.Copy happy. -type shadowConn struct { - conn net.Conn -} - -func (c *shadowConn) Read(b []byte) (n int, err error) { - return c.conn.Read(b) -} - -func (c *shadowConn) Write(b []byte) (n int, err error) { - n = len(b) // force byte length consistent - _, err = c.conn.Write(b) - return -} - -func (c *shadowConn) Close() error { - return c.conn.Close() -} - -func (c *shadowConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *shadowConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *shadowConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *shadowConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *shadowConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -type shadowConnector struct { - Cipher *url.Userinfo -} - -// ShadowConnector creates a Connector for shadowsocks proxy client. -// It accepts a cipher info for shadowsocks data encryption/decryption. -// The cipher must not be nil. -func ShadowConnector(cipher *url.Userinfo) Connector { - return &shadowConnector{Cipher: cipher} -} - -func (c *shadowConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { - rawaddr, err := ss.RawAddr(addr) - if err != nil { - return nil, err - } - - var method, password string - if c.Cipher != nil { - method = c.Cipher.Username() - password, _ = c.Cipher.Password() - } - - cipher, err := ss.NewCipher(method, password) - if err != nil { - return nil, err - } - - sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher) - if err != nil { - return nil, err - } - return &shadowConn{conn: sc}, nil -} - -type shadowHandler struct { - options *HandlerOptions -} - -// ShadowHandler creates a server Handler for shadowsocks proxy server. -func ShadowHandler(opts ...HandlerOption) Handler { - h := &shadowHandler{ - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *shadowHandler) Handle(conn net.Conn) { - defer conn.Close() - - var method, password string - users := h.options.Users - if len(users) > 0 { - method = users[0].Username() - password, _ = users[0].Password() - } - cipher, err := ss.NewCipher(method, password) - if err != nil { - log.Log("[ss]", err) - return - } - conn = &shadowConn{conn: ss.NewConn(conn, cipher)} - - log.Logf("[ss] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) - - addr, err := h.getRequest(conn) - if err != nil { - log.Logf("[ss] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr) - - if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ss] Unauthorized to tcp connect to %s", addr) - return - } - - cc, err := h.options.Chain.Dial(addr) - if err != nil { - log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err) - return - } - defer cc.Close() - - log.Logf("[ss] %s <-> %s", conn.RemoteAddr(), addr) - transport(conn, cc) - log.Logf("[ss] %s >-< %s", conn.RemoteAddr(), addr) -} - -const ( - idType = 0 // address type index - idIP0 = 1 // ip addres start index - idDmLen = 1 // domain address length index - idDm0 = 2 // domain address start index - - typeIPv4 = 1 // type is ipv4 address - typeDm = 3 // type is domain address - typeIPv6 = 4 // type is ipv6 address - - lenIPv4 = net.IPv4len + 2 // ipv4 + 2port - lenIPv6 = net.IPv6len + 2 // ipv6 + 2port - lenDmBase = 2 // 1addrLen + 2port, plus addrLen - lenHmacSha1 = 10 -) - -// This function is copied from shadowsocks library with some modification. -func (h *shadowHandler) getRequest(conn net.Conn) (host string, err error) { - // buf size should at least have the same size with the largest possible - // request size (when addrType is 3, domain name has at most 256 bytes) - // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) - buf := make([]byte, smallBufferSize) - - // read till we get possible domain length field - conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil { - return - } - - var reqStart, reqEnd int - addrType := buf[idType] - switch addrType & ss.AddrMask { - case typeIPv4: - reqStart, reqEnd = idIP0, idIP0+lenIPv4 - case typeIPv6: - reqStart, reqEnd = idIP0, idIP0+lenIPv6 - case typeDm: - if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil { - return - } - reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase) - default: - err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask) - return - } - - if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil { - return - } - - // Return string for typeIP is not most efficient, but browsers (Chrome, - // Safari, Firefox) all seems using typeDm exclusively. So this is not a - // big problem. - switch addrType & ss.AddrMask { - case typeIPv4: - host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String() - case typeIPv6: - host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() - case typeDm: - host = string(buf[idDm0 : idDm0+buf[idDmLen]]) - } - // parse port - port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd]) - host = net.JoinHostPort(host, strconv.Itoa(int(port))) - return -} - -type shadowUDPListener struct { - ln net.PacketConn - conns map[string]*udpServerConn - connChan chan net.Conn - errChan chan error - ttl time.Duration -} - -// ShadowUDPListener creates a Listener for shadowsocks UDP relay server. -func ShadowUDPListener(addr string, cipher *url.Userinfo, ttl time.Duration) (Listener, error) { - laddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenUDP("udp", laddr) - if err != nil { - return nil, err - } - - var method, password string - if cipher != nil { - method = cipher.Username() - password, _ = cipher.Password() - } - cp, err := ss.NewCipher(method, password) - if err != nil { - ln.Close() - return nil, err - } - l := &shadowUDPListener{ - ln: ss.NewSecurePacketConn(ln, cp, false), - conns: make(map[string]*udpServerConn), - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - ttl: ttl, - } - go l.listenLoop() - return l, nil -} - -func (l *shadowUDPListener) listenLoop() { - for { - b := make([]byte, mediumBufferSize) - n, raddr, err := l.ln.ReadFrom(b) - if err != nil { - log.Logf("[ssu] peer -> %s : %s", l.Addr(), err) - l.ln.Close() - l.errChan <- err - close(l.errChan) - return - } - if Debug { - log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n) - } - - conn, ok := l.conns[raddr.String()] - if !ok || conn.Closed() { - conn = newUDPServerConn(l.ln, raddr, l.ttl) - l.conns[raddr.String()] = conn - - select { - case l.connChan <- conn: - default: - conn.Close() - log.Logf("[ssu] %s - %s: connection queue is full", raddr, l.Addr()) - } - } - - select { - case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination. - default: - log.Logf("[ssu] %s -> %s : read queue is full", raddr, l.Addr()) - } - } -} - -func (l *shadowUDPListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *shadowUDPListener) Addr() net.Addr { - return l.ln.LocalAddr() -} - -func (l *shadowUDPListener) Close() error { - return l.ln.Close() -} - -type shadowUDPdHandler struct { - ttl time.Duration - options *HandlerOptions -} - -// ShadowUDPdHandler creates a server Handler for shadowsocks UDP relay server. -func ShadowUDPdHandler(opts ...HandlerOption) Handler { - h := &shadowUDPdHandler{ - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *shadowUDPdHandler) Handle(conn net.Conn) { - defer conn.Close() - - var err error - var cc net.PacketConn - if h.options.Chain.IsEmpty() { - cc, err = net.ListenUDP("udp", nil) - if err != nil { - log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err) - return - } - } else { - var c net.Conn - c, err = getSOCKS5UDPTunnel(h.options.Chain, nil) - if err != nil { - log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err) - return - } - cc = &udpTunnelConn{Conn: c} - } - defer cc.Close() - - log.Logf("[ssu] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - transportUDP(conn, cc) - log.Logf("[ssu] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) -} - -func transportUDP(sc net.Conn, cc net.PacketConn) error { - errc := make(chan error, 1) - go func() { - for { - b := make([]byte, mediumBufferSize) - n, err := sc.Read(b[3:]) // add rsv and frag fields to make it the standard SOCKS5 UDP datagram - if err != nil { - // log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err) - errc <- err - return - } - dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n+3])) - if err != nil { - log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err) - errc <- err - return - } - //if Debug { - // log.Logf("[ssu] %s >>> %s length: %d", sc.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data)) - //} - addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) - if err != nil { - errc <- err - return - } - if _, err := cc.WriteTo(dgram.Data, addr); err != nil { - errc <- err - return - } - } - }() - - go func() { - for { - b := make([]byte, mediumBufferSize) - n, addr, err := cc.ReadFrom(b) - if err != nil { - errc <- err - return - } - //if Debug { - // log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n) - //} - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n]) - buf := bytes.Buffer{} - dgram.Write(&buf) - if buf.Len() < 10 { - log.Logf("[ssu] %s <- %s : invalid udp datagram", sc.RemoteAddr(), addr) - continue - } - if _, err := sc.Write(buf.Bytes()[3:]); err != nil { - errc <- err - return - } - } - }() - - err := <-errc - if err != nil && err == io.EOF { - err = nil - } - return err -} diff --git a/gost/ssh.go b/gost/ssh.go deleted file mode 100644 index a09dc9d..0000000 --- a/gost/ssh.go +++ /dev/null @@ -1,834 +0,0 @@ -package gost - -import ( - "context" - "crypto/tls" - "encoding/binary" - "errors" - "fmt" - "net" - "net/url" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-log/log" - "golang.org/x/crypto/ssh" -) - -// Applicaple SSH Request types for Port Forwarding - RFC 4254 7.X -const ( - DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2 - RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 - ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 - CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1 - - GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel -) - -var ( - errSessionDead = errors.New("session is dead") -) - -type sshDirectForwardConnector struct { -} - -func SSHDirectForwardConnector() Connector { - return &sshDirectForwardConnector{} -} - -func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string) (net.Conn, error) { - cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. - if !ok { - return nil, errors.New("ssh: wrong connection type") - } - conn, err := cc.session.client.Dial("tcp", raddr) - if err != nil { - log.Logf("[ssh-tcp] %s -> %s : %s", cc.session.addr, raddr, err) - return nil, err - } - return conn, nil -} - -type sshRemoteForwardConnector struct { -} - -func SSHRemoteForwardConnector() Connector { - return &sshRemoteForwardConnector{} -} - -func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { - cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. - if !ok { - return nil, errors.New("ssh: wrong connection type") - } - - cc.session.once.Do(func() { - go func() { - defer log.Log("ssh-rtcp: session is closed") - defer close(cc.session.connChan) - - if cc.session == nil || cc.session.client == nil { - return - } - if strings.HasPrefix(addr, ":") { - addr = "0.0.0.0" + addr - } - ln, err := cc.session.client.Listen("tcp", addr) - if err != nil { - return - } - for { - rc, err := ln.Accept() - if err != nil { - log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), addr, err) - return - } - - select { - case cc.session.connChan <- rc: - default: - rc.Close() - log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), addr) - } - } - }() - }) - - sc, ok := <-cc.session.connChan - if !ok { - return nil, errors.New("ssh-rtcp: connection is closed") - } - return sc, nil -} - -type sshForwardTransporter struct { - sessions map[string]*sshSession - sessionMutex sync.Mutex -} - -func SSHForwardTransporter() Transporter { - return &sshForwardTransporter{ - sessions: make(map[string]*sshSession), - } -} - -func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if !ok || session.Closed() { - if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, opts.Timeout) - } else { - conn, err = opts.Chain.Dial(addr) - } - if err != nil { - return - } - session = &sshSession{ - addr: addr, - conn: conn, - } - tr.sessions[addr] = session - } - - return session.conn, nil -} - -func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - - config := ssh.ClientConfig{ - Timeout: opts.Timeout, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - if opts.User != nil { - config.User = opts.User.Username() - password, _ := opts.User.Password() - config.Auth = []ssh.AuthMethod{ - ssh.Password(password), - } - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[opts.Addr] - if session != nil && session.conn != conn { - conn.Close() - return nil, errors.New("ssh: unrecognized connection") - } - if !ok || session.client == nil { - sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - - session = &sshSession{ - addr: opts.Addr, - conn: conn, - client: ssh.NewClient(sshConn, chans, reqs), - closed: make(chan struct{}), - deaded: make(chan struct{}), - connChan: make(chan net.Conn, 1024), - } - tr.sessions[opts.Addr] = session - go session.Ping(opts.Interval, 1) - go session.waitServer() - go session.waitClose() - } - if session.Closed() { - delete(tr.sessions, opts.Addr) - return nil, errSessionDead - } - - return &sshNopConn{session: session}, nil -} - -func (tr *sshForwardTransporter) Multiplex() bool { - return true -} - -type sshTunnelTransporter struct { - sessions map[string]*sshSession - sessionMutex sync.Mutex -} - -// SSHTunnelTransporter creates a Transporter that is used by SSH tunnel client. -func SSHTunnelTransporter() Transporter { - return &sshTunnelTransporter{ - sessions: make(map[string]*sshSession), - } -} - -func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if !ok || session.Closed() { - if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, opts.Timeout) - } else { - conn, err = opts.Chain.Dial(addr) - } - if err != nil { - return - } - session = &sshSession{ - addr: addr, - conn: conn, - } - tr.sessions[addr] = session - } - - return session.conn, nil -} - -func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - - config := ssh.ClientConfig{ - Timeout: opts.Timeout, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - if opts.User != nil { - config.User = opts.User.Username() - password, _ := opts.User.Password() - config.Auth = []ssh.AuthMethod{ - ssh.Password(password), - } - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[opts.Addr] - if session != nil && session.conn != conn { - conn.Close() - return nil, errors.New("ssh: unrecognized connection") - } - if !ok || session.client == nil { - sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - - session = &sshSession{ - addr: opts.Addr, - conn: conn, - client: ssh.NewClient(sshConn, chans, reqs), - closed: make(chan struct{}), - deaded: make(chan struct{}), - } - tr.sessions[opts.Addr] = session - go session.Ping(opts.Interval, 1) - go session.waitServer() - go session.waitClose() - } - - if session.Closed() { - delete(tr.sessions, opts.Addr) - return nil, errSessionDead - } - - channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) - if err != nil { - return nil, err - } - go ssh.DiscardRequests(reqs) - return &sshConn{channel: channel, conn: conn}, nil -} - -func (tr *sshTunnelTransporter) Multiplex() bool { - return true -} - -type sshSession struct { - addr string - conn net.Conn - client *ssh.Client - closed chan struct{} - deaded chan struct{} - once sync.Once - connChan chan net.Conn -} - -func (s *sshSession) Ping(interval time.Duration, retries int) { - if interval <= 0 { - return - } - defer close(s.deaded) - - log.Log("[ssh] ping is enabled, interval:", interval) - baseCtx := context.Background() - t := time.NewTicker(interval) - defer t.Stop() - - for { - select { - case <-t.C: - start := time.Now() - //if Debug { - log.Log("[ssh] sending ping") - //} - ctx, cancel := context.WithTimeout(baseCtx, time.Second*30) - var err error - select { - case err = <-s.sendPing(): - case <-ctx.Done(): - err = errors.New("Timeout") - } - cancel() - if err != nil { - log.Log("[ssh] ping:", err) - return - } - //if Debug { - log.Log("[ssh] ping OK, RTT:", time.Since(start)) - //} - - case <-s.closed: - return - } - } -} - -func (s *sshSession) sendPing() <-chan error { - ch := make(chan error, 1) - go func() { - if _, _, err := s.client.SendRequest("ping", true, nil); err != nil { - ch <- err - } - close(ch) - }() - return ch -} - -func (s *sshSession) waitServer() error { - defer close(s.closed) - return s.client.Wait() -} - -func (s *sshSession) waitClose() { - defer s.client.Close() - - select { - case <-s.deaded: - case <-s.closed: - } -} - -func (s *sshSession) Closed() bool { - select { - case <-s.deaded: - return true - case <-s.closed: - return true - default: - } - return false -} - -type sshForwardHandler struct { - options *HandlerOptions - config *ssh.ServerConfig -} - -func SSHForwardHandler(opts ...HandlerOption) Handler { - h := &sshForwardHandler{ - options: new(HandlerOptions), - config: new(ssh.ServerConfig), - } - for _, opt := range opts { - opt(h.options) - } - h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Users...) - if len(h.options.Users) == 0 { - h.config.NoClientAuth = true - } - if h.options.TLSConfig != nil && len(h.options.TLSConfig.Certificates) > 0 { - signer, err := ssh.NewSignerFromKey(h.options.TLSConfig.Certificates[0].PrivateKey) - if err != nil { - log.Log("[ssh-forward]", err) - } - h.config.AddHostKey(signer) - } - - return h -} - -func (h *sshForwardHandler) Handle(conn net.Conn) { - sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config) - if err != nil { - log.Logf("[ssh-forward] %s -> %s : %s", conn.RemoteAddr(), h.options.Addr, err) - conn.Close() - return - } - defer sshConn.Close() - - log.Logf("[ssh-forward] %s <-> %s", conn.RemoteAddr(), h.options.Addr) - h.handleForward(sshConn, chans, reqs) - log.Logf("[ssh-forward] %s >-< %s", conn.RemoteAddr(), h.options.Addr) -} - -func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { - quit := make(chan struct{}) - defer close(quit) // quit signal - - go func() { - for req := range reqs { - switch req.Type { - case RemoteForwardRequest: - go h.tcpipForwardRequest(conn, req, quit) - default: - // log.Log("[ssh] unknown channel type:", req.Type) - if req.WantReply { - req.Reply(false, nil) - } - } - } - }() - - go func() { - for newChannel := range chans { - // Check the type of channel - t := newChannel.ChannelType() - switch t { - case DirectForwardRequest: - channel, requests, err := newChannel.Accept() - if err != nil { - log.Log("[ssh] Could not accept channel:", err) - continue - } - p := directForward{} - ssh.Unmarshal(newChannel.ExtraData(), &p) - - if p.Host1 == "" { - p.Host1 = "" - } - - go ssh.DiscardRequests(requests) - go h.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) - default: - log.Log("[ssh] Unknown channel type:", t) - newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - } - } - }() - - conn.Wait() -} - -func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr string) { - defer channel.Close() - - log.Logf("[ssh-tcp] %s - %s", h.options.Addr, raddr) - - if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr) - return - } - - conn, err := h.options.Chain.Dial(raddr) - if err != nil { - log.Logf("[ssh-tcp] %s - %s : %s", h.options.Addr, raddr, err) - return - } - defer conn.Close() - - log.Logf("[ssh-tcp] %s <-> %s", h.options.Addr, raddr) - transport(conn, channel) - log.Logf("[ssh-tcp] %s >-< %s", h.options.Addr, raddr) -} - -// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request -type tcpipForward struct { - Host string - Port uint32 -} - -func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) { - t := tcpipForward{} - ssh.Unmarshal(req.Payload, &t) - - addr := fmt.Sprintf("%s:%d", t.Host, t.Port) - - if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr) - req.Reply(false, nil) - return - } - - log.Log("[ssh-rtcp] listening on tcp", addr) - ln, err := net.Listen("tcp", addr) //tie to the client connection - if err != nil { - log.Log("[ssh-rtcp]", err) - req.Reply(false, nil) - return - } - defer ln.Close() - - replyFunc := func() error { - if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used - _, port, err := getHostPortFromAddr(ln.Addr()) - if err != nil { - return err - } - var b [4]byte - binary.BigEndian.PutUint32(b[:], uint32(port)) - t.Port = uint32(port) - return req.Reply(true, b[:]) - } - return req.Reply(true, nil) - } - if err := replyFunc(); err != nil { - log.Log("[ssh-rtcp]", err) - return - } - - go func() { - for { - conn, err := ln.Accept() - if err != nil { // Unable to accept new connection - listener is likely closed - return - } - - go func(conn net.Conn) { - defer conn.Close() - - p := directForward{} - var err error - - var portnum int - p.Host1 = t.Host - p.Port1 = t.Port - p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) - if err != nil { - return - } - - p.Port2 = uint32(portnum) - ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) - if err != nil { - log.Log("[ssh-rtcp] open forwarded channel:", err) - return - } - defer ch.Close() - go ssh.DiscardRequests(reqs) - - log.Logf("[ssh-rtcp] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - transport(ch, conn) - log.Logf("[ssh-rtcp] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) - }(conn) - } - }() - - <-quit -} - -// SSHConfig holds the SSH tunnel server config -type SSHConfig struct { - Users []*url.Userinfo - TLSConfig *tls.Config -} - -type sshTunnelListener struct { - net.Listener - config *ssh.ServerConfig - connChan chan net.Conn - errChan chan error -} - -// SSHTunnelListener creates a Listener for SSH tunnel server. -func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - - if config == nil { - config = &SSHConfig{} - } - - sshConfig := &ssh.ServerConfig{} - sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Users...) - if len(config.Users) == 0 { - sshConfig.NoClientAuth = true - } - if config.TLSConfig == nil { - cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey) - if err != nil { - ln.Close() - return nil, err - } - config.TLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - } - - signer, err := ssh.NewSignerFromKey(config.TLSConfig.Certificates[0].PrivateKey) - if err != nil { - ln.Close() - return nil, err - - } - sshConfig.AddHostKey(signer) - - l := &sshTunnelListener{ - Listener: ln, - config: sshConfig, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - - go l.listenLoop() - - return l, nil -} - -func (l *sshTunnelListener) listenLoop() { - for { - conn, err := l.Listener.Accept() - if err != nil { - log.Log("[ssh] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - go l.serveConn(conn) - } -} - -func (l *sshTunnelListener) serveConn(conn net.Conn) { - sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) - if err != nil { - log.Logf("[ssh] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - conn.Close() - return - } - defer sc.Close() - - go ssh.DiscardRequests(reqs) - go func() { - for newChannel := range chans { - // Check the type of channel - t := newChannel.ChannelType() - switch t { - case GostSSHTunnelRequest: - channel, requests, err := newChannel.Accept() - if err != nil { - log.Log("[ssh] Could not accept channel:", err) - continue - } - go ssh.DiscardRequests(requests) - cc := &sshConn{conn: conn, channel: channel} - select { - case l.connChan <- cc: - default: - cc.Close() - log.Logf("[ssh] %s - %s: connection queue is full", conn.RemoteAddr(), l.Addr()) - } - - default: - log.Log("[ssh] Unknown channel type:", t) - newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - } - } - }() - - log.Logf("[ssh] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - sc.Wait() - log.Logf("[ssh] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) -} - -func (l *sshTunnelListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" -type directForward struct { - Host1 string - Port1 uint32 - Host2 string - Port2 uint32 -} - -func (p directForward) String() string { - return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) -} - -func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { - host, portString, err := net.SplitHostPort(addr.String()) - if err != nil { - return - } - port, err = strconv.Atoi(portString) - return -} - -type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) - -func defaultSSHPasswordCallback(users ...*url.Userinfo) PasswordCallbackFunc { - return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - for _, user := range users { - u := user.Username() - p, _ := user.Password() - if u == conn.User() && p == string(password) { - return nil, nil - } - } - log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) - return nil, fmt.Errorf("password rejected for %s", conn.User()) - } -} - -type sshNopConn struct { - session *sshSession -} - -func (c *sshNopConn) Read(b []byte) (n int, err error) { - return 0, &net.OpError{Op: "read", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("read not supported")} -} - -func (c *sshNopConn) Write(b []byte) (n int, err error) { - return 0, &net.OpError{Op: "write", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("write not supported")} -} - -func (c *sshNopConn) Close() error { - return nil -} - -func (c *sshNopConn) LocalAddr() net.Addr { - return &net.TCPAddr{ - IP: net.IPv4zero, - Port: 0, - } -} - -func (c *sshNopConn) RemoteAddr() net.Addr { - return &net.TCPAddr{ - IP: net.IPv4zero, - Port: 0, - } -} - -func (c *sshNopConn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *sshNopConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *sshNopConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -type sshConn struct { - channel ssh.Channel - conn net.Conn -} - -func (c *sshConn) Read(b []byte) (n int, err error) { - return c.channel.Read(b) -} - -func (c *sshConn) Write(b []byte) (n int, err error) { - return c.channel.Write(b) -} - -func (c *sshConn) Close() error { - return c.channel.Close() -} - -func (c *sshConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *sshConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *sshConn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *sshConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *sshConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} diff --git a/gost/ws.go b/gost/ws.go deleted file mode 100644 index ff4605c..0000000 --- a/gost/ws.go +++ /dev/null @@ -1,295 +0,0 @@ -package gost - -import ( - "crypto/tls" - "net" - "net/http" - "net/http/httputil" - "time" - - "net/url" - - "github.com/go-log/log" - "gopkg.in/gorilla/websocket.v1" -) - -// WSOptions describes the options for websocket. -type WSOptions struct { - ReadBufferSize int - WriteBufferSize int - HandshakeTimeout time.Duration - EnableCompression bool -} - -type websocketConn struct { - conn *websocket.Conn - rb []byte -} - -func websocketClientConn(url string, conn net.Conn, tlsConfig *tls.Config, options *WSOptions) (net.Conn, error) { - if options == nil { - options = &WSOptions{} - } - dialer := websocket.Dialer{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - TLSClientConfig: tlsConfig, - HandshakeTimeout: options.HandshakeTimeout, - EnableCompression: options.EnableCompression, - NetDial: func(net, addr string) (net.Conn, error) { - return conn, nil - }, - } - c, resp, err := dialer.Dial(url, nil) - if err != nil { - return nil, err - } - resp.Body.Close() - return &websocketConn{conn: c}, nil -} - -func websocketServerConn(conn *websocket.Conn) net.Conn { - // conn.EnableWriteCompression(true) - return &websocketConn{ - conn: conn, - } -} - -func (c *websocketConn) Read(b []byte) (n int, err error) { - if len(c.rb) == 0 { - _, c.rb, err = c.conn.ReadMessage() - } - n = copy(b, c.rb) - c.rb = c.rb[n:] - return -} - -func (c *websocketConn) Write(b []byte) (n int, err error) { - err = c.conn.WriteMessage(websocket.BinaryMessage, b) - n = len(b) - return -} - -func (c *websocketConn) Close() error { - return c.conn.Close() -} - -func (c *websocketConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *websocketConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *websocketConn) SetDeadline(t time.Time) error { - if err := c.SetReadDeadline(t); err != nil { - return err - } - return c.SetWriteDeadline(t) -} -func (c *websocketConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *websocketConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -type wsTransporter struct { - *tcpTransporter - options *WSOptions -} - -// WSTransporter creates a Transporter that is used by websocket proxy client. -func WSTransporter(opts *WSOptions) Transporter { - return &wsTransporter{ - options: opts, - } -} - -func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - wsOptions := tr.options - if opts.WSOptions != nil { - wsOptions = opts.WSOptions - } - url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"} - return websocketClientConn(url.String(), conn, nil, wsOptions) -} - -type wssTransporter struct { - *tcpTransporter - options *WSOptions -} - -// WSSTransporter creates a Transporter that is used by websocket secure proxy client. -func WSSTransporter(opts *WSOptions) Transporter { - return &wssTransporter{ - options: opts, - } -} - -func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - wsOptions := tr.options - if opts.WSOptions != nil { - wsOptions = opts.WSOptions - } - if opts.TLSConfig == nil { - opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} - } - url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"} - return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) -} - -type wsListener struct { - addr net.Addr - upgrader *websocket.Upgrader - srv *http.Server - connChan chan net.Conn - errChan chan error -} - -// WSListener creates a Listener for websocket proxy server. -func WSListener(addr string, options *WSOptions) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - if options == nil { - options = &WSOptions{} - } - l := &wsListener{ - addr: tcpAddr, - upgrader: &websocket.Upgrader{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: options.EnableCompression, - }, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - - mux := http.NewServeMux() - mux.Handle("/ws", http.HandlerFunc(l.upgrade)) - l.srv = &http.Server{Addr: addr, Handler: mux} - - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - - go func() { - err := l.srv.Serve(tcpKeepAliveListener{ln}) - if err != nil { - l.errChan <- err - } - close(l.errChan) - }() - select { - case err := <-l.errChan: - return nil, err - default: - } - - return l, nil -} - -func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { - log.Logf("[ws] %s -> %s", r.RemoteAddr, l.addr) - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Log(string(dump)) - } - conn, err := l.upgrader.Upgrade(w, r, nil) - if err != nil { - log.Logf("[ws] %s - %s : %s", r.RemoteAddr, l.addr, err) - return - } - select { - case l.connChan <- websocketServerConn(conn): - default: - conn.Close() - log.Logf("[ws] %s - %s: connection queue is full", r.RemoteAddr, l.addr) - } -} - -func (l *wsListener) Accept() (conn net.Conn, err error) { - select { - case conn = <-l.connChan: - case err = <-l.errChan: - } - return -} - -func (l *wsListener) Close() error { - return l.srv.Close() -} - -func (l *wsListener) Addr() net.Addr { - return l.addr -} - -type wssListener struct { - *wsListener -} - -// WSSListener creates a Listener for websocket secure proxy server. -func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - if options == nil { - options = &WSOptions{} - } - l := &wssListener{ - wsListener: &wsListener{ - addr: tcpAddr, - upgrader: &websocket.Upgrader{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: options.EnableCompression, - }, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - }, - } - - mux := http.NewServeMux() - mux.Handle("/ws", http.HandlerFunc(l.upgrade)) - l.srv = &http.Server{ - Addr: addr, - TLSConfig: tlsConfig, - Handler: mux, - } - - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - - go func() { - err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) - if err != nil { - l.errChan <- err - } - close(l.errChan) - }() - select { - case err := <-l.errChan: - return nil, err - default: - } - - return l, nil -} diff --git a/handler.go b/handler.go index 9662783..4a454a0 100644 --- a/handler.go +++ b/handler.go @@ -1,32 +1,114 @@ package gost import ( + "bufio" + "crypto/tls" "net" + "net/url" + + "github.com/ginuerzh/gosocks4" + "github.com/ginuerzh/gosocks5" + "github.com/go-log/log" ) +// Handler is a proxy server handler type Handler interface { Handle(net.Conn) } -type defaultHandler struct { - server Server +// HandlerOptions describes the options for Handler. +type HandlerOptions struct { + Addr string + Chain *Chain + Users []*url.Userinfo + TLSConfig *tls.Config + Whitelist *Permissions + Blacklist *Permissions } -func DefaultHandler(server Server) Handler { - return &defaultHandler{server: server} +// HandlerOption allows a common way to set handler options. +type HandlerOption func(opts *HandlerOptions) + +// AddrHandlerOption sets the Addr option of HandlerOptions. +func AddrHandlerOption(addr string) HandlerOption { + return func(opts *HandlerOptions) { + opts.Addr = addr + } } -func (h *defaultHandler) Handle(conn net.Conn) { - var handler Handler +// ChainHandlerOption sets the Chain option of HandlerOptions. +func ChainHandlerOption(chain *Chain) HandlerOption { + return func(opts *HandlerOptions) { + opts.Chain = chain + } +} - switch h.server.Options().BaseOptions().Protocol { - case "http": - handler = HTTPHandler(h.server) - case "socks", "socks5": - case "ss": // shadowsocks - handler = ShadowHandler(h.server) +// UsersHandlerOption sets the Users option of HandlerOptions. +func UsersHandlerOption(users ...*url.Userinfo) HandlerOption { + return func(opts *HandlerOptions) { + opts.Users = users + } +} +// TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions. +func TLSConfigHandlerOption(config *tls.Config) HandlerOption { + return func(opts *HandlerOptions) { + opts.TLSConfig = config + } +} + +// WhitelistHandlerOption sets the Whitelist option of HandlerOptions. +func WhitelistHandlerOption(whitelist *Permissions) HandlerOption { + return func(opts *HandlerOptions) { + opts.Whitelist = whitelist + } +} + +// BlacklistHandlerOption sets the Blacklist option of HandlerOptions. +func BlacklistHandlerOption(blacklist *Permissions) HandlerOption { + return func(opts *HandlerOptions) { + opts.Blacklist = blacklist + } +} + +type autoHandler struct { + options []HandlerOption +} + +// AutoHandler creates a server Handler for auto proxy server. +func AutoHandler(opts ...HandlerOption) Handler { + h := &autoHandler{ + options: opts, + } + return h +} + +func (h *autoHandler) Handle(conn net.Conn) { + defer conn.Close() + + br := bufio.NewReader(conn) + b, err := br.Peek(1) + if err != nil { + log.Log(err) + return } - handler.Handle(conn) + cc := &bufferdConn{Conn: conn, br: br} + switch b[0] { + case gosocks4.Ver4: + SOCKS4Handler(h.options...).Handle(cc) + case gosocks5.Ver5: + SOCKS5Handler(h.options...).Handle(cc) + default: // http + HTTPHandler(h.options...).Handle(cc) + } +} + +type bufferdConn struct { + net.Conn + br *bufio.Reader +} + +func (c *bufferdConn) Read(b []byte) (int, error) { + return c.br.Read(b) } diff --git a/http.go b/http.go index 7fdc638..8de15c5 100644 --- a/http.go +++ b/http.go @@ -2,96 +2,100 @@ package gost import ( "bufio" - "crypto/tls" "encoding/base64" - "errors" - "io" + "fmt" "net" "net/http" "net/http/httputil" + "net/url" + "strings" "time" - "github.com/ginuerzh/pht" "github.com/go-log/log" - "github.com/golang/glog" - "golang.org/x/net/http2" ) -type HttpServer struct { - conn net.Conn - Base *ProxyServer +type httpConnector struct { + User *url.Userinfo } -func NewHttpServer(conn net.Conn, base *ProxyServer) *HttpServer { - return &HttpServer{ - conn: conn, - Base: base, - } +// HTTPConnector creates a Connector for HTTP proxy client. +// It accepts an optional auth info for HTTP Basic Authentication. +func HTTPConnector(user *url.Userinfo) Connector { + return &httpConnector{User: user} } -// Default HTTP server handler -func (s *HttpServer) HandleRequest(req *http.Request) { - -} - -func (s *HttpServer) forwardRequest(req *http.Request) { - last := s.Base.Chain.lastNode - if last == nil { - return +func (c *httpConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: addr}, + Host: addr, + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), } - cc, err := s.Base.Chain.GetConn() - if err != nil { - glog.V(LWARNING).Infof("[http] %s -> %s : %s", s.conn.RemoteAddr(), last.Addr, err) + req.Header.Set("Proxy-Connection", "keep-alive") - b := []byte("HTTP/1.1 503 Service unavailable\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - glog.V(LDEBUG).Infof("[http] %s <- %s\n%s", s.conn.RemoteAddr(), last.Addr, string(b)) - s.conn.Write(b) - return - } - defer cc.Close() - - if len(last.Users) > 0 { - user := last.Users[0] - s := user.String() - if _, set := user.Password(); !set { + if c.User != nil { + s := c.User.String() + if _, set := c.User.Password(); !set { s += ":" } req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) } - cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err = req.WriteProxy(cc); err != nil { - glog.V(LWARNING).Infof("[http] %s -> %s : %s", s.conn.RemoteAddr(), req.Host, err) - return + if err := req.Write(conn); err != nil { + return nil, err } - cc.SetWriteDeadline(time.Time{}) - glog.V(LINFO).Infof("[http] %s <-> %s", s.conn.RemoteAddr(), req.Host) - s.Base.transport(s.conn, cc) - glog.V(LINFO).Infof("[http] %s >-< %s", s.conn.RemoteAddr(), req.Host) - return + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Log(string(dump)) + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return nil, err + } + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Log(string(dump)) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s", resp.Status) + } + + return conn, nil } type httpHandler struct { - server Server + options *HandlerOptions } -func HTTPHandler(server Server) Handler { - return &httpHandler{server: server} +// HTTPHandler creates a server Handler for HTTP proxy server. +func HTTPHandler(opts ...HandlerOption) Handler { + h := &httpHandler{ + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + return h } func (h *httpHandler) Handle(conn net.Conn) { + defer conn.Close() + req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { - log.Log("[http]", err) + log.Logf("[http] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) return } - log.Logf("[http] %s %s - %s %s", req.Method, conn.RemoteAddr(), req.Host, req.Proto) - if Debug { + log.Logf("[http] %s %s - %s %s", req.Method, conn.RemoteAddr(), req.Host, req.Proto) dump, _ := httputil.DumpRequest(req, false) log.Logf(string(dump)) } @@ -104,21 +108,8 @@ func (h *httpHandler) Handle(conn net.Conn) { return } - valid := false u, p, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) - users := h.server.Options().BaseOptions().Users - for _, user := range users { - username := user.Username() - password, _ := user.Password() - if (u == username && p == password) || - (u == username && password == "") || - (username == "" && p == password) { - valid = true - break - } - } - - if len(users) > 0 && !valid { + if !authenticate(u, p, h.options.Users...) { log.Logf("[http] %s <- %s : proxy authentication required", conn.RemoteAddr(), req.Host) resp := "HTTP/1.1 407 Proxy Authentication Required\r\n" + "Proxy-Authenticate: Basic realm=\"gost\"\r\n" + @@ -128,20 +119,31 @@ func (h *httpHandler) Handle(conn net.Conn) { } req.Header.Del("Proxy-Authorization") + req.Header.Del("Proxy-Connection") + + if !Can("tcp", req.Host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[http] Unauthorized to tcp connect to %s", req.Host) + b := []byte("HTTP/1.1 403 Forbidden\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + conn.Write(b) + if Debug { + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) + } + return + } // forward http request - //lastNode := s.Base.Chain.lastNode - //if lastNode != nil && lastNode.Transport == "" && (lastNode.Protocol == "http" || lastNode.Protocol == "") { - // s.forwardRequest(req) - // return - //} + lastNode := h.options.Chain.LastNode() + if req.Method != http.MethodConnect && lastNode.Protocol == "http" { + h.forwardRequest(conn, req) + return + } - // if !s.Base.Node.Can("tcp", req.Host) { - // glog.Errorf("Unauthorized to tcp connect to %s", req.Host) - // return - // } - - cc, err := h.server.Chain().Dial(req.Host) + host := req.Host + if !strings.Contains(req.Host, ":") { + host += ":80" + } + cc, err := h.options.Chain.Dial(host) if err != nil { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) @@ -164,7 +166,6 @@ func (h *httpHandler) Handle(conn net.Conn) { conn.Write(b) } else { req.Header.Del("Proxy-Connection") - // req.Header.Set("Connection", "Keep-Alive") if err = req.Write(cc); err != nil { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) @@ -172,274 +173,87 @@ func (h *httpHandler) Handle(conn net.Conn) { } } + log.Logf("[http] %s <-> %s", cc.LocalAddr(), req.Host) + transport(conn, cc) + log.Logf("[http] %s >-< %s", cc.LocalAddr(), req.Host) +} + +func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request) { + if h.options.Chain.IsEmpty() { + return + } + lastNode := h.options.Chain.LastNode() + + cc, err := h.options.Chain.Conn() + if err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), lastNode.Addr, err) + + b := []byte("HTTP/1.1 503 Service unavailable\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + if Debug { + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), lastNode.Addr, string(b)) + } + conn.Write(b) + return + } + defer cc.Close() + + if lastNode.User != nil { + s := lastNode.User.String() + if _, set := lastNode.User.Password(); !set { + s += ":" + } + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) + } + + cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if err = req.WriteProxy(cc); err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + return + } + cc.SetWriteDeadline(time.Time{}) + log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host) - Transport(conn, cc) + transport(conn, cc) log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host) + return } -type Http2Server struct { - Base *ProxyServer - Handler http.Handler - TLSConfig *tls.Config -} - -func NewHttp2Server(base *ProxyServer) *Http2Server { - return &Http2Server{Base: base} -} - -func (s *Http2Server) ListenAndServeTLS(config *tls.Config) error { - srv := http.Server{ - Addr: s.Base.Node.Addr, - Handler: s.Handler, - TLSConfig: config, - } - if srv.Handler == nil { - srv.Handler = http.HandlerFunc(s.HandleRequest) - } - http2.ConfigureServer(&srv, nil) - return srv.ListenAndServeTLS("", "") -} - -// Default HTTP2 server handler -func (s *Http2Server) HandleRequest(w http.ResponseWriter, req *http.Request) { - target := req.Header.Get("Gost-Target") - if target == "" { - target = req.Host - } - glog.V(LINFO).Infof("[http2] %s %s - %s %s", req.Method, req.RemoteAddr, target, req.Proto) - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(req, false) - glog.Infoln(string(dump)) - } - - w.Header().Set("Proxy-Agent", "gost/"+Version) - - if !s.Base.Node.Can("tcp", target) { - glog.Errorf("Unauthorized to tcp connect to %s", target) +func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { + if proxyAuth == "" { return } - // HTTP2 as transport - if req.Header.Get("Proxy-Switch") == "gost" { - conn, err := s.Upgrade(w, req) - if err != nil { - glog.V(LINFO).Infof("[http2] %s -> %s : %s", req.RemoteAddr, target, err) - return - } - glog.V(LINFO).Infof("[http2] %s - %s : switch to HTTP2 transport mode OK", req.RemoteAddr, target) - s.Base.handleConn(conn) + if !strings.HasPrefix(proxyAuth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { return } - valid := false - u, p, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) - for _, user := range s.Base.Node.Users { - username := user.Username() - password, _ := user.Password() + return cs[:s], cs[s+1:], true +} + +func authenticate(username, password string, users ...*url.Userinfo) bool { + if len(users) == 0 { + return true + } + + for _, user := range users { + u := user.Username() + p, _ := user.Password() if (u == username && p == password) || - (u == username && password == "") || - (username == "" && p == password) { - valid = true - break + (u == username && p == "") || + (u == "" && p == password) { + return true } } - if len(s.Base.Node.Users) > 0 && !valid { - glog.V(LWARNING).Infof("[http2] %s <- %s : proxy authentication required", req.RemoteAddr, target) - w.WriteHeader(http.StatusProxyAuthRequired) - return - } - - req.Header.Del("Proxy-Authorization") - req.Header.Del("Proxy-Connection") - - c, err := s.Base.Chain.Dial(target) - if err != nil { - glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, target, err) - w.WriteHeader(http.StatusServiceUnavailable) - return - } - defer c.Close() - - glog.V(LINFO).Infof("[http2] %s <-> %s", req.RemoteAddr, target) - - if req.Method == http.MethodConnect { - w.WriteHeader(http.StatusOK) - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } - - // compatible with HTTP1.x - if hj, ok := w.(http.Hijacker); ok && req.ProtoMajor == 1 { - // we take over the underly connection - conn, _, err := hj.Hijack() - if err != nil { - glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, target, err) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer conn.Close() - glog.V(LINFO).Infof("[http2] %s -> %s : downgrade to HTTP/1.1", req.RemoteAddr, target) - s.Base.transport(conn, c) - return - } - - errc := make(chan error, 2) - go func() { - _, err := io.Copy(c, req.Body) - errc <- err - }() - go func() { - _, err := io.Copy(flushWriter{w}, c) - errc <- err - }() - - select { - case <-errc: - // glog.V(LWARNING).Infoln("exit", err) - } - glog.V(LINFO).Infof("[http2] %s >-< %s", req.RemoteAddr, target) - return - } - - // req.Header.Set("Connection", "Keep-Alive") - if err = req.Write(c); err != nil { - glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, target, err) - return - } - - resp, err := http.ReadResponse(bufio.NewReader(c), req) - if err != nil { - glog.V(LWARNING).Infoln("[http2] %s -> %s : %s", req.RemoteAddr, target, err) - return - } - defer resp.Body.Close() - - for k, v := range resp.Header { - for _, vv := range v { - w.Header().Add(k, vv) - } - } - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(flushWriter{w}, resp.Body); err != nil { - glog.V(LWARNING).Infof("[http2] %s <- %s : %s", req.RemoteAddr, target, err) - } - glog.V(LINFO).Infof("[http2] %s >-< %s", req.RemoteAddr, target) -} - -// Upgrade upgrade an HTTP2 request to a bidirectional connection that preparing for tunneling other protocol, just like a websocket connection. -func (s *Http2Server) Upgrade(w http.ResponseWriter, r *http.Request) (net.Conn, error) { - if r.Method != http.MethodConnect { - w.WriteHeader(http.StatusMethodNotAllowed) - return nil, errors.New("Method not allowed") - } - - w.WriteHeader(http.StatusOK) - - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } - - conn := &http2Conn{r: r.Body, w: flushWriter{w}} - conn.remoteAddr, _ = net.ResolveTCPAddr("tcp", r.RemoteAddr) - conn.localAddr, _ = net.ResolveTCPAddr("tcp", r.Host) - return conn, nil -} - -// HTTP2 client connection, wrapped up just like a net.Conn -type http2Conn struct { - r io.Reader - w io.Writer - remoteAddr net.Addr - localAddr net.Addr -} - -func (c *http2Conn) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} - -func (c *http2Conn) Write(b []byte) (n int, err error) { - return c.w.Write(b) -} - -func (c *http2Conn) Close() (err error) { - if rc, ok := c.r.(io.Closer); ok { - err = rc.Close() - } - if w, ok := c.w.(io.Closer); ok { - err = w.Close() - } - return -} - -func (c *http2Conn) LocalAddr() net.Addr { - return c.localAddr -} - -func (c *http2Conn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *http2Conn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *http2Conn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *http2Conn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -type flushWriter struct { - w io.Writer -} - -func (fw flushWriter) Write(p []byte) (n int, err error) { - defer func() { - if r := recover(); r != nil { - if s, ok := r.(string); ok { - err = errors.New(s) - return - } - err = r.(error) - } - }() - - n, err = fw.w.Write(p) - if err != nil { - // glog.V(LWARNING).Infoln("flush writer:", err) - return - } - if f, ok := fw.w.(http.Flusher); ok { - f.Flush() - } - return -} - -type PureHttpServer struct { - Base *ProxyServer - Handler func(net.Conn) -} - -func NewPureHttpServer(base *ProxyServer) *PureHttpServer { - return &PureHttpServer{ - Base: base, - } -} - -func (s *PureHttpServer) ListenAndServe() error { - server := pht.Server{ - Addr: s.Base.Node.Addr, - Key: s.Base.Node.Get("key"), - } - if server.Handler == nil { - server.Handler = s.handleConn - } - return server.ListenAndServe() -} - -func (s *PureHttpServer) handleConn(conn net.Conn) { - glog.V(LINFO).Infof("[pht] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) - s.Base.handleConn(conn) + return false } diff --git a/gost/http2.go b/http2.go similarity index 97% rename from gost/http2.go rename to http2.go index 2927afe..692d02a 100644 --- a/gost/http2.go +++ b/http2.go @@ -392,13 +392,7 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { errChan: make(chan error, 1), } if config == nil { - cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey) - if err != nil { - return nil, err - } - config = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } + config = DefaultTLSConfig } server := &http.Server{ Addr: addr, @@ -410,7 +404,12 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { } l.server = server - go server.ListenAndServeTLS("", "") + go func() { + err := server.ListenAndServeTLS("", "") + if err != nil { + log.Log("[http2]", err) + } + }() return l, nil } @@ -473,17 +472,11 @@ func H2Listener(addr string, config *tls.Config) (Listener, error) { return nil, err } if config == nil { - cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey) - if err != nil { - return nil, err - } - config = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } + config = DefaultTLSConfig } l := &h2Listener{ - Listener: ln, + Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, server: &http2.Server{ // MaxConcurrentStreams: 1000, PermitProhibitedCipherSuites: true, @@ -505,7 +498,7 @@ func H2CListener(addr string) (Listener, error) { return nil, err } l := &h2Listener{ - Listener: ln, + Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, server: &http2.Server{ // MaxConcurrentStreams: 1000, }, diff --git a/kcp.go b/kcp.go index 1980942..857c697 100644 --- a/kcp.go +++ b/kcp.go @@ -1,30 +1,30 @@ -// KCP feature is based on https://github.com/xtaci/kcptun - package gost import ( "crypto/sha1" "encoding/csv" - "encoding/json" + "errors" "fmt" - "github.com/golang/glog" - "github.com/klauspost/compress/snappy" - "golang.org/x/crypto/pbkdf2" - "gopkg.in/xtaci/kcp-go.v2" - "gopkg.in/xtaci/smux.v1" "net" "os" "time" -) -const ( - DefaultKCPConfigFile = "kcp.json" + "golang.org/x/crypto/pbkdf2" + + "sync" + + "github.com/go-log/log" + "github.com/klauspost/compress/snappy" + "gopkg.in/xtaci/kcp-go.v2" + "gopkg.in/xtaci/smux.v1" ) var ( - SALT = "kcp-go" + // KCPSalt is the default salt for KCP cipher. + KCPSalt = "kcp-go" ) +// KCPConfig describes the config for KCP. type KCPConfig struct { Key string `json:"key"` Crypt string `json:"crypt"` @@ -45,25 +45,10 @@ type KCPConfig struct { KeepAlive int `json:"keepalive"` SnmpLog string `json:"snmplog"` SnmpPeriod int `json:"snmpperiod"` + Signal bool `json:"signal"` // Signal enables the signal SIGUSR1 feature. } -func ParseKCPConfig(configFile string) (*KCPConfig, error) { - if configFile == "" { - configFile = DefaultKCPConfigFile - } - file, err := os.Open(configFile) - if err != nil { - return nil, err - } - defer file.Close() - - config := &KCPConfig{} - if err = json.NewDecoder(file).Decode(config); err != nil { - return nil, err - } - return config, nil -} - +// Init initializes the KCP config. func (c *KCPConfig) Init() { switch c.Mode { case "normal": @@ -80,6 +65,7 @@ func (c *KCPConfig) Init() { } var ( + // DefaultKCPConfig is the default KCP config. DefaultKCPConfig = &KCPConfig{ Key: "it's a secrect", Crypt: "aes", @@ -100,89 +86,323 @@ var ( KeepAlive: 10, SnmpLog: "", SnmpPeriod: 60, + Signal: false, } ) -type KCPServer struct { - Base *ProxyServer - Config *KCPConfig +type kcpConn struct { + conn net.Conn + stream *smux.Stream } -func NewKCPServer(base *ProxyServer, config *KCPConfig) *KCPServer { - return &KCPServer{Base: base, Config: config} +func (c *kcpConn) Read(b []byte) (n int, err error) { + return c.stream.Read(b) } -func (s *KCPServer) ListenAndServe() (err error) { - if s.Config == nil { - s.Config = DefaultKCPConfig - } - s.Config.Init() +func (c *kcpConn) Write(b []byte) (n int, err error) { + return c.stream.Write(b) +} - ln, err := kcp.ListenWithOptions(s.Base.Node.Addr, - blockCrypt(s.Config.Key, s.Config.Crypt, SALT), s.Config.DataShard, s.Config.ParityShard) +func (c *kcpConn) Close() error { + return c.stream.Close() +} + +func (c *kcpConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *kcpConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *kcpConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *kcpConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *kcpConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +type kcpSession struct { + conn net.Conn + session *smux.Session +} + +func (session *kcpSession) GetConn() (*kcpConn, error) { + stream, err := session.session.OpenStream() if err != nil { - return err + return nil, err } - if err = ln.SetDSCP(s.Config.DSCP); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) + return &kcpConn{conn: session.conn, stream: stream}, nil +} + +func (session *kcpSession) Close() error { + return session.session.Close() +} + +func (session *kcpSession) IsClosed() bool { + return session.session.IsClosed() +} + +func (session *kcpSession) NumStreams() int { + return session.session.NumStreams() +} + +type kcpTransporter struct { + sessions map[string]*kcpSession + sessionMutex sync.Mutex + config *KCPConfig +} + +// KCPTransporter creates a Transporter that is used by KCP proxy client. +func KCPTransporter(config *KCPConfig) Transporter { + if config == nil { + config = DefaultKCPConfig } - if err = ln.SetReadBuffer(s.Config.SockBuf); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) - } - if err = ln.SetWriteBuffer(s.Config.SockBuf); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) + config.Init() + + go snmpLogger(config.SnmpLog, config.SnmpPeriod) + if config.Signal { + go kcpSigHandler() } - go snmpLogger(s.Config.SnmpLog, s.Config.SnmpPeriod) - go kcpSigHandler() - for { - conn, err := ln.AcceptKCP() - if err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) - continue - } - - conn.SetStreamMode(true) - conn.SetNoDelay(s.Config.NoDelay, s.Config.Interval, s.Config.Resend, s.Config.NoCongestion) - conn.SetMtu(s.Config.MTU) - conn.SetWindowSize(s.Config.SndWnd, s.Config.RcvWnd) - conn.SetACKNoDelay(s.Config.AckNodelay) - conn.SetKeepAlive(s.Config.KeepAlive) - - go s.handleMux(conn) + return &kcpTransporter{ + config: config, + sessions: make(map[string]*kcpSession), } } -func (s *KCPServer) handleMux(conn net.Conn) { +func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + uaddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if !ok { + conn, err = net.DialUDP("udp", nil, uaddr) + if err != nil { + return + } + session = &kcpSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + config := tr.config + if opts.KCPConfig != nil { + config = opts.KCPConfig + } + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("kcp: unrecognized connection") + } + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*kcpSession, error) { + udpConn, ok := conn.(*net.UDPConn) + if !ok { + return nil, errors.New("kcp: wrong connection type") + } + + kcpconn, err := kcp.NewConn(addr, + blockCrypt(config.Key, config.Crypt, KCPSalt), + config.DataShard, config.ParityShard, + &kcp.ConnectedUDPConn{UDPConn: udpConn, Conn: udpConn}) + if err != nil { + return nil, err + } + + kcpconn.SetStreamMode(true) + kcpconn.SetNoDelay(config.NoDelay, config.Interval, config.Resend, config.NoCongestion) + kcpconn.SetWindowSize(config.SndWnd, config.RcvWnd) + kcpconn.SetMtu(config.MTU) + kcpconn.SetACKNoDelay(config.AckNodelay) + kcpconn.SetKeepAlive(config.KeepAlive) + + if err := kcpconn.SetDSCP(config.DSCP); err != nil { + log.Log("[kcp]", err) + } + if err := kcpconn.SetReadBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + if err := kcpconn.SetWriteBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + + // stream multiplex smuxConfig := smux.DefaultConfig() - smuxConfig.MaxReceiveBuffer = s.Config.SockBuf + smuxConfig.MaxReceiveBuffer = config.SockBuf + var cc net.Conn = kcpconn + if !config.NoComp { + cc = newCompStreamConn(kcpconn) + } + session, err := smux.Client(cc, smuxConfig) + if err != nil { + return nil, err + } + return &kcpSession{conn: conn, session: session}, nil +} - glog.V(LINFO).Infof("[kcp] %s - %s", conn.RemoteAddr(), s.Base.Node.Addr) +func (tr *kcpTransporter) Multiplex() bool { + return true +} - if !s.Config.NoComp { +type kcpListener struct { + config *KCPConfig + ln *kcp.Listener + connChan chan net.Conn + errChan chan error +} + +// KCPListener creates a Listener for KCP proxy server. +func KCPListener(addr string, config *KCPConfig) (Listener, error) { + if config == nil { + config = DefaultKCPConfig + } + config.Init() + + ln, err := kcp.ListenWithOptions(addr, + blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard) + if err != nil { + return nil, err + } + if err = ln.SetDSCP(config.DSCP); err != nil { + log.Log("[kcp]", err) + } + if err = ln.SetReadBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + if err = ln.SetWriteBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + + go snmpLogger(config.SnmpLog, config.SnmpPeriod) + if config.Signal { + go kcpSigHandler() + } + + l := &kcpListener{ + config: config, + ln: ln, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil +} + +func (l *kcpListener) listenLoop() { + for { + conn, err := l.ln.AcceptKCP() + if err != nil { + log.Log("[kcp] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + conn.SetStreamMode(true) + conn.SetNoDelay(l.config.NoDelay, l.config.Interval, l.config.Resend, l.config.NoCongestion) + conn.SetMtu(l.config.MTU) + conn.SetWindowSize(l.config.SndWnd, l.config.RcvWnd) + conn.SetACKNoDelay(l.config.AckNodelay) + conn.SetKeepAlive(l.config.KeepAlive) + go l.mux(conn) + } +} + +func (l *kcpListener) mux(conn net.Conn) { + smuxConfig := smux.DefaultConfig() + smuxConfig.MaxReceiveBuffer = l.config.SockBuf + + log.Logf("[kcp] %s - %s", conn.RemoteAddr(), l.Addr()) + + if !l.config.NoComp { conn = newCompStreamConn(conn) } mux, err := smux.Server(conn, smuxConfig) if err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) + log.Log("[kcp]", err) return } defer mux.Close() - glog.V(LINFO).Infof("[kcp] %s <-> %s", conn.RemoteAddr(), s.Base.Node.Addr) - defer glog.V(LINFO).Infof("[kcp] %s >-< %s", conn.RemoteAddr(), s.Base.Node.Addr) + log.Logf("[kcp] %s <-> %s", conn.RemoteAddr(), l.Addr()) + defer log.Logf("[kcp] %s >-< %s", conn.RemoteAddr(), l.Addr()) for { stream, err := mux.AcceptStream() if err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) + log.Log("[kcp] accept stream:", err) return } - go s.Base.handleConn(NewKCPConn(conn, stream)) + + cc := &kcpConn{conn: conn, stream: stream} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[kcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } } } +func (l *kcpListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} +func (l *kcpListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *kcpListener) Close() error { + return l.ln.Close() +} + func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) { pass := pbkdf2.Key([]byte(key), []byte(salt), 4096, 32, sha1.New) @@ -217,8 +437,8 @@ func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) { return } -func snmpLogger(path string, interval int) { - if path == "" || interval == 0 { +func snmpLogger(format string, interval int) { + if format == "" || interval == 0 { return } ticker := time.NewTicker(time.Duration(interval) * time.Second) @@ -226,20 +446,20 @@ func snmpLogger(path string, interval int) { for { select { case <-ticker.C: - f, err := os.OpenFile(time.Now().Format(path), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + f, err := os.OpenFile(time.Now().Format(format), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) if err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) + log.Log("[kcp]", err) return } w := csv.NewWriter(f) // write header in empty file if stat, err := f.Stat(); err == nil && stat.Size() == 0 { if err := w.Write(append([]string{"Unix"}, kcp.DefaultSnmp.Header()...)); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) + log.Log("[kcp]", err) } } if err := w.Write(append([]string{fmt.Sprint(time.Now().Unix())}, kcp.DefaultSnmp.ToSlice()...)); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) + log.Log("[kcp]", err) } kcp.DefaultSnmp.Reset() w.Flush() @@ -248,117 +468,6 @@ func snmpLogger(path string, interval int) { } } -type KCPSession struct { - conn net.Conn - session *smux.Session -} - -func DialKCP(addr string, config *KCPConfig) (*KCPSession, error) { - if config == nil { - config = DefaultKCPConfig - } - config.Init() - - kcpconn, err := kcp.DialWithOptions(addr, - blockCrypt(config.Key, config.Crypt, SALT), config.DataShard, config.ParityShard) - if err != nil { - return nil, err - } - - kcpconn.SetStreamMode(true) - kcpconn.SetNoDelay(config.NoDelay, config.Interval, config.Resend, config.NoCongestion) - kcpconn.SetWindowSize(config.SndWnd, config.RcvWnd) - kcpconn.SetMtu(config.MTU) - kcpconn.SetACKNoDelay(config.AckNodelay) - kcpconn.SetKeepAlive(config.KeepAlive) - - if err := kcpconn.SetDSCP(config.DSCP); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) - } - if err := kcpconn.SetReadBuffer(config.SockBuf); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) - } - if err := kcpconn.SetWriteBuffer(config.SockBuf); err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) - } - - // stream multiplex - smuxConfig := smux.DefaultConfig() - smuxConfig.MaxReceiveBuffer = config.SockBuf - var conn net.Conn = kcpconn - if !config.NoComp { - conn = newCompStreamConn(kcpconn) - } - session, err := smux.Client(conn, smuxConfig) - if err != nil { - conn.Close() - return nil, err - } - return &KCPSession{conn: conn, session: session}, nil -} - -func (session *KCPSession) GetConn() (*KCPConn, error) { - stream, err := session.session.OpenStream() - if err != nil { - session.Close() - return nil, err - } - return NewKCPConn(session.conn, stream), nil -} - -func (session *KCPSession) Close() error { - return session.session.Close() -} - -func (session *KCPSession) IsClosed() bool { - return session.session.IsClosed() -} - -func (session *KCPSession) NumStreams() int { - return session.session.NumStreams() -} - -type KCPConn struct { - conn net.Conn - stream *smux.Stream -} - -func NewKCPConn(conn net.Conn, stream *smux.Stream) *KCPConn { - return &KCPConn{conn: conn, stream: stream} -} - -func (c *KCPConn) Read(b []byte) (n int, err error) { - return c.stream.Read(b) -} - -func (c *KCPConn) Write(b []byte) (n int, err error) { - return c.stream.Write(b) -} - -func (c *KCPConn) Close() error { - return c.stream.Close() -} - -func (c *KCPConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *KCPConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *KCPConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *KCPConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *KCPConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - type compStreamConn struct { conn net.Conn w *snappy.Writer diff --git a/gost/log.go b/log.go similarity index 100% rename from gost/log.go rename to log.go diff --git a/node.go b/node.go index bc22387..82b7a3a 100644 --- a/node.go +++ b/node.go @@ -1,90 +1,40 @@ package gost import ( - "bufio" - "fmt" "net" "net/url" - "os" "strconv" "strings" - "github.com/golang/glog" + "github.com/go-log/log" ) -// Proxy node represent a proxy -type ProxyNode struct { - Addr string // [host]:port - Protocol string // protocol: http/socks5/ss - Transport string // transport: ws/wss/tls/http2/tcp/udp/rtcp/rudp - Remote string // remote address, used by tcp/udp port forwarding - Users []*url.Userinfo // authentication for proxy - Whitelist *Permissions - Blacklist *Permissions - values url.Values - serverName string - conn net.Conn +// Node is a proxy node, mainly used to construct a proxy chain. +type Node struct { + Addr string + Protocol string + Transport string + Remote string // remote address, used by tcp/udp port forwarding + User *url.Userinfo + Values url.Values + Client *Client + DialOptions []DialOption + HandshakeOptions []HandshakeOption } -// The proxy node string pattern is [scheme://][user:pass@host]:port. -// -// Scheme can be devided into two parts by character '+', such as: http+tls. -func ParseProxyNode(s string) (node ProxyNode, err error) { +func ParseNode(s string) (node Node, err error) { if !strings.Contains(s, "://") { - s = "gost://" + s + s = "auto://" + s } u, err := url.Parse(s) if err != nil { return } - query := u.Query() - - node = ProxyNode{ - Addr: u.Host, - values: query, - serverName: u.Host, - } - - if query.Get("whitelist") != "" { - node.Whitelist, err = ParsePermissions(query.Get("whitelist")) - - if err != nil { - glog.Fatal(err) - } - } else { - // By default allow for everyting - node.Whitelist, _ = ParsePermissions("*:*:*") - } - - if query.Get("blacklist") != "" { - node.Blacklist, err = ParsePermissions(query.Get("blacklist")) - - if err != nil { - glog.Fatal(err) - } - } else { - // By default block nothing - node.Blacklist, _ = ParsePermissions("") - } - - if u.User != nil { - node.Users = append(node.Users, u.User) - } - - users, er := parseUsers(node.Get("secrets")) - if users != nil { - node.Users = append(node.Users, users...) - } - if er != nil { - glog.V(LWARNING).Infoln("secrets:", er) - } - - if strings.Contains(u.Host, ":") { - node.serverName, _, _ = net.SplitHostPort(u.Host) - if node.serverName == "" { - node.serverName = "localhost" // default server name - } + node = Node{ + Addr: u.Host, + Values: u.Query(), + User: u.User, } schemes := strings.Split(u.Scheme, "+") @@ -98,20 +48,24 @@ func ParseProxyNode(s string) (node ProxyNode, err error) { } switch node.Transport { - case "ws", "wss", "tls", "http2", "quic", "kcp", "redirect", "ssu", "pht", "ssh": + case "tls", "ws", "wss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "redirect": case "https": node.Protocol = "http" node.Transport = "tls" case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding node.Remote = strings.Trim(u.EscapedPath(), "/") - case "rtcp", "rudp": // started from v2.1, rtcp and rudp are for remote port forwarding + case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding node.Remote = strings.Trim(u.EscapedPath(), "/") default: node.Transport = "" } switch node.Protocol { - case "http", "http2", "socks", "socks4", "socks4a", "socks5", "ss": + case "http", "http2", "socks4", "socks4a", "ss", "ssu": + case "socks", "socks5": + node.Protocol = "socks5" + case "tcp", "udp", "rtcp", "rudp": // port forwarding + case "direct", "remote", "forward": // SSH port forwarding default: node.Protocol = "" } @@ -119,40 +73,7 @@ func ParseProxyNode(s string) (node ProxyNode, err error) { return } -func parseUsers(authFile string) (users []*url.Userinfo, err error) { - if authFile == "" { - return - } - - file, err := os.Open(authFile) - if err != nil { - return - } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - s := strings.SplitN(line, " ", 2) - if len(s) == 1 { - users = append(users, url.User(strings.TrimSpace(s[0]))) - } else if len(s) == 2 { - users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) - } - } - - err = scanner.Err() - return -} - -// Get get node parameter by key -func (node *ProxyNode) Get(key string) string { - return node.values.Get(key) -} - -func (node *ProxyNode) Can(action string, addr string) bool { +func Can(action string, addr string, whitelist, blacklist *Permissions) bool { if !strings.Contains(addr, ":") { addr = addr + ":80" } @@ -168,46 +89,8 @@ func (node *ProxyNode) Can(action string, addr string) bool { return false } - glog.V(LDEBUG).Infof("Can action: %s, host: %s, port %d", action, host, port) - - return node.Whitelist.Can(action, host, port) && !node.Blacklist.Can(action, host, port) -} - -func (node *ProxyNode) getBool(key string) bool { - s := node.Get(key) - if b, _ := strconv.ParseBool(s); b { - return b + if Debug { + log.Logf("Can action: %s, host: %s, port %d", action, host, port) } - n, _ := strconv.Atoi(s) - return n > 0 -} - -func (node *ProxyNode) Set(key, value string) { - node.values.Set(key, value) -} - -func (node *ProxyNode) insecureSkipVerify() bool { - return !node.getBool("secure") -} - -func (node *ProxyNode) caFile() string { - return node.Get("ca") -} - -func (node *ProxyNode) certFile() string { - if cert := node.Get("cert"); cert != "" { - return cert - } - return DefaultCertFile -} - -func (node *ProxyNode) keyFile() string { - if key := node.Get("key"); key != "" { - return key - } - return DefaultKeyFile -} - -func (node ProxyNode) String() string { - return fmt.Sprintf("transport: %s, protocol: %s, addr: %s, whitelist: %v, blacklist: %v", node.Transport, node.Protocol, node.Addr, node.Whitelist, node.Blacklist) + return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port) } diff --git a/node_test.go b/node_test.go deleted file mode 100644 index 52cd183..0000000 --- a/node_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package gost - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNodeDefaultWhitelist(t *testing.T) { - assert := assert.New(t) - - node, _ := ParseProxyNode("http2://localhost:8000") - - assert.True(node.Can("connect", "google.pl:80")) - assert.True(node.Can("connect", "google.pl:443")) - assert.True(node.Can("connect", "google.pl:22")) - assert.True(node.Can("bind", "google.pl:80")) - assert.True(node.Can("bind", "google.com:80")) -} - -func TestNodeWhitelist(t *testing.T) { - assert := assert.New(t) - - node, _ := ParseProxyNode("http2://localhost:8000?whitelist=connect:google.pl:80,443") - - assert.True(node.Can("connect", "google.pl:80")) - assert.True(node.Can("connect", "google.pl:443")) - assert.False(node.Can("connect", "google.pl:22")) - assert.False(node.Can("bind", "google.pl:80")) - assert.False(node.Can("bind", "google.com:80")) -} - -func TestNodeBlacklist(t *testing.T) { - assert := assert.New(t) - - node, _ := ParseProxyNode("http2://localhost:8000?blacklist=connect:google.pl:80,443") - - assert.False(node.Can("connect", "google.pl:80")) - assert.False(node.Can("connect", "google.pl:443")) - assert.True(node.Can("connect", "google.pl:22")) - assert.True(node.Can("bind", "google.pl:80")) - assert.True(node.Can("bind", "google.com:80")) -} diff --git a/permissions.go b/permissions.go index 3e079eb..8566c80 100644 --- a/permissions.go +++ b/permissions.go @@ -9,14 +9,6 @@ import ( glob "github.com/ryanuber/go-glob" ) -type PortRange struct { - Min, Max int -} - -type PortSet []PortRange - -type StringSet []string - type Permission struct { Actions StringSet Hosts StringSet @@ -39,6 +31,10 @@ func maxint(x, y int) int { return y } +type PortRange struct { + Min, Max int +} + func (ir *PortRange) Contains(value int) bool { return value >= ir.Min && value <= ir.Max } @@ -88,6 +84,8 @@ func (ps *PortSet) Contains(value int) bool { return false } +type PortSet []PortRange + func ParsePortSet(s string) (*PortSet, error) { ps := &PortSet{} @@ -120,6 +118,8 @@ func (ss *StringSet) Contains(subj string) bool { return false } +type StringSet []string + func ParseStringSet(s string) (*StringSet, error) { ss := &StringSet{} if s == "" { diff --git a/quic.go b/quic.go index 446acf2..712fdfb 100644 --- a/quic.go +++ b/quic.go @@ -1,81 +1,241 @@ package gost import ( - "bufio" "crypto/tls" - "github.com/golang/glog" - "github.com/lucas-clemente/quic-go/h2quic" - "io" - "net/http" - "net/http/httputil" + "errors" + "net" + "sync" + "time" + + "github.com/go-log/log" + quic "github.com/lucas-clemente/quic-go" ) -type QuicServer struct { - Base *ProxyServer - Handler http.Handler +type quicSession struct { + conn net.Conn + session quic.Session +} + +func (session *quicSession) GetConn() (*quicConn, error) { + stream, err := session.session.OpenStream() + if err != nil { + return nil, err + } + return &quicConn{ + Stream: stream, + laddr: session.session.LocalAddr(), + raddr: session.session.RemoteAddr(), + }, nil +} + +func (session *quicSession) Close() error { + return session.session.Close(nil) +} + +type quicTransporter struct { + config *QUICConfig + sessionMutex sync.Mutex + sessions map[string]*quicSession +} + +// QUICTransporter creates a Transporter that is used by QUIC proxy client. +func QUICTransporter(config *QUICConfig) Transporter { + if config == nil { + config = &QUICConfig{} + } + return &quicTransporter{ + config: config, + sessions: make(map[string]*quicSession), + } +} + +func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if !ok { + conn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return + } + session = &quicSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + config := tr.config + if opts.QUICConfig != nil { + config = opts.QUICConfig + } + if config.TLSConfig == nil { + config.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("quic: unrecognized connection") + } + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) { + udpConn, ok := conn.(*net.UDPConn) + if !ok { + return nil, errors.New("quic: wrong connection type") + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + quicConfig := &quic.Config{ + HandshakeTimeout: config.Timeout, + KeepAlive: config.KeepAlive, + } + session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig) + if err != nil { + log.Log("quic dial", err) + return nil, err + } + return &quicSession{conn: conn, session: session}, nil +} + +func (tr *quicTransporter) Multiplex() bool { + return true +} + +type QUICConfig struct { TLSConfig *tls.Config + Timeout time.Duration + KeepAlive bool } -func NewQuicServer(base *ProxyServer) *QuicServer { - return &QuicServer{Base: base} +type quicListener struct { + ln quic.Listener + connChan chan net.Conn + errChan chan error } -func (s *QuicServer) ListenAndServeTLS(config *tls.Config) error { - server := &h2quic.Server{ - Server: &http.Server{ - Addr: s.Base.Node.Addr, - Handler: s.Handler, - TLSConfig: config, - }, +// QUICListener creates a Listener for QUIC proxy server. +func QUICListener(addr string, config *QUICConfig) (Listener, error) { + if config == nil { + config = &QUICConfig{} } - if server.Handler == nil { - // server.Handler = http.HandlerFunc(s.HandleRequest) - server.Handler = http.HandlerFunc(NewHttp2Server(s.Base).HandleRequest) - } - return server.ListenAndServe() -} - -func (s *QuicServer) HandleRequest(w http.ResponseWriter, req *http.Request) { - target := req.Host - glog.V(LINFO).Infof("[quic] %s %s - %s %s", req.Method, req.RemoteAddr, target, req.Proto) - - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(req, false) - glog.Infoln(string(dump)) + quicConfig := &quic.Config{ + HandshakeTimeout: config.Timeout, + KeepAlive: config.KeepAlive, } - c, err := s.Base.Chain.Dial(target) + tlsConfig := config.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + ln, err := quic.ListenAddr(addr, tlsConfig, quicConfig) if err != nil { - glog.V(LWARNING).Infof("[quic] %s -> %s : %s", req.RemoteAddr, target, err) - w.WriteHeader(http.StatusServiceUnavailable) - return - } - defer c.Close() - - glog.V(LINFO).Infof("[quic] %s <-> %s", req.RemoteAddr, target) - - req.Header.Set("Connection", "Keep-Alive") - if err = req.Write(c); err != nil { - glog.V(LWARNING).Infof("[quic] %s -> %s : %s", req.RemoteAddr, target, err) - return + return nil, err } - resp, err := http.ReadResponse(bufio.NewReader(c), req) - if err != nil { - glog.V(LWARNING).Infoln(err) - return + l := &quicListener{ + ln: ln, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), } - defer resp.Body.Close() + go l.listenLoop() - for k, v := range resp.Header { - for _, vv := range v { - w.Header().Add(k, vv) + return l, nil +} + +func (l *quicListener) listenLoop() { + for { + session, err := l.ln.Accept() + if err != nil { + log.Log("[quic] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.sessionLoop(session) + } +} + +func (l *quicListener) sessionLoop(session quic.Session) { + log.Logf("[quic] %s <-> %s", session.RemoteAddr(), session.LocalAddr()) + defer log.Logf("[quic] %s >-< %s", session.RemoteAddr(), session.LocalAddr()) + + for { + stream, err := session.AcceptStream() + if err != nil { + log.Log("[quic] accept stream:", err) + return + } + + cc := &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[quic] %s - %s: connection queue is full", session.RemoteAddr(), session.LocalAddr()) } } - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(flushWriter{w}, resp.Body); err != nil { - glog.V(LWARNING).Infof("[quic] %s <- %s : %s", req.RemoteAddr, target, err) - } - - glog.V(LINFO).Infof("[quic] %s >-< %s", req.RemoteAddr, target) +} + +func (l *quicListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *quicListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *quicListener) Close() error { + return l.ln.Close() +} + +type quicConn struct { + quic.Stream + laddr net.Addr + raddr net.Addr +} + +func (c *quicConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *quicConn) RemoteAddr() net.Addr { + return c.raddr } diff --git a/redirect.go b/redirect.go index 7524418..f7033c6 100644 --- a/redirect.go +++ b/redirect.go @@ -5,70 +5,58 @@ package gost import ( "errors" "fmt" - "github.com/golang/glog" "net" "syscall" + + "github.com/go-log/log" ) -const ( - SO_ORIGINAL_DST = 80 -) - -type RedsocksTCPServer struct { - Base *ProxyServer +type tcpRedirectHandler struct { + options *HandlerOptions } -func NewRedsocksTCPServer(base *ProxyServer) *RedsocksTCPServer { - return &RedsocksTCPServer{ - Base: base, +// TCPRedirectHandler creates a server Handler for TCP redirect server. +func TCPRedirectHandler(opts ...HandlerOption) Handler { + h := &tcpRedirectHandler{ + options: &HandlerOptions{ + Chain: new(Chain), + }, } + for _, opt := range opts { + opt(h.options) + } + return h } -func (s *RedsocksTCPServer) ListenAndServe() error { - laddr, err := net.ResolveTCPAddr("tcp", s.Base.Node.Addr) - if err != nil { - return err - } - ln, err := net.ListenTCP("tcp", laddr) - if err != nil { - return err +func (h *tcpRedirectHandler) Handle(c net.Conn) { + conn, ok := c.(*net.TCPConn) + if !ok { + log.Log("[red-tcp] not a TCP connection") } - defer ln.Close() - for { - conn, err := ln.AcceptTCP() - if err != nil { - glog.V(LWARNING).Infoln(err) - continue - } - go s.handleRedirectTCP(conn) - } -} - -func (s *RedsocksTCPServer) handleRedirectTCP(conn *net.TCPConn) { srcAddr := conn.RemoteAddr() - dstAddr, conn, err := getOriginalDstAddr(conn) + dstAddr, conn, err := h.getOriginalDstAddr(conn) if err != nil { - glog.V(LWARNING).Infof("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) + log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) return } defer conn.Close() - glog.V(LINFO).Infof("[red-tcp] %s -> %s", srcAddr, dstAddr) + log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr) - cc, err := s.Base.Chain.Dial(dstAddr.String()) + cc, err := h.options.Chain.Dial(dstAddr.String()) if err != nil { - glog.V(LWARNING).Infof("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) + log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) return } defer cc.Close() - glog.V(LINFO).Infof("[red-tcp] %s <-> %s", srcAddr, dstAddr) - s.Base.transport(conn, cc) - glog.V(LINFO).Infof("[red-tcp] %s >-< %s", srcAddr, dstAddr) + log.Logf("[red-tcp] %s <-> %s", srcAddr, dstAddr) + transport(conn, cc) + log.Logf("[red-tcp] %s >-< %s", srcAddr, dstAddr) } -func getOriginalDstAddr(conn *net.TCPConn) (addr net.Addr, c *net.TCPConn, err error) { +func (h *tcpRedirectHandler) getOriginalDstAddr(conn *net.TCPConn) (addr net.Addr, c *net.TCPConn, err error) { defer conn.Close() fc, err := conn.File() @@ -77,7 +65,7 @@ func getOriginalDstAddr(conn *net.TCPConn) (addr net.Addr, c *net.TCPConn, err e } defer fc.Close() - mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, SO_ORIGINAL_DST) + mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80) if err != nil { return } diff --git a/redirect_win.go b/redirect_win.go index 116bdd8..848b70b 100644 --- a/redirect_win.go +++ b/redirect_win.go @@ -3,15 +3,29 @@ package gost import ( - "errors" + "net" + + "github.com/go-log/log" ) -type RedsocksTCPServer struct{} - -func NewRedsocksTCPServer(base *ProxyServer) *RedsocksTCPServer { - return &RedsocksTCPServer{} +type tcpRedirectHandler struct { + options *HandlerOptions } -func (s *RedsocksTCPServer) ListenAndServe() error { - return errors.New("Not supported") +// TCPRedirectHandler creates a server Handler for TCP redirect server. +func TCPRedirectHandler(opts ...HandlerOption) Handler { + h := &tcpRedirectHandler{ + options: &HandlerOptions{ + Chain: new(Chain), + }, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *tcpRedirectHandler) Handle(c net.Conn) { + log.Log("[red-tcp] TCP redirect is not available on the Windows platform") + c.Close() } diff --git a/server.go b/server.go index 99eda47..fb34958 100644 --- a/server.go +++ b/server.go @@ -1,297 +1,104 @@ package gost import ( - "bufio" - "crypto/tls" "io" "net" - "net/http" - "strconv" - "strings" + "time" - "github.com/ginuerzh/gosocks4" - "github.com/ginuerzh/gosocks5" - "github.com/golang/glog" - ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" - "golang.org/x/crypto/ssh" + "github.com/go-log/log" ) -type ProxyServer struct { - Node ProxyNode - Chain *ProxyChain - TLSConfig *tls.Config - selector *ServerSelector - cipher *ss.Cipher - ota bool +// Server is a proxy server. +type Server struct { } -func NewProxyServer(node ProxyNode, chain *ProxyChain) *ProxyServer { - certFile, keyFile := node.certFile(), node.keyFile() +// Serve serves as a proxy server. +func (s *Server) Serve(l net.Listener, h Handler) error { + defer l.Close() - cert, err := LoadCertificate(certFile, keyFile) - if err != nil { - glog.Fatal(err) - } - - config := &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - - if chain == nil { - chain = NewProxyChain() - } - - var cipher *ss.Cipher - var ota bool - if node.Protocol == "ss" || node.Transport == "ssu" { - var err error - var method, password string - - if len(node.Users) > 0 { - method = node.Users[0].Username() - password, _ = node.Users[0].Password() - } - ota = node.getBool("ota") - if strings.HasSuffix(method, "-auth") { - ota = true - method = strings.TrimSuffix(method, "-auth") - } - cipher, err = ss.NewCipher(method, password) - if err != nil { - glog.Fatal(err) - } - } - return &ProxyServer{ - Node: node, - Chain: chain, - TLSConfig: config, - selector: &ServerSelector{ // socks5 server selector - // methods that socks5 server supported - methods: []uint8{ - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - MethodTLSAuth, - }, - // Users: node.Users, - TLSConfig: config, - }, - cipher: cipher, - ota: ota, - } -} - -func (s *ProxyServer) Serve() error { - var ln net.Listener - var err error - node := s.Node - - switch node.Transport { - case "ws": // websocket connection - return NewWebsocketServer(s).ListenAndServe() - case "wss": // websocket security connection - return NewWebsocketServer(s).ListenAndServeTLS(s.TLSConfig) - case "tls": // tls connection - ln, err = tls.Listen("tcp", node.Addr, s.TLSConfig) - case "http2": // Standard HTTP2 proxy server, compatible with HTTP1.x. - server := NewHttp2Server(s) - server.Handler = http.HandlerFunc(server.HandleRequest) - return server.ListenAndServeTLS(s.TLSConfig) - case "tcp": // Local TCP port forwarding - return NewTcpForwardServer(s).ListenAndServe() - case "udp": // Local UDP port forwarding - ttl, _ := strconv.Atoi(s.Node.Get("ttl")) - if ttl <= 0 { - ttl = DefaultTTL - } - return NewUdpForwardServer(s, ttl).ListenAndServe() - case "rtcp": // Remote TCP port forwarding - return NewRTcpForwardServer(s).Serve() - case "rudp": // Remote UDP port forwarding - return NewRUdpForwardServer(s).Serve() - case "quic": - return NewQuicServer(s).ListenAndServeTLS(s.TLSConfig) - case "kcp": - config, err := ParseKCPConfig(s.Node.Get("c")) - if err != nil { - glog.V(LWARNING).Infoln("[kcp]", err) - } - if config == nil { - config = DefaultKCPConfig - } - // override crypt and key if specified explicitly - if s.Node.Users != nil { - config.Crypt = s.Node.Users[0].Username() - config.Key, _ = s.Node.Users[0].Password() - } - return NewKCPServer(s, config).ListenAndServe() - case "redirect": - return NewRedsocksTCPServer(s).ListenAndServe() - case "ssu": // shadowsocks udp relay - ttl, _ := strconv.Atoi(s.Node.Get("ttl")) - if ttl <= 0 { - ttl = DefaultTTL - } - return NewShadowUdpServer(s, ttl).ListenAndServe() - case "pht": // pure http tunnel - return NewPureHttpServer(s).ListenAndServe() - case "ssh": // SSH tunnel - /* - key := s.Node.Get("key") - privateBytes, err := ioutil.ReadFile(key) - if err != nil { - glog.V(LWARNING).Infoln("[ssh]", err) - privateBytes = defaultRawKey - } - private, err := ssh.ParsePrivateKey(privateBytes) - if err != nil { - return err - } - */ - config := ssh.ServerConfig{ - PasswordCallback: DefaultPasswordCallback(s.Node.Users), - } - if len(s.Node.Users) == 0 { - config.NoClientAuth = true - } - signer, err := ssh.NewSignerFromKey(s.TLSConfig.Certificates[0].PrivateKey) + if l == nil { + ln, err := TCPListener(":8080") if err != nil { return err } - config.AddHostKey(signer) - s := &SSHServer{ - Addr: node.Addr, - Base: s, - Config: &config, - } - return s.ListenAndServe() - default: - ln, err = net.Listen("tcp", node.Addr) + l = ln + } + if h == nil { + h = HTTPHandler() } - if err != nil { - return err - } - - defer ln.Close() - + var tempDelay time.Duration for { - conn, err := ln.Accept() - if err != nil { - glog.V(LWARNING).Infoln(err) - continue + conn, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("server: Accept error: %v; retrying in %v", e, tempDelay) + time.Sleep(tempDelay) + continue + } + return e } - - setKeepAlive(conn, KeepAliveTime) - - go s.handleConn(conn) + tempDelay = 0 + go h.Handle(conn) } + } -func (s *ProxyServer) handleConn(conn net.Conn) { - defer conn.Close() +// Listener is a proxy server listener, just like a net.Listener. +type Listener interface { + net.Listener +} - switch s.Node.Protocol { - case "ss": // shadowsocks - //server := NewShadowServer(ss.NewConn(conn, s.cipher.Copy()), s) - //server.OTA = s.ota - //server.Serve() - return - case "http": - req, err := http.ReadRequest(bufio.NewReader(conn)) - if err != nil { - glog.V(LWARNING).Infoln("[http]", err) - return - } - NewHttpServer(conn, s).HandleRequest(req) - return - case "socks", "socks5": - conn = gosocks5.ServerConn(conn, s.selector) - req, err := gosocks5.ReadRequest(conn) - if err != nil { - glog.V(LWARNING).Infoln("[socks5]", err) - return - } - NewSocks5Server(conn, s).HandleRequest(req) - return - case "socks4", "socks4a": - req, err := gosocks4.ReadRequest(conn) - if err != nil { - glog.V(LWARNING).Infoln("[socks4]", err) - return - } - NewSocks4Server(conn, s).HandleRequest(req) - return - } +type tcpListener struct { + net.Listener +} - br := bufio.NewReader(conn) - b, err := br.Peek(1) +// TCPListener creates a Listener for TCP proxy server. +func TCPListener(addr string) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + return &tcpListener{Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() if err != nil { - glog.V(LWARNING).Infoln(err) return } - - switch b[0] { - case gosocks4.Ver4: - req, err := gosocks4.ReadRequest(br) - if err != nil { - glog.V(LWARNING).Infoln("[socks4]", err) - return - } - NewSocks4Server(conn, s).HandleRequest(req) - - case gosocks5.Ver5: - methods, err := gosocks5.ReadMethods(br) - if err != nil { - glog.V(LWARNING).Infoln("[socks5]", err) - return - } - method := s.selector.Select(methods...) - if _, err := conn.Write([]byte{gosocks5.Ver5, method}); err != nil { - glog.V(LWARNING).Infoln("[socks5] select:", err) - return - } - c, err := s.selector.OnSelected(method, conn) - if err != nil { - glog.V(LWARNING).Infoln("[socks5] onselected:", err) - return - } - conn = c - - req, err := gosocks5.ReadRequest(conn) - if err != nil { - glog.V(LWARNING).Infoln("[socks5] request:", err) - return - } - NewSocks5Server(conn, s).HandleRequest(req) - - default: // http - req, err := http.ReadRequest(br) - if err != nil { - glog.V(LWARNING).Infoln("[http]", err) - return - } - NewHttpServer(conn, s).HandleRequest(req) - } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(KeepAliveTime) + return tc, nil } -func (_ *ProxyServer) transport(conn1, conn2 net.Conn) (err error) { - errc := make(chan error, 2) - +func transport(rw1, rw2 io.ReadWriter) error { + errc := make(chan error, 1) go func() { - _, err := io.Copy(conn1, conn2) + _, err := io.Copy(rw1, rw2) errc <- err }() go func() { - _, err := io.Copy(conn2, conn1) + _, err := io.Copy(rw2, rw1) errc <- err }() - select { - case err = <-errc: - // glog.V(LWARNING).Infoln("transport exit", err) + err := <-errc + if err != nil && err == io.EOF { + err = nil } - - return + return err } diff --git a/signal_unix.go b/signal_unix.go index f46b394..a761318 100644 --- a/signal_unix.go +++ b/signal_unix.go @@ -3,11 +3,12 @@ package gost import ( - "github.com/golang/glog" - "gopkg.in/xtaci/kcp-go.v2" "os" "os/signal" "syscall" + + "github.com/go-log/log" + "gopkg.in/xtaci/kcp-go.v2" ) func kcpSigHandler() { @@ -17,7 +18,7 @@ func kcpSigHandler() { for { switch <-ch { case syscall.SIGUSR1: - glog.V(LINFO).Infof("[kcp] SNMP: %+v", kcp.DefaultSnmp.Copy()) + log.Logf("[kcp] SNMP: %+v", kcp.DefaultSnmp.Copy()) } } } diff --git a/socks.go b/socks.go index 31b1b9e..858e725 100644 --- a/socks.go +++ b/socks.go @@ -3,45 +3,51 @@ package gost import ( "bytes" "crypto/tls" + "errors" + "fmt" "net" "net/url" "strconv" "time" + "io" + "github.com/ginuerzh/gosocks4" "github.com/ginuerzh/gosocks5" "github.com/go-log/log" - "github.com/golang/glog" ) const ( - MethodTLS uint8 = 0x80 // extended method for tls - MethodTLSAuth uint8 = 0x82 // extended method for tls+auth + // MethodTLS is an extended SOCKS5 method for TLS. + MethodTLS uint8 = 0x80 + // MethodTLSAuth is an extended SOCKS5 method for TLS+AUTH. + MethodTLSAuth uint8 = 0x82 ) const ( - CmdUdpTun uint8 = 0xF3 // extended method for udp over tcp + // CmdUDPTun is an extended SOCKS5 method for UDP over TCP. + CmdUDPTun uint8 = 0xF3 ) -type ClientSelector struct { +type clientSelector struct { methods []uint8 User *url.Userinfo TLSConfig *tls.Config } -func (selector *ClientSelector) Methods() []uint8 { +func (selector *clientSelector) Methods() []uint8 { return selector.methods } -func (selector *ClientSelector) AddMethod(methods ...uint8) { +func (selector *clientSelector) AddMethod(methods ...uint8) { selector.methods = append(selector.methods, methods...) } -func (selector *ClientSelector) Select(methods ...uint8) (method uint8) { +func (selector *clientSelector) Select(methods ...uint8) (method uint8) { return } -func (selector *ClientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { +func (selector *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { switch method { case MethodTLS: conn = tls.Client(conn, selector.TLSConfig) @@ -63,7 +69,7 @@ func (selector *ClientSelector) OnSelected(method uint8, conn net.Conn) (net.Con return nil, err } if Debug { - log.Log(req) + log.Log("[socks5]", req) } resp, err := gosocks5.ReadUserPassResponse(conn) if err != nil { @@ -71,7 +77,7 @@ func (selector *ClientSelector) OnSelected(method uint8, conn net.Conn) (net.Con return nil, err } if Debug { - log.Log(resp) + log.Log("[socks5]", resp) } if resp.Status != gosocks5.Succeeded { return nil, gosocks5.ErrAuthFailure @@ -83,21 +89,21 @@ func (selector *ClientSelector) OnSelected(method uint8, conn net.Conn) (net.Con return conn, nil } -type ServerSelector struct { +type serverSelector struct { methods []uint8 - Users []url.Userinfo + Users []*url.Userinfo TLSConfig *tls.Config } -func (selector *ServerSelector) Methods() []uint8 { +func (selector *serverSelector) Methods() []uint8 { return selector.methods } -func (selector *ServerSelector) AddMethod(methods ...uint8) { +func (selector *serverSelector) AddMethod(methods ...uint8) { selector.methods = append(selector.methods, methods...) } -func (selector *ServerSelector) Select(methods ...uint8) (method uint8) { +func (selector *serverSelector) Select(methods ...uint8) (method uint8) { if Debug { log.Logf("[socks5] %d %d %v", gosocks5.Ver5, len(methods), methods) } @@ -110,7 +116,7 @@ func (selector *ServerSelector) Select(methods ...uint8) (method uint8) { } // when user/pass is set, auth is mandatory - if selector.Users != nil { + if len(selector.Users) > 0 { if method == gosocks5.MethodNoAuth { method = gosocks5.MethodUserPass } @@ -122,7 +128,7 @@ func (selector *ServerSelector) Select(methods ...uint8) (method uint8) { return } -func (selector *ServerSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { +func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { if Debug { log.Logf("[socks5] %d %d", gosocks5.Ver5, method) } @@ -182,14 +188,170 @@ func (selector *ServerSelector) OnSelected(method uint8, conn net.Conn) (net.Con return conn, nil } -type socks5Handler struct { - server Server +type socks5Connector struct { + User *url.Userinfo } -func (h *socks5Handler) Handle(conn net.Conn) { - selector := &ServerSelector{ - Users: h.server.Options().BaseOptions().Users, - TLSConfig: config, +// SOCKS5Connector creates a connector for SOCKS5 proxy client. +// It accepts an optional auth info for SOCKS5 Username/Password Authentication. +func SOCKS5Connector(user *url.Userinfo) Connector { + return &socks5Connector{User: user} +} + +func (c *socks5Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { + selector := &clientSelector{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + User: c.User, + } + selector.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + MethodTLS, + ) + + cc := gosocks5.ClientConn(conn, selector) + if err := cc.Handleshake(); err != nil { + return nil, err + } + conn = cc + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + p, _ := strconv.Atoi(port) + req := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{ + Type: gosocks5.AddrDomain, + Host: host, + Port: uint16(p), + }) + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5]", req) + } + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5]", reply) + } + + if reply.Rep != gosocks5.Succeeded { + return nil, errors.New("Service unavailable") + } + + return conn, nil +} + +type socks4Connector struct{} + +// SOCKS4Connector creates a Connector for SOCKS4 proxy client. +func SOCKS4Connector() Connector { + return &socks4Connector{} +} + +func (c *socks4Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { + taddr, err := net.ResolveTCPAddr("tcp4", addr) + if err != nil { + return nil, err + } + + req := gosocks4.NewRequest(gosocks4.CmdConnect, + &gosocks4.Addr{ + Type: gosocks4.AddrIPv4, + Host: taddr.IP.String(), + Port: uint16(taddr.Port), + }, nil, + ) + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4] %s", req) + } + + reply, err := gosocks4.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4] %s", reply) + } + + if reply.Code != gosocks4.Granted { + return nil, fmt.Errorf("[socks4] %d", reply.Code) + } + + return conn, nil +} + +type socks4aConnector struct{} + +// SOCKS4AConnector creates a Connector for SOCKS4A proxy client. +func SOCKS4AConnector() Connector { + return &socks4aConnector{} +} + +func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + p, _ := strconv.Atoi(port) + + req := gosocks4.NewRequest(gosocks4.CmdConnect, + &gosocks4.Addr{Type: gosocks4.AddrDomain, Host: host, Port: uint16(p)}, nil) + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4] %s", req) + } + + reply, err := gosocks4.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4] %s", reply) + } + + if reply.Code != gosocks4.Granted { + return nil, fmt.Errorf("[socks4] %d", reply.Code) + } + + return conn, nil +} + +type socks5Handler struct { + selector *serverSelector + options *HandlerOptions +} + +// SOCKS5Handler creates a server Handler for SOCKS5 proxy server. +func SOCKS5Handler(opts ...HandlerOption) Handler { + options := &HandlerOptions{} + for _, opt := range opts { + opt(options) + } + + tlsConfig := options.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + selector := &serverSelector{ // socks5 server selector + Users: options.Users, + TLSConfig: tlsConfig, } // methods that socks5 server supported selector.AddMethod( @@ -198,285 +360,134 @@ func (h *socks5Handler) Handle(conn net.Conn) { MethodTLS, MethodTLSAuth, ) - conn = gosocks5.ServerConn(conn, s.selector) + return &socks5Handler{ + options: options, + selector: selector, + } +} + +func (h *socks5Handler) Handle(conn net.Conn) { + defer conn.Close() + + conn = gosocks5.ServerConn(conn, h.selector) req, err := gosocks5.ReadRequest(conn) if err != nil { - glog.V(LWARNING).Infoln("[socks5]", err) + log.Log("[socks5]", err) return } -} - -type Socks5Server struct { - conn net.Conn - Base *ProxyServer -} - -func NewSocks5Server(conn net.Conn, base *ProxyServer) *Socks5Server { - return &Socks5Server{conn: conn, Base: base} -} - -func (s *Socks5Server) HandleRequest(req *gosocks5.Request) { - glog.V(LDEBUG).Infof("[socks5] %s -> %s\n%s", s.conn.RemoteAddr(), req.Addr, req) + if Debug { + log.Logf("[socks5] %s - %s\n%s", conn.RemoteAddr(), req.Addr, req) + } switch req.Cmd { case gosocks5.CmdConnect: - glog.V(LINFO).Infof("[socks5-connect] %s -> %s", s.conn.RemoteAddr(), req.Addr) - s.handleConnect(req) + h.handleConnect(conn, req) case gosocks5.CmdBind: - glog.V(LINFO).Infof("[socks5-bind] %s - %s", s.conn.RemoteAddr(), req.Addr) - s.handleBind(req) + h.handleBind(conn, req) case gosocks5.CmdUdp: - glog.V(LINFO).Infof("[socks5-udp] %s - %s", s.conn.RemoteAddr(), req.Addr) - s.handleUDPRelay(req) + h.handleUDPRelay(conn, req) - case CmdUdpTun: - glog.V(LINFO).Infof("[socks5-rudp] %s - %s", s.conn.RemoteAddr(), req.Addr) - s.handleUDPTunnel(req) + case CmdUDPTun: + h.handleUDPTunnel(conn, req) default: - glog.V(LWARNING).Infoln("[socks5] Unrecognized request:", req.Cmd) + log.Log("[socks5] Unrecognized request:", req.Cmd) } } -func (s *Socks5Server) handleConnect(req *gosocks5.Request) { +func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { addr := req.Addr.String() - - if !s.Base.Node.Can("tcp", addr) { - glog.Errorf("Unauthorized to tcp connect to %s", addr) + if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-connect] Unauthorized to tcp connect to %s", addr) rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - rep.Write(s.conn) + rep.Write(conn) + if Debug { + log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } return } - cc, err := s.Base.Chain.Dial(addr) + cc, err := h.options.Chain.Dial(addr) if err != nil { - glog.V(LWARNING).Infof("[socks5-connect] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err) + log.Logf("[socks5-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil) - rep.Write(s.conn) - glog.V(LDEBUG).Infof("[socks5-connect] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, rep) + rep.Write(conn) + if Debug { + log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } return } defer cc.Close() rep := gosocks5.NewReply(gosocks5.Succeeded, nil) - if err := rep.Write(s.conn); err != nil { - glog.V(LWARNING).Infof("[socks5-connect] %s <- %s : %s", s.conn.RemoteAddr(), req.Addr, err) + if err := rep.Write(conn); err != nil { + log.Logf("[socks5-connect] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) return } - glog.V(LDEBUG).Infof("[socks5-connect] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, rep) - - glog.V(LINFO).Infof("[socks5-connect] %s <-> %s", s.conn.RemoteAddr(), req.Addr) - //Transport(conn, cc) - s.Base.transport(s.conn, cc) - glog.V(LINFO).Infof("[socks5-connect] %s >-< %s", s.conn.RemoteAddr(), req.Addr) + if Debug { + log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + log.Logf("[socks5-connect] %s <-> %s", conn.RemoteAddr(), req.Addr) + transport(conn, cc) + log.Logf("[socks5-connect] %s >-< %s", conn.RemoteAddr(), req.Addr) } -func (s *Socks5Server) handleBind(req *gosocks5.Request) { - cc, err := s.Base.Chain.GetConn() - - // connection error when forwarding bind - if err != nil && err != ErrEmptyChain { - glog.V(LWARNING).Infof("[socks5-bind] %s <- %s : %s", s.conn.RemoteAddr(), req.Addr, err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(s.conn) - glog.V(LDEBUG).Infof("[socks5-bind] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, reply) +func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) { + if h.options.Chain.IsEmpty() { + addr := req.Addr.String() + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("Unauthorized to tcp bind to %s", addr) + return + } + h.bindOn(conn, addr) return } - // serve socks5 bind - if err == ErrEmptyChain { - addr := req.Addr.String() - - if !s.Base.Node.Can("rtcp", addr) { - glog.Errorf("Unauthorized to tcp bind to %s", addr) - return + cc, err := h.options.Chain.Conn() + if err != nil { + log.Logf("[socks5-bind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks5-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) } - - s.bindOn(addr) - return } // forward request - // note: this type of request forwarding is defined when starting server + // note: this type of request forwarding is defined when starting server, // so we don't need to authenticate it, as it's as explicit as whitelisting defer cc.Close() req.Write(cc) - glog.V(LINFO).Infof("[socks5-bind] %s <-> %s", s.conn.RemoteAddr(), cc.RemoteAddr()) - s.Base.transport(s.conn, cc) - glog.V(LINFO).Infof("[socks5-bind] %s >-< %s", s.conn.RemoteAddr(), cc.RemoteAddr()) + log.Logf("[socks5-bind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + transport(conn, cc) + log.Logf("[socks5-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) } -func (s *Socks5Server) handleUDPRelay(req *gosocks5.Request) { - addr := req.Addr.String() - - if !s.Base.Node.Can("udp", addr) { - glog.Errorf("Unauthorized to udp connect to %s", addr) - rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - rep.Write(s.conn) - return - } - - relay, err := net.ListenUDP("udp", nil) - if err != nil { - glog.V(LWARNING).Infof("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), relay.LocalAddr(), err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(s.conn) - glog.V(LDEBUG).Infof("[socks5-udp] %s <- %s\n%s", s.conn.RemoteAddr(), relay.LocalAddr(), reply) - return - } - defer relay.Close() - - socksAddr := ToSocksAddr(relay.LocalAddr()) - socksAddr.Host, _, _ = net.SplitHostPort(s.conn.LocalAddr().String()) // replace the IP to out-going interface's - reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) - if err := reply.Write(s.conn); err != nil { - glog.V(LWARNING).Infof("[socks5-udp] %s <- %s : %s", s.conn.RemoteAddr(), relay.LocalAddr(), err) - return - } - glog.V(LDEBUG).Infof("[socks5-udp] %s <- %s\n%s", s.conn.RemoteAddr(), reply.Addr, reply) - glog.V(LINFO).Infof("[socks5-udp] %s - %s BIND ON %s OK", s.conn.RemoteAddr(), relay.LocalAddr(), socksAddr) - - cc, err := s.Base.Chain.GetConn() - // connection error - if err != nil && err != ErrEmptyChain { - glog.V(LWARNING).Infof("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), socksAddr, err) - return - } - - // serve as standard socks5 udp relay local <-> remote - if err == ErrEmptyChain { - peer, er := net.ListenUDP("udp", nil) - if er != nil { - glog.V(LWARNING).Infof("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), socksAddr, er) - return - } - defer peer.Close() - - go s.transportUDP(relay, peer) - } - - // forward udp local <-> tunnel - if err == nil { - defer cc.Close() - - cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) - req := gosocks5.NewRequest(CmdUdpTun, nil) - if err := req.Write(cc); err != nil { - glog.V(LWARNING).Infoln("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), cc.RemoteAddr(), err) - return - } - cc.SetWriteDeadline(time.Time{}) - glog.V(LDEBUG).Infof("[socks5-udp] %s -> %s\n%s", s.conn.RemoteAddr(), cc.RemoteAddr(), req) - - cc.SetReadDeadline(time.Now().Add(ReadTimeout)) - reply, err = gosocks5.ReadReply(cc) - if err != nil { - glog.V(LWARNING).Infoln("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), cc.RemoteAddr(), err) - return - } - glog.V(LDEBUG).Infof("[socks5-udp] %s <- %s\n%s", s.conn.RemoteAddr(), cc.RemoteAddr(), reply) - - if reply.Rep != gosocks5.Succeeded { - glog.V(LWARNING).Infoln("[socks5-udp] %s <- %s : udp associate failed", s.conn.RemoteAddr(), cc.RemoteAddr()) - return - } - cc.SetReadDeadline(time.Time{}) - glog.V(LINFO).Infof("[socks5-udp] %s <-> %s [tun: %s]", s.conn.RemoteAddr(), socksAddr, reply.Addr) - - go s.tunnelClientUDP(relay, cc) - } - - glog.V(LINFO).Infof("[socks5-udp] %s <-> %s", s.conn.RemoteAddr(), socksAddr) - b := make([]byte, SmallBufferSize) - for { - _, err := s.conn.Read(b) // discard any data from tcp connection - if err != nil { - glog.V(LWARNING).Infof("[socks5-udp] %s - %s : %s", s.conn.RemoteAddr(), socksAddr, err) - break // client disconnected - } - } - glog.V(LINFO).Infof("[socks5-udp] %s >-< %s", s.conn.RemoteAddr(), socksAddr) -} - -func (s *Socks5Server) handleUDPTunnel(req *gosocks5.Request) { - cc, err := s.Base.Chain.GetConn() - - // connection error - if err != nil && err != ErrEmptyChain { - glog.V(LWARNING).Infof("[socks5-rudp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(s.conn) - glog.V(LDEBUG).Infof("[socks5-rudp] %s -> %s\n%s", s.conn.RemoteAddr(), req.Addr, reply) - return - } - - // serve tunnel udp, tunnel <-> remote, handle tunnel udp request - if err == ErrEmptyChain { - addr := req.Addr.String() - - if !s.Base.Node.Can("rudp", addr) { - glog.Errorf("Unauthorized to udp bind to %s", addr) - return - } - - bindAddr, _ := net.ResolveUDPAddr("udp", addr) - uc, err := net.ListenUDP("udp", bindAddr) - if err != nil { - glog.V(LWARNING).Infof("[socks5-rudp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err) - return - } - defer uc.Close() - - socksAddr := ToSocksAddr(uc.LocalAddr()) - socksAddr.Host, _, _ = net.SplitHostPort(s.conn.LocalAddr().String()) - reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) - if err := reply.Write(s.conn); err != nil { - glog.V(LWARNING).Infof("[socks5-rudp] %s <- %s : %s", s.conn.RemoteAddr(), socksAddr, err) - return - } - glog.V(LDEBUG).Infof("[socks5-rudp] %s <- %s\n%s", s.conn.RemoteAddr(), socksAddr, reply) - - glog.V(LINFO).Infof("[socks5-rudp] %s <-> %s", s.conn.RemoteAddr(), socksAddr) - s.tunnelServerUDP(s.conn, uc) - glog.V(LINFO).Infof("[socks5-rudp] %s >-< %s", s.conn.RemoteAddr(), socksAddr) - return - } - - defer cc.Close() - - // tunnel <-> tunnel, direct forwarding - // note: this type of request forwarding is defined when starting server - // so we don't need to authenticate it, as it's as explicit as whitelisting - req.Write(cc) - - glog.V(LINFO).Infof("[socks5-rudp] %s <-> %s [tun]", s.conn.RemoteAddr(), cc.RemoteAddr()) - s.Base.transport(s.conn, cc) - glog.V(LINFO).Infof("[socks5-rudp] %s >-< %s [tun]", s.conn.RemoteAddr(), cc.RemoteAddr()) -} - -func (s *Socks5Server) bindOn(addr string) { +func (h *socks5Handler) bindOn(conn net.Conn, addr string) { bindAddr, _ := net.ResolveTCPAddr("tcp", addr) ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error if err != nil { - glog.V(LWARNING).Infof("[socks5-bind] %s -> %s : %s", s.conn.RemoteAddr(), addr, err) - gosocks5.NewReply(gosocks5.Failure, nil).Write(s.conn) + log.Logf("[socks5-bind] %s -> %s : %s", conn.RemoteAddr(), addr, err) + gosocks5.NewReply(gosocks5.Failure, nil).Write(conn) return } - socksAddr := ToSocksAddr(ln.Addr()) + socksAddr := toSocksAddr(ln.Addr()) // Issue: may not reachable when host has multi-interface - socksAddr.Host, _, _ = net.SplitHostPort(s.conn.LocalAddr().String()) + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) - if err := reply.Write(s.conn); err != nil { - glog.V(LWARNING).Infof("[socks5-bind] %s <- %s : %s", s.conn.RemoteAddr(), addr, err) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5-bind] %s <- %s : %s", conn.RemoteAddr(), addr, err) ln.Close() return } - glog.V(LDEBUG).Infof("[socks5-bind] %s <- %s\n%s", s.conn.RemoteAddr(), addr, reply) - glog.V(LINFO).Infof("[socks5-bind] %s - %s BIND ON %s OK", s.conn.RemoteAddr(), addr, socksAddr) + if Debug { + log.Logf("[socks5-bind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply) + } + log.Logf("[socks5-bind] %s - %s BIND ON %s OK", conn.RemoteAddr(), addr, socksAddr) var pconn net.Conn accept := func() <-chan error { @@ -505,7 +516,7 @@ func (s *Socks5Server) bindOn(addr string) { defer close(errc) defer pc1.Close() - errc <- s.Base.transport(s.conn, pc1) + errc <- transport(conn, pc1) }() return errc @@ -517,39 +528,163 @@ func (s *Socks5Server) bindOn(addr string) { select { case err := <-accept(): if err != nil || pconn == nil { - glog.V(LWARNING).Infof("[socks5-bind] %s <- %s : %s", s.conn.RemoteAddr(), addr, err) + log.Logf("[socks5-bind] %s <- %s : %v", conn.RemoteAddr(), addr, err) return } defer pconn.Close() - reply := gosocks5.NewReply(gosocks5.Succeeded, ToSocksAddr(pconn.RemoteAddr())) + reply := gosocks5.NewReply(gosocks5.Succeeded, toSocksAddr(pconn.RemoteAddr())) if err := reply.Write(pc2); err != nil { - glog.V(LWARNING).Infof("[socks5-bind] %s <- %s : %s", s.conn.RemoteAddr(), addr, err) + log.Logf("[socks5-bind] %s <- %s : %v", conn.RemoteAddr(), addr, err) } - glog.V(LDEBUG).Infof("[socks5-bind] %s <- %s\n%s", s.conn.RemoteAddr(), addr, reply) - glog.V(LINFO).Infof("[socks5-bind] %s <- %s PEER %s ACCEPTED", s.conn.RemoteAddr(), socksAddr, pconn.RemoteAddr()) + if Debug { + log.Logf("[socks5-bind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply) + } + log.Logf("[socks5-bind] %s <- %s PEER %s ACCEPTED", conn.RemoteAddr(), socksAddr, pconn.RemoteAddr()) - glog.V(LINFO).Infof("[socks5-bind] %s <-> %s", s.conn.RemoteAddr(), pconn.RemoteAddr()) - if err = s.Base.transport(pc2, pconn); err != nil { - glog.V(LWARNING).Infoln(err) + log.Logf("[socks5-bind] %s <-> %s", conn.RemoteAddr(), pconn.RemoteAddr()) + if err = transport(pc2, pconn); err != nil { + log.Logf("[socks5-bind] %s - %s : %v", conn.RemoteAddr(), pconn.RemoteAddr(), err) } - glog.V(LINFO).Infof("[socks5-bind] %s >-< %s", s.conn.RemoteAddr(), pconn.RemoteAddr()) + log.Logf("[socks5-bind] %s >-< %s", conn.RemoteAddr(), pconn.RemoteAddr()) return case err := <-pipe(): - glog.V(LWARNING).Infof("[socks5-bind] %s -> %s : %v", s.conn.RemoteAddr(), addr, err) + if err != nil { + log.Logf("[socks5-bind] %s -> %s : %v", conn.RemoteAddr(), addr, err) + } ln.Close() return } } } -func (s *Socks5Server) transportUDP(relay, peer *net.UDPConn) (err error) { +func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { + addr := req.Addr.String() + if !Can("udp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-udp] Unauthorized to udp connect to %s", addr) + rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + return + } + + relay, err := net.ListenUDP("udp", nil) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), relay.LocalAddr(), err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), relay.LocalAddr(), reply) + } + return + } + defer relay.Close() + + socksAddr := toSocksAddr(relay.LocalAddr()) + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), relay.LocalAddr(), err) + return + } + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), reply.Addr, reply) + } + log.Logf("[socks5-udp] %s - %s BIND ON %s OK", conn.RemoteAddr(), relay.LocalAddr(), socksAddr) + + // serve as standard socks5 udp relay local <-> remote + if h.options.Chain.IsEmpty() { + peer, er := net.ListenUDP("udp", nil) + if er != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, er) + return + } + defer peer.Close() + + go h.transportUDP(relay, peer) + log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) + if err := h.discardClientData(conn); err != nil { + log.Logf("[socks5-udp] %s - %s : %s", conn.RemoteAddr(), socksAddr, err) + } + log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) + return + } + + cc, err := h.options.Chain.Conn() + // connection error + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + // forward udp local <-> tunnel + defer cc.Close() + + cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + + cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) + r := gosocks5.NewRequest(CmdUDPTun, nil) + if err := r.Write(cc); err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), cc.RemoteAddr(), err) + return + } + cc.SetWriteDeadline(time.Time{}) + if Debug { + log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), cc.RemoteAddr(), r) + } + cc.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err = gosocks5.ReadReply(cc) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), cc.RemoteAddr(), err) + return + } + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), cc.RemoteAddr(), reply) + } + + if reply.Rep != gosocks5.Succeeded { + log.Logf("[socks5-udp] %s <- %s : udp associate failed", conn.RemoteAddr(), cc.RemoteAddr()) + return + } + cc.SetReadDeadline(time.Time{}) + log.Logf("[socks5-udp] %s <-> %s [tun: %s]", conn.RemoteAddr(), socksAddr, reply.Addr) + + go h.tunnelClientUDP(relay, cc) + log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) + if err := h.discardClientData(conn); err != nil { + log.Logf("[socks5-udp] %s - %s : %s", conn.RemoteAddr(), socksAddr, err) + } + log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) +} + +func (h *socks5Handler) discardClientData(conn net.Conn) (err error) { + b := make([]byte, tinyBufferSize) + n := 0 + for { + n, err = conn.Read(b) // discard any data from tcp connection + if err != nil { + if err == io.EOF { // disconnect normally + err = nil + } + break // client disconnected + } + log.Logf("[socks5-udp] read %d UNEXPECTED TCP data from client", n) + } + return +} + +func (h *socks5Handler) transportUDP(relay, peer *net.UDPConn) (err error) { errc := make(chan error, 2) var clientAddr *net.UDPAddr go func() { - b := make([]byte, LargeBufferSize) + b := make([]byte, largeBufferSize) for { n, laddr, err := relay.ReadFromUDP(b) @@ -574,12 +709,14 @@ func (s *Socks5Server) transportUDP(relay, peer *net.UDPConn) (err error) { errc <- err return } - glog.V(LDEBUG).Infof("[socks5-udp] %s >>> %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) + if Debug { + log.Logf("[socks5-udp] %s >>> %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) + } } }() go func() { - b := make([]byte, LargeBufferSize) + b := make([]byte, largeBufferSize) for { n, raddr, err := peer.ReadFromUDP(b) @@ -591,13 +728,15 @@ func (s *Socks5Server) transportUDP(relay, peer *net.UDPConn) (err error) { continue } buf := bytes.Buffer{} - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, ToSocksAddr(raddr)), b[:n]) + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(raddr)), b[:n]) dgram.Write(&buf) if _, err := relay.WriteToUDP(buf.Bytes(), clientAddr); err != nil { errc <- err return } - glog.V(LDEBUG).Infof("[socks5-udp] %s <<< %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) + if Debug { + log.Logf("[socks5-udp] %s <<< %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) + } } }() @@ -609,18 +748,18 @@ func (s *Socks5Server) transportUDP(relay, peer *net.UDPConn) (err error) { return } -func (s *Socks5Server) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) { +func (h *socks5Handler) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) { errc := make(chan error, 2) var clientAddr *net.UDPAddr go func() { - b := make([]byte, LargeBufferSize) + b := make([]byte, mediumBufferSize) for { n, addr, err := uc.ReadFromUDP(b) if err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) + log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) errc <- err return } @@ -640,7 +779,9 @@ func (s *Socks5Server) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) errc <- err return } - glog.V(LDEBUG).Infof("[udp-tun] %s >>> %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) + if Debug { + log.Logf("[udp-tun] %s >>> %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) + } } }() @@ -648,7 +789,7 @@ func (s *Socks5Server) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) for { dgram, err := gosocks5.ReadUDPDatagram(cc) if err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) + log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) errc <- err return } @@ -665,7 +806,9 @@ func (s *Socks5Server) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) errc <- err return } - glog.V(LDEBUG).Infof("[udp-tun] %s <<< %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) + if Debug { + log.Logf("[udp-tun] %s <<< %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) + } } }() @@ -676,29 +819,91 @@ func (s *Socks5Server) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) return } -func (s *Socks5Server) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) { +func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { + // serve tunnel udp, tunnel <-> remote, handle tunnel udp request + if h.options.Chain.IsEmpty() { + addr := req.Addr.String() + + if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-udp] Unauthorized to udp bind to %s", addr) + return + } + + bindAddr, _ := net.ResolveUDPAddr("udp", addr) + uc, err := net.ListenUDP("udp", bindAddr) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + return + } + defer uc.Close() + + socksAddr := toSocksAddr(uc.LocalAddr()) + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply) + } + log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) + h.tunnelServerUDP(conn, uc) + log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) + return + } + + cc, err := h.options.Chain.Conn() + // connection error + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply) + return + } + defer cc.Close() + + cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + return + } + // tunnel <-> tunnel, direct forwarding + // note: this type of request forwarding is defined when starting server + // so we don't need to authenticate it, as it's as explicit as whitelisting + req.Write(cc) + + log.Logf("[socks5-udp] %s <-> %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) + transport(conn, cc) + log.Logf("[socks5-udp] %s >-< %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) +} + +func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) { errc := make(chan error, 2) go func() { - b := make([]byte, LargeBufferSize) + b := make([]byte, mediumBufferSize) for { n, addr, err := uc.ReadFromUDP(b) if err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) + log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) errc <- err return } // pipe from peer to tunnel dgram := gosocks5.NewUDPDatagram( - gosocks5.NewUDPHeader(uint16(n), 0, ToSocksAddr(addr)), b[:n]) + gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) if err := dgram.Write(cc); err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err) + log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err) errc <- err return } - glog.V(LDEBUG).Infof("[udp-tun] %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data)) + if Debug { + log.Logf("[udp-tun] %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data)) + } } }() @@ -706,7 +911,7 @@ func (s *Socks5Server) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) for { dgram, err := gosocks5.ReadUDPDatagram(cc) if err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) + log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) errc <- err return } @@ -717,11 +922,13 @@ func (s *Socks5Server) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) continue // drop silently } if _, err := uc.WriteToUDP(dgram.Data, addr); err != nil { - glog.V(LWARNING).Infof("[udp-tun] %s -> %s : %s", cc.RemoteAddr(), addr, err) + log.Logf("[udp-tun] %s -> %s : %s", cc.RemoteAddr(), addr, err) errc <- err return } - glog.V(LDEBUG).Infof("[udp-tun] %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data)) + if Debug { + log.Logf("[udp-tun] %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data)) + } } }() @@ -732,7 +939,7 @@ func (s *Socks5Server) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) return } -func ToSocksAddr(addr net.Addr) *gosocks5.Addr { +func toSocksAddr(addr net.Addr) *gosocks5.Addr { host := "0.0.0.0" port := 0 if addr != nil { @@ -747,78 +954,107 @@ func ToSocksAddr(addr net.Addr) *gosocks5.Addr { } } -type Socks4Server struct { - conn net.Conn - Base *ProxyServer +type socks4Handler struct { + options *HandlerOptions } -func NewSocks4Server(conn net.Conn, base *ProxyServer) *Socks4Server { - return &Socks4Server{conn: conn, Base: base} -} - -func (s *Socks4Server) HandleRequest(req *gosocks4.Request) { - glog.V(LDEBUG).Infof("[socks4] %s -> %s\n%s", s.conn.RemoteAddr(), req.Addr, req) - - switch req.Cmd { - case gosocks4.CmdConnect: - glog.V(LINFO).Infof("[socks4-connect] %s -> %s", s.conn.RemoteAddr(), req.Addr) - s.handleConnect(req) - - case gosocks4.CmdBind: - glog.V(LINFO).Infof("[socks4-bind] %s - %s", s.conn.RemoteAddr(), req.Addr) - s.handleBind(req) - - default: - glog.V(LWARNING).Infoln("[socks4] Unrecognized request:", req.Cmd) +// SOCKS4Handler creates a server Handler for SOCKS4(A) proxy server. +func SOCKS4Handler(opts ...HandlerOption) Handler { + options := &HandlerOptions{} + for _, opt := range opts { + opt(options) + } + return &socks4Handler{ + options: options, } } -func (s *Socks4Server) handleConnect(req *gosocks4.Request) { - addr := req.Addr.String() +func (h *socks4Handler) Handle(conn net.Conn) { + defer conn.Close() - if !s.Base.Node.Can("tcp", addr) { - glog.Errorf("Unauthorized to tcp connect to %s", addr) - rep := gosocks5.NewReply(gosocks4.Rejected, nil) - rep.Write(s.conn) + req, err := gosocks4.ReadRequest(conn) + if err != nil { + log.Log("[socks4]", err) return } - cc, err := s.Base.Chain.Dial(addr) + if Debug { + log.Logf("[socks4] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, req) + } + + switch req.Cmd { + case gosocks4.CmdConnect: + log.Logf("[socks4-connect] %s -> %s", conn.RemoteAddr(), req.Addr) + h.handleConnect(conn, req) + + case gosocks4.CmdBind: + log.Logf("[socks4-bind] %s - %s", conn.RemoteAddr(), req.Addr) + h.handleBind(conn, req) + + default: + log.Logf("[socks4] Unrecognized request: %d", req.Cmd) + } +} + +func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { + addr := req.Addr.String() + + if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks4-connect] Unauthorized to tcp connect to %s", addr) + rep := gosocks5.NewReply(gosocks4.Rejected, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + return + } + + cc, err := h.options.Chain.Dial(addr) if err != nil { - glog.V(LWARNING).Infof("[socks4-connect] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err) + log.Logf("[socks4-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) rep := gosocks4.NewReply(gosocks4.Failed, nil) - rep.Write(s.conn) - glog.V(LDEBUG).Infof("[socks4-connect] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, rep) + rep.Write(conn) + if Debug { + log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } return } defer cc.Close() rep := gosocks4.NewReply(gosocks4.Granted, nil) - if err := rep.Write(s.conn); err != nil { - glog.V(LWARNING).Infof("[socks4-connect] %s <- %s : %s", s.conn.RemoteAddr(), req.Addr, err) + if err := rep.Write(conn); err != nil { + log.Logf("[socks4-connect] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) return } - glog.V(LDEBUG).Infof("[socks4-connect] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, rep) + if Debug { + log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } - glog.V(LINFO).Infof("[socks4-connect] %s <-> %s", s.conn.RemoteAddr(), req.Addr) - s.Base.transport(s.conn, cc) - glog.V(LINFO).Infof("[socks4-connect] %s >-< %s", s.conn.RemoteAddr(), req.Addr) + log.Logf("[socks4-connect] %s <-> %s", conn.RemoteAddr(), req.Addr) + transport(conn, cc) + log.Logf("[socks4-connect] %s >-< %s", conn.RemoteAddr(), req.Addr) } -func (s *Socks4Server) handleBind(req *gosocks4.Request) { - cc, err := s.Base.Chain.GetConn() - - // connection error - if err != nil && err != ErrEmptyChain { - glog.V(LWARNING).Infof("[socks4-bind] %s <- %s : %s", s.conn.RemoteAddr(), req.Addr, err) - reply := gosocks4.NewReply(gosocks4.Failed, nil) - reply.Write(s.conn) - glog.V(LDEBUG).Infof("[socks4-bind] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, reply) +func (h *socks4Handler) handleBind(conn net.Conn, req *gosocks4.Request) { + // TODO: serve socks4 bind + if h.options.Chain.IsEmpty() { + reply := gosocks4.NewReply(gosocks4.Rejected, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks4-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) + } return } - // TODO: serve socks4 bind - if err == ErrEmptyChain { - //s.bindOn(req.Addr.String()) + + cc, err := h.options.Chain.Conn() + // connection error + if err != nil && err != ErrEmptyChain { + log.Logf("[socks4-bind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) + reply := gosocks4.NewReply(gosocks4.Failed, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks4-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) + } return } @@ -826,7 +1062,109 @@ func (s *Socks4Server) handleBind(req *gosocks4.Request) { // forward request req.Write(cc) - glog.V(LINFO).Infof("[socks4-bind] %s <-> %s", s.conn.RemoteAddr(), cc.RemoteAddr()) - s.Base.transport(s.conn, cc) - glog.V(LINFO).Infof("[socks4-bind] %s >-< %s", s.conn.RemoteAddr(), cc.RemoteAddr()) + log.Logf("[socks4-bind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + transport(conn, cc) + log.Logf("[socks4-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) +} + +func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { + conn, err := chain.Conn() + if err != nil { + return nil, err + } + cc, err := socks5Handshake(conn, chain.LastNode().User) + if err != nil { + conn.Close() + return nil, err + } + conn = cc + + conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) + req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(addr)) + if err := req.Write(conn); err != nil { + conn.Close() + return nil, err + } + if Debug { + log.Log("[socks5]", req) + } + conn.SetWriteDeadline(time.Time{}) + + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err := gosocks5.ReadReply(conn) + if err != nil { + conn.Close() + return nil, err + } + conn.SetReadDeadline(time.Time{}) + if Debug { + log.Log("[socks5]", reply) + } + + if reply.Rep != gosocks5.Succeeded { + conn.Close() + return nil, errors.New("UDP tunnel failure") + } + return conn, nil +} + +func socks5Handshake(conn net.Conn, user *url.Userinfo) (net.Conn, error) { + selector := &clientSelector{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + User: user, + } + selector.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + MethodTLS, + ) + cc := gosocks5.ClientConn(conn, selector) + if err := cc.Handleshake(); err != nil { + return nil, err + } + return cc, nil +} + +type udpTunnelConn struct { + raddr string + net.Conn +} + +func (c *udpTunnelConn) Read(b []byte) (n int, err error) { + dgram, err := gosocks5.ReadUDPDatagram(c.Conn) + if err != nil { + return + } + n = copy(b, dgram.Data) + return +} + +func (c *udpTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + dgram, err := gosocks5.ReadUDPDatagram(c.Conn) + if err != nil { + return + } + n = copy(b, dgram.Data) + addr, err = net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + return +} + +func (c *udpTunnelConn) Write(b []byte) (n int, err error) { + addr, err := net.ResolveUDPAddr("udp", c.raddr) + if err != nil { + return + } + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) + if err = dgram.Write(c.Conn); err != nil { + return + } + return len(b), nil +} + +func (c *udpTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) + if err = dgram.Write(c.Conn); err != nil { + return + } + return len(b), nil } diff --git a/ss.go b/ss.go index e612b04..079015d 100644 --- a/ss.go +++ b/ss.go @@ -3,28 +3,25 @@ package gost import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "net" + "net/url" "strconv" "time" "github.com/ginuerzh/gosocks5" "github.com/go-log/log" - "github.com/golang/glog" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" ) // Due to in/out byte length is inconsistent of the shadowsocks.Conn.Write, -// we wrap around it to make io.Copy happy +// we wrap around it to make io.Copy happy. type shadowConn struct { conn net.Conn } -func ShadowConn(conn net.Conn) net.Conn { - return &shadowConn{conn: conn} -} - func (c *shadowConn) Read(b []byte) (n int, err error) { return c.conn.Read(b) } @@ -59,18 +56,61 @@ func (c *shadowConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } -type shadowHandler struct { - server Server +type shadowConnector struct { + Cipher *url.Userinfo } -func ShadowHandler(server Server) Handler { - return &shadowHandler{server: server} +// ShadowConnector creates a Connector for shadowsocks proxy client. +// It accepts a cipher info for shadowsocks data encryption/decryption. +// The cipher must not be nil. +func ShadowConnector(cipher *url.Userinfo) Connector { + return &shadowConnector{Cipher: cipher} +} + +func (c *shadowConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { + rawaddr, err := ss.RawAddr(addr) + if err != nil { + return nil, err + } + + var method, password string + if c.Cipher != nil { + method = c.Cipher.Username() + password, _ = c.Cipher.Password() + } + + cipher, err := ss.NewCipher(method, password) + if err != nil { + return nil, err + } + + sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher) + if err != nil { + return nil, err + } + return &shadowConn{conn: sc}, nil +} + +type shadowHandler struct { + options *HandlerOptions +} + +// ShadowHandler creates a server Handler for shadowsocks proxy server. +func ShadowHandler(opts ...HandlerOption) Handler { + h := &shadowHandler{ + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + return h } func (h *shadowHandler) Handle(conn net.Conn) { - var method, password string + defer conn.Close() - users := h.server.Options().BaseOptions().Users + var method, password string + users := h.options.Users if len(users) > 0 { method = users[0].Username() password, _ = users[0].Password() @@ -80,7 +120,7 @@ func (h *shadowHandler) Handle(conn net.Conn) { log.Log("[ss]", err) return } - conn = ShadowConn(ss.NewConn(conn, cipher)) + conn = &shadowConn{conn: ss.NewConn(conn, cipher)} log.Logf("[ss] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) @@ -91,7 +131,12 @@ func (h *shadowHandler) Handle(conn net.Conn) { } log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr) - cc, err := h.server.Chain().Dial(addr) + if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ss] Unauthorized to tcp connect to %s", addr) + return + } + + cc, err := h.options.Chain.Dial(addr) if err != nil { log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err) return @@ -99,7 +144,7 @@ func (h *shadowHandler) Handle(conn net.Conn) { defer cc.Close() log.Logf("[ss] %s <-> %s", conn.RemoteAddr(), addr) - Transport(conn, cc) + transport(conn, cc) log.Logf("[ss] %s >-< %s", conn.RemoteAddr(), addr) } @@ -124,7 +169,7 @@ func (h *shadowHandler) getRequest(conn net.Conn) (host string, err error) { // buf size should at least have the same size with the largest possible // request size (when addrType is 3, domain name has at most 256 bytes) // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) - buf := make([]byte, SmallBufferSize) + buf := make([]byte, smallBufferSize) // read till we get possible domain length field conn.SetReadDeadline(time.Now().Add(30 * time.Second)) @@ -170,117 +215,205 @@ func (h *shadowHandler) getRequest(conn net.Conn) (host string, err error) { return } -type ShadowUdpServer struct { - Base *ProxyServer - TTL int +type shadowUDPListener struct { + ln net.PacketConn + conns map[string]*udpServerConn + connChan chan net.Conn + errChan chan error + ttl time.Duration } -func NewShadowUdpServer(base *ProxyServer, ttl int) *ShadowUdpServer { - return &ShadowUdpServer{Base: base, TTL: ttl} +// ShadowUDPListener creates a Listener for shadowsocks UDP relay server. +func ShadowUDPListener(addr string, cipher *url.Userinfo, ttl time.Duration) (Listener, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenUDP("udp", laddr) + if err != nil { + return nil, err + } + + var method, password string + if cipher != nil { + method = cipher.Username() + password, _ = cipher.Password() + } + cp, err := ss.NewCipher(method, password) + if err != nil { + ln.Close() + return nil, err + } + l := &shadowUDPListener{ + ln: ss.NewSecurePacketConn(ln, cp, false), + conns: make(map[string]*udpServerConn), + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + ttl: ttl, + } + go l.listenLoop() + return l, nil } -func (s *ShadowUdpServer) ListenAndServe() error { - laddr, err := net.ResolveUDPAddr("udp", s.Base.Node.Addr) - if err != nil { - return err - } - lconn, err := net.ListenUDP("udp", laddr) - if err != nil { - return err - } - defer lconn.Close() +func (l *shadowUDPListener) listenLoop() { + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := l.ln.ReadFrom(b) + if err != nil { + log.Logf("[ssu] peer -> %s : %s", l.Addr(), err) + l.ln.Close() + l.errChan <- err + close(l.errChan) + return + } + if Debug { + log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n) + } - conn := ss.NewSecurePacketConn(lconn, s.Base.cipher.Copy(), true) // force OTA on - - rChan, wChan := make(chan *packet, 128), make(chan *packet, 128) - // start send queue - go func(ch chan<- *packet) { - for { - b := make([]byte, MediumBufferSize) - n, addr, err := conn.ReadFrom(b[3:]) // add rsv and frag fields to make it the standard SOCKS5 UDP datagram - if err != nil { - glog.V(LWARNING).Infof("[ssu] %s -> %s : %s", addr, laddr, err) - continue - } - - b[3] &= ss.AddrMask // remove OTA flag - dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n+3])) - if err != nil { - glog.V(LWARNING).Infof("[ssu] %s -> %s : %s", addr, laddr, err) - continue - } + conn, ok := l.conns[raddr.String()] + if !ok || conn.Closed() { + conn = newUDPServerConn(l.ln, raddr, l.ttl) + l.conns[raddr.String()] = conn select { - case ch <- &packet{srcAddr: addr.String(), dstAddr: dgram.Header.Addr.String(), data: dgram.Data}: - case <-time.After(time.Second * 3): - glog.V(LWARNING).Infof("[ssu] %s -> %s : %s", addr, dgram.Header.Addr.String(), "send queue is full, discard") + case l.connChan <- conn: + default: + conn.Close() + log.Logf("[ssu] %s - %s: connection queue is full", raddr, l.Addr()) } } - }(wChan) - // start recv queue - go func(ch <-chan *packet) { - for pkt := range ch { - srcAddr, err := net.ResolveUDPAddr("udp", pkt.srcAddr) - if err != nil { - glog.V(LWARNING).Infof("[ssu] %s <- %s : %s", pkt.dstAddr, pkt.srcAddr, err) - continue - } - dstAddr, err := net.ResolveUDPAddr("udp", pkt.dstAddr) - if err != nil { - glog.V(LWARNING).Infof("[ssu] %s <- %s : %s", pkt.dstAddr, pkt.srcAddr, err) - continue - } - - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, ToSocksAddr(srcAddr)), pkt.data) - b := bytes.Buffer{} - dgram.Write(&b) - if b.Len() < 10 { - glog.V(LWARNING).Infof("[ssu] %s <- %s : invalid udp datagram", pkt.dstAddr, pkt.srcAddr) - continue - } - - if _, err := conn.WriteTo(b.Bytes()[3:], dstAddr); err != nil { // remove rsv and frag fields to make it standard shadowsocks UDP datagram - glog.V(LWARNING).Infof("[ssu] %s <- %s : %s", pkt.dstAddr, pkt.srcAddr, err) - return - } - } - }(rChan) - - // mapping client to node - m := make(map[string]*cnode) - - // start dispatcher - for pkt := range wChan { - // clear obsolete nodes - for k, node := range m { - if node != nil && node.err != nil { - close(node.wChan) - delete(m, k) - glog.V(LINFO).Infof("[ssu] clear node %s", k) - } - } - - node, ok := m[pkt.srcAddr] - if !ok { - node = &cnode{ - chain: s.Base.Chain, - srcAddr: pkt.srcAddr, - dstAddr: pkt.dstAddr, - rChan: rChan, - wChan: make(chan *packet, 32), - ttl: time.Duration(s.TTL) * time.Second, - } - m[pkt.srcAddr] = node - go node.run() - glog.V(LINFO).Infof("[ssu] %s -> %s : new client (%d)", pkt.srcAddr, pkt.dstAddr, len(m)) - } select { - case node.wChan <- pkt: - case <-time.After(time.Second * 3): - glog.V(LWARNING).Infof("[ssu] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, "node send queue is full, discard") + case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination. + default: + log.Logf("[ssu] %s -> %s : read queue is full", raddr, l.Addr()) } } - - return nil +} + +func (l *shadowUDPListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *shadowUDPListener) Addr() net.Addr { + return l.ln.LocalAddr() +} + +func (l *shadowUDPListener) Close() error { + return l.ln.Close() +} + +type shadowUDPdHandler struct { + ttl time.Duration + options *HandlerOptions +} + +// ShadowUDPdHandler creates a server Handler for shadowsocks UDP relay server. +func ShadowUDPdHandler(opts ...HandlerOption) Handler { + h := &shadowUDPdHandler{ + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *shadowUDPdHandler) Handle(conn net.Conn) { + defer conn.Close() + + var err error + var cc net.PacketConn + if h.options.Chain.IsEmpty() { + cc, err = net.ListenUDP("udp", nil) + if err != nil { + log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err) + return + } + } else { + var c net.Conn + c, err = getSOCKS5UDPTunnel(h.options.Chain, nil) + if err != nil { + log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err) + return + } + cc = &udpTunnelConn{Conn: c} + } + defer cc.Close() + + log.Logf("[ssu] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + transportUDP(conn, cc) + log.Logf("[ssu] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) +} + +func transportUDP(sc net.Conn, cc net.PacketConn) error { + errc := make(chan error, 1) + go func() { + for { + b := make([]byte, mediumBufferSize) + n, err := sc.Read(b[3:]) // add rsv and frag fields to make it the standard SOCKS5 UDP datagram + if err != nil { + // log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err) + errc <- err + return + } + dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n+3])) + if err != nil { + log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err) + errc <- err + return + } + //if Debug { + // log.Logf("[ssu] %s >>> %s length: %d", sc.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data)) + //} + addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + if err != nil { + errc <- err + return + } + if _, err := cc.WriteTo(dgram.Data, addr); err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + b := make([]byte, mediumBufferSize) + n, addr, err := cc.ReadFrom(b) + if err != nil { + errc <- err + return + } + //if Debug { + // log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n) + //} + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n]) + buf := bytes.Buffer{} + dgram.Write(&buf) + if buf.Len() < 10 { + log.Logf("[ssu] %s <- %s : invalid udp datagram", sc.RemoteAddr(), addr) + continue + } + if _, err := sc.Write(buf.Bytes()[3:]); err != nil { + errc <- err + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err } diff --git a/ssh.go b/ssh.go index 4e616d6..a2e1fe9 100644 --- a/ssh.go +++ b/ssh.go @@ -1,15 +1,19 @@ -// The ssh tunnel is inspired by easyssh(https://dev.justinjudd.org/justin/easyssh) - package gost import ( + "context" + "crypto/tls" "encoding/binary" + "errors" "fmt" "net" "net/url" "strconv" + "strings" + "sync" + "time" - "github.com/golang/glog" + "github.com/go-log/log" "golang.org/x/crypto/ssh" ) @@ -19,58 +23,436 @@ const ( RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1 + + GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel ) -type SSHServer struct { - Addr string - Base *ProxyServer - Config *ssh.ServerConfig - Handler func(ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) +var ( + errSessionDead = errors.New("session is dead") +) + +type sshDirectForwardConnector struct { } -func (s *SSHServer) ListenAndServe() error { - ln, err := net.Listen("tcp", s.Addr) - if err != nil { - glog.V(LWARNING).Infoln("[ssh] Listen:", err) - return err +func SSHDirectForwardConnector() Connector { + return &sshDirectForwardConnector{} +} + +func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string) (net.Conn, error) { + cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. + if !ok { + return nil, errors.New("ssh: wrong connection type") } - defer ln.Close() + conn, err := cc.session.client.Dial("tcp", raddr) + if err != nil { + log.Logf("[ssh-tcp] %s -> %s : %s", cc.session.addr, raddr, err) + return nil, err + } + return conn, nil +} - for { - conn, err := ln.Accept() - if err != nil { - glog.V(LWARNING).Infoln("[ssh] Accept:", err) - return err - } +type sshRemoteForwardConnector struct { +} - go func(conn net.Conn) { - sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config) - if err != nil { - glog.V(LWARNING).Infof("[ssh] %s -> %s : %s", conn.RemoteAddr(), s.Addr, err) +func SSHRemoteForwardConnector() Connector { + return &sshRemoteForwardConnector{} +} + +func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { + cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. + if !ok { + return nil, errors.New("ssh: wrong connection type") + } + + cc.session.once.Do(func() { + go func() { + defer log.Log("ssh-rtcp: session is closed") + defer close(cc.session.connChan) + + if cc.session == nil || cc.session.client == nil { return } - defer sshConn.Close() - - if s.Handler == nil { - s.Handler = s.handleSSHConn + if strings.HasPrefix(addr, ":") { + addr = "0.0.0.0" + addr } + ln, err := cc.session.client.Listen("tcp", addr) + if err != nil { + return + } + for { + rc, err := ln.Accept() + if err != nil { + log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), addr, err) + return + } - glog.V(LINFO).Infof("[ssh] %s <-> %s", conn.RemoteAddr(), s.Addr) - s.Handler(sshConn, chans, reqs) - glog.V(LINFO).Infof("[ssh] %s >-< %s", conn.RemoteAddr(), s.Addr) - }(conn) + select { + case cc.session.connChan <- rc: + default: + rc.Close() + log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), addr) + } + } + }() + }) + + sc, ok := <-cc.session.connChan + if !ok { + return nil, errors.New("ssh-rtcp: connection is closed") + } + return sc, nil +} + +type sshForwardTransporter struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex +} + +func SSHForwardTransporter() Transporter { + return &sshForwardTransporter{ + sessions: make(map[string]*sshSession), } } -func (s *SSHServer) handleSSHConn(conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { - quit := make(chan interface{}) +func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if !ok || session.Closed() { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &sshSession{ + addr: addr, + conn: conn, + } + tr.sessions[addr] = session + } + + return session.conn, nil +} + +func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + config := ssh.ClientConfig{ + Timeout: opts.Timeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + if opts.User != nil { + config.User = opts.User.Username() + password, _ := opts.User.Password() + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("ssh: unrecognized connection") + } + if !ok || session.client == nil { + sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + session = &sshSession{ + addr: opts.Addr, + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + deaded: make(chan struct{}), + connChan: make(chan net.Conn, 1024), + } + tr.sessions[opts.Addr] = session + go session.Ping(opts.Interval, opts.Timeout, 1) + go session.waitServer() + go session.waitClose() + } + if session.Closed() { + delete(tr.sessions, opts.Addr) + return nil, errSessionDead + } + + return &sshNopConn{session: session}, nil +} + +func (tr *sshForwardTransporter) Multiplex() bool { + return true +} + +type sshTunnelTransporter struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex +} + +// SSHTunnelTransporter creates a Transporter that is used by SSH tunnel client. +func SSHTunnelTransporter() Transporter { + return &sshTunnelTransporter{ + sessions: make(map[string]*sshSession), + } +} + +func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if !ok || session.Closed() { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &sshSession{ + addr: addr, + conn: conn, + } + tr.sessions[addr] = session + } + + return session.conn, nil +} + +func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + config := ssh.ClientConfig{ + Timeout: opts.Timeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + if opts.User != nil { + config.User = opts.User.Username() + password, _ := opts.User.Password() + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("ssh: unrecognized connection") + } + if !ok || session.client == nil { + sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + session = &sshSession{ + addr: opts.Addr, + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + deaded: make(chan struct{}), + } + tr.sessions[opts.Addr] = session + go session.Ping(opts.Interval, opts.Timeout, 1) + go session.waitServer() + go session.waitClose() + } + + if session.Closed() { + delete(tr.sessions, opts.Addr) + return nil, errSessionDead + } + + channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) + if err != nil { + return nil, err + } + go ssh.DiscardRequests(reqs) + return &sshConn{channel: channel, conn: conn}, nil +} + +func (tr *sshTunnelTransporter) Multiplex() bool { + return true +} + +type sshSession struct { + addr string + conn net.Conn + client *ssh.Client + closed chan struct{} + deaded chan struct{} + once sync.Once + connChan chan net.Conn +} + +func (s *sshSession) Ping(interval, timeout time.Duration, retries int) { + if interval <= 0 { + return + } + if timeout <= 0 { + timeout = 30 * time.Second + } + defer close(s.deaded) + + log.Log("[ssh] ping is enabled, interval:", interval) + baseCtx := context.Background() + t := time.NewTicker(interval) + defer t.Stop() + + for { + select { + case <-t.C: + start := time.Now() + //if Debug { + log.Log("[ssh] sending ping") + //} + ctx, cancel := context.WithTimeout(baseCtx, timeout) + var err error + select { + case err = <-s.sendPing(): + case <-ctx.Done(): + err = errors.New("Timeout") + } + cancel() + if err != nil { + log.Log("[ssh] ping:", err) + return + } + //if Debug { + log.Log("[ssh] ping OK, RTT:", time.Since(start)) + //} + + case <-s.closed: + return + } + } +} + +func (s *sshSession) sendPing() <-chan error { + ch := make(chan error, 1) + go func() { + if _, _, err := s.client.SendRequest("ping", true, nil); err != nil { + ch <- err + } + close(ch) + }() + return ch +} + +func (s *sshSession) waitServer() error { + defer close(s.closed) + return s.client.Wait() +} + +func (s *sshSession) waitClose() { + defer s.client.Close() + + select { + case <-s.deaded: + case <-s.closed: + } +} + +func (s *sshSession) Closed() bool { + select { + case <-s.deaded: + return true + case <-s.closed: + return true + default: + } + return false +} + +type sshForwardHandler struct { + options *HandlerOptions + config *ssh.ServerConfig +} + +func SSHForwardHandler(opts ...HandlerOption) Handler { + h := &sshForwardHandler{ + options: new(HandlerOptions), + config: new(ssh.ServerConfig), + } + for _, opt := range opts { + opt(h.options) + } + h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Users...) + if len(h.options.Users) == 0 { + h.config.NoClientAuth = true + } + tlsConfig := h.options.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + if tlsConfig != nil && len(tlsConfig.Certificates) > 0 { + signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey) + if err != nil { + log.Log("[ssh-forward]", err) + } + h.config.AddHostKey(signer) + } + + return h +} + +func (h *sshForwardHandler) Handle(conn net.Conn) { + sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config) + if err != nil { + log.Logf("[ssh-forward] %s -> %s : %s", conn.RemoteAddr(), h.options.Addr, err) + conn.Close() + return + } + defer sshConn.Close() + + log.Logf("[ssh-forward] %s <-> %s", conn.RemoteAddr(), h.options.Addr) + h.handleForward(sshConn, chans, reqs) + log.Logf("[ssh-forward] %s >-< %s", conn.RemoteAddr(), h.options.Addr) +} + +func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + quit := make(chan struct{}) + defer close(quit) // quit signal + go func() { for req := range reqs { switch req.Type { case RemoteForwardRequest: - go s.tcpipForwardRequest(conn, req, quit) + go h.tcpipForwardRequest(conn, req, quit) default: - // glog.V(LWARNING).Infoln("unknown channel type:", req.Type) + // log.Log("[ssh] unknown channel type:", req.Type) if req.WantReply { req.Reply(false, nil) } @@ -86,7 +468,7 @@ func (s *SSHServer) handleSSHConn(conn ssh.Conn, chans <-chan ssh.NewChannel, re case DirectForwardRequest: channel, requests, err := newChannel.Accept() if err != nil { - glog.V(LINFO).Infoln("[ssh] Could not accept channel:", err) + log.Log("[ssh] Could not accept channel:", err) continue } p := directForward{} @@ -97,50 +479,37 @@ func (s *SSHServer) handleSSHConn(conn ssh.Conn, chans <-chan ssh.NewChannel, re } go ssh.DiscardRequests(requests) - go s.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) + go h.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) default: - glog.V(LWARNING).Infoln("[ssh] Unknown channel type:", t) + log.Log("[ssh] Unknown channel type:", t) newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) } } }() conn.Wait() - close(quit) } -// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" -type directForward struct { - Host1 string - Port1 uint32 - Host2 string - Port2 uint32 -} - -func (p directForward) String() string { - return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) -} - -func (s *SSHServer) directPortForwardChannel(channel ssh.Channel, raddr string) { +func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr string) { defer channel.Close() - glog.V(LINFO).Infof("[ssh-tcp] %s - %s", s.Addr, raddr) + log.Logf("[ssh-tcp] %s - %s", h.options.Addr, raddr) - if !s.Base.Node.Can("tcp", raddr) { - glog.Errorf("Unauthorized to tcp connect to %s", raddr) + if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr) return } - conn, err := s.Base.Chain.Dial(raddr) + conn, err := h.options.Chain.Dial(raddr) if err != nil { - glog.V(LINFO).Infof("[ssh-tcp] %s - %s : %s", s.Addr, raddr, err) + log.Logf("[ssh-tcp] %s - %s : %s", h.options.Addr, raddr, err) return } defer conn.Close() - glog.V(LINFO).Infof("[ssh-tcp] %s <-> %s", s.Addr, raddr) - Transport(conn, channel) - glog.V(LINFO).Infof("[ssh-tcp] %s >-< %s", s.Addr, raddr) + log.Logf("[ssh-tcp] %s <-> %s", h.options.Addr, raddr) + transport(conn, channel) + log.Logf("[ssh-tcp] %s >-< %s", h.options.Addr, raddr) } // tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request @@ -149,22 +518,22 @@ type tcpipForward struct { Port uint32 } -func (s *SSHServer) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan interface{}) { +func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) { t := tcpipForward{} ssh.Unmarshal(req.Payload, &t) addr := fmt.Sprintf("%s:%d", t.Host, t.Port) - if !s.Base.Node.Can("rtcp", addr) { - glog.Errorf("Unauthorized to tcp bind to %s", addr) + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr) req.Reply(false, nil) return } - glog.V(LINFO).Infoln("[ssh-rtcp] listening tcp", addr) + log.Log("[ssh-rtcp] listening on tcp", addr) ln, err := net.Listen("tcp", addr) //tie to the client connection if err != nil { - glog.V(LWARNING).Infoln("[ssh-rtcp]", err) + log.Log("[ssh-rtcp]", err) req.Reply(false, nil) return } @@ -184,14 +553,14 @@ func (s *SSHServer) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit return req.Reply(true, nil) } if err := replyFunc(); err != nil { - glog.V(LWARNING).Infoln("[ssh-rtcp]", err) + log.Log("[ssh-rtcp]", err) return } go func() { for { conn, err := ln.Accept() - if err != nil { // Unable to accept new connection - listener likely closed + if err != nil { // Unable to accept new connection - listener is likely closed return } @@ -210,18 +579,17 @@ func (s *SSHServer) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit } p.Port2 = uint32(portnum) - glog.V(3).Info(p) ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) if err != nil { - glog.V(1).Infoln("[ssh-rtcp] open forwarded channel:", err) + log.Log("[ssh-rtcp] open forwarded channel:", err) return } defer ch.Close() go ssh.DiscardRequests(reqs) - glog.V(LINFO).Infof("[ssh-rtcp] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - Transport(ch, conn) - glog.V(LINFO).Infof("[ssh-rtcp] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) + log.Logf("[ssh-rtcp] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + transport(ch, conn) + log.Logf("[ssh-rtcp] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) }(conn) } }() @@ -229,6 +597,139 @@ func (s *SSHServer) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-quit } +// SSHConfig holds the SSH tunnel server config +type SSHConfig struct { + Users []*url.Userinfo + TLSConfig *tls.Config +} + +type sshTunnelListener struct { + net.Listener + config *ssh.ServerConfig + connChan chan net.Conn + errChan chan error +} + +// SSHTunnelListener creates a Listener for SSH tunnel server. +func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + if config == nil { + config = &SSHConfig{} + } + + sshConfig := &ssh.ServerConfig{} + sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Users...) + if len(config.Users) == 0 { + sshConfig.NoClientAuth = true + } + tlsConfig := config.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + + signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey) + if err != nil { + ln.Close() + return nil, err + + } + sshConfig.AddHostKey(signer) + + l := &sshTunnelListener{ + Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, + config: sshConfig, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + + go l.listenLoop() + + return l, nil +} + +func (l *sshTunnelListener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + log.Log("[ssh] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.serveConn(conn) + } +} + +func (l *sshTunnelListener) serveConn(conn net.Conn) { + sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) + if err != nil { + log.Logf("[ssh] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + conn.Close() + return + } + defer sc.Close() + + go ssh.DiscardRequests(reqs) + go func() { + for newChannel := range chans { + // Check the type of channel + t := newChannel.ChannelType() + switch t { + case GostSSHTunnelRequest: + channel, requests, err := newChannel.Accept() + if err != nil { + log.Log("[ssh] Could not accept channel:", err) + continue + } + go ssh.DiscardRequests(requests) + cc := &sshConn{conn: conn, channel: channel} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[ssh] %s - %s: connection queue is full", conn.RemoteAddr(), l.Addr()) + } + + default: + log.Log("[ssh] Unknown channel type:", t) + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + } + } + }() + + log.Logf("[ssh] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + sc.Wait() + log.Logf("[ssh] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) +} + +func (l *sshTunnelListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" +type directForward struct { + Host1 string + Port1 uint32 + Host2 string + Port2 uint32 +} + +func (p directForward) String() string { + return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) +} + func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { host, portString, err := net.SplitHostPort(addr.String()) if err != nil { @@ -240,7 +741,7 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) -func DefaultPasswordCallback(users []*url.Userinfo) PasswordCallbackFunc { +func defaultSSHPasswordCallback(users ...*url.Userinfo) PasswordCallbackFunc { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { for _, user := range users { u := user.Username() @@ -249,7 +750,86 @@ func DefaultPasswordCallback(users []*url.Userinfo) PasswordCallbackFunc { return nil, nil } } - glog.V(LINFO).Infof("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) + log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) return nil, fmt.Errorf("password rejected for %s", conn.User()) } } + +type sshNopConn struct { + session *sshSession +} + +func (c *sshNopConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *sshNopConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *sshNopConn) Close() error { + return nil +} + +func (c *sshNopConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } +} + +func (c *sshNopConn) RemoteAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } +} + +func (c *sshNopConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshNopConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshNopConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +type sshConn struct { + channel ssh.Channel + conn net.Conn +} + +func (c *sshConn) Read(b []byte) (n int, err error) { + return c.channel.Read(b) +} + +func (c *sshConn) Write(b []byte) (n int, err error) { + return c.channel.Write(b) +} + +func (c *sshConn) Close() error { + return c.channel.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/gost/tls.go b/tls.go similarity index 95% rename from gost/tls.go rename to tls.go index c65f6d0..d01c281 100644 --- a/gost/tls.go +++ b/tls.go @@ -34,11 +34,14 @@ type tlsListener struct { // TLSListener creates a Listener for TLS proxy server. func TLSListener(addr string, config *tls.Config) (Listener, error) { + if config == nil { + config = DefaultTLSConfig + } ln, err := tls.Listen("tcp", addr, config) if err != nil { return nil, err } - return &tlsListener{ln}, nil + return &tlsListener{tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil } // Wrap a net.Conn into a client tls connection, performing any diff --git a/ws.go b/ws.go index e92972a..8be51eb 100644 --- a/ws.go +++ b/ws.go @@ -5,119 +5,57 @@ import ( "net" "net/http" "net/http/httputil" - "strconv" "time" - "github.com/golang/glog" + "net/url" + + "github.com/go-log/log" "gopkg.in/gorilla/websocket.v1" ) -type WebsocketServer struct { - Addr string - Base *ProxyServer - Handler http.Handler - upgrader websocket.Upgrader -} - -func NewWebsocketServer(base *ProxyServer) *WebsocketServer { - rbuf, _ := strconv.Atoi(base.Node.Get("rbuf")) - wbuf, _ := strconv.Atoi(base.Node.Get("wbuf")) - comp := base.Node.getBool("compression") - - return &WebsocketServer{ - Addr: base.Node.Addr, - Base: base, - upgrader: websocket.Upgrader{ - ReadBufferSize: rbuf, - WriteBufferSize: wbuf, - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: comp, - }, - } -} - -// Default websocket server handler -func (s *WebsocketServer) HandleRequest(w http.ResponseWriter, r *http.Request) { - glog.V(LINFO).Infof("[ws] %s - %s", r.RemoteAddr, s.Addr) - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(r, false) - glog.V(LDEBUG).Infof("[ws] %s - %s\n%s", r.RemoteAddr, s.Addr, string(dump)) - } - conn, err := s.upgrader.Upgrade(w, r, nil) - if err != nil { - glog.V(LERROR).Infof("[ws] %s - %s : %s", r.RemoteAddr, s.Addr, err) - return - } - s.Base.handleConn(WebsocketServerConn(conn)) -} - -func (s *WebsocketServer) ListenAndServe() error { - mux := http.NewServeMux() - if s.Handler == nil { - s.Handler = http.HandlerFunc(s.HandleRequest) - } - mux.Handle("/ws", s.Handler) - return http.ListenAndServe(s.Addr, mux) -} - -func (s *WebsocketServer) ListenAndServeTLS(config *tls.Config) error { - mux := http.NewServeMux() - if s.Handler == nil { - s.Handler = http.HandlerFunc(s.HandleRequest) - } - mux.Handle("/ws", s.Handler) - server := &http.Server{ - Addr: s.Addr, - Handler: mux, - TLSConfig: config, - } - return server.ListenAndServeTLS("", "") -} - +// WSOptions describes the options for websocket. type WSOptions struct { ReadBufferSize int WriteBufferSize int HandshakeTimeout time.Duration EnableCompression bool - TLSConfig *tls.Config } -type WebsocketConn struct { +type websocketConn struct { conn *websocket.Conn rb []byte } -func WebsocketClientConn(url string, conn net.Conn, options *WSOptions) (*WebsocketConn, error) { +func websocketClientConn(url string, conn net.Conn, tlsConfig *tls.Config, options *WSOptions) (net.Conn, error) { if options == nil { options = &WSOptions{} } dialer := websocket.Dialer{ ReadBufferSize: options.ReadBufferSize, WriteBufferSize: options.WriteBufferSize, - TLSClientConfig: options.TLSConfig, + TLSClientConfig: tlsConfig, HandshakeTimeout: options.HandshakeTimeout, EnableCompression: options.EnableCompression, NetDial: func(net, addr string) (net.Conn, error) { return conn, nil }, } - c, resp, err := dialer.Dial(url, nil) if err != nil { return nil, err } resp.Body.Close() - return &WebsocketConn{conn: c}, nil + return &websocketConn{conn: c}, nil } -func WebsocketServerConn(conn *websocket.Conn) *WebsocketConn { - conn.EnableWriteCompression(true) - return &WebsocketConn{ +func websocketServerConn(conn *websocket.Conn) net.Conn { + // conn.EnableWriteCompression(true) + return &websocketConn{ conn: conn, } } -func (c *WebsocketConn) Read(b []byte) (n int, err error) { +func (c *websocketConn) Read(b []byte) (n int, err error) { if len(c.rb) == 0 { _, c.rb, err = c.conn.ReadMessage() } @@ -126,34 +64,236 @@ func (c *WebsocketConn) Read(b []byte) (n int, err error) { return } -func (c *WebsocketConn) Write(b []byte) (n int, err error) { +func (c *websocketConn) Write(b []byte) (n int, err error) { err = c.conn.WriteMessage(websocket.BinaryMessage, b) n = len(b) return } -func (c *WebsocketConn) Close() error { +func (c *websocketConn) Close() error { return c.conn.Close() } -func (c *WebsocketConn) LocalAddr() net.Addr { +func (c *websocketConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } -func (c *WebsocketConn) RemoteAddr() net.Addr { +func (c *websocketConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } -func (conn *WebsocketConn) SetDeadline(t time.Time) error { - if err := conn.SetReadDeadline(t); err != nil { +func (c *websocketConn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { return err } - return conn.SetWriteDeadline(t) + return c.SetWriteDeadline(t) } -func (c *WebsocketConn) SetReadDeadline(t time.Time) error { +func (c *websocketConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *WebsocketConn) SetWriteDeadline(t time.Time) error { +func (c *websocketConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +type wsTransporter struct { + *tcpTransporter + options *WSOptions +} + +// WSTransporter creates a Transporter that is used by websocket proxy client. +func WSTransporter(opts *WSOptions) Transporter { + return &wsTransporter{ + options: opts, + } +} + +func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"} + return websocketClientConn(url.String(), conn, nil, wsOptions) +} + +type wssTransporter struct { + *tcpTransporter + options *WSOptions +} + +// WSSTransporter creates a Transporter that is used by websocket secure proxy client. +func WSSTransporter(opts *WSOptions) Transporter { + return &wssTransporter{ + options: opts, + } +} + +func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"} + return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) +} + +type wsListener struct { + addr net.Addr + upgrader *websocket.Upgrader + srv *http.Server + connChan chan net.Conn + errChan chan error +} + +// WSListener creates a Listener for websocket proxy server. +func WSListener(addr string, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &wsListener{ + addr: tcpAddr, + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + + mux := http.NewServeMux() + mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{Addr: addr, Handler: mux} + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + + go func() { + err := l.srv.Serve(tcpKeepAliveListener{ln}) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +} + +func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { + log.Logf("[ws] %s -> %s", r.RemoteAddr, l.addr) + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Log(string(dump)) + } + conn, err := l.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Logf("[ws] %s - %s : %s", r.RemoteAddr, l.addr, err) + return + } + select { + case l.connChan <- websocketServerConn(conn): + default: + conn.Close() + log.Logf("[ws] %s - %s: connection queue is full", r.RemoteAddr, l.addr) + } +} + +func (l *wsListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + } + return +} + +func (l *wsListener) Close() error { + return l.srv.Close() +} + +func (l *wsListener) Addr() net.Addr { + return l.addr +} + +type wssListener struct { + *wsListener +} + +// WSSListener creates a Listener for websocket secure proxy server. +func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &wssListener{ + wsListener: &wsListener{ + addr: tcpAddr, + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + }, + } + + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + + mux := http.NewServeMux() + mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: addr, + TLSConfig: tlsConfig, + Handler: mux, + } + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + + go func() { + err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +}