Skip to content

Commit 5ec1219

Browse files
authored
Merge pull request #667 from ekovacs/feature/allow-custom-tls-config-for-acceptor
Allow the clients of acceptor to specify their own tls.Config
2 parents 8a53aa9 + f4d6fe9 commit 5ec1219

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

accepter_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
package quickfix
1717

1818
import (
19+
"crypto/tls"
1920
"net"
2021
"testing"
2122

2223
"github.com/quickfixgo/quickfix/config"
2324

2425
proxyproto "github.com/pires/go-proxyproto"
2526
"github.com/stretchr/testify/assert"
27+
"github.com/stretchr/testify/require"
2628
)
2729

2830
func TestAcceptor_Start(t *testing.T) {
@@ -83,3 +85,44 @@ func TestAcceptor_Start(t *testing.T) {
8385
})
8486
}
8587
}
88+
89+
func TestAcceptor_SetTLSConfig(t *testing.T) {
90+
sessionSettings := NewSessionSettings()
91+
sessionSettings.Set(config.BeginString, BeginStringFIX42)
92+
sessionSettings.Set(config.SenderCompID, "sender")
93+
sessionSettings.Set(config.TargetCompID, "target")
94+
95+
genericSettings := NewSettings()
96+
97+
genericSettings.GlobalSettings().Set("SocketAcceptPort", "5001")
98+
_, err := genericSettings.AddSession(sessionSettings)
99+
require.NoError(t, err)
100+
101+
logger, err := NewScreenLogFactory().Create()
102+
require.NoError(t, err)
103+
acceptor := &Acceptor{settings: genericSettings, globalLog: logger}
104+
defer acceptor.Stop()
105+
// example of a customized tls.Config that loads the certificates dynamically by the `GetCertificate` function
106+
// as opposed to the Certificates slice, that is static in nature, and is only populated once and needs application restart to reload the certs.
107+
customizedTLSConfig := tls.Config{
108+
Certificates: []tls.Certificate{},
109+
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
110+
cert, err := tls.LoadX509KeyPair("_test_data/localhost.crt", "_test_data/localhost.key")
111+
if err != nil {
112+
return nil, err
113+
}
114+
return &cert, nil
115+
},
116+
}
117+
118+
acceptor.SetTLSConfig(&customizedTLSConfig)
119+
assert.NoError(t, acceptor.Start())
120+
assert.Len(t, acceptor.listeners, 1)
121+
122+
conn, err := tls.Dial("tcp", "localhost:5001", &tls.Config{
123+
InsecureSkipVerify: true,
124+
})
125+
require.NoError(t, err)
126+
assert.NotNil(t, conn)
127+
defer conn.Close()
128+
}

acceptor.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type Acceptor struct {
4848
sessionHostPort map[SessionID]int
4949
listeners map[string]net.Listener
5050
connectionValidator ConnectionValidator
51+
tlsConfig *tls.Config
5152
sessionFactory
5253
}
5354

@@ -81,9 +82,12 @@ func (a *Acceptor) Start() (err error) {
8182
a.listeners[address] = nil
8283
}
8384

84-
var tlsConfig *tls.Config
85-
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
86-
return
85+
if a.tlsConfig == nil {
86+
var tlsConfig *tls.Config
87+
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
88+
return
89+
}
90+
a.tlsConfig = tlsConfig
8791
}
8892

8993
var useTCPProxy bool
@@ -94,8 +98,8 @@ func (a *Acceptor) Start() (err error) {
9498
}
9599

96100
for address := range a.listeners {
97-
if tlsConfig != nil {
98-
if a.listeners[address], err = tls.Listen("tcp", address, tlsConfig); err != nil {
101+
if a.tlsConfig != nil {
102+
if a.listeners[address], err = tls.Listen("tcp", address, a.tlsConfig); err != nil {
99103
return
100104
}
101105
} else if a.listeners[address], err = net.Listen("tcp", address); err != nil {
@@ -421,3 +425,13 @@ LOOP:
421425
func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
422426
a.connectionValidator = validator
423427
}
428+
429+
// SetTLSConfig allows the creator of the Acceptor to specify a fully customizable tls.Config of their choice,
430+
// which will be used in the Start() method.
431+
//
432+
// Note: when the caller explicitly provides a tls.Config with this function,
433+
// it takes precendent over TLS settings specified in the acceptor's settings.GlobalSettings(),
434+
// meaning that the `settings.GlobalSettings()` object is not inspected or used for the creation of the tls.Config.
435+
func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) {
436+
a.tlsConfig = tlsConfig
437+
}

0 commit comments

Comments
 (0)