support mutual TLS authentication

This commit is contained in:
ginuerzh 2020-05-23 21:07:37 +08:00
parent 60d7e01164
commit b285bcd6ce
3 changed files with 18 additions and 5 deletions

View File

@ -44,8 +44,9 @@ var (
defaultKeyFile = "key.pem"
)
// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid.
func tlsConfig(certFile, keyFile string) (*tls.Config, error) {
// Load the certificate from cert & key files and optional client CA file,
// will use the default certificate if the provided info are invalid.
func tlsConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
if certFile == "" || keyFile == "" {
certFile, keyFile = defaultCertFile, defaultKeyFile
}
@ -54,7 +55,15 @@ func tlsConfig(certFile, keyFile string) (*tls.Config, error) {
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
if pool, _ := loadCA(caFile); pool != nil {
cfg.ClientCAs = pool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return cfg, nil
}
func loadCA(caFile string) (cp *x509.CertPool, err error) {

View File

@ -67,7 +67,7 @@ func main() {
}
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile)
tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile, "")
if err != nil {
// generate random self-signed certificate.
cert, err := gost.GenCertificate()

View File

@ -128,6 +128,10 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
InsecureSkipVerify: !node.GetBool("secure"),
RootCAs: rootCAs,
}
if cert, err := tls.LoadX509KeyPair(node.Get("cert"), node.Get("key")); err == nil {
tlsCfg.Certificates = []tls.Certificate{cert}
}
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize = node.GetInt("rbuf")
@ -343,7 +347,7 @@ func (r *route) GenRouters() ([]router, error) {
}
}
certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile)
tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca"))
if err != nil && certFile != "" && keyFile != "" {
return nil, err
}