gost_software/cmd/gost/vendor/github.com/ginuerzh/pht/server.go
2017-02-05 14:35:38 +08:00

200 lines
3.8 KiB
Go

package pht
import (
"bufio"
"encoding/base64"
"fmt"
"net"
"net/http"
"strings"
"time"
)
const (
tokenURI = "/token"
pushURI = "/push"
pollURI = "/poll"
)
type Server struct {
Addr string
Key string
Handler func(net.Conn)
manager *sessionManager
}
func (s *Server) ListenAndServe() error {
s.manager = newSessionManager()
mux := http.NewServeMux()
mux.Handle(tokenURI, http.HandlerFunc(s.tokenHandler))
mux.Handle(pushURI, http.HandlerFunc(s.pushHandler))
mux.Handle(pollURI, http.HandlerFunc(s.pollHandler))
return http.ListenAndServe(s.Addr, mux)
}
func (s *Server) tokenHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
m := parseAuth(r.Header.Get("Authorization"))
if m["key"] != s.Key {
w.WriteHeader(http.StatusForbidden)
return
}
token, session, err := s.manager.NewSession(0, 0)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
conn, err := s.upgrade(session, r)
if err != nil {
s.manager.DelSession(token)
w.WriteHeader(http.StatusInternalServerError)
return
}
if s.Handler != nil {
go s.Handler(conn)
}
w.Write([]byte(fmt.Sprintf("token=%s", token)))
}
func (s *Server) pushHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
m := parseAuth(r.Header.Get("Authorization"))
if m["key"] != s.Key {
w.WriteHeader(http.StatusForbidden)
return
}
token := m["token"]
session := s.manager.GetSession(token)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
br := bufio.NewReader(r.Body)
data, err := br.ReadString('\n')
if err != nil {
s.manager.DelSession(token)
close(session.rchan)
w.WriteHeader(http.StatusInternalServerError)
return
}
data = strings.TrimSuffix(data, "\n")
if len(data) == 0 {
s.manager.DelSession(token)
close(session.rchan)
return
}
b, err := base64.StdEncoding.DecodeString(data)
if err != nil {
s.manager.DelSession(token)
close(session.rchan)
return
}
select {
case <-session.closed:
s.manager.DelSession(token)
return
case session.rchan <- b:
w.WriteHeader(http.StatusOK)
case <-time.After(time.Second * 90):
s.manager.DelSession(token)
w.WriteHeader(http.StatusRequestTimeout)
}
}
func (s *Server) pollHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
m := parseAuth(r.Header.Get("Authorization"))
if m["key"] != s.Key {
w.WriteHeader(http.StatusForbidden)
return
}
token := m["token"]
session := s.manager.GetSession(token)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
for {
select {
case data, ok := <-session.wchan:
if !ok {
s.manager.DelSession(token)
return // session is closed
}
bw := bufio.NewWriter(w)
bw.WriteString(base64.StdEncoding.EncodeToString(data))
bw.WriteString("\n")
if err := bw.Flush(); err != nil {
return
}
if fw, ok := w.(http.Flusher); ok {
fw.Flush()
}
case <-time.After(time.Second * 25):
return
}
}
}
func (s *Server) upgrade(sess *session, r *http.Request) (net.Conn, error) {
conn := newConn(sess)
raddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if err != nil {
raddr = &net.TCPAddr{}
}
conn.remoteAddr = raddr
laddr, err := net.ResolveTCPAddr("tcp", s.Addr)
if err != nil {
laddr = &net.TCPAddr{}
}
conn.localAddr = laddr
return conn, nil
}
func parseAuth(auth string) map[string]string {
mkv := make(map[string]string)
for _, s := range strings.Split(auth, ";") {
n := strings.Index(s, "=")
if n < 0 {
continue
}
mkv[strings.TrimSpace(s[:n])] = strings.TrimSpace(s[n+1:])
}
return mkv
}