Skip to content

Commit 26dc502

Browse files
committed
allow the clients of acceptor to specify their own tls.Config
1 parent 8a53aa9 commit 26dc502

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

accepter_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
package quickfix
1717

1818
import (
19+
"crypto/tls"
1920
"net"
2021
"testing"
2122

2223
"github.com/quickfixgo/quickfix/config"
2324

2425
proxyproto "github.com/pires/go-proxyproto"
2526
"github.com/stretchr/testify/assert"
27+
"github.com/stretchr/testify/require"
2628
)
2729

2830
func TestAcceptor_Start(t *testing.T) {
@@ -83,3 +85,44 @@ func TestAcceptor_Start(t *testing.T) {
8385
})
8486
}
8587
}
88+
89+
func TestAcceptor_SetTLSConfig(t *testing.T) {
90+
sessionSettings := NewSessionSettings()
91+
sessionSettings.Set(config.BeginString, BeginStringFIX42)
92+
sessionSettings.Set(config.SenderCompID, "sender")
93+
sessionSettings.Set(config.TargetCompID, "target")
94+
95+
genericSettings := NewSettings()
96+
97+
genericSettings.GlobalSettings().Set("SocketAcceptPort", "5001")
98+
_, err := genericSettings.AddSession(sessionSettings)
99+
require.NoError(t, err)
100+
101+
logger, err := NewScreenLogFactory().Create()
102+
require.NoError(t, err)
103+
acceptor := &Acceptor{settings: genericSettings, globalLog: logger}
104+
defer acceptor.Stop()
105+
// example of a customized tls.Config that loads the certificates dynamically by the `GetCertificate` function
106+
// as opposed to the Certificates slice, that is static in nature, and is only populated once and needs application restart to reload the certs.
107+
customizedTLSConfig := tls.Config{
108+
Certificates: []tls.Certificate{},
109+
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
110+
cert, err := tls.LoadX509KeyPair("_test_data/localhost.crt", "_test_data/localhost.key")
111+
if err != nil {
112+
return nil, err
113+
}
114+
return &cert, nil
115+
},
116+
}
117+
118+
acceptor.SetTLSConfig(&customizedTLSConfig)
119+
assert.NoError(t, acceptor.Start())
120+
assert.Len(t, acceptor.listeners, 1)
121+
122+
conn, err := tls.Dial("tcp", "localhost:5001", &tls.Config{
123+
InsecureSkipVerify: true,
124+
})
125+
require.NoError(t, err)
126+
assert.NotNil(t, conn)
127+
defer conn.Close()
128+
}

acceptor.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"bufio"
2020
"bytes"
2121
"crypto/tls"
22+
"fmt"
2223
"io"
2324
"net"
2425
"runtime/debug"
@@ -48,6 +49,7 @@ type Acceptor struct {
4849
sessionHostPort map[SessionID]int
4950
listeners map[string]net.Listener
5051
connectionValidator ConnectionValidator
52+
tlsConfig *tls.Config
5153
sessionFactory
5254
}
5355

@@ -81,9 +83,12 @@ func (a *Acceptor) Start() (err error) {
8183
a.listeners[address] = nil
8284
}
8385

84-
var tlsConfig *tls.Config
85-
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
86-
return
86+
if a.tlsConfig == nil {
87+
var tlsConfig *tls.Config
88+
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
89+
return
90+
}
91+
a.tlsConfig = tlsConfig
8792
}
8893

8994
var useTCPProxy bool
@@ -94,8 +99,8 @@ func (a *Acceptor) Start() (err error) {
9499
}
95100

96101
for address := range a.listeners {
97-
if tlsConfig != nil {
98-
if a.listeners[address], err = tls.Listen("tcp", address, tlsConfig); err != nil {
102+
if a.tlsConfig != nil {
103+
if a.listeners[address], err = tls.Listen("tcp", address, a.tlsConfig); err != nil {
99104
return
100105
}
101106
} else if a.listeners[address], err = net.Listen("tcp", address); err != nil {
@@ -228,6 +233,7 @@ func (a *Acceptor) invalidMessage(msg *bytes.Buffer, err error) {
228233
func (a *Acceptor) handleConnection(netConn net.Conn) {
229234
defer func() {
230235
if err := recover(); err != nil {
236+
fmt.Println("asdqwe", a.globalLog)
231237
a.globalLog.OnEventf("Connection Terminated with Panic: %s", debug.Stack())
232238
}
233239

@@ -421,3 +427,12 @@ LOOP:
421427
func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
422428
a.connectionValidator = validator
423429
}
430+
431+
// SetTLSConfig allows the creator of the Acceptor to specify a fully customizable tls.Config.
432+
433+
// When the caller explicitly provides a tls.Config with this function,
434+
// it takes precendent over TLS settings specified in the acceptor's settings.GlobalSettings(),
435+
// meaning that the setting object is not inspected or used for the creation of the tls.Config.
436+
func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) {
437+
a.tlsConfig = tlsConfig
438+
}

0 commit comments

Comments
 (0)