Skip to content

Commit 4a7dc3e

Browse files
authored
fix: add auth endpoint for getting a client by id, back links (#330)
1 parent 0eff85a commit 4a7dc3e

File tree

5 files changed

+265
-26
lines changed

5 files changed

+265
-26
lines changed

diode-server/auth/manager.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type RetrieveClientsRequest struct {
3232
type RetrieveClientsResponse struct {
3333
Clients []ClientInfo
3434
NextPageToken string
35+
PrevPageToken string
3536
}
3637

3738
// ClientManager is an interface for managing oauth2 clients.

diode-server/auth/manager_hydra.go

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func (h *HydraClientManager) RetrieveClients(ctx context.Context, q RetrieveClie
168168
out.Clients = append(out.Clients, clientInfoFromHydraClient(&client))
169169
}
170170

171-
out.NextPageToken = getHydraNextPageToken(response, h.logger)
171+
out.NextPageToken, out.PrevPageToken = getHydraPagingTokens(response, h.logger)
172172
return out, nil
173173
}
174174

@@ -199,7 +199,10 @@ func clientInfoFromHydraClient(client *hydra.OAuth2Client) ClientInfo {
199199
return clientInfo
200200
}
201201

202-
func getHydraNextPageToken(response *http.Response, logger *slog.Logger) string {
202+
// getHydraPagingLinks returns the next and previous page tokens from the response header
203+
func getHydraPagingTokens(response *http.Response, logger *slog.Logger) (string, string) {
204+
next := ""
205+
prev := ""
203206
for _, linkHeader := range response.Header.Values("Link") {
204207
links := strings.Split(linkHeader, ",")
205208
for _, link := range links {
@@ -209,33 +212,40 @@ func getHydraNextPageToken(response *http.Response, logger *slog.Logger) string
209212
}
210213
link := params[0]
211214
params = params[1:]
212-
// search for rel="next"
213215
for _, param := range params {
214216
vs := strings.Split(param, "=")
215217
if len(vs) != 2 {
216218
continue
217219
}
218220
k, v := strings.TrimSpace(vs[0]), strings.TrimSpace(vs[1])
219-
if k == "rel" && (v == "next" || v == "\"next\"") {
220-
link = strings.TrimPrefix(link, "<")
221-
link = strings.TrimSuffix(link, ">")
222-
parsedURL, err := url.Parse(link)
223-
if err != nil {
224-
logger.Warn("failed to parse url in rel=next link", "error", err, "link", linkHeader)
225-
return ""
221+
if k == "rel" {
222+
if v == "next" || v == "\"next\"" {
223+
next = getHydraPageToken(link, logger)
224+
} else if v == "prev" || v == "\"prev\"" {
225+
prev = getHydraPageToken(link, logger)
226226
}
227-
queryParams := parsedURL.Query()
228-
for key, values := range queryParams {
229-
if key == "page_token" {
230-
logger.Info("found next page token", "token", values[0])
231-
return values[0]
232-
}
233-
}
234-
logger.Warn("failed to find next page token in rel=next url", "link", linkHeader)
235-
return ""
236227
}
237228
}
238229
}
239230
}
231+
return next, prev
232+
}
233+
234+
func getHydraPageToken(link string, logger *slog.Logger) string {
235+
link = strings.TrimPrefix(link, "<")
236+
link = strings.TrimSuffix(link, ">")
237+
parsedURL, err := url.Parse(link)
238+
if err != nil {
239+
logger.Warn("failed to parse url in hydra paging link", "error", err, "link", link)
240+
return ""
241+
}
242+
queryParams := parsedURL.Query()
243+
for key, values := range queryParams {
244+
if key == "page_token" {
245+
logger.Debug("found page token", "token", values[0])
246+
return values[0]
247+
}
248+
}
249+
logger.Warn("failed to find page_token in hydra paging link", "link", link)
240250
return ""
241251
}

diode-server/auth/server.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ type ClientResponse struct {
6868
type ListClientsResponse struct {
6969
Data []ClientResponse `json:"data"`
7070
NextPageToken string `json:"next_page_token,omitempty"`
71+
PrevPageToken string `json:"prev_page_token,omitempty"`
7172
}
7273

7374
// ClientErrorResponse error response to client requests
@@ -158,6 +159,7 @@ func (s *Server) RegisterHandlers() {
158159
s.mux.HandleFunc("POST /token", s.token)
159160
s.mux.HandleFunc("POST /clients", s.createClient)
160161
s.mux.HandleFunc("GET /clients", s.listClients)
162+
s.mux.HandleFunc("GET /clients/{clientID}", s.getClient)
161163
s.mux.HandleFunc("DELETE /clients/{clientID}", s.deleteClient)
162164
}
163165

@@ -492,6 +494,7 @@ func (s *Server) listClients(w http.ResponseWriter, r *http.Request) {
492494
out := ListClientsResponse{
493495
Data: make([]ClientResponse, 0, len(clients.Clients)),
494496
NextPageToken: clients.NextPageToken,
497+
PrevPageToken: clients.PrevPageToken,
495498
}
496499
for _, client := range clients.Clients {
497500
out.Data = append(out.Data, ClientResponse{
@@ -557,3 +560,47 @@ func (s *Server) deleteClient(w http.ResponseWriter, r *http.Request) {
557560

558561
w.WriteHeader(http.StatusNoContent)
559562
}
563+
564+
func (s *Server) getClient(w http.ResponseWriter, r *http.Request) {
565+
jwtToken, _, ok := s.authorizeCall(w, r, []string{authutil.ScopeDiodeRead})
566+
if !ok {
567+
return
568+
}
569+
570+
ownerID, err := s.tokenOwnership.TokenOwnerID(r.Context(), jwtToken)
571+
if err != nil {
572+
s.logger.Error("failed to get token owner ID", "error", err)
573+
w.WriteHeader(statusFromError(err))
574+
return
575+
}
576+
577+
clientID := r.PathValue("clientID")
578+
if clientID == "" {
579+
err = writeJSON(w, http.StatusBadRequest, ClientErrorResponse{Error: "client ID is required"})
580+
if err != nil {
581+
s.logger.Error("failed to write response", "error", err)
582+
w.WriteHeader(http.StatusInternalServerError)
583+
}
584+
return
585+
}
586+
587+
// get the client and verify ownership
588+
client, err := s.clientManager.RetrieveClientByID(r.Context(), clientID)
589+
if err != nil {
590+
s.logger.Error("failed to get client", "error", err)
591+
w.WriteHeader(statusFromError(err))
592+
return
593+
}
594+
595+
if client.Owner != ownerID {
596+
s.logger.Error("client does not belong to requestor", "client_id", clientID, "owner_id", ownerID)
597+
w.WriteHeader(http.StatusNotFound)
598+
return
599+
}
600+
601+
err = writeJSON(w, http.StatusOK, client)
602+
if err != nil {
603+
s.logger.Error("failed to write response", "error", err)
604+
w.WriteHeader(http.StatusInternalServerError)
605+
}
606+
}

diode-server/auth/server_hydra_integration_test.go

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,38 @@ func TestServerHydraIntegration(t *testing.T) {
150150
require.Equal(t, 10, len(result.Data))
151151
require.Equal(t, result.NextPageToken, "")
152152

153-
// page through the clients in pages of size 3
154-
pageSize := 3
153+
// get the first client
154+
firstClient := client.getClient(t, result.Data[0].ClientID)
155+
require.Equal(t, result.Data[0].ClientID, firstClient.ClientID)
156+
require.Equal(t, result.Data[0].ClientName, firstClient.ClientName)
157+
require.Equal(t, result.Data[0].Scope, firstClient.Scope)
158+
159+
// delete the first client
160+
client.deleteClient(t, result.Data[0].ClientID)
161+
162+
// list clients, should include the 9 remaining test clients
163+
result = client.listClients(t, "", 100)
164+
require.Equal(t, 9, len(result.Data))
165+
166+
// page through the 9 remaining clients in pages of size 2
167+
var priorResult auth.ListClientsResponse
168+
pageSize := 2
155169
nextToken := ""
156170
seen := make(map[string]bool)
157171
pages := 0
158-
for range 5 { // should be 4 pages, stop after 5
172+
for range 6 { // should be 5 pages, stop after 6
159173
result = client.listClients(t, nextToken, pageSize)
174+
// previous page should be the same as the prior result
175+
if pages > 0 {
176+
prevPage := client.listClients(t, result.PrevPageToken, pageSize)
177+
require.Equal(t, priorResult.Data, prevPage.Data)
178+
}
179+
priorResult = result
180+
160181
pages++
161182
for _, c := range result.Data {
162183
seen[c.ClientID] = true
163184
}
164-
165185
nextToken = result.NextPageToken
166186
if nextToken == "" {
167187
break
@@ -170,9 +190,9 @@ func TestServerHydraIntegration(t *testing.T) {
170190
}
171191
}
172192

173-
// verify that we saw all 10 clients
174-
require.Equal(t, 10, len(seen))
175-
require.Equal(t, 4, pages)
193+
// verify that we saw all 9 clients
194+
require.Equal(t, 9, len(seen))
195+
require.Equal(t, 5, pages)
176196
}
177197

178198
type authTestClient struct {
@@ -233,6 +253,37 @@ func (c *authTestClient) createClient(t *testing.T, clientName string, scope str
233253
return createdClient
234254
}
235255

256+
func (c *authTestClient) getClient(t *testing.T, clientID string) auth.ClientResponse {
257+
req, err := http.NewRequest(http.MethodGet, c.endpoint+"/clients/"+clientID, nil)
258+
require.NoError(t, err)
259+
req.Header.Set("Authorization", "Bearer "+c.token)
260+
client := &http.Client{}
261+
resp, err := client.Do(req)
262+
require.NoError(t, err)
263+
defer func() {
264+
_ = resp.Body.Close()
265+
}()
266+
require.Equal(t, http.StatusOK, resp.StatusCode)
267+
268+
var result auth.ClientResponse
269+
err = json.NewDecoder(resp.Body).Decode(&result)
270+
require.NoError(t, err)
271+
return result
272+
}
273+
274+
func (c *authTestClient) deleteClient(t *testing.T, clientID string) {
275+
req, err := http.NewRequest(http.MethodDelete, c.endpoint+"/clients/"+clientID, nil)
276+
require.NoError(t, err)
277+
req.Header.Set("Authorization", "Bearer "+c.token)
278+
client := &http.Client{}
279+
resp, err := client.Do(req)
280+
require.NoError(t, err)
281+
defer func() {
282+
_ = resp.Body.Close()
283+
}()
284+
require.Equal(t, http.StatusNoContent, resp.StatusCode)
285+
}
286+
236287
func (c *authTestClient) authenticate(t *testing.T, clientID string, clientSecret string, scope string) {
237288
data := url.Values{}
238289
data.Set("grant_type", "client_credentials")

diode-server/auth/server_test.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,136 @@ func TestDeleteClient(t *testing.T) {
726726
}
727727
}
728728

729+
func TestGetClient(t *testing.T) {
730+
readOnlyToken := jwt.Token{
731+
Claims: jwt.MapClaims{
732+
"exp": time.Now().Add(time.Hour).Unix(),
733+
"iat": time.Now().Unix(),
734+
"client_id": "client123",
735+
"scope": "diode:read",
736+
},
737+
Valid: true,
738+
}
739+
invalidToken := jwt.Token{
740+
Valid: false,
741+
}
742+
743+
validAccessToken := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5IiwidHlwIjoiSldUIn0.eyJpc3MiOiJodHRwczovL2F1dGguZXhhbXBsZS5jb20iLCJzdWIiOiJ1c2VyMTIzIiwiYXVkIjoiYXBpIiwiZXhwIjoxNjUwMDAwMDAwLCJpYXQiOjE1MDAwMDAwMDAsImNsaWVudF9pZCI6ImNsaWVudDEyMyIsInNjb3BlIjoicmVhZCB3cml0ZSIsInVzZXJuYW1lIjoidGVzdHVzZXIifQ.WcPGXClpKD7Bc1C0CCDA1060E2GGlTfamrd8-W0ghBE"
744+
invalidAccessToken := "invalid.token.string"
745+
746+
tests := []struct {
747+
name string
748+
accessToken string
749+
clientID string
750+
parsedToken jwt.Token
751+
lookupResult auth.ClientInfo
752+
lookupErr error
753+
expectStatus int
754+
expect auth.ClientResponse
755+
}{
756+
{
757+
name: "can get client",
758+
accessToken: validAccessToken,
759+
clientID: "test-client-1-abcdef0123567890",
760+
parsedToken: readOnlyToken,
761+
lookupResult: auth.ClientInfo{
762+
ClientID: "test-client-1-abcdef0123567890",
763+
ClientName: "Test Client 1",
764+
Scope: "diode:ingest",
765+
Owner: "diode/user",
766+
CreatedAt: "2021-01-01T00:00:00Z",
767+
},
768+
expect: auth.ClientResponse{
769+
ClientID: "test-client-1-abcdef0123567890",
770+
ClientName: "Test Client 1",
771+
Scope: "diode:ingest",
772+
CreatedAt: "2021-01-01T00:00:00Z",
773+
},
774+
expectStatus: http.StatusOK,
775+
},
776+
{
777+
name: "cannot get client with invalid access token",
778+
accessToken: invalidAccessToken,
779+
clientID: "test-client-1-abcdef0123567890",
780+
parsedToken: invalidToken,
781+
expectStatus: http.StatusUnauthorized,
782+
},
783+
{
784+
name: "cannot get client that does not exist",
785+
accessToken: validAccessToken,
786+
clientID: "test-client-1-abcdef0123567890",
787+
parsedToken: readOnlyToken,
788+
lookupErr: auth.NewAuthError("client not found", http.StatusNotFound),
789+
expectStatus: http.StatusNotFound,
790+
},
791+
{
792+
name: "cannot get a client with the wrong owner",
793+
accessToken: validAccessToken,
794+
clientID: "test-client-1-abcdef0123567890",
795+
parsedToken: readOnlyToken,
796+
lookupResult: auth.ClientInfo{
797+
ClientID: "test-client-1-abcdef0123567890",
798+
ClientName: "Test Client 1",
799+
Owner: "diode/system",
800+
Scope: "diode:read diode:write",
801+
CreatedAt: "2021-01-01T00:00:00Z",
802+
},
803+
expectStatus: http.StatusNotFound,
804+
},
805+
}
806+
807+
ctx := context.Background()
808+
setupEnv()
809+
defer teardownEnv()
810+
811+
// Setup a test server to mock the OAuth2 server
812+
mockJWKSServer := mockJWKSServer()
813+
defer mockJWKSServer.Close()
814+
815+
_ = os.Setenv("OAUTH2_PUBLIC_SERVER_URL", mockJWKSServer.URL)
816+
defer func() {
817+
_ = os.Unsetenv("OAUTH2_PUBLIC_SERVER_URL")
818+
}()
819+
820+
for _, test := range tests {
821+
t.Run(test.name, func(t *testing.T) {
822+
defaultOwnership := &auth.DefaultTokenOwner{}
823+
accessToken := test.accessToken
824+
if accessToken == "" {
825+
accessToken = validAccessToken
826+
}
827+
mockTokenParser := &MockTokenParser{
828+
tokenMap: map[string]jwt.Token{
829+
accessToken: test.parsedToken,
830+
},
831+
}
832+
mockClientManager := &mocks.ClientManager{}
833+
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug, AddSource: false}))
834+
server, err := auth.NewServer(ctx, logger, mockTokenParser, mockClientManager, defaultOwnership)
835+
require.NoError(t, err)
836+
require.NotNil(t, server)
837+
838+
testServer := httptest.NewServer(server.GetMux())
839+
defer testServer.Close()
840+
841+
if test.lookupResult != (auth.ClientInfo{}) || test.lookupErr != nil {
842+
mockClientManager.EXPECT().RetrieveClientByID(mock.Anything, test.clientID).Return(test.lookupResult, test.lookupErr)
843+
}
844+
845+
req, _ := http.NewRequest("GET", testServer.URL+"/clients/"+test.clientID, nil)
846+
req.Header.Set("Authorization", "Bearer "+accessToken)
847+
req.Header.Set("Content-Type", "application/json")
848+
client := &http.Client{}
849+
resp, err := client.Do(req)
850+
require.NoError(t, err)
851+
defer func() {
852+
_ = resp.Body.Close()
853+
}()
854+
require.Equal(t, test.expectStatus, resp.StatusCode)
855+
})
856+
}
857+
}
858+
729859
func makeIntrospectRequest(serverURL, token string) (*http.Response, error) {
730860
req, _ := http.NewRequest(
731861
"POST",

0 commit comments

Comments
 (0)