gost_software/vendor/github.com/ginuerzh/gost/ws.go
2017-03-24 11:24:03 +02:00

143 lines
3.2 KiB
Go

package gost
import (
"crypto/tls"
"github.com/golang/glog"
"gopkg.in/gorilla/websocket.v1"
"net"
"net/http"
"net/http/httputil"
"time"
)
type WebsocketServer struct {
Addr string
Base *ProxyServer
Handler http.Handler
upgrader websocket.Upgrader
}
func NewWebsocketServer(base *ProxyServer) *WebsocketServer {
return &WebsocketServer{
Addr: base.Node.Addr,
Base: base,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: true,
},
}
}
// Default websocket server handler
func (s *WebsocketServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
glog.V(LINFO).Infof("[ws] %s - %s", r.RemoteAddr, s.Addr)
if glog.V(LDEBUG) {
dump, _ := httputil.DumpRequest(r, false)
glog.V(LDEBUG).Infof("[ws] %s - %s\n%s", r.RemoteAddr, s.Addr, string(dump))
}
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
glog.V(LERROR).Infof("[ws] %s - %s : %s", r.RemoteAddr, s.Addr, err)
return
}
s.Base.handleConn(WebsocketServerConn(conn))
}
func (s *WebsocketServer) ListenAndServe() error {
mux := http.NewServeMux()
if s.Handler == nil {
s.Handler = http.HandlerFunc(s.HandleRequest)
}
mux.Handle("/ws", s.Handler)
return http.ListenAndServe(s.Addr, mux)
}
func (s *WebsocketServer) ListenAndServeTLS(config *tls.Config) error {
mux := http.NewServeMux()
if s.Handler == nil {
s.Handler = http.HandlerFunc(s.HandleRequest)
}
mux.Handle("/ws", s.Handler)
server := &http.Server{
Addr: s.Addr,
Handler: mux,
TLSConfig: config,
}
return server.ListenAndServeTLS("", "")
}
type WebsocketConn struct {
conn *websocket.Conn
rb []byte
}
func WebsocketClientConn(url string, conn net.Conn, config *tls.Config) (*WebsocketConn, error) {
dialer := websocket.Dialer{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
TLSClientConfig: config,
HandshakeTimeout: DialTimeout,
EnableCompression: true,
NetDial: func(net, addr string) (net.Conn, error) {
return conn, nil
},
}
c, resp, err := dialer.Dial(url, nil)
if err != nil {
return nil, err
}
resp.Body.Close()
return &WebsocketConn{conn: c}, nil
}
func WebsocketServerConn(conn *websocket.Conn) *WebsocketConn {
conn.EnableWriteCompression(true)
return &WebsocketConn{
conn: conn,
}
}
func (c *WebsocketConn) Read(b []byte) (n int, err error) {
if len(c.rb) == 0 {
_, c.rb, err = c.conn.ReadMessage()
}
n = copy(b, c.rb)
c.rb = c.rb[n:]
return
}
func (c *WebsocketConn) Write(b []byte) (n int, err error) {
err = c.conn.WriteMessage(websocket.BinaryMessage, b)
n = len(b)
return
}
func (c *WebsocketConn) Close() error {
return c.conn.Close()
}
func (c *WebsocketConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *WebsocketConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (conn *WebsocketConn) SetDeadline(t time.Time) error {
if err := conn.SetReadDeadline(t); err != nil {
return err
}
return conn.SetWriteDeadline(t)
}
func (c *WebsocketConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *WebsocketConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}