Skip to content

Commit 3f873f9

Browse files
authored
filtering mTLS connections based on the subject name from Caller (#4081)
* filtering TLS connections based on the subject name from Caller * add validation to client rawCerts * fix lint * update config name and error msgs * update variable name in tlsconfig * renaming var in tlssetting * custom SN verification in verifyPeerCertificate * add minor validation * update err msg * address comment * address comment
1 parent 2142649 commit 3f873f9

File tree

7 files changed

+183
-47
lines changed

7 files changed

+183
-47
lines changed

cns/configuration/cns_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,6 @@
3535
"AZRSettings": {
3636
"PopulateHomeAzCacheRetryIntervalSecs": 60
3737
},
38-
"MinTLSVersion": "TLS 1.2"
38+
"MinTLSVersion": "TLS 1.2",
39+
"MtlsClientCertSubjectName": ""
3940
}

cns/configuration/configuration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ type CNSConfig struct {
5959
WireserverIP string
6060
GRPCSettings GRPCSettings
6161
MinTLSVersion string
62+
MtlsClientCertSubjectName string
6263
}
6364

6465
type TelemetrySettings struct {

cns/configuration/configuration_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
222222
IPAddress: "localhost",
223223
Port: 8080,
224224
},
225-
MinTLSVersion: "TLS 1.2",
225+
MinTLSVersion: "TLS 1.2",
226+
MtlsClientCertSubjectName: "",
226227
},
227228
},
228229
{
@@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
253254
IPAddress: "192.168.1.1",
254255
Port: 9090,
255256
},
256-
MinTLSVersion: "TLS 1.3",
257+
MinTLSVersion: "TLS 1.3",
258+
MtlsClientCertSubjectName: "example.com",
257259
},
258260
want: CNSConfig{
259261
ChannelMode: "Other",
@@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
283285
IPAddress: "192.168.1.1",
284286
Port: 9090,
285287
},
286-
MinTLSVersion: "TLS 1.3",
288+
MinTLSVersion: "TLS 1.3",
289+
MtlsClientCertSubjectName: "example.com",
287290
},
288291
},
289292
}

cns/service.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,54 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.
156156
return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
157157
}
158158

159+
// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name.
160+
func verifyPeerCertificate(verifiedChains [][]*x509.Certificate, clientSubjectName string) error {
161+
// no client subject name provided, skip verification
162+
if clientSubjectName == "" {
163+
return nil
164+
}
165+
166+
if len(verifiedChains) == 0 || len(verifiedChains[0]) == 0 {
167+
return errors.New("no client certificate provided during mTLS")
168+
}
169+
170+
// Get client leaf certificate
171+
clientCert := verifiedChains[0][0]
172+
// Match DNS names (case-insensitive)
173+
dnsNames := clientCert.DNSNames
174+
for _, dns := range dnsNames {
175+
if strings.EqualFold(dns, clientSubjectName) {
176+
return nil
177+
}
178+
}
179+
180+
// If SANs didn't match, fall back to Common Name (CN) match.
181+
clientCN := clientCert.Subject.CommonName
182+
if clientCN != "" && strings.EqualFold(clientCN, clientSubjectName) {
183+
return nil
184+
}
185+
186+
// maskHalf of the DNS names
187+
maskedDNS := make([]string, len(dnsNames))
188+
for i, dns := range dnsNames {
189+
maskedDNS[i] = maskHalf(dns)
190+
}
191+
192+
return errors.Errorf("Failed to verify client certificate subject name during mTLS, clientSubjectName: %s, client cert SANs: %+v, clientCN: %s",
193+
clientSubjectName, maskedDNS, maskHalf(clientCN))
194+
}
195+
196+
// maskHalf masks half of the input string with asterisks.
197+
func maskHalf(s string) string {
198+
n := len(s)
199+
if n == 0 {
200+
return s
201+
}
202+
203+
half := n / 2
204+
return s[:half] + strings.Repeat("*", n-half)
205+
}
206+
159207
func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
160208
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
161209
if err != nil {
@@ -202,8 +250,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
202250
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
203251
tlsConfig.ClientCAs = rootCAs
204252
tlsConfig.RootCAs = rootCAs
253+
tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
254+
return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName)
255+
}
205256
}
206-
207257
logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)
208258

209259
return tlsConfig, nil
@@ -254,6 +304,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
254304
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
255305
tlsConfig.ClientCAs = rootCAs
256306
tlsConfig.RootCAs = rootCAs
307+
tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
308+
return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName)
309+
}
257310
}
258311

259312
logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)

cns/service/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ func main() {
810810
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
811811
UseMTLS: cnsconfig.UseMTLS,
812812
MinTLSVersion: cnsconfig.MinTLSVersion,
813+
MtlsClientCertSubjectName: cnsconfig.MtlsClientCertSubjectName,
813814
}
814815
}
815816

cns/service_test.go

