diff --git a/Cargo.lock b/Cargo.lock index d828dd8d2..855a04b78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,6 +471,51 @@ dependencies = [ "objc2", ] +[[package]] +name = "bollard" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87a52479c9237eb04047ddb94788c41ca0d26eaff8b697ecfbb4c32f7fdc3b1b" +dependencies = [ + "base64 0.22.1", + "bollard-stubs", + "bytes", + "futures-core", + "futures-util", + "hex", + "http", + "http-body-util", + "hyper", + "hyper-named-pipe", + "hyper-util", + "hyperlocal", + "log", + "pin-project-lite", + "serde", + "serde_derive", + "serde_json", + "serde_repr", + "serde_urlencoded", + "thiserror 2.0.17", + "tokio", + "tokio-util", + "tower-service", + "url", + "winapi", +] + +[[package]] +name = "bollard-stubs" +version = "1.49.1-rc.28.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5731fe885755e92beff1950774068e0cae67ea6ec7587381536fca84f1779623" +dependencies = [ + "serde", + "serde_json", + "serde_repr", + "serde_with", +] + [[package]] name = "borrow-or-share" version = "0.2.4" @@ -1346,14 +1391,29 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" name = "e2e" version = "0.0.1" dependencies = [ + "async-trait", + "bollard", + "bytes", + "dashmap", + "futures-util", + "graphql-parser", + "graphql-tools", "hive-router", "hive-router-config", + "hive-router-plan-executor", + "hive-router-query-planner", + "http", "insta", "jsonwebtoken", "lazy_static", "mockito", + "multer", "ntex", + "r2d2", + "redis", "reqwest", + "serde", + "serde_json", "sonic-rs", "subgraphs", "tempfile", @@ -2004,6 +2064,8 @@ dependencies = [ "mimalloc", "moka", "ntex", + "ntex-service", + "ntex-util", "rand 0.9.2", "regex-automata", "reqwest", @@ -2046,6 +2108,7 @@ name = "hive-router-plan-executor" version = "6.0.1" dependencies = [ "ahash", + "arc-swap", "async-trait", "bumpalo", "bytes", @@ -2064,11 +2127,14 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", + "ntex", "ntex-http", "ordered-float", "regex-automata", + "reqwest", "ryu", "serde", + "serde_json", "sonic-rs", "strum 0.27.2", "subgraphs", @@ -2215,6 +2281,21 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-named-pipe" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73b7d8abf35697b81a825e386fc151e0d503e8cb5fcb93cc8669c376dfd6f278" +dependencies = [ + "hex", + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", + "winapi", +] + [[package]] name = "hyper-rustls" version = "0.27.7" @@ -2274,6 +2355,21 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "hyperlocal" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "986c5ce3b994526b3cd75578e62554abd09f0899d6206de48b3e96ab34ccc8c7" +dependencies = [ + "hex", + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -2797,6 +2893,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -4129,6 +4235,17 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r2d2" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" +dependencies = [ + "log", + "parking_lot 0.12.5", + "scheduled-thread-pool", +] + [[package]] name = "rancor" version = "0.1.1" @@ -4217,6 +4334,23 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redis" +version = "0.32.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "014cc767fefab6a3e798ca45112bccad9c6e0e218fbd49720042716c73cfef44" +dependencies = [ + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "r2d2", + "ryu", + "sha1_smol", + "socket2 0.6.1", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4328,6 +4462,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", + "futures-util", "h2", "http", "http-body", @@ -4339,6 +4474,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -4633,6 +4769,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot 0.12.5", +] + [[package]] name = "schemars" version = "0.9.0" @@ -4815,6 +4960,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", +] + [[package]] name = "serde_spanned" version = "1.0.3" @@ -4902,6 +5058,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" @@ -5228,6 +5390,7 @@ dependencies = [ "async-graphql", "async-graphql-axum", "axum", + "dashmap", "lazy_static", "rand 0.9.2", "sonic-rs", @@ -5804,6 +5967,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/Cargo.toml b/Cargo.toml index 3b3d217ed..9a2d8e04b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,3 +59,4 @@ retry-policies = "0.4.0" reqwest-retry = "0.7.0" reqwest-middleware = "0.4.2" vrl = { version = "0.28.0", features = ["compiler", "parser", "value", "diagnostic", "stdlib", "core"] } +bytes = "1.10.1" \ No newline at end of file diff --git a/bench/subgraphs/Cargo.toml b/bench/subgraphs/Cargo.toml index 7d1f940bf..724e74e7b 100644 --- a/bench/subgraphs/Cargo.toml +++ b/bench/subgraphs/Cargo.toml @@ -17,6 +17,7 @@ lazy_static = { workspace = true } rand = { workspace = true } tokio = { workspace = true } sonic-rs = { workspace = true } +dashmap = { workspace = true } async-graphql = "7.0.17" async-graphql-axum = "7.0.17" diff --git a/bench/subgraphs/lib.rs b/bench/subgraphs/lib.rs index 575d0f499..507de10fa 100644 --- a/bench/subgraphs/lib.rs +++ b/bench/subgraphs/lib.rs @@ -13,14 +13,12 @@ use axum::{ routing::{get, post_service}, Router, }; +use dashmap::DashMap; use sonic_rs::Value; -use std::{collections::HashMap, env::var, sync::Arc}; +use std::{env::var, sync::Arc}; use tokio::{ net::TcpListener, - sync::{ - oneshot::{self, Sender}, - Mutex, - }, + sync::oneshot::{self, Sender}, task::JoinHandle, }; @@ -55,7 +53,7 @@ async fn add_subgraph_header(req: Request, next: Next) -> Response { } async fn track_requests( - State(state): State, + State(state): State>, request: Request, next: Next, ) -> impl IntoResponse { @@ -63,9 +61,8 @@ async fn track_requests( let (parts, body) = request.into_parts(); let body_bytes = to_bytes(body, usize::MAX).await.unwrap(); let record = extract_record(&parts, body_bytes.clone()); - let mut log = state.request_log.lock().await; - log.entry(path).or_default().push(record); + state.request_log.entry(path).or_default().push(record); let new_body = axum::body::Body::from(body_bytes); let request = Request::from_parts(parts, new_body); @@ -74,7 +71,7 @@ async fn track_requests( fn extract_record(request_parts: &Parts, request_body: Bytes) -> RequestLog { let header_map = request_parts.headers.clone(); - let body_value: Value = sonic_rs::from_slice(&request_body).unwrap(); + let body_value: Value = sonic_rs::from_slice(&request_body).unwrap_or(Value::new()); RequestLog { headers: header_map, @@ -92,25 +89,24 @@ pub struct RequestLog { pub request_body: Value, } -#[derive(Clone)] pub struct SubgraphsServiceState { - pub request_log: Arc>>>, + pub request_log: DashMap>, pub health_check_url: String, } pub fn start_subgraphs_server( port: Option, -) -> (JoinHandle<()>, Sender<()>, SubgraphsServiceState) { +) -> (JoinHandle<()>, Sender<()>, Arc) { let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); let host = var("HOST").unwrap_or("0.0.0.0".to_owned()); let port = port .map(|v| v.to_string()) .unwrap_or(var("PORT").unwrap_or("4200".to_owned())); - let shared_state = SubgraphsServiceState { - request_log: Arc::new(Mutex::new(HashMap::new())), + let shared_state = Arc::new(SubgraphsServiceState { + request_log: DashMap::new(), health_check_url: format!("http://{}:{}/health", host, port), - }; + }); let app = Router::new() .route( diff --git a/bench/subgraphs/products.rs b/bench/subgraphs/products.rs index 363f5640b..3b978dff1 100644 --- a/bench/subgraphs/products.rs +++ b/bench/subgraphs/products.rs @@ -1,5 +1,10 @@ -use async_graphql::{EmptyMutation, EmptySubscription, Object, Schema, SimpleObject, ID}; +use std::io::Read; + +use async_graphql::{ + Context, EmptySubscription, InputObject, Object, Schema, SimpleObject, Upload, ID, +}; use lazy_static::lazy_static; +use tokio::io::AsyncWriteExt; lazy_static! { static ref PRODUCTS: Vec = vec![ @@ -69,6 +74,8 @@ pub struct Product { pub struct Query; +pub struct Mutation; + #[Object(extends = true)] impl Query { async fn top_products( @@ -94,8 +101,52 @@ impl Query { } } -pub fn get_subgraph() -> Schema { - Schema::build(Query, EmptyMutation, EmptySubscription) +pub fn get_subgraph() -> Schema { + Schema::build(Query, Mutation, EmptySubscription) .enable_federation() .finish() } + +#[Object(extends = true)] +impl Mutation { + async fn upload(&self, ctx: &Context<'_>, file: Option) -> String { + if file.is_none() { + return "No file uploaded".to_string(); + } + // Write to a temp location, and return the path + let uploaded_file = file.unwrap().value(ctx).unwrap(); + let path = format!("/tmp/{}", uploaded_file.filename); + let mut buf = vec![]; + let _ = uploaded_file.into_read().read_to_end(&mut buf); + let mut tmp_file_on_disk = tokio::fs::File::create(&path).await.unwrap(); + tmp_file_on_disk.write_all(&buf).await.unwrap(); + path + } + async fn oneof_test(&self, input: OneOfTestInput) -> OneOfTestResult { + OneOfTestResult { + string: input.string, + int: input.int, + float: input.float, + boolean: input.boolean, + id: input.id, + } + } +} + +#[derive(InputObject)] +struct OneOfTestInput { + pub string: Option, + pub int: Option, + pub float: Option, + pub boolean: Option, + pub id: Option, +} + +#[derive(SimpleObject)] +struct OneOfTestResult { + pub string: Option, + pub int: Option, + pub float: Option, + pub boolean: Option, + pub id: Option, +} diff --git a/bin/router/Cargo.toml b/bin/router/Cargo.toml index 4969b0a46..15d5e2051 100644 --- a/bin/router/Cargo.toml +++ b/bin/router/Cargo.toml @@ -53,3 +53,5 @@ tokio-util = "0.7.16" cookie = "0.18.1" regex-automata = "0.4.10" arc-swap = "1.7.1" +ntex-util = "2.15.0" +ntex-service = "3.5.0" \ No newline at end of file diff --git a/bin/router/src/jwt/mod.rs b/bin/router/src/jwt/mod.rs index d95854b1c..d7804e156 100644 --- a/bin/router/src/jwt/mod.rs +++ b/bin/router/src/jwt/mod.rs @@ -265,26 +265,27 @@ impl JwtAuthRuntime { Ok(token_data) } - pub fn validate_request(&self, request: &mut HttpRequest) -> Result<(), JwtError> { + pub fn validate_request( + &self, + request: &HttpRequest, + ) -> Result, JwtError> { let valid_jwks = self.jwks.all(); match self.authenticate(&valid_jwks, request) { - Ok((token_payload, maybe_token_prefix, token)) => { - request.extensions_mut().insert(JwtRequestContext { - token_payload, - token_raw: token, - token_prefix: maybe_token_prefix, - }); - } + Ok((token_payload, maybe_token_prefix, token)) => Ok(Some(JwtRequestContext { + token_payload, + token_raw: token, + token_prefix: maybe_token_prefix, + })), Err(e) => { warn!("jwt token error: {:?}", e); if self.config.require_authentication.is_some_and(|v| v) { return Err(e); } + + Ok(None) } } - - Ok(()) } } diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 6a3f7f5c0..daa5047ab 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -4,6 +4,7 @@ mod http_utils; mod jwt; mod logger; mod pipeline; +pub mod plugins; mod schema_state; mod shared_state; mod supergraph; @@ -19,11 +20,17 @@ use crate::{ }, jwt::JwtAuthRuntime, logger::configure_logging, - pipeline::graphql_request_handler, + pipeline::{ + error::PipelineError, + graphql_request_handler, + header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, + }, + plugins::plugins_service::PluginService, }; pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; +pub use crate::plugins::registry::PluginRegistry; use hive_router_config::{load_config, HiveRouterConfig}; use http::header::RETRY_AFTER; use ntex::{ @@ -33,7 +40,7 @@ use ntex::{ use tracing::{info, warn}; async fn graphql_endpoint_handler( - mut request: HttpRequest, + req: HttpRequest, body_bytes: Bytes, schema_state: web::types::State>, app_state: web::types::State>, @@ -45,26 +52,32 @@ async fn graphql_endpoint_handler( if let Some(early_response) = app_state .cors_runtime .as_ref() - .and_then(|cors| cors.get_early_response(&request)) + .and_then(|cors| cors.get_early_response(&req)) { return early_response; } - let mut res = graphql_request_handler( - &mut request, + let accept_ok = !req.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); + + let mut response = match graphql_request_handler( + &req, body_bytes, supergraph, app_state.get_ref(), schema_state.get_ref(), ) - .await; + .await + { + Ok(response_with_req) => response_with_req, + Err(error) => return PipelineError { accept_ok, error }.into(), + }; // Apply CORS headers to the final response if CORS is configured. if let Some(cors) = app_state.cors_runtime.as_ref() { - cors.set_headers(&request, res.headers_mut()); + cors.set_headers(&req, response.headers_mut()); } - res + response } else { warn!("No supergraph available yet, unable to process request"); @@ -74,7 +87,9 @@ async fn graphql_endpoint_handler( } } -pub async fn router_entrypoint() -> Result<(), Box> { +pub async fn router_entrypoint( + plugin_registry: Option, +) -> Result<(), Box> { let config_path = std::env::var("ROUTER_CONFIG_FILE_PATH").ok(); let router_config = load_config(config_path)?; configure_logging(&router_config.log); @@ -82,10 +97,11 @@ pub async fn router_entrypoint() -> Result<(), Box> { let addr = router_config.http.address(); let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager).await?; + configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_registry).await?; let maybe_error = web::HttpServer::new(move || { web::App::new() + .wrap(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) .configure(configure_ntex_app) @@ -105,17 +121,31 @@ pub async fn router_entrypoint() -> Result<(), Box> { pub async fn configure_app_from_config( router_config: HiveRouterConfig, bg_tasks_manager: &mut BackgroundTasksManager, + plugin_registry: Option, ) -> Result<(Arc, Arc), Box> { let jwt_runtime = match router_config.jwt.is_jwt_auth_enabled() { true => Some(JwtAuthRuntime::init(bg_tasks_manager, &router_config.jwt).await?), false => None, }; + let plugins = match plugin_registry { + Some(plugin_registry) => plugin_registry.initialize_plugins(&router_config)?, + None => None, + }; + let router_config_arc = Arc::new(router_config); - let schema_state = - SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone()).await?; + let shared_state = Arc::new(RouterSharedState::new( + router_config_arc.clone(), + jwt_runtime, + plugins, + )?); + let schema_state = SchemaState::new_from_config( + bg_tasks_manager, + router_config_arc.clone(), + shared_state.clone(), + ) + .await?; let schema_state_arc = Arc::new(schema_state); - let shared_state = Arc::new(RouterSharedState::new(router_config_arc, jwt_runtime)?); Ok((shared_state, schema_state_arc)) } diff --git a/bin/router/src/main.rs b/bin/router/src/main.rs index b4f250b38..162b3dfa4 100644 --- a/bin/router/src/main.rs +++ b/bin/router/src/main.rs @@ -5,7 +5,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[ntex::main] async fn main() -> Result<(), Box> { - match router_entrypoint().await { + match router_entrypoint(None).await { Ok(_) => Ok(()), Err(err) => { eprintln!("Failed to start Hive Router:\n {}", err); diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index 8c472695e..ab5759b5e 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use std::sync::Arc; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::variables::collect_variables; use hive_router_query_planner::state::supergraph_state::OperationKind; use http::Method; @@ -8,10 +9,8 @@ use ntex::web::HttpRequest; use sonic_rs::Value; use tracing::{error, trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; #[derive(Clone, Debug)] pub struct CoerceVariablesPayload { @@ -22,22 +21,22 @@ pub struct CoerceVariablesPayload { pub fn coerce_request_variables( req: &HttpRequest, supergraph: &SupergraphData, - execution_params: &mut ExecutionRequest, - normalized_operation: &Arc, -) -> Result { + graphql_params: &mut GraphQLParams, + normalized_operation: &GraphQLNormalizationPayload, +) -> Result { if req.method() == Method::GET { if let Some(OperationKind::Mutation) = normalized_operation.operation_for_plan.operation_kind { error!("Mutation is not allowed over GET, stopping"); - return Err(req.new_pipeline_error(PipelineErrorVariant::MutationNotAllowedOverHttpGet)); + return Err(PipelineErrorVariant::MutationNotAllowedOverHttpGet); } } match collect_variables( &normalized_operation.operation_for_plan, - &mut execution_params.variables, + &mut graphql_params.variables, &supergraph.metadata, ) { Ok(values) => { @@ -55,7 +54,7 @@ pub fn coerce_request_variables( "failed to collect variables from incoming request: {}", err_msg ); - Err(req.new_pipeline_error(PipelineErrorVariant::VariablesCoercionError(err_msg))) + Err(PipelineErrorVariant::VariablesCoercionError(err_msg)) } } } diff --git a/bin/router/src/pipeline/csrf_prevention.rs b/bin/router/src/pipeline/csrf_prevention.rs index 51561dd99..37c063b09 100644 --- a/bin/router/src/pipeline/csrf_prevention.rs +++ b/bin/router/src/pipeline/csrf_prevention.rs @@ -1,7 +1,7 @@ use hive_router_config::csrf::CSRFPreventionConfig; use ntex::web::HttpRequest; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; // NON_PREFLIGHTED_CONTENT_TYPES are content types that do not require a preflight // OPTIONS request. These are content types that are considered "simple" by the CORS @@ -15,9 +15,9 @@ const NON_PREFLIGHTED_CONTENT_TYPES: [&str; 3] = [ #[inline] pub fn perform_csrf_prevention( - req: &mut HttpRequest, + req: &HttpRequest, csrf_config: &CSRFPreventionConfig, -) -> Result<(), PipelineError> { +) -> Result<(), PipelineErrorVariant> { // If CSRF prevention is not configured or disabled, skip the checks. if !csrf_config.enabled || csrf_config.required_headers.is_empty() { return Ok(()); @@ -39,7 +39,7 @@ pub fn perform_csrf_prevention( if has_required_header { Ok(()) } else { - Err(req.new_pipeline_error(PipelineErrorVariant::CsrfPreventionFailed)) + Err(PipelineErrorVariant::CsrfPreventionFailed) } } diff --git a/bin/router/src/pipeline/deserialize_graphql_params.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs new file mode 100644 index 000000000..b22b18a3a --- /dev/null +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use http::Method; +use ntex::util::Bytes; +use ntex::web::types::Query; +use ntex::web::HttpRequest; +use tracing::{trace, warn}; + +use crate::pipeline::error::PipelineErrorVariant; +use crate::pipeline::header::AssertRequestJson; + +#[derive(serde::Deserialize, Debug)] +struct GETQueryParams { + pub query: Option, + #[serde(rename = "camelCase")] + pub operation_name: Option, + pub variables: Option, + pub extensions: Option, +} + +impl TryInto for GETQueryParams { + type Error = PipelineErrorVariant; + + fn try_into(self) -> Result { + let variables = match self.variables.as_deref() { + Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) { + Ok(vars) => vars, + Err(e) => { + return Err(PipelineErrorVariant::FailedToParseVariables(e)); + } + }, + _ => HashMap::new(), + }; + + let extensions = match self.extensions.as_deref() { + Some(e_str) if !e_str.is_empty() => match sonic_rs::from_str(e_str) { + Ok(exts) => Some(exts), + Err(e) => { + return Err(PipelineErrorVariant::FailedToParseExtensions(e)); + } + }, + _ => None, + }; + + let execution_request = GraphQLParams { + query: self.query, + operation_name: self.operation_name, + variables, + extensions, + }; + + Ok(execution_request) + } +} + +pub trait GetQueryStr { + fn get_query(&self) -> Result<&str, PipelineErrorVariant>; +} + +impl GetQueryStr for GraphQLParams { + fn get_query(&self) -> Result<&str, PipelineErrorVariant> { + self.query + .as_deref() + .ok_or(PipelineErrorVariant::GetMissingQueryParam("query")) + } +} + +#[inline] +pub fn deserialize_graphql_params( + req: &HttpRequest, + body_bytes: Bytes, +) -> Result { + let http_method = req.method(); + let graphql_params: GraphQLParams = match *http_method { + Method::GET => { + trace!("processing GET GraphQL operation"); + let query_params_str = req + .uri() + .query() + .ok_or_else(|| PipelineErrorVariant::GetInvalidQueryParams)?; + let query_params = Query::::from_query(query_params_str) + .map_err(PipelineErrorVariant::GetUnprocessableQueryParams)? + .0; + + trace!("parsed GET query params: {:?}", query_params); + + query_params.try_into()? + } + Method::POST => { + trace!("Processing POST GraphQL request"); + + req.assert_json_content_type()?; + + let execution_request = unsafe { + sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { + warn!("Failed to parse body: {}", e); + PipelineErrorVariant::FailedToParseBody(e) + })? + }; + + execution_request + } + _ => { + warn!("unsupported HTTP method: {}", http_method); + + return Err(PipelineErrorVariant::UnsupportedHttpMethod( + http_method.to_owned(), + )); + } + }; + + Ok(graphql_params) +} diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index eec36ea76..1d856d62e 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -10,15 +10,12 @@ use hive_router_query_planner::{ }; use http::{HeaderName, Method, StatusCode}; use ntex::{ - http::ResponseBuilder, - web::{self, error::QueryPayloadError, HttpRequest}, + http::{Response, ResponseBuilder}, + web::error::QueryPayloadError, }; use serde::{Deserialize, Serialize}; -use crate::pipeline::{ - header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, - progressive_override::LabelEvaluationError, -}; +use crate::pipeline::progressive_override::LabelEvaluationError; #[derive(Debug)] pub struct PipelineError { @@ -26,18 +23,6 @@ pub struct PipelineError { pub error: PipelineErrorVariant, } -pub trait PipelineErrorFromAcceptHeader { - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError; -} - -impl PipelineErrorFromAcceptHeader for HttpRequest { - #[inline] - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError { - let accept_ok = !self.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); - PipelineError { accept_ok, error } - } -} - #[derive(Debug, thiserror::Error)] pub enum PipelineErrorVariant { // HTTP-related errors @@ -78,7 +63,7 @@ pub enum PipelineErrorVariant { #[error("Failed to execute a plan: {0}")] PlanExecutionError(PlanExecutionError), #[error("Failed to produce a plan: {0}")] - PlannerError(Arc), + PlannerError(PlannerError), #[error(transparent)] LabelEvaluationError(LabelEvaluationError), @@ -156,11 +141,11 @@ pub struct FailedExecutionResult { pub errors: Option>, } -impl PipelineError { - pub fn into_response(self) -> web::HttpResponse { - let status = self.error.default_status_code(self.accept_ok); +impl From for Response { + fn from(val: PipelineError) -> Self { + let status = val.error.default_status_code(val.accept_ok); - if let PipelineErrorVariant::ValidationErrors(validation_errors) = self.error { + if let PipelineErrorVariant::ValidationErrors(validation_errors) = val.error { let validation_error_result = FailedExecutionResult { errors: Some(validation_errors.iter().map(|error| error.into()).collect()), }; @@ -168,8 +153,8 @@ impl PipelineError { return ResponseBuilder::new(status).json(&validation_error_result); } - let code = self.error.graphql_error_code(); - let message = self.error.graphql_error_message(); + let code = val.error.graphql_error_code(); + let message = val.error.graphql_error_message(); let graphql_error = GraphQLError::from_message_and_extensions( message, diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 42ace79ce..d870a8d45 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -1,16 +1,16 @@ use std::collections::HashMap; -use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; use crate::shared_state::RouterSharedState; -use hive_router_plan_executor::execute_query_plan; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; -use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; +use hive_router_plan_executor::execution::plan::QueryPlanExecutionContext; +use hive_router_plan_executor::executors::http::HttpResponse; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; +use hive_router_plan_executor::plugin_context::PluginRequestState; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use http::HeaderName; use ntex::web::HttpRequest; @@ -24,16 +24,18 @@ enum ExposeQueryPlanMode { DryRun, } +#[allow(clippy::too_many_arguments)] #[inline] pub async fn execute_plan( req: &HttpRequest, supergraph: &SupergraphData, - app_state: &Arc, - normalized_payload: &Arc, - query_plan_payload: &Arc, + app_state: &RouterSharedState, + normalized_payload: &GraphQLNormalizationPayload, + query_plan_payload: &QueryPlan, variable_payload: &CoerceVariablesPayload, client_request_details: &ClientRequestDetails<'_, '_>, -) -> Result { + plugin_req_state: &Option>, +) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; if app_state.router_config.query_planner.allow_expose { @@ -65,7 +67,7 @@ pub async fn execute_plan( metadata: &supergraph.metadata, }; - let jwt_forward_plan: Option = if app_state + let jwt_auth_forwarding: Option = if app_state .router_config .jwt .is_jwt_extensions_forwarding_enabled() @@ -79,13 +81,15 @@ pub async fn execute_plan( .forward_claims_to_upstream_extensions .field_name, ) - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))? + .map_err(PipelineErrorVariant::JwtForwardingError)? } else { None }; - execute_query_plan(QueryPlanExecutionContext { + let ctx = QueryPlanExecutionContext { + plugin_req_state, query_plan: query_plan_payload, + operation_for_plan: &normalized_payload.operation_for_plan, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, variable_values: &variable_payload.variables_map, @@ -93,12 +97,12 @@ pub async fn execute_plan( client_request: client_request_details, introspection_context: &introspection_context, operation_type_name: normalized_payload.root_type_name, - jwt_auth_forwarding: &jwt_forward_plan, + jwt_auth_forwarding, executors: &supergraph.subgraph_executor_map, - }) - .await - .map_err(|err| { + }; + + ctx.execute_query_plan().await.map_err(|err| { tracing::error!("Failed to execute query plan: {}", err); - req.new_pipeline_error(PipelineErrorVariant::PlanExecutionError(err)) + PipelineErrorVariant::PlanExecutionError(err) }) } diff --git a/bin/router/src/pipeline/execution_request.rs b/bin/router/src/pipeline/execution_request.rs deleted file mode 100644 index c17a6f355..000000000 --- a/bin/router/src/pipeline/execution_request.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::collections::HashMap; - -use http::Method; -use ntex::util::Bytes; -use ntex::web::types::Query; -use ntex::web::HttpRequest; -use serde::{Deserialize, Deserializer}; -use sonic_rs::Value; -use tracing::{trace, warn}; - -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::header::AssertRequestJson; - -#[derive(serde::Deserialize, Debug)] -struct GETQueryParams { - pub query: Option, - #[serde(rename = "camelCase")] - pub operation_name: Option, - pub variables: Option, - pub extensions: Option, -} - -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct ExecutionRequest { - pub query: String, - pub operation_name: Option, - #[serde(default, deserialize_with = "deserialize_null_default")] - pub variables: HashMap, - // TODO: We don't use extensions yet, but we definitely will in the future. - #[allow(dead_code)] - pub extensions: Option>, -} - -fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result -where - T: Default + Deserialize<'de>, - D: Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) -} - -impl TryInto for GETQueryParams { - type Error = PipelineErrorVariant; - - fn try_into(self) -> Result { - let query = match self.query { - Some(q) => q, - None => return Err(PipelineErrorVariant::GetMissingQueryParam("query")), - }; - - let variables = match self.variables.as_deref() { - Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) { - Ok(vars) => vars, - Err(e) => { - return Err(PipelineErrorVariant::FailedToParseVariables(e)); - } - }, - _ => HashMap::new(), - }; - - let extensions = match self.extensions.as_deref() { - Some(e_str) if !e_str.is_empty() => match sonic_rs::from_str(e_str) { - Ok(exts) => Some(exts), - Err(e) => { - return Err(PipelineErrorVariant::FailedToParseExtensions(e)); - } - }, - _ => None, - }; - - let execution_request = ExecutionRequest { - query, - operation_name: self.operation_name, - variables, - extensions, - }; - - Ok(execution_request) - } -} - -#[inline] -pub async fn get_execution_request( - req: &mut HttpRequest, - body_bytes: Bytes, -) -> Result { - let http_method = req.method(); - let execution_request: ExecutionRequest = match *http_method { - Method::GET => { - trace!("processing GET GraphQL operation"); - let query_params_str = req.uri().query().ok_or_else(|| { - req.new_pipeline_error(PipelineErrorVariant::GetInvalidQueryParams) - })?; - let query_params = Query::::from_query(query_params_str) - .map_err(|e| { - req.new_pipeline_error(PipelineErrorVariant::GetUnprocessableQueryParams(e)) - })? - .0; - - trace!("parsed GET query params: {:?}", query_params); - - query_params - .try_into() - .map_err(|err| req.new_pipeline_error(err))? - } - Method::POST => { - trace!("Processing POST GraphQL request"); - - req.assert_json_content_type()?; - - let execution_request = unsafe { - sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { - warn!("Failed to parse body: {}", e); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseBody(e)) - })? - }; - - execution_request - } - _ => { - warn!("unsupported HTTP method: {}", http_method); - - return Err( - req.new_pipeline_error(PipelineErrorVariant::UnsupportedHttpMethod( - http_method.to_owned(), - )), - ); - } - }; - - Ok(execution_request) -} diff --git a/bin/router/src/pipeline/header.rs b/bin/router/src/pipeline/header.rs index 92a591235..19ea8c7af 100644 --- a/bin/router/src/pipeline/header.rs +++ b/bin/router/src/pipeline/header.rs @@ -6,7 +6,7 @@ use lazy_static::lazy_static; use ntex::web::HttpRequest; use tracing::{trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; lazy_static! { pub static ref APPLICATION_JSON_STR: &'static str = "application/json"; @@ -34,31 +34,29 @@ impl RequestAccepts for HttpRequest { } pub trait AssertRequestJson { - fn assert_json_content_type(&self) -> Result<(), PipelineError>; + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant>; } impl AssertRequestJson for HttpRequest { #[inline] - fn assert_json_content_type(&self) -> Result<(), PipelineError> { + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant> { match self.headers().get(CONTENT_TYPE) { Some(value) => { - let content_type_str = value.to_str().map_err(|_| { - self.new_pipeline_error(PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE)) - })?; + let content_type_str = value + .to_str() + .map_err(|_| PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE))?; if !content_type_str.contains(*APPLICATION_JSON_STR) { warn!( "Invalid content type on a POST request: {}", content_type_str ); - return Err( - self.new_pipeline_error(PipelineErrorVariant::UnsupportedContentType) - ); + return Err(PipelineErrorVariant::UnsupportedContentType); } Ok(()) } None => { trace!("POST without content type detected"); - Err(self.new_pipeline_error(PipelineErrorVariant::MissingContentTypeHeader)) + Err(PipelineErrorVariant::MissingContentTypeHeader) } } } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 2b4721972..104bcd19a 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,8 +1,16 @@ use std::sync::Arc; -use hive_router_plan_executor::execution::{ - client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::PlanExecutionOutput, +use hive_router_plan_executor::{ + execution::client_request_details::{ + ClientRequestDetails, JwtRequestDetails, OperationDetails, + }, + executors::http::HttpResponse, + hooks::{ + on_graphql_params::{OnGraphQLParamsEndHookPayload, OnGraphQLParamsStartHookPayload}, + on_supergraph_load::SupergraphData, + }, + plugin_context::{PluginContext, PluginRequestState, RouterHttpRequest}, + plugin_trait::{EndControlFlow, StartControlFlow}, }; use hive_router_query_planner::{ state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, @@ -18,29 +26,29 @@ use crate::{ pipeline::{ coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, - error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, + deserialize_graphql_params::{deserialize_graphql_params, GetQueryStr}, + error::PipelineErrorVariant, execution::execute_plan, - execution_request::get_execution_request, header::{ RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, TEXT_HTML_CONTENT_TYPE, }, normalize::normalize_request_with_cache, - parser::parse_operation_with_cache, + parser::{parse_operation_with_cache, ParseResult}, progressive_override::request_override_context, - query_plan::plan_operation_with_cache, + query_plan::{plan_operation_with_cache, QueryPlanResult}, validation::validate_operation_with_cache, }, - schema_state::{SchemaState, SupergraphData}, + schema_state::SchemaState, shared_state::RouterSharedState, }; pub mod coerce_variables; pub mod cors; pub mod csrf_prevention; +pub mod deserialize_graphql_params; pub mod error; pub mod execution; -pub mod execution_request; pub mod header; pub mod normalize; pub mod parser; @@ -52,96 +60,192 @@ static GRAPHIQL_HTML: &str = include_str!("../../static/graphiql.html"); #[inline] pub async fn graphql_request_handler( - req: &mut HttpRequest, + req: &HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, - schema_state: &Arc, -) -> web::HttpResponse { + shared_state: &RouterSharedState, + schema_state: &SchemaState, +) -> Result { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { - return web::HttpResponse::Ok() + return Ok(web::HttpResponse::Ok() .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) - .body(GRAPHIQL_HTML); + .body(GRAPHIQL_HTML)); } else { - return web::HttpResponse::NotFound().into(); + return Ok(web::HttpResponse::NotFound().into()); } } - if let Some(jwt) = &shared_state.jwt_auth_runtime { + let jwt_context = if let Some(jwt) = &shared_state.jwt_auth_runtime { match jwt.validate_request(req) { - Ok(_) => (), - Err(err) => return err.make_response(), + Ok(jwt_context) => jwt_context, + Err(err) => return Ok(err.make_response()), } + } else { + None + }; + + let response_content_type: &'static HeaderValue = + if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { + &APPLICATION_GRAPHQL_RESPONSE_JSON + } else { + &APPLICATION_JSON + }; + + let mut plugin_req_state = None; + if let Some(plugins) = shared_state.plugins.as_ref() { + let plugin_context = req + .extensions() + .get::>() + .cloned() + .expect("Plugin manager should be loaded"); + + plugin_req_state = Some(PluginRequestState { + plugins: plugins.clone(), + router_http_request: RouterHttpRequest { + uri: req.uri(), + method: req.method(), + version: req.version(), + headers: req.headers(), + match_info: req.match_info(), + query_string: req.query_string(), + path: req.path(), + }, + context: plugin_context, + }); } - match execute_pipeline(req, body_bytes, supergraph, shared_state, schema_state).await { - Ok(response) => { - let response_bytes = Bytes::from(response.body); - let response_headers = response.headers; - - let response_content_type: &'static HeaderValue = - if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { - &APPLICATION_GRAPHQL_RESPONSE_JSON - } else { - &APPLICATION_JSON - }; - - let mut response_builder = web::HttpResponse::Ok(); - for (header_name, header_value) in response_headers { - if let Some(header_name) = header_name { - response_builder.header(header_name, header_value); - } - } + let response = execute_pipeline( + req, + body_bytes, + supergraph, + shared_state, + schema_state, + jwt_context, + plugin_req_state, + ) + .await?; + let response_status = response.status; + let response_bytes = response.body; + let response_headers = response.headers; - response_builder - .header(http::header::CONTENT_TYPE, response_content_type) - .body(response_bytes) + let mut response_builder = web::HttpResponse::Ok(); + for (header_name, header_value) in response_headers { + if let Some(header_name) = header_name { + response_builder.header(header_name, header_value); } - Err(err) => err.into_response(), } + + Ok(response_builder + .header(http::header::CONTENT_TYPE, response_content_type) + .status(response_status) + .body(response_bytes.to_vec())) } #[inline] #[allow(clippy::await_holding_refcell_ref)] pub async fn execute_pipeline( - req: &mut HttpRequest, - body_bytes: Bytes, + req: &HttpRequest, + body: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, - schema_state: &Arc, -) -> Result { + shared_state: &RouterSharedState, + schema_state: &SchemaState, + jwt_context: Option, + plugin_req_state: Option>, +) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; - let mut execution_request = get_execution_request(req, body_bytes).await?; - let parser_payload = parse_operation_with_cache(req, shared_state, &execution_request).await?; - validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) - .await?; + /* Handle on_deserialize hook in the plugins - START */ + let mut deserialization_end_callbacks = vec![]; - let normalize_payload = normalize_request_with_cache( - req, + let mut graphql_params = None; + let mut body = body; + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut deserialization_payload: OnGraphQLParamsStartHookPayload = + OnGraphQLParamsStartHookPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + body, + graphql_params: None, + }; + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_graphql_params(deserialization_payload).await; + deserialization_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { /* continue to next plugin */ } + StartControlFlow::EndResponse(response) => { + return Ok(response); + } + StartControlFlow::OnEnd(callback) => { + deserialization_end_callbacks.push(callback); + } + } + } + graphql_params = deserialization_payload.graphql_params; + body = deserialization_payload.body; + } + let mut graphql_params = match graphql_params { + Some(params) => params, + None => deserialize_graphql_params(req, body)?, + }; + + if let Some(plugin_req_state) = &plugin_req_state { + let mut payload = OnGraphQLParamsEndHookPayload { + graphql_params, + context: &plugin_req_state.context, + }; + for deserialization_end_callback in deserialization_end_callbacks { + let result = deserialization_end_callback(payload); + payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { /* continue to next plugin */ } + EndControlFlow::EndResponse(response) => { + return Ok(response); + } + } + } + graphql_params = payload.graphql_params; + } + + /* Handle on_deserialize hook in the plugins - END */ + + let parser_result = + parse_operation_with_cache(shared_state, &graphql_params, &plugin_req_state).await?; + + let parser_payload = match parser_result { + ParseResult::Payload(payload) => payload, + ParseResult::Response(response) => { + return Ok(response); + } + }; + + validate_operation_with_cache( supergraph, schema_state, - &execution_request, + shared_state, &parser_payload, + &plugin_req_state, ) .await?; + + let normalize_payload = + normalize_request_with_cache(supergraph, schema_state, &graphql_params, &parser_payload) + .await?; + let variable_payload = - coerce_request_variables(req, supergraph, &mut execution_request, &normalize_payload)?; + coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); - let req_extensions = req.extensions(); - let jwt_context = req_extensions.get::(); let jwt_request_details = match jwt_context { Some(jwt_context) => JwtRequestDetails::Authenticated { - token: jwt_context.token_raw.as_str(), - prefix: jwt_context.token_prefix.as_deref(), scopes: jwt_context.extract_scopes(), - claims: &jwt_context + claims: jwt_context .get_claims_value() - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?, + .map_err(PipelineErrorVariant::JwtForwardingError)?, + token: jwt_context.token_raw, + prefix: jwt_context.token_prefix, }, None => JwtRequestDetails::Unauthenticated, }; @@ -158,26 +262,32 @@ pub async fn execute_pipeline( Some(OperationKind::Subscription) => "subscription", None => "query", }, - query: &execution_request.query, + query: graphql_params.get_query()?, }, - jwt: &jwt_request_details, + jwt: jwt_request_details, }; let progressive_override_ctx = request_override_context( &shared_state.override_labels_evaluator, &client_request_details, ) - .map_err(|error| req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)))?; + .map_err(PipelineErrorVariant::LabelEvaluationError)?; - let query_plan_payload = plan_operation_with_cache( - req, + let query_plan_result = plan_operation_with_cache( supergraph, schema_state, &normalize_payload, &progressive_override_ctx, &query_plan_cancellation_token, + &plugin_req_state, ) .await?; + let query_plan_payload = match query_plan_result { + QueryPlanResult::QueryPlan(plan) => plan, + QueryPlanResult::Response(response) => { + return Ok(response); + } + }; let execution_result = execute_plan( req, @@ -187,6 +297,7 @@ pub async fn execute_pipeline( &query_plan_payload, &variable_payload, &client_request_details, + &plugin_req_state, ) .await?; diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index 4fc2cc5ef..97cbb80ac 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -1,17 +1,17 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; use hive_router_query_planner::ast::operation::OperationDefinition; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::SchemaState; use tracing::{error, trace}; #[derive(Debug)] @@ -25,16 +25,15 @@ pub struct GraphQLNormalizationPayload { #[inline] pub async fn normalize_request_with_cache( - req: &HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - execution_params: &ExecutionRequest, + schema_state: &SchemaState, + graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, -) -> Result, PipelineError> { - let cache_key = match &execution_params.operation_name { +) -> Result, PipelineErrorVariant> { + let cache_key = match &graphql_params.operation_name { Some(operation_name) => { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); operation_name.hash(&mut hasher); hasher.finish() } @@ -54,7 +53,7 @@ pub async fn normalize_request_with_cache( None => match normalize_operation( &supergraph.planner.supergraph, &parser_payload.parsed_operation, - execution_params.operation_name.as_deref(), + graphql_params.operation_name.as_deref(), ) { Ok(doc) => { trace!( @@ -86,7 +85,7 @@ pub async fn normalize_request_with_cache( error!("Failed to normalize GraphQL operation: {}", err); trace!("{:?}", err); - Err(req.new_pipeline_error(PipelineErrorVariant::NormalizationError(err))) + Err(PipelineErrorVariant::NormalizationError(err)) } }, } diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 6e8a37141..2ccef500e 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,12 +2,18 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; +use hive_router_plan_executor::executors::http::HttpResponse; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_parse::{ + OnGraphQLParseEndHookPayload, OnGraphQLParseStartHookPayload, +}; +use hive_router_plan_executor::plugin_context::PluginRequestState; +use hive_router_plan_executor::plugin_trait::{EndControlFlow, StartControlFlow}; use hive_router_query_planner::utils::parsing::safe_parse_operation; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; +use crate::pipeline::deserialize_graphql_params::GetQueryStr; +use crate::pipeline::error::PipelineErrorVariant; use crate::shared_state::RouterSharedState; use tracing::{error, trace}; @@ -17,15 +23,20 @@ pub struct GraphQLParserPayload { pub cache_key: u64, } +pub enum ParseResult { + Payload(GraphQLParserPayload), + Response(HttpResponse), +} + #[inline] pub async fn parse_operation_with_cache( - req: &HttpRequest, - app_state: &Arc, - execution_params: &ExecutionRequest, -) -> Result { + app_state: &RouterSharedState, + graphql_params: &GraphQLParams, + plugin_req_state: &Option>, +) -> Result { let cache_key = { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); hasher.finish() }; @@ -33,12 +44,64 @@ pub async fn parse_operation_with_cache( trace!("Found cached parsed operation for query"); cached } else { - let parsed = safe_parse_operation(&execution_params.query).map_err(|err| { - error!("Failed to parse GraphQL operation: {}", err); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) - })?; - trace!("sucessfully parsed GraphQL operation"); - let parsed_arc = Arc::new(parsed); + let mut document = None; + let mut on_end_callbacks = vec![]; + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + /* Handle on_graphql_parse hook in the plugins - START */ + let mut start_payload = OnGraphQLParseStartHookPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + graphql_params, + document, + }; + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_graphql_parse(start_payload).await; + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { + // continue to next plugin + } + StartControlFlow::EndResponse(response) => { + return Ok(ParseResult::Response(response)); + } + StartControlFlow::OnEnd(callback) => { + // store the callback to be called later + on_end_callbacks.push(callback); + } + } + } + document = start_payload.document; + } + + let document = match document { + Some(parsed) => parsed, + None => { + let query_str = graphql_params.get_query()?; + let parsed = safe_parse_operation(query_str).map_err(|err| { + error!("Failed to parse GraphQL operation: {}", err); + PipelineErrorVariant::FailedToParseOperation(err) + })?; + trace!("successfully parsed GraphQL operation"); + parsed + } + }; + let mut end_payload = OnGraphQLParseEndHookPayload { document }; + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { + // continue to next callback + } + EndControlFlow::EndResponse(response) => { + return Ok(ParseResult::Response(response)); + } + } + } + let document = end_payload.document; + /* Handle on_graphql_parse hook in the plugins - END */ + + let parsed_arc = Arc::new(document); app_state .parse_cache .insert(cache_key, parsed_arc.clone()) @@ -46,8 +109,8 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok(GraphQLParserPayload { + Ok(ParseResult::Payload(GraphQLParserPayload { parsed_operation, cache_key, - }) + })) } diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index d0b09c183..4743d8672 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec, 'req>( +pub fn request_override_context( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec, 'req>, + client_request_details: &ClientRequestDetails<'_, '_>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index b2f730be7..87e25b627 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -1,24 +1,41 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::SchemaState; +use hive_router_plan_executor::executors::http::HttpResponse; +use hive_router_plan_executor::hooks::on_query_plan::{ + OnQueryPlanEndHookPayload, OnQueryPlanStartHookPayload, +}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginRequestState; +use hive_router_plan_executor::plugin_trait::{EndControlFlow, StartControlFlow}; use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; +pub enum QueryPlanResult { + QueryPlan(Arc), + Response(HttpResponse), +} + +pub enum QueryPlanGetterError { + Planner(PlannerError), + Response(HttpResponse), +} + #[inline] -pub async fn plan_operation_with_cache( - req: &HttpRequest, +pub async fn plan_operation_with_cache<'req>( supergraph: &SupergraphData, - schema_state: &Arc, - normalized_operation: &Arc, + schema_state: &SchemaState, + normalized_operation: &GraphQLNormalizationPayload, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, -) -> Result, PipelineError> { + plugin_req_state: &Option>, +) -> Result { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -30,7 +47,7 @@ pub async fn plan_operation_with_cache( let plan_result = schema_state .plan_cache - .try_get_with(plan_cache_key, async move { + .try_get_with(plan_cache_key, async { if is_pure_introspection { return Ok(Arc::new(QueryPlan { kind: "QueryPlan".to_string(), @@ -38,20 +55,80 @@ pub async fn plan_operation_with_cache( })); } - supergraph - .planner - .plan_from_normalized_operation( + let mut query_plan: Option = None; + let mut on_end_callbacks = vec![]; + + if let Some(plugin_req_state) = plugin_req_state { + /* Handle on_query_plan hook in the plugins - START */ + let mut start_payload = OnQueryPlanStartHookPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, filtered_operation_for_plan, - (&request_override_context.clone()).into(), + planner_override_context: (&request_override_context.clone()).into(), cancellation_token, - ) - .map(Arc::new) + query_plan, + planner: &supergraph.planner, + }; + + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_query_plan(start_payload).await; + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { + // continue to next plugin + } + StartControlFlow::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + StartControlFlow::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + query_plan = start_payload.query_plan; + } + + let query_plan = match query_plan { + Some(plan) => plan, + None => supergraph + .planner + .plan_from_normalized_operation( + filtered_operation_for_plan, + (&request_override_context.clone()).into(), + cancellation_token, + ) + .map_err(QueryPlanGetterError::Planner)?, + }; + + let mut end_payload = OnQueryPlanEndHookPayload { query_plan }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { + // continue to next callback + } + EndControlFlow::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + } + } + + Ok(Arc::new(end_payload.query_plan)) + /* Handle on_query_plan hook in the plugins - END */ }) .await; match plan_result { - Ok(plan) => Ok(plan), - Err(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), + Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), + Err(e) => match e.as_ref() { + QueryPlanGetterError::Planner(e) => Err(PipelineErrorVariant::PlannerError(e.clone())), + QueryPlanGetterError::Response(response) => { + Ok(QueryPlanResult::Response(response.clone())) + } + }, } } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 85d44c2f1..da03ddada 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -1,21 +1,27 @@ use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; use graphql_tools::validation::validate::validate; -use ntex::web::HttpRequest; +use hive_router_plan_executor::executors::http::HttpResponse; +use hive_router_plan_executor::hooks::on_graphql_validation::{ + OnGraphQLValidationEndHookPayload, OnGraphQLValidationStartHookPayload, +}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginRequestState; +use hive_router_plan_executor::plugin_trait::{EndControlFlow, StartControlFlow}; use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( - req: &HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - app_state: &Arc, + schema_state: &SchemaState, + app_state: &RouterSharedState, parser_payload: &GraphQLParserPayload, -) -> Result<(), PipelineError> { + plugin_req_state: &Option>, +) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; let validation_result = match schema_state @@ -37,12 +43,60 @@ pub async fn validate_operation_with_cache( parser_payload.cache_key ); - let res = validate( - consumer_schema_ast, - &parser_payload.parsed_operation, - &app_state.validation_plan, - ); - let arc_res = Arc::new(res); + let mut on_end_callbacks = vec![]; + let document = &parser_payload.parsed_operation; + let errors = if let Some(plugin_req_state) = plugin_req_state.as_ref() { + /* Handle on_graphql_validate hook in the plugins - START */ + let mut start_payload = OnGraphQLValidationStartHookPayload::new( + plugin_req_state, + consumer_schema_ast, + document, + &app_state.validation_plan, + ); + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_graphql_validation(start_payload).await; + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { + // continue to next plugin + } + StartControlFlow::EndResponse(response) => { + return Ok(Some(response)); + } + StartControlFlow::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + match start_payload.errors { + Some(errors) => errors, + None => validate( + consumer_schema_ast, + start_payload.document, + start_payload.get_validation_plan(), + ), + } + } else { + validate(consumer_schema_ast, document, &app_state.validation_plan) + }; + + let mut end_payload = OnGraphQLValidationEndHookPayload { errors }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { + // continue to next callback + } + EndControlFlow::EndResponse(response) => { + return Ok(Some(response)); + } + } + } + /* Handle on_graphql_validate hook in the plugins - END */ + + let arc_res = Arc::new(end_payload.errors); schema_state .validate_cache @@ -59,10 +113,8 @@ pub async fn validate_operation_with_cache( ); trace!("Validation errors: {:?}", validation_result); - return Err( - req.new_pipeline_error(PipelineErrorVariant::ValidationErrors(validation_result)) - ); + return Err(PipelineErrorVariant::ValidationErrors(validation_result)); } - Ok(()) + Ok(None) } diff --git a/bin/router/src/plugins/mod.rs b/bin/router/src/plugins/mod.rs new file mode 100644 index 000000000..6ffeef508 --- /dev/null +++ b/bin/router/src/plugins/mod.rs @@ -0,0 +1,2 @@ +pub mod plugins_service; +pub mod registry; diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs new file mode 100644 index 000000000..b88728379 --- /dev/null +++ b/bin/router/src/plugins/plugins_service.rs @@ -0,0 +1,121 @@ +use std::sync::Arc; + +use hive_router_plan_executor::{ + hooks::on_http_request::{OnHttpRequestHookPayload, OnHttpResponseHookPayload}, + plugin_context::PluginContext, + plugin_trait::{EndControlFlow, StartControlFlow}, +}; +use ntex::{ + http::ResponseBuilder, + service::{Service, ServiceCtx}, + web::{self, DefaultError}, + Middleware, +}; + +use crate::RouterSharedState; + +pub struct PluginService; + +impl Middleware for PluginService { + type Service = PluginMiddleware; + + fn create(&self, service: S) -> Self::Service { + PluginMiddleware { service } + } +} + +pub struct PluginMiddleware { + // This is special: We need this to avoid lifetime issues. + service: S, +} + +impl Service> for PluginMiddleware +where + S: Service, Response = web::WebResponse, Error = web::Error>, +{ + type Response = web::WebResponse; + type Error = S::Error; + + ntex::forward_ready!(service); + + async fn call( + &self, + req: web::WebRequest, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + let plugins = req + .app_state::>() + .and_then(|shared_state| shared_state.plugins.clone()); + + if let Some(plugins) = plugins.as_ref() { + let plugin_context = Arc::new(PluginContext::default()); + req.extensions_mut().insert(plugin_context.clone()); + + let mut start_payload = OnHttpRequestHookPayload { + router_http_request: req, + context: &plugin_context, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.iter() { + let result = plugin.on_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { + // continue to next plugin + } + StartControlFlow::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + StartControlFlow::EndResponse(response) => { + let mut resp_builder = ResponseBuilder::new(response.status); + for (key, value) in response.headers { + if let Some(key) = key { + resp_builder.header(key, value); + } + } + let response = start_payload + .router_http_request + .into_response(resp_builder.body(response.body.to_vec())); + return Ok(response); + } + } + } + + let req = start_payload.router_http_request; + + let response = ctx.call(&self.service, req).await?; + + let mut end_payload = OnHttpResponseHookPayload { + response, + context: &plugin_context, + }; + + for callback in on_end_callbacks.into_iter().rev() { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { + // continue to next callback + } + EndControlFlow::EndResponse(response) => { + let mut resp_builder = ResponseBuilder::new(response.status); + for (key, value) in response.headers { + if let Some(key) = key { + resp_builder.header(key, value); + } + } + let response = resp_builder.body(response.body.to_vec()); + end_payload.response = end_payload.response.into_response(response); + return Ok(end_payload.response); + } + } + } + + return Ok(end_payload.response); + } + + ctx.call(&self.service, req).await + } +} diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs new file mode 100644 index 000000000..0383c8000 --- /dev/null +++ b/bin/router/src/plugins/registry.rs @@ -0,0 +1,80 @@ +use std::collections::HashMap; + +use hive_router_config::HiveRouterConfig; +use hive_router_plan_executor::plugin_trait::{RouterPluginBoxed, RouterPluginWithConfig}; +use serde_json::Value; +use tracing::info; + +type PluginFactory = Box serde_json::Result>>; + +pub struct PluginRegistry { + map: HashMap<&'static str, PluginFactory>, +} + +impl Default for PluginRegistry { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, thiserror::Error)] +pub enum PluginRegistryError { + #[error("Failed to initialize plugin '{0}': {1}")] + Config(String, serde_json::Error), + #[error( + "Plugin '{0}' is not registered in the registry but is specified in the configuration" + )] + MissingInRegistry(String), +} + +impl PluginRegistry { + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + pub fn register(mut self) -> Self { + self.map.insert( + P::plugin_name(), + Box::new(|plugin_config: Value| { + let config: P::Config = serde_json::from_value(plugin_config)?; + match P::from_config(config) { + Some(plugin) => Ok(Some(Box::new(plugin))), + None => Ok(None), + } + }), + ); + self + } + pub fn initialize_plugins( + &self, + router_config: &HiveRouterConfig, + ) -> Result>, PluginRegistryError> { + let mut plugins: Vec = vec![]; + + for (plugin_name, plugin_config_value) in router_config.plugins.iter() { + if let Some(factory) = self.map.get(plugin_name.as_str()) { + match factory(plugin_config_value.clone()) { + Ok(plugin) => { + info!("Loaded plugin: {}", plugin_name); + match plugin { + Some(plugin) => plugins.push(plugin), + None => info!("Plugin '{}' is disabled, skipping", plugin_name), + } + } + Err(err) => { + return Err(PluginRegistryError::Config(plugin_name.clone(), err)); + } + } + } else { + return Err(PluginRegistryError::MissingInRegistry(plugin_name.clone())); + } + } + + if plugins.is_empty() { + Ok(None) + } else { + Ok(Some(plugins)) + } + } +} diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index f14cc6cf0..ed88b29b7 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -1,10 +1,14 @@ use arc_swap::{ArcSwap, Guard}; use async_trait::async_trait; -use graphql_tools::validation::utils::ValidationError; +use graphql_tools::{static_graphql::schema::Document, validation::utils::ValidationError}; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; use hive_router_plan_executor::{ executors::error::SubgraphExecutorError, - introspection::schema::{SchemaMetadata, SchemaWithMetadata}, + hooks::on_supergraph_load::{ + OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload, SupergraphData, + }, + introspection::schema::SchemaWithMetadata, + plugin_trait::{EndControlFlow, StartControlFlow}, SubgraphExecutorMap, }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -26,6 +30,7 @@ use crate::{ base::{LoadSupergraphError, ReloadSupergraphResult, SupergraphLoader}, resolve_from_config, }, + RouterSharedState, }; pub struct SchemaState { @@ -35,12 +40,6 @@ pub struct SchemaState { pub normalize_cache: Cache>, } -pub struct SupergraphData { - pub metadata: SchemaMetadata, - pub planner: Planner, - pub subgraph_executor_map: SubgraphExecutorMap, -} - #[derive(Debug, thiserror::Error)] pub enum SupergraphManagerError { #[error("Failed to load supergraph: {0}")] @@ -65,6 +64,7 @@ impl SchemaState { pub async fn new_from_config( bg_tasks_manager: &mut BackgroundTasksManager, router_config: Arc, + app_state: Arc, ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new(&router_config.supergraph, tx)?; @@ -85,9 +85,57 @@ impl SchemaState { while let Some(new_sdl) = rx.recv().await { debug!("Received new supergraph SDL, building new supergraph state..."); - match Self::build_data(router_config.clone(), &new_sdl) { - Ok(new_data) => { - swappable_data_spawn_clone.store(Arc::new(Some(new_data))); + let mut new_ast = parse_schema(&new_sdl); + + let mut on_end_callbacks = vec![]; + + if let Some(plugins) = app_state.plugins.as_ref() { + let mut start_payload = OnSupergraphLoadStartHookPayload { + current_supergraph_data: swappable_data_spawn_clone.clone(), + new_ast, + }; + for plugin in plugins.as_ref() { + let result = plugin.on_supergraph_reload(start_payload); + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { + // continue to next plugin + } + StartControlFlow::EndResponse(_) => { + unreachable!("Plugins should not end supergraph reload processing"); + } + StartControlFlow::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + new_ast = start_payload.new_ast; + } + + match Self::build_data(router_config.clone(), &new_ast) { + Ok(new_supergraph_data) => { + let mut end_payload = OnSupergraphLoadEndHookPayload { + new_supergraph_data, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { + // continue to next callback + } + EndControlFlow::EndResponse(_) => { + unreachable!( + "Plugins should not end supergraph reload processing" + ); + } + } + } + + let new_supergraph_data = end_payload.new_supergraph_data; + + swappable_data_spawn_clone.store(Arc::new(Some(new_supergraph_data))); debug!("Supergraph updated successfully"); task_plan_cache.invalidate_all(); @@ -112,11 +160,10 @@ impl SchemaState { fn build_data( router_config: Arc, - supergraph_sdl: &str, + parsed_supergraph_sdl: &Document, ) -> Result { - let parsed_supergraph_sdl = parse_schema(supergraph_sdl); - let supergraph_state = SupergraphState::new(&parsed_supergraph_sdl); - let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; + let supergraph_state = SupergraphState::new(parsed_supergraph_sdl); + let planner = Planner::new_from_supergraph(parsed_supergraph_sdl)?; let metadata = planner.consumer_schema.schema_metadata(); let subgraph_executor_map = SubgraphExecutorMap::from_http_endpoint_map( supergraph_state.subgraph_endpoint_map, diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index f36bda6cd..bffcfc4b4 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -3,6 +3,7 @@ use hive_router_config::HiveRouterConfig; use hive_router_plan_executor::headers::{ compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan, }; +use hive_router_plan_executor::plugin_trait::RouterPluginBoxed; use moka::future::Cache; use std::sync::Arc; @@ -18,12 +19,14 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, + pub plugins: Option>>, } impl RouterSharedState { pub fn new( router_config: Arc, jwt_auth_runtime: Option, + plugins: Option>, ) -> Result { Ok(Self { validation_plan: graphql_tools::validation::rules::default_rules_validation_plan(), @@ -36,6 +39,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, + plugins: plugins.map(Arc::new), }) } } diff --git a/docs/README.md b/docs/README.md index fb474b9d2..0a61d1baa 100644 --- a/docs/README.md +++ b/docs/README.md @@ -13,6 +13,7 @@ |[**log**](#log)|`object`|The router logger configuration.
Default: `{"filter":null,"format":"json","level":"info"}`
|| |[**override\_labels**](#override_labels)|`object`|Configuration for overriding labels.
|| |[**override\_subgraph\_urls**](#override_subgraph_urls)|`object`|Configuration for overriding subgraph URLs.
Default: `{}`
|| +|[**plugins**](#plugins)|`object`|Configuration for custom plugins
|| |[**query\_planner**](#query_planner)|`object`|Query planning configuration.
Default: `{"allow_expose":false,"timeout":"10s"}`
|| |[**supergraph**](#supergraph)|`object`|Configuration for the Federation supergraph source. By default, the router will use a local file-based supergraph source (`./supergraph.graphql`).
|| |[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"dedupe_enabled":true,"max_connections_per_host":100,"pool_idle_timeout_seconds":50}`
|| @@ -102,6 +103,7 @@ override_subgraph_urls: .original_url } +plugins: {} query_planner: allow_expose: false timeout: 10s @@ -1641,6 +1643,13 @@ products: |----|----|-----------|--------| |**url**||Overrides for the URL of the subgraph.

For convenience, a plain string in your configuration will be treated as a static URL.

### Static URL Example
```yaml
url: "https://api.example.com/graphql"
```

### Dynamic Expression Example

The expression has access to the following variables:
- `request`: The incoming HTTP request, including headers and other metadata.
- `original_url`: The original URL of the subgraph (from supergraph sdl).

```yaml
url:
expression: \|
if .request.headers."x-region" == "us-east" {
"https://products-us-east.example.com/graphql"
} else if .request.headers."x-region" == "eu-west" {
"https://products-eu-west.example.com/graphql"
} else {
.original_url
}
|yes| + +## plugins: object + +Configuration for custom plugins + + +**Additional Properties:** allowed ## query\_planner: object diff --git a/e2e/Cargo.toml b/e2e/Cargo.toml index 5a604afc1..f3c6e37a0 100644 --- a/e2e/Cargo.toml +++ b/e2e/Cargo.toml @@ -17,10 +17,26 @@ lazy_static = { workspace = true } jsonwebtoken = { workspace = true } insta = { workspace = true } reqwest = { workspace = true } +serde = { workspace = true } +dashmap = { workspace = true } +async-trait = { workspace = true } +http = { workspace = true } +bytes = { workspace = true } +graphql-tools = { workspace = true } +graphql-parser = { workspace = true } +serde_json = { workspace = true } hive-router = { path = "../bin/router" } hive-router-config = { path = "../lib/router-config" } +hive-router-plan-executor = { path = "../lib/executor" } +hive-router-query-planner = { path = "../lib/query-planner" } subgraphs = { path = "../bench/subgraphs" } mockito = "1.7.0" tempfile = "3.23.0" +redis = { version= "0.32.7", features = ["r2d2"]} +r2d2 = "0.8.10" +multer = "3.1.0" +futures-util = "0.3.31" +bollard = "0.19.4" + diff --git a/e2e/src/file_supergraph.rs b/e2e/src/file_supergraph.rs index 2ceb77f14..a9f825fb8 100644 --- a/e2e/src/file_supergraph.rs +++ b/e2e/src/file_supergraph.rs @@ -21,12 +21,15 @@ mod file_supergraph_e2e_tests { let first_supergraph = include_str!("../supergraph.graphql"); fs::write(&supergraph_file_path, first_supergraph).expect("failed to write supergraph"); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: file path: {supergraph_file_path} "#, - )) + ), + None, + ) .await .expect("failed to start router"); wait_for_readiness(&app.app).await; @@ -64,13 +67,16 @@ mod file_supergraph_e2e_tests { fs::write(&supergraph_file_path, "type Query { f: String }") .expect("failed to write supergraph"); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: file path: {supergraph_file_path} poll_interval: 100ms "#, - )) + ), + None, + ) .await .expect("failed to start router"); wait_for_readiness(&app.app).await; diff --git a/e2e/src/hive_cdn_supergraph.rs b/e2e/src/hive_cdn_supergraph.rs index 03c1efafb..363492f96 100644 --- a/e2e/src/hive_cdn_supergraph.rs +++ b/e2e/src/hive_cdn_supergraph.rs @@ -24,13 +24,16 @@ mod hive_cdn_supergraph_e2e_tests { .with_body(include_str!("../supergraph.graphql")) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key "#, - )) + ), + None, + ) .await .expect("failed to start router"); @@ -77,14 +80,17 @@ mod hive_cdn_supergraph_e2e_tests { .with_status(304) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 100ms "#, - )) + ), + None, + ) .await .expect("failed to start router"); @@ -135,14 +141,17 @@ mod hive_cdn_supergraph_e2e_tests { .with_status(304) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 100ms "#, - )) + ), + None, + ) .await .expect("failed to start router"); @@ -199,14 +208,17 @@ mod hive_cdn_supergraph_e2e_tests { .with_body(include_str!("../supergraph.graphql")) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 800ms "#, - )) + ), + None, + ) .await .expect("failed to start router"); @@ -276,8 +288,9 @@ mod hive_cdn_supergraph_e2e_tests { .with_body("type Query { dummy: String }") .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key @@ -285,7 +298,9 @@ mod hive_cdn_supergraph_e2e_tests { retry_policy: max_retries: 10 "#, - )) + ), + None, + ) .await .expect("failed to start router"); @@ -304,8 +319,9 @@ mod hive_cdn_supergraph_e2e_tests { .with_status(500) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key @@ -313,7 +329,9 @@ mod hive_cdn_supergraph_e2e_tests { retry_policy: max_retries: 3 "#, - )) + ), + None, + ) .await .expect("failed to start router"); diff --git a/e2e/src/jwt.rs b/e2e/src/jwt.rs index eb4f5eba4..3c0471d72 100644 --- a/e2e/src/jwt.rs +++ b/e2e/src/jwt.rs @@ -27,7 +27,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_forward_claims_to_subgraph_via_extensions() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_forward.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_forward.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -69,9 +69,10 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_allow_expressions_to_access_jwt_details() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml") - .await - .unwrap(); + let app = + init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) + .await + .unwrap(); wait_for_readiness(&app.app).await; let exp = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -132,9 +133,10 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_allow_expressions_to_access_jwt_scopes() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml") - .await - .unwrap(); + let app = + init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) + .await + .unwrap(); wait_for_readiness(&app.app).await; // First request with a token and "scope: read:accounts" @@ -243,7 +245,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_without_token_when_auth_is_required() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -270,7 +272,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_malformed_token() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -301,7 +303,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_invalid_signature() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -324,7 +326,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn accepts_request_with_valid_token() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -358,7 +360,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_expired_token() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); @@ -391,7 +393,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_wrong_issuer() { - let app = init_router_from_config_file("configs/jwt_auth_issuer.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_issuer.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -413,7 +415,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_wrong_audience() { - let app = init_router_from_config_file("configs/jwt_auth_audience.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_audience.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; diff --git a/e2e/src/lib.rs b/e2e/src/lib.rs index 9086e01f4..3875f71a1 100644 --- a/e2e/src/lib.rs +++ b/e2e/src/lib.rs @@ -7,6 +7,8 @@ mod jwt; #[cfg(test)] mod override_subgraph_urls; #[cfg(test)] +mod plugins; +#[cfg(test)] mod probes; #[cfg(test)] mod supergraph; diff --git a/e2e/src/override_subgraph_urls.rs b/e2e/src/override_subgraph_urls.rs index cd3e0789d..1932608f1 100644 --- a/e2e/src/override_subgraph_urls.rs +++ b/e2e/src/override_subgraph_urls.rs @@ -16,6 +16,7 @@ mod override_subgraph_urls_e2e_tests { let subgraphs_server = SubgraphsServer::start_with_port(4100).await; let app = init_router_from_config_file( "configs/override_subgraph_urls/override_static.router.yaml", + None, ) .await .unwrap(); @@ -48,6 +49,7 @@ mod override_subgraph_urls_e2e_tests { let subgraphs_server = SubgraphsServer::start_with_port(4100).await; let app = init_router_from_config_file( "configs/override_subgraph_urls/override_dynamic_header.router.yaml", + None, ) .await .unwrap(); diff --git a/e2e/src/plugins/allowed_clients.json b/e2e/src/plugins/allowed_clients.json new file mode 100644 index 000000000..381a87082 --- /dev/null +++ b/e2e/src/plugins/allowed_clients.json @@ -0,0 +1 @@ +["urql", "graphql-request"] \ No newline at end of file diff --git a/e2e/src/plugins/apollo_sandbox.rs b/e2e/src/plugins/apollo_sandbox.rs new file mode 100644 index 000000000..327cb1c44 --- /dev/null +++ b/e2e/src/plugins/apollo_sandbox.rs @@ -0,0 +1,202 @@ +use std::collections::HashMap; + +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::on_http_request::{OnHttpRequestHookPayload, OnHttpRequestHookResult}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; +use http::HeaderMap; +use reqwest::StatusCode; +pub(crate) use sonic_rs::{Deserialize, Serialize}; + +#[derive(Default, Serialize, Deserialize, Debug, Clone)] +#[serde(default, rename_all = "camelCase")] +pub struct ApolloSandboxOptions { + pub enabled: bool, + pub initial_endpoint: String, + /** + * By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.Set `hideCookieToggle` to `false` to enable users of your embedded Sandbox instance to toggle the **Include cookies** setting. + */ + pub hide_cookie_toggle: bool, + /** + * By default, the embedded Sandbox has a URL input box that is editable by users.Set endpointIsEditable to false to prevent users of your embedded Sandbox instance from changing the endpoint URL. + */ + pub endpoint_is_editable: bool, + /** + * You can set `includeCookies` to `true` if you instead want Sandbox to pass `{ credentials: 'include' }` for its requests.If you pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials).This config option is deprecated in favor of using the connection settings cookie toggle in Sandbox and setting the default value via `initialState.includeCookies`. + */ + pub include_cookies: bool, + /** + * An object containing additional options related to the state of the embedded Sandbox on page load. + */ + pub initial_state: ApolloSandboxInitialStateOptions, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "camelCase")] +pub struct ApolloSandboxInitialStateOptions { + pub enabled: bool, + /** + * Set this value to `true` if you want Sandbox to pass `{ credentials: 'include' }` for its requests by default.If you set `hideCookieToggle` to `false`, users can override this default setting with the **Include cookies** toggle. (By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.)If you also pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials). + */ + pub include_cookies: bool, + /** + * A URI-encoded operation to populate in Sandbox's editor on load.If you omit this, Sandbox initially loads an example query based on your schema.Example: + * ```js + * initialState: { + * document: ` + * query ExampleQuery { + * books { + * title + * } + * } + * ` + * } + * ``` + */ + pub document: Option, + /** + * A URI-encoded, serialized object containing initial variable values to populate in Sandbox on load.If provided, these variables should apply to the initial query you provide for [`document`](https://www.apollographql.com/docs/apollo-sandbox#document).Example: + * + * ```js + * initialState: { + * variables: { + * userID: "abc123" + * }, + * } + * ``` + */ + pub variables: Option, + /** + * A URI-encoded, serialized object containing initial HTTP header values to populate in Sandbox on load.Example: + * + * + * ```js + * initialState: { + * headers: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub headers: Option, + /** + * The ID of a collection, paired with an operation ID to populate in Sandbox on load. You can find these values from a registered graph in Studio by clicking the **...** menu next to an operation in the Explorer of that graph and selecting **View operation details**.Example: + * + * ```js + * initialState: { + * collectionId: 'abc1234', + * operationId: 'xyz1234' + * } + * ``` + */ + pub collection_id: Option, + pub operation_id: Option, + /** + * If `true`, the embedded Sandbox periodically polls your `initialEndpoint` for schema updates.The default value is `true`.Example: + * + * ```js + * initialState: { + * pollForSchemaUpdates: false; + * } + * ``` + */ + pub poll_for_schema_updates: bool, + /** + * Headers that are applied by default to every operation executed by the embedded Sandbox. Users can turn off the application of these headers, but they can't modify their values.The embedded Sandbox always includes these headers in its introspection queries to your `initialEndpoint`.Example: + * + * ```js + * initialState: { + * sharedHeaders: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub shared_headers: HashMap, +} + +impl RouterPluginWithConfig for ApolloSandboxPlugin { + type Config = ApolloSandboxOptions; + fn plugin_name() -> &'static str { + "apollo_sandbox" + } + fn from_config(config: ApolloSandboxOptions) -> Option { + if config.enabled { + Some(ApolloSandboxPlugin { + serialized_options: sonic_rs::to_string(&config) + .unwrap_or_else(|_| "{}".to_string()), + }) + } else { + None + } + } +} + +pub struct ApolloSandboxPlugin { + serialized_options: String, +} + +impl RouterPlugin for ApolloSandboxPlugin { + fn on_http_request<'req>( + &'req self, + payload: OnHttpRequestHookPayload<'req>, + ) -> OnHttpRequestHookResult<'req> { + if payload.router_http_request.path() == "/apollo-sandbox" { + let config = + sonic_rs::to_string(&self.serialized_options).unwrap_or_else(|_| "{}".to_string()); + let html = format!( + r#" +
+ + + "#, + config + ); + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/html".parse().unwrap()); + return payload.end_response(HttpResponse { + body: html.into_bytes().into(), + headers, + status: StatusCode::OK, + }); + } + payload.cont() + } +} + +#[cfg(test)] +mod apollo_sandbox_tests { + use hive_router::PluginRegistry; + + #[ntex::test] + async fn renders_apollo_sandbox_page() { + use crate::testkit::init_router_from_config_inline; + use ntex::web::test; + + let app = init_router_from_config_inline( + r#" + plugins: + apollo_sandbox: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + + let req = test::TestRequest::get().uri("/apollo-sandbox").to_request(); + let response = app.call(req).await.expect("failed to call /apollo-sandbox"); + let status = response.status(); + + let body_bytes = test::read_body(response).await; + let body_str = std::str::from_utf8(&body_bytes).expect("response body is not valid UTF-8"); + + assert_eq!(status, 200); + assert!(body_str.contains("EmbeddedSandbox")); + } +} diff --git a/e2e/src/plugins/apq.rs b/e2e/src/plugins/apq.rs new file mode 100644 index 000000000..683606f30 --- /dev/null +++ b/e2e/src/plugins/apq.rs @@ -0,0 +1,241 @@ +use dashmap::DashMap; +use http::StatusCode; +use serde::Deserialize; +use serde_json::json; +use sonic_rs::{JsonContainerTrait, JsonValueTrait}; + +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; + +#[derive(Deserialize)] +pub struct APQPluginConfig { + pub enabled: bool, +} + +pub struct APQPlugin { + cache: DashMap, +} + +impl RouterPluginWithConfig for APQPlugin { + type Config = APQPluginConfig; + fn plugin_name() -> &'static str { + "apq" + } + fn from_config(config: Self::Config) -> Option { + if config.enabled { + Some(APQPlugin { + cache: DashMap::new(), + }) + } else { + None + } + } +} + +#[async_trait::async_trait] +impl RouterPlugin for APQPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { + payload.on_end(|mut payload| { + let persisted_query_ext = payload + .graphql_params + .extensions + .as_ref() + .and_then(|ext| ext.get("persistedQuery")) + .and_then(|pq| pq.as_object()); + if let Some(persisted_query_ext) = persisted_query_ext { + match persisted_query_ext.get(&"version").and_then(|v| v.as_i64()) { + Some(1) => {} + _ => { + let body = json!({ + "errors": [ + { + "message": "Unsupported persisted query version", + "extensions": { + "code": "UNSUPPORTED_PERSISTED_QUERY_VERSION" + } + } + ] + }); + return payload.end_response(HttpResponse { + body: body.to_string().into_bytes().into(), + status: StatusCode::BAD_REQUEST, + headers: http::HeaderMap::new(), + }); + } + } + let sha256_hash = match persisted_query_ext + .get(&"sha256Hash") + .and_then(|h| h.as_str()) + { + Some(h) => h, + None => { + let body = json!({ + "errors": [ + { + "message": "Missing sha256Hash in persisted query", + "extensions": { + "code": "MISSING_PERSISTED_QUERY_HASH" + } + } + ] + }); + return payload.end_response(HttpResponse { + body: body.to_string().into_bytes().into(), + status: StatusCode::BAD_REQUEST, + headers: http::HeaderMap::new(), + }); + } + }; + if let Some(query_param) = &payload.graphql_params.query { + // Store the query in the cache + self.cache + .insert(sha256_hash.to_string(), query_param.to_string()); + } else { + // Try to get the query from the cache + if let Some(cached_query) = self.cache.get(sha256_hash) { + // Update the graphql_params with the cached query + payload.graphql_params.query = Some(cached_query.value().to_string()); + } else { + let body = json!({ + "errors": [ + { + "message": "PersistedQueryNotFound", + "extensions": { + "code": "PERSISTED_QUERY_NOT_FOUND" + } + } + ] + }); + return payload.end_response(HttpResponse { + body: body.to_string().into_bytes().into(), + status: StatusCode::NOT_FOUND, + headers: http::HeaderMap::new(), + }); + } + } + } + + payload.cont() + }) + } +} + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + + use hive_router::PluginRegistry; + use ntex::web::test; + use serde_json::json; + #[ntex::test] + async fn sends_not_found_error_if_query_missing() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + apq: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let body = json!( + { + "extensions": { + "persistedQuery": { + "version": 1, + "sha256Hash": "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38", + }, + }, + } + ); + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + let body = test::read_body(resp).await; + let body_json: serde_json::Value = + serde_json::from_slice(&body).expect("Response body should be valid JSON"); + assert_eq!( + body_json, + json!({ + "errors": [ + { + "message": "PersistedQueryNotFound", + "extensions": { + "code": "PERSISTED_QUERY_NOT_FOUND" + } + } + ] + }), + "Expected PersistedQueryNotFound error" + ); + } + #[ntex::test] + async fn saves_persisted_query() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + apq: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let query = "{ users { id } }"; + let sha256_hash = "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"; + let body = json!( + { + "query": query, + "extensions": { + "persistedQuery": { + "version": 1, + "sha256Hash": sha256_hash, + }, + }, + } + ); + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!( + resp.status().is_success(), + "Expected 200 OK when sending full query" + ); + + // Now send only the hash and expect it to be found + let body = json!( + { + "extensions": { + "persistedQuery": { + "version": 1, + "sha256Hash": sha256_hash, + }, + }, + } + ); + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!( + resp.status().is_success(), + "Expected 200 OK when sending persisted query hash" + ); + } +} diff --git a/e2e/src/plugins/async_auth.rs b/e2e/src/plugins/async_auth.rs new file mode 100644 index 000000000..07871d6a2 --- /dev/null +++ b/e2e/src/plugins/async_auth.rs @@ -0,0 +1,215 @@ +// From https://github.com/apollographql/router/blob/dev/examples/async-auth/rust/src/allow_client_id_from_file.rs +use serde::Deserialize; +use sonic_rs::json; +use std::path::PathBuf; + +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; + +#[derive(Deserialize)] +pub struct AllowClientIdConfig { + pub enabled: bool, + pub header: String, + pub path: String, +} + +impl RouterPluginWithConfig for AllowClientIdFromFilePlugin { + type Config = AllowClientIdConfig; + fn plugin_name() -> &'static str { + "allow_client_id_from_file" + } + fn from_config(config: AllowClientIdConfig) -> Option { + if config.enabled { + Some(AllowClientIdFromFilePlugin { + header_key: config.header, + allowed_ids_path: PathBuf::from(config.path), + }) + } else { + None + } + } +} + +pub struct AllowClientIdFromFilePlugin { + header_key: String, + allowed_ids_path: PathBuf, +} + +#[async_trait::async_trait] +impl RouterPlugin for AllowClientIdFromFilePlugin { + // Whenever it is a GraphQL request, + // We don't use on_http_request here because we want to run this only when it is a GraphQL request + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { + let header = payload.router_http_request.headers.get(&self.header_key); + match header { + Some(client_id) => { + let client_id_str = client_id.to_str(); + match client_id_str { + Ok(client_id) => { + let allowed_clients: Vec = sonic_rs::from_str( + std::fs::read_to_string(self.allowed_ids_path.clone()) + .unwrap() + .as_str(), + ) + .unwrap(); + + if !allowed_clients.contains(&client_id.to_string()) { + // Prepare an HTTP 403 response with a GraphQL error message + let body = json!( + { + "errors": [ + { + "message": "client-id is not allowed", + "extensions": { + "code": "UNAUTHORIZED_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), + headers: http::HeaderMap::new(), + status: http::StatusCode::FORBIDDEN, + }); + } + } + Err(_not_a_string_error) => { + let message = format!("'{}' value is not a string", &self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "BAD_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), + headers: http::HeaderMap::new(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + None => { + let message = format!("Missing '{}' header", &self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "AUTH_ERROR" + } + } + ] + } + ); + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), + headers: http::HeaderMap::new(), + status: http::StatusCode::UNAUTHORIZED, + }); + } + } + payload.cont() + } +} + +#[cfg(test)] +mod tests { + use crate::testkit::{ + init_graphql_request, init_router_from_config_inline, wait_for_readiness, SubgraphsServer, + }; + + use hive_router::PluginRegistry; + use ntex::web::test; + use serde_json::Value; + #[ntex::test] + async fn should_allow_only_allowed_client_ids() { + SubgraphsServer::start().await; + + let app = init_router_from_config_inline( + r#" + plugins: + allow_client_id_from_file: + enabled: true + path: "./src/plugins/allowed_clients.json" + header: "x-client-id" + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + // Test with an allowed client id + let req = init_graphql_request("{ users { id } }", None).header("x-client-id", "urql"); + let resp = test::call_service(&app.app, req.to_request()).await; + let status = resp.status(); + assert!(status.is_success(), "Expected 200 OK for allowed client id"); + // Test with a disallowed client id + let req = init_graphql_request("{ users { id } }", None) + .header("x-client-id", "forbidden-client"); + let resp = test::call_service(&app.app, req.to_request()).await; + assert_eq!( + resp.status(), + http::StatusCode::FORBIDDEN, + "Expected 403 FORBIDDEN for disallowed client id" + ); + let body_bytes = test::read_body(resp).await; + let body_json: Value = + serde_json::from_slice(&body_bytes).expect("Response body should be valid JSON"); + assert_eq!( + body_json, + serde_json::json!({ + "errors": [ + { + "message": "client-id is not allowed", + "extensions": { + "code": "UNAUTHORIZED_CLIENT_ID" + } + } + ] + }), + "Expected error message for disallowed client id" + ); + // Test with missing client id + let req = init_graphql_request("{ users { id } }", None); + let resp = test::call_service(&app.app, req.to_request()).await; + assert_eq!( + resp.status(), + http::StatusCode::UNAUTHORIZED, + "Expected 401 UNAUTHORIZED for missing client id" + ); + let body_bytes = test::read_body(resp).await; + let body_json: Value = + serde_json::from_slice(&body_bytes).expect("Response body should be valid JSON"); + assert_eq!( + body_json, + serde_json::json!({ + "errors": [ + { + "message": "Missing 'x-client-id' header", + "extensions": { + "code": "AUTH_ERROR" + } + } + ] + }), + "Expected error message for missing client id" + ); + } +} diff --git a/e2e/src/plugins/context_data.rs b/e2e/src/plugins/context_data.rs new file mode 100644 index 000000000..ab120122d --- /dev/null +++ b/e2e/src/plugins/context_data.rs @@ -0,0 +1,135 @@ +// From https://github.com/apollographql/router/blob/dev/examples/context/rust/src/context_data.rs + +use serde::Deserialize; + +use hive_router_plan_executor::{ + hooks::{ + on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + on_subgraph_execute::{ + OnSubgraphExecuteEndHookPayload, OnSubgraphExecuteStartHookPayload, + OnSubgraphExecuteStartHookResult, + }, + }, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; + +#[derive(Deserialize)] +pub struct ContextDataPluginConfig { + pub enabled: bool, +} + +pub struct ContextDataPlugin {} + +pub struct ContextData { + incoming_data: String, + response_count: u64, +} + +impl RouterPluginWithConfig for ContextDataPlugin { + type Config = ContextDataPluginConfig; + fn plugin_name() -> &'static str { + "context_data" + } + fn from_config(config: ContextDataPluginConfig) -> Option { + if config.enabled { + Some(ContextDataPlugin {}) + } else { + None + } + } +} + +#[async_trait::async_trait] +impl RouterPlugin for ContextDataPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { + let context_data = ContextData { + incoming_data: "world".to_string(), + response_count: 0, + }; + + payload.context.insert(context_data); + + payload.on_end(|payload| { + let context_data = payload.context.get_mut::(); + if let Some(mut context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartHookPayload<'exec>, + ) -> OnSubgraphExecuteStartHookResult<'exec> { + let context_data_entry = payload.context.get_ref::(); + if let Some(ref context_data_entry) = context_data_entry { + tracing::info!("hello {}", context_data_entry.incoming_data); // Hello world! + let new_header_value = format!("Hello {}", context_data_entry.incoming_data); + payload.execution_request.headers.insert( + "x-hello", + http::HeaderValue::from_str(&new_header_value).unwrap(), + ); + } + payload.on_end(|payload: OnSubgraphExecuteEndHookPayload<'exec>| { + let context_data = payload.context.get_mut::(); + if let Some(mut context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } +} + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + use hive_router::PluginRegistry; + use ntex::web::test; + #[ntex::test] + async fn should_add_context_data_and_modify_subgraph_request() { + let subgraphs = SubgraphsServer::start().await; + + let app = init_router_from_config_inline( + r#" + plugins: + context_data: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + + wait_for_readiness(&app.app).await; + + let resp = test::call_service( + &app.app, + crate::testkit::init_graphql_request("{ users { id } }", None).to_request(), + ) + .await; + + assert!(resp.status().is_success(), "Expected 200 OK"); + + let request_logs = subgraphs + .get_subgraph_requests_log("accounts") + .await + .expect("expected requests sent to accounts subgraph"); + assert_eq!( + request_logs.len(), + 1, + "expected 1 request to accounts subgraph" + ); + let hello_header_value = request_logs[0] + .headers + .get("x-hello") + .expect("expected x-hello header to be present in subgraph request") + .to_str() + .expect("header value should be valid string"); + assert_eq!(hello_header_value, "Hello world"); + } +} diff --git a/e2e/src/plugins/forbid_anonymous_operations.rs b/e2e/src/plugins/forbid_anonymous_operations.rs new file mode 100644 index 000000000..6befdd4a6 --- /dev/null +++ b/e2e/src/plugins/forbid_anonymous_operations.rs @@ -0,0 +1,125 @@ +// Same with https://github.com/apollographql/router/blob/dev/examples/forbid-anonymous-operations/rust/src/forbid_anonymous_operations.rs + +use http::StatusCode; +use serde::Deserialize; +use serde_json::json; + +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; + +#[derive(Deserialize)] +pub struct ForbidAnonymousOperationsPluginConfig { + pub enabled: bool, +} +pub struct ForbidAnonymousOperationsPlugin {} + +impl RouterPluginWithConfig for ForbidAnonymousOperationsPlugin { + type Config = ForbidAnonymousOperationsPluginConfig; + fn plugin_name() -> &'static str { + "forbid_anonymous_operations" + } + fn from_config(config: Self::Config) -> Option { + if config.enabled { + Some(ForbidAnonymousOperationsPlugin {}) + } else { + None + } + } +} + +#[async_trait::async_trait] +impl RouterPlugin for ForbidAnonymousOperationsPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { + // After the GraphQL parameters have been parsed, we can check if the operation is anonymous + // So we use `on_end` + payload.on_end(|payload| { + let maybe_operation_name = &payload + .graphql_params + .operation_name + .as_ref(); + + if maybe_operation_name + .is_none_or(|operation_name| operation_name.is_empty()) + { + // let's log the error + tracing::error!("Operation is not allowed!"); + + // Prepare an HTTP 400 response with a GraphQL error message + let body = json!({ + "errors": [ + { + "message": "Anonymous operations are not allowed", + "extensions": { + "code": "ANONYMOUS_OPERATION" + } + } + ] + }); + return payload.end_response(HttpResponse { + body: body.to_string().into(), + headers: http::HeaderMap::new(), + status: StatusCode::BAD_REQUEST, + }); + } + // we're good to go! + tracing::info!("operation is allowed!"); + payload.cont() + }) + } +} + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + use hive_router::PluginRegistry; + use http::StatusCode; + use ntex::web::test; + use serde_json::{json, Value}; + #[ntex::test] + async fn should_forbid_anonymous_operations() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + forbid_anonymous_operations: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + + let resp = test::call_service( + &app.app, + test::TestRequest::post() + .uri("/graphql") + .set_payload(r#"{"query":"{ __schema { types { name } } }"}"#) + .header("content-type", "application/json") + .to_request(), + ) + .await; + + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let json_body: Value = serde_json::from_slice(&test::read_body(resp).await).unwrap(); + assert_eq!( + json_body, + json!({ + "errors": [ + { + "message": "Anonymous operations are not allowed", + "extensions": { + "code": "ANONYMOUS_OPERATION" + } + } + ] + }) + ); + } +} diff --git a/e2e/src/plugins/mod.rs b/e2e/src/plugins/mod.rs new file mode 100644 index 000000000..eee9d7f7a --- /dev/null +++ b/e2e/src/plugins/mod.rs @@ -0,0 +1,11 @@ +pub mod apollo_sandbox; +pub mod apq; +pub mod async_auth; +pub mod context_data; +pub mod forbid_anonymous_operations; +pub mod multipart; +pub mod one_of; +pub mod propagate_status_code; +pub mod response_cache; +pub mod root_field_limit; +pub mod subgraph_response_cache; diff --git a/e2e/src/plugins/multipart.rs b/e2e/src/plugins/multipart.rs new file mode 100644 index 000000000..2cf929a95 --- /dev/null +++ b/e2e/src/plugins/multipart.rs @@ -0,0 +1,261 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::{ + on_graphql_params::{ + GraphQLParams, OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult, + }, + on_subgraph_http_request::{ + OnSubgraphHttpRequestHookPayload, OnSubgraphHttpRequestHookResult, + }, + }, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; +use multer::Multipart; +use serde::Deserialize; +use serde_json::json; +use tracing::error; + +#[derive(Deserialize)] +pub struct MultipartPluginConfig { + pub enabled: bool, +} +pub struct MultipartPlugin {} + +pub struct MultipartFile { + pub filename: Option, + pub content_type: Option, + pub content: Bytes, +} + +pub struct MultipartContext { + pub file_map: HashMap>, + pub files: HashMap, +} + +impl RouterPluginWithConfig for MultipartPlugin { + type Config = MultipartPluginConfig; + fn plugin_name() -> &'static str { + "multipart" + } + fn from_config(config: MultipartPluginConfig) -> Option { + if config.enabled { + Some(MultipartPlugin {}) + } else { + None + } + } +} + +#[async_trait::async_trait] +impl RouterPlugin for MultipartPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + mut payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { + if let Some(content_type) = payload.router_http_request.headers.get("content-type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.starts_with("multipart/form-data") { + let boundary = multer::parse_boundary(content_type_str).unwrap(); + let body = payload.body.clone(); + let stream = futures_util::stream::once(async move { + Ok::(Bytes::from(body.to_vec())) + }); + let mut multipart = Multipart::new(stream, boundary); + while let Some(field) = multipart.next_field().await.unwrap() { + let field_name = field.name().unwrap().to_string(); + let filename = field.file_name().map(|s| s.to_string()); + let content_type = field.content_type().map(|s| s.to_string()); + let data = field.bytes().await.unwrap(); + match field_name.as_str() { + "operations" => { + let graphql_params: GraphQLParams = + sonic_rs::from_slice(&data).unwrap(); + payload.graphql_params = Some(graphql_params); + } + "map" => { + let file_map: HashMap> = + sonic_rs::from_slice(&data).unwrap(); + payload.context.insert(MultipartContext { + file_map, + files: HashMap::new(), + }); + } + field_name => { + let multipart_ctx = payload.context.get_mut::(); + if let Some(mut multipart_ctx) = multipart_ctx { + let multipart_file = MultipartFile { + filename, + content_type, + content: data, + }; + multipart_ctx + .files + .insert(field_name.to_string(), multipart_file); + } + } + } + } + } + } + } + payload.cont() + } + + async fn on_subgraph_http_request<'exec>( + &'exec self, + mut payload: OnSubgraphHttpRequestHookPayload<'exec>, + ) -> OnSubgraphHttpRequestHookResult<'exec> { + if let Some(variables) = &payload.execution_request.variables { + let multipart_ctx = payload.context.get_ref::(); + if let Some(multipart_ctx) = multipart_ctx { + let mut file_map: HashMap> = HashMap::new(); + for variable_name in variables.keys() { + // Matching variables that are file references + for (files_ref, op_refs) in &multipart_ctx.file_map { + for op_ref in op_refs { + if op_ref.starts_with(format!("variables.{}", variable_name).as_str()) { + let op_refs_in_curr_map = + file_map.entry(files_ref.to_string()).or_default(); + op_refs_in_curr_map.push(op_ref.to_string()); + } + } + } + } + if !file_map.is_empty() { + let mut form = reqwest::multipart::Form::new(); + form = form.text( + "operations", + String::from_utf8(payload.body.clone()).unwrap(), + ); + let file_map_str: String = sonic_rs::to_string(&file_map).unwrap(); + form = form.text("map", file_map_str); + for (file_ref, _op_refs) in file_map { + if let Some(file_field) = multipart_ctx.files.get(&file_ref) { + let mut part = + reqwest::multipart::Part::bytes(file_field.content.to_vec()); + if let Some(file_name) = &file_field.filename { + part = part.file_name(file_name.to_string()); + } + if let Some(content_type) = &file_field.content_type { + part = part.mime_str(&content_type.to_string()).unwrap(); + } + form = form.part(file_ref, part); + } + } + let resp = reqwest::Client::new() + .post(payload.endpoint.to_string()) + // Using query as endpoint URL + .multipart(form) + .send() + .await; + match resp { + Ok(resp) => { + payload.response = Some(HttpResponse { + status: resp.status(), + headers: resp.headers().clone(), + body: resp.bytes().await.unwrap(), + }); + } + Err(err) => { + error!("Failed to send multipart request to subgraph: {}", err); + let body = json!({ + "errors": [{ + "message": format!("Failed to send multipart request to subgraph: {}", err) + }] + }); + return payload.end_response(HttpResponse { + status: reqwest::StatusCode::INTERNAL_SERVER_ERROR, + headers: reqwest::header::HeaderMap::new(), + body: serde_json::to_vec(&body).unwrap().into(), + }); + } + } + } + } + } + payload.cont() + } +} + +#[cfg(test)] +mod tests { + use futures_util::StreamExt; + use hive_router::PluginRegistry; + use ntex::web::test; + + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + + #[ntex::test] + async fn forward_files() { + let subgraphs_server = SubgraphsServer::start().await; + + let app = init_router_from_config_inline( + r#" + plugins: + multipart: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + + let form = reqwest::multipart::Form::new() + .text("operations", r#"{"query":"mutation ($file: Upload) { upload(file: $file) }","variables":{"file":null}}"#) + .text("map", r#"{"0":["variables.file"]}"#) + .part( + "0", + reqwest::multipart::Part::bytes("file content".as_bytes().to_vec()) + .file_name("test.txt") + .mime_str("text/plain") + .unwrap(), + ); + + let boundary = form.boundary().to_string(); + let form_stream = form.into_stream(); + + let mut form_bytes = vec![]; + let mut stream = form_stream; + while let Some(item) = stream.next().await { + let chunk = item.expect("Failed to read chunk"); + form_bytes.extend_from_slice(&chunk); + } + + let req = test::TestRequest::post() + .uri("/graphql") + .header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + ) + .set_payload(form_bytes); + + let resp = test::call_service(&app.app, req.to_request()).await; + + let body = test::read_body(resp).await; + let body_str = String::from_utf8_lossy(&body); + let body_json = serde_json::from_str::(&body_str).unwrap(); + let upload_file_path = &body_json["data"]["upload"].as_str().unwrap(); + assert!( + upload_file_path.contains("test.txt"), + "Response should contain the filename" + ); + let file_content = tokio::fs::read(upload_file_path).await.unwrap(); + assert_eq!( + file_content, b"file content", + "File content should match the uploaded content" + ); + assert_eq!( + subgraphs_server + .get_subgraph_requests_log("products") + .await + .unwrap() + .len(), + 1, + "Expected 1 request to products subgraph" + ); + } +} diff --git a/e2e/src/plugins/one_of.rs b/e2e/src/plugins/one_of.rs new file mode 100644 index 000000000..08f6fd5f8 --- /dev/null +++ b/e2e/src/plugins/one_of.rs @@ -0,0 +1,350 @@ +// This example will show `@oneOf` input type validation in two steps: +// 1. During validation step +// 2. During execution step + +// We handle execution too to validate input objects at runtime as well. +/* + Let's say we have the following input type with `@oneOf` directive: + input PaymentMethod @oneOf { + creditCard: CreditCardInput + bankTransfer: BankTransferInput + paypal: PayPalInput + } + + During validation, if a variable of type `PaymentMethod` is provided with multiple fields set, + we will raise a validation error. + + ```graphql + mutation MakePayment { + makePayment(method: { + creditCard: { number: "1234", expiry: "12/24" }, + paypal: { email: "john@doe.com" } + }) { + success + } + } + ``` + + But since variables can be dynamic, we also validate during execution. If the input object has multiple fields set, + we return an error in the response. + + ```graphql + mutation MakePayment($method: PaymentMethod!) { + makePayment(method: $method) { + success + } + } + ``` + + with variables: + { + "method": { + "creditCard": { "number": "1234", "expiry": "12/24" }, + "paypal": { "email": "john@doe.com" } + } + } +*/ + +use std::{collections::BTreeMap, sync::RwLock}; + +use graphql_parser::{ + query::Value, + schema::{Definition, TypeDefinition}, +}; +use graphql_tools::ast::visit_document; +use graphql_tools::{ + ast::{OperationVisitor, OperationVisitorContext}, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::{ + on_execute::{OnExecuteStartHookPayload, OnExecuteStartHookResult}, + on_graphql_validation::{ + OnGraphQLValidationStartHookPayload, OnGraphQLValidationStartHookResult, + }, + on_supergraph_load::{OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload}, + }, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload, StartHookResult}, +}; +use serde::Deserialize; +use sonic_rs::{json, JsonContainerTrait}; + +#[derive(Deserialize)] +pub struct OneOfPluginConfig { + pub enabled: bool, +} + +impl RouterPluginWithConfig for OneOfPlugin { + type Config = OneOfPluginConfig; + fn plugin_name() -> &'static str { + "oneof" + } + fn from_config(config: OneOfPluginConfig) -> Option { + if config.enabled { + Some(OneOfPlugin { + one_of_types: RwLock::new(vec![]), + }) + } else { + None + } + } +} + +pub struct OneOfPlugin { + pub one_of_types: RwLock>, +} + +#[async_trait::async_trait] +impl RouterPlugin for OneOfPlugin { + // 1. During validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartHookPayload<'exec>, + ) -> OnGraphQLValidationStartHookResult<'exec> { + let rule = OneOfValidationRule { + one_of_types: self.one_of_types.read().unwrap().clone(), + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // 2. During execution step + async fn on_execute<'exec>( + &'exec self, + payload: OnExecuteStartHookPayload<'exec>, + ) -> OnExecuteStartHookResult<'exec> { + if let (Some(variable_values), Some(variable_defs)) = ( + &payload.variable_values, + &payload.operation_for_plan.variable_definitions, + ) { + for def in variable_defs { + let variable_named_type = def.variable_type.inner_type(); + let one_of_types = self.one_of_types.read().unwrap(); + if one_of_types.contains(&variable_named_type.to_string()) { + let var_name = &def.name; + if let Some(value) = variable_values.get(var_name).and_then(|v| v.as_object()) { + let keys_num = value.len(); + if keys_num > 1 { + let err_msg = format!( + "Variable '${}' of input object type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + var_name, + variable_named_type, + keys_num + ); + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_FIELDS_SET_IN_ONEOF" + } + }] + })) + .unwrap() + .into(), + headers: Default::default(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + } + } + payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartHookPayload, + ) -> StartHookResult<'exec, OnSupergraphLoadStartHookPayload, OnSupergraphLoadEndHookPayload> + { + for def in start_payload.new_ast.definitions.iter() { + if let Definition::TypeDefinition(TypeDefinition::InputObject(input_obj)) = def { + for directive in input_obj.directives.iter() { + if directive.name == "oneOf" { + self.one_of_types + .write() + .unwrap() + .push(input_obj.name.clone()); + } + } + } + } + start_payload.cont() + } +} + +struct OneOfValidationRule { + one_of_types: Vec, +} + +impl ValidationRule for OneOfValidationRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + op_ctx: &mut OperationVisitorContext<'_>, + validation_error_context: &mut ValidationErrorContext, + ) { + visit_document( + &mut OneOfValidation { + one_of_types: self.one_of_types.clone(), + }, + op_ctx.operation, + op_ctx, + validation_error_context, + ); + } +} + +struct OneOfValidation { + one_of_types: Vec, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for OneOfValidation { + fn enter_object_value( + &mut self, + visitor_context: &mut OperationVisitorContext<'a>, + user_context: &mut ValidationErrorContext, + fields: &BTreeMap, + ) { + if let Some(TypeDefinition::InputObject(input_type)) = visitor_context.current_input_type() + { + if self.one_of_types.contains(&input_type.name) { + let mut set_fields = vec![]; + for (field_name, field_value) in fields.iter() { + if !matches!(field_value, Value::Null) { + set_fields.push(field_name.clone()); + } + } + if set_fields.len() > 1 { + let err_msg = format!( + "Input object of type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + input_type.name, + set_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_FIELDS_SET_IN_ONEOF", + locations: vec![], + message: err_msg, + }); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use hive_router::PluginRegistry; + use serde_json::{from_slice, Value}; + + use crate::testkit::{init_router_from_config_inline, wait_for_readiness}; + + #[ntex::test] + async fn one_of_validates_in_validation_rule() { + let app = init_router_from_config_inline( + r#" + plugins: + oneof: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + + let req = crate::testkit::init_graphql_request( + r#" + mutation OneOfTest { + oneofTest(input: { + string: "test", + int: 42 + }) { + string + int + float + boolean + id + } + } + "#, + None, + ); + + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + let body = ntex::web::test::read_body(resp).await; + let body_val: Value = from_slice(&body).expect("Response body should be valid JSON"); + let errors = body_val + .get("errors") + .expect("Response should contain errors"); + let first_error = errors + .as_array() + .expect("Errors should be an array") + .first() + .expect("There should be at least one error"); + let message = first_error + .get("message") + .expect("Error should have a message") + .as_str() + .expect("Message should be a string"); + assert!(message.contains("multiple fields set")); + } + + #[ntex::test] + async fn one_of_validates_during_execution() { + let app = init_router_from_config_inline( + r#" + plugins: + oneof: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + + let req = crate::testkit::init_graphql_request( + r#" + mutation OneOfTest($input: OneOfTestInput!) { + oneofTest(input: $input) { + string + int + float + boolean + id + } + } + "#, + Some(sonic_rs::json!({ + "input": { + "string": "test", + "int": 42 + } + })), + ); + + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + let body = ntex::web::test::read_body(resp).await; + let body_val: Value = from_slice(&body).expect("Response body should be valid JSON"); + let errors = body_val + .get("errors") + .expect("Response should contain errors"); + let first_error = errors + .as_array() + .expect("Errors should be an array") + .first() + .expect("There should be at least one error"); + let message = first_error + .get("message") + .expect("Error should have a message") + .as_str() + .expect("Message should be a string"); + assert!(message.contains("multiple fields set")); + } +} diff --git a/e2e/src/plugins/propagate_status_code.rs b/e2e/src/plugins/propagate_status_code.rs new file mode 100644 index 000000000..701902c6f --- /dev/null +++ b/e2e/src/plugins/propagate_status_code.rs @@ -0,0 +1,179 @@ +// From https://github.com/apollographql/router/blob/dev/examples/status-code-propagation/rust/src/propagate_status_code.rs + +use http::StatusCode; +use serde::Deserialize; + +use hive_router_plan_executor::{ + hooks::{ + on_http_request::{OnHttpRequestHookPayload, OnHttpRequestHookResult}, + on_subgraph_http_request::{ + OnSubgraphHttpRequestHookPayload, OnSubgraphHttpRequestHookResult, + }, + }, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; + +#[derive(Deserialize)] +pub struct PropagateStatusCodePluginConfig { + pub enabled: bool, + pub status_codes: Vec, +} + +impl RouterPluginWithConfig for PropagateStatusCodePlugin { + type Config = PropagateStatusCodePluginConfig; + fn plugin_name() -> &'static str { + "propagate_status_code" + } + fn from_config(config: PropagateStatusCodePluginConfig) -> Option { + if !config.enabled { + return None; + } + let status_codes = config + .status_codes + .into_iter() + .filter_map(|code| StatusCode::from_u16(code as u16).ok()) + .collect(); + Some(PropagateStatusCodePlugin { status_codes }) + } +} + +pub struct PropagateStatusCodePlugin { + pub status_codes: Vec, +} + +pub struct PropagateStatusCodeCtx { + pub status_code: StatusCode, +} + +#[async_trait::async_trait] +impl RouterPlugin for PropagateStatusCodePlugin { + async fn on_subgraph_http_request<'exec>( + &'exec self, + payload: OnSubgraphHttpRequestHookPayload<'exec>, + ) -> OnSubgraphHttpRequestHookResult<'exec> { + payload.on_end(|payload| { + let status_code = payload.response.status; + // if a response contains a status code we're watching... + if self.status_codes.contains(&status_code) { + // Checking if there is already a context entry + let ctx = payload.context.get_mut::(); + if let Some(mut ctx) = ctx { + // Update the status code if the new one is more severe (higher) + if status_code.as_u16() > ctx.status_code.as_u16() { + ctx.status_code = status_code; + } + } else { + // Insert a new context entry + let new_ctx = PropagateStatusCodeCtx { status_code }; + payload.context.insert(new_ctx); + } + } + payload.cont() + }) + } + fn on_http_request<'exec>( + &'exec self, + payload: OnHttpRequestHookPayload<'exec>, + ) -> OnHttpRequestHookResult<'exec> { + payload.on_end(|mut payload| { + // Checking if there is a context entry + let ctx = payload.context.get_ref::(); + if let Some(ctx) = ctx { + // Update the HTTP response status code + *payload.response.response_mut().status_mut() = ctx.status_code; + } + payload.cont() + }) + } +} + +#[cfg(test)] +mod tests { + #[ntex::test] + async fn propagates_highest_status_code() { + let mut subgraphs_server = mockito::Server::new_async().await; + let accounts_mock_207 = subgraphs_server + .mock("POST", "/accounts") + .with_status(207) + .with_body(r#"{"data": {"users": [{"id": "1"}]}}"#) + .create_async() + .await; + let products_mock_206 = subgraphs_server + .mock("POST", "/products") + .with_status(206) + .with_body(r#"{"data": {"topProducts": [{"upc": "a"}]}}"#) + .create_async() + .await; + let app = crate::testkit::init_router_from_config_inline( + &format!( + r#" + override_subgraph_urls: + accounts: + url: http://{}/accounts + products: + url: http://{}/products + plugins: + propagate_status_code: + enabled: true + status_codes: [206, 207] + "#, + subgraphs_server.host_with_port(), + subgraphs_server.host_with_port() + ), + Some(hive_router::PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + crate::testkit::wait_for_readiness(&app.app).await; + + let req = + crate::testkit::init_graphql_request("{ users { id } topProducts { upc } }", None); + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + assert_eq!(resp.status().as_u16(), 207); + accounts_mock_207.assert_async().await; + products_mock_206.assert_async().await; + } + #[ntex::test] + async fn ignores_unlisted_status_codes() { + let mut subgraphs_server = mockito::Server::new_async().await; + let accounts_mock_208 = subgraphs_server + .mock("POST", "/accounts") + .with_status(208) + .with_body(r#"{"data": {"users": [{"id": "1"}]}}"#) + .create_async() + .await; + let products_mock_209 = subgraphs_server + .mock("POST", "/products") + .with_status(209) + .with_body(r#"{"data": {"topProducts": [{"upc": "a"}]}}"#) + .create_async() + .await; + let app = crate::testkit::init_router_from_config_inline( + &format!( + r#" + override_subgraph_urls: + accounts: + url: http://{}/accounts + products: + url: http://{}/products + plugins: + propagate_status_code: + enabled: true + status_codes: [208] + "#, + subgraphs_server.host_with_port(), + subgraphs_server.host_with_port() + ), + Some(hive_router::PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + crate::testkit::wait_for_readiness(&app.app).await; + let req = + crate::testkit::init_graphql_request("{ users { id } topProducts { upc } }", None); + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + assert_eq!(resp.status().as_u16(), 208); + accounts_mock_208.assert_async().await; + products_mock_209.assert_async().await; + } +} diff --git a/e2e/src/plugins/response_cache.rs b/e2e/src/plugins/response_cache.rs new file mode 100644 index 000000000..4a09b9e0f --- /dev/null +++ b/e2e/src/plugins/response_cache.rs @@ -0,0 +1,268 @@ +use dashmap::DashMap; +use http::{HeaderMap, StatusCode}; +use redis::Commands; +use serde::Deserialize; + +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::{ + on_execute::{ + OnExecuteEndHookPayload, OnExecuteStartHookPayload, OnExecuteStartHookResult, + }, + on_supergraph_load::{OnSupergraphLoadStartHookPayload, OnSupergraphLoadStartHookResult}, + }, + plugin_trait::{EndHookPayload, RouterPluginWithConfig, StartHookPayload}, + plugins::plugin_trait::RouterPlugin, + utils::consts::TYPENAME_FIELD_NAME, +}; +use tracing::trace; + +#[derive(Deserialize)] +pub struct ResponseCachePluginOptions { + pub enabled: bool, + pub redis_url: String, + #[serde(default = "default_ttl_seconds")] + pub default_ttl_seconds: u64, +} + +fn default_ttl_seconds() -> u64 { + 5 +} + +impl RouterPluginWithConfig for ResponseCachePlugin { + type Config = ResponseCachePluginOptions; + fn plugin_name() -> &'static str { + "response_cache_plugin" + } + fn from_config(config: ResponseCachePluginOptions) -> Option { + if !config.enabled { + return None; + } + let redis_client = + redis::Client::open(config.redis_url).expect("Failed to create Redis client"); + let pool = r2d2::Pool::builder() + .build(redis_client) + .unwrap_or_else(|err| panic!("Failed to create Redis connection pool: {}", err)); + Some(Self { + redis: pool, + ttl_per_type: DashMap::new(), + default_ttl_seconds: config.default_ttl_seconds, + }) + } +} + +pub struct ResponseCachePlugin { + redis: r2d2::Pool, + ttl_per_type: DashMap, + default_ttl_seconds: u64, +} + +#[async_trait::async_trait] +impl RouterPlugin for ResponseCachePlugin { + async fn on_execute<'exec>( + &'exec self, + payload: OnExecuteStartHookPayload<'exec>, + ) -> OnExecuteStartHookResult<'exec> { + let key = format!( + "response_cache:{}:{:?}", + payload.query_plan, payload.variable_values + ); + if let Ok(mut conn) = self.redis.get() { + trace!("Checking cache for key: {}", key); + let cache_result: Result, redis::RedisError> = conn.get(&key); + match cache_result { + Ok(body) => { + if body.is_empty() { + trace!("Cache miss for key: {}", key); + } else { + trace!( + "Cache hit for key: {} -> {}", + key, + String::from_utf8_lossy(&body) + ); + return payload.end_response(HttpResponse { + body: body.into(), + headers: HeaderMap::new(), + status: StatusCode::OK, + }); + } + } + Err(err) => { + trace!("Error accessing cache for key {}: {}", key, err); + } + } + return payload.on_end(move |mut payload: OnExecuteEndHookPayload<'exec>| { + // Do not cache if there are errors + if !payload.errors.is_empty() { + trace!("Not caching response due to errors"); + return payload.cont(); + } + + if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { + trace!("Caching response for key: {}", key); + // Decide on the ttl somehow + // Get the type names + let mut max_ttl = 0; + + // Imagine this code is traversing the response data to find type names + if let Some(obj) = payload.data.as_object() { + if let Some(typename) = obj + .iter() + .position(|(k, _)| k == &TYPENAME_FIELD_NAME) + .and_then(|idx| obj[idx].1.as_str()) + { + if let Some(ttl) = self.ttl_per_type.get(typename).map(|v| *v) { + max_ttl = max_ttl.max(ttl); + } + } + } + + // If no ttl found, default + if max_ttl == 0 { + max_ttl = self.default_ttl_seconds; + } + trace!("Using TTL of {} seconds for key: {}", max_ttl, key); + + // Insert the ttl into extensions for client awareness + payload + .extensions + .get_or_insert_default() + .insert("response_cache_ttl".to_string(), sonic_rs::json!(max_ttl)); + + // Set the cache with the decided ttl + let result = + conn.set_ex::<&str, Vec, ()>(&key, serialized.clone(), max_ttl); + if let Err(err) = result { + trace!("Failed to set cache for key {}: {}", key, err); + } else { + trace!("Cached response for key: {} with TTL: {}", key, max_ttl); + } + } + payload.cont() + }); + } + payload.cont() + } + fn on_supergraph_reload<'a>( + &'a self, + payload: OnSupergraphLoadStartHookPayload, + ) -> OnSupergraphLoadStartHookResult<'a> { + // Visit the schema and update ttl_per_type based on some directive + payload.new_ast.definitions.iter().for_each(|def| { + if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { + if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { + for directive in &obj_type.directives { + if directive.name == "cacheControl" { + for arg in &directive.arguments { + if arg.0 == "maxAge" { + if let graphql_parser::query::Value::Int(max_age) = &arg.1 { + if let Some(max_age) = max_age.as_i64() { + self.ttl_per_type + .insert(obj_type.name.clone(), max_age as u64); + } + } + } + } + } + } + } + } + }); + + payload.cont() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use http::StatusCode; + use tracing::trace; + + use crate::testkit::{ + wait_for_readiness, SubgraphsServer, TestDockerContainer, TestDockerContainerOpts, + }; + + #[ntex::test] + async fn test_caching_with_default_ttl() { + let container = TestDockerContainer::async_new(TestDockerContainerOpts { + name: "redis_resp_caching_test".to_string(), + image: "redis/redis-stack:latest".to_string(), + ports: HashMap::from([(6379, 6379)]), + env: vec!["ALLOW_EMPTY_PASSWORD=yes".to_string()], + ..Default::default() + }) + .await + .expect("failed to start redis container"); + + // Redis flush all to ensure clean state + container + .exec(vec!["redis-cli", "FLUSHALL"]) + .await + .expect("Failed to flush redis"); + let subgraphs_server = SubgraphsServer::start().await; + + let app = crate::testkit::init_router_from_config_inline( + r#" + plugins: + response_cache_plugin: + enabled: true + redis_url: "redis://0.0.0.0:6379" + default_ttl_seconds: 2 + "#, + Some(hive_router::PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + + wait_for_readiness(&app.app).await; + + let req = crate::testkit::init_graphql_request("{ users { id } }", None); + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + trace!("First response received"); + assert_eq!(resp.status(), StatusCode::OK); + let resp_body = ntex::web::test::read_body(resp).await; + trace!( + "Response body read: {:?}", + String::from_utf8_lossy(&resp_body) + ); + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("Failed to get subgraph requests log"); + assert_eq!(subgraph_requests.len(), 1, "Expected one subgraph request"); + let req = crate::testkit::init_graphql_request("{ users { id } }", None); + let resp2 = ntex::web::test::call_service(&app.app, req.to_request()).await; + trace!("Second response received"); + assert!(resp2.status().is_success()); + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("Failed to get subgraph requests log"); + assert_eq!( + subgraph_requests.len(), + 1, + "Expected only one subgraph request due to caching" + ); + trace!("Waiting for cache to expire..."); + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + let req = crate::testkit::init_graphql_request("{ users { id } }", None); + let resp3 = ntex::web::test::call_service(&app.app, req.to_request()).await; + assert!(resp3.status().is_success()); + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("Failed to get subgraph requests log"); + assert_eq!( + subgraph_requests.len(), + 2, + "Expected a second subgraph request after cache expiry" + ); + container.stop().await; + } + #[ntex::test] + async fn respect_directives_on_supergraph_reload() { + todo!(); + } +} diff --git a/e2e/src/plugins/root_field_limit.rs b/e2e/src/plugins/root_field_limit.rs new file mode 100644 index 000000000..6baf07ecb --- /dev/null +++ b/e2e/src/plugins/root_field_limit.rs @@ -0,0 +1,211 @@ +use graphql_tools::{ + ast::{visit_document, OperationVisitor, OperationVisitorContext, TypeDefinitionExtension}, + static_graphql, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; +use hive_router_query_planner::ast::selection_item::SelectionItem; +use serde::Deserialize; +use sonic_rs::json; + +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::{ + on_graphql_validation::{ + OnGraphQLValidationStartHookPayload, OnGraphQLValidationStartHookResult, + }, + on_query_plan::{OnQueryPlanStartHookPayload, OnQueryPlanStartHookResult}, + }, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; + +// This example shows two ways of limiting the number of root fields in a query: +// 1. During validation step +// 2. During query planning step + +#[async_trait::async_trait] +impl RouterPlugin for RootFieldLimitPlugin { + // Using validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartHookPayload<'exec>, + ) -> OnGraphQLValidationStartHookResult<'exec> { + let rule = RootFieldLimitRule { + max_root_fields: self.max_root_fields, + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // Or during query planning + async fn on_query_plan<'exec>( + &'exec self, + payload: OnQueryPlanStartHookPayload<'exec>, + ) -> OnQueryPlanStartHookResult<'exec> { + let mut cnt = 0; + for selection in payload + .filtered_operation_for_plan + .selection_set + .items + .iter() + { + match selection { + SelectionItem::Field(_) => { + cnt += 1; + if cnt > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + cnt, self.max_root_fields + ); + tracing::warn!("{}", err_msg); + let body = json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_ROOT_FIELDS" + } + }] + }); + // Return error + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), + headers: http::HeaderMap::new(), + status: http::StatusCode::PAYLOAD_TOO_LARGE, + }); + } + } + SelectionItem::InlineFragment(_) => { + unreachable!("Inline fragments should have been inlined before query planning"); + } + SelectionItem::FragmentSpread(_) => { + unreachable!("Fragment spreads should have been inlined before query planning"); + } + } + } + payload.cont() + } +} + +#[derive(Deserialize)] +pub struct RootFieldLimitPluginConfig { + enabled: bool, + max_root_fields: usize, +} + +impl RouterPluginWithConfig for RootFieldLimitPlugin { + type Config = RootFieldLimitPluginConfig; + fn plugin_name() -> &'static str { + "root_field_limit" + } + fn from_config(config: Self::Config) -> Option { + if !config.enabled { + return None; + } + Some(RootFieldLimitPlugin { + max_root_fields: config.max_root_fields, + }) + } +} + +pub struct RootFieldLimitPlugin { + max_root_fields: usize, +} + +pub struct RootFieldLimitRule { + max_root_fields: usize, +} + +struct RootFieldSelections { + max_root_fields: usize, + count: usize, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for RootFieldSelections { + fn enter_field( + &mut self, + visitor_context: &mut OperationVisitorContext, + user_context: &mut ValidationErrorContext, + field: &static_graphql::query::Field, + ) { + let parent_type_name = visitor_context.current_parent_type().map(|t| t.name()); + if parent_type_name == Some("Query") { + self.count += 1; + if self.count > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + self.count, self.max_root_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_ROOT_FIELDS", + locations: vec![field.position], + message: err_msg, + }); + } + } + } +} + +impl ValidationRule for RootFieldLimitRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + ctx: &mut OperationVisitorContext<'_>, + error_collector: &mut ValidationErrorContext, + ) { + visit_document( + &mut RootFieldSelections { + max_root_fields: self.max_root_fields, + count: 0, + }, + ctx.operation, + ctx, + error_collector, + ); + } +} + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + use hive_router::PluginRegistry; + use ntex::web::test; + #[ntex::test] + async fn rejects_query_with_too_many_root_fields() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + root_field_limit: + enabled: true + max_root_fields: 1 + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let resp = test::call_service( + &app.app, + test::TestRequest::post() + .uri("/graphql") + .set_payload( + r#"{"query":"query TooManyRootFields { users { id } topProducts { upc } }"}"#, + ) + .header("content-type", "application/json") + .to_request(), + ) + .await; + let json_body: serde_json::Value = + serde_json::from_slice(&test::read_body(resp).await).unwrap(); + + let error_msg = json_body["errors"][0]["message"].as_str().unwrap(); + assert!( + error_msg.contains("Query has too many root fields"), + "Unexpected error message: {}", + error_msg + ); + } +} diff --git a/e2e/src/plugins/subgraph_response_cache.rs b/e2e/src/plugins/subgraph_response_cache.rs new file mode 100644 index 000000000..b7b8a0289 --- /dev/null +++ b/e2e/src/plugins/subgraph_response_cache.rs @@ -0,0 +1,97 @@ +use dashmap::DashMap; +use serde::Deserialize; + +use hive_router_plan_executor::{ + executors::http::HttpResponse, + hooks::on_subgraph_execute::{ + OnSubgraphExecuteEndHookPayload, OnSubgraphExecuteStartHookPayload, + OnSubgraphExecuteStartHookResult, + }, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, +}; + +#[derive(Deserialize)] +pub struct SubgraphResponseCachePluginConfig { + enabled: bool, +} + +impl RouterPluginWithConfig for SubgraphResponseCachePlugin { + type Config = SubgraphResponseCachePluginConfig; + fn plugin_name() -> &'static str { + "subgraph_response_cache" + } + fn from_config(config: SubgraphResponseCachePluginConfig) -> Option { + if config.enabled { + Some(SubgraphResponseCachePlugin { + cache: DashMap::new(), + }) + } else { + None + } + } +} + +pub struct SubgraphResponseCachePlugin { + cache: DashMap, +} + +#[async_trait::async_trait] +impl RouterPlugin for SubgraphResponseCachePlugin { + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartHookPayload<'exec>, + ) -> OnSubgraphExecuteStartHookResult<'exec> { + let key = format!( + "subgraph_response_cache:{}:{:?}", + payload.execution_request.query, payload.execution_request.variables + ); + if let Some(cached_response) = self.cache.get(&key) { + // Here payload.response is Option + // So it is bypassing the actual subgraph request + payload.execution_result = Some(cached_response.clone()); + return payload.cont(); + } + payload.on_end(move |payload: OnSubgraphExecuteEndHookPayload| { + // Here payload.response is not Option + self.cache.insert(key, payload.execution_result.clone()); + payload.cont() + }) + } +} + +#[cfg(test)] +mod tests { + use crate::testkit::{ + init_graphql_request, init_router_from_config_inline, wait_for_readiness, SubgraphsServer, + }; + use hive_router::PluginRegistry; + use ntex::web::test; + + // Tests on_subgraph_execute's override behavior + #[ntex::test] + async fn caches_subgraph_responses() { + let subgraphs = SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + subgraph_response_cache: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let req = init_graphql_request("{ users { id } }", None); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!(resp.status().is_success()); + let req = init_graphql_request("{ users { id } }", None); + let resp2 = test::call_service(&app.app, req.to_request()).await; + assert!(resp2.status().is_success()); + let subgraph_requests = subgraphs + .get_subgraph_requests_log("accounts") + .await + .expect("failed to get subgraph requests log"); + assert_eq!(subgraph_requests.len(), 1); + } +} diff --git a/e2e/src/probes.rs b/e2e/src/probes.rs index 86c6a8d8f..027154d35 100644 --- a/e2e/src/probes.rs +++ b/e2e/src/probes.rs @@ -23,14 +23,17 @@ mod probes_e2e_tests { }) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 500ms "#, - )) + ), + None, + ) .await .expect("failed to start router"); @@ -76,14 +79,17 @@ mod probes_e2e_tests { .with_status(500) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 500ms "#, - )) + ), + None, + ) .await .expect("failed to start router"); diff --git a/e2e/src/supergraph.rs b/e2e/src/supergraph.rs index 07b345aac..9cabfe952 100644 --- a/e2e/src/supergraph.rs +++ b/e2e/src/supergraph.rs @@ -29,14 +29,17 @@ mod supergraph_e2e_tests { .with_body("type Query { dummyNew: NewType } type NewType { id: ID! }") .create(); - let test_app = init_router_from_config_inline(&format!( - r#"supergraph: + let test_app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 500ms "#, - )) + ), + None, + ) .await .expect("failed to start router"); @@ -191,14 +194,17 @@ mod supergraph_e2e_tests { .create(); let test_app = Arc::new( - init_router_from_config_inline(&format!( - r#"supergraph: + init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 300ms "#, - )) + ), + None, + ) .await .expect("failed to start router"), ); diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index 638138801..1f73c6d18 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -1,8 +1,15 @@ -use std::{path::PathBuf, sync::Arc, time::Duration}; +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; +use bollard::{ + exec::{CreateExecOptions, StartExecResults}, + query_parameters::CreateImageOptionsBuilder, + secret::{ContainerCreateBody, ContainerCreateResponse, CreateImageInfo, HostConfig, PortMap}, + Docker, +}; +use futures_util::TryStreamExt; use hive_router::{ background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, - RouterSharedState, SchemaState, + plugins::plugins_service::PluginService, PluginRegistry, RouterSharedState, SchemaState, }; use hive_router_config::{load_config, parse_yaml_config, HiveRouterConfig}; use ntex::{ @@ -74,7 +81,7 @@ where pub struct SubgraphsServer { shutdown_tx: Option>, - subgraph_shared_state: SubgraphsServiceState, + subgraph_shared_state: Arc, } impl Drop for SubgraphsServer { @@ -116,14 +123,16 @@ impl SubgraphsServer { } pub async fn get_subgraph_requests_log(&self, subgraph_name: &str) -> Option> { - let log = self.subgraph_shared_state.request_log.lock().await; - - log.get(&format!("/{}", subgraph_name)).cloned() + self.subgraph_shared_state + .request_log + .get(&format!("/{}", subgraph_name)) + .map(|entry| entry.value().clone()) } } pub async fn init_router_from_config_file( config_path: &str, + plugin_registry: Option, ) -> Result< TestRouterApp< impl ntex::Service, @@ -133,11 +142,12 @@ pub async fn init_router_from_config_file( let supergraph_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(config_path); let router_config = load_config(Some(supergraph_path.to_str().unwrap().to_string()))?; - init_router_from_config(router_config).await + init_router_from_config(router_config, plugin_registry).await } pub async fn init_router_from_config_inline( config_yaml: &str, + plugin_registry: Option, ) -> Result< TestRouterApp< impl ntex::Service, @@ -145,7 +155,7 @@ pub async fn init_router_from_config_inline( Box, > { let router_config = parse_yaml_config(config_yaml.to_string())?; - init_router_from_config(router_config).await + init_router_from_config(router_config, plugin_registry).await } pub struct TestRouterApp { @@ -173,6 +183,7 @@ impl TestRouterApp { pub async fn init_router_from_config( router_config: HiveRouterConfig, + plugin_registry: Option, ) -> Result< TestRouterApp< impl ntex::Service, @@ -181,10 +192,11 @@ pub async fn init_router_from_config( > { let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager).await?; + configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_registry).await?; let ntex_app = test::init_service( web::App::new() + .wrap(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) .configure(configure_ntex_app), @@ -204,3 +216,133 @@ impl Drop for TestRouterApp { self.bg_tasks_manager.shutdown(); } } + +#[derive(Default)] +pub struct TestDockerContainerOpts { + pub name: String, + pub image: String, + pub ports: HashMap, + pub env: Vec, +} + +pub struct TestDockerContainer { + docker: Docker, + container: ContainerCreateResponse, +} + +impl TestDockerContainer { + pub async fn async_new(opts: TestDockerContainerOpts) -> Result { + let docker = + Docker::connect_with_local_defaults().expect("Failed to connect to Docker daemon"); + let mut port_bindings = PortMap::new(); + for (container_port, host_port) in opts.ports.iter() { + port_bindings.insert( + format!("{}/tcp", container_port), + Some(vec![bollard::models::PortBinding { + host_port: Some(host_port.to_string()), + ..Default::default() + }]), + ); + } + let _: Vec = docker + .create_image( + Some( + CreateImageOptionsBuilder::default() + .from_image(&opts.image) + .build(), + ), + None, + None, + ) + .try_collect() + .await + .expect("Failed to pull the image"); + let container_exists = docker + .list_containers(Some(bollard::query_parameters::ListContainersOptions { + all: true, + ..Default::default() + })) + .await? + .into_iter() + .any(|c| { + c.names + .unwrap_or_default() + .iter() + .any(|name| name.trim_start_matches('/').eq(&opts.name)) + }); + if container_exists { + docker + .remove_container( + &opts.name, + Some(bollard::query_parameters::RemoveContainerOptions { + force: true, + ..Default::default() + }), + ) + .await + .expect("Failed to remove existing container"); + } + let container = docker + .create_container( + Some( + bollard::query_parameters::CreateContainerOptionsBuilder::default() + .name(&opts.name) + .build(), + ), + ContainerCreateBody { + image: Some(opts.image.to_string()), + host_config: Some(HostConfig { + port_bindings: Some(port_bindings), + ..Default::default() + }), + env: Some(opts.env), + ..Default::default() + }, + ) + .await + .expect("Failed to create the container"); + docker + .start_container( + &container.id, + None::, + ) + .await + .expect("Failed to start the container"); + Ok(Self { docker, container }) + } + pub async fn exec(&self, cmd: Vec<&str>) -> Result<(), bollard::errors::Error> { + let exec = self + .docker + .create_exec( + &self.container.id, + CreateExecOptions { + attach_stdout: Some(true), + attach_stderr: Some(true), + cmd: Some(cmd), + ..Default::default() + }, + ) + .await?; + match self.docker.start_exec(&exec.id, None).await? { + StartExecResults::Attached { mut output, .. } => { + while let Some(msg) = output.try_next().await? { + print!("{}", msg); + } + } + _ => {} + } + Ok(()) + } + pub async fn stop(&self) { + self.docker + .remove_container( + &self.container.id, + Some(bollard::query_parameters::RemoveContainerOptions { + force: true, + ..Default::default() + }), + ) + .await + .expect("Failed to remove the container"); + } +} diff --git a/e2e/supergraph.graphql b/e2e/supergraph.graphql index 1fe16b12b..655a18c1a 100644 --- a/e2e/supergraph.graphql +++ b/e2e/supergraph.graphql @@ -2,6 +2,7 @@ schema @link(url: "https://specs.apollo.dev/link/v1.0") @link(url: "https://specs.apollo.dev/join/v0.3", for: EXECUTION) { query: Query + mutation: Mutation } directive @join__enumValue(graph: join__Graph!) repeatable on ENUM_VALUE @@ -113,3 +114,31 @@ type User birthday: Int @join__field(graph: ACCOUNTS) reviews: [Review] @join__field(graph: REVIEWS) } + +scalar Upload + +type Mutation + @join__type(graph: PRODUCTS) { + upload(file: Upload): String @join__field(graph: PRODUCTS) + + oneofTest(input: OneOfTestInput!): OneOfTestResult + @join__field(graph: PRODUCTS) +} + +directive @oneOf on INPUT_OBJECT + +input OneOfTestInput @oneOf @join__type(graph: PRODUCTS) { + string: String + int: Int + float: Float + boolean: Boolean + id: ID +} + +type OneOfTestResult @join__type(graph: PRODUCTS) { + string: String @join__field(graph: PRODUCTS) + int: Int @join__field(graph: PRODUCTS) + float: Float @join__field(graph: PRODUCTS) + boolean: Boolean @join__field(graph: PRODUCTS) + id: ID @join__field(graph: PRODUCTS) +} \ No newline at end of file diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 27f7af1bf..629d65e5d 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -30,10 +30,15 @@ xxhash-rust = { workspace = true } tokio = { workspace = true, features = ["sync"] } dashmap = { workspace = true } vrl = { workspace = true } +reqwest = { workspace = true, features = ["multipart"] } +serde_json = { workspace = true } ahash = "0.8.12" regex-automata = "0.4.10" strum = { version = "0.27.2", features = ["derive"] } + +arc-swap = "1.7.1" +ntex = { version = "2", features = ["tokio"] } ntex-http = "0.1.15" ordered-float = "4.2.0" hyper-tls = { version = "0.6.0", features = ["vendored"] } @@ -44,7 +49,7 @@ hyper-util = { version = "0.1.16", features = [ "http2", "tokio", ] } -bytes = "1.10.1" +bytes = { workspace = true } itoa = "1.0.15" ryu = "1.0.20" indexmap = "2.10.0" diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 6985376cc..20b7dcf98 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -18,14 +18,14 @@ pub struct ClientRequestDetails<'exec, 'req> { pub url: &'req http::Uri, pub headers: &'req NtexHeaderMap, pub operation: OperationDetails<'exec>, - pub jwt: &'exec JwtRequestDetails<'req>, + pub jwt: JwtRequestDetails, } -pub enum JwtRequestDetails<'exec> { +pub enum JwtRequestDetails { Authenticated { - token: &'exec str, - prefix: Option<&'exec str>, - claims: &'exec sonic_rs::Value, + token: String, + prefix: Option, + claims: sonic_rs::Value, scopes: Option>, }, Unauthenticated, @@ -67,7 +67,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ])); // .request.jwt - let jwt_value = match details.jwt { + let jwt_value = match &details.jwt { JwtRequestDetails::Authenticated { token, prefix, @@ -78,7 +78,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ("token".into(), token.to_string().into()), ( "prefix".into(), - prefix.unwrap_or_default().to_string().into(), + prefix.as_deref().unwrap_or_default().to_string().into(), ), ("claims".into(), sonic_value_to_vrl_value(claims)), ( diff --git a/lib/executor/src/execution/jwt_forward.rs b/lib/executor/src/execution/jwt_forward.rs index 24c19ff9f..9aefc601c 100644 --- a/lib/executor/src/execution/jwt_forward.rs +++ b/lib/executor/src/execution/jwt_forward.rs @@ -8,7 +8,7 @@ pub struct JwtAuthForwardingPlan { pub extension_field_value: Value, } -impl JwtRequestDetails<'_> { +impl JwtRequestDetails { pub fn build_forwarding_plan( &self, extension_field_name: &str, diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index f86356312..188240ce9 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -2,9 +2,12 @@ use std::collections::{BTreeSet, HashMap}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; -use hive_router_query_planner::planner::plan_nodes::{ - ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode, - QueryPlan, SequenceNode, +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + planner::plan_nodes::{ + ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, + PlanNode, QueryPlan, SequenceNode, + }, }; use http::HeaderMap; use serde::Deserialize; @@ -18,19 +21,19 @@ use crate::{ jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt, }, - executors::{ - common::{HttpExecutionRequest, HttpExecutionResponse}, - map::SubgraphExecutorMap, - }, + executors::{common::SubgraphExecutionRequest, http::HttpResponse, map::SubgraphExecutorMap}, headers::{ plan::HeaderRulesPlan, request::modify_subgraph_request_headers, response::{apply_subgraph_response_headers, modify_client_response_headers}, }, + hooks::on_execute::{OnExecuteEndHookPayload, OnExecuteStartHookPayload}, introspection::{ resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, }, + plugin_context::PluginRequestState, + plugin_trait::{EndControlFlow, StartControlFlow}, projection::{ plan::FieldProjectionPlan, request::{project_requires, RequestProjectionContext}, @@ -49,7 +52,9 @@ use crate::{ }; pub struct QueryPlanExecutionContext<'exec, 'req> { + pub plugin_req_state: &'exec Option>, pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, @@ -58,67 +63,132 @@ pub struct QueryPlanExecutionContext<'exec, 'req> { pub introspection_context: &'exec IntrospectionContext<'exec, 'static>, pub operation_type_name: &'exec str, pub executors: &'exec SubgraphExecutorMap, - pub jwt_auth_forwarding: &'exec Option, + pub jwt_auth_forwarding: Option, } -pub struct PlanExecutionOutput { - pub body: Vec, - pub headers: HeaderMap, -} +impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { + pub async fn execute_query_plan(self) -> Result { + let mut init_value = if let Some(introspection_query) = self.introspection_context.query { + resolve_introspection(introspection_query, self.introspection_context) + } else { + Value::Null + }; -pub async fn execute_query_plan<'exec, 'req>( - ctx: QueryPlanExecutionContext<'exec, 'req>, -) -> Result { - let init_value = if let Some(introspection_query) = ctx.introspection_context.query { - resolve_introspection(introspection_query, ctx.introspection_context) - } else { - Value::Null - }; + let mut query_plan = self.query_plan; + + let dedupe_subgraph_requests = self.operation_type_name == "Query"; + let mut extensions = self.extensions; + + let mut on_end_callbacks = vec![]; + + if let Some(plugin_req_state) = self.plugin_req_state.as_ref() { + let mut start_payload = OnExecuteStartHookPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + query_plan, + operation_for_plan: self.operation_for_plan, + data: init_value, + errors: Vec::new(), + extensions, + variable_values: self.variable_values, + dedupe_subgraph_requests, + }; + + for plugin in plugin_req_state.plugins.iter() { + let result = plugin.on_execute(start_payload).await; + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { /* continue to next plugin */ } + StartControlFlow::EndResponse(response) => { + return Ok(response); + } + StartControlFlow::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + query_plan = start_payload.query_plan; - let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value); - let executor = Executor::new( - ctx.variable_values, - ctx.executors, - ctx.introspection_context.metadata, - ctx.client_request, - ctx.headers_plan, - ctx.jwt_auth_forwarding, - // Deduplicate subgraph requests only if the operation type is a query - ctx.operation_type_name == "Query", - ); - - if ctx.query_plan.node.is_some() { - executor - .execute(&mut exec_ctx, ctx.query_plan.node.as_ref()) - .await?; - } + init_value = start_payload.data; + + extensions = start_payload.extensions; + } + + let mut exec_ctx = ExecutionContext::new(query_plan, init_value); + let executor = Executor::new( + self.variable_values, + self.executors, + self.introspection_context.metadata, + self.client_request, + self.headers_plan, + self.jwt_auth_forwarding, + // Deduplicate subgraph requests only if the operation type is a query + self.operation_type_name == "Query", + self.plugin_req_state, + ); + + if query_plan.node.is_some() { + executor + .execute(&mut exec_ctx, query_plan.node.as_ref()) + .await?; + } + + let mut response_headers = HeaderMap::new(); + modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + .with_plan_context(LazyPlanContext { + subgraph_name: || None, + affected_path: || None, + })?; + + let mut data = exec_ctx.final_response; + let mut errors = exec_ctx.errors; + let mut response_size_estimate = exec_ctx.response_storage.estimate_final_response_size(); + + if !on_end_callbacks.is_empty() { + let mut end_payload = OnExecuteEndHookPayload { + data, + errors, + extensions, + response_size_estimate, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { /* continue to next callback */ } + EndControlFlow::EndResponse(output) => { + return Ok(output); + } + } + } + + data = end_payload.data; + errors = end_payload.errors; + extensions = end_payload.extensions; + response_size_estimate = end_payload.response_size_estimate; + } - let mut response_headers = HeaderMap::new(); - modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + let body = project_by_operation( + &data, + errors, + &extensions, + self.operation_type_name, + self.projection_plan, + self.variable_values, + response_size_estimate, + ) .with_plan_context(LazyPlanContext { subgraph_name: || None, affected_path: || None, })?; - let final_response = &exec_ctx.final_response; - let body = project_by_operation( - final_response, - exec_ctx.errors, - &ctx.extensions, - ctx.operation_type_name, - ctx.projection_plan, - ctx.variable_values, - exec_ctx.response_storage.estimate_final_response_size(), - ) - .with_plan_context(LazyPlanContext { - subgraph_name: || None, - affected_path: || None, - })?; - - Ok(PlanExecutionOutput { - body, - headers: response_headers, - }) + Ok(HttpResponse { + body: body.into(), + headers: response_headers, + status: http::StatusCode::OK, + }) + } } pub struct Executor<'exec, 'req> { @@ -127,8 +197,9 @@ pub struct Executor<'exec, 'req> { executors: &'exec SubgraphExecutorMap, client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_req_state: &'exec Option>, } struct ConcurrencyScope<'exec, T> { @@ -155,20 +226,15 @@ impl<'exec, T> ConcurrencyScope<'exec, T> { } } -struct SubgraphOutput { - body: Bytes, - headers: HeaderMap, -} - struct FetchJob { fetch_node_id: i64, subgraph_name: String, - response: SubgraphOutput, + response: HttpResponse, } struct FlattenFetchJob { flatten_node_path: FlattenNodePath, - response: SubgraphOutput, + response: HttpResponse, fetch_node_id: i64, subgraph_name: String, representation_hashes: Vec, @@ -181,18 +247,13 @@ enum ExecutionJob { None, } -impl From for SubgraphOutput { +impl From for HttpResponse { fn from(value: ExecutionJob) -> Self { match value { - ExecutionJob::Fetch(j) => Self { - body: j.response.body, - headers: j.response.headers, - }, - ExecutionJob::FlattenFetch(j) => Self { - body: j.response.body, - headers: j.response.headers, - }, + ExecutionJob::Fetch(j) => j.response, + ExecutionJob::FlattenFetch(j) => j.response, ExecutionJob::None => Self { + status: http::StatusCode::OK, body: Bytes::new(), headers: HeaderMap::new(), }, @@ -200,15 +261,6 @@ impl From for SubgraphOutput { } } -impl From for SubgraphOutput { - fn from(res: HttpExecutionResponse) -> Self { - Self { - body: res.body, - headers: res.headers, - } - } -} - struct PreparedFlattenData { representations: Vec, representation_hashes: Vec, @@ -216,14 +268,16 @@ struct PreparedFlattenData { } impl<'exec, 'req> Executor<'exec, 'req> { + #[allow(clippy::too_many_arguments)] pub fn new( variable_values: &'exec Option>, executors: &'exec SubgraphExecutorMap, schema_metadata: &'exec SchemaMetadata, client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_req_state: &'exec Option>, ) -> Self { Executor { variable_values, @@ -233,6 +287,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { headers_plan, dedupe_subgraph_requests, jwt_forwarding_plan, + plugin_req_state, } } @@ -700,7 +755,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { let variable_refs = select_fetch_variables(self.variable_values, node.variable_usages.as_ref()); - let mut subgraph_request = HttpExecutionRequest { + let mut subgraph_request = SubgraphExecutionRequest { query: node.operation.document_str.as_str(), dedupe: self.dedupe_subgraph_requests, operation_name: node.operation_name.as_deref(), @@ -722,9 +777,13 @@ impl<'exec, 'req> Executor<'exec, 'req> { subgraph_name: node.service_name.clone(), response: self .executors - .execute(&node.service_name, subgraph_request, self.client_request) - .await - .into(), + .execute( + &node.service_name, + subgraph_request, + self.client_request, + self.plugin_req_state, + ) + .await, })) } diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index bdcd4d819..0015d8139 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -1,16 +1,19 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; -use bytes::Bytes; -use http::HeaderMap; +use http::{HeaderMap, Uri}; use sonic_rs::Value; +use crate::{executors::http::HttpResponse, plugin_context::PluginRequestState}; + #[async_trait] pub trait SubgraphExecutor { + fn endpoint(&self) -> &Uri; async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, - ) -> HttpExecutionResponse; + execution_request: SubgraphExecutionRequest<'a>, + plugin_req_state: &'a Option>, + ) -> HttpResponse; fn to_boxed_arc<'a>(self) -> Arc> where @@ -26,7 +29,7 @@ pub type SubgraphExecutorBoxedArc = Arc>; pub type SubgraphRequestExtensions = HashMap; -pub struct HttpExecutionRequest<'a> { +pub struct SubgraphExecutionRequest<'a> { pub query: &'a str, pub dedupe: bool, pub operation_name: Option<&'a str>, @@ -37,15 +40,10 @@ pub struct HttpExecutionRequest<'a> { pub extensions: Option, } -impl HttpExecutionRequest<'_> { +impl SubgraphExecutionRequest<'_> { pub fn add_request_extensions_field(&mut self, key: String, value: Value) { self.extensions .get_or_insert_with(HashMap::new) .insert(key, value); } } - -pub struct HttpExecutionResponse { - pub body: Bytes, - pub headers: HeaderMap, -} diff --git a/lib/executor/src/executors/dedupe.rs b/lib/executor/src/executors/dedupe.rs index a60599f19..420ede67d 100644 --- a/lib/executor/src/executors/dedupe.rs +++ b/lib/executor/src/executors/dedupe.rs @@ -1,16 +1,8 @@ use ahash::AHasher; -use bytes::Bytes; -use http::{HeaderMap, Method, StatusCode, Uri}; +use http::{HeaderMap, Method, Uri}; use std::collections::BTreeMap; use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; -#[derive(Debug, Clone)] -pub struct SharedResponse { - pub status: StatusCode, - pub headers: HeaderMap, - pub body: Bytes, -} - pub fn request_fingerprint( method: &Method, url: &Uri, diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 29b392567..af9193194 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -1,7 +1,11 @@ use std::sync::Arc; -use crate::executors::common::HttpExecutionResponse; -use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; +use crate::executors::dedupe::{request_fingerprint, ABuildHasher}; +use crate::hooks::on_subgraph_http_request::{ + OnSubgraphHttpRequestHookPayload, OnSubgraphHttpResponseHookPayload, +}; +use crate::plugin_context::PluginRequestState; +use crate::plugin_trait::{EndControlFlow, StartControlFlow}; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -9,8 +13,8 @@ use tokio::sync::OnceCell; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; -use http::HeaderMap; use http::HeaderValue; +use http::{HeaderMap, StatusCode}; use http_body_util::BodyExt; use http_body_util::Full; use hyper::Version; @@ -19,7 +23,7 @@ use hyper_util::client::legacy::{connect::HttpConnector, Client}; use tokio::sync::Semaphore; use tracing::debug; -use crate::executors::common::HttpExecutionRequest; +use crate::executors::common::SubgraphExecutionRequest; use crate::executors::error::SubgraphExecutorError; use crate::response::graphql_error::GraphQLError; use crate::utils::consts::CLOSE_BRACE; @@ -28,7 +32,6 @@ use crate::utils::consts::COMMA; use crate::utils::consts::QUOTE; use crate::{executors::common::SubgraphExecutor, json_writer::write_and_escape_string}; -#[derive(Debug)] pub struct HTTPSubgraphExecutor { pub subgraph_name: String, pub endpoint: http::Uri, @@ -36,7 +39,7 @@ pub struct HTTPSubgraphExecutor { pub header_map: HeaderMap, pub semaphore: Arc, pub config: Arc, - pub in_flight_requests: Arc>, ABuildHasher>>, + pub in_flight_requests: Arc>, ABuildHasher>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -51,7 +54,7 @@ impl HTTPSubgraphExecutor { http_client: Arc, semaphore: Arc, config: Arc, - in_flight_requests: Arc>, ABuildHasher>>, + in_flight_requests: Arc>, ABuildHasher>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -76,7 +79,7 @@ impl HTTPSubgraphExecutor { fn build_request_body<'a>( &self, - execution_request: &HttpExecutionRequest<'a>, + execution_request: &SubgraphExecutionRequest<'a>, ) -> Result, SubgraphExecutorError> { let mut body = Vec::with_capacity(4096); body.put(FIRST_QUOTE_STR); @@ -133,57 +136,6 @@ impl HTTPSubgraphExecutor { Ok(body) } - async fn _send_request( - &self, - body: Vec, - headers: HeaderMap, - ) -> Result { - let mut req = hyper::Request::builder() - .method(http::Method::POST) - .uri(&self.endpoint) - .version(Version::HTTP_11) - .body(Full::new(Bytes::from(body))) - .map_err(|e| { - SubgraphExecutorError::RequestBuildFailure(self.endpoint.to_string(), e.to_string()) - })?; - - *req.headers_mut() = headers; - - debug!("making http request to {}", self.endpoint.to_string()); - - let res = self.http_client.request(req).await.map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) - })?; - - debug!( - "http request to {} completed, status: {}", - self.endpoint.to_string(), - res.status() - ); - - let (parts, body) = res.into_parts(); - let body = body - .collect() - .await - .map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) - })? - .to_bytes(); - - if body.is_empty() { - return Err(SubgraphExecutorError::RequestFailure( - self.endpoint.to_string(), - "Empty response body".to_string(), - )); - } - - Ok(SharedResponse { - status: parts.status, - body, - headers: parts.headers, - }) - } - fn error_to_graphql_bytes(&self, error: SubgraphExecutorError) -> Bytes { let graphql_error: GraphQLError = error.into(); let mut graphql_error = graphql_error.add_subgraph_name(&self.subgraph_name); @@ -207,49 +159,178 @@ impl HTTPSubgraphExecutor { } } +async fn send_request( + http_client: &Client, Full>, + subgraph_name: &str, + endpoint: &http::Uri, + mut method: http::Method, + mut body: Vec, + mut execution_request: SubgraphExecutionRequest<'_>, + plugin_req_state: &Option>, +) -> Result { + let mut on_end_callbacks = vec![]; + let mut response = None; + + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut start_payload = OnSubgraphHttpRequestHookPayload { + subgraph_name, + endpoint, + method, + body, + execution_request, + context: &plugin_req_state.context, + response, + }; + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_subgraph_http_request(start_payload).await; + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { /* continue to next plugin */ } + StartControlFlow::EndResponse(response) => { + return Ok(response); + } + StartControlFlow::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + method = start_payload.method; + body = start_payload.body; + execution_request = start_payload.execution_request; + response = start_payload.response; + } + + let mut response = match response { + Some(response) => response, + None => { + let mut req = hyper::Request::builder() + .method(method) + .uri(endpoint) + .version(Version::HTTP_11) + .body(Full::new(Bytes::from(body))) + .map_err(|e| { + SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) + })?; + + *req.headers_mut() = execution_request.headers; + + debug!("making http request to {}", endpoint.to_string()); + + let res = http_client.request(req).await.map_err(|e| { + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) + })?; + + debug!( + "http request to {} completed, status: {}", + endpoint.to_string(), + res.status() + ); + + let (parts, body) = res.into_parts(); + let body = body + .collect() + .await + .map_err(|e| { + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) + })? + .to_bytes(); + + if body.is_empty() { + return Err(SubgraphExecutorError::RequestFailure( + endpoint.to_string(), + "Empty response body".to_string(), + )); + } + + HttpResponse { + status: parts.status, + body, + headers: parts.headers, + } + } + }; + + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut end_payload = OnSubgraphHttpResponseHookPayload { + response, + context: &plugin_req_state.context, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { /* continue to next callback */ } + EndControlFlow::EndResponse(response) => { + return Ok(response); + } + } + } + + response = end_payload.response; + } + + Ok(response) +} + #[async_trait] impl SubgraphExecutor for HTTPSubgraphExecutor { + fn endpoint(&self) -> &http::Uri { + &self.endpoint + } #[tracing::instrument(skip_all, fields(subgraph_name = self.subgraph_name))] async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, - ) -> HttpExecutionResponse { + mut execution_request: SubgraphExecutionRequest<'a>, + plugin_req_state: &'a Option>, + ) -> HttpResponse { let body = match self.build_request_body(&execution_request) { Ok(body) => body, Err(e) => { self.log_error(&e); - return HttpExecutionResponse { + return HttpResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, }; } }; - let mut headers = execution_request.headers; self.header_map.iter().for_each(|(key, value)| { - headers.insert(key, value.clone()); + execution_request.headers.insert(key, value.clone()); }); + let method = http::Method::POST; + if !self.config.traffic_shaping.dedupe_enabled || !execution_request.dedupe { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - return match self._send_request(body, headers).await { - Ok(shared_response) => HttpExecutionResponse { - body: shared_response.body, - headers: shared_response.headers, - }, + return match send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + execution_request, + plugin_req_state, + ) + .await + { + Ok(shared_response) => shared_response, Err(e) => { self.log_error(&e); - HttpExecutionResponse { + HttpResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, } } }; } - let fingerprint = request_fingerprint(&http::Method::POST, &self.endpoint, &headers, &body); + let fingerprint = + request_fingerprint(&method, &self.endpoint, &execution_request.headers, &body); // Clone the cell from the map, dropping the lock from the DashMap immediately. // Prevents any deadlocks. @@ -266,7 +347,16 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - self._send_request(body, headers).await + send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + execution_request, + plugin_req_state, + ) + .await }; // It's important to remove the entry from the map before returning the result. // This ensures that once the OnceCell is set, no future requests can join it. @@ -277,17 +367,22 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { .await; match response_result { - Ok(shared_response) => HttpExecutionResponse { - body: shared_response.body.clone(), - headers: shared_response.headers.clone(), - }, + Ok(shared_response) => shared_response.clone(), Err(e) => { self.log_error(&e); - HttpExecutionResponse { + HttpResponse { body: self.error_to_graphql_bytes(e.clone()), headers: Default::default(), + status: StatusCode::OK, } } } } } + +#[derive(Clone)] +pub struct HttpResponse { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: Bytes, +} diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index a3c297ad1..b666189c7 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -29,13 +29,16 @@ use vrl::{ use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ - common::{ - HttpExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, - }, - dedupe::{ABuildHasher, SharedResponse}, + common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, + dedupe::ABuildHasher, error::SubgraphExecutorError, - http::{HTTPSubgraphExecutor, HttpClient}, + http::{HTTPSubgraphExecutor, HttpClient, HttpResponse}, }, + hooks::on_subgraph_execute::{ + OnSubgraphExecuteEndHookPayload, OnSubgraphExecuteStartHookPayload, + }, + plugin_context::PluginRequestState, + plugin_trait::{EndControlFlow, StartControlFlow}, response::graphql_error::GraphQLError, }; @@ -59,7 +62,7 @@ pub struct SubgraphExecutorMap { client: Arc, semaphores_by_origin: DashMap>, max_connections_per_host: usize, - in_flight_requests: Arc>, ABuildHasher>>, + in_flight_requests: Arc>, ABuildHasher>>, } impl SubgraphExecutorMap { @@ -115,36 +118,102 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - pub async fn execute<'a, 'req>( + pub async fn execute<'exec, 'req>( &self, subgraph_name: &str, - execution_request: HttpExecutionRequest<'a>, - client_request: &ClientRequestDetails<'a, 'req>, - ) -> HttpExecutionResponse { - match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor.execute(execution_request).await, + execution_request: SubgraphExecutionRequest<'exec>, + client_request: &ClientRequestDetails<'exec, 'req>, + plugin_req_state: &Option>, + ) -> HttpResponse { + let mut executor = match self.get_or_create_executor(subgraph_name, client_request) { + Ok(Some(executor)) => executor, Err(err) => { error!( "Subgraph executor error for subgraph '{}': {}", subgraph_name, err, ); - self.internal_server_error_response(err.into(), subgraph_name) + return self.internal_server_error_response(err.into(), subgraph_name); } Ok(None) => { error!( "Subgraph executor not found for subgraph '{}'", subgraph_name ); - self.internal_server_error_response("Internal server error".into(), subgraph_name) + return self + .internal_server_error_response("Internal server error".into(), subgraph_name); + } + }; + + let mut on_end_callbacks = vec![]; + + let mut execution_request = execution_request; + let mut execution_result = None; + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut start_payload = OnSubgraphExecuteStartHookPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + subgraph_name, + executor, + execution_request, + execution_result, + }; + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_subgraph_execute(start_payload).await; + start_payload = result.payload; + match result.control_flow { + StartControlFlow::Continue => { + // continue to next plugin + } + StartControlFlow::EndResponse(response) => { + // TODO: FFIX + return response; + } + StartControlFlow::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } } + execution_request = start_payload.execution_request; + execution_result = start_payload.execution_result; + executor = start_payload.executor; } + + let mut execution_result = match execution_result { + Some(execution_result) => execution_result, + None => executor.execute(execution_request, plugin_req_state).await, + }; + + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut end_payload = OnSubgraphExecuteEndHookPayload { + context: &plugin_req_state.context, + execution_result, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + EndControlFlow::Continue => { + // continue to next callback + } + EndControlFlow::EndResponse(response) => { + // TODO: FFIX + return response; + } + } + } + + execution_result = end_payload.execution_result; + } + + execution_result } fn internal_server_error_response( &self, graphql_error: GraphQLError, subgraph_name: &str, - ) -> HttpExecutionResponse { + ) -> HttpResponse { let errors = vec![graphql_error.add_subgraph_name(subgraph_name)]; let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); let mut buffer = BytesMut::new(); @@ -152,9 +221,10 @@ impl SubgraphExecutorMap { buffer.put_slice(&errors_bytes); buffer.put_slice(b"}"); - HttpExecutionResponse { + HttpResponse { body: buffer.freeze(), headers: Default::default(), + status: http::StatusCode::INTERNAL_SERVER_ERROR, } } diff --git a/lib/executor/src/headers/mod.rs b/lib/executor/src/headers/mod.rs index 62f9fe701..0338bf6df 100644 --- a/lib/executor/src/headers/mod.rs +++ b/lib/executor/src/headers/mod.rs @@ -82,7 +82,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -116,7 +116,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); modify_subgraph_request_headers(&plan, "any", &client_details, &mut out).unwrap(); @@ -163,7 +163,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -201,7 +201,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -235,7 +235,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -275,7 +275,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; // For "accounts" subgraph, the specific rule should apply. @@ -319,7 +319,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -384,7 +384,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -448,7 +448,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -505,7 +505,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -563,7 +563,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -622,7 +622,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); diff --git a/lib/executor/src/headers/request.rs b/lib/executor/src/headers/request.rs index 637ab0d58..7f362be73 100644 --- a/lib/executor/src/headers/request.rs +++ b/lib/executor/src/headers/request.rs @@ -45,9 +45,9 @@ pub fn modify_subgraph_request_headers( Ok(()) } -pub struct RequestExpressionContext<'a, 'req> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, +pub struct RequestExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, } trait ApplyRequestHeader { @@ -117,7 +117,7 @@ impl ApplyRequestHeader for RequestPropagateRegex { ctx: &RequestExpressionContext, output_headers: &mut HeaderMap, ) -> Result<(), HeaderRuleRuntimeError> { - for (header_name, header_value) in ctx.client_request.headers { + for (header_name, header_value) in ctx.client_request.headers.iter() { if is_denied_header(header_name) { continue; } diff --git a/lib/executor/src/headers/response.rs b/lib/executor/src/headers/response.rs index 6a5c34444..b4942837f 100644 --- a/lib/executor/src/headers/response.rs +++ b/lib/executor/src/headers/response.rs @@ -50,10 +50,10 @@ pub fn apply_subgraph_response_headers( Ok(()) } -pub struct ResponseExpressionContext<'a, 'req> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, - pub subgraph_headers: &'a HeaderMap, +pub struct ResponseExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, + pub subgraph_headers: &'exec HeaderMap, } trait ApplyResponseHeader { diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index 4f912a463..bdcbdadc0 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -4,10 +4,11 @@ pub mod executors; pub mod headers; pub mod introspection; pub mod json_writer; +pub mod plugins; pub mod projection; pub mod response; pub mod utils; pub mod variables; -pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; +pub use plugins::*; diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs new file mode 100644 index 000000000..64851d0fd --- /dev/null +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -0,0 +1,9 @@ +pub mod on_execute; +pub mod on_graphql_params; +pub mod on_graphql_parse; +pub mod on_graphql_validation; +pub mod on_http_request; +pub mod on_query_plan; +pub mod on_subgraph_execute; +pub mod on_subgraph_http_request; +pub mod on_supergraph_load; diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs new file mode 100644 index 000000000..e1392af14 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -0,0 +1,41 @@ +use std::collections::HashMap; + +use hive_router_query_planner::ast::operation::OperationDefinition; +use hive_router_query_planner::planner::plan_nodes::QueryPlan; + +use crate::plugin_context::{PluginContext, RouterHttpRequest}; +use crate::plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}; +use crate::response::graphql_error::GraphQLError; +use crate::response::value::Value; + +pub struct OnExecuteStartHookPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, + + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, + + pub variable_values: &'exec Option>, + + pub dedupe_subgraph_requests: bool, +} + +impl<'exec> StartHookPayload> for OnExecuteStartHookPayload<'exec> {} + +pub type OnExecuteStartHookResult<'exec> = + StartHookResult<'exec, OnExecuteStartHookPayload<'exec>, OnExecuteEndHookPayload<'exec>>; + +pub struct OnExecuteEndHookPayload<'exec> { + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, + + pub response_size_estimate: usize, +} + +impl<'exec> EndHookPayload for OnExecuteEndHookPayload<'exec> {} + +pub type OnExecuteEndHookResult<'exec> = EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs new file mode 100644 index 000000000..ea44426f2 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -0,0 +1,123 @@ +use core::fmt; + +use std::collections::HashMap; + +use ntex::util::Bytes; +use serde::{de, Deserialize, Deserializer}; +use sonic_rs::Value; + +use crate::plugin_context::PluginContext; +use crate::plugin_context::RouterHttpRequest; +use crate::plugin_trait::EndHookPayload; +use crate::plugin_trait::EndHookResult; +use crate::plugin_trait::StartHookPayload; +use crate::plugin_trait::StartHookResult; + +#[derive(Debug, Clone, Default)] +pub struct GraphQLParams { + pub query: Option, + pub operation_name: Option, + pub variables: HashMap, + // TODO: We don't use extensions yet, but we definitely will in the future. + #[allow(dead_code)] + pub extensions: Option>, +} + +// Workaround for https://github.com/cloudwego/sonic-rs/issues/114 + +impl<'de> Deserialize<'de> for GraphQLParams { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct GraphQLErrorExtensionsVisitor; + + impl<'de> de::Visitor<'de> for GraphQLErrorExtensionsVisitor { + type Value = GraphQLParams; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map for GraphQLErrorExtensions") + } + + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut query = None; + let mut operation_name = None; + let mut variables: Option> = None; + let mut extensions: Option> = None; + let mut extra_params = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "query" => { + if query.is_some() { + return Err(de::Error::duplicate_field("query")); + } + query = map.next_value::>()?; + } + "operationName" => { + if operation_name.is_some() { + return Err(de::Error::duplicate_field("operationName")); + } + operation_name = map.next_value::>()?; + } + "variables" => { + if variables.is_some() { + return Err(de::Error::duplicate_field("variables")); + } + variables = map.next_value::>>()?; + } + "extensions" => { + if extensions.is_some() { + return Err(de::Error::duplicate_field("extensions")); + } + extensions = map.next_value::>>()?; + } + other => { + let value: Value = map.next_value()?; + extra_params.insert(other.to_string(), value); + } + } + } + + Ok(GraphQLParams { + query, + operation_name, + variables: variables.unwrap_or_default(), + extensions, + }) + } + } + + deserializer.deserialize_map(GraphQLErrorExtensionsVisitor) + } +} + +pub struct OnGraphQLParamsStartHookPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub body: Bytes, + pub graphql_params: Option, +} + +impl<'exec> StartHookPayload> + for OnGraphQLParamsStartHookPayload<'exec> +{ +} + +pub type OnGraphQLParamsStartHookResult<'exec> = StartHookResult< + 'exec, + OnGraphQLParamsStartHookPayload<'exec>, + OnGraphQLParamsEndHookPayload<'exec>, +>; + +pub struct OnGraphQLParamsEndHookPayload<'exec> { + pub graphql_params: GraphQLParams, + pub context: &'exec PluginContext, +} + +impl<'exec> EndHookPayload for OnGraphQLParamsEndHookPayload<'exec> {} + +pub type OnGraphQLParamsEndHookResult<'exec> = EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs new file mode 100644 index 000000000..9ee55780a --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -0,0 +1,30 @@ +use graphql_tools::static_graphql::query::Document; + +use crate::{ + hooks::on_graphql_params::GraphQLParams, + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, +}; + +pub struct OnGraphQLParseStartHookPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub graphql_params: &'exec GraphQLParams, + pub document: Option, +} + +impl<'exec> StartHookPayload + for OnGraphQLParseStartHookPayload<'exec> +{ +} + +pub type OnGraphQLParseHookResult<'exec> = + StartHookResult<'exec, OnGraphQLParseStartHookPayload<'exec>, OnGraphQLParseEndHookPayload>; + +pub struct OnGraphQLParseEndHookPayload { + pub document: Document, +} + +impl EndHookPayload for OnGraphQLParseEndHookPayload {} + +pub type OnGraphQLParseEndHookResult = EndHookResult; diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs new file mode 100644 index 000000000..9431dd29f --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -0,0 +1,85 @@ +use graphql_tools::{ + static_graphql::query::Document, + validation::{ + rules::{default_rules_validation_plan, ValidationRule}, + utils::ValidationError, + validate::ValidationPlan, + }, +}; +use hive_router_query_planner::state::supergraph_state::SchemaDocument; + +use crate::{ + plugin_context::{PluginContext, PluginRequestState, RouterHttpRequest}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, +}; + +pub struct OnGraphQLValidationStartHookPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub schema: &'exec SchemaDocument, + pub document: &'exec Document, + default_validation_plan: &'exec ValidationPlan, + new_validation_plan: Option, + pub errors: Option>, +} + +impl<'exec> StartHookPayload + for OnGraphQLValidationStartHookPayload<'exec> +{ +} + +pub type OnGraphQLValidationStartHookResult<'exec> = StartHookResult< + 'exec, + OnGraphQLValidationStartHookPayload<'exec>, + OnGraphQLValidationEndHookPayload, +>; + +impl<'exec> OnGraphQLValidationStartHookPayload<'exec> { + pub fn new( + plugin_req_state: &'exec PluginRequestState<'exec>, + schema: &'exec SchemaDocument, + document: &'exec Document, + default_validation_plan: &'exec ValidationPlan, + ) -> Self { + OnGraphQLValidationStartHookPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + schema, + document, + default_validation_plan, + new_validation_plan: None, + errors: None, + } + } + + pub fn add_validation_rule(&mut self, rule: Box) { + self.new_validation_plan + .get_or_insert_with(default_rules_validation_plan) + .add_rule(rule); + } + + pub fn filter_validation_rules(&mut self, mut f: F) + where + F: FnMut(&Box) -> bool, + { + let plan = self + .new_validation_plan + .get_or_insert_with(default_rules_validation_plan); + plan.rules.retain(|rule| f(rule)); + } + + pub fn get_validation_plan(&self) -> &ValidationPlan { + match &self.new_validation_plan { + Some(plan) => plan, + None => self.default_validation_plan, + } + } +} + +pub struct OnGraphQLValidationEndHookPayload { + pub errors: Vec, +} + +impl EndHookPayload for OnGraphQLValidationEndHookPayload {} + +pub type OnGraphQLValidationHookEndResult = EndHookResult; diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs new file mode 100644 index 000000000..8473a69fc --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -0,0 +1,25 @@ +use ntex::web::{self, DefaultError, WebRequest}; + +use crate::{ + plugin_context::PluginContext, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, +}; + +pub struct OnHttpRequestHookPayload<'req> { + pub router_http_request: WebRequest, + pub context: &'req PluginContext, +} + +impl<'req> StartHookPayload> for OnHttpRequestHookPayload<'req> {} + +pub type OnHttpRequestHookResult<'req> = + StartHookResult<'req, OnHttpRequestHookPayload<'req>, OnHttpResponseHookPayload<'req>>; + +pub struct OnHttpResponseHookPayload<'req> { + pub response: web::WebResponse, + pub context: &'req PluginContext, +} + +impl<'req> EndHookPayload for OnHttpResponseHookPayload<'req> {} + +pub type OnHttpResponseHookResult<'req> = EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs new file mode 100644 index 000000000..103eafdf2 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -0,0 +1,34 @@ +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + graph::PlannerOverrideContext, + planner::{plan_nodes::QueryPlan, Planner}, + utils::cancellation::CancellationToken, +}; + +use crate::{ + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, +}; + +pub struct OnQueryPlanStartHookPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub filtered_operation_for_plan: &'exec OperationDefinition, + pub planner_override_context: PlannerOverrideContext, + pub cancellation_token: &'exec CancellationToken, + pub query_plan: Option, + pub planner: &'exec Planner, +} + +impl<'exec> StartHookPayload for OnQueryPlanStartHookPayload<'exec> {} + +pub type OnQueryPlanStartHookResult<'exec> = + StartHookResult<'exec, OnQueryPlanStartHookPayload<'exec>, OnQueryPlanEndHookPayload>; + +pub struct OnQueryPlanEndHookPayload { + pub query_plan: QueryPlan, +} + +impl EndHookPayload for OnQueryPlanEndHookPayload {} + +pub type OnQueryPlanEndHookResult = EndHookResult; diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs new file mode 100644 index 000000000..b4e08f320 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -0,0 +1,40 @@ +use crate::{ + executors::{ + common::{SubgraphExecutionRequest, SubgraphExecutorBoxedArc}, + http::HttpResponse, + }, + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, +}; + +pub struct OnSubgraphExecuteStartHookPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + + pub subgraph_name: &'exec str, + pub executor: SubgraphExecutorBoxedArc, + + pub execution_request: SubgraphExecutionRequest<'exec>, + pub execution_result: Option, +} + +impl<'exec> StartHookPayload> + for OnSubgraphExecuteStartHookPayload<'exec> +{ +} + +pub type OnSubgraphExecuteStartHookResult<'exec> = StartHookResult< + 'exec, + OnSubgraphExecuteStartHookPayload<'exec>, + OnSubgraphExecuteEndHookPayload<'exec>, +>; + +pub struct OnSubgraphExecuteEndHookPayload<'exec> { + pub execution_result: HttpResponse, + pub context: &'exec PluginContext, +} + +impl<'exec> EndHookPayload for OnSubgraphExecuteEndHookPayload<'exec> {} + +pub type OnSubgraphExecuteEndHookResult<'exec> = + EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs new file mode 100644 index 000000000..6e9d6a092 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -0,0 +1,40 @@ +use crate::{ + executors::{common::SubgraphExecutionRequest, http::HttpResponse}, + plugin_context::PluginContext, + plugin_trait::{EndHookPayload, StartHookPayload}, +}; + +pub struct OnSubgraphHttpRequestHookPayload<'exec> { + pub subgraph_name: &'exec str, + + pub endpoint: &'exec http::Uri, + pub method: http::Method, + pub body: Vec, + pub execution_request: SubgraphExecutionRequest<'exec>, + + pub context: &'exec PluginContext, + + // Early response + pub response: Option, +} + +impl<'exec> StartHookPayload> + for OnSubgraphHttpRequestHookPayload<'exec> +{ +} + +pub type OnSubgraphHttpRequestHookResult<'exec> = crate::plugin_trait::StartHookResult< + 'exec, + OnSubgraphHttpRequestHookPayload<'exec>, + OnSubgraphHttpResponseHookPayload<'exec>, +>; + +pub struct OnSubgraphHttpResponseHookPayload<'exec> { + pub context: &'exec PluginContext, + pub response: HttpResponse, +} + +impl<'exec> EndHookPayload for OnSubgraphHttpResponseHookPayload<'exec> {} + +pub type OnSubgraphHttpResponseHookResult<'exec> = + crate::plugin_trait::EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_supergraph_load.rs b/lib/executor/src/plugins/hooks/on_supergraph_load.rs new file mode 100644 index 000000000..2425c9c3f --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_supergraph_load.rs @@ -0,0 +1,39 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use graphql_tools::static_graphql::schema::Document; +use hive_router_query_planner::planner::Planner; + +use crate::{ + introspection::schema::SchemaMetadata, + plugin_trait::{EndHookPayload, StartHookPayload}, + SubgraphExecutorMap, +}; + +pub struct SupergraphData { + pub metadata: SchemaMetadata, + pub planner: Planner, + pub subgraph_executor_map: SubgraphExecutorMap, +} + +pub struct OnSupergraphLoadStartHookPayload { + pub current_supergraph_data: Arc>>, + pub new_ast: Document, +} + +impl StartHookPayload for OnSupergraphLoadStartHookPayload {} + +pub type OnSupergraphLoadStartHookResult<'exec> = crate::plugin_trait::StartHookResult< + 'exec, + OnSupergraphLoadStartHookPayload, + OnSupergraphLoadEndHookPayload, +>; + +pub struct OnSupergraphLoadEndHookPayload { + pub new_supergraph_data: SupergraphData, +} + +impl EndHookPayload for OnSupergraphLoadEndHookPayload {} + +pub type OnSupergraphLoadEndHookResult = + crate::plugin_trait::EndHookResult; diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs new file mode 100644 index 000000000..008dc147a --- /dev/null +++ b/lib/executor/src/plugins/mod.rs @@ -0,0 +1,3 @@ +pub mod hooks; +pub mod plugin_context; +pub mod plugin_trait; diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs new file mode 100644 index 000000000..b5babeca9 --- /dev/null +++ b/lib/executor/src/plugins/plugin_context.rs @@ -0,0 +1,159 @@ +use std::{ + any::{Any, TypeId}, + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use dashmap::{ + mapref::one::{Ref, RefMut}, + DashMap, +}; +use http::Uri; +use ntex::router::Path; +use ntex_http::HeaderMap; + +use crate::plugin_trait::RouterPluginBoxed; + +pub struct RouterHttpRequest<'exec> { + pub uri: &'exec Uri, + pub method: &'exec http::Method, + pub version: http::Version, + pub headers: &'exec HeaderMap, + pub path: &'exec str, + pub query_string: &'exec str, + pub match_info: &'exec Path, +} + +#[derive(Default)] +pub struct PluginContext { + inner: DashMap>, +} + +pub struct PluginContextRefEntry<'a, T> { + pub entry: Ref<'a, TypeId, Box>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> AsRef for PluginContextRefEntry<'a, T> { + fn as_ref(&self) -> &T { + let boxed_any = self.entry.value(); + boxed_any + .downcast_ref::() + .expect("type mismatch in PluginContextRefEntry") + } +} + +impl<'a, T: Any + Send + Sync> Deref for PluginContextRefEntry<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +pub struct PluginContextMutEntry<'a, T> { + pub entry: RefMut<'a, TypeId, Box>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> AsRef for PluginContextMutEntry<'a, T> { + fn as_ref(&self) -> &T { + let boxed_any = self.entry.value(); + boxed_any + .downcast_ref::() + .expect("type mismatch in PluginContextMutEntry") + } +} + +impl<'a, T: Any + Send + Sync> Deref for PluginContextMutEntry<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl<'a, T: Any + Send + Sync> AsMut for PluginContextMutEntry<'a, T> { + fn as_mut(&mut self) -> &mut T { + let boxed_any = self.entry.value_mut(); + boxed_any + .downcast_mut::() + .expect("type mismatch in PluginContextMutEntry") + } +} + +impl<'a, T: Any + Send + Sync> DerefMut for PluginContextMutEntry<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.as_mut() + } +} + +impl PluginContext { + pub fn contains(&self) -> bool { + let type_id = TypeId::of::(); + self.inner.contains_key(&type_id) + } + pub fn insert(&self, value: T) -> Option> { + let type_id = TypeId::of::(); + self.inner + .insert(type_id, Box::new(value)) + .and_then(|boxed_any| boxed_any.downcast::().ok()) + } + pub fn get_ref<'a, T: Any + Send + Sync>(&'a self) -> Option> { + let type_id = TypeId::of::(); + self.inner.get(&type_id).map(|entry| PluginContextRefEntry { + entry, + phantom: std::marker::PhantomData, + }) + } + pub fn get_mut<'a, T: Any + Send + Sync>(&'a self) -> Option> { + let type_id = TypeId::of::(); + self.inner + .get_mut(&type_id) + .map(|entry| PluginContextMutEntry { + entry, + phantom: std::marker::PhantomData, + }) + } +} + +pub struct PluginRequestState<'req> { + pub plugins: Arc>, + pub router_http_request: RouterHttpRequest<'req>, + pub context: Arc, +} + +#[cfg(test)] +mod tests { + #[test] + fn inserts_and_gets_immut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + let ctx_ref: &TestCtx = &ctx.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 42); + } + #[test] + fn inserts_and_mutates_with_mut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + { + let ctx_mut: &mut TestCtx = &mut ctx.get_mut().unwrap(); + ctx_mut.value = 100; + } + + let ctx_ref: &TestCtx = &ctx.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 100); + } +} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs new file mode 100644 index 000000000..72517b776 --- /dev/null +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -0,0 +1,164 @@ +use serde::de::DeserializeOwned; + +use crate::{ + executors::http::HttpResponse, + hooks::{ + on_execute::{OnExecuteStartHookPayload, OnExecuteStartHookResult}, + on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + on_graphql_parse::{OnGraphQLParseHookResult, OnGraphQLParseStartHookPayload}, + on_graphql_validation::{ + OnGraphQLValidationStartHookPayload, OnGraphQLValidationStartHookResult, + }, + on_http_request::{OnHttpRequestHookPayload, OnHttpRequestHookResult}, + on_query_plan::{OnQueryPlanStartHookPayload, OnQueryPlanStartHookResult}, + on_subgraph_execute::{ + OnSubgraphExecuteStartHookPayload, OnSubgraphExecuteStartHookResult, + }, + on_subgraph_http_request::{ + OnSubgraphHttpRequestHookPayload, OnSubgraphHttpRequestHookResult, + }, + on_supergraph_load::{OnSupergraphLoadStartHookPayload, OnSupergraphLoadStartHookResult}, + }, +}; + +pub struct StartHookResult<'exec, TStartPayload, TEndPayload> { + pub payload: TStartPayload, + pub control_flow: StartControlFlow<'exec, TEndPayload>, +} + +pub enum StartControlFlow<'exec, TEndPayload> { + Continue, + EndResponse(HttpResponse), + OnEnd(Box EndHookResult + Send + 'exec>), +} + +pub trait StartHookPayload +where + Self: Sized, +{ + fn cont<'exec>(self) -> StartHookResult<'exec, Self, TEndPayload> { + StartHookResult { + payload: self, + control_flow: StartControlFlow::Continue, + } + } + + fn end_response<'exec>( + self, + output: HttpResponse, + ) -> StartHookResult<'exec, Self, TEndPayload> { + StartHookResult { + payload: self, + control_flow: StartControlFlow::EndResponse(output), + } + } + + fn on_end<'exec, F>(self, f: F) -> StartHookResult<'exec, Self, TEndPayload> + where + F: FnOnce(TEndPayload) -> EndHookResult + Send + 'exec, + { + StartHookResult { + payload: self, + control_flow: StartControlFlow::OnEnd(Box::new(f)), + } + } +} + +pub struct EndHookResult { + pub payload: TEndPayload, + pub control_flow: EndControlFlow, +} + +pub enum EndControlFlow { + Continue, + EndResponse(HttpResponse), +} + +pub trait EndHookPayload +where + Self: Sized, +{ + fn cont(self) -> EndHookResult { + EndHookResult { + payload: self, + control_flow: EndControlFlow::Continue, + } + } + + fn end_response(self, output: HttpResponse) -> EndHookResult { + EndHookResult { + payload: self, + control_flow: EndControlFlow::EndResponse(output), + } + } +} + +pub trait RouterPluginWithConfig +where + Self: Sized, + Self: RouterPlugin, +{ + fn plugin_name() -> &'static str; + type Config: DeserializeOwned; + fn from_config(config: Self::Config) -> Option; +} + +#[async_trait::async_trait] +pub trait RouterPlugin { + fn on_http_request<'req>( + &'req self, + start_payload: OnHttpRequestHookPayload<'req>, + ) -> OnHttpRequestHookResult<'req> { + start_payload.cont() + } + async fn on_graphql_params<'exec>( + &'exec self, + start_payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { + start_payload.cont() + } + async fn on_graphql_parse<'exec>( + &'exec self, + start_payload: OnGraphQLParseStartHookPayload<'exec>, + ) -> OnGraphQLParseHookResult<'exec> { + start_payload.cont() + } + async fn on_graphql_validation<'exec>( + &'exec self, + start_payload: OnGraphQLValidationStartHookPayload<'exec>, + ) -> OnGraphQLValidationStartHookResult<'exec> { + start_payload.cont() + } + async fn on_query_plan<'exec>( + &'exec self, + start_payload: OnQueryPlanStartHookPayload<'exec>, + ) -> OnQueryPlanStartHookResult<'exec> { + start_payload.cont() + } + async fn on_execute<'exec>( + &'exec self, + start_payload: OnExecuteStartHookPayload<'exec>, + ) -> OnExecuteStartHookResult<'exec> { + start_payload.cont() + } + async fn on_subgraph_execute<'exec>( + &'exec self, + start_payload: OnSubgraphExecuteStartHookPayload<'exec>, + ) -> OnSubgraphExecuteStartHookResult<'exec> { + start_payload.cont() + } + async fn on_subgraph_http_request<'exec>( + &'exec self, + start_payload: OnSubgraphHttpRequestHookPayload<'exec>, + ) -> OnSubgraphHttpRequestHookResult<'exec> { + start_payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartHookPayload, + ) -> OnSupergraphLoadStartHookResult<'exec> { + start_payload.cont() + } +} + +pub type RouterPluginBoxed = Box; diff --git a/lib/router-config/src/lib.rs b/lib/router-config/src/lib.rs index 537244c9e..d113288f7 100644 --- a/lib/router-config/src/lib.rs +++ b/lib/router-config/src/lib.rs @@ -92,6 +92,10 @@ pub struct HiveRouterConfig { /// Configuration for overriding labels. #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub override_labels: OverrideLabelsConfig, + + /// Configuration for custom plugins + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub plugins: HashMap, } #[derive(Debug, thiserror::Error)] diff --git a/lib/router-config/src/primitives/file_path.rs b/lib/router-config/src/primitives/file_path.rs index f8140562a..13aac9c37 100644 --- a/lib/router-config/src/primitives/file_path.rs +++ b/lib/router-config/src/primitives/file_path.rs @@ -9,6 +9,7 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, Serialize, }; +use tracing::info; #[derive(Debug, Clone, Serialize)] pub struct FilePath { @@ -70,6 +71,11 @@ impl<'de> Visitor<'de> for FilePathVisitor { { CONTEXT_START_PATH.with(|ctx| { if let Some(start_path) = ctx.borrow().as_ref() { + info!( + "Deserializing FilePath '{}' with start path '{}'", + v, + start_path.display() + ); match FilePath::resolve_relative(start_path, v, true) { Ok(file_path) => Ok(file_path), Err(err) => Err(E::custom(format!("Failed to canonicalize path: {}", err))),