Skip to content

Commit b300a66

Browse files
authored
fix: add custom headers for the upgrade request (#64)
* fix: add custom headers for the upgrade request Signed-off-by: Ying Wang <yingwang@us.ibm.com>
1 parent 5724808 commit b300a66

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

fluent/client/ws_client.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ type DefaultWSConnectionFactory struct {
9595
URL string
9696
AuthInfo *IAMAuthInfo
9797
TLSConfig *tls.Config
98+
Header http.Header
9899
}
99100

100101
func (wcf *DefaultWSConnectionFactory) New() (ext.Conn, error) {
@@ -103,6 +104,13 @@ func (wcf *DefaultWSConnectionFactory) New() (ext.Conn, error) {
103104
header = http.Header{}
104105
)
105106

107+
// set additional custom headers. here we do not validate
108+
// header names and values. Caller should make sure the
109+
// headers provided are not conflict with protocols
110+
if wcf.Header != nil {
111+
header = wcf.Header
112+
}
113+
106114
if wcf.AuthInfo != nil && len(wcf.AuthInfo.IAMToken()) > 0 {
107115
header.Add(AuthorizationHeader, wcf.AuthInfo.IAMToken())
108116
}

fluent/client/ws_client_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ var _ = Describe("DefaultWSConnectionFactory", func() {
6464
svr *httptest.Server
6565
ch chan struct{}
6666
useTLS, testError bool
67+
testHeaders http.Header
6768
customErr *client.WSConnError
6869
)
6970

@@ -78,6 +79,11 @@ var _ = Describe("DefaultWSConnectionFactory", func() {
7879
header := r.Header.Get(fclient.AuthorizationHeader)
7980
Expect(header).To(Equal("oi"))
8081

82+
for k := range testHeaders {
83+
v := r.Header.Get(k)
84+
Expect(v).To(Equal(testHeaders[k][0]))
85+
}
86+
8187
svrConnection, err := ws.NewConnection(wc, svrOpts)
8288
if err != nil {
8389
Fail("broke")
@@ -111,6 +117,7 @@ var _ = Describe("DefaultWSConnectionFactory", func() {
111117

112118
AfterEach(func() {
113119
svr.Close()
120+
testHeaders = nil
114121
})
115122

116123
It("sends auth headers", func() {
@@ -131,6 +138,31 @@ var _ = Describe("DefaultWSConnectionFactory", func() {
131138
Expect(cli.Disconnect()).ToNot(HaveOccurred())
132139
})
133140

141+
It("sends auth headers with additional header", func() {
142+
u := "ws" + strings.TrimPrefix(svr.URL, "http")
143+
144+
testHeaders = http.Header{
145+
"User-Agent": []string{"xxxx:1.0.5"}, // user agent
146+
"X-a": []string{""}, // empty value
147+
"X-b": []string{"value"}, // some string value
148+
}
149+
150+
cli := fclient.NewWS(client.WSConnectionOptions{
151+
Factory: &client.DefaultWSConnectionFactory{
152+
URL: u,
153+
TLSConfig: &tls.Config{
154+
InsecureSkipVerify: true,
155+
},
156+
AuthInfo: NewIAMAuthInfo("oi"),
157+
Header: testHeaders,
158+
},
159+
})
160+
161+
Expect(cli.Connect()).ToNot(HaveOccurred())
162+
Eventually(ch).Should(Receive())
163+
Expect(cli.Disconnect()).ToNot(HaveOccurred())
164+
})
165+
134166
When("sends wrong url, expects error", func() {
135167

136168
BeforeEach(func() {

0 commit comments

Comments
 (0)