|
| 1 | +namespace FSharp.Data.GraphQL.Samples.StarWarsApi.Middleware |
| 2 | + |
| 3 | +open System |
| 4 | +open System.Threading.Tasks |
| 5 | +open Microsoft.FSharp.Quotations |
| 6 | +open Microsoft.FSharp.Quotations.Patterns |
| 7 | +open Microsoft.FSharp.Linq.RuntimeHelpers |
| 8 | +open Microsoft.AspNetCore.Authorization |
| 9 | +open Microsoft.AspNetCore.Http |
| 10 | +open Microsoft.Extensions.DependencyInjection |
| 11 | + |
| 12 | +open FSharp.Data.GraphQL |
| 13 | +open FSharp.Data.GraphQL.Types |
| 14 | +open FSharp.Data.GraphQL.Samples.StarWarsApi |
| 15 | + |
| 16 | +type FieldPolicyMiddleware<'Val, 'Res> = ResolveFieldContext -> 'Val -> (ResolveFieldContext -> 'Val -> Async<'Res>) -> Async<'Res> |
| 17 | + |
| 18 | +type internal CustomPolicyFieldDefinition<'Val, 'Res> (source : FieldDef<'Val, 'Res>, middleware : FieldPolicyMiddleware<'Val, 'Res>) = |
| 19 | + |
| 20 | + interface FieldDef<'Val, 'Res> with |
| 21 | + |
| 22 | + member _.Name = source.Name |
| 23 | + member _.Description = source.Description |
| 24 | + member _.DeprecationReason = source.DeprecationReason |
| 25 | + member _.TypeDef = source.TypeDef |
| 26 | + member _.Args = source.Args |
| 27 | + member _.Metadata = source.Metadata |
| 28 | + member _.Resolve = |
| 29 | + |
| 30 | + let changeAsyncResolver expr = |
| 31 | + let expr = |
| 32 | + match expr with |
| 33 | + | WithValue (_, _, e) -> e |
| 34 | + | _ -> failwith "Unexpected resolver expression." |
| 35 | + let resolver = |
| 36 | + <@ fun ctx input -> middleware ctx input (%%expr : ResolveFieldContext -> 'Val -> Async<'Res>) @> |
| 37 | + let compiledResolver = LeafExpressionConverter.EvaluateQuotation resolver |
| 38 | + Expr.WithValue (compiledResolver, resolver.Type, resolver) |
| 39 | + |
| 40 | + let changeSyncResolver expr = |
| 41 | + let expr = |
| 42 | + match expr with |
| 43 | + | WithValue (_, _, e) -> e |
| 44 | + | _ -> failwith "Unexpected resolver expression." |
| 45 | + let resolver = |
| 46 | + <@ |
| 47 | + fun ctx input -> |
| 48 | + middleware ctx input (fun ctx input -> |
| 49 | + ((%%expr : ResolveFieldContext -> 'Val -> 'Res) ctx input) |
| 50 | + |> async.Return) |
| 51 | + @> |
| 52 | + try |
| 53 | + let compiledResolver = LeafExpressionConverter.EvaluateQuotation resolver |
| 54 | + Expr.WithValue (compiledResolver, resolver.Type, resolver) |
| 55 | + with :? NotSupportedException as ex -> |
| 56 | + let message = |
| 57 | + $"F# compiler cannot convert '{source.Name}' field resolver expression to LINQ, use function instead" |
| 58 | + raise (NotSupportedException (message, ex)) |
| 59 | + |
| 60 | + match source.Resolve with |
| 61 | + | Sync (input, output, expr) -> Async (input, output, changeSyncResolver expr) |
| 62 | + | Async (input, output, expr) -> Async (input, output, changeAsyncResolver expr) |
| 63 | + | Undefined -> failwith "Field has no resolve function." |
| 64 | + | x -> failwith <| sprintf "Resolver '%A' is not supported." x |
| 65 | + |
| 66 | + interface IEquatable<FieldDef> with |
| 67 | + member _.Equals (other) = source.Equals (other) |
| 68 | + |
| 69 | + override _.Equals y = source.Equals y |
| 70 | + override _.GetHashCode () = source.GetHashCode () |
| 71 | + override _.ToString () = source.ToString () |
| 72 | + |
| 73 | +[<AutoOpen>] |
| 74 | +module TypeSystemExtensions = |
| 75 | + |
| 76 | + let handlePolicies (policies : string array) (ctx : ResolveFieldContext) value = async { |
| 77 | + |
| 78 | + let root : Root = downcast ctx.Context.RootValue |
| 79 | + let serviceProvider = root.ServiceProvider |
| 80 | + let authorizationService = serviceProvider.GetRequiredService<IAuthorizationService> () |
| 81 | + let principal = serviceProvider.GetRequiredService<IHttpContextAccessor>().HttpContext.User |
| 82 | + |
| 83 | + let! authorizationResults = |
| 84 | + policies |
| 85 | + |> Seq.map (fun p -> authorizationService.AuthorizeAsync (principal, value, p)) |
| 86 | + |> Seq.toArray |
| 87 | + |> Task.WhenAll |
| 88 | + |> Async.AwaitTask |
| 89 | + |
| 90 | + let failedRequirements = |
| 91 | + authorizationResults |
| 92 | + |> Seq.where (fun r -> not r.Succeeded) |
| 93 | + |> Seq.collect (fun r -> r.Failure.FailedRequirements) |
| 94 | + |
| 95 | + if Seq.isEmpty failedRequirements then |
| 96 | + return Ok () |
| 97 | + else |
| 98 | + return Error "Forbidden" |
| 99 | + } |
| 100 | + |
| 101 | + [<Literal>] |
| 102 | + let AuthorizationPolicy = "AuthorizationPolicy" |
| 103 | + |
| 104 | + type FieldDef<'Val, 'Res> with |
| 105 | + |
| 106 | + member field.WithPolicyMiddleware<'Val, 'Res> (middleware : FieldPolicyMiddleware<'Val, 'Res>) : FieldDef<'Val, 'Res> = |
| 107 | + upcast CustomPolicyFieldDefinition (field, middleware) |
| 108 | + |
| 109 | + member field.WithAuthorizationPolicies<'Val, 'Res> ([<ParamArray>] policies : string array) : FieldDef<'Val, 'Res> = |
| 110 | + |
| 111 | + let middleware ctx value (resolver : ResolveFieldContext -> 'Val -> Async<'Res>) : Async<'Res> = async { |
| 112 | + let! result = handlePolicies policies ctx value |
| 113 | + match result with |
| 114 | + | Ok _ -> return! resolver ctx value |
| 115 | + | Error error -> return raise (GQLMessageException error) |
| 116 | + } |
| 117 | + |
| 118 | + field.WithPolicyMiddleware<'Val, 'Res> middleware |
0 commit comments