diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 1cba6e7..a9f8872 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -296,8 +296,10 @@ func serve(chain *gost.Chain) error { case "tcp": ln, err = gost.TCPListener(node.Addr) case "rtcp": + // Directly use SSH port forwarding if the last chain node is forward+ssh if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHRemoteForwardConnector() + chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() } ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) case "udp": @@ -351,6 +353,11 @@ func serve(chain *gost.Chain) error { case "http": handler = gost.HTTPHandler(handlerOptions...) case "tcp": + // Directly use SSH port forwarding if the last chain node is forward+ssh + if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { + chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHDirectForwardConnector() + chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() + } handler = gost.TCPDirectForwardHandler(node.Remote, handlerOptions...) case "rtcp": handler = gost.TCPRemoteForwardHandler(node.Remote, handlerOptions...)