200 lines
3.8 KiB
Go
200 lines
3.8 KiB
Go
package pht
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
tokenURI = "/token"
|
|
pushURI = "/push"
|
|
pollURI = "/poll"
|
|
)
|
|
|
|
type Server struct {
|
|
Addr string
|
|
Key string
|
|
Handler func(net.Conn)
|
|
manager *sessionManager
|
|
}
|
|
|
|
func (s *Server) ListenAndServe() error {
|
|
s.manager = newSessionManager()
|
|
|
|
mux := http.NewServeMux()
|
|
mux.Handle(tokenURI, http.HandlerFunc(s.tokenHandler))
|
|
mux.Handle(pushURI, http.HandlerFunc(s.pushHandler))
|
|
mux.Handle(pollURI, http.HandlerFunc(s.pollHandler))
|
|
|
|
return http.ListenAndServe(s.Addr, mux)
|
|
}
|
|
|
|
func (s *Server) tokenHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
m := parseAuth(r.Header.Get("Authorization"))
|
|
if m["key"] != s.Key {
|
|
w.WriteHeader(http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
token, session, err := s.manager.NewSession(0, 0)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
conn, err := s.upgrade(session, r)
|
|
if err != nil {
|
|
s.manager.DelSession(token)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if s.Handler != nil {
|
|
go s.Handler(conn)
|
|
}
|
|
|
|
w.Write([]byte(fmt.Sprintf("token=%s", token)))
|
|
}
|
|
|
|
func (s *Server) pushHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
m := parseAuth(r.Header.Get("Authorization"))
|
|
if m["key"] != s.Key {
|
|
w.WriteHeader(http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
token := m["token"]
|
|
session := s.manager.GetSession(token)
|
|
if session == nil {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
br := bufio.NewReader(r.Body)
|
|
data, err := br.ReadString('\n')
|
|
if err != nil {
|
|
s.manager.DelSession(token)
|
|
close(session.rchan)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
data = strings.TrimSuffix(data, "\n")
|
|
if len(data) == 0 {
|
|
s.manager.DelSession(token)
|
|
close(session.rchan)
|
|
return
|
|
}
|
|
|
|
b, err := base64.StdEncoding.DecodeString(data)
|
|
if err != nil {
|
|
s.manager.DelSession(token)
|
|
close(session.rchan)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-session.closed:
|
|
s.manager.DelSession(token)
|
|
return
|
|
case session.rchan <- b:
|
|
w.WriteHeader(http.StatusOK)
|
|
case <-time.After(time.Second * 90):
|
|
s.manager.DelSession(token)
|
|
w.WriteHeader(http.StatusRequestTimeout)
|
|
}
|
|
}
|
|
|
|
func (s *Server) pollHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
m := parseAuth(r.Header.Get("Authorization"))
|
|
if m["key"] != s.Key {
|
|
w.WriteHeader(http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
token := m["token"]
|
|
session := s.manager.GetSession(token)
|
|
if session == nil {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
if fw, ok := w.(http.Flusher); ok {
|
|
fw.Flush()
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case data, ok := <-session.wchan:
|
|
if !ok {
|
|
s.manager.DelSession(token)
|
|
return // session is closed
|
|
}
|
|
bw := bufio.NewWriter(w)
|
|
bw.WriteString(base64.StdEncoding.EncodeToString(data))
|
|
bw.WriteString("\n")
|
|
if err := bw.Flush(); err != nil {
|
|
return
|
|
}
|
|
|
|
if fw, ok := w.(http.Flusher); ok {
|
|
fw.Flush()
|
|
}
|
|
case <-time.After(time.Second * 25):
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) upgrade(sess *session, r *http.Request) (net.Conn, error) {
|
|
conn := newConn(sess)
|
|
raddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
|
if err != nil {
|
|
raddr = &net.TCPAddr{}
|
|
}
|
|
conn.remoteAddr = raddr
|
|
|
|
laddr, err := net.ResolveTCPAddr("tcp", s.Addr)
|
|
if err != nil {
|
|
laddr = &net.TCPAddr{}
|
|
}
|
|
conn.localAddr = laddr
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func parseAuth(auth string) map[string]string {
|
|
mkv := make(map[string]string)
|
|
|
|
for _, s := range strings.Split(auth, ";") {
|
|
n := strings.Index(s, "=")
|
|
if n < 0 {
|
|
continue
|
|
}
|
|
mkv[strings.TrimSpace(s[:n])] = strings.TrimSpace(s[n+1:])
|
|
}
|
|
|
|
return mkv
|
|
}
|