diff --git a/client.go b/client.go index 47564d3..3bdc5c1 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "path" + "sync" ) var ( @@ -44,16 +45,15 @@ func NewClient(rawURL, schema string, headers map[string]string) *Client { } // Set required headers - c.Transport.header.Set("Accept", "application/json") - c.Transport.header.Set("Content-Type", "application/json") - c.Transport.header.Set("Accept-Profile", schema) - c.Transport.header.Set("Content-Profile", schema) - c.Transport.header.Set("X-Client-Info", "postgrest-go/"+version) - + c.Transport.SetHeaders(map[string]string{ + "Accept": "application/json", + "Content-Type": "application/json", + "Accept-Profile": schema, + "Content-Profile": schema, + "X-Client-Info": "postgrest-go/" + version, + }) // Set optional headers if they exist - for key, value := range headers { - c.Transport.header.Set(key, value) - } + c.Transport.SetHeaders(headers) return &c } @@ -84,20 +84,22 @@ func (c *Client) Ping() bool { // SetApiKey sets api key header for subsequent requests. func (c *Client) SetApiKey(apiKey string) *Client { - c.Transport.header.Set("apikey", apiKey) + c.Transport.SetHeader("apikey", apiKey) return c } // SetAuthToken sets authorization header for subsequent requests. func (c *Client) SetAuthToken(authToken string) *Client { - c.Transport.header.Set("Authorization", "Bearer "+authToken) + c.Transport.SetHeader("Authorization", "Bearer "+authToken) return c } // ChangeSchema modifies the schema for subsequent requests. func (c *Client) ChangeSchema(schema string) *Client { - c.Transport.header.Set("Accept-Profile", schema) - c.Transport.header.Set("Content-Profile", schema) + c.Transport.SetHeaders(map[string]string{ + "Accept-Profile": schema, + "Content-Profile": schema, + }) return c } @@ -156,17 +158,35 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string { } type transport struct { - header http.Header baseURL url.URL Parent http.RoundTripper + + mu sync.RWMutex + header http.Header +} + +func (t *transport) SetHeader(key, value string) { + t.mu.Lock() + defer t.mu.Unlock() + t.header.Set(key, value) +} + +func (t *transport) SetHeaders(headers map[string]string) { + t.mu.Lock() + defer t.mu.Unlock() + for key, value := range headers { + t.header.Set(key, value) + } } -func (t transport) RoundTrip(req *http.Request) (*http.Response, error) { +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.RLock() for headerName, values := range t.header { for _, val := range values { req.Header.Add(headerName, val) } } + t.mu.RUnlock() req.URL = t.baseURL.ResolveReference(req.URL)