Lines changed: 118 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -133,57 +133,108 @@ func TestNewService(t *testing.T) {
133133
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
134134
testCertFilePath := createTestCertificate(t)
135135

136-
config.TLSSettings = serverTLS.TlsSettings{
137-
TLSPort: "10091",
138-
TLSSubjectName: "localhost",
139-
TLSCertificatePath: testCertFilePath,
140-
UseMTLS: true,
141-
MinTLSVersion: "TLS 1.2",
136+
cases := []struct {
137+
name string
138+
tlsSettings serverTLS.TlsSettings
139+
handshakeFailureExpected bool
140+
}{
141+
{
142+
name: "matching client SANs",
143+
tlsSettings: serverTLS.TlsSettings{
144+
TLSPort: "10091",
145+
TLSSubjectName: "localhost",
146+
TLSCertificatePath: testCertFilePath,
147+
UseMTLS: true,
148+
MinTLSVersion: "TLS 1.2",
149+
MtlsClientCertSubjectName: "example.com",
150+
},
151+
handshakeFailureExpected: false,
152+
},
153+
{
154+
name: "matching client cert CN",
155+
tlsSettings: serverTLS.TlsSettings{
156+
TLSPort: "10093",
157+
TLSSubjectName: "localhost",
158+
TLSCertificatePath: testCertFilePath,
159+
UseMTLS: true,
160+
MinTLSVersion: "TLS 1.2",
161+
MtlsClientCertSubjectName: "foo.com", // Common Name from test certificate
162+
},
163+
handshakeFailureExpected: false,
164+
},
165+
{
166+
name: "failing to match client SANs and CN",
167+
tlsSettings: serverTLS.TlsSettings{
168+
TLSPort: "10092",
169+
TLSSubjectName: "localhost",
170+
TLSCertificatePath: testCertFilePath,
171+
UseMTLS: true,
172+
MinTLSVersion: "TLS 1.2",
173+
MtlsClientCertSubjectName: "random.com",
174+
},
175+
handshakeFailureExpected: true,
176+
},
142177
}
143178

144-
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
145-
require.NoError(t, err)
146-
require.IsType(t, &Service{}, svc)
179+
for _, tc := range cases {
180+
t.Run(tc.name, func(t *testing.T) {
181+
config.TLSSettings = tc.tlsSettings
147182

148-
svc.SetOption(acn.OptCnsURL, "")
149-
svc.SetOption(acn.OptCnsPort, "")
183+
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
184+
require.NoError(t, err)
185+
require.IsType(t, &Service{}, svc)
150186

151-
err = svc.Initialize(config)
152-
t.Cleanup(func() {
153-
svc.Uninitialize()
154-
})
155-
require.NoError(t, err)
187+
svc.SetOption(acn.OptCnsURL, "")
188+
svc.SetOption(acn.OptCnsPort, "")
156189

157-
err = svc.StartListener(config)
158-
require.NoError(t, err)
190+
err = svc.Initialize(config)
191+
require.NoError(t, err)
159192

160-
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
161-
require.NoError(t, err)
193+
err = svc.StartListener(config)
194+
require.NoError(t, err)
162195

163-
client := &http.Client{
164-
Transport: &http.Transport{
165-
TLSClientConfig: mTLSConfig,
166-
},
167-
}
196+
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
197+
require.NoError(t, err)
168198

169-
// TLS listener
170-
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
171-
require.NoError(t, err)
172-
resp, err := client.Do(req)
173-
t.Cleanup(func() {
174-
resp.Body.Close()
175-
})
176-
require.NoError(t, err)
199+
client := &http.Client{
200+
Transport: &http.Transport{
201+
TLSClientConfig: mTLSConfig,
202+
},
203+
}
177204

178-
// HTTP listener
179-
httpClient := &http.Client{}
180-
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
181-
require.NoError(t, err)
182-
resp, err = httpClient.Do(req)
183-
t.Cleanup(func() {
184-
resp.Body.Close()
185-
})
186-
require.NoError(t, err)
205+
tlsURL := "https://localhost:" + tc.tlsSettings.TLSPort
206+
// TLS listener
207+
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody)
208+
require.NoError(t, err)
209+
resp, err := client.Do(req)
210+
if tc.handshakeFailureExpected {
211+
require.Error(t, err)
212+
require.ErrorContains(t, err, "Failed to verify client certificate subject name during mTLS")
213+
} else {
214+
require.NoError(t, err)
215+
t.Cleanup(func() {
216+
if resp != nil && resp.Body != nil {
217+
resp.Body.Close()
218+
}
219+
})
220+
}
221+
222+
// HTTP listener
223+
httpClient := &http.Client{}
224+
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
225+
require.NoError(t, err)
226+
resp, err = httpClient.Do(req)
227+
require.NoError(t, err)
228+
t.Cleanup(func() {
229+
if resp != nil && resp.Body != nil {
230+
resp.Body.Close()
231+
}
232+
})
233+
234+
// Cleanup
235+
svc.Uninitialize()
236+
})
237+
}
187238
})
188239
}
189240

@@ -355,3 +406,28 @@ func TestTLSVersionNumber(t *testing.T) {
355406
require.NoError(t, err)
356407
})
357408
}
409+
410+
func TestMaskHalf(t *testing.T) {
411+
tests := []struct {
412+
name string
413+
in string
414+
want string
415+
}{
416+
{"empty", "", ""},
417+
{"one char string", "e", "*"},
418+
{"two chars string", "ex", "e*"},
419+
{"three chars string", "exa", "e**"},
420+
{"four chars string", "exam", "ex**"},
421+
{"five chars string", "examp", "ex***"},
422+
{"long string", "example.com", "examp******"},
423+
}
424+
425+
for _, tc := range tests {
426+
t.Run(tc.name, func(t *testing.T) {
427+
got := maskHalf(tc.in)
428+
if got != tc.want {
429+
t.Fatalf("maskHalf(%s) = %s, want %s", tc.in, got, tc.want)
430+
}
431+
})
432+
}
433+
}

server/tls/tlscertificate_retriever.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ type TlsSettings struct {
1515
KeyVaultCertificateRefreshInterval time.Duration
1616
UseMTLS bool
1717
MinTLSVersion string
18+
MtlsClientCertSubjectName string
1819
}
1920

2021
func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {

0 commit comments

Comments
 (0)