Skip to content

Commit 24fd705

Browse files
committed
chore: upgrade goVirtualHost to support multi certificates
1 parent ed1519e commit 24fd705

File tree

10 files changed

+105
-29
lines changed

10 files changed

+105
-29
lines changed

src/goVirtualHost/helper.go

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,69 @@
11
package goVirtualHost
22

3-
import "crypto/tls"
3+
import (
4+
"crypto/tls"
5+
"errors"
6+
)
47

5-
func LoadCertificate(certFile, keyFile string) (*tls.Certificate, error) {
6-
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
7-
if err != nil {
8-
return nil, err
8+
var MissingCertFileAndKeyFile = errors.New("missing certificate file and key file")
9+
var MissingCertFile = errors.New("missing certificate file")
10+
var MissingKeyFile = errors.New("missing key file")
11+
var CertKeyFileCountNotMatch = errors.New("certificate file count and key file count not match")
12+
13+
func LoadCertificate(certFile, keyFile string) (cert tls.Certificate, err error) {
14+
if len(certFile) == 0 && len(keyFile) == 0 {
15+
err = MissingCertFileAndKeyFile
16+
return
17+
} else if len(certFile) == 0 {
18+
err = MissingCertFile
19+
return
20+
} else if len(keyFile) == 0 {
21+
err = MissingKeyFile
22+
return
23+
}
24+
25+
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
26+
return
27+
}
28+
29+
func LoadCertificates(certFiles, keyFiles []string) (certs []tls.Certificate, errs []error) {
30+
certLen := len(certFiles)
31+
if certLen != len(keyFiles) {
32+
errs = append(errs, CertKeyFileCountNotMatch)
33+
return
34+
}
35+
36+
if certLen == 0 {
37+
return
38+
}
39+
40+
certs = make([]tls.Certificate, 0, certLen)
41+
for i := 0; i < certLen; i++ {
42+
cert, err := LoadCertificate(certFiles[i], keyFiles[i])
43+
if err != nil {
44+
errs = append(errs, err)
45+
} else {
46+
certs = append(certs, cert)
47+
}
48+
}
49+
50+
return
51+
}
52+
53+
func LoadCertificatesFromEntries(certKeyFileEntries [][2]string) (certs []tls.Certificate, errs []error) {
54+
certLen := len(certKeyFileEntries)
55+
if certLen == 0 {
56+
return
57+
}
58+
59+
certs = make([]tls.Certificate, 0, certLen)
60+
for i := 0; i < certLen; i++ {
61+
cert, err := LoadCertificate(certKeyFileEntries[i][0], certKeyFileEntries[i][1])
62+
if err != nil {
63+
errs = append(errs, err)
64+
} else {
65+
certs = append(certs, cert)
66+
}
967
}
10-
return &cert, nil
68+
return
1169
}

src/goVirtualHost/hostInfo.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,28 @@ import "crypto/tls"
44

55
func (info *HostInfo) toParam(listen string, useTLS bool) *param {
66
proto, ip, port := splitListen(listen, false)
7-
var cert *tls.Certificate
7+
var certs []tls.Certificate
88
if useTLS {
9-
cert = info.Cert
9+
certs = info.Certs
1010
}
1111

1212
param := &param{
1313
proto: proto,
1414
ip: ip,
1515
port: port,
1616
useTLS: useTLS,
17-
cert: cert,
17+
certs: certs,
1818
}
1919

2020
return param
2121
}
2222

23-
func (info *HostInfo) parse() (hostNames []string, params params) {
23+
func (info *HostInfo) parse() (hostNames []string, params params, certs certs) {
2424
hostNames = normalizeHostNames(info.HostNames)
2525

26+
useTLSForListen := len(info.Certs) > 0
2627
for _, listen := range info.Listens {
27-
param := info.toParam(listen, info.Cert != nil)
28+
param := info.toParam(listen, useTLSForListen)
2829
param.hostNames = hostNames
2930
params = append(params, param)
3031
}
@@ -41,5 +42,9 @@ func (info *HostInfo) parse() (hostNames []string, params params) {
4142
params = append(params, param)
4243
}
4344

45+
if (useTLSForListen && len(info.Listens) > 0) || len(info.ListensTLS) > 0 {
46+
certs = append(info.Certs)
47+
}
48+
4449
return
4550
}

src/goVirtualHost/param.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func (param *param) hasHostNames(checkHostNames []string) bool {
2323
}
2424

2525
func (param *param) validate() (errs []error) {
26-
if param.useTLS && param.cert == nil {
26+
if param.useTLS && len(param.certs) == 0 {
2727
err := wrapError(CertificateNotFound, fmt.Sprintf("certificate not found for TLS listens: %+v", param))
2828
errs = append(errs, err)
2929
}

src/goVirtualHost/param_test.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,32 @@ func TestParamValidate(t *testing.T) {
1717
}
1818
errs = p.validate()
1919
if len(errs) > 0 {
20-
t.Error()
20+
t.Error(errs)
2121
}
2222

2323
p.useTLS = true
2424
errs = p.validate()
2525
if len(errs) == 0 {
2626
t.Error()
2727
} else if !errors.Is(errs[0], CertificateNotFound) {
28-
t.Error()
28+
t.Error(errs)
29+
}
30+
31+
p.certs = nil
32+
errs = p.validate()
33+
if len(errs) == 0 {
34+
t.Error(errs)
2935
}
3036

31-
p.cert = &tls.Certificate{}
37+
p.certs = []tls.Certificate{}
38+
errs = p.validate()
39+
if len(errs) == 0 {
40+
t.Error(errs)
41+
}
42+
43+
p.certs = append(p.certs, tls.Certificate{})
3244
errs = p.validate()
3345
if len(errs) > 0 {
34-
t.Error()
46+
t.Error(errs)
3547
}
3648
}

src/goVirtualHost/params.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ func (params params) validateParam(param *param) (errs []error) {
3333
}
3434

3535
if ownParam.proto == param.proto && ownParam.ip == param.ip && ownParam.port == param.port {
36-
ownUseTLS := ownParam.cert != nil
37-
useTLS := param.cert != nil
38-
if ownUseTLS != useTLS {
36+
if ownParam.useTLS != param.useTLS {
3937
err := wrapError(ConflictTLSMode, fmt.Sprintf("cannot serve for both Plain and TLS mode: %+v, %+v", ownParam, param))
4038
errs = append(errs, err)
4139
}

src/goVirtualHost/params_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func TestParamsValidateParam(t *testing.T) {
8282
ip: "",
8383
port: ":80",
8484
useTLS: true,
85-
cert: &tls.Certificate{},
85+
certs: []tls.Certificate{},
8686
}
8787
errs = ps.validateParam(p)
8888
if len(errs) == 0 {

src/goVirtualHost/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (server *server) updateHttpServerTLSConfig() {
5454
certs := []tls.Certificate{}
5555

5656
for _, vhost := range server.vhosts {
57-
certs = append(certs, *vhost.cert)
57+
certs = append(certs, vhost.certs...)
5858
}
5959

6060
tlsConfig = &tls.Config{

src/goVirtualHost/service.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ func (svc *Service) Add(info *HostInfo) (errs []error) {
5050
return
5151
}
5252

53-
hostNames, vhostParams := info.parse()
53+
hostNames, vhostParams, certs := info.parse()
5454

5555
errs = svc.params.validate(vhostParams)
5656
if len(errs) > 0 {
5757
return
5858
}
5959
svc.params = append(svc.params, vhostParams...)
6060

61-
vhost := newVhost(info.Cert, hostNames, info.Handler)
61+
vhost := newVhost(certs, hostNames, info.Handler)
6262
svc.vhosts = append(svc.vhosts, vhost)
6363

6464
svc.addVhostToServers(vhost, vhostParams)
@@ -114,6 +114,8 @@ func (svc *Service) Open() (errs []error) {
114114
svc.state = stateOpened
115115
svc.mu.Unlock()
116116

117+
svc.params = nil // release unused data
118+
117119
for _, s := range svc.servers {
118120
s.updateDefaultVhost()
119121
s.updateHttpServerTLSConfig()

src/goVirtualHost/type.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@ type HostInfo struct {
1212
Listens []string
1313
ListensPlain []string
1414
ListensTLS []string
15-
Cert *tls.Certificate
15+
Certs []tls.Certificate
1616
HostNames []string
1717
Handler http.Handler
1818
}
1919

20+
type certs []tls.Certificate
21+
2022
// normalized HostInfo Param
2123
type param struct {
2224
proto string // "tcp", "tcp4", "tcp6"
2325
ip string
2426
port string
2527
useTLS bool
26-
cert *tls.Certificate
28+
certs certs
2729
hostNames []string
2830
}
2931

@@ -52,7 +54,7 @@ type servers []*server
5254

5355
// virtual host
5456
type vhost struct {
55-
cert *tls.Certificate
57+
certs certs
5658
hostNames []string
5759
handler http.Handler
5860
}

src/goVirtualHost/vhost.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
package goVirtualHost
22

33
import (
4-
"crypto/tls"
54
"net/http"
65
"strings"
76
)
87

9-
func newVhost(cert *tls.Certificate, hostNames []string, handler http.Handler) *vhost {
8+
func newVhost(certs certs, hostNames []string, handler http.Handler) *vhost {
109
vhost := &vhost{
11-
cert: cert,
10+
certs: certs,
1211
hostNames: hostNames,
1312
handler: handler,
1413
}

0 commit comments

Comments
 (0)