11package jwtmiddleware
22
33import (
4+ "context"
45 "errors"
56 "fmt"
7+ "google.golang.org/grpc"
8+ "google.golang.org/grpc/codes"
9+ "google.golang.org/grpc/status"
610 "net/http"
711)
812
@@ -28,7 +32,7 @@ type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error)
2832// DefaultErrorHandler is the default error handler implementation for the
2933// JWTMiddleware. If an error handler is not provided via the WithErrorHandler
3034// option this will be used.
31- func DefaultErrorHandler (w http.ResponseWriter , r * http.Request , err error ) {
35+ func DefaultErrorHandler (w http.ResponseWriter , _ * http.Request , err error ) {
3236 w .Header ().Set ("Content-Type" , "application/json" )
3337
3438 switch {
@@ -67,3 +71,48 @@ func (e invalidError) Error() string {
6771func (e invalidError ) Unwrap () error {
6872 return e .details
6973}
74+
75+ type GrpcErrorHandler struct {
76+ GrpcUnaryErrorHandler
77+ GrpcStreamErrorHandler
78+ }
79+
80+ type GrpcUnaryErrorHandler func (ctx context.Context , req any , info * grpc.UnaryServerInfo , handler grpc.UnaryHandler , err error ) (any , error )
81+ type GrpcStreamErrorHandler func (srv any , ss grpc.ServerStream , info * grpc.StreamServerInfo , handler grpc.StreamHandler , err error ) error
82+
83+ func DefaultGrpcErrorHandler () GrpcErrorHandler {
84+ return GrpcErrorHandler {
85+ GrpcUnaryErrorHandler : DefaultGrpcUnaryErrorHandler ,
86+ GrpcStreamErrorHandler : DefaultGrpcStreamErrorHandler ,
87+ }
88+ }
89+
90+ func DefaultGrpcUnaryErrorHandler (ctx context.Context , req any , _ * grpc.UnaryServerInfo , handler grpc.UnaryHandler , err error ) (any , error ) {
91+ if err != nil {
92+ switch {
93+ case errors .Is (err , ErrJWTMissing ):
94+ return nil , status .Errorf (codes .InvalidArgument , ErrJWTMissing .Error ())
95+ case errors .Is (err , ErrJWTInvalid ):
96+ return nil , status .Errorf (codes .Unauthenticated , ErrJWTInvalid .Error ())
97+ default :
98+ return nil , status .Errorf (codes .Internal , err .Error ())
99+ }
100+ }
101+
102+ return handler (ctx , req )
103+ }
104+
105+ func DefaultGrpcStreamErrorHandler (srv any , ss grpc.ServerStream , _ * grpc.StreamServerInfo , handler grpc.StreamHandler , err error ) error {
106+ if err != nil {
107+ switch {
108+ case errors .Is (err , ErrJWTMissing ):
109+ return status .Errorf (codes .InvalidArgument , ErrJWTMissing .Error ())
110+ case errors .Is (err , ErrJWTInvalid ):
111+ return status .Errorf (codes .Unauthenticated , ErrJWTInvalid .Error ())
112+ default :
113+ return status .Errorf (codes .Internal , err .Error ())
114+ }
115+ }
116+
117+ return handler (srv , newWrappedStream (ss ))
118+ }
0 commit comments