diff --git a/gost/kcp.go b/gost/kcp.go index a5ccfa2..08c5684 100644 --- a/gost/kcp.go +++ b/gost/kcp.go @@ -191,7 +191,12 @@ func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Con session, ok := tr.sessions[addr] if !ok { - return net.DialUDP("udp", nil, uaddr) + conn, err = net.DialUDP("udp", nil, uaddr) + if err != nil { + return + } + session = &kcpSession{conn: conn} + tr.sessions[addr] = session } return session.conn, nil } @@ -209,7 +214,7 @@ func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( defer tr.sessionMutex.Unlock() session, ok := tr.sessions[opts.Addr] - if !ok { + if !ok || session.session == nil { s, err := tr.initSession(opts.Addr, conn, config) if err != nil { conn.Close() @@ -229,13 +234,15 @@ func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( } func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*kcpSession, error) { - pc, ok := conn.(net.PacketConn) + udpConn, ok := conn.(*net.UDPConn) if !ok { return nil, errors.New("wrong connection type") } kcpconn, err := kcp.NewConn(addr, - blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard, pc) + blockCrypt(config.Key, config.Crypt, KCPSalt), + config.DataShard, config.ParityShard, + &kcp.ConnectedUDPConn{UDPConn: udpConn, Conn: udpConn}) if err != nil { return nil, err }