From d34732ac3eb19edef397656f36ef5117b2c98a76 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Mon, 13 Oct 2025 16:16:49 +0300 Subject: [PATCH 01/31] Response Cache Plugin --- Cargo.lock | 24 ++++ lib/executor/Cargo.toml | 3 + lib/executor/src/lib.rs | 2 + lib/executor/src/plugins/mod.rs | 2 + lib/executor/src/plugins/response_cache.rs | 136 +++++++++++++++++++++ lib/executor/src/plugins/traits.rs | 61 +++++++++ 6 files changed, 228 insertions(+) create mode 100644 lib/executor/src/plugins/mod.rs create mode 100644 lib/executor/src/plugins/response_cache.rs create mode 100644 lib/executor/src/plugins/traits.rs diff --git a/Cargo.lock b/Cargo.lock index d828dd8d2..d5c05dfff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2064,8 +2064,10 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", + "ntex", "ntex-http", "ordered-float", + "redis", "regex-automata", "ryu", "serde", @@ -4217,6 +4219,22 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redis" +version = "0.32.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "014cc767fefab6a3e798ca45112bccad9c6e0e218fbd49720042716c73cfef44" +dependencies = [ + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "sha1_smol", + "socket2 0.6.1", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4902,6 +4920,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 27f7af1bf..7dcfc03fb 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -34,6 +34,8 @@ vrl = { workspace = true } ahash = "0.8.12" regex-automata = "0.4.10" strum = { version = "0.27.2", features = ["derive"] } + +ntex = { version = "2", features = ["tokio"] } ntex-http = "0.1.15" ordered-float = "4.2.0" hyper-tls = { version = "0.6.0", features = ["vendored"] } @@ -49,6 +51,7 @@ itoa = "1.0.15" ryu = "1.0.20" indexmap = "2.10.0" bumpalo = "3.19.0" +redis = "0.32.7" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index 4f912a463..c245a7483 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -4,6 +4,7 @@ pub mod executors; pub mod headers; pub mod introspection; pub mod json_writer; +pub mod plugins; pub mod projection; pub mod response; pub mod utils; @@ -11,3 +12,4 @@ pub mod variables; pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; +pub use plugins::response_cache::*; diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs new file mode 100644 index 000000000..0e59883b7 --- /dev/null +++ b/lib/executor/src/plugins/mod.rs @@ -0,0 +1,2 @@ +pub mod response_cache; +pub mod traits; diff --git a/lib/executor/src/plugins/response_cache.rs b/lib/executor/src/plugins/response_cache.rs new file mode 100644 index 000000000..545d018a3 --- /dev/null +++ b/lib/executor/src/plugins/response_cache.rs @@ -0,0 +1,136 @@ +use dashmap::DashMap; +use ntex::web::HttpResponse; +use redis::Commands; +use sonic_rs::json; + +use crate::{ + plugins::traits::{ + ControlFlow, OnExecuteEnd, OnExecuteEndPayload, OnExecuteStart, OnExecuteStartPayload, + OnSchemaReload, OnSchemaReloadPayload, + }, + utils::consts::TYPENAME_FIELD_NAME, +}; + +pub struct ResponseCachePlugin { + redis_client: redis::Client, + ttl_per_type: DashMap, +} + +impl ResponseCachePlugin { + pub fn try_new(redis_url: &str) -> Result { + let redis_client = redis::Client::open(redis_url)?; + Ok(Self { + redis_client, + ttl_per_type: DashMap::new(), + }) + } +} + +pub struct ResponseCacheContext { + key: String, +} + +impl OnExecuteStart for ResponseCachePlugin { + fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow { + let key = format!( + "response_cache:{}:{:?}", + payload.query_plan, payload.variable_values + ); + payload + .router_http_request + .extensions_mut() + .insert(ResponseCacheContext { key: key.clone() }); + if let Ok(mut conn) = self.redis_client.get_connection() { + let cached_response: Option> = conn.get(&key).ok(); + if let Some(cached_response) = cached_response { + return ControlFlow::Break( + HttpResponse::Ok() + .header("Content-Type", "application/json") + .body(cached_response), + ); + } + } + ControlFlow::Continue + } +} + +impl OnExecuteEnd for ResponseCachePlugin { + fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow { + // Do not cache if there are errors + if !payload.errors.is_empty() { + return ControlFlow::Continue; + } + if let Some(key) = payload + .router_http_request + .extensions() + .get::() + .map(|ctx| &ctx.key) + { + if let Ok(mut conn) = self.redis_client.get_connection() { + if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { + // Decide on the ttl somehow + // Get the type names + let mut max_ttl = 0; + + // Imagine this code is traversing the response data to find type names + if let Some(obj) = payload.data.as_object() { + if let Some(typename) = obj + .iter() + .position(|(k, _)| k == &TYPENAME_FIELD_NAME) + .and_then(|idx| obj[idx].1.as_str()) + { + if let Some(ttl) = self.ttl_per_type.get(typename).map(|v| *v) { + max_ttl = max_ttl.max(ttl); + } + } + } + + // If no ttl found, default to 60 seconds + if max_ttl == 0 { + max_ttl = 60; + } + + // Insert the ttl into extensions for client awareness + payload + .extensions + .insert("response_cache_ttl".to_string(), json!(max_ttl)); + + // Set the cache with the decided ttl + let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); + } + } + } + ControlFlow::Continue + } +} + +impl OnSchemaReload for ResponseCachePlugin { + fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { + // Visit the schema and update ttl_per_type based on some directive + payload + .new_schema + .document + .definitions + .iter() + .for_each(|def| { + if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { + if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { + for directive in &obj_type.directives { + if directive.name == "cacheControl" { + for arg in &directive.arguments { + if arg.0 == "maxAge" { + if let graphql_parser::query::Value::Int(max_age) = &arg.1 { + if let Some(max_age) = max_age.as_i64() { + self.ttl_per_type + .insert(obj_type.name.clone(), max_age as u64); + } + } + } + } + } + } + } + } + }); + } +} diff --git a/lib/executor/src/plugins/traits.rs b/lib/executor/src/plugins/traits.rs new file mode 100644 index 000000000..4357e5eab --- /dev/null +++ b/lib/executor/src/plugins/traits.rs @@ -0,0 +1,61 @@ +use std::{collections::HashMap, sync::Arc}; + +use hive_router_query_planner::consumer_schema::ConsumerSchema; +use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use ntex::web::HttpRequest; +use ntex::web::HttpResponse; + +use crate::response::graphql_error::GraphQLError; +use crate::response::value::Value; + +pub enum ControlFlow { + Continue, + Break(HttpResponse), +} + +pub struct ExecutionResult<'exec> { + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, + pub extensions: &'exec mut Option>>, +} + +pub struct OnExecuteStartPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, + pub extensions: Option<&'exec mut sonic_rs::Value>, + + pub skip_execution: bool, + + pub variable_values: &'exec Option>, +} + +pub trait OnExecuteStart { + fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow; +} + +pub struct OnExecuteEndPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec Value<'exec>, + pub errors: &'exec Vec, + pub extensions: &'exec mut HashMap, + + pub variable_values: &'exec Option>, +} + +pub trait OnExecuteEnd { + fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow; +} + +pub struct OnSchemaReloadPayload { + pub old_schema: &'static ConsumerSchema, + pub new_schema: &'static mut ConsumerSchema, +} + +pub trait OnSchemaReload { + fn on_schema_reload(&self, payload: OnSchemaReloadPayload); +} From d6a6b7a79dd3f0c9e7aa514af1a1577fde8e9fe9 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 14 Oct 2025 14:08:52 +0300 Subject: [PATCH 02/31] Iteration --- lib/executor/src/plugins/response_cache.rs | 49 +++++++--------------- lib/executor/src/plugins/traits.rs | 25 ++++++----- 2 files changed, 26 insertions(+), 48 deletions(-) diff --git a/lib/executor/src/plugins/response_cache.rs b/lib/executor/src/plugins/response_cache.rs index 545d018a3..01719deca 100644 --- a/lib/executor/src/plugins/response_cache.rs +++ b/lib/executor/src/plugins/response_cache.rs @@ -1,12 +1,10 @@ use dashmap::DashMap; use ntex::web::HttpResponse; use redis::Commands; -use sonic_rs::json; use crate::{ plugins::traits::{ - ControlFlow, OnExecuteEnd, OnExecuteEndPayload, OnExecuteStart, OnExecuteStartPayload, - OnSchemaReload, OnSchemaReloadPayload, + ControlFlow, OnExecutePayload, OnSchemaReloadPayload, RouterPlugin }, utils::consts::TYPENAME_FIELD_NAME, }; @@ -26,20 +24,15 @@ impl ResponseCachePlugin { } } -pub struct ResponseCacheContext { - key: String, -} - -impl OnExecuteStart for ResponseCachePlugin { - fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow { +impl RouterPlugin for ResponseCachePlugin { + fn on_execute<'exec>( + &self, + payload: OnExecutePayload<'exec>, + ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { let key = format!( "response_cache:{}:{:?}", payload.query_plan, payload.variable_values ); - payload - .router_http_request - .extensions_mut() - .insert(ResponseCacheContext { key: key.clone() }); if let Ok(mut conn) = self.redis_client.get_connection() { let cached_response: Option> = conn.get(&key).ok(); if let Some(cached_response) = cached_response { @@ -49,24 +42,12 @@ impl OnExecuteStart for ResponseCachePlugin { .body(cached_response), ); } - } - ControlFlow::Continue - } -} + ControlFlow::OnEnd(Box::new(move |payload: OnExecutePayload| { + // Do not cache if there are errors + if !payload.errors.is_empty() { + return ControlFlow::Continue; + } -impl OnExecuteEnd for ResponseCachePlugin { - fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow { - // Do not cache if there are errors - if !payload.errors.is_empty() { - return ControlFlow::Continue; - } - if let Some(key) = payload - .router_http_request - .extensions() - .get::() - .map(|ctx| &ctx.key) - { - if let Ok(mut conn) = self.redis_client.get_connection() { if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { // Decide on the ttl somehow // Get the type names @@ -93,18 +74,16 @@ impl OnExecuteEnd for ResponseCachePlugin { // Insert the ttl into extensions for client awareness payload .extensions - .insert("response_cache_ttl".to_string(), json!(max_ttl)); + .insert("response_cache_ttl".to_string(), sonic_rs::json!(max_ttl)); // Set the cache with the decided ttl let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); } - } + ControlFlow::Continue + })); } ControlFlow::Continue } -} - -impl OnSchemaReload for ResponseCachePlugin { fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { // Visit the schema and update ttl_per_type based on some directive payload diff --git a/lib/executor/src/plugins/traits.rs b/lib/executor/src/plugins/traits.rs index 4357e5eab..dd3db1731 100644 --- a/lib/executor/src/plugins/traits.rs +++ b/lib/executor/src/plugins/traits.rs @@ -8,9 +8,10 @@ use ntex::web::HttpResponse; use crate::response::graphql_error::GraphQLError; use crate::response::value::Value; -pub enum ControlFlow { +pub enum ControlFlow<'a, TPayload> { Continue, Break(HttpResponse), + OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), } pub struct ExecutionResult<'exec> { @@ -19,21 +20,27 @@ pub struct ExecutionResult<'exec> { pub extensions: &'exec mut Option>>, } -pub struct OnExecuteStartPayload<'exec> { +pub struct OnExecutePayload<'exec> { pub router_http_request: &'exec HttpRequest, pub query_plan: Arc, pub data: &'exec mut Value<'exec>, pub errors: &'exec mut Vec, - pub extensions: Option<&'exec mut sonic_rs::Value>, + pub extensions: &'exec mut HashMap, pub skip_execution: bool, pub variable_values: &'exec Option>, } -pub trait OnExecuteStart { - fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow; +pub trait RouterPlugin { + fn on_execute<'exec>( + &self, + _payload: OnExecutePayload<'exec>, + ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { + ControlFlow::Continue + } + fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} } pub struct OnExecuteEndPayload<'exec> { @@ -47,15 +54,7 @@ pub struct OnExecuteEndPayload<'exec> { pub variable_values: &'exec Option>, } -pub trait OnExecuteEnd { - fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow; -} - pub struct OnSchemaReloadPayload { pub old_schema: &'static ConsumerSchema, pub new_schema: &'static mut ConsumerSchema, } - -pub trait OnSchemaReload { - fn on_schema_reload(&self, payload: OnSchemaReloadPayload); -} From 0ef2138233d76a257418121ac013c9e1b64f81d9 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 14 Oct 2025 17:09:42 +0300 Subject: [PATCH 03/31] New example --- lib/executor/src/lib.rs | 2 +- lib/executor/src/plugins/examples/mod.rs | 2 + .../plugins/{ => examples}/response_cache.rs | 7 +-- .../examples/subgraph_response_cache.rs | 32 ++++++++++ lib/executor/src/plugins/hooks/mod.rs | 3 + lib/executor/src/plugins/hooks/on_execute.rs | 33 ++++++++++ .../src/plugins/hooks/on_schema_reload.rs | 6 ++ .../src/plugins/hooks/on_subgraph_execute.rs | 47 +++++++++++++++ lib/executor/src/plugins/mod.rs | 5 +- lib/executor/src/plugins/plugin_trait.rs | 28 +++++++++ lib/executor/src/plugins/traits.rs | 60 ------------------- 11 files changed, 158 insertions(+), 67 deletions(-) create mode 100644 lib/executor/src/plugins/examples/mod.rs rename lib/executor/src/plugins/{ => examples}/response_cache.rs (95%) create mode 100644 lib/executor/src/plugins/examples/subgraph_response_cache.rs create mode 100644 lib/executor/src/plugins/hooks/mod.rs create mode 100644 lib/executor/src/plugins/hooks/on_execute.rs create mode 100644 lib/executor/src/plugins/hooks/on_schema_reload.rs create mode 100644 lib/executor/src/plugins/hooks/on_subgraph_execute.rs create mode 100644 lib/executor/src/plugins/plugin_trait.rs delete mode 100644 lib/executor/src/plugins/traits.rs diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index c245a7483..1f29c192e 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -12,4 +12,4 @@ pub mod variables; pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; -pub use plugins::response_cache::*; +pub use plugins::*; diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs new file mode 100644 index 000000000..3d54dfbed --- /dev/null +++ b/lib/executor/src/plugins/examples/mod.rs @@ -0,0 +1,2 @@ +pub mod response_cache; +pub mod subgraph_response_cache; \ No newline at end of file diff --git a/lib/executor/src/plugins/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs similarity index 95% rename from lib/executor/src/plugins/response_cache.rs rename to lib/executor/src/plugins/examples/response_cache.rs index 01719deca..ed8a106a6 100644 --- a/lib/executor/src/plugins/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -3,10 +3,9 @@ use ntex::web::HttpResponse; use redis::Commands; use crate::{ - plugins::traits::{ - ControlFlow, OnExecutePayload, OnSchemaReloadPayload, RouterPlugin - }, - utils::consts::TYPENAME_FIELD_NAME, + hooks::{on_execute::OnExecutePayload, on_schema_reload::OnSchemaReloadPayload}, plugins::plugin_trait::{ + ControlFlow, RouterPlugin + }, utils::consts::TYPENAME_FIELD_NAME }; pub struct ResponseCachePlugin { diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs new file mode 100644 index 000000000..32739df1e --- /dev/null +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -0,0 +1,32 @@ +use dashmap::DashMap; + +use crate::{hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, SubgraphExecutorResponse, SubgraphResponse}, plugin_trait::{ControlFlow, RouterPlugin}}; + +struct SubgraphResponseCachePlugin { + cache: DashMap>, +} + +impl RouterPlugin for SubgraphResponseCachePlugin { + fn on_subgraph_execute<'exec>( + &self, + payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + let key = format!( + "subgraph_response_cache:{}:{}:{:?}", + payload.subgraph_name, payload.execution_request.operation_name.unwrap_or(""), payload.execution_request.variables + ); + if let Some(cached_response) = self.cache.get(&key) { + *payload.response = Some(SubgraphExecutorResponse::RawResponse(cached_response)); + // Return early with the cached response + return ControlFlow::Continue; + } else { + ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphExecuteEndPayload| { + let cacheable = payload.response.errors.is_none_or(|errors| errors.is_empty()); + if cacheable { + self.cache.insert(key, *payload.response); + } + ControlFlow::Continue + })) + } + } +} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs new file mode 100644 index 000000000..1954154d8 --- /dev/null +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -0,0 +1,3 @@ +pub mod on_execute; +pub mod on_schema_reload; +pub mod on_subgraph_execute; \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs new file mode 100644 index 000000000..0d5759de2 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -0,0 +1,33 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use ntex::web::HttpRequest; + +use crate::response::{value::Value}; +use crate::response::graphql_error::GraphQLError; + +pub struct OnExecutePayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, + pub extensions: &'exec mut HashMap, + + pub skip_execution: bool, + + pub variable_values: &'exec Option>, +} + +pub struct OnExecuteEndPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec Value<'exec>, + pub errors: &'exec Vec, + pub extensions: &'exec mut HashMap, + + pub variable_values: &'exec Option>, +} + diff --git a/lib/executor/src/plugins/hooks/on_schema_reload.rs b/lib/executor/src/plugins/hooks/on_schema_reload.rs new file mode 100644 index 000000000..29863d964 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_schema_reload.rs @@ -0,0 +1,6 @@ +use hive_router_query_planner::consumer_schema::ConsumerSchema; + +pub struct OnSchemaReloadPayload { + pub old_schema: &'static ConsumerSchema, + pub new_schema: &'static mut ConsumerSchema, +} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs new file mode 100644 index 000000000..4fde5afb8 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -0,0 +1,47 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use hive_router_query_planner::ast::operation::SubgraphFetchOperation; +use ntex::web::HttpRequest; + +use crate::{executors::dedupe::SharedResponse, response::{graphql_error::GraphQLError, value::Value}}; + + + +pub struct OnSubgraphExecuteStartPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub subgraph_name: &'exec str, + // The node that initiates this subgraph execution + pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, + // This will be tricky to implement with the current structure, + // but I'm sure we'll figure it out + pub response: &'exec mut Option>, +} + +pub struct SubgraphExecutionRequest<'exec> { + pub query: &'exec str, + // We can add the original operation here too + pub operation: &'exec SubgraphFetchOperation, + + pub dedupe: bool, + pub operation_name: Option<&'exec str>, + pub variables: Option>, + pub extensions: Option>, + pub representations: Option>, +} + +pub struct SubgraphResponse<'exec> { + pub data: Value<'exec>, + pub errors: Option>, + pub extensions: Option>>, +} + +pub struct OnSubgraphExecuteEndPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub subgraph_name: &'exec str, + // The node that initiates this subgraph execution + pub execution_request: &'exec SubgraphExecutionRequest<'exec>, + // This will be tricky to implement with the current structure, + // but I'm sure we'll figure it out + pub response: &'exec SubgraphResponse<'exec>, +} diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs index 0e59883b7..6c35286af 100644 --- a/lib/executor/src/plugins/mod.rs +++ b/lib/executor/src/plugins/mod.rs @@ -1,2 +1,3 @@ -pub mod response_cache; -pub mod traits; +pub mod examples; +pub mod plugin_trait; +pub mod hooks; \ No newline at end of file diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs new file mode 100644 index 000000000..8d5a951e7 --- /dev/null +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -0,0 +1,28 @@ +use ntex::web::HttpResponse; + +use crate::hooks::on_execute::OnExecutePayload; +use crate::hooks::on_schema_reload::OnSchemaReloadPayload; +use crate::hooks::on_subgraph_execute::OnSubgraphExecuteEndPayload; +use crate::hooks::on_subgraph_execute::OnSubgraphExecuteStartPayload; + +pub enum ControlFlow<'a, TPayload> { + Continue, + Break(HttpResponse), + OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), +} + +pub trait RouterPlugin { + fn on_execute<'exec>( + &self, + _payload: OnExecutePayload<'exec>, + ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { + ControlFlow::Continue + } + fn on_subgraph_execute<'exec>( + &self, + _payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + ControlFlow::Continue + } + fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} +} \ No newline at end of file diff --git a/lib/executor/src/plugins/traits.rs b/lib/executor/src/plugins/traits.rs deleted file mode 100644 index dd3db1731..000000000 --- a/lib/executor/src/plugins/traits.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use hive_router_query_planner::consumer_schema::ConsumerSchema; -use hive_router_query_planner::planner::plan_nodes::QueryPlan; -use ntex::web::HttpRequest; -use ntex::web::HttpResponse; - -use crate::response::graphql_error::GraphQLError; -use crate::response::value::Value; - -pub enum ControlFlow<'a, TPayload> { - Continue, - Break(HttpResponse), - OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), -} - -pub struct ExecutionResult<'exec> { - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut Option>>, -} - -pub struct OnExecutePayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, - - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut HashMap, - - pub skip_execution: bool, - - pub variable_values: &'exec Option>, -} - -pub trait RouterPlugin { - fn on_execute<'exec>( - &self, - _payload: OnExecutePayload<'exec>, - ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { - ControlFlow::Continue - } - fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} -} - -pub struct OnExecuteEndPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, - - pub data: &'exec Value<'exec>, - pub errors: &'exec Vec, - pub extensions: &'exec mut HashMap, - - pub variable_values: &'exec Option>, -} - -pub struct OnSchemaReloadPayload { - pub old_schema: &'static ConsumerSchema, - pub new_schema: &'static mut ConsumerSchema, -} From 22279bf956a4546144de14917fb8da378766a6ec Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 14 Oct 2025 17:30:13 +0300 Subject: [PATCH 04/31] Another plugin --- .../examples/subgraph_response_cache.rs | 36 +++++++++--------- lib/executor/src/plugins/hooks/mod.rs | 2 +- ...execute.rs => on_subgraph_http_request.rs} | 38 +++++++++---------- lib/executor/src/plugins/plugin_trait.rs | 15 ++++---- 4 files changed, 44 insertions(+), 47 deletions(-) rename lib/executor/src/plugins/hooks/{on_subgraph_execute.rs => on_subgraph_http_request.rs} (54%) diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 32739df1e..d4742b2ad 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,32 +1,30 @@ use dashmap::DashMap; -use crate::{hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, SubgraphExecutorResponse, SubgraphResponse}, plugin_trait::{ControlFlow, RouterPlugin}}; +use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ControlFlow, RouterPlugin}}; -struct SubgraphResponseCachePlugin { - cache: DashMap>, +pub struct SubgraphResponseCachePlugin { + cache: DashMap, } impl RouterPlugin for SubgraphResponseCachePlugin { - fn on_subgraph_execute<'exec>( - &self, - payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + fn on_subgraph_http_request<'exec>( + &'static self, + payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { let key = format!( - "subgraph_response_cache:{}:{}:{:?}", - payload.subgraph_name, payload.execution_request.operation_name.unwrap_or(""), payload.execution_request.variables + "subgraph_response_cache:{}:{:?}", + payload.execution_request.query, payload.execution_request.variables ); if let Some(cached_response) = self.cache.get(&key) { - *payload.response = Some(SubgraphExecutorResponse::RawResponse(cached_response)); - // Return early with the cached response + // Here payload.response is Option + // So it is bypassing the actual subgraph request + *payload.response = Some(cached_response.clone()); return ControlFlow::Continue; - } else { - ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphExecuteEndPayload| { - let cacheable = payload.response.errors.is_none_or(|errors| errors.is_empty()); - if cacheable { - self.cache.insert(key, *payload.response); - } - ControlFlow::Continue - })) } + ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphHttpResponsePayload| { + // Here payload.response is not Option + self.cache.insert(key, payload.response.clone()); + ControlFlow::Continue + })) } } \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 1954154d8..5a4d94c22 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,3 +1,3 @@ pub mod on_execute; pub mod on_schema_reload; -pub mod on_subgraph_execute; \ No newline at end of file +pub mod on_subgraph_http_request; \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs similarity index 54% rename from lib/executor/src/plugins/hooks/on_subgraph_execute.rs rename to lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 4fde5afb8..326e516f0 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,28 +1,34 @@ use std::collections::HashMap; -use bytes::Bytes; use hive_router_query_planner::ast::operation::SubgraphFetchOperation; +use http::{HeaderMap, Uri}; use ntex::web::HttpRequest; -use crate::{executors::dedupe::SharedResponse, response::{graphql_error::GraphQLError, value::Value}}; +use crate:: + executors::dedupe::SharedResponse +; +pub struct OnSubgraphHttpRequestPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub subgraph_name: &'exec str, + // At this point, there is no point of mutating this + pub execution_request: &'exec SubgraphExecutionRequest<'exec>, + pub endpoint: &'exec mut Uri, + // By default, it is POST + pub method: &'exec mut http::Method, + pub headers: &'exec mut HeaderMap, + pub request_body: &'exec mut Vec, -pub struct OnSubgraphExecuteStartPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub subgraph_name: &'exec str, - // The node that initiates this subgraph execution - pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, - // This will be tricky to implement with the current structure, - // but I'm sure we'll figure it out - pub response: &'exec mut Option>, + // Early response + pub response: &'exec mut Option, } pub struct SubgraphExecutionRequest<'exec> { pub query: &'exec str, // We can add the original operation here too pub operation: &'exec SubgraphFetchOperation, - + pub dedupe: bool, pub operation_name: Option<&'exec str>, pub variables: Option>, @@ -30,18 +36,12 @@ pub struct SubgraphExecutionRequest<'exec> { pub representations: Option>, } -pub struct SubgraphResponse<'exec> { - pub data: Value<'exec>, - pub errors: Option>, - pub extensions: Option>>, -} - -pub struct OnSubgraphExecuteEndPayload<'exec> { +pub struct OnSubgraphHttpResponsePayload<'exec> { pub router_http_request: &'exec HttpRequest, pub subgraph_name: &'exec str, // The node that initiates this subgraph execution pub execution_request: &'exec SubgraphExecutionRequest<'exec>, // This will be tricky to implement with the current structure, // but I'm sure we'll figure it out - pub response: &'exec SubgraphResponse<'exec>, + pub response: &'exec mut SharedResponse, } diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 8d5a951e7..53b0720d1 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -2,13 +2,12 @@ use ntex::web::HttpResponse; use crate::hooks::on_execute::OnExecutePayload; use crate::hooks::on_schema_reload::OnSchemaReloadPayload; -use crate::hooks::on_subgraph_execute::OnSubgraphExecuteEndPayload; -use crate::hooks::on_subgraph_execute::OnSubgraphExecuteStartPayload; +use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; -pub enum ControlFlow<'a, TPayload> { +pub enum ControlFlow<'exec, TPayload> { Continue, Break(HttpResponse), - OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), + OnEnd(Box ControlFlow<'exec, ()> + 'exec>), } pub trait RouterPlugin { @@ -18,10 +17,10 @@ pub trait RouterPlugin { ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { ControlFlow::Continue } - fn on_subgraph_execute<'exec>( - &self, - _payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + fn on_subgraph_http_request<'exec>( + &'static self, + _payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { ControlFlow::Continue } fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} From 8f0c0f99cd87f2400ec1979e54adc91142b753da Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 18 Nov 2025 01:55:12 +0300 Subject: [PATCH 05/31] More --- bin/router/src/lib.rs | 7 +- bin/router/src/pipeline/coerce_variables.rs | 6 +- ...quest.rs => deserialize_graphql_params.rs} | 61 ++++------ bin/router/src/pipeline/mod.rs | 96 ++++++++++----- bin/router/src/pipeline/normalize.rs | 10 +- bin/router/src/pipeline/parser.rs | 77 ++++++++++-- bin/router/src/shared_state.rs | 3 + lib/executor/src/execution/plan.rs | 4 +- lib/executor/src/executors/common.rs | 6 +- lib/executor/src/executors/http.rs | 6 +- lib/executor/src/executors/map.rs | 4 +- lib/executor/src/plugins/examples/apq.rs | 55 +++++++++ lib/executor/src/plugins/examples/mod.rs | 3 +- .../src/plugins/examples/response_cache.rs | 32 ++--- .../examples/subgraph_response_cache.rs | 12 +- lib/executor/src/plugins/hooks/mod.rs | 8 +- .../src/plugins/hooks/on_deserialization.rs | 45 +++++++ lib/executor/src/plugins/hooks/on_execute.rs | 22 ++-- .../src/plugins/hooks/on_graphql_parse.rs | 19 +++ .../plugins/hooks/on_graphql_validation.rs | 25 ++++ .../src/plugins/hooks/on_http_request.rs | 16 +++ .../src/plugins/hooks/on_query_plan.rs | 23 ++++ .../src/plugins/hooks/on_schema_reload.rs | 6 +- .../src/plugins/hooks/on_subgraph_execute.rs | 34 ++++++ .../plugins/hooks/on_subgraph_http_request.rs | 24 +--- lib/executor/src/plugins/plugin_trait.rs | 114 +++++++++++++++--- 26 files changed, 555 insertions(+), 163 deletions(-) rename bin/router/src/pipeline/{execution_request.rs => deserialize_graphql_params.rs} (67%) create mode 100644 lib/executor/src/plugins/examples/apq.rs create mode 100644 lib/executor/src/plugins/hooks/on_deserialization.rs create mode 100644 lib/executor/src/plugins/hooks/on_graphql_parse.rs create mode 100644 lib/executor/src/plugins/hooks/on_graphql_validation.rs create mode 100644 lib/executor/src/plugins/hooks/on_http_request.rs create mode 100644 lib/executor/src/plugins/hooks/on_query_plan.rs create mode 100644 lib/executor/src/plugins/hooks/on_subgraph_execute.rs diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 6a3f7f5c0..da799b4cc 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -27,8 +27,7 @@ pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; use hive_router_config::{load_config, HiveRouterConfig}; use http::header::RETRY_AFTER; use ntex::{ - util::Bytes, - web::{self, HttpRequest}, + util::Bytes, web::{self, HttpRequest} }; use tracing::{info, warn}; @@ -121,7 +120,9 @@ pub async fn configure_app_from_config( } pub fn configure_ntex_app(cfg: &mut web::ServiceConfig) { - cfg.route("/graphql", web::to(graphql_endpoint_handler)) + cfg + .route("/graphql", web::to(graphql_endpoint_handler)) .route("/health", web::to(health_check_handler)) .route("/readiness", web::to(readiness_check_handler)); } + diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index 8c472695e..fa85223e0 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; use hive_router_plan_executor::variables::collect_variables; use hive_router_query_planner::state::supergraph_state::OperationKind; use http::Method; @@ -9,7 +10,6 @@ use sonic_rs::Value; use tracing::{error, trace, warn}; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::schema_state::SupergraphData; @@ -22,7 +22,7 @@ pub struct CoerceVariablesPayload { pub fn coerce_request_variables( req: &HttpRequest, supergraph: &SupergraphData, - execution_params: &mut ExecutionRequest, + graphql_params: &mut GraphQLParams, normalized_operation: &Arc, ) -> Result { if req.method() == Method::GET { @@ -37,7 +37,7 @@ pub fn coerce_request_variables( match collect_variables( &normalized_operation.operation_for_plan, - &mut execution_params.variables, + &mut graphql_params.variables, &supergraph.metadata, ) { Ok(values) => { diff --git a/bin/router/src/pipeline/execution_request.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs similarity index 67% rename from bin/router/src/pipeline/execution_request.rs rename to bin/router/src/pipeline/deserialize_graphql_params.rs index c17a6f355..1a769e5a2 100644 --- a/bin/router/src/pipeline/execution_request.rs +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -1,11 +1,10 @@ use std::collections::HashMap; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; use http::Method; use ntex::util::Bytes; use ntex::web::types::Query; use ntex::web::HttpRequest; -use serde::{Deserialize, Deserializer}; -use sonic_rs::Value; use tracing::{trace, warn}; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; @@ -20,36 +19,10 @@ struct GETQueryParams { pub extensions: Option, } -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct ExecutionRequest { - pub query: String, - pub operation_name: Option, - #[serde(default, deserialize_with = "deserialize_null_default")] - pub variables: HashMap, - // TODO: We don't use extensions yet, but we definitely will in the future. - #[allow(dead_code)] - pub extensions: Option>, -} - -fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result -where - T: Default + Deserialize<'de>, - D: Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) -} - -impl TryInto for GETQueryParams { +impl TryInto for GETQueryParams { type Error = PipelineErrorVariant; - fn try_into(self) -> Result { - let query = match self.query { - Some(q) => q, - None => return Err(PipelineErrorVariant::GetMissingQueryParam("query")), - }; - + fn try_into(self) -> Result { let variables = match self.variables.as_deref() { Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) { Ok(vars) => vars, @@ -70,8 +43,8 @@ impl TryInto for GETQueryParams { _ => None, }; - let execution_request = ExecutionRequest { - query, + let execution_request = GraphQLParams { + query: self.query, operation_name: self.operation_name, variables, extensions, @@ -81,13 +54,25 @@ impl TryInto for GETQueryParams { } } +pub trait GetQueryStr { + fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant>; +} + +impl GetQueryStr for GraphQLParams { + fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant> { + self.query + .as_deref() + .ok_or(PipelineErrorVariant::GetMissingQueryParam("query")) + } +} + #[inline] -pub async fn get_execution_request( - req: &mut HttpRequest, +pub fn deserialize_graphql_params( + req: &HttpRequest, body_bytes: Bytes, -) -> Result { +) -> Result { let http_method = req.method(); - let execution_request: ExecutionRequest = match *http_method { + let graphql_params: GraphQLParams = match *http_method { Method::GET => { trace!("processing GET GraphQL operation"); let query_params_str = req.uri().query().ok_or_else(|| { @@ -111,7 +96,7 @@ pub async fn get_execution_request( req.assert_json_content_type()?; let execution_request = unsafe { - sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { + sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { warn!("Failed to parse body: {}", e); req.new_pipeline_error(PipelineErrorVariant::FailedToParseBody(e)) })? @@ -130,5 +115,5 @@ pub async fn get_execution_request( } }; - Ok(execution_request) + Ok(graphql_params) } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 2b4721972..61cb353bf 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,11 +1,17 @@ use std::sync::Arc; -use hive_router_plan_executor::execution::{ - client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::PlanExecutionOutput, +use hive_router_plan_executor::{ + execution::{ + client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, + plan::PlanExecutionOutput, + }, + hooks::on_deserialization::{ + OnDeserializationEndPayload, OnDeserializationStartPayload + }, + plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ - state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, + state::supergraph_state::OperationKind, utils::cancellation::CancellationToken }; use http::{header::CONTENT_TYPE, HeaderValue, Method}; use ntex::{ @@ -16,20 +22,9 @@ use ntex::{ use crate::{ jwt::context::JwtRequestContext, pipeline::{ - coerce_variables::coerce_request_variables, - csrf_prevention::perform_csrf_prevention, - error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, - execution::execute_plan, - execution_request::get_execution_request, - header::{ - RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON, - APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, TEXT_HTML_CONTENT_TYPE, - }, - normalize::normalize_request_with_cache, - parser::parse_operation_with_cache, - progressive_override::request_override_context, - query_plan::plan_operation_with_cache, - validation::validate_operation_with_cache, + coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, deserialize_graphql_params::{GetQueryStr, deserialize_graphql_params}, error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, execution::execute_plan, header::{ + APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, RequestAccepts, TEXT_HTML_CONTENT_TYPE + }, normalize::normalize_request_with_cache, parser::parse_operation_with_cache, progressive_override::request_override_context, query_plan::plan_operation_with_cache, validation::validate_operation_with_cache }, schema_state::{SchemaState, SupergraphData}, shared_state::RouterSharedState, @@ -40,7 +35,7 @@ pub mod cors; pub mod csrf_prevention; pub mod error; pub mod execution; -pub mod execution_request; +pub mod deserialize_graphql_params; pub mod header; pub mod normalize; pub mod parser; @@ -104,17 +99,61 @@ pub async fn graphql_request_handler( #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline( - req: &mut HttpRequest, - body_bytes: Bytes, +pub async fn execute_pipeline<'req>( + req: &'req mut HttpRequest, + body: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, + shared_state: &'req Arc, schema_state: &Arc, ) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; - let mut execution_request = get_execution_request(req, body_bytes).await?; - let parser_payload = parse_operation_with_cache(req, shared_state, &execution_request).await?; + /* Handle on_deserialize hook in the plugins - START */ + let mut deserialization_end_callbacks = vec![]; + let mut deserialization_payload: OnDeserializationStartPayload<'req> = OnDeserializationStartPayload { + router_http_request: req, + body, + graphql_params: None, + }; + for plugin in &shared_state.plugins { + let result = plugin.on_deserialization(deserialization_payload); + deserialization_payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(callback) => { + deserialization_end_callbacks.push(callback); + } + } + } + let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { + deserialize_graphql_params(req, deserialization_payload.body).expect("Failed to parse execution request") + }); + + let mut payload: OnDeserializationEndPayload<'req> = OnDeserializationEndPayload { + router_http_request: req, + graphql_params, + }; + for deserialization_end_callback in deserialization_end_callbacks { + let result = deserialization_end_callback(payload); + payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + }, + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + let mut graphql_params = payload.graphql_params; + /* Handle on_deserialize hook in the plugins - END */ + + let parser_payload = parse_operation_with_cache(req, shared_state, &graphql_params).await?; validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) .await?; @@ -122,12 +161,13 @@ pub async fn execute_pipeline( req, supergraph, schema_state, - &execution_request, + &graphql_params, &parser_payload, ) .await?; + let variable_payload = - coerce_request_variables(req, supergraph, &mut execution_request, &normalize_payload)?; + coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); @@ -158,7 +198,7 @@ pub async fn execute_pipeline( Some(OperationKind::Subscription) => "subscription", None => "query", }, - query: &execution_request.query, + query: graphql_params.get_query().map_err(|err| req.new_pipeline_error(err))?, }, jwt: &jwt_request_details, }; diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index 4fc2cc5ef..f3a07ea95 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -1,6 +1,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; @@ -9,7 +10,6 @@ use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; use crate::pipeline::parser::GraphQLParserPayload; use crate::schema_state::{SchemaState, SupergraphData}; use tracing::{error, trace}; @@ -28,13 +28,13 @@ pub async fn normalize_request_with_cache( req: &HttpRequest, supergraph: &SupergraphData, schema_state: &Arc, - execution_params: &ExecutionRequest, + graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, ) -> Result, PipelineError> { - let cache_key = match &execution_params.operation_name { + let cache_key = match &graphql_params.operation_name { Some(operation_name) => { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); operation_name.hash(&mut hasher); hasher.finish() } @@ -54,7 +54,7 @@ pub async fn normalize_request_with_cache( None => match normalize_operation( &supergraph.planner.supergraph, &parser_payload.parsed_operation, - execution_params.operation_name.as_deref(), + graphql_params.operation_name.as_deref(), ) { Ok(doc) => { trace!( diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 6e8a37141..1f3357428 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,12 +2,15 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; +use crate::pipeline::deserialize_graphql_params::GetQueryStr; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; use crate::shared_state::RouterSharedState; use tracing::{error, trace}; @@ -21,11 +24,11 @@ pub struct GraphQLParserPayload { pub async fn parse_operation_with_cache( req: &HttpRequest, app_state: &Arc, - execution_params: &ExecutionRequest, + graphql_params: &GraphQLParams, ) -> Result { let cache_key = { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); hasher.finish() }; @@ -33,12 +36,68 @@ pub async fn parse_operation_with_cache( trace!("Found cached parsed operation for query"); cached } else { - let parsed = safe_parse_operation(&execution_params.query).map_err(|err| { - error!("Failed to parse GraphQL operation: {}", err); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) - })?; - trace!("sucessfully parsed GraphQL operation"); - let parsed_arc = Arc::new(parsed); + /* Handle on_graphql_parse hook in the plugins - START */ + let mut start_payload = OnGraphQLParseStartPayload { + router_http_request: req, + graphql_params, + document: None, + }; + let mut on_end_callbacks = vec![]; + for plugin in &app_state.plugins { + let result = plugin.on_graphql_parse(start_payload); + start_payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + todo!() + } + ControlFlowResult::OnEnd(callback) => { + // store the callback to be called later + on_end_callbacks.push(callback); + } + } + } + let document = match start_payload.document { + Some(parsed) => parsed, + None => { + let query_str = graphql_params.get_query().map_err(|err| { + req.new_pipeline_error(err) + })?; + let parsed = safe_parse_operation(query_str).map_err(|err| { + error!("Failed to parse GraphQL operation: {}", err); + req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) + })?; + trace!("successfully parsed GraphQL operation"); + parsed + } + }; + let mut end_payload = OnGraphQLParseEndPayload { + router_http_request: req, + graphql_params, + document, + }; + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + todo!() + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!(); + } + } + } + let document = end_payload.document; + /* Handle on_graphql_parse hook in the plugins - END */ + + let parsed_arc = Arc::new(document); app_state .parse_cache .insert(cache_key, parsed_arc.clone()) diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index f36bda6cd..06446102a 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -3,6 +3,7 @@ use hive_router_config::HiveRouterConfig; use hive_router_plan_executor::headers::{ compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan, }; +use hive_router_plan_executor::plugin_trait::RouterPlugin; use moka::future::Cache; use std::sync::Arc; @@ -18,6 +19,7 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, + pub plugins: Vec>, } impl RouterSharedState { @@ -36,6 +38,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, + plugins: Vec::new(), }) } } diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index f86356312..6bc314516 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -19,7 +19,7 @@ use crate::{ rewrites::FetchRewriteExt, }, executors::{ - common::{HttpExecutionRequest, HttpExecutionResponse}, + common::{SubgraphExecutionRequest, HttpExecutionResponse}, map::SubgraphExecutorMap, }, headers::{ @@ -700,7 +700,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { let variable_refs = select_fetch_variables(self.variable_values, node.variable_usages.as_ref()); - let mut subgraph_request = HttpExecutionRequest { + let mut subgraph_request = SubgraphExecutionRequest { query: node.operation.document_str.as_str(), dedupe: self.dedupe_subgraph_requests, operation_name: node.operation_name.as_deref(), diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index bdcd4d819..ba13b8707 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -9,7 +9,7 @@ use sonic_rs::Value; pub trait SubgraphExecutor { async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> HttpExecutionResponse; fn to_boxed_arc<'a>(self) -> Arc> @@ -26,7 +26,7 @@ pub type SubgraphExecutorBoxedArc = Arc>; pub type SubgraphRequestExtensions = HashMap; -pub struct HttpExecutionRequest<'a> { +pub struct SubgraphExecutionRequest<'a> { pub query: &'a str, pub dedupe: bool, pub operation_name: Option<&'a str>, @@ -37,7 +37,7 @@ pub struct HttpExecutionRequest<'a> { pub extensions: Option, } -impl HttpExecutionRequest<'_> { +impl SubgraphExecutionRequest<'_> { pub fn add_request_extensions_field(&mut self, key: String, value: Value) { self.extensions .get_or_insert_with(HashMap::new) diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 29b392567..5947cd7d3 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -19,7 +19,7 @@ use hyper_util::client::legacy::{connect::HttpConnector, Client}; use tokio::sync::Semaphore; use tracing::debug; -use crate::executors::common::HttpExecutionRequest; +use crate::executors::common::SubgraphExecutionRequest; use crate::executors::error::SubgraphExecutorError; use crate::response::graphql_error::GraphQLError; use crate::utils::consts::CLOSE_BRACE; @@ -76,7 +76,7 @@ impl HTTPSubgraphExecutor { fn build_request_body<'a>( &self, - execution_request: &HttpExecutionRequest<'a>, + execution_request: &SubgraphExecutionRequest<'a>, ) -> Result, SubgraphExecutorError> { let mut body = Vec::with_capacity(4096); body.put(FIRST_QUOTE_STR); @@ -212,7 +212,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { #[tracing::instrument(skip_all, fields(subgraph_name = self.subgraph_name))] async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> HttpExecutionResponse { let body = match self.build_request_body(&execution_request) { Ok(body) => body, diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index a3c297ad1..2e8ac78ae 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -30,7 +30,7 @@ use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ common::{ - HttpExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, + SubgraphExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, }, dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, @@ -118,7 +118,7 @@ impl SubgraphExecutorMap { pub async fn execute<'a, 'req>( &self, subgraph_name: &str, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, client_request: &ClientRequestDetails<'a, 'req>, ) -> HttpExecutionResponse { match self.get_or_create_executor(subgraph_name, client_request) { diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs new file mode 100644 index 000000000..7d6ac9256 --- /dev/null +++ b/lib/executor/src/plugins/examples/apq.rs @@ -0,0 +1,55 @@ +use dashmap::DashMap; +use sonic_rs::{JsonContainerTrait, JsonValueTrait}; + +use crate::{ + hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct APQPlugin { + cache: DashMap, +} + +impl RouterPlugin for APQPlugin { + fn on_deserialization<'exec>( + &'exec self, + start_payload: OnDeserializationStartPayload<'exec>, + ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> + { + start_payload.on_end(|mut end_payload| { + let persisted_query_ext = end_payload.graphql_params.extensions.as_ref() + .and_then(|ext| ext.get("persistedQuery")) + .and_then(|pq| pq.as_object()); + if let Some(persisted_query_ext) = persisted_query_ext { + match persisted_query_ext.get(&"version").and_then(|v| v.as_str()) { + Some("1") => {} + _ => { + // TODO: Error for unsupported version + return end_payload.cont(); + } + } + let sha256_hash = match persisted_query_ext.get(&"sha256Hash").and_then(|h| h.as_str()) { + Some(h) => h, + None => { + return end_payload.cont(); + } + }; + if let Some(query_param) = &end_payload.graphql_params.query { + // Store the query in the cache + self.cache.insert(sha256_hash.to_string(), query_param.to_string()); + } else { + // Try to get the query from the cache + if let Some(cached_query) = self.cache.get(sha256_hash) { + // Update the graphql_params with the cached query + end_payload.graphql_params.query = Some(cached_query.value().to_string()); + } else { + // Error + return end_payload.cont(); + } + } + } + + end_payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index 3d54dfbed..68e3e7092 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -1,2 +1,3 @@ pub mod response_cache; -pub mod subgraph_response_cache; \ No newline at end of file +pub mod subgraph_response_cache; +pub mod apq; \ No newline at end of file diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index ed8a106a6..5942e6d91 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -1,11 +1,9 @@ use dashmap::DashMap; -use ntex::web::HttpResponse; +use http::HeaderMap; use redis::Commands; use crate::{ - hooks::{on_execute::OnExecutePayload, on_schema_reload::OnSchemaReloadPayload}, plugins::plugin_trait::{ - ControlFlow, RouterPlugin - }, utils::consts::TYPENAME_FIELD_NAME + execution::plan::PlanExecutionOutput, hooks::{on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_schema_reload::OnSchemaReloadPayload}, plugin_trait::{EndPayload, HookResult, StartPayload}, plugins::plugin_trait::RouterPlugin, utils::consts::TYPENAME_FIELD_NAME }; pub struct ResponseCachePlugin { @@ -25,9 +23,9 @@ impl ResponseCachePlugin { impl RouterPlugin for ResponseCachePlugin { fn on_execute<'exec>( - &self, - payload: OnExecutePayload<'exec>, - ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { + &'exec self, + payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { let key = format!( "response_cache:{}:{:?}", payload.query_plan, payload.variable_values @@ -35,16 +33,18 @@ impl RouterPlugin for ResponseCachePlugin { if let Ok(mut conn) = self.redis_client.get_connection() { let cached_response: Option> = conn.get(&key).ok(); if let Some(cached_response) = cached_response { - return ControlFlow::Break( - HttpResponse::Ok() - .header("Content-Type", "application/json") - .body(cached_response), + return payload.end_response( + + PlanExecutionOutput { + body: cached_response, + headers: HeaderMap::new(), + } ); } - ControlFlow::OnEnd(Box::new(move |payload: OnExecutePayload| { + return payload.on_end(move |payload: OnExecuteEndPayload<'exec>| { // Do not cache if there are errors if !payload.errors.is_empty() { - return ControlFlow::Continue; + return payload.cont(); } if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { @@ -78,10 +78,10 @@ impl RouterPlugin for ResponseCachePlugin { // Set the cache with the decided ttl let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); } - ControlFlow::Continue - })); + payload.cont() + }); } - ControlFlow::Continue + payload.cont() } fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { // Visit the schema and update ttl_per_type based on some directive diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index d4742b2ad..55d98a893 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,6 +1,6 @@ use dashmap::DashMap; -use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ControlFlow, RouterPlugin}}; +use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; pub struct SubgraphResponseCachePlugin { cache: DashMap, @@ -10,7 +10,7 @@ impl RouterPlugin for SubgraphResponseCachePlugin { fn on_subgraph_http_request<'exec>( &'static self, payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { let key = format!( "subgraph_response_cache:{}:{:?}", payload.execution_request.query, payload.execution_request.variables @@ -19,12 +19,12 @@ impl RouterPlugin for SubgraphResponseCachePlugin { // Here payload.response is Option // So it is bypassing the actual subgraph request *payload.response = Some(cached_response.clone()); - return ControlFlow::Continue; + return payload.cont(); } - ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphHttpResponsePayload| { + payload.on_end(move |payload: OnSubgraphHttpResponsePayload<'exec>| { // Here payload.response is not Option self.cache.insert(key, payload.response.clone()); - ControlFlow::Continue - })) + payload.cont() + }) } } \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 5a4d94c22..65ccf6f4d 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,3 +1,9 @@ pub mod on_execute; pub mod on_schema_reload; -pub mod on_subgraph_http_request; \ No newline at end of file +pub mod on_subgraph_http_request; +pub mod on_http_request; +pub mod on_deserialization; +pub mod on_graphql_parse; +pub mod on_graphql_validation; +pub mod on_query_plan; +pub mod on_subgraph_execute; \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_deserialization.rs b/lib/executor/src/plugins/hooks/on_deserialization.rs new file mode 100644 index 000000000..84991ff56 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_deserialization.rs @@ -0,0 +1,45 @@ +use std::collections::HashMap; + +use ntex::util::Bytes; +use serde::Deserialize; +use serde::Deserializer; +use sonic_rs::Value; + +use crate::plugin_trait::EndPayload; +use crate::plugin_trait::StartPayload; + +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GraphQLParams { + pub query: Option, + pub operation_name: Option, + #[serde(default, deserialize_with = "deserialize_null_default")] + pub variables: HashMap, + // TODO: We don't use extensions yet, but we definitely will in the future. + #[allow(dead_code)] + pub extensions: Option>, +} + +fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result +where + T: Default + Deserialize<'de>, + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + Ok(opt.unwrap_or_default()) +} + +pub struct OnDeserializationStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub body: Bytes, + pub graphql_params: Option, +} + +impl<'exec> StartPayload> for OnDeserializationStartPayload<'exec> {} + +pub struct OnDeserializationEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: GraphQLParams, +} + +impl<'exec> EndPayload for OnDeserializationEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 0d5759de2..5057075e3 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,33 +1,41 @@ use std::collections::HashMap; -use std::sync::Arc; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use ntex::web::HttpRequest; +use crate::plugin_trait::{EndPayload, StartPayload}; use crate::response::{value::Value}; use crate::response::graphql_error::GraphQLError; -pub struct OnExecutePayload<'exec> { +pub struct OnExecuteStartPayload<'exec> { pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, + pub query_plan: &'exec QueryPlan, pub data: &'exec mut Value<'exec>, pub errors: &'exec mut Vec, pub extensions: &'exec mut HashMap, - pub skip_execution: bool, + pub skip_execution: &'exec mut bool, pub variable_values: &'exec Option>, + + pub dedupe_subgraph_requests: &'exec mut bool, } +impl<'exec> StartPayload> for OnExecuteStartPayload<'exec> {} + pub struct OnExecuteEndPayload<'exec> { pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, + pub query_plan: &'exec QueryPlan, + - pub data: &'exec Value<'exec>, - pub errors: &'exec Vec, + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, pub extensions: &'exec mut HashMap, pub variable_values: &'exec Option>, + + pub dedupe_subgraph_requests: &'exec mut bool, } +impl<'exec> EndPayload for OnExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs new file mode 100644 index 000000000..8719cdac3 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -0,0 +1,19 @@ +use graphql_tools::static_graphql::query::Document; + +use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; + +pub struct OnGraphQLParseStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub document: Option, +} + +impl<'exec> StartPayload> for OnGraphQLParseStartPayload<'exec> {} + +pub struct OnGraphQLParseEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub document: Document, +} + +impl<'exec> EndPayload for OnGraphQLParseEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs new file mode 100644 index 000000000..e5ecf898f --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -0,0 +1,25 @@ +use graphql_tools::{static_graphql::query::Document, validation::{utils::ValidationError, validate::ValidationPlan}}; +use hive_router_query_planner::state::supergraph_state::SchemaDocument; + +use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; + +pub struct OnGraphQLValidationStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub schema: &'exec SchemaDocument, + pub document: &'exec Document, + pub validation_plan: &'exec mut ValidationPlan, + pub errors: &'exec mut Option> +} + +impl<'exec> StartPayload> for OnGraphQLValidationStartPayload<'exec> {} + +pub struct OnGraphQLValidationEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub schema: &'exec SchemaDocument, + pub document: &'exec Document, + pub errors: &'exec mut Vec, +} + +impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs new file mode 100644 index 000000000..847e7465e --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -0,0 +1,16 @@ +use ntex::{http::Response, web::HttpRequest}; + +use crate::plugin_trait::{EndPayload, StartPayload}; + +pub struct OnHttpRequestPayload<'exec> { + pub router_http_request: &'exec HttpRequest, +} + +impl<'exec> StartPayload> for OnHttpRequestPayload<'exec> {} + +pub struct OnHttpResponse<'exec> { + pub router_http_request: &'exec HttpRequest, + pub response: &'exec mut Response, +} + +impl<'exec> EndPayload for OnHttpResponse<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs new file mode 100644 index 000000000..7963524ad --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -0,0 +1,23 @@ +use graphql_tools::static_graphql::query::Document; +use hive_router_query_planner::planner::{Planner, plan_nodes::QueryPlan}; + +use crate::plugin_trait::{EndPayload, StartPayload}; + +pub struct OnQueryPlanStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub document: &'exec Document, + // Other params + pub query_plan: &'exec mut Option, + pub planner: &'exec Planner, +} + +impl<'exec> StartPayload> for OnQueryPlanStartPayload<'exec> {} + +pub struct OnQueryPlanEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub document: &'exec Document, + // Other params + pub query_plan: &'exec mut QueryPlan, +} + +impl<'exec> EndPayload for OnQueryPlanEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_schema_reload.rs b/lib/executor/src/plugins/hooks/on_schema_reload.rs index 29863d964..a96d6c240 100644 --- a/lib/executor/src/plugins/hooks/on_schema_reload.rs +++ b/lib/executor/src/plugins/hooks/on_schema_reload.rs @@ -1,6 +1,6 @@ use hive_router_query_planner::consumer_schema::ConsumerSchema; -pub struct OnSchemaReloadPayload { - pub old_schema: &'static ConsumerSchema, - pub new_schema: &'static mut ConsumerSchema, +pub struct OnSchemaReloadPayload<'a> { + pub old_schema: &'a ConsumerSchema, + pub new_schema: &'a mut ConsumerSchema, } diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs new file mode 100644 index 000000000..167340bc8 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -0,0 +1,34 @@ +use bytes::Bytes; +use hive_router_query_planner::planner::plan_nodes::FetchNode; + +use crate::{executors::common::{SubgraphExecutionRequest, SubgraphExecutorBoxedArc}, plugin_trait::{EndPayload, StartPayload}, response::subgraph_response::SubgraphResponse}; + + +pub struct OnSubgraphExecuteStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub executor: &'exec SubgraphExecutorBoxedArc, + pub subgraph_name: &'exec str, + + pub node: &'exec mut FetchNode, + pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, + pub response: &'exec mut Option>, +} + +impl<'exec> StartPayload> for OnSubgraphExecuteStartPayload<'exec> {} + +pub enum SubgraphExecutorResponse<'exec> { + Bytes(Bytes), + SubgraphResponse(SubgraphResponse<'exec>), +} + +pub struct OnSubgraphExecuteEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub executor: &'exec SubgraphExecutorBoxedArc, + pub subgraph_name: &'exec str, + + pub node: &'exec FetchNode, + pub execution_request: &'exec SubgraphExecutionRequest<'exec>, + pub response: &'exec mut SubgraphExecutorResponse<'exec>, +} + +impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 326e516f0..ac720b870 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,11 +1,8 @@ -use std::collections::HashMap; - -use hive_router_query_planner::ast::operation::SubgraphFetchOperation; use http::{HeaderMap, Uri}; use ntex::web::HttpRequest; -use crate:: - executors::dedupe::SharedResponse +use crate::{ + executors::{common::SubgraphExecutionRequest, dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} ; pub struct OnSubgraphHttpRequestPayload<'exec> { @@ -24,24 +21,13 @@ pub struct OnSubgraphHttpRequestPayload<'exec> { pub response: &'exec mut Option, } -pub struct SubgraphExecutionRequest<'exec> { - pub query: &'exec str, - // We can add the original operation here too - pub operation: &'exec SubgraphFetchOperation, - - pub dedupe: bool, - pub operation_name: Option<&'exec str>, - pub variables: Option>, - pub extensions: Option>, - pub representations: Option>, -} +impl<'exec> StartPayload> for OnSubgraphHttpRequestPayload<'exec> {} pub struct OnSubgraphHttpResponsePayload<'exec> { pub router_http_request: &'exec HttpRequest, pub subgraph_name: &'exec str, - // The node that initiates this subgraph execution pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - // This will be tricky to implement with the current structure, - // but I'm sure we'll figure it out pub response: &'exec mut SharedResponse, } + +impl<'exec> EndPayload for OnSubgraphHttpResponsePayload<'exec> {} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 53b0720d1..220c5b88c 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -1,27 +1,113 @@ -use ntex::web::HttpResponse; - -use crate::hooks::on_execute::OnExecutePayload; +use crate::execution::plan::PlanExecutionOutput; +use crate::hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}; +use crate::hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}; +use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; +use crate::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponse}; +use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; use crate::hooks::on_schema_reload::OnSchemaReloadPayload; use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; -pub enum ControlFlow<'exec, TPayload> { +pub struct HookResult<'exec, TStartPayload, TEndPayload> { + pub start_payload: TStartPayload, + pub control_flow: ControlFlowResult<'exec, TEndPayload>, +} + +pub enum ControlFlowResult<'exec, TEndPayload> { Continue, - Break(HttpResponse), - OnEnd(Box ControlFlow<'exec, ()> + 'exec>), + EndResponse(PlanExecutionOutput), + OnEnd(Box HookResult<'exec, TEndPayload, ()> + 'exec>), +} + +pub trait StartPayload + where Self: Sized + { + + fn cont<'exec>(self) -> HookResult<'exec, Self, TEndPayload> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::Continue, + } + } + + fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, TEndPayload> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::EndResponse(output), + } + } + + fn on_end<'exec, F>(self, f: F) -> HookResult<'exec, Self, TEndPayload> + where F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + 'exec, + { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::OnEnd(Box::new(f)), + } + } +} + +pub trait EndPayload + where Self: Sized + { + fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::Continue, + } + } + + fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::EndResponse(output), + } + } } +// Add sync send etc pub trait RouterPlugin { - fn on_execute<'exec>( + fn on_http_request<'exec>( + &self, + start_payload: OnHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponse<'exec>> { + start_payload.cont() + } + fn on_deserialization<'exec>( + &'exec self, + start_payload: OnDeserializationStartPayload<'exec>, + ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> { + start_payload.cont() + } + fn on_graphql_parse<'exec>( &self, - _payload: OnExecutePayload<'exec>, - ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { - ControlFlow::Continue + start_payload: OnGraphQLParseStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload<'exec>> { + start_payload.cont() + } + fn on_graphql_validation<'exec>( + &self, + start_payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload<'exec>> { + start_payload.cont() + } + fn on_query_plan<'exec>( + &self, + start_payload: OnQueryPlanStartPayload<'exec>, + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { + start_payload.cont() + } + fn on_execute<'exec>( + &'exec self, + start_payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + start_payload.cont() } fn on_subgraph_http_request<'exec>( &'static self, - _payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { - ControlFlow::Continue + start_payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { + start_payload.cont() } - fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} + fn on_schema_reload<'a>(&'a self, _start_payload: OnSchemaReloadPayload) {} } \ No newline at end of file From 0bbe2f8684f6023a49b36174f46a1213d9684d4b Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 18 Nov 2025 19:59:16 +0300 Subject: [PATCH 06/31] More --- Cargo.lock | 1 + bin/router/src/lib.rs | 4 +- bin/router/src/pipeline/coerce_variables.rs | 4 +- .../pipeline/deserialize_graphql_params.rs | 2 +- bin/router/src/pipeline/error.rs | 2 +- bin/router/src/pipeline/execution.rs | 4 +- bin/router/src/pipeline/mod.rs | 41 ++++--- bin/router/src/pipeline/normalize.rs | 5 +- bin/router/src/pipeline/parser.rs | 30 +++-- bin/router/src/pipeline/query_plan.rs | 101 ++++++++++++++-- bin/router/src/pipeline/validation.rs | 70 ++++++++++- bin/router/src/schema_state.rs | 79 +++++++++--- bin/router/src/shared_state.rs | 4 +- lib/executor/Cargo.toml | 1 + .../src/execution/client_request_details.rs | 2 +- lib/executor/src/execution/error.rs | 2 +- lib/executor/src/execution/plan.rs | 92 ++++++++++---- lib/executor/src/executors/http.rs | 5 +- lib/executor/src/executors/map.rs | 81 +++++++++++-- lib/executor/src/plugins/examples/apq.rs | 24 ++-- .../src/plugins/examples/response_cache.rs | 29 +++-- .../examples/subgraph_response_cache.rs | 14 +-- lib/executor/src/plugins/hooks/mod.rs | 4 +- lib/executor/src/plugins/hooks/on_execute.rs | 24 ++-- ...eserialization.rs => on_graphql_params.rs} | 9 +- .../src/plugins/hooks/on_graphql_parse.rs | 2 +- .../plugins/hooks/on_graphql_validation.rs | 64 ++++++++-- .../src/plugins/hooks/on_http_request.rs | 2 +- .../src/plugins/hooks/on_query_plan.rs | 18 +-- .../src/plugins/hooks/on_schema_reload.rs | 6 - .../src/plugins/hooks/on_subgraph_execute.rs | 33 ++--- .../plugins/hooks/on_subgraph_http_request.rs | 26 ++-- .../src/plugins/hooks/on_supergraph_load.rs | 27 +++++ lib/executor/src/plugins/plugin_trait.rs | 114 +++++++++++------- 34 files changed, 655 insertions(+), 271 deletions(-) rename lib/executor/src/plugins/hooks/{on_deserialization.rs => on_graphql_params.rs} (76%) delete mode 100644 lib/executor/src/plugins/hooks/on_schema_reload.rs create mode 100644 lib/executor/src/plugins/hooks/on_supergraph_load.rs diff --git a/Cargo.lock b/Cargo.lock index d5c05dfff..b81088873 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2046,6 +2046,7 @@ name = "hive-router-plan-executor" version = "6.0.1" dependencies = [ "ahash", + "arc-swap", "async-trait", "bumpalo", "bytes", diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index da799b4cc..fb19df5c8 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -111,10 +111,10 @@ pub async fn configure_app_from_config( }; let router_config_arc = Arc::new(router_config); + let shared_state = Arc::new(RouterSharedState::new(router_config_arc.clone(), jwt_runtime)?); let schema_state = - SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone()).await?; + SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone(), shared_state.clone()).await?; let schema_state_arc = Arc::new(schema_state); - let shared_state = Arc::new(RouterSharedState::new(router_config_arc, jwt_runtime)?); Ok((shared_state, schema_state_arc)) } diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index fa85223e0..b159f244e 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; use std::sync::Arc; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::variables::collect_variables; use hive_router_query_planner::state::supergraph_state::OperationKind; use http::Method; @@ -11,7 +12,6 @@ use tracing::{error, trace, warn}; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; #[derive(Clone, Debug)] pub struct CoerceVariablesPayload { diff --git a/bin/router/src/pipeline/deserialize_graphql_params.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs index 1a769e5a2..3c0eb5f12 100644 --- a/bin/router/src/pipeline/deserialize_graphql_params.rs +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use http::Method; use ntex::util::Bytes; use ntex::web::types::Query; diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index eec36ea76..71e0a197d 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -78,7 +78,7 @@ pub enum PipelineErrorVariant { #[error("Failed to execute a plan: {0}")] PlanExecutionError(PlanExecutionError), #[error("Failed to produce a plan: {0}")] - PlannerError(Arc), + PlannerError(PlannerError), #[error(transparent)] LabelEvaluationError(LabelEvaluationError), diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 42ace79ce..56f92fece 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -4,12 +4,12 @@ use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; use crate::shared_state::RouterSharedState; use hive_router_plan_executor::execute_query_plan; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use http::HeaderName; @@ -85,6 +85,7 @@ pub async fn execute_plan( }; execute_query_plan(QueryPlanExecutionContext { + router_http_request: req, query_plan: query_plan_payload, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, @@ -95,6 +96,7 @@ pub async fn execute_plan( operation_type_name: normalized_payload.root_type_name, jwt_auth_forwarding: &jwt_forward_plan, executors: &supergraph.subgraph_executor_map, + plugins: &app_state.plugins, }) .await .map_err(|err| { diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 61cb353bf..ddc949c4c 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -5,9 +5,9 @@ use hive_router_plan_executor::{ client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, plan::PlanExecutionOutput, }, - hooks::on_deserialization::{ - OnDeserializationEndPayload, OnDeserializationStartPayload - }, + hooks::{on_graphql_params::{ + OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload + }, on_supergraph_load::SupergraphData}, plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ @@ -24,9 +24,9 @@ use crate::{ pipeline::{ coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, deserialize_graphql_params::{GetQueryStr, deserialize_graphql_params}, error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, execution::execute_plan, header::{ APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, RequestAccepts, TEXT_HTML_CONTENT_TYPE - }, normalize::normalize_request_with_cache, parser::parse_operation_with_cache, progressive_override::request_override_context, query_plan::plan_operation_with_cache, validation::validate_operation_with_cache + }, normalize::normalize_request_with_cache, parser::{ParseResult, parse_operation_with_cache}, progressive_override::request_override_context, query_plan::{QueryPlanResult, plan_operation_with_cache}, validation::validate_operation_with_cache }, - schema_state::{SchemaState, SupergraphData}, + schema_state::{SchemaState}, shared_state::RouterSharedState, }; @@ -110,14 +110,14 @@ pub async fn execute_pipeline<'req>( /* Handle on_deserialize hook in the plugins - START */ let mut deserialization_end_callbacks = vec![]; - let mut deserialization_payload: OnDeserializationStartPayload<'req> = OnDeserializationStartPayload { + let mut deserialization_payload: OnGraphQLParamsStartPayload<'req> = OnGraphQLParamsStartPayload { router_http_request: req, body, graphql_params: None, }; - for plugin in &shared_state.plugins { - let result = plugin.on_deserialization(deserialization_payload); - deserialization_payload = result.start_payload; + for plugin in shared_state.plugins.as_ref() { + let result = plugin.on_graphql_params(deserialization_payload); + deserialization_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { @@ -132,13 +132,12 @@ pub async fn execute_pipeline<'req>( deserialize_graphql_params(req, deserialization_payload.body).expect("Failed to parse execution request") }); - let mut payload: OnDeserializationEndPayload<'req> = OnDeserializationEndPayload { - router_http_request: req, + let mut payload = OnGraphQLParamsEndPayload { graphql_params, }; for deserialization_end_callback in deserialization_end_callbacks { let result = deserialization_end_callback(payload); - payload = result.start_payload; + payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { @@ -153,7 +152,13 @@ pub async fn execute_pipeline<'req>( let mut graphql_params = payload.graphql_params; /* Handle on_deserialize hook in the plugins - END */ - let parser_payload = parse_operation_with_cache(req, shared_state, &graphql_params).await?; + let parser_payload = match parse_operation_with_cache(req, shared_state, &graphql_params).await? { + ParseResult::Payload(payload) => payload, + ParseResult::Response(response) => { + return Ok(response); + } + }; + validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) .await?; @@ -209,15 +214,21 @@ pub async fn execute_pipeline<'req>( ) .map_err(|error| req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)))?; - let query_plan_payload = plan_operation_with_cache( + let query_plan_payload = match plan_operation_with_cache( req, supergraph, schema_state, &normalize_payload, &progressive_override_ctx, &query_plan_cancellation_token, + shared_state, ) - .await?; + .await? { + QueryPlanResult::QueryPlan(query_plan_payload) => query_plan_payload, + QueryPlanResult::Response(response) => { + return Ok(response); + } + }; let execution_result = execute_plan( req, diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index f3a07ea95..c57e2d566 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -1,7 +1,8 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; @@ -11,7 +12,7 @@ use xxhash_rust::xxh3::Xxh3; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::{SchemaState}; use tracing::{error, trace}; #[derive(Debug)] diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 1f3357428..18365a3e2 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,7 +2,8 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; @@ -20,12 +21,17 @@ pub struct GraphQLParserPayload { pub cache_key: u64, } +pub enum ParseResult { + Payload(GraphQLParserPayload), + Response(PlanExecutionOutput), +} + #[inline] pub async fn parse_operation_with_cache( req: &HttpRequest, app_state: &Arc, graphql_params: &GraphQLParams, -) -> Result { +) -> Result { let cache_key = { let mut hasher = Xxh3::new(); graphql_params.query.hash(&mut hasher); @@ -43,15 +49,15 @@ pub async fn parse_operation_with_cache( document: None, }; let mut on_end_callbacks = vec![]; - for plugin in &app_state.plugins { + for plugin in app_state.plugins.as_ref() { let result = plugin.on_graphql_parse(start_payload); - start_payload = result.start_payload; + start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { // continue to next plugin } ControlFlowResult::EndResponse(response) => { - todo!() + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(callback) => { // store the callback to be called later @@ -80,13 +86,13 @@ pub async fn parse_operation_with_cache( }; for callback in on_end_callbacks { let result = callback(end_payload); - end_payload = result.start_payload; + end_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { // continue to next callback } ControlFlowResult::EndResponse(response) => { - todo!() + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -105,8 +111,10 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok(GraphQLParserPayload { - parsed_operation, - cache_key, - }) + Ok( + ParseResult::Payload(GraphQLParserPayload { + parsed_operation, + cache_key, + }) + ) } diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index b2f730be7..58b4e7475 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -4,12 +4,28 @@ use std::sync::Arc; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::{SchemaState}; +use crate::RouterSharedState; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanStartPayload; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; +pub enum QueryPlanResult { + QueryPlan(Arc), + Response(PlanExecutionOutput), +} + +pub enum QueryPlanGetterError { + Planner(PlannerError), + Response(PlanExecutionOutput), +} + #[inline] pub async fn plan_operation_with_cache( req: &HttpRequest, @@ -18,7 +34,8 @@ pub async fn plan_operation_with_cache( normalized_operation: &Arc, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, -) -> Result, PipelineError> { + app_state: &Arc, +) -> Result { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -38,20 +55,80 @@ pub async fn plan_operation_with_cache( })); } - supergraph - .planner - .plan_from_normalized_operation( - filtered_operation_for_plan, - (&request_override_context.clone()).into(), - cancellation_token, - ) - .map(Arc::new) + /* Handle on_query_plan hook in the plugins - START */ + let mut start_payload = OnQueryPlanStartPayload { + router_http_request: req, + filtered_operation_for_plan, + planner_override_context: (&request_override_context.clone()).into(), + cancellation_token, + query_plan: None, + planner: &supergraph.planner, + }; + + let mut on_end_callbacks = vec![]; + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_query_plan(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + let query_plan = match start_payload.query_plan { + Some(plan) => plan, + None => supergraph + .planner + .plan_from_normalized_operation( + filtered_operation_for_plan, + (&request_override_context.clone()).into(), + cancellation_token, + ) + .map_err(|e| QueryPlanGetterError::Planner(e))?, + }; + + let mut end_payload = hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanEndPayload { + router_http_request: req, + filtered_operation_for_plan, + planner_override_context: (&request_override_context.clone()).into(), + cancellation_token, + query_plan, + planner: &supergraph.planner, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + } + } + } + + Ok(Arc::new(end_payload.query_plan)) + /* Handle on_query_plan hook in the plugins - END */ }) .await; match plan_result { - Ok(plan) => Ok(plan), - Err(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), + Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), + Err(e) => match e.as_ref() { + QueryPlanGetterError::Planner(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), + QueryPlanGetterError::Response(response) => Ok(QueryPlanResult::Response(response.clone())), + }, } } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 85d44c2f1..f97aa1661 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -2,9 +2,13 @@ use std::sync::Arc; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::{SchemaState}; use crate::shared_state::RouterSharedState; use graphql_tools::validation::validate::validate; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use ntex::web::HttpRequest; use tracing::{error, trace}; @@ -15,7 +19,7 @@ pub async fn validate_operation_with_cache( schema_state: &Arc, app_state: &Arc, parser_payload: &GraphQLParserPayload, -) -> Result<(), PipelineError> { +) -> Result, PipelineError> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; let validation_result = match schema_state @@ -36,13 +40,67 @@ pub async fn validate_operation_with_cache( "validation result of hash {} does not exists in cache", parser_payload.cache_key ); - - let res = validate( + + /* Handle on_graphql_validate hook in the plugins - START */ + let mut start_payload = OnGraphQLValidationStartPayload::new( + req, consumer_schema_ast, &parser_payload.parsed_operation, &app_state.validation_plan, ); - let arc_res = Arc::new(res); + let mut on_end_callbacks = vec![]; + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_graphql_validation(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Ok(Some(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let errors = match start_payload.errors { + Some(errors) => errors, + None => { + validate( + consumer_schema_ast, + &start_payload.document, + start_payload.get_validation_plan(), + ) + } + }; + + let mut end_payload = OnGraphQLValidationEndPayload { + router_http_request: req, + schema: consumer_schema_ast, + document: &parser_payload.parsed_operation, + errors, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + return Ok(Some(response)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + } + } + } + /* Handle on_graphql_validate hook in the plugins - END */ + + let arc_res = Arc::new(end_payload.errors); schema_state .validate_cache @@ -64,5 +122,5 @@ pub async fn validate_operation_with_cache( ); } - Ok(()) + Ok(None) } diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index f14cc6cf0..3f79a06aa 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -1,11 +1,9 @@ use arc_swap::{ArcSwap, Guard}; use async_trait::async_trait; -use graphql_tools::validation::utils::ValidationError; +use graphql_tools::{static_graphql::schema::Document, validation::utils::ValidationError}; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; use hive_router_plan_executor::{ - executors::error::SubgraphExecutorError, - introspection::schema::{SchemaMetadata, SchemaWithMetadata}, - SubgraphExecutorMap, + SubgraphExecutorMap, executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData}, introspection::schema::SchemaWithMetadata, plugin_trait::{ControlFlowResult, RouterPlugin} }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use hive_router_query_planner::{ @@ -20,12 +18,10 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error, trace}; use crate::{ - background_tasks::{BackgroundTask, BackgroundTasksManager}, - pipeline::normalize::GraphQLNormalizationPayload, - supergraph::{ + RouterSharedState, background_tasks::{BackgroundTask, BackgroundTasksManager}, pipeline::normalize::GraphQLNormalizationPayload, supergraph::{ base::{LoadSupergraphError, ReloadSupergraphResult, SupergraphLoader}, resolve_from_config, - }, + } }; pub struct SchemaState { @@ -35,12 +31,6 @@ pub struct SchemaState { pub normalize_cache: Cache>, } -pub struct SupergraphData { - pub metadata: SchemaMetadata, - pub planner: Planner, - pub subgraph_executor_map: SubgraphExecutorMap, -} - #[derive(Debug, thiserror::Error)] pub enum SupergraphManagerError { #[error("Failed to load supergraph: {0}")] @@ -65,6 +55,7 @@ impl SchemaState { pub async fn new_from_config( bg_tasks_manager: &mut BackgroundTasksManager, router_config: Arc, + app_state: Arc ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new(&router_config.supergraph, tx)?; @@ -85,9 +76,58 @@ impl SchemaState { while let Some(new_sdl) = rx.recv().await { debug!("Received new supergraph SDL, building new supergraph state..."); - match Self::build_data(router_config.clone(), &new_sdl) { - Ok(new_data) => { - swappable_data_spawn_clone.store(Arc::new(Some(new_data))); + let new_ast = parse_schema(&new_sdl); + + let mut start_payload = OnSupergraphLoadStartPayload { + current_supergraph_data: swappable_data_spawn_clone.clone(), + new_ast, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_supergraph_reload(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + }, + ControlFlowResult::EndResponse(_) => { + unreachable!("Plugins should not end supergraph reload processing"); + }, + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let new_ast = start_payload.new_ast; + + match Self::build_data(router_config.clone(), &new_ast, app_state.plugins.clone()) { + Ok(new_supergraph_data) => { + let mut end_payload = OnSupergraphLoadEndPayload { + new_supergraph_data, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + }, + ControlFlowResult::EndResponse(_) => { + unreachable!("Plugins should not end supergraph reload processing"); + }, + ControlFlowResult::OnEnd(_) => { + unreachable!("End callbacks should not register further end callbacks"); + } + } + } + + let new_supergraph_data = end_payload.new_supergraph_data; + + swappable_data_spawn_clone.store(Arc::new(Some(new_supergraph_data))); debug!("Supergraph updated successfully"); task_plan_cache.invalidate_all(); @@ -112,15 +152,16 @@ impl SchemaState { fn build_data( router_config: Arc, - supergraph_sdl: &str, + parsed_supergraph_sdl: &Document, + plugins: Arc>>, ) -> Result { - let parsed_supergraph_sdl = parse_schema(supergraph_sdl); let supergraph_state = SupergraphState::new(&parsed_supergraph_sdl); let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; let metadata = planner.consumer_schema.schema_metadata(); let subgraph_executor_map = SubgraphExecutorMap::from_http_endpoint_map( supergraph_state.subgraph_endpoint_map, router_config, + plugins.clone(), )?; Ok(SupergraphData { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 06446102a..877ffa0e3 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -19,7 +19,7 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, - pub plugins: Vec>, + pub plugins: Arc>>, } impl RouterSharedState { @@ -38,7 +38,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, - plugins: Vec::new(), + plugins: Arc::new(vec![]), }) } } diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 7dcfc03fb..39d51ee7e 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -35,6 +35,7 @@ ahash = "0.8.12" regex-automata = "0.4.10" strum = { version = "0.27.2", features = ["derive"] } +arc-swap = "1.7.1" ntex = { version = "2", features = ["tokio"] } ntex-http = "0.1.15" ordered-float = "4.2.0" diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 6985376cc..35540dab2 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap}; use bytes::Bytes; use http::Method; diff --git a/lib/executor/src/execution/error.rs b/lib/executor/src/execution/error.rs index 63460eb48..aaaf9a729 100644 --- a/lib/executor/src/execution/error.rs +++ b/lib/executor/src/execution/error.rs @@ -116,7 +116,7 @@ impl IntoPlanExecutionError for Result { let kind = PlanExecutionErrorKind::ProjectionFailure(source); PlanExecutionError::new(kind, context) }) - } + } } impl IntoPlanExecutionError for Result { diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 6bc314516..520f429c8 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -7,48 +7,44 @@ use hive_router_query_planner::planner::plan_nodes::{ QueryPlan, SequenceNode, }; use http::HeaderMap; +use ntex::web::HttpRequest; use serde::Deserialize; use sonic_rs::ValueRef; use crate::{ - context::ExecutionContext, - execution::{ + context::ExecutionContext, execution::{ client_request_details::ClientRequestDetails, error::{IntoPlanExecutionError, LazyPlanContext, PlanExecutionError}, jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt, - }, - executors::{ - common::{SubgraphExecutionRequest, HttpExecutionResponse}, + }, executors::{ + common::{HttpExecutionResponse, SubgraphExecutionRequest}, map::SubgraphExecutorMap, - }, - headers::{ + }, headers::{ plan::HeaderRulesPlan, request::modify_subgraph_request_headers, response::{apply_subgraph_response_headers, modify_client_response_headers}, - }, - introspection::{ - resolve::{resolve_introspection, IntrospectionContext}, + }, hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, introspection::{ + resolve::{IntrospectionContext, resolve_introspection}, schema::SchemaMetadata, - }, - projection::{ + }, plugin_trait::{ControlFlowResult, RouterPlugin}, projection::{ plan::FieldProjectionPlan, - request::{project_requires, RequestProjectionContext}, + request::{RequestProjectionContext, project_requires}, response::project_by_operation, - }, - response::{ + }, response::{ graphql_error::{GraphQLError, GraphQLErrorExtensions, GraphQLErrorPath}, merge::deep_merge, subgraph_response::SubgraphResponse, value::Value, - }, - utils::{ + }, utils::{ consts::{CLOSE_BRACKET, OPEN_BRACKET}, traverse::{traverse_and_callback, traverse_and_callback_mut}, - }, + } }; pub struct QueryPlanExecutionContext<'exec, 'req> { + pub router_http_request: &'exec HttpRequest, + pub plugins: &'exec Vec>, pub query_plan: &'exec QueryPlan, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, @@ -61,6 +57,7 @@ pub struct QueryPlanExecutionContext<'exec, 'req> { pub jwt_auth_forwarding: &'exec Option, } +#[derive(Clone)] pub struct PlanExecutionOutput { pub body: Vec, pub headers: HeaderMap, @@ -75,6 +72,36 @@ pub async fn execute_query_plan<'exec, 'req>( Value::Null }; + let dedupe_subgraph_requests = ctx.operation_type_name == "Query"; + + let mut start_payload = OnExecuteStartPayload { + router_http_request: ctx.router_http_request, + query_plan: ctx.query_plan, + data: init_value, + errors: Vec::new(), + extensions: ctx.extensions.clone(), + variable_values: ctx.variable_values, + dedupe_subgraph_requests, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in ctx.plugins { + let result = plugin.on_execute(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ }, + ControlFlowResult::EndResponse(response) => { + return Ok(response); + }, + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let init_value = start_payload.data; + let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value); let executor = Executor::new( ctx.variable_values, @@ -100,15 +127,36 @@ pub async fn execute_query_plan<'exec, 'req>( affected_path: || None, })?; - let final_response = &exec_ctx.final_response; + let mut end_payload = OnExecuteEndPayload { + data: exec_ctx.final_response, + errors: exec_ctx.errors, + extensions: start_payload.extensions, + response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ }, + ControlFlowResult::EndResponse(response) => { + return Ok(response); + }, + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + let body = project_by_operation( - final_response, - exec_ctx.errors, + &end_payload.data, + end_payload.errors, &ctx.extensions, ctx.operation_type_name, ctx.projection_plan, ctx.variable_values, - exec_ctx.response_storage.estimate_final_response_size(), + end_payload.response_size_estimate, ) .with_plan_context(LazyPlanContext { subgraph_name: || None, diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 5947cd7d3..c09e01067 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; +use crate::plugin_trait::RouterPlugin; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -28,7 +29,6 @@ use crate::utils::consts::COMMA; use crate::utils::consts::QUOTE; use crate::{executors::common::SubgraphExecutor, json_writer::write_and_escape_string}; -#[derive(Debug)] pub struct HTTPSubgraphExecutor { pub subgraph_name: String, pub endpoint: http::Uri, @@ -37,6 +37,7 @@ pub struct HTTPSubgraphExecutor { pub semaphore: Arc, pub config: Arc, pub in_flight_requests: Arc>, ABuildHasher>>, + pub plugins: Arc>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -52,6 +53,7 @@ impl HTTPSubgraphExecutor { semaphore: Arc, config: Arc, in_flight_requests: Arc>, ABuildHasher>>, + plugins: Arc>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -71,6 +73,7 @@ impl HTTPSubgraphExecutor { semaphore, config, in_flight_requests, + plugins, } } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 2e8ac78ae..6f780b76f 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,16 +27,14 @@ use vrl::{ }; use crate::{ - execution::client_request_details::ClientRequestDetails, - executors::{ + execution::client_request_details::ClientRequestDetails, executors::{ common::{ - SubgraphExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, + HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc }, dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient}, - }, - response::graphql_error::GraphQLError, + }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError }; type SubgraphName = String; @@ -60,10 +58,14 @@ pub struct SubgraphExecutorMap { semaphores_by_origin: DashMap>, max_connections_per_host: usize, in_flight_requests: Arc>, ABuildHasher>>, + plugins: Arc>>, } impl SubgraphExecutorMap { - pub fn new(config: Arc) -> Self { + pub fn new( + config: Arc, + plugins: Arc>>, + ) -> Self { let https = HttpsConnector::new(); let client: HttpClient = Client::builder(TokioExecutor::new()) .pool_timer(TokioTimer::new()) @@ -85,14 +87,16 @@ impl SubgraphExecutorMap { semaphores_by_origin: Default::default(), max_connections_per_host, in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())), + plugins, } } pub fn from_http_endpoint_map( subgraph_endpoint_map: HashMap, config: Arc, + plugins: Arc>>, ) -> Result { - let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone()); + let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), plugins); for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.into_iter() { let endpoint_str = config @@ -121,8 +125,40 @@ impl SubgraphExecutorMap { execution_request: SubgraphExecutionRequest<'a>, client_request: &ClientRequestDetails<'a, 'req>, ) -> HttpExecutionResponse { - match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor.execute(execution_request).await, + let mut start_payload = OnSubgraphExecuteStartPayload { + subgraph_name: subgraph_name.to_string(), + execution_request, + execution_result: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in self.plugins.as_ref() { + let result = plugin.on_subgraph_execute(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + }; + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let execution_request = start_payload.execution_request; + + let execution_result = match self.get_or_create_executor(subgraph_name, client_request) { + Ok(Some(executor)) => executor + .execute(execution_request) + .await, Err(err) => { error!( "Subgraph executor error for subgraph '{}': {}", @@ -137,7 +173,33 @@ impl SubgraphExecutorMap { ); self.internal_server_error_response("Internal server error".into(), subgraph_name) } + }; + + let mut end_payload = OnSubgraphExecuteEndPayload { + execution_result + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + }; + } + ControlFlowResult::OnEnd(_) => { + unreachable!("End callbacks should not register further end callbacks"); + } + } } + + end_payload.execution_result } fn internal_server_error_response( @@ -324,6 +386,7 @@ impl SubgraphExecutorMap { semaphore, self.config.clone(), self.in_flight_requests.clone(), + self.plugins.clone(), ); self.executors_by_subgraph diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index 7d6ac9256..d5400d314 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -2,7 +2,7 @@ use dashmap::DashMap; use sonic_rs::{JsonContainerTrait, JsonValueTrait}; use crate::{ - hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, }; @@ -11,13 +11,13 @@ pub struct APQPlugin { } impl RouterPlugin for APQPlugin { - fn on_deserialization<'exec>( + fn on_graphql_params<'exec>( &'exec self, - start_payload: OnDeserializationStartPayload<'exec>, - ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { - start_payload.on_end(|mut end_payload| { - let persisted_query_ext = end_payload.graphql_params.extensions.as_ref() + payload.on_end(|mut payload| { + let persisted_query_ext = payload.graphql_params.extensions.as_ref() .and_then(|ext| ext.get("persistedQuery")) .and_then(|pq| pq.as_object()); if let Some(persisted_query_ext) = persisted_query_ext { @@ -25,31 +25,31 @@ impl RouterPlugin for APQPlugin { Some("1") => {} _ => { // TODO: Error for unsupported version - return end_payload.cont(); + return payload.cont(); } } let sha256_hash = match persisted_query_ext.get(&"sha256Hash").and_then(|h| h.as_str()) { Some(h) => h, None => { - return end_payload.cont(); + return payload.cont(); } }; - if let Some(query_param) = &end_payload.graphql_params.query { + if let Some(query_param) = &payload.graphql_params.query { // Store the query in the cache self.cache.insert(sha256_hash.to_string(), query_param.to_string()); } else { // Try to get the query from the cache if let Some(cached_query) = self.cache.get(sha256_hash) { // Update the graphql_params with the cached query - end_payload.graphql_params.query = Some(cached_query.value().to_string()); + payload.graphql_params.query = Some(cached_query.value().to_string()); } else { // Error - return end_payload.cont(); + return payload.cont(); } } } - end_payload.cont() + payload.cont() }) } } diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index 5942e6d91..d9d611307 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -3,7 +3,13 @@ use http::HeaderMap; use redis::Commands; use crate::{ - execution::plan::PlanExecutionOutput, hooks::{on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_schema_reload::OnSchemaReloadPayload}, plugin_trait::{EndPayload, HookResult, StartPayload}, plugins::plugin_trait::RouterPlugin, utils::consts::TYPENAME_FIELD_NAME + execution::plan::PlanExecutionOutput, + hooks::{ + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + }, + plugin_trait::{EndPayload, HookResult, StartPayload}, + plugins::plugin_trait::RouterPlugin, + utils::consts::TYPENAME_FIELD_NAME, }; pub struct ResponseCachePlugin { @@ -33,15 +39,12 @@ impl RouterPlugin for ResponseCachePlugin { if let Ok(mut conn) = self.redis_client.get_connection() { let cached_response: Option> = conn.get(&key).ok(); if let Some(cached_response) = cached_response { - return payload.end_response( - - PlanExecutionOutput { - body: cached_response, - headers: HeaderMap::new(), - } - ); + return payload.end_response(PlanExecutionOutput { + body: cached_response, + headers: HeaderMap::new(), + }); } - return payload.on_end(move |payload: OnExecuteEndPayload<'exec>| { + return payload.on_end(move |mut payload: OnExecuteEndPayload<'exec>| { // Do not cache if there are errors if !payload.errors.is_empty() { return payload.cont(); @@ -73,6 +76,7 @@ impl RouterPlugin for ResponseCachePlugin { // Insert the ttl into extensions for client awareness payload .extensions + .get_or_insert_default() .insert("response_cache_ttl".to_string(), sonic_rs::json!(max_ttl)); // Set the cache with the decided ttl @@ -83,11 +87,10 @@ impl RouterPlugin for ResponseCachePlugin { } payload.cont() } - fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { + fn on_supergraph_reload<'a>(&'a self, payload: OnSupergraphLoadStartPayload) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { // Visit the schema and update ttl_per_type based on some directive payload - .new_schema - .document + .new_ast .definitions .iter() .for_each(|def| { @@ -110,5 +113,7 @@ impl RouterPlugin for ResponseCachePlugin { } } }); + + payload.cont() } } diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 55d98a893..037a314d0 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,16 +1,16 @@ use dashmap::DashMap; -use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; +use crate::{executors::dedupe::SharedResponse, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; pub struct SubgraphResponseCachePlugin { cache: DashMap, } impl RouterPlugin for SubgraphResponseCachePlugin { - fn on_subgraph_http_request<'exec>( - &'static self, - payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { + fn on_subgraph_execute<'exec>( + &'exec self, + payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { let key = format!( "subgraph_response_cache:{}:{:?}", payload.execution_request.query, payload.execution_request.variables @@ -21,9 +21,9 @@ impl RouterPlugin for SubgraphResponseCachePlugin { *payload.response = Some(cached_response.clone()); return payload.cont(); } - payload.on_end(move |payload: OnSubgraphHttpResponsePayload<'exec>| { + payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { // Here payload.response is not Option - self.cache.insert(key, payload.response.clone()); + self.cache.insert(key, payload.execution_result.body.as_ref()); payload.cont() }) } diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 65ccf6f4d..453c84c98 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,8 +1,8 @@ pub mod on_execute; -pub mod on_schema_reload; +pub mod on_supergraph_load; pub mod on_subgraph_http_request; pub mod on_http_request; -pub mod on_deserialization; +pub mod on_graphql_params; pub mod on_graphql_parse; pub mod on_graphql_validation; pub mod on_query_plan; diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 5057075e3..dfcdaceb8 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -11,31 +11,23 @@ pub struct OnExecuteStartPayload<'exec> { pub router_http_request: &'exec HttpRequest, pub query_plan: &'exec QueryPlan, - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut HashMap, - - pub skip_execution: &'exec mut bool, + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, pub variable_values: &'exec Option>, - pub dedupe_subgraph_requests: &'exec mut bool, + pub dedupe_subgraph_requests: bool, } impl<'exec> StartPayload> for OnExecuteStartPayload<'exec> {} pub struct OnExecuteEndPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: &'exec QueryPlan, - - - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut HashMap, - - pub variable_values: &'exec Option>, + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, - pub dedupe_subgraph_requests: &'exec mut bool, + pub response_size_estimate: usize, } impl<'exec> EndPayload for OnExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_deserialization.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs similarity index 76% rename from lib/executor/src/plugins/hooks/on_deserialization.rs rename to lib/executor/src/plugins/hooks/on_graphql_params.rs index 84991ff56..5e6ce1c47 100644 --- a/lib/executor/src/plugins/hooks/on_deserialization.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -29,17 +29,16 @@ where Ok(opt.unwrap_or_default()) } -pub struct OnDeserializationStartPayload<'exec> { +pub struct OnGraphQLParamsStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, pub body: Bytes, pub graphql_params: Option, } -impl<'exec> StartPayload> for OnDeserializationStartPayload<'exec> {} +impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} -pub struct OnDeserializationEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, +pub struct OnGraphQLParamsEndPayload { pub graphql_params: GraphQLParams, } -impl<'exec> EndPayload for OnDeserializationEndPayload<'exec> {} +impl EndPayload for OnGraphQLParamsEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs index 8719cdac3..162a7eee2 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_parse.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -1,6 +1,6 @@ use graphql_tools::static_graphql::query::Document; -use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; +use crate::{hooks::on_graphql_params::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; pub struct OnGraphQLParseStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index e5ecf898f..a789cb5fd 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -1,25 +1,71 @@ -use graphql_tools::{static_graphql::query::Document, validation::{utils::ValidationError, validate::ValidationPlan}}; +use graphql_tools::{ + static_graphql::query::Document, + validation::{rules::{ValidationRule, default_rules_validation_plan}, utils::ValidationError, validate::ValidationPlan}, +}; use hive_router_query_planner::state::supergraph_state::SchemaDocument; -use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; +use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnGraphQLValidationStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub graphql_params: &'exec GraphQLParams, pub schema: &'exec SchemaDocument, pub document: &'exec Document, - pub validation_plan: &'exec mut ValidationPlan, - pub errors: &'exec mut Option> + default_validation_plan: &'exec ValidationPlan, + new_validation_plan: Option, + pub errors: Option>, } -impl<'exec> StartPayload> for OnGraphQLValidationStartPayload<'exec> {} +impl<'exec> StartPayload> + for OnGraphQLValidationStartPayload<'exec> +{ +} + +impl<'exec> OnGraphQLValidationStartPayload<'exec> { + pub fn new( + router_http_request: &'exec ntex::web::HttpRequest, + schema: &'exec SchemaDocument, + document: &'exec Document, + default_validation_plan: &'exec ValidationPlan, + ) -> Self { + OnGraphQLValidationStartPayload { + router_http_request, + schema, + document, + default_validation_plan, + new_validation_plan: None, + errors: None, + } + } + + pub fn add_validation_rule(&mut self, rule: Box) { + self.new_validation_plan + .get_or_insert_with(|| default_rules_validation_plan()) + .add_rule(rule); + } + + pub fn filter_validation_rules(&mut self, mut f: F) + where + F: FnMut(&Box) -> bool, + { + let plan = self + .new_validation_plan + .get_or_insert_with(|| default_rules_validation_plan()); + plan.rules.retain(|rule| f(rule)); + } + + pub fn get_validation_plan(&self) -> &ValidationPlan { + match &self.new_validation_plan { + Some(plan) => plan, + None => self.default_validation_plan, + } + } +} pub struct OnGraphQLValidationEndPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub graphql_params: &'exec GraphQLParams, pub schema: &'exec SchemaDocument, pub document: &'exec Document, - pub errors: &'exec mut Vec, + pub errors: Vec, } -impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} \ No newline at end of file +impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 847e7465e..29a8344e5 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -3,7 +3,7 @@ use ntex::{http::Response, web::HttpRequest}; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnHttpRequestPayload<'exec> { - pub router_http_request: &'exec HttpRequest, + pub client_request: &'exec HttpRequest, } impl<'exec> StartPayload> for OnHttpRequestPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs index 7963524ad..39ae3c2d6 100644 --- a/lib/executor/src/plugins/hooks/on_query_plan.rs +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -1,13 +1,13 @@ -use graphql_tools::static_graphql::query::Document; -use hive_router_query_planner::planner::{Planner, plan_nodes::QueryPlan}; +use hive_router_query_planner::{ast::operation::OperationDefinition, graph::PlannerOverrideContext, planner::{Planner, plan_nodes::QueryPlan}, utils::cancellation::CancellationToken}; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnQueryPlanStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub document: &'exec Document, - // Other params - pub query_plan: &'exec mut Option, + pub filtered_operation_for_plan: &'exec OperationDefinition, + pub planner_override_context: PlannerOverrideContext, + pub cancellation_token: &'exec CancellationToken, + pub query_plan: Option, pub planner: &'exec Planner, } @@ -15,9 +15,11 @@ impl<'exec> StartPayload> for OnQueryPlanStartPaylo pub struct OnQueryPlanEndPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub document: &'exec Document, - // Other params - pub query_plan: &'exec mut QueryPlan, + pub filtered_operation_for_plan: &'exec OperationDefinition, + pub planner_override_context: PlannerOverrideContext, + pub cancellation_token: &'exec CancellationToken, + pub query_plan: QueryPlan, + pub planner: &'exec Planner, } impl<'exec> EndPayload for OnQueryPlanEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_schema_reload.rs b/lib/executor/src/plugins/hooks/on_schema_reload.rs deleted file mode 100644 index a96d6c240..000000000 --- a/lib/executor/src/plugins/hooks/on_schema_reload.rs +++ /dev/null @@ -1,6 +0,0 @@ -use hive_router_query_planner::consumer_schema::ConsumerSchema; - -pub struct OnSchemaReloadPayload<'a> { - pub old_schema: &'a ConsumerSchema, - pub new_schema: &'a mut ConsumerSchema, -} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 167340bc8..6a514006f 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,34 +1,19 @@ -use bytes::Bytes; -use hive_router_query_planner::planner::plan_nodes::FetchNode; -use crate::{executors::common::{SubgraphExecutionRequest, SubgraphExecutorBoxedArc}, plugin_trait::{EndPayload, StartPayload}, response::subgraph_response::SubgraphResponse}; + +use crate::{executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, plugin_trait::{EndPayload, StartPayload}}; pub struct OnSubgraphExecuteStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub executor: &'exec SubgraphExecutorBoxedArc, - pub subgraph_name: &'exec str, + pub subgraph_name: String, - pub node: &'exec mut FetchNode, - pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, - pub response: &'exec mut Option>, + pub execution_request: SubgraphExecutionRequest<'exec>, + pub execution_result: Option, } -impl<'exec> StartPayload> for OnSubgraphExecuteStartPayload<'exec> {} - -pub enum SubgraphExecutorResponse<'exec> { - Bytes(Bytes), - SubgraphResponse(SubgraphResponse<'exec>), -} - -pub struct OnSubgraphExecuteEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub executor: &'exec SubgraphExecutorBoxedArc, - pub subgraph_name: &'exec str, +impl<'exec> StartPayload for OnSubgraphExecuteStartPayload<'exec> {} - pub node: &'exec FetchNode, - pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - pub response: &'exec mut SubgraphExecutorResponse<'exec>, +pub struct OnSubgraphExecuteEndPayload { + pub execution_result: HttpExecutionResponse, } -impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} \ No newline at end of file +impl<'exec> EndPayload for OnSubgraphExecuteEndPayload {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index ac720b870..44bfa9dc9 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,4 +1,6 @@ -use http::{HeaderMap, Uri}; +use bytes::Bytes; +use http::{HeaderMap, Request, Uri}; +use http_body_util::Full; use ntex::web::HttpRequest; use crate::{ @@ -6,28 +8,18 @@ use crate::{ ; pub struct OnSubgraphHttpRequestPayload<'exec> { - pub router_http_request: &'exec HttpRequest, pub subgraph_name: &'exec str, // At this point, there is no point of mutating this - pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - - pub endpoint: &'exec mut Uri, - // By default, it is POST - pub method: &'exec mut http::Method, - pub headers: &'exec mut HeaderMap, - pub request_body: &'exec mut Vec, + pub request: Request>, // Early response - pub response: &'exec mut Option, + pub response: Option, } -impl<'exec> StartPayload> for OnSubgraphHttpRequestPayload<'exec> {} +impl<'exec> StartPayload for OnSubgraphHttpRequestPayload<'exec> {} -pub struct OnSubgraphHttpResponsePayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub subgraph_name: &'exec str, - pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - pub response: &'exec mut SharedResponse, +pub struct OnSubgraphHttpResponsePayload { + pub response: SharedResponse, } -impl<'exec> EndPayload for OnSubgraphHttpResponsePayload<'exec> {} +impl<'exec> EndPayload for OnSubgraphHttpResponsePayload {} diff --git a/lib/executor/src/plugins/hooks/on_supergraph_load.rs b/lib/executor/src/plugins/hooks/on_supergraph_load.rs new file mode 100644 index 000000000..e4e68ca35 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_supergraph_load.rs @@ -0,0 +1,27 @@ +use std::sync::Arc; + +use graphql_tools::static_graphql::schema::Document; +use hive_router_query_planner::{planner::Planner}; +use arc_swap::{ArcSwap}; + +use crate::{SubgraphExecutorMap, introspection::schema::SchemaMetadata, plugin_trait::{EndPayload, StartPayload}}; + + +pub struct SupergraphData { + pub metadata: SchemaMetadata, + pub planner: Planner, + pub subgraph_executor_map: SubgraphExecutorMap, +} + +pub struct OnSupergraphLoadStartPayload { + pub current_supergraph_data: Arc>>, + pub new_ast: Document, +} + +impl StartPayload for OnSupergraphLoadStartPayload {} + +pub struct OnSupergraphLoadEndPayload { + pub new_supergraph_data: SupergraphData, +} + +impl EndPayload for OnSupergraphLoadEndPayload {} \ No newline at end of file diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 220c5b88c..d56856652 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -1,113 +1,141 @@ use crate::execution::plan::PlanExecutionOutput; -use crate::hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}; use crate::hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}; +use crate::hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}; use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; -use crate::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use crate::hooks::on_graphql_validation::{ + OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, +}; use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponse}; use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; -use crate::hooks::on_schema_reload::OnSchemaReloadPayload; -use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; +use crate::hooks::on_subgraph_execute::{ + OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, +}; +use crate::hooks::on_subgraph_http_request::{ + OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, +}; +use crate::hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}; pub struct HookResult<'exec, TStartPayload, TEndPayload> { - pub start_payload: TStartPayload, + pub payload: TStartPayload, pub control_flow: ControlFlowResult<'exec, TEndPayload>, } pub enum ControlFlowResult<'exec, TEndPayload> { Continue, EndResponse(PlanExecutionOutput), - OnEnd(Box HookResult<'exec, TEndPayload, ()> + 'exec>), + OnEnd(Box HookResult<'exec, TEndPayload, ()> + Send + 'exec>), } pub trait StartPayload - where Self: Sized - { - +where + Self: Sized, +{ fn cont<'exec>(self) -> HookResult<'exec, Self, TEndPayload> { HookResult { - start_payload: self, + payload: self, control_flow: ControlFlowResult::Continue, } } - fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, TEndPayload> { + fn end_response<'exec>( + self, + output: PlanExecutionOutput, + ) -> HookResult<'exec, Self, TEndPayload> { HookResult { - start_payload: self, + payload: self, control_flow: ControlFlowResult::EndResponse(output), } } fn on_end<'exec, F>(self, f: F) -> HookResult<'exec, Self, TEndPayload> - where F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + 'exec, + where + F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + Send + 'exec, { HookResult { - start_payload: self, + payload: self, control_flow: ControlFlowResult::OnEnd(Box::new(f)), } } } pub trait EndPayload - where Self: Sized - { - fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { - HookResult { - start_payload: self, - control_flow: ControlFlowResult::Continue, - } +where + Self: Sized, +{ + fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { + HookResult { + payload: self, + control_flow: ControlFlowResult::Continue, } + } - fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { - HookResult { - start_payload: self, - control_flow: ControlFlowResult::EndResponse(output), - } + fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { + HookResult { + payload: self, + control_flow: ControlFlowResult::EndResponse(output), } + } } -// Add sync send etc pub trait RouterPlugin { fn on_http_request<'exec>( - &self, + &self, start_payload: OnHttpRequestPayload<'exec>, ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponse<'exec>> { start_payload.cont() } - fn on_deserialization<'exec>( - &'exec self, - start_payload: OnDeserializationStartPayload<'exec>, - ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> { + fn on_graphql_params<'exec>( + &'exec self, + start_payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { start_payload.cont() } fn on_graphql_parse<'exec>( - &self, + &self, start_payload: OnGraphQLParseStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload<'exec>> { start_payload.cont() } fn on_graphql_validation<'exec>( - &self, + &self, start_payload: OnGraphQLValidationStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload<'exec>> { + ) -> HookResult< + 'exec, + OnGraphQLValidationStartPayload<'exec>, + OnGraphQLValidationEndPayload<'exec>, + > { start_payload.cont() } fn on_query_plan<'exec>( - &self, + &self, start_payload: OnQueryPlanStartPayload<'exec>, - ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { start_payload.cont() } fn on_execute<'exec>( - &'exec self, + &'exec self, start_payload: OnExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + start_payload.cont() + } + fn on_subgraph_execute<'exec>( + &'exec self, + start_payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> + { start_payload.cont() } fn on_subgraph_http_request<'exec>( - &'static self, + &'exec self, start_payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> + { + start_payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { start_payload.cont() } - fn on_schema_reload<'a>(&'a self, _start_payload: OnSchemaReloadPayload) {} -} \ No newline at end of file +} From 134dda5cd6c93989f360906102c0a6a18fdf48a4 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 18 Nov 2025 22:14:26 +0300 Subject: [PATCH 07/31] More --- lib/executor/src/executors/common.rs | 1 + lib/executor/src/executors/http.rs | 145 +++++++++++++----- .../examples/subgraph_response_cache.rs | 10 +- .../plugins/hooks/on_subgraph_http_request.rs | 5 +- 4 files changed, 111 insertions(+), 50 deletions(-) diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index ba13b8707..9044c062e 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -45,6 +45,7 @@ impl SubgraphExecutionRequest<'_> { } } +#[derive(Clone)] pub struct HttpExecutionResponse { pub body: Bytes, pub headers: HeaderMap, diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index c09e01067..1e00e6fce 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -2,7 +2,8 @@ use std::sync::Arc; use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; -use crate::plugin_trait::RouterPlugin; +use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; +use crate::plugin_trait::{ControlFlowResult, RouterPlugin}; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -10,7 +11,7 @@ use tokio::sync::OnceCell; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; -use http::HeaderMap; +use http::{HeaderMap, StatusCode}; use http::HeaderValue; use http_body_util::BodyExt; use http_body_util::Full; @@ -136,31 +137,87 @@ impl HTTPSubgraphExecutor { Ok(body) } - async fn _send_request( - &self, - body: Vec, - headers: HeaderMap, - ) -> Result { + fn error_to_graphql_bytes(&self, error: SubgraphExecutorError) -> Bytes { + let graphql_error: GraphQLError = error.into(); + let mut graphql_error = graphql_error.add_subgraph_name(&self.subgraph_name); + graphql_error.message = "Failed to execute request to subgraph".to_string(); + + let errors = vec![graphql_error]; + // This unwrap is safe as GraphQLError serialization shouldn't fail. + let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); + let mut buffer = BytesMut::new(); + buffer.put_slice(b"{\"errors\":"); + buffer.put_slice(&errors_bytes); + buffer.put_slice(b"}"); + buffer.freeze() + } + + fn log_error(&self, error: &SubgraphExecutorError) { + tracing::error!( + error = error as &dyn std::error::Error, + "Subgraph executor error" + ); + } +} + +async fn send_request( + http_client: &Client, Full>, + subgraph_name: &str, + endpoint: &http::Uri, + method: http::Method, + body: Vec, + headers: HeaderMap, + plugins: Arc>>, +) -> Result { let mut req = hyper::Request::builder() - .method(http::Method::POST) - .uri(&self.endpoint) + .method(method) + .uri(endpoint) .version(Version::HTTP_11) .body(Full::new(Bytes::from(body))) .map_err(|e| { - SubgraphExecutorError::RequestBuildFailure(self.endpoint.to_string(), e.to_string()) + SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) })?; *req.headers_mut() = headers; - debug!("making http request to {}", self.endpoint.to_string()); + let mut start_payload = OnSubgraphHttpRequestPayload { + subgraph_name, + request: req, + response: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.as_ref() { + let result = plugin.on_subgraph_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + // TODO: Fixx + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + debug!("making http request to {}", endpoint.to_string()); - let res = self.http_client.request(req).await.map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) + let req = start_payload.request; + + let res = http_client.request(req).await.map_err(|e| { + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) })?; debug!( "http request to {} completed, status: {}", - self.endpoint.to_string(), + endpoint.to_string(), res.status() ); @@ -169,45 +226,47 @@ impl HTTPSubgraphExecutor { .collect() .await .map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) })? .to_bytes(); if body.is_empty() { return Err(SubgraphExecutorError::RequestFailure( - self.endpoint.to_string(), + endpoint.to_string(), "Empty response body".to_string(), )); } - Ok(SharedResponse { + let response = SharedResponse { status: parts.status, - body, + body: body, headers: parts.headers, - }) - } + }; - fn error_to_graphql_bytes(&self, error: SubgraphExecutorError) -> Bytes { - let graphql_error: GraphQLError = error.into(); - let mut graphql_error = graphql_error.add_subgraph_name(&self.subgraph_name); - graphql_error.message = "Failed to execute request to subgraph".to_string(); + let mut end_payload = OnSubgraphHttpResponsePayload { + response, + }; - let errors = vec![graphql_error]; - // This unwrap is safe as GraphQLError serialization shouldn't fail. - let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); - let mut buffer = BytesMut::new(); - buffer.put_slice(b"{\"errors\":"); - buffer.put_slice(&errors_bytes); - buffer.put_slice(b"}"); - buffer.freeze() - } + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(response) => { + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } - fn log_error(&self, error: &SubgraphExecutorError) { - tracing::error!( - error = error as &dyn std::error::Error, - "Subgraph executor error" - ); - } + Ok(end_payload.response) } #[async_trait] @@ -233,11 +292,13 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { headers.insert(key, value.clone()); }); + let method = http::Method::POST; + if !self.config.traffic_shaping.dedupe_enabled || !execution_request.dedupe { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - return match self._send_request(body, headers).await { + return match send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body, headers: shared_response.headers, @@ -252,7 +313,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { }; } - let fingerprint = request_fingerprint(&http::Method::POST, &self.endpoint, &headers, &body); + let fingerprint = request_fingerprint(&method, &self.endpoint, &headers, &body); // Clone the cell from the map, dropping the lock from the DashMap immediately. // Prevents any deadlocks. @@ -269,7 +330,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - self._send_request(body, headers).await + send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await }; // It's important to remove the entry from the map before returning the result. // This ensures that once the OnceCell is set, no future requests can join it. diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 037a314d0..71ff8c1d9 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,15 +1,15 @@ use dashmap::DashMap; -use crate::{executors::dedupe::SharedResponse, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; +use crate::{executors::{common::HttpExecutionResponse}, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; pub struct SubgraphResponseCachePlugin { - cache: DashMap, + cache: DashMap, } impl RouterPlugin for SubgraphResponseCachePlugin { fn on_subgraph_execute<'exec>( &'exec self, - payload: OnSubgraphExecuteStartPayload<'exec>, + mut payload: OnSubgraphExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { let key = format!( "subgraph_response_cache:{}:{:?}", @@ -18,12 +18,12 @@ impl RouterPlugin for SubgraphResponseCachePlugin { if let Some(cached_response) = self.cache.get(&key) { // Here payload.response is Option // So it is bypassing the actual subgraph request - *payload.response = Some(cached_response.clone()); + payload.execution_result = Some(cached_response.clone()); return payload.cont(); } payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { // Here payload.response is not Option - self.cache.insert(key, payload.execution_result.body.as_ref()); + self.cache.insert(key, payload.execution_result.clone()); payload.cont() }) } diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 44bfa9dc9..f7c798ed7 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,10 +1,9 @@ use bytes::Bytes; -use http::{HeaderMap, Request, Uri}; +use http::{Request}; use http_body_util::Full; -use ntex::web::HttpRequest; use crate::{ - executors::{common::SubgraphExecutionRequest, dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} + executors::{dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} ; pub struct OnSubgraphHttpRequestPayload<'exec> { From 9791e1ceb96da52581cfeb89a202856152e76a2a Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Wed, 19 Nov 2025 18:56:36 +0300 Subject: [PATCH 08/31] Owned values --- bin/router/src/jwt/mod.rs | 21 +- bin/router/src/lib.rs | 53 ++-- bin/router/src/pipeline/coerce_variables.rs | 8 +- bin/router/src/pipeline/csrf_prevention.rs | 8 +- .../pipeline/deserialize_graphql_params.rs | 33 +-- bin/router/src/pipeline/error.rs | 33 +-- bin/router/src/pipeline/execution.rs | 35 +-- bin/router/src/pipeline/header.rs | 18 +- bin/router/src/pipeline/mod.rs | 219 +++++++++------ bin/router/src/pipeline/normalize.rs | 12 +- bin/router/src/pipeline/parser.rs | 57 ++-- .../src/pipeline/progressive_override.rs | 8 +- bin/router/src/pipeline/query_plan.rs | 45 ++- bin/router/src/pipeline/validation.rs | 41 ++- bin/router/src/schema_state.rs | 37 ++- .../src/execution/client_request_details.rs | 28 +- lib/executor/src/execution/error.rs | 2 +- lib/executor/src/execution/jwt_forward.rs | 2 +- lib/executor/src/execution/plan.rs | 257 ++++++++++-------- lib/executor/src/executors/http.rs | 196 +++++++------ lib/executor/src/executors/map.rs | 39 +-- lib/executor/src/headers/expression.rs | 4 +- lib/executor/src/headers/mod.rs | 96 +++---- lib/executor/src/headers/request.rs | 6 +- lib/executor/src/headers/response.rs | 4 +- lib/executor/src/lib.rs | 1 - lib/executor/src/plugins/examples/apq.rs | 18 +- lib/executor/src/plugins/examples/mod.rs | 2 +- .../src/plugins/examples/multipart.rs | 0 .../src/plugins/examples/response_cache.rs | 38 +-- .../examples/subgraph_response_cache.rs | 16 +- lib/executor/src/plugins/hooks/mod.rs | 8 +- lib/executor/src/plugins/hooks/on_execute.rs | 7 +- .../src/plugins/hooks/on_graphql_params.rs | 92 ++++++- .../src/plugins/hooks/on_graphql_parse.rs | 15 +- .../plugins/hooks/on_graphql_validation.rs | 28 +- .../src/plugins/hooks/on_http_request.rs | 2 +- .../src/plugins/hooks/on_query_plan.rs | 20 +- .../src/plugins/hooks/on_subgraph_execute.rs | 12 +- .../plugins/hooks/on_subgraph_http_request.rs | 11 +- .../src/plugins/hooks/on_supergraph_load.rs | 13 +- lib/executor/src/plugins/mod.rs | 2 +- lib/executor/src/plugins/plugin_trait.rs | 21 +- 43 files changed, 868 insertions(+), 700 deletions(-) create mode 100644 lib/executor/src/plugins/examples/multipart.rs diff --git a/bin/router/src/jwt/mod.rs b/bin/router/src/jwt/mod.rs index d95854b1c..d7804e156 100644 --- a/bin/router/src/jwt/mod.rs +++ b/bin/router/src/jwt/mod.rs @@ -265,26 +265,27 @@ impl JwtAuthRuntime { Ok(token_data) } - pub fn validate_request(&self, request: &mut HttpRequest) -> Result<(), JwtError> { + pub fn validate_request( + &self, + request: &HttpRequest, + ) -> Result, JwtError> { let valid_jwks = self.jwks.all(); match self.authenticate(&valid_jwks, request) { - Ok((token_payload, maybe_token_prefix, token)) => { - request.extensions_mut().insert(JwtRequestContext { - token_payload, - token_raw: token, - token_prefix: maybe_token_prefix, - }); - } + Ok((token_payload, maybe_token_prefix, token)) => Ok(Some(JwtRequestContext { + token_payload, + token_raw: token, + token_prefix: maybe_token_prefix, + })), Err(e) => { warn!("jwt token error: {:?}", e); if self.config.require_authentication.is_some_and(|v| v) { return Err(e); } + + Ok(None) } } - - Ok(()) } } diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index fb19df5c8..416021b3a 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -19,7 +19,11 @@ use crate::{ }, jwt::JwtAuthRuntime, logger::configure_logging, - pipeline::graphql_request_handler, + pipeline::{ + error::PipelineError, + graphql_request_handler, + header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, + }, }; pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; @@ -27,12 +31,13 @@ pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; use hive_router_config::{load_config, HiveRouterConfig}; use http::header::RETRY_AFTER; use ntex::{ - util::Bytes, web::{self, HttpRequest} + util::Bytes, + web::{self, HttpRequest}, }; use tracing::{info, warn}; async fn graphql_endpoint_handler( - mut request: HttpRequest, + req: HttpRequest, body_bytes: Bytes, schema_state: web::types::State>, app_state: web::types::State>, @@ -44,26 +49,35 @@ async fn graphql_endpoint_handler( if let Some(early_response) = app_state .cors_runtime .as_ref() - .and_then(|cors| cors.get_early_response(&request)) + .and_then(|cors| cors.get_early_response(&req)) { return early_response; } - let mut res = graphql_request_handler( - &mut request, + let accept_ok = !req.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); + + let result = match graphql_request_handler( + req, body_bytes, supergraph, - app_state.get_ref(), - schema_state.get_ref(), + app_state.get_ref().clone(), + schema_state.get_ref().clone(), ) - .await; + .await + { + Ok(response_with_req) => response_with_req, + Err(error) => return PipelineError { accept_ok, error }.into(), + }; + + let mut response = result.result; + let req = result.request; // Apply CORS headers to the final response if CORS is configured. if let Some(cors) = app_state.cors_runtime.as_ref() { - cors.set_headers(&request, res.headers_mut()); + cors.set_headers(&req, response.headers_mut()); } - res + response } else { warn!("No supergraph available yet, unable to process request"); @@ -111,18 +125,23 @@ pub async fn configure_app_from_config( }; let router_config_arc = Arc::new(router_config); - let shared_state = Arc::new(RouterSharedState::new(router_config_arc.clone(), jwt_runtime)?); - let schema_state = - SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone(), shared_state.clone()).await?; + let shared_state = Arc::new(RouterSharedState::new( + router_config_arc.clone(), + jwt_runtime, + )?); + let schema_state = SchemaState::new_from_config( + bg_tasks_manager, + router_config_arc.clone(), + shared_state.clone(), + ) + .await?; let schema_state_arc = Arc::new(schema_state); Ok((shared_state, schema_state_arc)) } pub fn configure_ntex_app(cfg: &mut web::ServiceConfig) { - cfg - .route("/graphql", web::to(graphql_endpoint_handler)) + cfg.route("/graphql", web::to(graphql_endpoint_handler)) .route("/health", web::to(health_check_handler)) .route("/readiness", web::to(readiness_check_handler)); } - diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index b159f244e..d10fbb6c4 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -10,7 +10,7 @@ use ntex::web::HttpRequest; use sonic_rs::Value; use tracing::{error, trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; #[derive(Clone, Debug)] @@ -24,14 +24,14 @@ pub fn coerce_request_variables( supergraph: &SupergraphData, graphql_params: &mut GraphQLParams, normalized_operation: &Arc, -) -> Result { +) -> Result { if req.method() == Method::GET { if let Some(OperationKind::Mutation) = normalized_operation.operation_for_plan.operation_kind { error!("Mutation is not allowed over GET, stopping"); - return Err(req.new_pipeline_error(PipelineErrorVariant::MutationNotAllowedOverHttpGet)); + return Err(PipelineErrorVariant::MutationNotAllowedOverHttpGet); } } @@ -55,7 +55,7 @@ pub fn coerce_request_variables( "failed to collect variables from incoming request: {}", err_msg ); - Err(req.new_pipeline_error(PipelineErrorVariant::VariablesCoercionError(err_msg))) + Err(PipelineErrorVariant::VariablesCoercionError(err_msg)) } } } diff --git a/bin/router/src/pipeline/csrf_prevention.rs b/bin/router/src/pipeline/csrf_prevention.rs index 51561dd99..37c063b09 100644 --- a/bin/router/src/pipeline/csrf_prevention.rs +++ b/bin/router/src/pipeline/csrf_prevention.rs @@ -1,7 +1,7 @@ use hive_router_config::csrf::CSRFPreventionConfig; use ntex::web::HttpRequest; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; // NON_PREFLIGHTED_CONTENT_TYPES are content types that do not require a preflight // OPTIONS request. These are content types that are considered "simple" by the CORS @@ -15,9 +15,9 @@ const NON_PREFLIGHTED_CONTENT_TYPES: [&str; 3] = [ #[inline] pub fn perform_csrf_prevention( - req: &mut HttpRequest, + req: &HttpRequest, csrf_config: &CSRFPreventionConfig, -) -> Result<(), PipelineError> { +) -> Result<(), PipelineErrorVariant> { // If CSRF prevention is not configured or disabled, skip the checks. if !csrf_config.enabled || csrf_config.required_headers.is_empty() { return Ok(()); @@ -39,7 +39,7 @@ pub fn perform_csrf_prevention( if has_required_header { Ok(()) } else { - Err(req.new_pipeline_error(PipelineErrorVariant::CsrfPreventionFailed)) + Err(PipelineErrorVariant::CsrfPreventionFailed) } } diff --git a/bin/router/src/pipeline/deserialize_graphql_params.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs index 3c0eb5f12..b22b18a3a 100644 --- a/bin/router/src/pipeline/deserialize_graphql_params.rs +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -7,7 +7,7 @@ use ntex::web::types::Query; use ntex::web::HttpRequest; use tracing::{trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::header::AssertRequestJson; #[derive(serde::Deserialize, Debug)] @@ -55,11 +55,11 @@ impl TryInto for GETQueryParams { } pub trait GetQueryStr { - fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant>; + fn get_query(&self) -> Result<&str, PipelineErrorVariant>; } impl GetQueryStr for GraphQLParams { - fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant> { + fn get_query(&self) -> Result<&str, PipelineErrorVariant> { self.query .as_deref() .ok_or(PipelineErrorVariant::GetMissingQueryParam("query")) @@ -70,25 +70,22 @@ impl GetQueryStr for GraphQLParams { pub fn deserialize_graphql_params( req: &HttpRequest, body_bytes: Bytes, -) -> Result { +) -> Result { let http_method = req.method(); let graphql_params: GraphQLParams = match *http_method { Method::GET => { trace!("processing GET GraphQL operation"); - let query_params_str = req.uri().query().ok_or_else(|| { - req.new_pipeline_error(PipelineErrorVariant::GetInvalidQueryParams) - })?; + let query_params_str = req + .uri() + .query() + .ok_or_else(|| PipelineErrorVariant::GetInvalidQueryParams)?; let query_params = Query::::from_query(query_params_str) - .map_err(|e| { - req.new_pipeline_error(PipelineErrorVariant::GetUnprocessableQueryParams(e)) - })? + .map_err(PipelineErrorVariant::GetUnprocessableQueryParams)? .0; trace!("parsed GET query params: {:?}", query_params); - query_params - .try_into() - .map_err(|err| req.new_pipeline_error(err))? + query_params.try_into()? } Method::POST => { trace!("Processing POST GraphQL request"); @@ -98,7 +95,7 @@ pub fn deserialize_graphql_params( let execution_request = unsafe { sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { warn!("Failed to parse body: {}", e); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseBody(e)) + PipelineErrorVariant::FailedToParseBody(e) })? }; @@ -107,11 +104,9 @@ pub fn deserialize_graphql_params( _ => { warn!("unsupported HTTP method: {}", http_method); - return Err( - req.new_pipeline_error(PipelineErrorVariant::UnsupportedHttpMethod( - http_method.to_owned(), - )), - ); + return Err(PipelineErrorVariant::UnsupportedHttpMethod( + http_method.to_owned(), + )); } }; diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index 71e0a197d..1d856d62e 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -10,15 +10,12 @@ use hive_router_query_planner::{ }; use http::{HeaderName, Method, StatusCode}; use ntex::{ - http::ResponseBuilder, - web::{self, error::QueryPayloadError, HttpRequest}, + http::{Response, ResponseBuilder}, + web::error::QueryPayloadError, }; use serde::{Deserialize, Serialize}; -use crate::pipeline::{ - header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, - progressive_override::LabelEvaluationError, -}; +use crate::pipeline::progressive_override::LabelEvaluationError; #[derive(Debug)] pub struct PipelineError { @@ -26,18 +23,6 @@ pub struct PipelineError { pub error: PipelineErrorVariant, } -pub trait PipelineErrorFromAcceptHeader { - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError; -} - -impl PipelineErrorFromAcceptHeader for HttpRequest { - #[inline] - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError { - let accept_ok = !self.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); - PipelineError { accept_ok, error } - } -} - #[derive(Debug, thiserror::Error)] pub enum PipelineErrorVariant { // HTTP-related errors @@ -156,11 +141,11 @@ pub struct FailedExecutionResult { pub errors: Option>, } -impl PipelineError { - pub fn into_response(self) -> web::HttpResponse { - let status = self.error.default_status_code(self.accept_ok); +impl From for Response { + fn from(val: PipelineError) -> Self { + let status = val.error.default_status_code(val.accept_ok); - if let PipelineErrorVariant::ValidationErrors(validation_errors) = self.error { + if let PipelineErrorVariant::ValidationErrors(validation_errors) = val.error { let validation_error_result = FailedExecutionResult { errors: Some(validation_errors.iter().map(|error| error.into()).collect()), }; @@ -168,8 +153,8 @@ impl PipelineError { return ResponseBuilder::new(status).json(&validation_error_result); } - let code = self.error.graphql_error_code(); - let message = self.error.graphql_error_message(); + let code = val.error.graphql_error_code(); + let message = val.error.graphql_error_message(); let graphql_error = GraphQLError::from_message_and_extensions( message, diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 56f92fece..e69a89179 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -2,13 +2,14 @@ use std::collections::HashMap; use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::shared_state::RouterSharedState; -use hive_router_plan_executor::execute_query_plan; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; -use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; +use hive_router_plan_executor::execution::plan::{ + PlanExecutionOutput, QueryPlanExecutionContext, ResultWithRequest, +}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -26,14 +27,14 @@ enum ExposeQueryPlanMode { #[inline] pub async fn execute_plan( - req: &HttpRequest, + req: HttpRequest, supergraph: &SupergraphData, - app_state: &Arc, - normalized_payload: &Arc, - query_plan_payload: &Arc, + app_state: Arc, + normalized_payload: Arc, + query_plan_payload: Arc, variable_payload: &CoerceVariablesPayload, - client_request_details: &ClientRequestDetails<'_, '_>, -) -> Result { + client_request_details: &ClientRequestDetails<'_>, +) -> Result, PipelineErrorVariant> { let mut expose_query_plan = ExposeQueryPlanMode::No; if app_state.router_config.query_planner.allow_expose { @@ -65,7 +66,7 @@ pub async fn execute_plan( metadata: &supergraph.metadata, }; - let jwt_forward_plan: Option = if app_state + let jwt_auth_forwarding: Option = if app_state .router_config .jwt .is_jwt_extensions_forwarding_enabled() @@ -79,12 +80,12 @@ pub async fn execute_plan( .forward_claims_to_upstream_extensions .field_name, ) - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))? + .map_err(PipelineErrorVariant::JwtForwardingError)? } else { None }; - execute_query_plan(QueryPlanExecutionContext { + let ctx = QueryPlanExecutionContext { router_http_request: req, query_plan: query_plan_payload, projection_plan: &normalized_payload.projection_plan, @@ -94,13 +95,13 @@ pub async fn execute_plan( client_request: client_request_details, introspection_context: &introspection_context, operation_type_name: normalized_payload.root_type_name, - jwt_auth_forwarding: &jwt_forward_plan, + jwt_auth_forwarding, executors: &supergraph.subgraph_executor_map, plugins: &app_state.plugins, - }) - .await - .map_err(|err| { + }; + + ctx.execute_query_plan().await.map_err(|err| { tracing::error!("Failed to execute query plan: {}", err); - req.new_pipeline_error(PipelineErrorVariant::PlanExecutionError(err)) + PipelineErrorVariant::PlanExecutionError(err) }) } diff --git a/bin/router/src/pipeline/header.rs b/bin/router/src/pipeline/header.rs index 92a591235..19ea8c7af 100644 --- a/bin/router/src/pipeline/header.rs +++ b/bin/router/src/pipeline/header.rs @@ -6,7 +6,7 @@ use lazy_static::lazy_static; use ntex::web::HttpRequest; use tracing::{trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; lazy_static! { pub static ref APPLICATION_JSON_STR: &'static str = "application/json"; @@ -34,31 +34,29 @@ impl RequestAccepts for HttpRequest { } pub trait AssertRequestJson { - fn assert_json_content_type(&self) -> Result<(), PipelineError>; + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant>; } impl AssertRequestJson for HttpRequest { #[inline] - fn assert_json_content_type(&self) -> Result<(), PipelineError> { + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant> { match self.headers().get(CONTENT_TYPE) { Some(value) => { - let content_type_str = value.to_str().map_err(|_| { - self.new_pipeline_error(PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE)) - })?; + let content_type_str = value + .to_str() + .map_err(|_| PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE))?; if !content_type_str.contains(*APPLICATION_JSON_STR) { warn!( "Invalid content type on a POST request: {}", content_type_str ); - return Err( - self.new_pipeline_error(PipelineErrorVariant::UnsupportedContentType) - ); + return Err(PipelineErrorVariant::UnsupportedContentType); } Ok(()) } None => { trace!("POST without content type detected"); - Err(self.new_pipeline_error(PipelineErrorVariant::MissingContentTypeHeader)) + Err(PipelineErrorVariant::MissingContentTypeHeader) } } } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index ddc949c4c..b97f7e67f 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -3,15 +3,16 @@ use std::sync::Arc; use hive_router_plan_executor::{ execution::{ client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::PlanExecutionOutput, + plan::{PlanExecutionOutput, ResultWithRequest, WithResult}, + }, + hooks::{ + on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + on_supergraph_load::SupergraphData, }, - hooks::{on_graphql_params::{ - OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload - }, on_supergraph_load::SupergraphData}, plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ - state::supergraph_state::OperationKind, utils::cancellation::CancellationToken + state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, }; use http::{header::CONTENT_TYPE, HeaderValue, Method}; use ntex::{ @@ -22,20 +23,31 @@ use ntex::{ use crate::{ jwt::context::JwtRequestContext, pipeline::{ - coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, deserialize_graphql_params::{GetQueryStr, deserialize_graphql_params}, error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, execution::execute_plan, header::{ - APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, RequestAccepts, TEXT_HTML_CONTENT_TYPE - }, normalize::normalize_request_with_cache, parser::{ParseResult, parse_operation_with_cache}, progressive_override::request_override_context, query_plan::{QueryPlanResult, plan_operation_with_cache}, validation::validate_operation_with_cache + coerce_variables::coerce_request_variables, + csrf_prevention::perform_csrf_prevention, + deserialize_graphql_params::{deserialize_graphql_params, GetQueryStr}, + error::PipelineErrorVariant, + execution::execute_plan, + header::{ + RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON, + APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, TEXT_HTML_CONTENT_TYPE, + }, + normalize::normalize_request_with_cache, + parser::{parse_operation_with_cache, ParseResult}, + progressive_override::request_override_context, + query_plan::{plan_operation_with_cache, QueryPlanResult}, + validation::validate_operation_with_cache, }, - schema_state::{SchemaState}, + schema_state::SchemaState, shared_state::RouterSharedState, }; pub mod coerce_variables; pub mod cors; pub mod csrf_prevention; +pub mod deserialize_graphql_params; pub mod error; pub mod execution; -pub mod deserialize_graphql_params; pub mod header; pub mod normalize; pub mod parser; @@ -47,70 +59,82 @@ static GRAPHIQL_HTML: &str = include_str!("../../static/graphiql.html"); #[inline] pub async fn graphql_request_handler( - req: &mut HttpRequest, + req: HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, - schema_state: &Arc, -) -> web::HttpResponse { + shared_state: Arc, + schema_state: Arc, +) -> Result, PipelineErrorVariant> { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { - return web::HttpResponse::Ok() - .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) - .body(GRAPHIQL_HTML); + return Ok(req.with_result( + web::HttpResponse::Ok() + .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) + .body(GRAPHIQL_HTML), + )); } else { - return web::HttpResponse::NotFound().into(); + return Ok(req.with_result(web::HttpResponse::NotFound().into())); } } - if let Some(jwt) = &shared_state.jwt_auth_runtime { - match jwt.validate_request(req) { - Ok(_) => (), - Err(err) => return err.make_response(), + let jwt_context = if let Some(jwt) = &shared_state.jwt_auth_runtime { + match jwt.validate_request(&req) { + Ok(jwt_context) => jwt_context, + Err(err) => return Ok(req.with_result(err.make_response())), } - } + } else { + None + }; - match execute_pipeline(req, body_bytes, supergraph, shared_state, schema_state).await { - Ok(response) => { - let response_bytes = Bytes::from(response.body); - let response_headers = response.headers; - - let response_content_type: &'static HeaderValue = - if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { - &APPLICATION_GRAPHQL_RESPONSE_JSON - } else { - &APPLICATION_JSON - }; - - let mut response_builder = web::HttpResponse::Ok(); - for (header_name, header_value) in response_headers { - if let Some(header_name) = header_name { - response_builder.header(header_name, header_value); - } - } + let response_content_type: &'static HeaderValue = + if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { + &APPLICATION_GRAPHQL_RESPONSE_JSON + } else { + &APPLICATION_JSON + }; - response_builder - .header(http::header::CONTENT_TYPE, response_content_type) - .body(response_bytes) + let execution_result_with_req = execute_pipeline( + req, + body_bytes, + supergraph, + shared_state, + schema_state, + jwt_context, + ) + .await?; + let response = execution_result_with_req.result; + let response_bytes = Bytes::from(response.body); + let response_headers = response.headers; + + let mut response_builder = web::HttpResponse::Ok(); + for (header_name, header_value) in response_headers { + if let Some(header_name) = header_name { + response_builder.header(header_name, header_value); } - Err(err) => err.into_response(), } + + Ok(execution_result_with_req.request.with_result( + response_builder + .header(http::header::CONTENT_TYPE, response_content_type) + .body(response_bytes), + )) } #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline<'req>( - req: &'req mut HttpRequest, +pub async fn execute_pipeline( + req: HttpRequest, body: Bytes, supergraph: &SupergraphData, - shared_state: &'req Arc, - schema_state: &Arc, -) -> Result { - perform_csrf_prevention(req, &shared_state.router_config.csrf)?; + shared_state: Arc, + schema_state: Arc, + jwt_context: Option, +) -> Result, PipelineErrorVariant> { + perform_csrf_prevention(&req, &shared_state.router_config.csrf)?; /* Handle on_deserialize hook in the plugins - START */ let mut deserialization_end_callbacks = vec![]; - let mut deserialization_payload: OnGraphQLParamsStartPayload<'req> = OnGraphQLParamsStartPayload { + let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { router_http_request: req, body, graphql_params: None, @@ -121,7 +145,9 @@ pub async fn execute_pipeline<'req>( match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(response); + return Ok(deserialization_payload + .router_http_request + .with_result(response)); } ControlFlowResult::OnEnd(callback) => { deserialization_end_callbacks.push(callback); @@ -129,20 +155,24 @@ pub async fn execute_pipeline<'req>( } } let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { - deserialize_graphql_params(req, deserialization_payload.body).expect("Failed to parse execution request") + deserialize_graphql_params( + &deserialization_payload.router_http_request, + deserialization_payload.body, + ) + .expect("Failed to parse execution request") }); - let mut payload = OnGraphQLParamsEndPayload { - graphql_params, - }; + let mut payload = OnGraphQLParamsEndPayload { graphql_params }; for deserialization_end_callback in deserialization_end_callbacks { let result = deserialization_end_callback(payload); payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(response); - }, + return Ok(deserialization_payload + .router_http_request + .with_result(response)); + } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again unreachable!("on_end callback returned OnEnd again"); @@ -152,49 +182,58 @@ pub async fn execute_pipeline<'req>( let mut graphql_params = payload.graphql_params; /* Handle on_deserialize hook in the plugins - END */ - let parser_payload = match parse_operation_with_cache(req, shared_state, &graphql_params).await? { + let req = deserialization_payload.router_http_request; + let parser_result = + parse_operation_with_cache(req, shared_state.clone(), &graphql_params).await?; + + let mut req = parser_result.request; + + let parser_payload = match parser_result.result { ParseResult::Payload(payload) => payload, ParseResult::Response(response) => { - return Ok(response); + return Ok(req.with_result(response)); } }; - validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) - .await?; + validate_operation_with_cache( + &mut req, + supergraph, + schema_state.clone(), + shared_state.clone(), + &parser_payload, + ) + .await?; let normalize_payload = normalize_request_with_cache( - req, supergraph, - schema_state, + schema_state.clone(), &graphql_params, &parser_payload, ) .await?; - + let variable_payload = - coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; + coerce_request_variables(&req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); - let req_extensions = req.extensions(); - let jwt_context = req_extensions.get::(); let jwt_request_details = match jwt_context { Some(jwt_context) => JwtRequestDetails::Authenticated { - token: jwt_context.token_raw.as_str(), - prefix: jwt_context.token_prefix.as_deref(), scopes: jwt_context.extract_scopes(), - claims: &jwt_context + claims: jwt_context .get_claims_value() - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?, + .map_err(PipelineErrorVariant::JwtForwardingError)?, + token: jwt_context.token_raw, + prefix: jwt_context.token_prefix, }, None => JwtRequestDetails::Unauthenticated, }; let client_request_details = ClientRequestDetails { - method: req.method(), - url: req.uri(), - headers: req.headers(), + method: req.method().clone(), + url: req.uri().clone(), + headers: req.headers().clone(), operation: OperationDetails { name: normalize_payload.operation_for_plan.name.as_deref(), kind: match normalize_payload.operation_for_plan.operation_kind { @@ -203,39 +242,41 @@ pub async fn execute_pipeline<'req>( Some(OperationKind::Subscription) => "subscription", None => "query", }, - query: graphql_params.get_query().map_err(|err| req.new_pipeline_error(err))?, + query: graphql_params.get_query()?, }, - jwt: &jwt_request_details, + jwt: jwt_request_details, }; let progressive_override_ctx = request_override_context( &shared_state.override_labels_evaluator, &client_request_details, ) - .map_err(|error| req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)))?; + .map_err(PipelineErrorVariant::LabelEvaluationError)?; - let query_plan_payload = match plan_operation_with_cache( + let query_plan_result = plan_operation_with_cache( req, supergraph, - schema_state, - &normalize_payload, + schema_state.clone(), + normalize_payload.clone(), &progressive_override_ctx, &query_plan_cancellation_token, - shared_state, + shared_state.clone(), ) - .await? { - QueryPlanResult::QueryPlan(query_plan_payload) => query_plan_payload, + .await?; + let req = query_plan_result.request; + let query_plan_payload = match query_plan_result.result { + QueryPlanResult::QueryPlan(plan) => plan, QueryPlanResult::Response(response) => { - return Ok(response); + return Ok(req.with_result(response)); } }; let execution_result = execute_plan( req, supergraph, - shared_state, - &normalize_payload, - &query_plan_payload, + shared_state.clone(), + normalize_payload.clone(), + query_plan_payload, &variable_payload, &client_request_details, ) diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index c57e2d566..54093d065 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -7,12 +7,11 @@ use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; use hive_router_query_planner::ast::operation::OperationDefinition; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState}; +use crate::schema_state::SchemaState; use tracing::{error, trace}; #[derive(Debug)] @@ -26,12 +25,11 @@ pub struct GraphQLNormalizationPayload { #[inline] pub async fn normalize_request_with_cache( - req: &HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, + schema_state: Arc, graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, -) -> Result, PipelineError> { +) -> Result, PipelineErrorVariant> { let cache_key = match &graphql_params.operation_name { Some(operation_name) => { let mut hasher = Xxh3::new(); @@ -87,7 +85,7 @@ pub async fn normalize_request_with_cache( error!("Failed to normalize GraphQL operation: {}", err); trace!("{:?}", err); - Err(req.new_pipeline_error(PipelineErrorVariant::NormalizationError(err))) + Err(PipelineErrorVariant::NormalizationError(err)) } }, } diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 18365a3e2..e6194fbd5 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,16 +2,20 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::execution::plan::{ + PlanExecutionOutput, ResultWithRequest, WithResult, +}; use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; -use hive_router_plan_executor::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; +use hive_router_plan_executor::hooks::on_graphql_parse::{ + OnGraphQLParseEndPayload, OnGraphQLParseStartPayload, +}; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; use crate::pipeline::deserialize_graphql_params::GetQueryStr; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::shared_state::RouterSharedState; use tracing::{error, trace}; @@ -28,26 +32,26 @@ pub enum ParseResult { #[inline] pub async fn parse_operation_with_cache( - req: &HttpRequest, - app_state: &Arc, + req: HttpRequest, + app_state: Arc, graphql_params: &GraphQLParams, -) -> Result { +) -> Result, PipelineErrorVariant> { let cache_key = { let mut hasher = Xxh3::new(); graphql_params.query.hash(&mut hasher); hasher.finish() }; + /* Handle on_graphql_parse hook in the plugins - START */ + let mut start_payload = OnGraphQLParseStartPayload { + router_http_request: req, + graphql_params, + document: None, + }; let parsed_operation = if let Some(cached) = app_state.parse_cache.get(&cache_key).await { trace!("Found cached parsed operation for query"); cached } else { - /* Handle on_graphql_parse hook in the plugins - START */ - let mut start_payload = OnGraphQLParseStartPayload { - router_http_request: req, - graphql_params, - document: None, - }; let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { let result = plugin.on_graphql_parse(start_payload); @@ -57,7 +61,9 @@ pub async fn parse_operation_with_cache( // continue to next plugin } ControlFlowResult::EndResponse(response) => { - return Ok(ParseResult::Response(response)); + return Ok(start_payload + .router_http_request + .with_result(ParseResult::Response(response))); } ControlFlowResult::OnEnd(callback) => { // store the callback to be called later @@ -65,25 +71,20 @@ pub async fn parse_operation_with_cache( } } } + let document = match start_payload.document { Some(parsed) => parsed, None => { - let query_str = graphql_params.get_query().map_err(|err| { - req.new_pipeline_error(err) - })?; + let query_str = graphql_params.get_query()?; let parsed = safe_parse_operation(query_str).map_err(|err| { error!("Failed to parse GraphQL operation: {}", err); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) + PipelineErrorVariant::FailedToParseOperation(err) })?; trace!("successfully parsed GraphQL operation"); parsed } }; - let mut end_payload = OnGraphQLParseEndPayload { - router_http_request: req, - graphql_params, - document, - }; + let mut end_payload = OnGraphQLParseEndPayload { document }; for callback in on_end_callbacks { let result = callback(end_payload); end_payload = result.payload; @@ -92,7 +93,9 @@ pub async fn parse_operation_with_cache( // continue to next callback } ControlFlowResult::EndResponse(response) => { - return Ok(ParseResult::Response(response)); + return Ok(start_payload + .router_http_request + .with_result(ParseResult::Response(response))); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -111,10 +114,10 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok( - ParseResult::Payload(GraphQLParserPayload { + Ok(start_payload + .router_http_request + .with_result(ParseResult::Payload(GraphQLParserPayload { parsed_operation, cache_key, - }) - ) + }))) } diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index d0b09c183..dc28a0d60 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec, 'req>( +pub fn request_override_context<'exec>( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec, 'req>, + client_request_details: &ClientRequestDetails<'exec>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; @@ -158,9 +158,9 @@ impl OverrideLabelsEvaluator { }) } - pub(crate) fn evaluate<'exec, 'req>( + pub(crate) fn evaluate<'exec>( &self, - client_request: &ClientRequestDetails<'exec, 'req>, + client_request: &ClientRequestDetails<'exec>, ) -> Result, LabelEvaluationError> { let mut active_flags = self.static_enabled_labels.clone(); diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 58b4e7475..807946296 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -1,13 +1,17 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; -use crate::schema_state::{SchemaState}; +use crate::schema_state::SchemaState; use crate::RouterSharedState; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; -use hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanStartPayload; +use hive_router_plan_executor::execution::plan::{ + PlanExecutionOutput, ResultWithRequest, WithResult, +}; +use hive_router_plan_executor::hooks::on_query_plan::{ + OnQueryPlanEndPayload, OnQueryPlanStartPayload, +}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -28,14 +32,14 @@ pub enum QueryPlanGetterError { #[inline] pub async fn plan_operation_with_cache( - req: &HttpRequest, + mut req: HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - normalized_operation: &Arc, + schema_state: Arc, + normalized_operation: Arc, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, - app_state: &Arc, -) -> Result { + app_state: Arc, +) -> Result, PipelineErrorVariant> { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -47,7 +51,7 @@ pub async fn plan_operation_with_cache( let plan_result = schema_state .plan_cache - .try_get_with(plan_cache_key, async move { + .try_get_with(plan_cache_key, async { if is_pure_introspection { return Ok(Arc::new(QueryPlan { kind: "QueryPlan".to_string(), @@ -57,7 +61,7 @@ pub async fn plan_operation_with_cache( /* Handle on_query_plan hook in the plugins - START */ let mut start_payload = OnQueryPlanStartPayload { - router_http_request: req, + router_http_request: &mut req, filtered_operation_for_plan, planner_override_context: (&request_override_context.clone()).into(), cancellation_token, @@ -90,17 +94,10 @@ pub async fn plan_operation_with_cache( (&request_override_context.clone()).into(), cancellation_token, ) - .map_err(|e| QueryPlanGetterError::Planner(e))?, + .map_err(QueryPlanGetterError::Planner)?, }; - let mut end_payload = hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanEndPayload { - router_http_request: req, - filtered_operation_for_plan, - planner_override_context: (&request_override_context.clone()).into(), - cancellation_token, - query_plan, - planner: &supergraph.planner, - }; + let mut end_payload = OnQueryPlanEndPayload { query_plan }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -124,10 +121,12 @@ pub async fn plan_operation_with_cache( .await; match plan_result { - Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), + Ok(plan) => Ok(req.with_result(QueryPlanResult::QueryPlan(plan))), Err(e) => match e.as_ref() { - QueryPlanGetterError::Planner(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), - QueryPlanGetterError::Response(response) => Ok(QueryPlanResult::Response(response.clone())), + QueryPlanGetterError::Planner(e) => Err(PipelineErrorVariant::PlannerError(e.clone())), + QueryPlanGetterError::Response(response) => { + Ok(req.with_result(QueryPlanResult::Response(response.clone()))) + } }, } } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index f97aa1661..3efbddcca 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -1,12 +1,14 @@ use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState}; +use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; use graphql_tools::validation::validate::validate; use hive_router_plan_executor::execution::plan::PlanExecutionOutput; -use hive_router_plan_executor::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use hive_router_plan_executor::hooks::on_graphql_validation::{ + OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, +}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use ntex::web::HttpRequest; @@ -14,12 +16,12 @@ use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( - req: &HttpRequest, + req: &mut HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - app_state: &Arc, + schema_state: Arc, + app_state: Arc, parser_payload: &GraphQLParserPayload, -) -> Result, PipelineError> { +) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; let validation_result = match schema_state @@ -40,7 +42,7 @@ pub async fn validate_operation_with_cache( "validation result of hash {} does not exists in cache", parser_payload.cache_key ); - + /* Handle on_graphql_validate hook in the plugins - START */ let mut start_payload = OnGraphQLValidationStartPayload::new( req, @@ -67,21 +69,14 @@ pub async fn validate_operation_with_cache( let errors = match start_payload.errors { Some(errors) => errors, - None => { - validate( - consumer_schema_ast, - &start_payload.document, - start_payload.get_validation_plan(), - ) - } + None => validate( + consumer_schema_ast, + start_payload.document, + start_payload.get_validation_plan(), + ), }; - let mut end_payload = OnGraphQLValidationEndPayload { - router_http_request: req, - schema: consumer_schema_ast, - document: &parser_payload.parsed_operation, - errors, - }; + let mut end_payload = OnGraphQLValidationEndPayload { errors }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -117,9 +112,7 @@ pub async fn validate_operation_with_cache( ); trace!("Validation errors: {:?}", validation_result); - return Err( - req.new_pipeline_error(PipelineErrorVariant::ValidationErrors(validation_result)) - ); + return Err(PipelineErrorVariant::ValidationErrors(validation_result)); } Ok(None) diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 3f79a06aa..69db5db3b 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -3,7 +3,13 @@ use async_trait::async_trait; use graphql_tools::{static_graphql::schema::Document, validation::utils::ValidationError}; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; use hive_router_plan_executor::{ - SubgraphExecutorMap, executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData}, introspection::schema::SchemaWithMetadata, plugin_trait::{ControlFlowResult, RouterPlugin} + executors::error::SubgraphExecutorError, + hooks::on_supergraph_load::{ + OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData, + }, + introspection::schema::SchemaWithMetadata, + plugin_trait::{ControlFlowResult, RouterPlugin}, + SubgraphExecutorMap, }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use hive_router_query_planner::{ @@ -18,10 +24,13 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error, trace}; use crate::{ - RouterSharedState, background_tasks::{BackgroundTask, BackgroundTasksManager}, pipeline::normalize::GraphQLNormalizationPayload, supergraph::{ + background_tasks::{BackgroundTask, BackgroundTasksManager}, + pipeline::normalize::GraphQLNormalizationPayload, + supergraph::{ base::{LoadSupergraphError, ReloadSupergraphResult, SupergraphLoader}, resolve_from_config, - } + }, + RouterSharedState, }; pub struct SchemaState { @@ -55,7 +64,7 @@ impl SchemaState { pub async fn new_from_config( bg_tasks_manager: &mut BackgroundTasksManager, router_config: Arc, - app_state: Arc + app_state: Arc, ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new(&router_config.supergraph, tx)?; @@ -91,10 +100,10 @@ impl SchemaState { match result.control_flow { ControlFlowResult::Continue => { // continue to next plugin - }, + } ControlFlowResult::EndResponse(_) => { unreachable!("Plugins should not end supergraph reload processing"); - }, + } ControlFlowResult::OnEnd(callback) => { on_end_callbacks.push(callback); } @@ -115,12 +124,16 @@ impl SchemaState { match result.control_flow { ControlFlowResult::Continue => { // continue to next callback - }, + } ControlFlowResult::EndResponse(_) => { - unreachable!("Plugins should not end supergraph reload processing"); - }, + unreachable!( + "Plugins should not end supergraph reload processing" + ); + } ControlFlowResult::OnEnd(_) => { - unreachable!("End callbacks should not register further end callbacks"); + unreachable!( + "End callbacks should not register further end callbacks" + ); } } } @@ -155,8 +168,8 @@ impl SchemaState { parsed_supergraph_sdl: &Document, plugins: Arc>>, ) -> Result { - let supergraph_state = SupergraphState::new(&parsed_supergraph_sdl); - let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; + let supergraph_state = SupergraphState::new(parsed_supergraph_sdl); + let planner = Planner::new_from_supergraph(parsed_supergraph_sdl)?; let metadata = planner.consumer_schema.schema_metadata(); let subgraph_executor_map = SubgraphExecutorMap::from_http_endpoint_map( supergraph_state.subgraph_endpoint_map, diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 35540dab2..71f28b746 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap}; +use std::collections::BTreeMap; use bytes::Bytes; use http::Method; @@ -13,28 +13,28 @@ pub struct OperationDetails<'exec> { pub kind: &'static str, } -pub struct ClientRequestDetails<'exec, 'req> { - pub method: &'req Method, - pub url: &'req http::Uri, - pub headers: &'req NtexHeaderMap, +pub struct ClientRequestDetails<'exec> { + pub method: Method, + pub url: http::Uri, + pub headers: NtexHeaderMap, pub operation: OperationDetails<'exec>, - pub jwt: &'exec JwtRequestDetails<'req>, + pub jwt: JwtRequestDetails, } -pub enum JwtRequestDetails<'exec> { +pub enum JwtRequestDetails { Authenticated { - token: &'exec str, - prefix: Option<&'exec str>, - claims: &'exec sonic_rs::Value, + token: String, + prefix: Option, + claims: sonic_rs::Value, scopes: Option>, }, Unauthenticated, } -impl From<&ClientRequestDetails<'_, '_>> for Value { +impl From<&ClientRequestDetails<'_>> for Value { fn from(details: &ClientRequestDetails) -> Self { // .request.headers - let headers_value = client_header_map_to_vrl_value(details.headers); + let headers_value = client_header_map_to_vrl_value(&details.headers); // .request.url let url_value = Self::Object(BTreeMap::from([ @@ -67,7 +67,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ])); // .request.jwt - let jwt_value = match details.jwt { + let jwt_value = match &details.jwt { JwtRequestDetails::Authenticated { token, prefix, @@ -78,7 +78,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ("token".into(), token.to_string().into()), ( "prefix".into(), - prefix.unwrap_or_default().to_string().into(), + prefix.as_deref().unwrap_or_default().to_string().into(), ), ("claims".into(), sonic_value_to_vrl_value(claims)), ( diff --git a/lib/executor/src/execution/error.rs b/lib/executor/src/execution/error.rs index aaaf9a729..63460eb48 100644 --- a/lib/executor/src/execution/error.rs +++ b/lib/executor/src/execution/error.rs @@ -116,7 +116,7 @@ impl IntoPlanExecutionError for Result { let kind = PlanExecutionErrorKind::ProjectionFailure(source); PlanExecutionError::new(kind, context) }) - } + } } impl IntoPlanExecutionError for Result { diff --git a/lib/executor/src/execution/jwt_forward.rs b/lib/executor/src/execution/jwt_forward.rs index 24c19ff9f..9aefc601c 100644 --- a/lib/executor/src/execution/jwt_forward.rs +++ b/lib/executor/src/execution/jwt_forward.rs @@ -8,7 +8,7 @@ pub struct JwtAuthForwardingPlan { pub extension_field_value: Value, } -impl JwtRequestDetails<'_> { +impl JwtRequestDetails { pub fn build_forwarding_plan( &self, extension_field_name: &str, diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 520f429c8..5ea8bd758 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -1,4 +1,7 @@ -use std::collections::{BTreeSet, HashMap}; +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; @@ -12,49 +15,58 @@ use serde::Deserialize; use sonic_rs::ValueRef; use crate::{ - context::ExecutionContext, execution::{ + context::ExecutionContext, + execution::{ client_request_details::ClientRequestDetails, error::{IntoPlanExecutionError, LazyPlanContext, PlanExecutionError}, jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt, - }, executors::{ + }, + executors::{ common::{HttpExecutionResponse, SubgraphExecutionRequest}, map::SubgraphExecutorMap, - }, headers::{ + }, + headers::{ plan::HeaderRulesPlan, request::modify_subgraph_request_headers, response::{apply_subgraph_response_headers, modify_client_response_headers}, - }, hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, introspection::{ - resolve::{IntrospectionContext, resolve_introspection}, + }, + hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + introspection::{ + resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, - }, plugin_trait::{ControlFlowResult, RouterPlugin}, projection::{ + }, + plugin_trait::{ControlFlowResult, RouterPlugin}, + projection::{ plan::FieldProjectionPlan, - request::{RequestProjectionContext, project_requires}, + request::{project_requires, RequestProjectionContext}, response::project_by_operation, - }, response::{ + }, + response::{ graphql_error::{GraphQLError, GraphQLErrorExtensions, GraphQLErrorPath}, merge::deep_merge, subgraph_response::SubgraphResponse, value::Value, - }, utils::{ + }, + utils::{ consts::{CLOSE_BRACKET, OPEN_BRACKET}, traverse::{traverse_and_callback, traverse_and_callback_mut}, - } + }, }; -pub struct QueryPlanExecutionContext<'exec, 'req> { - pub router_http_request: &'exec HttpRequest, +pub struct QueryPlanExecutionContext<'exec> { + pub router_http_request: HttpRequest, pub plugins: &'exec Vec>, - pub query_plan: &'exec QueryPlan, + pub query_plan: Arc, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, pub extensions: Option>, - pub client_request: &'exec ClientRequestDetails<'exec, 'req>, + pub client_request: &'exec ClientRequestDetails<'exec>, pub introspection_context: &'exec IntrospectionContext<'exec, 'static>, pub operation_type_name: &'exec str, pub executors: &'exec SubgraphExecutorMap, - pub jwt_auth_forwarding: &'exec Option, + pub jwt_auth_forwarding: Option, } #[derive(Clone)] @@ -63,119 +75,140 @@ pub struct PlanExecutionOutput { pub headers: HeaderMap, } -pub async fn execute_query_plan<'exec, 'req>( - ctx: QueryPlanExecutionContext<'exec, 'req>, -) -> Result { - let init_value = if let Some(introspection_query) = ctx.introspection_context.query { - resolve_introspection(introspection_query, ctx.introspection_context) - } else { - Value::Null - }; +pub struct ResultWithRequest { + pub result: T, + pub request: HttpRequest, +} - let dedupe_subgraph_requests = ctx.operation_type_name == "Query"; +pub trait WithResult { + fn with_result(self, result: T) -> ResultWithRequest; +} - let mut start_payload = OnExecuteStartPayload { - router_http_request: ctx.router_http_request, - query_plan: ctx.query_plan, - data: init_value, - errors: Vec::new(), - extensions: ctx.extensions.clone(), - variable_values: ctx.variable_values, - dedupe_subgraph_requests, - }; +impl WithResult for HttpRequest { + fn with_result(self, result: T) -> ResultWithRequest { + ResultWithRequest { result, request: self } + } +} - let mut on_end_callbacks = vec![]; +impl<'exec> QueryPlanExecutionContext<'exec> { + pub async fn execute_query_plan( + self, + ) -> Result, PlanExecutionError> { + let init_value = if let Some(introspection_query) = self.introspection_context.query { + resolve_introspection(introspection_query, self.introspection_context) + } else { + Value::Null + }; - for plugin in ctx.plugins { - let result = plugin.on_execute(start_payload); - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ }, - ControlFlowResult::EndResponse(response) => { - return Ok(response); - }, - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + let dedupe_subgraph_requests = self.operation_type_name == "Query"; + + let mut start_payload = OnExecuteStartPayload { + router_http_request: self.router_http_request, + query_plan: self.query_plan, + data: init_value, + errors: Vec::new(), + extensions: self.extensions.clone(), + variable_values: self.variable_values, + dedupe_subgraph_requests, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in self.plugins { + let result = plugin.on_execute(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(start_payload.router_http_request.with_result(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } - } - let init_value = start_payload.data; - - let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value); - let executor = Executor::new( - ctx.variable_values, - ctx.executors, - ctx.introspection_context.metadata, - ctx.client_request, - ctx.headers_plan, - ctx.jwt_auth_forwarding, - // Deduplicate subgraph requests only if the operation type is a query - ctx.operation_type_name == "Query", - ); - - if ctx.query_plan.node.is_some() { - executor - .execute(&mut exec_ctx, ctx.query_plan.node.as_ref()) - .await?; - } + let query_plan = start_payload.query_plan; + + let init_value = start_payload.data; - let mut response_headers = HeaderMap::new(); - modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + let mut exec_ctx = ExecutionContext::new(&query_plan, init_value); + let executor = Executor::new( + self.variable_values, + self.executors, + self.introspection_context.metadata, + self.client_request, + self.headers_plan, + self.jwt_auth_forwarding, + // Deduplicate subgraph requests only if the operation type is a query + self.operation_type_name == "Query", + ); + + if query_plan.node.is_some() { + executor + .execute(&mut exec_ctx, query_plan.node.as_ref()) + .await?; + } + + let mut response_headers = HeaderMap::new(); + modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + .with_plan_context(LazyPlanContext { + subgraph_name: || None, + affected_path: || None, + })?; + + let mut end_payload = OnExecuteEndPayload { + data: exec_ctx.final_response, + errors: exec_ctx.errors, + extensions: start_payload.extensions, + response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(output) => { + return Ok(start_payload.router_http_request.with_result(output)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + + let body = project_by_operation( + &end_payload.data, + end_payload.errors, + &self.extensions, + self.operation_type_name, + self.projection_plan, + self.variable_values, + end_payload.response_size_estimate, + ) .with_plan_context(LazyPlanContext { subgraph_name: || None, affected_path: || None, })?; - let mut end_payload = OnExecuteEndPayload { - data: exec_ctx.final_response, - errors: exec_ctx.errors, - extensions: start_payload.extensions, - response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), - }; - - for callback in on_end_callbacks { - let result = callback(end_payload); - end_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next callback */ }, - ControlFlowResult::EndResponse(response) => { - return Ok(response); - }, - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); - } - } + Ok(start_payload + .router_http_request + .with_result(PlanExecutionOutput { + body, + headers: response_headers, + })) } - - let body = project_by_operation( - &end_payload.data, - end_payload.errors, - &ctx.extensions, - ctx.operation_type_name, - ctx.projection_plan, - ctx.variable_values, - end_payload.response_size_estimate, - ) - .with_plan_context(LazyPlanContext { - subgraph_name: || None, - affected_path: || None, - })?; - - Ok(PlanExecutionOutput { - body, - headers: response_headers, - }) } -pub struct Executor<'exec, 'req> { +pub struct Executor<'exec> { variable_values: &'exec Option>, schema_metadata: &'exec SchemaMetadata, executors: &'exec SubgraphExecutorMap, - client_request: &'exec ClientRequestDetails<'exec, 'req>, + client_request: &'exec ClientRequestDetails<'exec>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, } @@ -263,14 +296,14 @@ struct PreparedFlattenData { representation_hash_to_index: HashMap, } -impl<'exec, 'req> Executor<'exec, 'req> { +impl<'exec> Executor<'exec> { pub fn new( variable_values: &'exec Option>, executors: &'exec SubgraphExecutorMap, schema_metadata: &'exec SchemaMetadata, - client_request: &'exec ClientRequestDetails<'exec, 'req>, + client_request: &'exec ClientRequestDetails<'exec>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, ) -> Self { Executor { diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 1e00e6fce..a5a5ed263 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -2,7 +2,9 @@ use std::sync::Arc; use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; -use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; +use crate::hooks::on_subgraph_http_request::{ + OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, +}; use crate::plugin_trait::{ControlFlowResult, RouterPlugin}; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; @@ -11,8 +13,8 @@ use tokio::sync::OnceCell; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; -use http::{HeaderMap, StatusCode}; use http::HeaderValue; +use http::{HeaderMap, StatusCode}; use http_body_util::BodyExt; use http_body_util::Full; use hyper::Version; @@ -169,104 +171,101 @@ async fn send_request( headers: HeaderMap, plugins: Arc>>, ) -> Result { - let mut req = hyper::Request::builder() - .method(method) - .uri(endpoint) - .version(Version::HTTP_11) - .body(Full::new(Bytes::from(body))) - .map_err(|e| { - SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) - })?; - - *req.headers_mut() = headers; - - let mut start_payload = OnSubgraphHttpRequestPayload { - subgraph_name, - request: req, - response: None, - }; + let mut req = hyper::Request::builder() + .method(method) + .uri(endpoint) + .version(Version::HTTP_11) + .body(Full::new(Bytes::from(body))) + .map_err(|e| { + SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) + })?; - let mut on_end_callbacks = vec![]; - - for plugin in plugins.as_ref() { - let result = plugin.on_subgraph_http_request(start_payload); - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { - // TODO: Fixx - return Ok(SharedResponse { - status: StatusCode::OK, - body: response.body.into(), - headers: response.headers, - }); - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); - } + *req.headers_mut() = headers; + + let mut start_payload = OnSubgraphHttpRequestPayload { + subgraph_name, + request: req, + response: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.as_ref() { + let result = plugin.on_subgraph_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + // TODO: Fixx + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); } } + } - debug!("making http request to {}", endpoint.to_string()); + debug!("making http request to {}", endpoint.to_string()); - let req = start_payload.request; + let req = start_payload.request; - let res = http_client.request(req).await.map_err(|e| { - SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) - })?; + let res = http_client + .request(req) + .await + .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))?; - debug!( - "http request to {} completed, status: {}", - endpoint.to_string(), - res.status() - ); + debug!( + "http request to {} completed, status: {}", + endpoint.to_string(), + res.status() + ); - let (parts, body) = res.into_parts(); - let body = body - .collect() - .await - .map_err(|e| { - SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) - })? - .to_bytes(); - - if body.is_empty() { - return Err(SubgraphExecutorError::RequestFailure( - endpoint.to_string(), - "Empty response body".to_string(), - )); - } - - let response = SharedResponse { - status: parts.status, - body: body, - headers: parts.headers, - }; + let (parts, body) = res.into_parts(); + let body = body + .collect() + .await + .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))? + .to_bytes(); - let mut end_payload = OnSubgraphHttpResponsePayload { - response, - }; + if body.is_empty() { + return Err(SubgraphExecutorError::RequestFailure( + endpoint.to_string(), + "Empty response body".to_string(), + )); + } - for callback in on_end_callbacks { - let result = callback(end_payload); - end_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next callback */ } - ControlFlowResult::EndResponse(response) => { - return Ok(SharedResponse { - status: StatusCode::OK, - body: response.body.into(), - headers: response.headers, - }); - } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); - } + let response = SharedResponse { + status: parts.status, + body, + headers: parts.headers, + }; + + let mut end_payload = OnSubgraphHttpResponsePayload { response }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(response) => { + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); } } + } - Ok(end_payload.response) + Ok(end_payload.response) } #[async_trait] @@ -298,7 +297,17 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - return match send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await { + return match send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + headers, + self.plugins.clone(), + ) + .await + { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body, headers: shared_response.headers, @@ -330,7 +339,16 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await + send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + headers, + self.plugins.clone(), + ) + .await }; // It's important to remove the entry from the map before returning the result. // This ensures that once the OnceCell is set, no future requests can join it. diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 6f780b76f..cf2f6dc13 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,14 +27,19 @@ use vrl::{ }; use crate::{ - execution::client_request_details::ClientRequestDetails, executors::{ + execution::client_request_details::ClientRequestDetails, + executors::{ common::{ - HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc + HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, + SubgraphExecutorBoxedArc, }, dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient}, - }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError + }, + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_trait::{ControlFlowResult, RouterPlugin}, + response::graphql_error::GraphQLError, }; type SubgraphName = String; @@ -119,11 +124,11 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - pub async fn execute<'a, 'req>( + pub async fn execute<'a>( &self, subgraph_name: &str, execution_request: SubgraphExecutionRequest<'a>, - client_request: &ClientRequestDetails<'a, 'req>, + client_request: &ClientRequestDetails<'a>, ) -> HttpExecutionResponse { let mut start_payload = OnSubgraphExecuteStartPayload { subgraph_name: subgraph_name.to_string(), @@ -156,9 +161,7 @@ impl SubgraphExecutorMap { let execution_request = start_payload.execution_request; let execution_result = match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor - .execute(execution_request) - .await, + Ok(Some(executor)) => executor.execute(execution_request).await, Err(err) => { error!( "Subgraph executor error for subgraph '{}': {}", @@ -174,10 +177,8 @@ impl SubgraphExecutorMap { self.internal_server_error_response("Internal server error".into(), subgraph_name) } }; - - let mut end_payload = OnSubgraphExecuteEndPayload { - execution_result - }; + + let mut end_payload = OnSubgraphExecuteEndPayload { execution_result }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -187,11 +188,11 @@ impl SubgraphExecutorMap { // continue to next callback } ControlFlowResult::EndResponse(response) => { - // TODO: FFIX - return HttpExecutionResponse { - body: response.body.into(), - headers: response.headers, - }; + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + }; } ControlFlowResult::OnEnd(_) => { unreachable!("End callbacks should not register further end callbacks"); @@ -226,7 +227,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_, '_>, + client_request: &ClientRequestDetails<'_>, ) -> Result, SubgraphExecutorError> { let from_expression = self.get_or_create_executor_from_expression(subgraph_name, client_request)?; @@ -245,7 +246,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor_from_expression( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_, '_>, + client_request: &ClientRequestDetails<'_>, ) -> Result, SubgraphExecutorError> { if let Some(expression) = self.expressions_by_subgraph.get(subgraph_name) { let original_url_value = VrlValue::Bytes(Bytes::from( diff --git a/lib/executor/src/headers/expression.rs b/lib/executor/src/headers/expression.rs index ed63ebfc0..5852f2b75 100644 --- a/lib/executor/src/headers/expression.rs +++ b/lib/executor/src/headers/expression.rs @@ -46,7 +46,7 @@ fn header_map_to_vrl_value(headers: &HeaderMap) -> Value { Value::Object(obj) } -impl From<&RequestExpressionContext<'_, '_>> for Value { +impl From<&RequestExpressionContext<'_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &RequestExpressionContext) -> Self { // .subgraph @@ -65,7 +65,7 @@ impl From<&RequestExpressionContext<'_, '_>> for Value { } } -impl From<&ResponseExpressionContext<'_, '_>> for Value { +impl From<&ResponseExpressionContext<'_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &ResponseExpressionContext) -> Self { // .subgraph diff --git a/lib/executor/src/headers/mod.rs b/lib/executor/src/headers/mod.rs index 62f9fe701..c617edfa0 100644 --- a/lib/executor/src/headers/mod.rs +++ b/lib/executor/src/headers/mod.rs @@ -74,15 +74,15 @@ mod tests { ); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -108,15 +108,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); modify_subgraph_request_headers(&plan, "any", &client_details, &mut out).unwrap(); @@ -155,15 +155,15 @@ mod tests { ); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -193,15 +193,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: Some("MyQuery"), query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -227,15 +227,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -267,15 +267,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; // For "accounts" subgraph, the specific rule should apply. @@ -311,15 +311,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -376,15 +376,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -440,15 +440,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -497,15 +497,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -555,15 +555,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -614,15 +614,15 @@ mod tests { client_headers.insert(header_name_owned("x-keep"), header_value_owned("hi").into()); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); diff --git a/lib/executor/src/headers/request.rs b/lib/executor/src/headers/request.rs index 637ab0d58..44b6ed8b9 100644 --- a/lib/executor/src/headers/request.rs +++ b/lib/executor/src/headers/request.rs @@ -45,9 +45,9 @@ pub fn modify_subgraph_request_headers( Ok(()) } -pub struct RequestExpressionContext<'a, 'req> { +pub struct RequestExpressionContext<'a> { pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, + pub client_request: &'a ClientRequestDetails<'a>, } trait ApplyRequestHeader { @@ -117,7 +117,7 @@ impl ApplyRequestHeader for RequestPropagateRegex { ctx: &RequestExpressionContext, output_headers: &mut HeaderMap, ) -> Result<(), HeaderRuleRuntimeError> { - for (header_name, header_value) in ctx.client_request.headers { + for (header_name, header_value) in ctx.client_request.headers.iter() { if is_denied_header(header_name) { continue; } diff --git a/lib/executor/src/headers/response.rs b/lib/executor/src/headers/response.rs index 6a5c34444..94019d585 100644 --- a/lib/executor/src/headers/response.rs +++ b/lib/executor/src/headers/response.rs @@ -50,9 +50,9 @@ pub fn apply_subgraph_response_headers( Ok(()) } -pub struct ResponseExpressionContext<'a, 'req> { +pub struct ResponseExpressionContext<'a> { pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, + pub client_request: &'a ClientRequestDetails<'a>, pub subgraph_headers: &'a HeaderMap, } diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index 1f29c192e..bdcbdadc0 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -10,6 +10,5 @@ pub mod response; pub mod utils; pub mod variables; -pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; pub use plugins::*; diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index d5400d314..f5e380973 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -13,11 +13,13 @@ pub struct APQPlugin { impl RouterPlugin for APQPlugin { fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> - { + payload: OnGraphQLParamsStartPayload, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { payload.on_end(|mut payload| { - let persisted_query_ext = payload.graphql_params.extensions.as_ref() + let persisted_query_ext = payload + .graphql_params + .extensions + .as_ref() .and_then(|ext| ext.get("persistedQuery")) .and_then(|pq| pq.as_object()); if let Some(persisted_query_ext) = persisted_query_ext { @@ -28,7 +30,10 @@ impl RouterPlugin for APQPlugin { return payload.cont(); } } - let sha256_hash = match persisted_query_ext.get(&"sha256Hash").and_then(|h| h.as_str()) { + let sha256_hash = match persisted_query_ext + .get(&"sha256Hash") + .and_then(|h| h.as_str()) + { Some(h) => h, None => { return payload.cont(); @@ -36,7 +41,8 @@ impl RouterPlugin for APQPlugin { }; if let Some(query_param) = &payload.graphql_params.query { // Store the query in the cache - self.cache.insert(sha256_hash.to_string(), query_param.to_string()); + self.cache + .insert(sha256_hash.to_string(), query_param.to_string()); } else { // Try to get the query from the cache if let Some(cached_query) = self.cache.get(sha256_hash) { diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index 68e3e7092..a6d766a9c 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -1,3 +1,3 @@ +pub mod apq; pub mod response_cache; pub mod subgraph_response_cache; -pub mod apq; \ No newline at end of file diff --git a/lib/executor/src/plugins/examples/multipart.rs b/lib/executor/src/plugins/examples/multipart.rs new file mode 100644 index 000000000..e69de29bb diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index d9d611307..fc6276643 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -5,7 +5,8 @@ use redis::Commands; use crate::{ execution::plan::PlanExecutionOutput, hooks::{ - on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, }, plugin_trait::{EndPayload, HookResult, StartPayload}, plugins::plugin_trait::RouterPlugin, @@ -87,24 +88,22 @@ impl RouterPlugin for ResponseCachePlugin { } payload.cont() } - fn on_supergraph_reload<'a>(&'a self, payload: OnSupergraphLoadStartPayload) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + fn on_supergraph_reload<'a>( + &'a self, + payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { // Visit the schema and update ttl_per_type based on some directive - payload - .new_ast - .definitions - .iter() - .for_each(|def| { - if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { - if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { - for directive in &obj_type.directives { - if directive.name == "cacheControl" { - for arg in &directive.arguments { - if arg.0 == "maxAge" { - if let graphql_parser::query::Value::Int(max_age) = &arg.1 { - if let Some(max_age) = max_age.as_i64() { - self.ttl_per_type - .insert(obj_type.name.clone(), max_age as u64); - } + payload.new_ast.definitions.iter().for_each(|def| { + if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { + if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { + for directive in &obj_type.directives { + if directive.name == "cacheControl" { + for arg in &directive.arguments { + if arg.0 == "maxAge" { + if let graphql_parser::query::Value::Int(max_age) = &arg.1 { + if let Some(max_age) = max_age.as_i64() { + self.ttl_per_type + .insert(obj_type.name.clone(), max_age as u64); } } } @@ -112,7 +111,8 @@ impl RouterPlugin for ResponseCachePlugin { } } } - }); + } + }); payload.cont() } diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 71ff8c1d9..4d192dd39 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,6 +1,10 @@ use dashmap::DashMap; -use crate::{executors::{common::HttpExecutionResponse}, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; +use crate::{ + executors::common::HttpExecutionResponse, + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; pub struct SubgraphResponseCachePlugin { cache: DashMap, @@ -8,9 +12,9 @@ pub struct SubgraphResponseCachePlugin { impl RouterPlugin for SubgraphResponseCachePlugin { fn on_subgraph_execute<'exec>( - &'exec self, - mut payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { let key = format!( "subgraph_response_cache:{}:{:?}", payload.execution_request.query, payload.execution_request.variables @@ -21,10 +25,10 @@ impl RouterPlugin for SubgraphResponseCachePlugin { payload.execution_result = Some(cached_response.clone()); return payload.cont(); } - payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { + payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { // Here payload.response is not Option self.cache.insert(key, payload.execution_result.clone()); payload.cont() }) } -} \ No newline at end of file +} diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 453c84c98..64851d0fd 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,9 +1,9 @@ pub mod on_execute; -pub mod on_supergraph_load; -pub mod on_subgraph_http_request; -pub mod on_http_request; pub mod on_graphql_params; pub mod on_graphql_parse; pub mod on_graphql_validation; +pub mod on_http_request; pub mod on_query_plan; -pub mod on_subgraph_execute; \ No newline at end of file +pub mod on_subgraph_execute; +pub mod on_subgraph_http_request; +pub mod on_supergraph_load; diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index dfcdaceb8..9a77679b9 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,15 +1,16 @@ use std::collections::HashMap; +use std::sync::Arc; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use ntex::web::HttpRequest; use crate::plugin_trait::{EndPayload, StartPayload}; -use crate::response::{value::Value}; use crate::response::graphql_error::GraphQLError; +use crate::response::value::Value; pub struct OnExecuteStartPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: &'exec QueryPlan, + pub router_http_request: HttpRequest, + pub query_plan: Arc, pub data: Value<'exec>, pub errors: Vec, diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs index 5e6ce1c47..a9afabed1 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_params.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -1,41 +1,103 @@ +use core::fmt; + use std::collections::HashMap; use ntex::util::Bytes; -use serde::Deserialize; -use serde::Deserializer; +use serde::{de, Deserialize, Deserializer}; use sonic_rs::Value; use crate::plugin_trait::EndPayload; use crate::plugin_trait::StartPayload; -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] +#[derive(Debug, Clone, Default)] pub struct GraphQLParams { pub query: Option, pub operation_name: Option, - #[serde(default, deserialize_with = "deserialize_null_default")] pub variables: HashMap, // TODO: We don't use extensions yet, but we definitely will in the future. #[allow(dead_code)] pub extensions: Option>, } -fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result -where - T: Default + Deserialize<'de>, - D: Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) +// Workaround for https://github.com/cloudwego/sonic-rs/issues/114 + +impl<'de> Deserialize<'de> for GraphQLParams { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct GraphQLErrorExtensionsVisitor; + + impl<'de> de::Visitor<'de> for GraphQLErrorExtensionsVisitor { + type Value = GraphQLParams; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map for GraphQLErrorExtensions") + } + + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut query = None; + let mut operation_name = None; + let mut variables: Option> = None; + let mut extensions: Option> = None; + let mut extra_params = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "query" => { + if query.is_some() { + return Err(de::Error::duplicate_field("query")); + } + query = map.next_value::>()?; + } + "operationName" => { + if operation_name.is_some() { + return Err(de::Error::duplicate_field("operationName")); + } + operation_name = map.next_value::>()?; + } + "variables" => { + if variables.is_some() { + return Err(de::Error::duplicate_field("variables")); + } + variables = map.next_value::>>()?; + } + "extensions" => { + if extensions.is_some() { + return Err(de::Error::duplicate_field("extensions")); + } + extensions = map.next_value::>>()?; + } + other => { + let value: Value = map.next_value()?; + extra_params.insert(other.to_string(), value); + } + } + } + + Ok(GraphQLParams { + query, + operation_name, + variables: variables.unwrap_or_default(), + extensions, + }) + } + } + + deserializer.deserialize_map(GraphQLErrorExtensionsVisitor) + } } -pub struct OnGraphQLParamsStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, +pub struct OnGraphQLParamsStartPayload { + pub router_http_request: ntex::web::HttpRequest, pub body: Bytes, pub graphql_params: Option, } -impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} +impl StartPayload for OnGraphQLParamsStartPayload {} pub struct OnGraphQLParamsEndPayload { pub graphql_params: GraphQLParams, diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs index 162a7eee2..df9b4e480 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_parse.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -1,19 +1,20 @@ use graphql_tools::static_graphql::query::Document; -use crate::{hooks::on_graphql_params::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; +use crate::{ + hooks::on_graphql_params::GraphQLParams, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnGraphQLParseStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, + pub router_http_request: ntex::web::HttpRequest, pub graphql_params: &'exec GraphQLParams, pub document: Option, } -impl<'exec> StartPayload> for OnGraphQLParseStartPayload<'exec> {} +impl<'exec> StartPayload for OnGraphQLParseStartPayload<'exec> {} -pub struct OnGraphQLParseEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub graphql_params: &'exec GraphQLParams, +pub struct OnGraphQLParseEndPayload { pub document: Document, } -impl<'exec> EndPayload for OnGraphQLParseEndPayload<'exec> {} \ No newline at end of file +impl EndPayload for OnGraphQLParseEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index a789cb5fd..f6bb55004 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -1,13 +1,17 @@ use graphql_tools::{ static_graphql::query::Document, - validation::{rules::{ValidationRule, default_rules_validation_plan}, utils::ValidationError, validate::ValidationPlan}, + validation::{ + rules::{default_rules_validation_plan, ValidationRule}, + utils::ValidationError, + validate::ValidationPlan, + }, }; use hive_router_query_planner::state::supergraph_state::SchemaDocument; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnGraphQLValidationStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, + pub router_http_request: &'exec mut ntex::web::HttpRequest, pub schema: &'exec SchemaDocument, pub document: &'exec Document, default_validation_plan: &'exec ValidationPlan, @@ -15,14 +19,11 @@ pub struct OnGraphQLValidationStartPayload<'exec> { pub errors: Option>, } -impl<'exec> StartPayload> - for OnGraphQLValidationStartPayload<'exec> -{ -} +impl<'exec> StartPayload for OnGraphQLValidationStartPayload<'exec> {} impl<'exec> OnGraphQLValidationStartPayload<'exec> { pub fn new( - router_http_request: &'exec ntex::web::HttpRequest, + router_http_request: &'exec mut ntex::web::HttpRequest, schema: &'exec SchemaDocument, document: &'exec Document, default_validation_plan: &'exec ValidationPlan, @@ -39,17 +40,17 @@ impl<'exec> OnGraphQLValidationStartPayload<'exec> { pub fn add_validation_rule(&mut self, rule: Box) { self.new_validation_plan - .get_or_insert_with(|| default_rules_validation_plan()) + .get_or_insert_with(default_rules_validation_plan) .add_rule(rule); } - + pub fn filter_validation_rules(&mut self, mut f: F) where F: FnMut(&Box) -> bool, { let plan = self .new_validation_plan - .get_or_insert_with(|| default_rules_validation_plan()); + .get_or_insert_with(default_rules_validation_plan); plan.rules.retain(|rule| f(rule)); } @@ -61,11 +62,8 @@ impl<'exec> OnGraphQLValidationStartPayload<'exec> { } } -pub struct OnGraphQLValidationEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub schema: &'exec SchemaDocument, - pub document: &'exec Document, +pub struct OnGraphQLValidationEndPayload { pub errors: Vec, } -impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} +impl EndPayload for OnGraphQLValidationEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 29a8344e5..a7f6f6bb5 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -13,4 +13,4 @@ pub struct OnHttpResponse<'exec> { pub response: &'exec mut Response, } -impl<'exec> EndPayload for OnHttpResponse<'exec> {} \ No newline at end of file +impl<'exec> EndPayload for OnHttpResponse<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs index 39ae3c2d6..fd2089ec6 100644 --- a/lib/executor/src/plugins/hooks/on_query_plan.rs +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -1,9 +1,14 @@ -use hive_router_query_planner::{ast::operation::OperationDefinition, graph::PlannerOverrideContext, planner::{Planner, plan_nodes::QueryPlan}, utils::cancellation::CancellationToken}; +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + graph::PlannerOverrideContext, + planner::{plan_nodes::QueryPlan, Planner}, + utils::cancellation::CancellationToken, +}; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnQueryPlanStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, + pub router_http_request: &'exec mut ntex::web::HttpRequest, pub filtered_operation_for_plan: &'exec OperationDefinition, pub planner_override_context: PlannerOverrideContext, pub cancellation_token: &'exec CancellationToken, @@ -11,15 +16,10 @@ pub struct OnQueryPlanStartPayload<'exec> { pub planner: &'exec Planner, } -impl<'exec> StartPayload> for OnQueryPlanStartPayload<'exec> {} +impl<'exec> StartPayload for OnQueryPlanStartPayload<'exec> {} -pub struct OnQueryPlanEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub filtered_operation_for_plan: &'exec OperationDefinition, - pub planner_override_context: PlannerOverrideContext, - pub cancellation_token: &'exec CancellationToken, +pub struct OnQueryPlanEndPayload { pub query_plan: QueryPlan, - pub planner: &'exec Planner, } -impl<'exec> EndPayload for OnQueryPlanEndPayload<'exec> {} \ No newline at end of file +impl EndPayload for OnQueryPlanEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 6a514006f..5a6fcc6a6 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,14 +1,14 @@ - - -use crate::{executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, plugin_trait::{EndPayload, StartPayload}}; - +use crate::{ + executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnSubgraphExecuteStartPayload<'exec> { pub subgraph_name: String, pub execution_request: SubgraphExecutionRequest<'exec>, pub execution_result: Option, -} +} impl<'exec> StartPayload for OnSubgraphExecuteStartPayload<'exec> {} @@ -16,4 +16,4 @@ pub struct OnSubgraphExecuteEndPayload { pub execution_result: HttpExecutionResponse, } -impl<'exec> EndPayload for OnSubgraphExecuteEndPayload {} \ No newline at end of file +impl EndPayload for OnSubgraphExecuteEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index f7c798ed7..1b50f001b 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,10 +1,11 @@ use bytes::Bytes; -use http::{Request}; +use http::Request; use http_body_util::Full; use crate::{ - executors::{dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} -; + executors::dedupe::SharedResponse, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnSubgraphHttpRequestPayload<'exec> { pub subgraph_name: &'exec str, @@ -18,7 +19,7 @@ pub struct OnSubgraphHttpRequestPayload<'exec> { impl<'exec> StartPayload for OnSubgraphHttpRequestPayload<'exec> {} pub struct OnSubgraphHttpResponsePayload { - pub response: SharedResponse, + pub response: SharedResponse, } -impl<'exec> EndPayload for OnSubgraphHttpResponsePayload {} +impl EndPayload for OnSubgraphHttpResponsePayload {} diff --git a/lib/executor/src/plugins/hooks/on_supergraph_load.rs b/lib/executor/src/plugins/hooks/on_supergraph_load.rs index e4e68ca35..21dfbf5a5 100644 --- a/lib/executor/src/plugins/hooks/on_supergraph_load.rs +++ b/lib/executor/src/plugins/hooks/on_supergraph_load.rs @@ -1,11 +1,14 @@ use std::sync::Arc; +use arc_swap::ArcSwap; use graphql_tools::static_graphql::schema::Document; -use hive_router_query_planner::{planner::Planner}; -use arc_swap::{ArcSwap}; - -use crate::{SubgraphExecutorMap, introspection::schema::SchemaMetadata, plugin_trait::{EndPayload, StartPayload}}; +use hive_router_query_planner::planner::Planner; +use crate::{ + introspection::schema::SchemaMetadata, + plugin_trait::{EndPayload, StartPayload}, + SubgraphExecutorMap, +}; pub struct SupergraphData { pub metadata: SchemaMetadata, @@ -24,4 +27,4 @@ pub struct OnSupergraphLoadEndPayload { pub new_supergraph_data: SupergraphData, } -impl EndPayload for OnSupergraphLoadEndPayload {} \ No newline at end of file +impl EndPayload for OnSupergraphLoadEndPayload {} diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs index 6c35286af..02490fb5e 100644 --- a/lib/executor/src/plugins/mod.rs +++ b/lib/executor/src/plugins/mod.rs @@ -1,3 +1,3 @@ pub mod examples; +pub mod hooks; pub mod plugin_trait; -pub mod hooks; \ No newline at end of file diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index d56856652..c502ba087 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -86,30 +86,27 @@ pub trait RouterPlugin { } fn on_graphql_params<'exec>( &'exec self, - start_payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + start_payload: OnGraphQLParamsStartPayload, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { start_payload.cont() } fn on_graphql_parse<'exec>( &self, start_payload: OnGraphQLParseStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload<'exec>> { + ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { start_payload.cont() } fn on_graphql_validation<'exec>( &self, start_payload: OnGraphQLValidationStartPayload<'exec>, - ) -> HookResult< - 'exec, - OnGraphQLValidationStartPayload<'exec>, - OnGraphQLValidationEndPayload<'exec>, - > { + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { start_payload.cont() } fn on_query_plan<'exec>( &self, start_payload: OnQueryPlanStartPayload<'exec>, - ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { start_payload.cont() } fn on_execute<'exec>( @@ -121,15 +118,13 @@ pub trait RouterPlugin { fn on_subgraph_execute<'exec>( &'exec self, start_payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> - { + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { start_payload.cont() } fn on_subgraph_http_request<'exec>( &'exec self, start_payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> - { + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { start_payload.cont() } fn on_supergraph_reload<'exec>( From 9ebd0fb895a8bf24ed0e847d2153925cb5f8be81 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Thu, 20 Nov 2025 18:09:14 +0300 Subject: [PATCH 09/31] Async Hooks and Shared Context --- Cargo.lock | 2 + bin/router/Cargo.toml | 2 + bin/router/src/lib.rs | 10 +- bin/router/src/pipeline/execution.rs | 17 ++- bin/router/src/pipeline/mod.rs | 109 ++++++++++-------- bin/router/src/pipeline/parser.rs | 35 +++--- .../src/pipeline/progressive_override.rs | 8 +- bin/router/src/pipeline/query_plan.rs | 21 ++-- bin/router/src/pipeline/validation.rs | 8 +- bin/router/src/plugins/mod.rs | 1 + bin/router/src/plugins/plugins_service.rs | 109 ++++++++++++++++++ .../src/execution/client_request_details.rs | 12 +- lib/executor/src/execution/plan.rs | 76 ++++++------ lib/executor/src/executors/http.rs | 2 +- lib/executor/src/executors/map.rs | 16 ++- lib/executor/src/headers/expression.rs | 4 +- lib/executor/src/headers/mod.rs | 72 ++++++------ lib/executor/src/headers/request.rs | 6 +- lib/executor/src/headers/response.rs | 8 +- lib/executor/src/plugins/examples/apq.rs | 7 +- .../src/plugins/examples/response_cache.rs | 3 +- .../examples/subgraph_response_cache.rs | 3 +- lib/executor/src/plugins/hooks/on_execute.rs | 8 +- .../src/plugins/hooks/on_graphql_params.rs | 9 +- .../src/plugins/hooks/on_graphql_parse.rs | 4 +- .../plugins/hooks/on_graphql_validation.rs | 13 ++- .../src/plugins/hooks/on_http_request.rs | 22 ++-- .../src/plugins/hooks/on_query_plan.rs | 8 +- .../src/plugins/hooks/on_subgraph_execute.rs | 4 + lib/executor/src/plugins/mod.rs | 1 + lib/executor/src/plugins/plugin_context.rs | 43 +++++++ lib/executor/src/plugins/plugin_trait.rs | 27 ++--- 32 files changed, 424 insertions(+), 246 deletions(-) create mode 100644 bin/router/src/plugins/mod.rs create mode 100644 bin/router/src/plugins/plugins_service.rs create mode 100644 lib/executor/src/plugins/plugin_context.rs diff --git a/Cargo.lock b/Cargo.lock index b81088873..52b09935b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2004,6 +2004,8 @@ dependencies = [ "mimalloc", "moka", "ntex", + "ntex-service", + "ntex-util", "rand 0.9.2", "regex-automata", "reqwest", diff --git a/bin/router/Cargo.toml b/bin/router/Cargo.toml index 4969b0a46..15d5e2051 100644 --- a/bin/router/Cargo.toml +++ b/bin/router/Cargo.toml @@ -53,3 +53,5 @@ tokio-util = "0.7.16" cookie = "0.18.1" regex-automata = "0.4.10" arc-swap = "1.7.1" +ntex-util = "2.15.0" +ntex-service = "3.5.0" \ No newline at end of file diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 416021b3a..dec90d730 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -4,6 +4,7 @@ mod http_utils; mod jwt; mod logger; mod pipeline; +mod plugins; mod schema_state; mod shared_state; mod supergraph; @@ -24,6 +25,7 @@ use crate::{ graphql_request_handler, header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, }, + plugins::plugins_service::PluginService, }; pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; @@ -56,8 +58,8 @@ async fn graphql_endpoint_handler( let accept_ok = !req.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); - let result = match graphql_request_handler( - req, + let mut response = match graphql_request_handler( + &req, body_bytes, supergraph, app_state.get_ref().clone(), @@ -69,9 +71,6 @@ async fn graphql_endpoint_handler( Err(error) => return PipelineError { accept_ok, error }.into(), }; - let mut response = result.result; - let req = result.request; - // Apply CORS headers to the final response if CORS is configured. if let Some(cors) = app_state.cors_runtime.as_ref() { cors.set_headers(&req, response.headers_mut()); @@ -99,6 +98,7 @@ pub async fn router_entrypoint() -> Result<(), Box> { let maybe_error = web::HttpServer::new(move || { web::App::new() + .wrap(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) .configure(configure_ntex_app) diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index e69a89179..e40fc02c3 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -7,11 +7,10 @@ use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::shared_state::RouterSharedState; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; -use hive_router_plan_executor::execution::plan::{ - PlanExecutionOutput, QueryPlanExecutionContext, ResultWithRequest, -}; +use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use http::HeaderName; use ntex::web::HttpRequest; @@ -26,15 +25,16 @@ enum ExposeQueryPlanMode { } #[inline] -pub async fn execute_plan( - req: HttpRequest, +pub async fn execute_plan<'exec, 'req>( + req: &HttpRequest, supergraph: &SupergraphData, app_state: Arc, normalized_payload: Arc, query_plan_payload: Arc, variable_payload: &CoerceVariablesPayload, - client_request_details: &ClientRequestDetails<'_>, -) -> Result, PipelineErrorVariant> { + client_request_details: &ClientRequestDetails<'exec, 'req>, + plugin_manager: PluginManager<'req>, +) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; if app_state.router_config.query_planner.allow_expose { @@ -86,7 +86,7 @@ pub async fn execute_plan( }; let ctx = QueryPlanExecutionContext { - router_http_request: req, + plugin_manager: &plugin_manager, query_plan: query_plan_payload, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, @@ -97,7 +97,6 @@ pub async fn execute_plan( operation_type_name: normalized_payload.root_type_name, jwt_auth_forwarding, executors: &supergraph.subgraph_executor_map, - plugins: &app_state.plugins, }; ctx.execute_query_plan().await.map_err(|err| { diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index b97f7e67f..860a13c91 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -3,12 +3,13 @@ use std::sync::Arc; use hive_router_plan_executor::{ execution::{ client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::{PlanExecutionOutput, ResultWithRequest, WithResult}, + plan::PlanExecutionOutput, }, hooks::{ on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, on_supergraph_load::SupergraphData, }, + plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ @@ -59,28 +60,26 @@ static GRAPHIQL_HTML: &str = include_str!("../../static/graphiql.html"); #[inline] pub async fn graphql_request_handler( - req: HttpRequest, + req: &HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, shared_state: Arc, schema_state: Arc, -) -> Result, PipelineErrorVariant> { +) -> Result { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { - return Ok(req.with_result( - web::HttpResponse::Ok() - .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) - .body(GRAPHIQL_HTML), - )); + return Ok(web::HttpResponse::Ok() + .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) + .body(GRAPHIQL_HTML)); } else { - return Ok(req.with_result(web::HttpResponse::NotFound().into())); + return Ok(web::HttpResponse::NotFound().into()); } } let jwt_context = if let Some(jwt) = &shared_state.jwt_auth_runtime { - match jwt.validate_request(&req) { + match jwt.validate_request(req) { Ok(jwt_context) => jwt_context, - Err(err) => return Ok(req.with_result(err.make_response())), + Err(err) => return Ok(err.make_response()), } } else { None @@ -93,16 +92,36 @@ pub async fn graphql_request_handler( &APPLICATION_JSON }; - let execution_result_with_req = execute_pipeline( + let plugin_context = req + .extensions() + .get::>() + .cloned() + .expect("Plugin manager should be loaded"); + + let plugin_manager = PluginManager { + plugins: shared_state.plugins.clone(), + router_http_request: RouterHttpRequest { + uri: req.uri(), + method: req.method(), + version: req.version(), + headers: req.headers(), + match_info: req.match_info(), + query_string: req.query_string(), + path: req.path(), + }, + context: plugin_context, + }; + + let response = execute_pipeline( req, body_bytes, supergraph, shared_state, schema_state, jwt_context, + plugin_manager, ) .await?; - let response = execution_result_with_req.result; let response_bytes = Bytes::from(response.body); let response_headers = response.headers; @@ -113,41 +132,39 @@ pub async fn graphql_request_handler( } } - Ok(execution_result_with_req.request.with_result( - response_builder - .header(http::header::CONTENT_TYPE, response_content_type) - .body(response_bytes), - )) + Ok(response_builder + .header(http::header::CONTENT_TYPE, response_content_type) + .body(response_bytes)) } #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline( - req: HttpRequest, +pub async fn execute_pipeline<'req>( + req: &'req HttpRequest, body: Bytes, supergraph: &SupergraphData, shared_state: Arc, schema_state: Arc, jwt_context: Option, -) -> Result, PipelineErrorVariant> { - perform_csrf_prevention(&req, &shared_state.router_config.csrf)?; + plugin_manager: PluginManager<'req>, +) -> Result { + perform_csrf_prevention(req, &shared_state.router_config.csrf)?; /* Handle on_deserialize hook in the plugins - START */ let mut deserialization_end_callbacks = vec![]; let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { - router_http_request: req, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, body, graphql_params: None, }; for plugin in shared_state.plugins.as_ref() { - let result = plugin.on_graphql_params(deserialization_payload); + let result = plugin.on_graphql_params(deserialization_payload).await; deserialization_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(deserialization_payload - .router_http_request - .with_result(response)); + return Ok(response); } ControlFlowResult::OnEnd(callback) => { deserialization_end_callbacks.push(callback); @@ -155,11 +172,8 @@ pub async fn execute_pipeline( } } let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { - deserialize_graphql_params( - &deserialization_payload.router_http_request, - deserialization_payload.body, - ) - .expect("Failed to parse execution request") + deserialize_graphql_params(req, deserialization_payload.body) + .expect("Failed to parse execution request") }); let mut payload = OnGraphQLParamsEndPayload { graphql_params }; @@ -169,9 +183,7 @@ pub async fn execute_pipeline( match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(deserialization_payload - .router_http_request - .with_result(response)); + return Ok(response); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -182,25 +194,22 @@ pub async fn execute_pipeline( let mut graphql_params = payload.graphql_params; /* Handle on_deserialize hook in the plugins - END */ - let req = deserialization_payload.router_http_request; let parser_result = - parse_operation_with_cache(req, shared_state.clone(), &graphql_params).await?; - - let mut req = parser_result.request; + parse_operation_with_cache(shared_state.clone(), &graphql_params, &plugin_manager).await?; - let parser_payload = match parser_result.result { + let parser_payload = match parser_result { ParseResult::Payload(payload) => payload, ParseResult::Response(response) => { - return Ok(req.with_result(response)); + return Ok(response); } }; validate_operation_with_cache( - &mut req, supergraph, schema_state.clone(), shared_state.clone(), &parser_payload, + &plugin_manager, ) .await?; @@ -213,7 +222,7 @@ pub async fn execute_pipeline( .await?; let variable_payload = - coerce_request_variables(&req, supergraph, &mut graphql_params, &normalize_payload)?; + coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); @@ -231,9 +240,9 @@ pub async fn execute_pipeline( }; let client_request_details = ClientRequestDetails { - method: req.method().clone(), - url: req.uri().clone(), - headers: req.headers().clone(), + method: req.method(), + url: req.uri(), + headers: req.headers(), operation: OperationDetails { name: normalize_payload.operation_for_plan.name.as_deref(), kind: match normalize_payload.operation_for_plan.operation_kind { @@ -254,20 +263,19 @@ pub async fn execute_pipeline( .map_err(PipelineErrorVariant::LabelEvaluationError)?; let query_plan_result = plan_operation_with_cache( - req, supergraph, schema_state.clone(), normalize_payload.clone(), &progressive_override_ctx, &query_plan_cancellation_token, shared_state.clone(), + &plugin_manager, ) .await?; - let req = query_plan_result.request; - let query_plan_payload = match query_plan_result.result { + let query_plan_payload = match query_plan_result { QueryPlanResult::QueryPlan(plan) => plan, QueryPlanResult::Response(response) => { - return Ok(req.with_result(response)); + return Ok(response); } }; @@ -279,6 +287,7 @@ pub async fn execute_pipeline( query_plan_payload, &variable_payload, &client_request_details, + plugin_manager, ) .await?; diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index e6194fbd5..0a11ab2aa 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,16 +2,14 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; -use hive_router_plan_executor::execution::plan::{ - PlanExecutionOutput, ResultWithRequest, WithResult, -}; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_graphql_parse::{ OnGraphQLParseEndPayload, OnGraphQLParseStartPayload, }; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; use crate::pipeline::deserialize_graphql_params::GetQueryStr; @@ -31,11 +29,11 @@ pub enum ParseResult { } #[inline] -pub async fn parse_operation_with_cache( - req: HttpRequest, +pub async fn parse_operation_with_cache<'req>( app_state: Arc, graphql_params: &GraphQLParams, -) -> Result, PipelineErrorVariant> { + plugin_manager: &PluginManager<'req>, +) -> Result { let cache_key = { let mut hasher = Xxh3::new(); graphql_params.query.hash(&mut hasher); @@ -43,7 +41,8 @@ pub async fn parse_operation_with_cache( }; /* Handle on_graphql_parse hook in the plugins - START */ let mut start_payload = OnGraphQLParseStartPayload { - router_http_request: req, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, graphql_params, document: None, }; @@ -54,16 +53,14 @@ pub async fn parse_operation_with_cache( } else { let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { - let result = plugin.on_graphql_parse(start_payload); + let result = plugin.on_graphql_parse(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { // continue to next plugin } ControlFlowResult::EndResponse(response) => { - return Ok(start_payload - .router_http_request - .with_result(ParseResult::Response(response))); + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(callback) => { // store the callback to be called later @@ -93,9 +90,7 @@ pub async fn parse_operation_with_cache( // continue to next callback } ControlFlowResult::EndResponse(response) => { - return Ok(start_payload - .router_http_request - .with_result(ParseResult::Response(response))); + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -114,10 +109,8 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok(start_payload - .router_http_request - .with_result(ParseResult::Payload(GraphQLParserPayload { - parsed_operation, - cache_key, - }))) + Ok(ParseResult::Payload(GraphQLParserPayload { + parsed_operation, + cache_key, + })) } diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index dc28a0d60..d0b09c183 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec>( +pub fn request_override_context<'exec, 'req>( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec>, + client_request_details: &ClientRequestDetails<'exec, 'req>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; @@ -158,9 +158,9 @@ impl OverrideLabelsEvaluator { }) } - pub(crate) fn evaluate<'exec>( + pub(crate) fn evaluate<'exec, 'req>( &self, - client_request: &ClientRequestDetails<'exec>, + client_request: &ClientRequestDetails<'exec, 'req>, ) -> Result, LabelEvaluationError> { let mut active_flags = self.static_enabled_labels.clone(); diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 807946296..156b8ef9f 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -6,18 +6,16 @@ use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; use crate::schema_state::SchemaState; use crate::RouterSharedState; -use hive_router_plan_executor::execution::plan::{ - PlanExecutionOutput, ResultWithRequest, WithResult, -}; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; use hive_router_plan_executor::hooks::on_query_plan::{ OnQueryPlanEndPayload, OnQueryPlanStartPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; pub enum QueryPlanResult { @@ -31,15 +29,15 @@ pub enum QueryPlanGetterError { } #[inline] -pub async fn plan_operation_with_cache( - mut req: HttpRequest, +pub async fn plan_operation_with_cache<'req>( supergraph: &SupergraphData, schema_state: Arc, normalized_operation: Arc, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, app_state: Arc, -) -> Result, PipelineErrorVariant> { + plugin_manager: &PluginManager<'req>, +) -> Result { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -61,7 +59,8 @@ pub async fn plan_operation_with_cache( /* Handle on_query_plan hook in the plugins - START */ let mut start_payload = OnQueryPlanStartPayload { - router_http_request: &mut req, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, filtered_operation_for_plan, planner_override_context: (&request_override_context.clone()).into(), cancellation_token, @@ -71,7 +70,7 @@ pub async fn plan_operation_with_cache( let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { - let result = plugin.on_query_plan(start_payload); + let result = plugin.on_query_plan(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { @@ -121,11 +120,11 @@ pub async fn plan_operation_with_cache( .await; match plan_result { - Ok(plan) => Ok(req.with_result(QueryPlanResult::QueryPlan(plan))), + Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), Err(e) => match e.as_ref() { QueryPlanGetterError::Planner(e) => Err(PipelineErrorVariant::PlannerError(e.clone())), QueryPlanGetterError::Response(response) => { - Ok(req.with_result(QueryPlanResult::Response(response.clone()))) + Ok(QueryPlanResult::Response(response.clone())) } }, } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 3efbddcca..afe833656 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -10,17 +10,17 @@ use hive_router_plan_executor::hooks::on_graphql_validation::{ OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_plan_executor::plugin_trait::ControlFlowResult; -use ntex::web::HttpRequest; use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( - req: &mut HttpRequest, supergraph: &SupergraphData, schema_state: Arc, app_state: Arc, parser_payload: &GraphQLParserPayload, + plugin_manager: &PluginManager<'_>, ) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; @@ -45,14 +45,14 @@ pub async fn validate_operation_with_cache( /* Handle on_graphql_validate hook in the plugins - START */ let mut start_payload = OnGraphQLValidationStartPayload::new( - req, + plugin_manager, consumer_schema_ast, &parser_payload.parsed_operation, &app_state.validation_plan, ); let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { - let result = plugin.on_graphql_validation(start_payload); + let result = plugin.on_graphql_validation(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { diff --git a/bin/router/src/plugins/mod.rs b/bin/router/src/plugins/mod.rs new file mode 100644 index 000000000..3753246b2 --- /dev/null +++ b/bin/router/src/plugins/mod.rs @@ -0,0 +1 @@ +pub mod plugins_service; diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs new file mode 100644 index 000000000..49204b662 --- /dev/null +++ b/bin/router/src/plugins/plugins_service.rs @@ -0,0 +1,109 @@ +use std::sync::Arc; + +use hive_router_plan_executor::{ + hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + plugin_context::PluginContext, + plugin_trait::ControlFlowResult, +}; +use ntex::{ + service::{Service, ServiceCtx}, + web::{self, DefaultError}, + Middleware, +}; + +use crate::RouterSharedState; + +pub struct PluginService; + +impl Middleware for PluginService { + type Service = PluginMiddleware; + + fn create(&self, service: S) -> Self::Service { + PluginMiddleware { service } + } +} + +pub struct PluginMiddleware { + // This is special: We need this to avoid lifetime issues. + service: S, +} + +impl Service> for PluginMiddleware +where + S: Service, Response = web::WebResponse, Error = web::Error>, +{ + type Response = web::WebResponse; + type Error = S::Error; + + ntex::forward_ready!(service); + + async fn call( + &self, + req: web::WebRequest, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + let plugins = req + .app_state::>() + .map(|shared_state| shared_state.plugins.clone()); + + if let Some(plugins) = plugins { + let plugin_context = Arc::new(PluginContext::default()); + req.extensions_mut().insert(plugin_context.clone()); + let mut start_payload = OnHttpRequestPayload { + router_http_request: req, + context: &plugin_context, + response: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.iter() { + let result = plugin.on_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + ControlFlowResult::EndResponse(_response) => { + // Short-circuit the request with the provided response + unimplemented!(); + } + } + } + + let req = start_payload.router_http_request; + + let response = match start_payload.response { + Some(response) => response, + None => ctx.call(&self.service, req).await?, + }; + + let mut end_payload = OnHttpResponsePayload { response }; + + for callback in on_end_callbacks.into_iter().rev() { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(_response) => { + // Short-circuit the request with the provided response + unimplemented!() + } + ControlFlowResult::OnEnd(_) => { + // This should not happen + unreachable!(); + } + } + } + + return Ok(end_payload.response); + } + + ctx.call(&self.service, req).await + } +} diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 71f28b746..20b7dcf98 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -13,10 +13,10 @@ pub struct OperationDetails<'exec> { pub kind: &'static str, } -pub struct ClientRequestDetails<'exec> { - pub method: Method, - pub url: http::Uri, - pub headers: NtexHeaderMap, +pub struct ClientRequestDetails<'exec, 'req> { + pub method: &'req Method, + pub url: &'req http::Uri, + pub headers: &'req NtexHeaderMap, pub operation: OperationDetails<'exec>, pub jwt: JwtRequestDetails, } @@ -31,10 +31,10 @@ pub enum JwtRequestDetails { Unauthenticated, } -impl From<&ClientRequestDetails<'_>> for Value { +impl From<&ClientRequestDetails<'_, '_>> for Value { fn from(details: &ClientRequestDetails) -> Self { // .request.headers - let headers_value = client_header_map_to_vrl_value(&details.headers); + let headers_value = client_header_map_to_vrl_value(details.headers); // .request.url let url_value = Self::Object(BTreeMap::from([ diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 5ea8bd758..bbe5ba836 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -10,7 +10,6 @@ use hive_router_query_planner::planner::plan_nodes::{ QueryPlan, SequenceNode, }; use http::HeaderMap; -use ntex::web::HttpRequest; use serde::Deserialize; use sonic_rs::ValueRef; @@ -36,7 +35,8 @@ use crate::{ resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, }, - plugin_trait::{ControlFlowResult, RouterPlugin}, + plugin_context::PluginManager, + plugin_trait::ControlFlowResult, projection::{ plan::FieldProjectionPlan, request::{project_requires, RequestProjectionContext}, @@ -54,15 +54,14 @@ use crate::{ }, }; -pub struct QueryPlanExecutionContext<'exec> { - pub router_http_request: HttpRequest, - pub plugins: &'exec Vec>, +pub struct QueryPlanExecutionContext<'exec, 'req> { + pub plugin_manager: &'exec PluginManager<'exec>, pub query_plan: Arc, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, pub extensions: Option>, - pub client_request: &'exec ClientRequestDetails<'exec>, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, pub introspection_context: &'exec IntrospectionContext<'exec, 'static>, pub operation_type_name: &'exec str, pub executors: &'exec SubgraphExecutorMap, @@ -75,25 +74,8 @@ pub struct PlanExecutionOutput { pub headers: HeaderMap, } -pub struct ResultWithRequest { - pub result: T, - pub request: HttpRequest, -} - -pub trait WithResult { - fn with_result(self, result: T) -> ResultWithRequest; -} - -impl WithResult for HttpRequest { - fn with_result(self, result: T) -> ResultWithRequest { - ResultWithRequest { result, request: self } - } -} - -impl<'exec> QueryPlanExecutionContext<'exec> { - pub async fn execute_query_plan( - self, - ) -> Result, PlanExecutionError> { +impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { + pub async fn execute_query_plan(self) -> Result { let init_value = if let Some(introspection_query) = self.introspection_context.query { resolve_introspection(introspection_query, self.introspection_context) } else { @@ -103,8 +85,9 @@ impl<'exec> QueryPlanExecutionContext<'exec> { let dedupe_subgraph_requests = self.operation_type_name == "Query"; let mut start_payload = OnExecuteStartPayload { - router_http_request: self.router_http_request, - query_plan: self.query_plan, + router_http_request: &self.plugin_manager.router_http_request, + context: &self.plugin_manager.context, + query_plan: &self.query_plan, data: init_value, errors: Vec::new(), extensions: self.extensions.clone(), @@ -114,13 +97,13 @@ impl<'exec> QueryPlanExecutionContext<'exec> { let mut on_end_callbacks = vec![]; - for plugin in self.plugins { - let result = plugin.on_execute(start_payload); + for plugin in self.plugin_manager.plugins.iter() { + let result = plugin.on_execute(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(start_payload.router_http_request.with_result(response)); + return Ok(response); } ControlFlowResult::OnEnd(callback) => { on_end_callbacks.push(callback); @@ -132,7 +115,7 @@ impl<'exec> QueryPlanExecutionContext<'exec> { let init_value = start_payload.data; - let mut exec_ctx = ExecutionContext::new(&query_plan, init_value); + let mut exec_ctx = ExecutionContext::new(query_plan, init_value); let executor = Executor::new( self.variable_values, self.executors, @@ -142,6 +125,7 @@ impl<'exec> QueryPlanExecutionContext<'exec> { self.jwt_auth_forwarding, // Deduplicate subgraph requests only if the operation type is a query self.operation_type_name == "Query", + self.plugin_manager, ); if query_plan.node.is_some() { @@ -170,7 +154,7 @@ impl<'exec> QueryPlanExecutionContext<'exec> { match result.control_flow { ControlFlowResult::Continue => { /* continue to next callback */ } ControlFlowResult::EndResponse(output) => { - return Ok(start_payload.router_http_request.with_result(output)); + return Ok(output); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -193,23 +177,22 @@ impl<'exec> QueryPlanExecutionContext<'exec> { affected_path: || None, })?; - Ok(start_payload - .router_http_request - .with_result(PlanExecutionOutput { - body, - headers: response_headers, - })) + Ok(PlanExecutionOutput { + body, + headers: response_headers, + }) } } -pub struct Executor<'exec> { +pub struct Executor<'exec, 'req> { variable_values: &'exec Option>, schema_metadata: &'exec SchemaMetadata, executors: &'exec SubgraphExecutorMap, - client_request: &'exec ClientRequestDetails<'exec>, + client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_manager: &'exec PluginManager<'exec>, } struct ConcurrencyScope<'exec, T> { @@ -296,15 +279,16 @@ struct PreparedFlattenData { representation_hash_to_index: HashMap, } -impl<'exec> Executor<'exec> { +impl<'exec, 'req> Executor<'exec, 'req> { pub fn new( variable_values: &'exec Option>, executors: &'exec SubgraphExecutorMap, schema_metadata: &'exec SchemaMetadata, - client_request: &'exec ClientRequestDetails<'exec>, + client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_manager: &'exec PluginManager<'exec>, ) -> Self { Executor { variable_values, @@ -314,6 +298,7 @@ impl<'exec> Executor<'exec> { headers_plan, dedupe_subgraph_requests, jwt_forwarding_plan, + plugin_manager, } } @@ -803,7 +788,12 @@ impl<'exec> Executor<'exec> { subgraph_name: node.service_name.clone(), response: self .executors - .execute(&node.service_name, subgraph_request, self.client_request) + .execute( + &node.service_name, + subgraph_request, + self.client_request, + self.plugin_manager, + ) .await .into(), })) diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index a5a5ed263..9e2770965 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -191,7 +191,7 @@ async fn send_request( let mut on_end_callbacks = vec![]; for plugin in plugins.as_ref() { - let result = plugin.on_subgraph_http_request(start_payload); + let result = plugin.on_subgraph_http_request(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index cf2f6dc13..cf6033040 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -38,6 +38,7 @@ use crate::{ http::{HTTPSubgraphExecutor, HttpClient}, }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_context::PluginManager, plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError, }; @@ -124,13 +125,16 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - pub async fn execute<'a>( + pub async fn execute<'exec, 'req>( &self, subgraph_name: &str, - execution_request: SubgraphExecutionRequest<'a>, - client_request: &ClientRequestDetails<'a>, + execution_request: SubgraphExecutionRequest<'exec>, + client_request: &ClientRequestDetails<'exec, 'req>, + plugin_manager: &PluginManager<'req>, ) -> HttpExecutionResponse { let mut start_payload = OnSubgraphExecuteStartPayload { + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, subgraph_name: subgraph_name.to_string(), execution_request, execution_result: None, @@ -139,7 +143,7 @@ impl SubgraphExecutorMap { let mut on_end_callbacks = vec![]; for plugin in self.plugins.as_ref() { - let result = plugin.on_subgraph_execute(start_payload); + let result = plugin.on_subgraph_execute(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { @@ -227,7 +231,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_>, + client_request: &ClientRequestDetails<'_, '_>, ) -> Result, SubgraphExecutorError> { let from_expression = self.get_or_create_executor_from_expression(subgraph_name, client_request)?; @@ -246,7 +250,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor_from_expression( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_>, + client_request: &ClientRequestDetails<'_, '_>, ) -> Result, SubgraphExecutorError> { if let Some(expression) = self.expressions_by_subgraph.get(subgraph_name) { let original_url_value = VrlValue::Bytes(Bytes::from( diff --git a/lib/executor/src/headers/expression.rs b/lib/executor/src/headers/expression.rs index 5852f2b75..ed63ebfc0 100644 --- a/lib/executor/src/headers/expression.rs +++ b/lib/executor/src/headers/expression.rs @@ -46,7 +46,7 @@ fn header_map_to_vrl_value(headers: &HeaderMap) -> Value { Value::Object(obj) } -impl From<&RequestExpressionContext<'_>> for Value { +impl From<&RequestExpressionContext<'_, '_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &RequestExpressionContext) -> Self { // .subgraph @@ -65,7 +65,7 @@ impl From<&RequestExpressionContext<'_>> for Value { } } -impl From<&ResponseExpressionContext<'_>> for Value { +impl From<&ResponseExpressionContext<'_, '_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &ResponseExpressionContext) -> Self { // .subgraph diff --git a/lib/executor/src/headers/mod.rs b/lib/executor/src/headers/mod.rs index c617edfa0..0338bf6df 100644 --- a/lib/executor/src/headers/mod.rs +++ b/lib/executor/src/headers/mod.rs @@ -74,9 +74,9 @@ mod tests { ); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -108,9 +108,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -155,9 +155,9 @@ mod tests { ); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -193,9 +193,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: Some("MyQuery"), query: "{ __typename }", @@ -227,9 +227,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -267,9 +267,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -311,9 +311,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -376,9 +376,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -440,9 +440,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -497,9 +497,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -555,9 +555,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -614,9 +614,9 @@ mod tests { client_headers.insert(header_name_owned("x-keep"), header_value_owned("hi").into()); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", diff --git a/lib/executor/src/headers/request.rs b/lib/executor/src/headers/request.rs index 44b6ed8b9..7f362be73 100644 --- a/lib/executor/src/headers/request.rs +++ b/lib/executor/src/headers/request.rs @@ -45,9 +45,9 @@ pub fn modify_subgraph_request_headers( Ok(()) } -pub struct RequestExpressionContext<'a> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a>, +pub struct RequestExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, } trait ApplyRequestHeader { diff --git a/lib/executor/src/headers/response.rs b/lib/executor/src/headers/response.rs index 94019d585..b4942837f 100644 --- a/lib/executor/src/headers/response.rs +++ b/lib/executor/src/headers/response.rs @@ -50,10 +50,10 @@ pub fn apply_subgraph_response_headers( Ok(()) } -pub struct ResponseExpressionContext<'a> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a>, - pub subgraph_headers: &'a HeaderMap, +pub struct ResponseExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, + pub subgraph_headers: &'exec HeaderMap, } trait ApplyResponseHeader { diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index f5e380973..91a473b5e 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -10,11 +10,12 @@ pub struct APQPlugin { cache: DashMap, } +#[async_trait::async_trait] impl RouterPlugin for APQPlugin { - fn on_graphql_params<'exec>( + async fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { payload.on_end(|mut payload| { let persisted_query_ext = payload .graphql_params diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index fc6276643..ee92a029d 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -28,8 +28,9 @@ impl ResponseCachePlugin { } } +#[async_trait::async_trait] impl RouterPlugin for ResponseCachePlugin { - fn on_execute<'exec>( + async fn on_execute<'exec>( &'exec self, payload: OnExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 4d192dd39..4e4b36666 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -10,8 +10,9 @@ pub struct SubgraphResponseCachePlugin { cache: DashMap, } +#[async_trait::async_trait] impl RouterPlugin for SubgraphResponseCachePlugin { - fn on_subgraph_execute<'exec>( + async fn on_subgraph_execute<'exec>( &'exec self, mut payload: OnSubgraphExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 9a77679b9..328a1a8fd 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,16 +1,16 @@ use std::collections::HashMap; -use std::sync::Arc; use hive_router_query_planner::planner::plan_nodes::QueryPlan; -use ntex::web::HttpRequest; +use crate::plugin_context::{PluginContext, RouterHttpRequest}; use crate::plugin_trait::{EndPayload, StartPayload}; use crate::response::graphql_error::GraphQLError; use crate::response::value::Value; pub struct OnExecuteStartPayload<'exec> { - pub router_http_request: HttpRequest, - pub query_plan: Arc, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub query_plan: &'exec QueryPlan, pub data: Value<'exec>, pub errors: Vec, diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs index a9afabed1..d954f94e7 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_params.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -6,6 +6,8 @@ use ntex::util::Bytes; use serde::{de, Deserialize, Deserializer}; use sonic_rs::Value; +use crate::plugin_context::PluginContext; +use crate::plugin_context::RouterHttpRequest; use crate::plugin_trait::EndPayload; use crate::plugin_trait::StartPayload; @@ -91,13 +93,14 @@ impl<'de> Deserialize<'de> for GraphQLParams { } } -pub struct OnGraphQLParamsStartPayload { - pub router_http_request: ntex::web::HttpRequest, +pub struct OnGraphQLParamsStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub body: Bytes, pub graphql_params: Option, } -impl StartPayload for OnGraphQLParamsStartPayload {} +impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} pub struct OnGraphQLParamsEndPayload { pub graphql_params: GraphQLParams, diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs index df9b4e480..fa29e3b9d 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_parse.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -2,11 +2,13 @@ use graphql_tools::static_graphql::query::Document; use crate::{ hooks::on_graphql_params::GraphQLParams, + plugin_context::{PluginContext, RouterHttpRequest}, plugin_trait::{EndPayload, StartPayload}, }; pub struct OnGraphQLParseStartPayload<'exec> { - pub router_http_request: ntex::web::HttpRequest, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub graphql_params: &'exec GraphQLParams, pub document: Option, } diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index f6bb55004..c341a6a36 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -8,10 +8,14 @@ use graphql_tools::{ }; use hive_router_query_planner::state::supergraph_state::SchemaDocument; -use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::{ + plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnGraphQLValidationStartPayload<'exec> { - pub router_http_request: &'exec mut ntex::web::HttpRequest, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub schema: &'exec SchemaDocument, pub document: &'exec Document, default_validation_plan: &'exec ValidationPlan, @@ -23,13 +27,14 @@ impl<'exec> StartPayload for OnGraphQLValidationS impl<'exec> OnGraphQLValidationStartPayload<'exec> { pub fn new( - router_http_request: &'exec mut ntex::web::HttpRequest, + plugin_manager: &'exec PluginManager<'exec>, schema: &'exec SchemaDocument, document: &'exec Document, default_validation_plan: &'exec ValidationPlan, ) -> Self { OnGraphQLValidationStartPayload { - router_http_request, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, schema, document, default_validation_plan, diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index a7f6f6bb5..84a683948 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -1,16 +1,20 @@ -use ntex::{http::Response, web::HttpRequest}; +use ntex::web::{self, DefaultError}; -use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::{ + plugin_context::PluginContext, + plugin_trait::{EndPayload, StartPayload}, +}; -pub struct OnHttpRequestPayload<'exec> { - pub client_request: &'exec HttpRequest, +pub struct OnHttpRequestPayload<'req> { + pub router_http_request: web::WebRequest, + pub context: &'req PluginContext, + pub response: Option, } -impl<'exec> StartPayload> for OnHttpRequestPayload<'exec> {} +impl<'req> StartPayload for OnHttpRequestPayload<'req> {} -pub struct OnHttpResponse<'exec> { - pub router_http_request: &'exec HttpRequest, - pub response: &'exec mut Response, +pub struct OnHttpResponsePayload { + pub response: web::WebResponse, } -impl<'exec> EndPayload for OnHttpResponse<'exec> {} +impl EndPayload for OnHttpResponsePayload {} diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs index fd2089ec6..9b2110fd7 100644 --- a/lib/executor/src/plugins/hooks/on_query_plan.rs +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -5,10 +5,14 @@ use hive_router_query_planner::{ utils::cancellation::CancellationToken, }; -use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::{ + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnQueryPlanStartPayload<'exec> { - pub router_http_request: &'exec mut ntex::web::HttpRequest, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub filtered_operation_for_plan: &'exec OperationDefinition, pub planner_override_context: PlannerOverrideContext, pub cancellation_token: &'exec CancellationToken, diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 5a6fcc6a6..25ec2c30b 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,9 +1,13 @@ use crate::{ executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, + plugin_context::{PluginContext, RouterHttpRequest}, plugin_trait::{EndPayload, StartPayload}, }; pub struct OnSubgraphExecuteStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub subgraph_name: String, pub execution_request: SubgraphExecutionRequest<'exec>, diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs index 02490fb5e..3c24ff9f2 100644 --- a/lib/executor/src/plugins/mod.rs +++ b/lib/executor/src/plugins/mod.rs @@ -1,3 +1,4 @@ pub mod examples; pub mod hooks; +pub mod plugin_context; pub mod plugin_trait; diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs new file mode 100644 index 000000000..bce85a517 --- /dev/null +++ b/lib/executor/src/plugins/plugin_context.rs @@ -0,0 +1,43 @@ +use std::{ + any::{Any, TypeId}, + sync::Arc, +}; + +use dashmap::DashMap; +use http::Uri; +use ntex::router::Path; +use ntex_http::HeaderMap; + +use crate::plugin_trait::RouterPlugin; + +pub struct RouterHttpRequest<'exec> { + pub uri: &'exec Uri, + pub method: &'exec http::Method, + pub version: http::Version, + pub headers: &'exec HeaderMap, + pub path: &'exec str, + pub query_string: &'exec str, + pub match_info: &'exec Path, +} + +#[derive(Default)] +pub struct PluginContext { + inner: DashMap>, +} + +impl PluginContext { + pub fn insert(&self, value: T) { + self.inner.insert(TypeId::of::(), Arc::new(value)); + } + pub fn get(&self) -> Option> { + self.inner + .get(&TypeId::of::()) + .map(|v| v.clone().downcast::().ok().unwrap()) + } +} + +pub struct PluginManager<'req> { + pub plugins: Arc>>, + pub router_http_request: RouterHttpRequest<'req>, + pub context: Arc, +} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index c502ba087..0a3f6db32 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -5,7 +5,7 @@ use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseSta use crate::hooks::on_graphql_validation::{ OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, }; -use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponse}; +use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}; use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; use crate::hooks::on_subgraph_execute::{ OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, @@ -77,51 +77,52 @@ where } } +#[async_trait::async_trait] pub trait RouterPlugin { - fn on_http_request<'exec>( + fn on_http_request<'req>( &self, - start_payload: OnHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponse<'exec>> { + start_payload: OnHttpRequestPayload<'req>, + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { start_payload.cont() } - fn on_graphql_params<'exec>( + async fn on_graphql_params<'exec>( &'exec self, - start_payload: OnGraphQLParamsStartPayload, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { + start_payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { start_payload.cont() } - fn on_graphql_parse<'exec>( + async fn on_graphql_parse<'exec>( &self, start_payload: OnGraphQLParseStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { start_payload.cont() } - fn on_graphql_validation<'exec>( + async fn on_graphql_validation<'exec>( &self, start_payload: OnGraphQLValidationStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> { start_payload.cont() } - fn on_query_plan<'exec>( + async fn on_query_plan<'exec>( &self, start_payload: OnQueryPlanStartPayload<'exec>, ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { start_payload.cont() } - fn on_execute<'exec>( + async fn on_execute<'exec>( &'exec self, start_payload: OnExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { start_payload.cont() } - fn on_subgraph_execute<'exec>( + async fn on_subgraph_execute<'exec>( &'exec self, start_payload: OnSubgraphExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { start_payload.cont() } - fn on_subgraph_http_request<'exec>( + async fn on_subgraph_http_request<'exec>( &'exec self, start_payload: OnSubgraphHttpRequestPayload<'exec>, ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { From 419294e209b15bf58de435c5df24a1e2cb1f0cb3 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Thu, 20 Nov 2025 18:19:02 +0300 Subject: [PATCH 10/31] Improvements --- bin/router/src/lib.rs | 4 +- bin/router/src/pipeline/coerce_variables.rs | 3 +- bin/router/src/pipeline/execution.rs | 14 +++---- bin/router/src/pipeline/mod.rs | 42 +++++++++---------- bin/router/src/pipeline/normalize.rs | 2 +- bin/router/src/pipeline/parser.rs | 6 +-- .../src/pipeline/progressive_override.rs | 4 +- bin/router/src/pipeline/query_plan.rs | 6 +-- bin/router/src/pipeline/validation.rs | 4 +- lib/executor/src/execution/plan.rs | 9 ++-- 10 files changed, 43 insertions(+), 51 deletions(-) diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index dec90d730..5e5a0e353 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -62,8 +62,8 @@ async fn graphql_endpoint_handler( &req, body_bytes, supergraph, - app_state.get_ref().clone(), - schema_state.get_ref().clone(), + app_state.get_ref(), + schema_state.get_ref(), ) .await { diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index d10fbb6c4..ab5759b5e 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::sync::Arc; use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; @@ -23,7 +22,7 @@ pub fn coerce_request_variables( req: &HttpRequest, supergraph: &SupergraphData, graphql_params: &mut GraphQLParams, - normalized_operation: &Arc, + normalized_operation: &GraphQLNormalizationPayload, ) -> Result { if req.method() == Method::GET { if let Some(OperationKind::Mutation) = diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index e40fc02c3..fae429895 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; use crate::pipeline::error::PipelineErrorVariant; @@ -24,16 +23,17 @@ enum ExposeQueryPlanMode { DryRun, } +#[allow(clippy::too_many_arguments)] #[inline] -pub async fn execute_plan<'exec, 'req>( +pub async fn execute_plan( req: &HttpRequest, supergraph: &SupergraphData, - app_state: Arc, - normalized_payload: Arc, - query_plan_payload: Arc, + app_state: &RouterSharedState, + normalized_payload: &GraphQLNormalizationPayload, + query_plan_payload: &QueryPlan, variable_payload: &CoerceVariablesPayload, - client_request_details: &ClientRequestDetails<'exec, 'req>, - plugin_manager: PluginManager<'req>, + client_request_details: &ClientRequestDetails<'_, '_>, + plugin_manager: PluginManager<'_>, ) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 860a13c91..f05745265 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -63,8 +63,8 @@ pub async fn graphql_request_handler( req: &HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, - shared_state: Arc, - schema_state: Arc, + shared_state: &RouterSharedState, + schema_state: &SchemaState, ) -> Result { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { @@ -139,14 +139,14 @@ pub async fn graphql_request_handler( #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline<'req>( - req: &'req HttpRequest, +pub async fn execute_pipeline( + req: &HttpRequest, body: Bytes, supergraph: &SupergraphData, - shared_state: Arc, - schema_state: Arc, + shared_state: &RouterSharedState, + schema_state: &SchemaState, jwt_context: Option, - plugin_manager: PluginManager<'req>, + plugin_manager: PluginManager<'_>, ) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; @@ -195,7 +195,7 @@ pub async fn execute_pipeline<'req>( /* Handle on_deserialize hook in the plugins - END */ let parser_result = - parse_operation_with_cache(shared_state.clone(), &graphql_params, &plugin_manager).await?; + parse_operation_with_cache(shared_state, &graphql_params, &plugin_manager).await?; let parser_payload = match parser_result { ParseResult::Payload(payload) => payload, @@ -206,20 +206,16 @@ pub async fn execute_pipeline<'req>( validate_operation_with_cache( supergraph, - schema_state.clone(), - shared_state.clone(), + schema_state, + shared_state, &parser_payload, &plugin_manager, ) .await?; - let normalize_payload = normalize_request_with_cache( - supergraph, - schema_state.clone(), - &graphql_params, - &parser_payload, - ) - .await?; + let normalize_payload = + normalize_request_with_cache(supergraph, schema_state, &graphql_params, &parser_payload) + .await?; let variable_payload = coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; @@ -264,11 +260,11 @@ pub async fn execute_pipeline<'req>( let query_plan_result = plan_operation_with_cache( supergraph, - schema_state.clone(), - normalize_payload.clone(), + schema_state, + &normalize_payload, &progressive_override_ctx, &query_plan_cancellation_token, - shared_state.clone(), + shared_state, &plugin_manager, ) .await?; @@ -282,9 +278,9 @@ pub async fn execute_pipeline<'req>( let execution_result = execute_plan( req, supergraph, - shared_state.clone(), - normalize_payload.clone(), - query_plan_payload, + shared_state, + &normalize_payload, + &query_plan_payload, &variable_payload, &client_request_details, plugin_manager, diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index 54093d065..97cbb80ac 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -26,7 +26,7 @@ pub struct GraphQLNormalizationPayload { #[inline] pub async fn normalize_request_with_cache( supergraph: &SupergraphData, - schema_state: Arc, + schema_state: &SchemaState, graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, ) -> Result, PipelineErrorVariant> { diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 0a11ab2aa..aebbd6beb 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -29,10 +29,10 @@ pub enum ParseResult { } #[inline] -pub async fn parse_operation_with_cache<'req>( - app_state: Arc, +pub async fn parse_operation_with_cache( + app_state: &RouterSharedState, graphql_params: &GraphQLParams, - plugin_manager: &PluginManager<'req>, + plugin_manager: &PluginManager<'_>, ) -> Result { let cache_key = { let mut hasher = Xxh3::new(); diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index d0b09c183..4743d8672 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec, 'req>( +pub fn request_override_context( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec, 'req>, + client_request_details: &ClientRequestDetails<'_, '_>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 156b8ef9f..d1f83ca7d 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -31,11 +31,11 @@ pub enum QueryPlanGetterError { #[inline] pub async fn plan_operation_with_cache<'req>( supergraph: &SupergraphData, - schema_state: Arc, - normalized_operation: Arc, + schema_state: &SchemaState, + normalized_operation: &GraphQLNormalizationPayload, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, - app_state: Arc, + app_state: &RouterSharedState, plugin_manager: &PluginManager<'req>, ) -> Result { let stable_override_context = diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index afe833656..92cb1eb6f 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -17,8 +17,8 @@ use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( supergraph: &SupergraphData, - schema_state: Arc, - app_state: Arc, + schema_state: &SchemaState, + app_state: &RouterSharedState, parser_payload: &GraphQLParserPayload, plugin_manager: &PluginManager<'_>, ) -> Result, PipelineErrorVariant> { diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index bbe5ba836..7d07f430e 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{BTreeSet, HashMap}, - sync::Arc, -}; +use std::collections::{BTreeSet, HashMap}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; @@ -56,7 +53,7 @@ use crate::{ pub struct QueryPlanExecutionContext<'exec, 'req> { pub plugin_manager: &'exec PluginManager<'exec>, - pub query_plan: Arc, + pub query_plan: &'exec QueryPlan, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, @@ -87,7 +84,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { let mut start_payload = OnExecuteStartPayload { router_http_request: &self.plugin_manager.router_http_request, context: &self.plugin_manager.context, - query_plan: &self.query_plan, + query_plan: self.query_plan, data: init_value, errors: Vec::new(), extensions: self.extensions.clone(), From 57bb0f06c5c2e2fbe16c37070553f81c29ab7293 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 16:07:22 +0300 Subject: [PATCH 11/31] Examples --- Cargo.lock | 21 +++ bin/router/src/pipeline/mod.rs | 5 +- bin/router/src/plugins/plugins_service.rs | 23 ++- lib/executor/Cargo.toml | 3 + lib/executor/src/execution/plan.rs | 2 + lib/executor/src/executors/common.rs | 1 + lib/executor/src/executors/http.rs | 5 + lib/executor/src/executors/map.rs | 8 +- .../src/plugins/examples/apollo_sandbox.rs | 155 +++++++++++++++++ .../src/plugins/examples/async_auth.rs | 113 +++++++++++++ .../src/plugins/examples/context_data.rs | 67 ++++++++ .../examples/forbid_anonymous_operations.rs | 55 ++++++ lib/executor/src/plugins/examples/mod.rs | 7 + .../src/plugins/examples/multipart.rs | 159 ++++++++++++++++++ .../plugins/examples/propagate_status_code.rs | 46 +++++ .../src/plugins/examples/response_cache.rs | 3 +- .../src/plugins/examples/root_field_limit.rs | 62 +++++++ .../src/plugins/hooks/on_graphql_params.rs | 7 +- .../src/plugins/hooks/on_http_request.rs | 5 +- .../src/plugins/hooks/on_subgraph_execute.rs | 10 +- lib/executor/src/plugins/plugin_context.rs | 109 +++++++++++- lib/executor/src/plugins/plugin_trait.rs | 4 +- 22 files changed, 842 insertions(+), 28 deletions(-) create mode 100644 lib/executor/src/plugins/examples/apollo_sandbox.rs create mode 100644 lib/executor/src/plugins/examples/async_auth.rs create mode 100644 lib/executor/src/plugins/examples/context_data.rs create mode 100644 lib/executor/src/plugins/examples/forbid_anonymous_operations.rs create mode 100644 lib/executor/src/plugins/examples/propagate_status_code.rs create mode 100644 lib/executor/src/plugins/examples/root_field_limit.rs diff --git a/Cargo.lock b/Cargo.lock index 52b09935b..43720d595 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2055,6 +2055,7 @@ dependencies = [ "criterion", "dashmap", "futures", + "futures-util", "graphql-parser", "graphql-tools", "hive-router-config", @@ -2067,11 +2068,13 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", + "multer", "ntex", "ntex-http", "ordered-float", "redis", "regex-automata", + "reqwest", "ryu", "serde", "sonic-rs", @@ -2802,6 +2805,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -4349,6 +4362,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", + "futures-util", "h2", "http", "http-body", @@ -4360,6 +4374,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -5831,6 +5846,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index f05745265..4d2f264ee 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -176,7 +176,10 @@ pub async fn execute_pipeline( .expect("Failed to parse execution request") }); - let mut payload = OnGraphQLParamsEndPayload { graphql_params }; + let mut payload = OnGraphQLParamsEndPayload { + graphql_params, + context: &plugin_manager.context, + }; for deserialization_end_callback in deserialization_end_callbacks { let result = deserialization_end_callback(payload); payload = result.payload; diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs index 49204b662..2d0d70c50 100644 --- a/bin/router/src/plugins/plugins_service.rs +++ b/bin/router/src/plugins/plugins_service.rs @@ -1,11 +1,14 @@ use std::sync::Arc; use hive_router_plan_executor::{ + execution::plan::PlanExecutionOutput, hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, plugin_context::PluginContext, plugin_trait::ControlFlowResult, }; +use http::StatusCode; use ntex::{ + http::ResponseBuilder, service::{Service, ServiceCtx}, web::{self, DefaultError}, Middleware, @@ -52,11 +55,11 @@ where let mut start_payload = OnHttpRequestPayload { router_http_request: req, context: &plugin_context, - response: None, }; let mut on_end_callbacks = vec![]; + let mut early_response: Option = None; for plugin in plugins.iter() { let result = plugin.on_http_request(start_payload); start_payload = result.payload; @@ -67,18 +70,24 @@ where ControlFlowResult::OnEnd(callback) => { on_end_callbacks.push(callback); } - ControlFlowResult::EndResponse(_response) => { - // Short-circuit the request with the provided response - unimplemented!(); + ControlFlowResult::EndResponse(response) => { + early_response = Some(response); + break; } } } let req = start_payload.router_http_request; - let response = match start_payload.response { - Some(response) => response, - None => ctx.call(&self.service, req).await?, + let response = if let Some(early_response) = early_response { + let mut builder = ResponseBuilder::new(StatusCode::OK); + for (key, value) in early_response.headers.iter() { + builder.header(key, value); + } + let res = builder.body(early_response.body); + req.into_response(res) + } else { + ctx.call(&self.service, req).await? }; let mut end_payload = OnHttpResponsePayload { response }; diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 39d51ee7e..07aefd5b6 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -30,6 +30,7 @@ xxhash-rust = { workspace = true } tokio = { workspace = true, features = ["sync"] } dashmap = { workspace = true } vrl = { workspace = true } +reqwest = { workspace = true, features = ["multipart"] } ahash = "0.8.12" regex-automata = "0.4.10" @@ -53,6 +54,8 @@ ryu = "1.0.20" indexmap = "2.10.0" bumpalo = "3.19.0" redis = "0.32.7" +multer = "3.1.0" +futures-util = "0.3.31" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 7d07f430e..b38dccbac 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -69,6 +69,7 @@ pub struct QueryPlanExecutionContext<'exec, 'req> { pub struct PlanExecutionOutput { pub body: Vec, pub headers: HeaderMap, + pub status: http::StatusCode, } impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { @@ -177,6 +178,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { Ok(PlanExecutionOutput { body, headers: response_headers, + status: http::StatusCode::OK, }) } } diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 9044c062e..6b3c804b5 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -49,4 +49,5 @@ impl SubgraphExecutionRequest<'_> { pub struct HttpExecutionResponse { pub body: Bytes, pub headers: HeaderMap, + pub status: http::StatusCode, } diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 9e2770965..f33cbaedc 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -282,6 +282,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { return HttpExecutionResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, }; } }; @@ -311,12 +312,14 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body, headers: shared_response.headers, + status: shared_response.status, }, Err(e) => { self.log_error(&e); HttpExecutionResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, } } }; @@ -362,12 +365,14 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body.clone(), headers: shared_response.headers.clone(), + status: shared_response.status, }, Err(e) => { self.log_error(&e); HttpExecutionResponse { body: self.error_to_graphql_bytes(e.clone()), headers: Default::default(), + status: StatusCode::OK, } } } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index cf6033040..bf8eb7838 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -154,6 +154,7 @@ impl SubgraphExecutorMap { return HttpExecutionResponse { body: response.body.into(), headers: response.headers, + status: response.status, }; } ControlFlowResult::OnEnd(callback) => { @@ -182,7 +183,10 @@ impl SubgraphExecutorMap { } }; - let mut end_payload = OnSubgraphExecuteEndPayload { execution_result }; + let mut end_payload = OnSubgraphExecuteEndPayload { + context: &plugin_manager.context, + execution_result, + }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -196,6 +200,7 @@ impl SubgraphExecutorMap { return HttpExecutionResponse { body: response.body.into(), headers: response.headers, + status: response.status, }; } ControlFlowResult::OnEnd(_) => { @@ -222,6 +227,7 @@ impl SubgraphExecutorMap { HttpExecutionResponse { body: buffer.freeze(), headers: Default::default(), + status: http::StatusCode::INTERNAL_SERVER_ERROR, } } diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/lib/executor/src/plugins/examples/apollo_sandbox.rs new file mode 100644 index 000000000..7c559bc8c --- /dev/null +++ b/lib/executor/src/plugins/examples/apollo_sandbox.rs @@ -0,0 +1,155 @@ +use ::serde::{Deserialize, Serialize}; +use ahash::HashMap; +use http::{HeaderMap, StatusCode}; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +#[derive(Default, Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct ApolloSandboxOptions { + /** + * The URL of the GraphQL endpoint that Sandbox introspects on initial load. Sandbox populates its pages using the schema obtained from this endpoint. + * The default value is `http://localhost:4000`. + * You should only pass non-production endpoints to Sandbox. Sandbox is powered by schema introspection, and we recommend [disabling introspection in production](https://www.apollographql.com/blog/graphql/security/why-you-should-disable-graphql-introspection-in-production/). + * To provide a "Sandbox-like" experience for production endpoints, we recommend using either a [public variant](https://www.apollographql.com/docs/graphos/platform/graph-management/variants#public-variants) or the [embedded Explorer](https://www.apollographql.com/docs/graphos/platform/explorer/embed). + */ + pub initial_endpoint: String, + /** + * By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.Set `hideCookieToggle` to `false` to enable users of your embedded Sandbox instance to toggle the **Include cookies** setting. + */ + pub hide_cookie_toggle: bool, + /** + * By default, the embedded Sandbox has a URL input box that is editable by users.Set endpointIsEditable to false to prevent users of your embedded Sandbox instance from changing the endpoint URL. + */ + pub endpoint_is_editable: bool, + /** + * You can set `includeCookies` to `true` if you instead want Sandbox to pass `{ credentials: 'include' }` for its requests.If you pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials).This config option is deprecated in favor of using the connection settings cookie toggle in Sandbox and setting the default value via `initialState.includeCookies`. + */ + pub include_cookies: bool, + /** + * An object containing additional options related to the state of the embedded Sandbox on page load. + */ + pub initial_state: ApolloSandboxInitialStateOptions, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "camelCase")] +pub struct ApolloSandboxInitialStateOptions { + /** + * Set this value to `true` if you want Sandbox to pass `{ credentials: 'include' }` for its requests by default.If you set `hideCookieToggle` to `false`, users can override this default setting with the **Include cookies** toggle. (By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.)If you also pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials). + */ + pub include_cookies: bool, + /** + * A URI-encoded operation to populate in Sandbox's editor on load.If you omit this, Sandbox initially loads an example query based on your schema.Example: + * ```js + * initialState: { + * document: ` + * query ExampleQuery { + * books { + * title + * } + * } + * ` + * } + * ``` + */ + pub document: Option, + /** + * A URI-encoded, serialized object containing initial variable values to populate in Sandbox on load.If provided, these variables should apply to the initial query you provide for [`document`](https://www.apollographql.com/docs/apollo-sandbox#document).Example: + * + * ```js + * initialState: { + * variables: { + * userID: "abc123" + * }, + * } + * ``` + */ + pub variables: Option, + /** + * A URI-encoded, serialized object containing initial HTTP header values to populate in Sandbox on load.Example: + * + * + * ```js + * initialState: { + * headers: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub headers: Option, + /** + * The ID of a collection, paired with an operation ID to populate in Sandbox on load. You can find these values from a registered graph in Studio by clicking the **...** menu next to an operation in the Explorer of that graph and selecting **View operation details**.Example: + * + * ```js + * initialState: { + * collectionId: 'abc1234', + * operationId: 'xyz1234' + * } + * ``` + */ + pub collection_id: Option, + pub operation_id: Option, + /** + * If `true`, the embedded Sandbox periodically polls your `initialEndpoint` for schema updates.The default value is `true`.Example: + * + * ```js + * initialState: { + * pollForSchemaUpdates: false; + * } + * ``` + */ + pub poll_for_schema_updates: bool, + /** + * Headers that are applied by default to every operation executed by the embedded Sandbox. Users can turn off the application of these headers, but they can't modify their values.The embedded Sandbox always includes these headers in its introspection queries to your `initialEndpoint`.Example: + * + * ```js + * initialState: { + * sharedHeaders: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub shared_headers: HashMap, +} + +pub struct ApolloSandboxPlugin { + pub options: ApolloSandboxOptions, +} + +impl RouterPlugin for ApolloSandboxPlugin { + fn on_http_request<'req>( + &self, + payload: OnHttpRequestPayload<'req>, + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { + if payload.router_http_request.path() == "/apollo-sandbox" { + let config = sonic_rs::to_string(&self.options).unwrap_or_else(|_| "{}".to_string()); + let html = format!( + r#" +
+ + + "#, + config + ); + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/html".parse().unwrap()); + return payload.end_response(PlanExecutionOutput { + body: html.into_bytes(), + headers, + status: StatusCode::OK, + }); + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/async_auth.rs b/lib/executor/src/plugins/examples/async_auth.rs new file mode 100644 index 000000000..ce6d73331 --- /dev/null +++ b/lib/executor/src/plugins/examples/async_auth.rs @@ -0,0 +1,113 @@ +use std::path::PathBuf; + +// From https://github.com/apollographql/router/blob/dev/examples/async-auth/rust/src/allow_client_id_from_file.rs +use serde::Deserialize; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +#[derive(Deserialize)] +pub struct AllowClientIdConfig { + pub header: String, + pub path: String, +} + +pub struct AllowClientIdFromFile { + header_key: String, + allowed_ids_path: PathBuf, +} + +#[async_trait::async_trait] +impl RouterPlugin for AllowClientIdFromFile { + // Whenever it is a GraphQL request, + // We don't use on_http_request here because we want to run this only when it is a GraphQL request + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let header = payload.router_http_request.headers.get(&self.header_key); + match header { + Some(client_id) => { + let client_id_str = client_id.to_str(); + match client_id_str { + Ok(client_id) => { + let allowed_clients: Vec = sonic_rs::from_str( + std::fs::read_to_string(self.allowed_ids_path.clone()) + .unwrap() + .as_str(), + ) + .unwrap(); + + if !allowed_clients.contains(&client_id.to_string()) { + // Prepare an HTTP 403 response with a GraphQL error message + let body = json!( + { + "errors": [ + { + "message": "client-id is not allowed", + "extensions": { + "code": "UNAUTHORIZED_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::FORBIDDEN, + }); + } + } + Err(_not_a_string_error) => { + let message = format!("'{}' value is not a string", self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "BAD_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + None => { + let message = format!("Missing '{}' header", self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "AUTH_ERROR" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::UNAUTHORIZED, + }); + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/context_data.rs b/lib/executor/src/plugins/examples/context_data.rs new file mode 100644 index 000000000..38265fe57 --- /dev/null +++ b/lib/executor/src/plugins/examples/context_data.rs @@ -0,0 +1,67 @@ +// From https://github.com/apollographql/router/blob/dev/examples/context/rust/src/context_data.rs + +use crate::{ + hooks::{ + on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, + plugin_context::PluginContextMutEntry, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct ContextDataPlugin {} + +pub struct ContextData { + incoming_data: String, + response_count: u64, +} + +#[async_trait::async_trait] +impl RouterPlugin for ContextDataPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let context_data = ContextData { + incoming_data: "world".to_string(), + response_count: 0, + }; + + payload.context.insert(context_data); + + payload.on_end(|payload| { + let mut ctx_data_entry = payload.context.get_mut_entry(); + let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); + if let Some(context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + let ctx_data_entry = payload.context.get_ref_entry(); + let context_data: Option<&ContextData> = ctx_data_entry.get_ref(); + if let Some(context_data) = context_data { + tracing::info!("hello {}", context_data.incoming_data); // Hello world! + let new_header_value = format!("Hello {}", context_data.incoming_data); + payload.execution_request.headers.insert( + "x-hello", + http::HeaderValue::from_str(&new_header_value).unwrap(), + ); + } + payload.on_end(|payload: OnSubgraphExecuteEndPayload<'exec>| { + let mut ctx_data_entry: PluginContextMutEntry = + payload.context.get_mut_entry(); + let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); + if let Some(context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs new file mode 100644 index 000000000..a566d0e9d --- /dev/null +++ b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs @@ -0,0 +1,55 @@ +// Same with https://github.com/apollographql/router/blob/dev/examples/forbid-anonymous-operations/rust/src/forbid_anonymous_operations.rs + +use http::StatusCode; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +pub struct ForbidAnonymousOperations {} + +#[async_trait::async_trait] +impl RouterPlugin for ForbidAnonymousOperations { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let maybe_operation_name = &payload + .graphql_params + .as_ref() + .and_then(|params| params.operation_name.as_ref()); + + if maybe_operation_name.is_none() + || maybe_operation_name + .expect("is_none() has been checked before; qed") + .is_empty() + { + // let's log the error + tracing::error!("Operation is not allowed!"); + + // Prepare an HTTP 400 response with a GraphQL error message + let response_body = json!({ + "errors": [ + { + "message": "Anonymous operations are not allowed", + "extensions": { + "code": "ANONYMOUS_OPERATION" + } + } + ] + }); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&response_body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: StatusCode::BAD_REQUEST, + }); + } else { + // we're good to go! + tracing::info!("operation is allowed!"); + return payload.cont(); + } + } +} diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index a6d766a9c..70ff81639 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -1,3 +1,10 @@ +pub mod apollo_sandbox; pub mod apq; +pub mod async_auth; +pub mod context_data; +pub mod forbid_anonymous_operations; +pub mod multipart; +pub mod propagate_status_code; pub mod response_cache; +pub mod root_field_limit; pub mod subgraph_response_cache; diff --git a/lib/executor/src/plugins/examples/multipart.rs b/lib/executor/src/plugins/examples/multipart.rs index e69de29bb..6a6162bc6 100644 --- a/lib/executor/src/plugins/examples/multipart.rs +++ b/lib/executor/src/plugins/examples/multipart.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; + +use crate::{ + executors::common::HttpExecutionResponse, + hooks::{ + on_graphql_params::{ + GraphQLParams, OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload, + }, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; +use bytes::Bytes; +use dashmap::DashMap; +use multer::Multipart; +use serde::Serialize; +pub struct MultipartPlugin {} + +pub struct MultipartFile { + pub filename: Option, + pub content_type: Option, + pub content: Bytes, +} + +pub struct MultipartContext { + pub file_map: HashMap>, + pub files: DashMap, +} + +#[derive(Serialize)] +struct MultipartOperations<'a> { + pub query: &'a str, + pub variables: Option<&'a HashMap<&'a str, &'a sonic_rs::Value>>, + pub operation_name: Option<&'a str>, +} + +#[async_trait::async_trait] +impl RouterPlugin for MultipartPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + mut payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + if let Some(content_type) = payload.router_http_request.headers.get("content-type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.starts_with("multipart/form-data") { + let boundary = multer::parse_boundary(content_type_str).unwrap(); + let body = payload.body.clone(); + let stream = futures_util::stream::once(async move { + Ok::(Bytes::from(body.to_vec())) + }); + let mut multipart = Multipart::new(stream, boundary); + while let Some(field) = multipart.next_field().await.unwrap() { + let field_name = field.name().unwrap().to_string(); + let filename = field.file_name().map(|s| s.to_string()); + let content_type = field.content_type().map(|s| s.to_string()); + let data = field.bytes().await.unwrap(); + match field_name.as_str() { + "operations" => { + let graphql_params: GraphQLParams = + sonic_rs::from_slice(&data).unwrap(); + payload.graphql_params = Some(graphql_params); + } + "map" => { + let file_map: HashMap> = + sonic_rs::from_slice(&data).unwrap(); + payload.context.insert(MultipartContext { + file_map, + files: DashMap::new(), + }); + } + field_name => { + let mut ctx_entry = payload.context.get_mut_entry(); + let multipart_ctx: Option<&mut MultipartContext> = + ctx_entry.get_ref_mut(); + if let Some(multipart_ctx) = multipart_ctx { + let multipart_file = MultipartFile { + filename, + content_type, + content: data, + }; + multipart_ctx + .files + .insert(field_name.to_string(), multipart_file); + } + } + } + } + } + } + } + payload.cont() + } + + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + if let Some(variables) = &payload.execution_request.variables { + let ctx_ref = payload.context.get_ref_entry(); + let multipart_ctx: Option<&MultipartContext> = ctx_ref.get_ref(); + if let Some(multipart_ctx) = multipart_ctx { + let mut file_map: HashMap> = HashMap::new(); + for variable_name in variables.keys() { + // Matching variables that are file references + for (files_ref, op_refs) in &multipart_ctx.file_map { + for op_ref in op_refs { + if op_ref.starts_with(format!("variables.{}", variable_name).as_str()) { + let op_refs_in_curr_map = + file_map.entry(files_ref.to_string()).or_default(); + op_refs_in_curr_map.push(op_ref.to_string()); + } + } + } + } + if !file_map.is_empty() { + let mut form = reqwest::multipart::Form::new(); + let operations_struct = MultipartOperations { + query: payload.execution_request.query, + variables: payload.execution_request.variables.as_ref(), + operation_name: payload.execution_request.operation_name, + }; + let operations = sonic_rs::to_string(&operations_struct).unwrap(); + form = form.text("operations", operations); + let file_map_str: String = sonic_rs::to_string(&file_map).unwrap(); + form = form.text("map", file_map_str); + for (file_ref, _op_refs) in file_map { + if let Some(file_field) = multipart_ctx.files.get(&file_ref) { + let mut part = + reqwest::multipart::Part::bytes(file_field.content.to_vec()); + if let Some(file_name) = &file_field.filename { + part = part.file_name(file_name.to_string()); + } + if let Some(content_type) = &file_field.content_type { + part = part.mime_str(&content_type.to_string()).unwrap(); + } + form = form.part(file_ref, part); + } + } + let resp = reqwest::Client::new() + .post("http://example.com/graphql") + // Using query as endpoint URL + .multipart(form) + .send() + .await + .unwrap(); + let headers = resp.headers().clone(); + let status = resp.status(); + let body = resp.bytes().await.unwrap(); + payload.execution_result = Some(HttpExecutionResponse { + body, + headers, + status, + }); + } + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs new file mode 100644 index 000000000..1e1b28d9d --- /dev/null +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -0,0 +1,46 @@ +// From https://github.com/apollographql/router/blob/dev/examples/status-code-propagation/rust/src/propagate_status_code.rs + +use http::StatusCode; + +use crate::{ + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct PropagateStatusCodePlugin { + pub status_codes: Vec, +} + +pub struct PropagateStatusCodeCtx { + pub status_code: StatusCode, +} + +#[async_trait::async_trait] +impl RouterPlugin for PropagateStatusCodePlugin { + async fn on_subgraph_execute<'exec>( + &'exec self, + payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload<'exec>> + { + payload.on_end(|payload| { + let status_code = payload.execution_result.status; + // if a response contains a status code we're watching... + if self.status_codes.contains(&status_code) { + // Checking if there is already a context entry + let mut ctx_entry = payload.context.get_mut_entry(); + let ctx: Option<&mut PropagateStatusCodeCtx> = ctx_entry.get_ref_mut(); + if let Some(ctx) = ctx { + // Update the status code if the new one is more severe (higher) + if status_code.as_u16() > ctx.status_code.as_u16() { + ctx.status_code = status_code; + } + } else { + // Insert a new context entry + let new_ctx = PropagateStatusCodeCtx { status_code }; + payload.context.insert(new_ctx); + } + } + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index ee92a029d..d144b07ec 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -1,5 +1,5 @@ use dashmap::DashMap; -use http::HeaderMap; +use http::{HeaderMap, StatusCode}; use redis::Commands; use crate::{ @@ -44,6 +44,7 @@ impl RouterPlugin for ResponseCachePlugin { return payload.end_response(PlanExecutionOutput { body: cached_response, headers: HeaderMap::new(), + status: StatusCode::OK, }); } return payload.on_end(move |mut payload: OnExecuteEndPayload<'exec>| { diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/lib/executor/src/plugins/examples/root_field_limit.rs new file mode 100644 index 000000000..0c57771c3 --- /dev/null +++ b/lib/executor/src/plugins/examples/root_field_limit.rs @@ -0,0 +1,62 @@ +use hive_router_query_planner::ast::selection_item::SelectionItem; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +pub struct RootFieldLimitPlugin { + pub max_root_fields: usize, +} + +#[async_trait::async_trait] +impl RouterPlugin for RootFieldLimitPlugin { + async fn on_query_plan<'exec>( + &'exec self, + payload: OnQueryPlanStartPayload<'exec>, + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { + let mut cnt = 0; + for selection in payload + .filtered_operation_for_plan + .selection_set + .items + .iter() + { + match selection { + SelectionItem::Field(_) => { + cnt += 1; + if cnt > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + cnt, self.max_root_fields + ); + tracing::warn!("{}", err_msg); + let body = json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_ROOT_FIELDS" + } + }] + }); + // Return error + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::PAYLOAD_TOO_LARGE, + }); + } + } + SelectionItem::InlineFragment(_) => { + unreachable!("Inline fragments should have been inlined before query planning"); + } + SelectionItem::FragmentSpread(_) => { + unreachable!("Fragment spreads should have been inlined before query planning"); + } + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs index d954f94e7..c69d094e4 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_params.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -100,10 +100,11 @@ pub struct OnGraphQLParamsStartPayload<'exec> { pub graphql_params: Option, } -impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} +impl<'exec> StartPayload> for OnGraphQLParamsStartPayload<'exec> {} -pub struct OnGraphQLParamsEndPayload { +pub struct OnGraphQLParamsEndPayload<'exec> { pub graphql_params: GraphQLParams, + pub context: &'exec PluginContext, } -impl EndPayload for OnGraphQLParamsEndPayload {} +impl<'exec> EndPayload for OnGraphQLParamsEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 84a683948..9964425df 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -1,4 +1,4 @@ -use ntex::web::{self, DefaultError}; +use ntex::web::{self, DefaultError, WebRequest}; use crate::{ plugin_context::PluginContext, @@ -6,9 +6,8 @@ use crate::{ }; pub struct OnHttpRequestPayload<'req> { - pub router_http_request: web::WebRequest, + pub router_http_request: WebRequest, pub context: &'req PluginContext, - pub response: Option, } impl<'req> StartPayload for OnHttpRequestPayload<'req> {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 25ec2c30b..18d037c10 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -14,10 +14,14 @@ pub struct OnSubgraphExecuteStartPayload<'exec> { pub execution_result: Option, } -impl<'exec> StartPayload for OnSubgraphExecuteStartPayload<'exec> {} +impl<'exec> StartPayload> + for OnSubgraphExecuteStartPayload<'exec> +{ +} -pub struct OnSubgraphExecuteEndPayload { +pub struct OnSubgraphExecuteEndPayload<'exec> { pub execution_result: HttpExecutionResponse, + pub context: &'exec PluginContext, } -impl EndPayload for OnSubgraphExecuteEndPayload {} +impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs index bce85a517..d5ea9421f 100644 --- a/lib/executor/src/plugins/plugin_context.rs +++ b/lib/executor/src/plugins/plugin_context.rs @@ -3,7 +3,10 @@ use std::{ sync::Arc, }; -use dashmap::DashMap; +use dashmap::{ + mapref::one::{Ref, RefMut}, + DashMap, +}; use http::Uri; use ntex::router::Path; use ntex_http::HeaderMap; @@ -22,17 +25,69 @@ pub struct RouterHttpRequest<'exec> { #[derive(Default)] pub struct PluginContext { - inner: DashMap>, + inner: DashMap>, +} + +pub struct PluginContextRefEntry<'a, T> { + pub entry: Option>>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> PluginContextRefEntry<'a, T> { + pub fn get_ref(&self) -> Option<&T> { + match &self.entry { + None => None, + Some(entry) => { + let boxed_any = entry.value(); + Some(boxed_any.downcast_ref::()?) + } + } + } +} +pub struct PluginContextMutEntry<'a, T> { + pub entry: Option>>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> PluginContextMutEntry<'a, T> { + pub fn get_ref_mut(&mut self) -> Option<&mut T> { + match &mut self.entry { + None => None, + Some(entry) => { + let boxed_any = entry.value_mut(); + Some(boxed_any.downcast_mut::()?) + } + } + } } impl PluginContext { - pub fn insert(&self, value: T) { - self.inner.insert(TypeId::of::(), Arc::new(value)); + pub fn contains(&self) -> bool { + let type_id = TypeId::of::(); + self.inner.contains_key(&type_id) } - pub fn get(&self) -> Option> { + pub fn insert(&self, value: T) -> Option> { + let type_id = TypeId::of::(); self.inner - .get(&TypeId::of::()) - .map(|v| v.clone().downcast::().ok().unwrap()) + .insert(type_id, Box::new(value)) + .and_then(|boxed_any| boxed_any.downcast::().ok()) + } + pub fn get_ref_entry(&self) -> PluginContextRefEntry<'_, T> { + let type_id = TypeId::of::(); + let entry = self.inner.get(&type_id); + PluginContextRefEntry { + entry, + phantom: std::marker::PhantomData, + } + } + pub fn get_mut_entry<'a, T: Any + Send + Sync>(&'a self) -> PluginContextMutEntry<'a, T> { + let type_id = TypeId::of::(); + let entry = self.inner.get_mut(&type_id); + + PluginContextMutEntry { + entry, + phantom: std::marker::PhantomData, + } } } @@ -41,3 +96,43 @@ pub struct PluginManager<'req> { pub router_http_request: RouterHttpRequest<'req>, pub context: Arc, } + +#[cfg(test)] +mod tests { + #[test] + fn inserts_and_gets_immut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + let entry = ctx.get_ref_entry(); + let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 42); + } + #[test] + fn inserts_and_mutates_with_mut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + { + let mut entry = ctx.get_mut_entry(); + let ctx_mut: &mut TestCtx = entry.get_ref_mut().unwrap(); + ctx_mut.value = 100; + } + + let entry = ctx.get_ref_entry(); + let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 100); + } +} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 0a3f6db32..c946bc0ed 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -92,7 +92,7 @@ pub trait RouterPlugin { start_payload.cont() } async fn on_graphql_parse<'exec>( - &self, + &'exec self, start_payload: OnGraphQLParseStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { start_payload.cont() @@ -105,7 +105,7 @@ pub trait RouterPlugin { start_payload.cont() } async fn on_query_plan<'exec>( - &self, + &'exec self, start_payload: OnQueryPlanStartPayload<'exec>, ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { start_payload.cont() From 1ac23c24b364d84c68ab58d6c2ec8cead94e85b4 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:06:56 +0300 Subject: [PATCH 12/31] Add oneOf --- bin/router/src/pipeline/execution.rs | 1 + lib/executor/src/execution/plan.rs | 12 +- lib/executor/src/plugins/examples/mod.rs | 1 + lib/executor/src/plugins/examples/one_of.rs | 134 ++++++++++++++++++ .../src/plugins/examples/root_field_limit.rs | 91 +++++++++++- lib/executor/src/plugins/hooks/on_execute.rs | 2 + lib/executor/src/plugins/plugin_trait.rs | 2 +- 7 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 lib/executor/src/plugins/examples/one_of.rs diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index fae429895..991f503c1 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -88,6 +88,7 @@ pub async fn execute_plan( let ctx = QueryPlanExecutionContext { plugin_manager: &plugin_manager, query_plan: query_plan_payload, + operation_for_plan: &normalized_payload.operation_for_plan, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, variable_values: &variable_payload.variables_map, diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index b38dccbac..ad1a46744 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -2,9 +2,12 @@ use std::collections::{BTreeSet, HashMap}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; -use hive_router_query_planner::planner::plan_nodes::{ - ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode, - QueryPlan, SequenceNode, +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + planner::plan_nodes::{ + ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, + PlanNode, QueryPlan, SequenceNode, + }, }; use http::HeaderMap; use serde::Deserialize; @@ -54,6 +57,7 @@ use crate::{ pub struct QueryPlanExecutionContext<'exec, 'req> { pub plugin_manager: &'exec PluginManager<'exec>, pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, @@ -81,11 +85,11 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { }; let dedupe_subgraph_requests = self.operation_type_name == "Query"; - let mut start_payload = OnExecuteStartPayload { router_http_request: &self.plugin_manager.router_http_request, context: &self.plugin_manager.context, query_plan: self.query_plan, + operation_for_plan: self.operation_for_plan, data: init_value, errors: Vec::new(), extensions: self.extensions.clone(), diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index 70ff81639..eee9d7f7a 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -4,6 +4,7 @@ pub mod async_auth; pub mod context_data; pub mod forbid_anonymous_operations; pub mod multipart; +pub mod one_of; pub mod propagate_status_code; pub mod response_cache; pub mod root_field_limit; diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs new file mode 100644 index 000000000..90ced2d07 --- /dev/null +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -0,0 +1,134 @@ +// This example will show `@oneOf` input type validation in two steps: +// 1. During validation step +// 2. During execution step + +// We handle execution too to validate input objects at runtime as well. + +use std::{collections::BTreeMap, sync::RwLock}; + +use crate::{ + hooks::{ + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, + on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; +use graphql_parser::{ + query::Value, + schema::{Definition, TypeDefinition}, +}; +use graphql_tools::ast::visit_document; +use graphql_tools::{ + ast::{OperationVisitor, OperationVisitorContext}, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; + +pub struct OneOfPlugin { + pub one_of_types: RwLock>, +} + +#[async_trait::async_trait] +impl RouterPlugin for OneOfPlugin { + // 1. During validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { + let rule = OneOfValidationRule { + one_of_types: self.one_of_types.read().unwrap().clone(), + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // 2. During execution step + async fn on_execute<'exec>( + &'exec self, + payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload> { + payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + for def in start_payload.new_ast.definitions.iter() { + if let Definition::TypeDefinition(TypeDefinition::InputObject(input_obj)) = def { + for directive in input_obj.directives.iter() { + if directive.name == "oneOf" { + self.one_of_types + .write() + .unwrap() + .push(input_obj.name.clone()); + } + } + } + } + start_payload.cont() + } +} + +struct OneOfValidationRule { + one_of_types: Vec, +} + +impl ValidationRule for OneOfValidationRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + op_ctx: &mut OperationVisitorContext<'_>, + validation_error_context: &mut ValidationErrorContext, + ) { + visit_document( + &mut OneOfValidation { + one_of_types: self.one_of_types.clone(), + }, + op_ctx.operation, + op_ctx, + validation_error_context, + ); + } +} + +struct OneOfValidation { + one_of_types: Vec, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for OneOfValidation { + fn enter_object_value( + &mut self, + visitor_context: &mut OperationVisitorContext<'a>, + user_context: &mut ValidationErrorContext, + fields: &BTreeMap, + ) { + if let Some(TypeDefinition::InputObject(input_type)) = visitor_context.current_input_type() + { + if self.one_of_types.contains(&input_type.name) { + let mut set_fields = vec![]; + for (field_name, field_value) in fields.iter() { + if !matches!(field_value, Value::Null) { + set_fields.push(field_name.clone()); + } + } + if set_fields.len() > 1 { + let err_msg = format!( + "Input object of type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + input_type.name, + set_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_FIELDS_SET_IN_ONEOF", + locations: vec![], + message: err_msg, + }); + } + } + } + } +} diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/lib/executor/src/plugins/examples/root_field_limit.rs index 0c57771c3..24872e737 100644 --- a/lib/executor/src/plugins/examples/root_field_limit.rs +++ b/lib/executor/src/plugins/examples/root_field_limit.rs @@ -1,18 +1,42 @@ +use graphql_tools::{ + ast::{visit_document, OperationVisitor, OperationVisitorContext, TypeDefinitionExtension}, + static_graphql, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; use hive_router_query_planner::ast::selection_item::SelectionItem; use sonic_rs::json; use crate::{ execution::plan::PlanExecutionOutput, - hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + hooks::{ + on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, + on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + }, plugin_trait::{HookResult, RouterPlugin, StartPayload}, }; -pub struct RootFieldLimitPlugin { - pub max_root_fields: usize, -} +// This example shows two ways of limiting the number of root fields in a query: +// 1. During validation step +// 2. During query planning step #[async_trait::async_trait] impl RouterPlugin for RootFieldLimitPlugin { + // Using validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { + let rule = RootFieldLimitRule { + max_root_fields: self.max_root_fields, + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // Or during query planning async fn on_query_plan<'exec>( &'exec self, payload: OnQueryPlanStartPayload<'exec>, @@ -60,3 +84,62 @@ impl RouterPlugin for RootFieldLimitPlugin { payload.cont() } } + +pub struct RootFieldLimitPlugin { + max_root_fields: usize, +} + +pub struct RootFieldLimitRule { + max_root_fields: usize, +} + +struct RootFieldSelections { + max_root_fields: usize, + count: usize, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for RootFieldSelections { + fn enter_field( + &mut self, + visitor_context: &mut OperationVisitorContext, + user_context: &mut ValidationErrorContext, + field: &static_graphql::query::Field, + ) { + let parent_type_name = visitor_context.current_parent_type().map(|t| t.name()); + if parent_type_name == Some("Query") { + self.count += 1; + if self.count > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + self.count, self.max_root_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_ROOT_FIELDS", + locations: vec![field.position], + message: err_msg, + }); + } + } + } +} + +impl ValidationRule for RootFieldLimitRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + ctx: &mut OperationVisitorContext<'_>, + error_collector: &mut ValidationErrorContext, + ) { + visit_document( + &mut RootFieldSelections { + max_root_fields: self.max_root_fields, + count: 0, + }, + ctx.operation, + ctx, + error_collector, + ); + } +} diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 328a1a8fd..b69ba3297 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use hive_router_query_planner::ast::operation::OperationDefinition; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use crate::plugin_context::{PluginContext, RouterHttpRequest}; @@ -11,6 +12,7 @@ pub struct OnExecuteStartPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, pub data: Value<'exec>, pub errors: Vec, diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index c946bc0ed..5a3e7c242 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -98,7 +98,7 @@ pub trait RouterPlugin { start_payload.cont() } async fn on_graphql_validation<'exec>( - &self, + &'exec self, start_payload: OnGraphQLValidationStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> { From f1a48ef6f333d9b65611268182999cda422c469f Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:07:06 +0300 Subject: [PATCH 13/31] Runtime error --- lib/executor/src/plugins/examples/one_of.rs | 38 +++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs index 90ced2d07..601107af1 100644 --- a/lib/executor/src/plugins/examples/one_of.rs +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -7,6 +7,7 @@ use std::{collections::BTreeMap, sync::RwLock}; use crate::{ + execution::plan::PlanExecutionOutput, hooks::{ on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, @@ -26,6 +27,7 @@ use graphql_tools::{ utils::{ValidationError, ValidationErrorContext}, }, }; +use sonic_rs::{json, JsonContainerTrait}; pub struct OneOfPlugin { pub one_of_types: RwLock>, @@ -50,6 +52,42 @@ impl RouterPlugin for OneOfPlugin { &'exec self, payload: OnExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload> { + if let (Some(variable_values), Some(variable_defs)) = ( + &payload.variable_values, + &payload.operation_for_plan.variable_definitions, + ) { + for def in variable_defs { + let variable_named_type = def.variable_type.inner_type(); + let one_of_types = self.one_of_types.read().unwrap(); + if one_of_types.contains(&variable_named_type.to_string()) { + let var_name = &def.name; + if let Some(value) = variable_values.get(var_name).and_then(|v| v.as_object()) { + let keys_num = value.len(); + if keys_num > 1 { + let err_msg = format!( + "Variable '${}' of input object type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + var_name, + variable_named_type, + keys_num + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_FIELDS_SET_IN_ONEOF" + } + }] + })) + .unwrap(), + headers: Default::default(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + } + } payload.cont() } fn on_supergraph_reload<'exec>( From 73c693ec3bc5ae7ae1ccf195ca09f7c392e379ef Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:10:24 +0300 Subject: [PATCH 14/31] Add description --- lib/executor/src/plugins/examples/one_of.rs | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs index 601107af1..738ece25a 100644 --- a/lib/executor/src/plugins/examples/one_of.rs +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -3,6 +3,47 @@ // 2. During execution step // We handle execution too to validate input objects at runtime as well. +/* + Let's say we have the following input type with `@oneOf` directive: + input PaymentMethod @oneOf { + creditCard: CreditCardInput + bankTransfer: BankTransferInput + paypal: PayPalInput + } + + During validation, if a variable of type `PaymentMethod` is provided with multiple fields set, + we will raise a validation error. + + ```graphql + mutation MakePayment { + makePayment(method: { + creditCard: { number: "1234", expiry: "12/24" }, + paypal: { email: "john@doe.com" } + }) { + success + } + } + ``` + + But since variables can be dynamic, we also validate during execution. If the input object has multiple fields set, + we return an error in the response. + + ```graphql + mutation MakePayment($method: PaymentMethod!) { + makePayment(method: $method) { + success + } + } + ``` + + with variables: + { + "method": { + "creditCard": { "number": "1234", "expiry": "12/24" }, + "paypal": { "email": "john@doe.com" } + } + } +*/ use std::{collections::BTreeMap, sync::RwLock}; From 0afbea593ca025eb0772f3f08f5a622225ad85e8 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:34:55 +0300 Subject: [PATCH 15/31] Propagate status code --- bin/router/src/plugins/plugins_service.rs | 5 ++++- .../src/plugins/examples/apollo_sandbox.rs | 4 ++-- .../plugins/examples/propagate_status_code.rs | 20 ++++++++++++++++++- .../src/plugins/hooks/on_http_request.rs | 7 ++++--- lib/executor/src/plugins/plugin_trait.rs | 4 ++-- 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs index 2d0d70c50..223a702ee 100644 --- a/bin/router/src/plugins/plugins_service.rs +++ b/bin/router/src/plugins/plugins_service.rs @@ -90,7 +90,10 @@ where ctx.call(&self.service, req).await? }; - let mut end_payload = OnHttpResponsePayload { response }; + let mut end_payload = OnHttpResponsePayload { + response, + context: &plugin_context, + }; for callback in on_end_callbacks.into_iter().rev() { let result = callback(end_payload); diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/lib/executor/src/plugins/examples/apollo_sandbox.rs index 7c559bc8c..224654bdb 100644 --- a/lib/executor/src/plugins/examples/apollo_sandbox.rs +++ b/lib/executor/src/plugins/examples/apollo_sandbox.rs @@ -125,9 +125,9 @@ pub struct ApolloSandboxPlugin { impl RouterPlugin for ApolloSandboxPlugin { fn on_http_request<'req>( - &self, + &'req self, payload: OnHttpRequestPayload<'req>, - ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { if payload.router_http_request.path() == "/apollo-sandbox" { let config = sonic_rs::to_string(&self.options).unwrap_or_else(|_| "{}".to_string()); let html = format!( diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs index 1e1b28d9d..1d519904e 100644 --- a/lib/executor/src/plugins/examples/propagate_status_code.rs +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -3,7 +3,10 @@ use http::StatusCode; use crate::{ - hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + hooks::{ + on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, }; @@ -43,4 +46,19 @@ impl RouterPlugin for PropagateStatusCodePlugin { payload.cont() }) } + fn on_http_request<'exec>( + &'exec self, + payload: OnHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponsePayload<'exec>> { + payload.on_end(|mut payload| { + // Checking if there is a context entry + let ctx_entry = payload.context.get_ref_entry(); + let ctx: Option<&PropagateStatusCodeCtx> = ctx_entry.get_ref(); + if let Some(ctx) = ctx { + // Update the HTTP response status code + *payload.response.response_mut().status_mut() = ctx.status_code; + } + payload.cont() + }) + } } diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 9964425df..e9e857a80 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -10,10 +10,11 @@ pub struct OnHttpRequestPayload<'req> { pub context: &'req PluginContext, } -impl<'req> StartPayload for OnHttpRequestPayload<'req> {} +impl<'req> StartPayload> for OnHttpRequestPayload<'req> {} -pub struct OnHttpResponsePayload { +pub struct OnHttpResponsePayload<'req> { pub response: web::WebResponse, + pub context: &'req PluginContext, } -impl EndPayload for OnHttpResponsePayload {} +impl<'req> EndPayload for OnHttpResponsePayload<'req> {} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 5a3e7c242..12292be11 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -80,9 +80,9 @@ where #[async_trait::async_trait] pub trait RouterPlugin { fn on_http_request<'req>( - &self, + &'req self, start_payload: OnHttpRequestPayload<'req>, - ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { start_payload.cont() } async fn on_graphql_params<'exec>( From 6d8fb529db854b91ef8bc549760d71c37bf59831 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Mon, 24 Nov 2025 16:59:30 +0300 Subject: [PATCH 16/31] Plugin System Configuration --- Cargo.lock | 1 + bin/router/src/lib.rs | 13 +++-- bin/router/src/main.rs | 8 ++- bin/router/src/plugins/mod.rs | 1 + bin/router/src/plugins/registry.rs | 53 +++++++++++++++++++ bin/router/src/shared_state.rs | 3 +- e2e/src/testkit.rs | 5 +- lib/executor/Cargo.toml | 1 + .../src/plugins/examples/apollo_sandbox.rs | 12 ++++- lib/executor/src/plugins/examples/apq.rs | 14 ++++- .../src/plugins/examples/async_auth.rs | 15 +++++- .../src/plugins/examples/context_data.rs | 12 ++++- .../examples/forbid_anonymous_operations.rs | 12 ++++- .../src/plugins/examples/multipart.rs | 12 ++++- lib/executor/src/plugins/examples/one_of.rs | 14 ++++- .../plugins/examples/propagate_status_code.rs | 23 +++++++- .../src/plugins/examples/response_cache.rs | 29 ++++++---- .../src/plugins/examples/root_field_limit.rs | 20 ++++++- .../examples/subgraph_response_cache.rs | 14 ++++- lib/executor/src/plugins/plugin_trait.rs | 20 +++++++ lib/router-config/src/lib.rs | 4 ++ 21 files changed, 258 insertions(+), 28 deletions(-) create mode 100644 bin/router/src/plugins/registry.rs diff --git a/Cargo.lock b/Cargo.lock index 43720d595..544be59be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2077,6 +2077,7 @@ dependencies = [ "reqwest", "ryu", "serde", + "serde_json", "sonic-rs", "strum 0.27.2", "subgraphs", diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 5e5a0e353..b7320f8d8 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -9,7 +9,7 @@ mod schema_state; mod shared_state; mod supergraph; -use std::sync::Arc; +use std::{sync::Arc}; use crate::{ background_tasks::BackgroundTasksManager, @@ -37,6 +37,7 @@ use ntex::{ web::{self, HttpRequest}, }; use tracing::{info, warn}; +pub use crate::plugins::registry::PluginRegistry; async fn graphql_endpoint_handler( req: HttpRequest, @@ -86,7 +87,9 @@ async fn graphql_endpoint_handler( } } -pub async fn router_entrypoint() -> Result<(), Box> { +pub async fn router_entrypoint( + plugin_factories: PluginRegistry +) -> Result<(), Box> { let config_path = std::env::var("ROUTER_CONFIG_FILE_PATH").ok(); let router_config = load_config(config_path)?; configure_logging(&router_config.log); @@ -94,7 +97,7 @@ pub async fn router_entrypoint() -> Result<(), Box> { let addr = router_config.http.address(); let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager).await?; + configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_factories).await?; let maybe_error = web::HttpServer::new(move || { web::App::new() @@ -118,16 +121,20 @@ pub async fn router_entrypoint() -> Result<(), Box> { pub async fn configure_app_from_config( router_config: HiveRouterConfig, bg_tasks_manager: &mut BackgroundTasksManager, + plugin_factories: PluginRegistry ) -> Result<(Arc, Arc), Box> { let jwt_runtime = match router_config.jwt.is_jwt_auth_enabled() { true => Some(JwtAuthRuntime::init(bg_tasks_manager, &router_config.jwt).await?), false => None, }; + let plugins = plugin_factories.initialize_plugins(&router_config); + let router_config_arc = Arc::new(router_config); let shared_state = Arc::new(RouterSharedState::new( router_config_arc.clone(), jwt_runtime, + plugins, )?); let schema_state = SchemaState::new_from_config( bg_tasks_manager, diff --git a/bin/router/src/main.rs b/bin/router/src/main.rs index b4f250b38..44430eef7 100644 --- a/bin/router/src/main.rs +++ b/bin/router/src/main.rs @@ -1,11 +1,15 @@ -use hive_router::router_entrypoint; +use hive_router::{PluginRegistry, router_entrypoint}; +use hive_router_plan_executor::examples::apq::APQPlugin; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[ntex::main] async fn main() -> Result<(), Box> { - match router_entrypoint().await { + let mut plugin_factories = PluginRegistry::new(); + plugin_factories.register::(); + + match router_entrypoint(plugin_factories).await { Ok(_) => Ok(()), Err(err) => { eprintln!("Failed to start Hive Router:\n {}", err); diff --git a/bin/router/src/plugins/mod.rs b/bin/router/src/plugins/mod.rs index 3753246b2..b6110a8df 100644 --- a/bin/router/src/plugins/mod.rs +++ b/bin/router/src/plugins/mod.rs @@ -1 +1,2 @@ pub mod plugins_service; +pub mod registry; \ No newline at end of file diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs new file mode 100644 index 000000000..4e2412b9c --- /dev/null +++ b/bin/router/src/plugins/registry.rs @@ -0,0 +1,53 @@ +use std::collections::HashMap; + +use hive_router_config::HiveRouterConfig; +use hive_router_plan_executor::plugin_trait::{RouterPlugin, RouterPluginWithConfig}; +use serde_json::Value; +use tracing::{info, warn}; + +pub struct PluginRegistry { + map: HashMap< + &'static str, + Box Result, serde_json::Error>>, + >, +} + +impl PluginRegistry { + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + pub fn register(&mut self) { + self.map.insert( + P::plugin_name(), + Box::new(|plugin_config: Value| Ok(P::from_config_value(plugin_config)?)), + ); + } + pub fn initialize_plugins(&self, router_config: &HiveRouterConfig) -> Vec> { + let mut plugins: Vec> = vec![]; + + for (plugin_name, plugin_config_value) in router_config.plugins.iter() { + if let Some(factory) = self.map.get(plugin_name.as_str()) { + match factory(plugin_config_value.clone()) { + Ok(plugin) => { + info!("Loaded plugin: {}", plugin_name); + plugins.push(plugin); + } + Err(err) => { + warn!( + "Failed to load plugin '{}': {}, skipping plugin", + plugin_name, err + ); + } + } + } else { + warn!( + "No factory found for plugin '{}', skipping plugin", + plugin_name + ); + } + } + plugins + } +} diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 877ffa0e3..f876dcb49 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -26,6 +26,7 @@ impl RouterSharedState { pub fn new( router_config: Arc, jwt_auth_runtime: Option, + plugins: Vec>, ) -> Result { Ok(Self { validation_plan: graphql_tools::validation::rules::default_rules_validation_plan(), @@ -38,7 +39,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, - plugins: Arc::new(vec![]), + plugins: Arc::new(plugins), }) } } diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index 638138801..4af1a12cf 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -1,8 +1,7 @@ use std::{path::PathBuf, sync::Arc, time::Duration}; use hive_router::{ - background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, - RouterSharedState, SchemaState, + PluginRegistry, RouterSharedState, SchemaState, background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app }; use hive_router_config::{load_config, parse_yaml_config, HiveRouterConfig}; use ntex::{ @@ -181,7 +180,7 @@ pub async fn init_router_from_config( > { let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager).await?; + configure_app_from_config(router_config, &mut bg_tasks_manager, PluginRegistry::new()).await?; let ntex_app = test::init_service( web::App::new() diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 07aefd5b6..0ba354798 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -31,6 +31,7 @@ tokio = { workspace = true, features = ["sync"] } dashmap = { workspace = true } vrl = { workspace = true } reqwest = { workspace = true, features = ["multipart"] } +serde_json = { workspace = true } ahash = "0.8.12" regex-automata = "0.4.10" diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/lib/executor/src/plugins/examples/apollo_sandbox.rs index 224654bdb..6febd288e 100644 --- a/lib/executor/src/plugins/examples/apollo_sandbox.rs +++ b/lib/executor/src/plugins/examples/apollo_sandbox.rs @@ -5,7 +5,7 @@ use http::{HeaderMap, StatusCode}; use crate::{ execution::plan::PlanExecutionOutput, hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, - plugin_trait::{HookResult, RouterPlugin, StartPayload}, + plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; #[derive(Default, Serialize, Deserialize, Debug, Clone)] @@ -119,6 +119,16 @@ pub struct ApolloSandboxInitialStateOptions { pub shared_headers: HashMap, } +impl RouterPluginWithConfig for ApolloSandboxPlugin { + type Config = ApolloSandboxOptions; + fn plugin_name() -> &'static str { + "apollo_sandbox" + } + fn new(config: ApolloSandboxOptions) -> Self { + ApolloSandboxPlugin { options: config } + } +} + pub struct ApolloSandboxPlugin { pub options: ApolloSandboxOptions, } diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index 91a473b5e..1843e42cb 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -3,13 +3,25 @@ use sonic_rs::{JsonContainerTrait, JsonValueTrait}; use crate::{ hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPluginWithConfig, RouterPlugin, StartPayload}, }; pub struct APQPlugin { cache: DashMap, } +impl RouterPluginWithConfig for APQPlugin { + type Config = (); + fn plugin_name() -> &'static str { + "apq_plugin" + } + fn new(_config: Self::Config) -> Self { + APQPlugin { + cache: DashMap::new(), + } + } +} + #[async_trait::async_trait] impl RouterPlugin for APQPlugin { async fn on_graphql_params<'exec>( diff --git a/lib/executor/src/plugins/examples/async_auth.rs b/lib/executor/src/plugins/examples/async_auth.rs index ce6d73331..58c7e0b3c 100644 --- a/lib/executor/src/plugins/examples/async_auth.rs +++ b/lib/executor/src/plugins/examples/async_auth.rs @@ -7,7 +7,7 @@ use sonic_rs::json; use crate::{ execution::plan::PlanExecutionOutput, hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - plugin_trait::{HookResult, RouterPlugin, StartPayload}, + plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; #[derive(Deserialize)] @@ -16,6 +16,19 @@ pub struct AllowClientIdConfig { pub path: String, } +impl RouterPluginWithConfig for AllowClientIdFromFile { + type Config = AllowClientIdConfig; + fn plugin_name() -> &'static str { + "allow_client_id_from_file" + } + fn new(config: AllowClientIdConfig) -> Self { + AllowClientIdFromFile { + header_key: config.header, + allowed_ids_path: PathBuf::from(config.path), + } + } +} + pub struct AllowClientIdFromFile { header_key: String, allowed_ids_path: PathBuf, diff --git a/lib/executor/src/plugins/examples/context_data.rs b/lib/executor/src/plugins/examples/context_data.rs index 38265fe57..25b2f34c1 100644 --- a/lib/executor/src/plugins/examples/context_data.rs +++ b/lib/executor/src/plugins/examples/context_data.rs @@ -6,7 +6,7 @@ use crate::{ on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, }, plugin_context::PluginContextMutEntry, - plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; pub struct ContextDataPlugin {} @@ -16,6 +16,16 @@ pub struct ContextData { response_count: u64, } +impl RouterPluginWithConfig for ContextDataPlugin { + type Config = (); + fn plugin_name() -> &'static str { + "context_data_plugin" + } + fn new(_config: ()) -> Self { + ContextDataPlugin {} + } +} + #[async_trait::async_trait] impl RouterPlugin for ContextDataPlugin { async fn on_graphql_params<'exec>( diff --git a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs index a566d0e9d..4fbe21099 100644 --- a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs +++ b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs @@ -6,11 +6,21 @@ use sonic_rs::json; use crate::{ execution::plan::PlanExecutionOutput, hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - plugin_trait::{HookResult, RouterPlugin, StartPayload}, + plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; pub struct ForbidAnonymousOperations {} +impl RouterPluginWithConfig for ForbidAnonymousOperations { + type Config = (); + fn plugin_name() -> &'static str { + "forbid_anonymous_operations" + } + fn new(_config: Self::Config) -> Self { + ForbidAnonymousOperations {} + } +} + #[async_trait::async_trait] impl RouterPlugin for ForbidAnonymousOperations { async fn on_graphql_params<'exec>( diff --git a/lib/executor/src/plugins/examples/multipart.rs b/lib/executor/src/plugins/examples/multipart.rs index 6a6162bc6..051fda9ce 100644 --- a/lib/executor/src/plugins/examples/multipart.rs +++ b/lib/executor/src/plugins/examples/multipart.rs @@ -8,7 +8,7 @@ use crate::{ }, on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, }, - plugin_trait::{HookResult, RouterPlugin, StartPayload}, + plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; use bytes::Bytes; use dashmap::DashMap; @@ -34,6 +34,16 @@ struct MultipartOperations<'a> { pub operation_name: Option<&'a str>, } +impl RouterPluginWithConfig for MultipartPlugin { + type Config = (); + fn plugin_name() -> &'static str { + "multipart_plugin" + } + fn new(_config: ()) -> Self { + MultipartPlugin {} + } +} + #[async_trait::async_trait] impl RouterPlugin for MultipartPlugin { async fn on_graphql_params<'exec>( diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs index 738ece25a..e0a33e7dc 100644 --- a/lib/executor/src/plugins/examples/one_of.rs +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -54,7 +54,7 @@ use crate::{ on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, }, - plugin_trait::{HookResult, RouterPlugin, StartPayload}, + plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; use graphql_parser::{ query::Value, @@ -70,6 +70,18 @@ use graphql_tools::{ }; use sonic_rs::{json, JsonContainerTrait}; +impl RouterPluginWithConfig for OneOfPlugin { + type Config = (); + fn plugin_name() -> &'static str { + "one_of_plugin" + } + fn new(_config: ()) -> Self { + OneOfPlugin { + one_of_types: RwLock::new(vec![]), + } + } +} + pub struct OneOfPlugin { pub one_of_types: RwLock>, } diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs index 1d519904e..0cc2e6e71 100644 --- a/lib/executor/src/plugins/examples/propagate_status_code.rs +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -1,15 +1,36 @@ // From https://github.com/apollographql/router/blob/dev/examples/status-code-propagation/rust/src/propagate_status_code.rs use http::StatusCode; +use serde::Deserialize; use crate::{ hooks::{ on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, }, - plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; +#[derive(Deserialize)] +pub struct PropagateStatusCodePluginConfig { + pub status_codes: Vec +} + +impl RouterPluginWithConfig for PropagateStatusCodePlugin { + type Config = PropagateStatusCodePluginConfig; + fn plugin_name() -> &'static str { + "propagate_status_code_plugin" + } + fn new(config: PropagateStatusCodePluginConfig) -> Self { + let status_codes = config + .status_codes + .into_iter() + .filter_map(|code| StatusCode::from_u16(code as u16).ok()) + .collect(); + PropagateStatusCodePlugin { status_codes } + } +} + pub struct PropagateStatusCodePlugin { pub status_codes: Vec, } diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index d144b07ec..8941d7c3b 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -1,6 +1,7 @@ use dashmap::DashMap; use http::{HeaderMap, StatusCode}; use redis::Commands; +use serde::Deserialize; use crate::{ execution::plan::PlanExecutionOutput, @@ -8,26 +9,36 @@ use crate::{ on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, }, - plugin_trait::{EndPayload, HookResult, StartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPluginWithConfig, StartPayload}, plugins::plugin_trait::RouterPlugin, utils::consts::TYPENAME_FIELD_NAME, }; -pub struct ResponseCachePlugin { - redis_client: redis::Client, - ttl_per_type: DashMap, +#[derive(Deserialize)] +pub struct ResponseCachePluginOptions { + pub redis_url: String, } -impl ResponseCachePlugin { - pub fn try_new(redis_url: &str) -> Result { - let redis_client = redis::Client::open(redis_url)?; - Ok(Self { +impl RouterPluginWithConfig for ResponseCachePlugin { + type Config = ResponseCachePluginOptions; + fn plugin_name() -> &'static str { + "response_cache_plugin" + } + fn new(config: ResponseCachePluginOptions) -> Self { + let redis_client = redis::Client::open(config.redis_url) + .expect("Failed to create Redis client"); + Self { redis_client, ttl_per_type: DashMap::new(), - }) + } } } +pub struct ResponseCachePlugin { + redis_client: redis::Client, + ttl_per_type: DashMap, +} + #[async_trait::async_trait] impl RouterPlugin for ResponseCachePlugin { async fn on_execute<'exec>( diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/lib/executor/src/plugins/examples/root_field_limit.rs index 24872e737..04aa525ec 100644 --- a/lib/executor/src/plugins/examples/root_field_limit.rs +++ b/lib/executor/src/plugins/examples/root_field_limit.rs @@ -7,6 +7,7 @@ use graphql_tools::{ }, }; use hive_router_query_planner::ast::selection_item::SelectionItem; +use serde::Deserialize; use sonic_rs::json; use crate::{ @@ -15,7 +16,7 @@ use crate::{ on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, }, - plugin_trait::{HookResult, RouterPlugin, StartPayload}, + plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; // This example shows two ways of limiting the number of root fields in a query: @@ -85,6 +86,23 @@ impl RouterPlugin for RootFieldLimitPlugin { } } +#[derive(Deserialize)] +pub struct RootFieldLimitPluginConfig { + max_root_fields: usize, +} + +impl RouterPluginWithConfig for RootFieldLimitPlugin { + type Config = RootFieldLimitPluginConfig; + fn plugin_name() -> &'static str { + "root_field_limit_plugin" + } + fn new(config: Self::Config) -> Self { + RootFieldLimitPlugin { + max_root_fields: config.max_root_fields, + } + } +} + pub struct RootFieldLimitPlugin { max_root_fields: usize, } diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 4e4b36666..e456cf3ae 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -3,9 +3,21 @@ use dashmap::DashMap; use crate::{ executors::common::HttpExecutionResponse, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, - plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; +impl RouterPluginWithConfig for SubgraphResponseCachePlugin { + type Config = (); + fn plugin_name() -> &'static str { + "subgraph_response_cache_plugin" + } + fn new(_config: ()) -> Self { + SubgraphResponseCachePlugin { + cache: DashMap::new(), + } + } +} + pub struct SubgraphResponseCachePlugin { cache: DashMap, } diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 12292be11..a802f239b 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -1,3 +1,5 @@ +use serde::de::DeserializeOwned; + use crate::execution::plan::PlanExecutionOutput; use crate::hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}; use crate::hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}; @@ -77,6 +79,24 @@ where } } +pub trait RouterPluginWithConfig where + Self: Sized, + Self: RouterPlugin, +{ + fn plugin_name() -> &'static str; + type Config: Send + Sync + DeserializeOwned; + fn new(config: Self::Config) -> Self; + fn from_config_value(value: serde_json::Value) -> serde_json::Result> + where + Self: Sized, + { + let config: Self::Config = serde_json::from_value(value)?; + Ok( + Box::new(Self::new(config)) + ) + } +} + #[async_trait::async_trait] pub trait RouterPlugin { fn on_http_request<'req>( diff --git a/lib/router-config/src/lib.rs b/lib/router-config/src/lib.rs index 537244c9e..925246c2c 100644 --- a/lib/router-config/src/lib.rs +++ b/lib/router-config/src/lib.rs @@ -92,6 +92,10 @@ pub struct HiveRouterConfig { /// Configuration for overriding labels. #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub override_labels: OverrideLabelsConfig, + + /// Configuration for plugins. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub plugins: HashMap, } #[derive(Debug, thiserror::Error)] From c5c2f06bafa624abfe2f33d1a59c419afe64b38e Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Mon, 24 Nov 2025 18:01:17 +0300 Subject: [PATCH 17/31] Better --- bin/router/src/lib.rs | 8 ++++---- bin/router/src/main.rs | 8 ++++---- bin/router/src/plugins/mod.rs | 2 +- bin/router/src/plugins/registry.rs | 13 +++++++++++-- e2e/src/testkit.rs | 6 ++++-- lib/executor/src/plugins/examples/apq.rs | 4 ++-- .../src/plugins/examples/propagate_status_code.rs | 2 +- lib/executor/src/plugins/examples/response_cache.rs | 4 ++-- lib/executor/src/plugins/plugin_trait.rs | 7 +++---- 9 files changed, 32 insertions(+), 22 deletions(-) diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index b7320f8d8..8df43c132 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -9,7 +9,7 @@ mod schema_state; mod shared_state; mod supergraph; -use std::{sync::Arc}; +use std::sync::Arc; use crate::{ background_tasks::BackgroundTasksManager, @@ -30,6 +30,7 @@ use crate::{ pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; +pub use crate::plugins::registry::PluginRegistry; use hive_router_config::{load_config, HiveRouterConfig}; use http::header::RETRY_AFTER; use ntex::{ @@ -37,7 +38,6 @@ use ntex::{ web::{self, HttpRequest}, }; use tracing::{info, warn}; -pub use crate::plugins::registry::PluginRegistry; async fn graphql_endpoint_handler( req: HttpRequest, @@ -88,7 +88,7 @@ async fn graphql_endpoint_handler( } pub async fn router_entrypoint( - plugin_factories: PluginRegistry + plugin_factories: PluginRegistry, ) -> Result<(), Box> { let config_path = std::env::var("ROUTER_CONFIG_FILE_PATH").ok(); let router_config = load_config(config_path)?; @@ -121,7 +121,7 @@ pub async fn router_entrypoint( pub async fn configure_app_from_config( router_config: HiveRouterConfig, bg_tasks_manager: &mut BackgroundTasksManager, - plugin_factories: PluginRegistry + plugin_factories: PluginRegistry, ) -> Result<(Arc, Arc), Box> { let jwt_runtime = match router_config.jwt.is_jwt_auth_enabled() { true => Some(JwtAuthRuntime::init(bg_tasks_manager, &router_config.jwt).await?), diff --git a/bin/router/src/main.rs b/bin/router/src/main.rs index 44430eef7..71d126fda 100644 --- a/bin/router/src/main.rs +++ b/bin/router/src/main.rs @@ -1,4 +1,4 @@ -use hive_router::{PluginRegistry, router_entrypoint}; +use hive_router::{router_entrypoint, PluginRegistry}; use hive_router_plan_executor::examples::apq::APQPlugin; #[global_allocator] @@ -6,10 +6,10 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[ntex::main] async fn main() -> Result<(), Box> { - let mut plugin_factories = PluginRegistry::new(); - plugin_factories.register::(); + let mut plugin_registry = PluginRegistry::new(); + plugin_registry.register::(); - match router_entrypoint(plugin_factories).await { + match router_entrypoint(plugin_registry).await { Ok(_) => Ok(()), Err(err) => { eprintln!("Failed to start Hive Router:\n {}", err); diff --git a/bin/router/src/plugins/mod.rs b/bin/router/src/plugins/mod.rs index b6110a8df..6ffeef508 100644 --- a/bin/router/src/plugins/mod.rs +++ b/bin/router/src/plugins/mod.rs @@ -1,2 +1,2 @@ pub mod plugins_service; -pub mod registry; \ No newline at end of file +pub mod registry; diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs index 4e2412b9c..905dbb4bf 100644 --- a/bin/router/src/plugins/registry.rs +++ b/bin/router/src/plugins/registry.rs @@ -12,6 +12,12 @@ pub struct PluginRegistry { >, } +impl Default for PluginRegistry { + fn default() -> Self { + Self::new() + } +} + impl PluginRegistry { pub fn new() -> Self { Self { @@ -24,7 +30,10 @@ impl PluginRegistry { Box::new(|plugin_config: Value| Ok(P::from_config_value(plugin_config)?)), ); } - pub fn initialize_plugins(&self, router_config: &HiveRouterConfig) -> Vec> { + pub fn initialize_plugins( + &self, + router_config: &HiveRouterConfig, + ) -> Vec> { let mut plugins: Vec> = vec![]; for (plugin_name, plugin_config_value) in router_config.plugins.iter() { @@ -43,7 +52,7 @@ impl PluginRegistry { } } else { warn!( - "No factory found for plugin '{}', skipping plugin", + "No plugin found registered '{}', skipping plugin", plugin_name ); } diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index 4af1a12cf..a44210edb 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -1,7 +1,8 @@ use std::{path::PathBuf, sync::Arc, time::Duration}; use hive_router::{ - PluginRegistry, RouterSharedState, SchemaState, background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app + background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, + PluginRegistry, RouterSharedState, SchemaState, }; use hive_router_config::{load_config, parse_yaml_config, HiveRouterConfig}; use ntex::{ @@ -180,7 +181,8 @@ pub async fn init_router_from_config( > { let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager, PluginRegistry::new()).await?; + configure_app_from_config(router_config, &mut bg_tasks_manager, PluginRegistry::new()) + .await?; let ntex_app = test::init_service( web::App::new() diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index 1843e42cb..49d9f6f04 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -3,7 +3,7 @@ use sonic_rs::{JsonContainerTrait, JsonValueTrait}; use crate::{ hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - plugin_trait::{EndPayload, HookResult, RouterPluginWithConfig, RouterPlugin, StartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; pub struct APQPlugin { @@ -17,7 +17,7 @@ impl RouterPluginWithConfig for APQPlugin { } fn new(_config: Self::Config) -> Self { APQPlugin { - cache: DashMap::new(), + cache: DashMap::new(), } } } diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs index 0cc2e6e71..1f5e48592 100644 --- a/lib/executor/src/plugins/examples/propagate_status_code.rs +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -13,7 +13,7 @@ use crate::{ #[derive(Deserialize)] pub struct PropagateStatusCodePluginConfig { - pub status_codes: Vec + pub status_codes: Vec, } impl RouterPluginWithConfig for PropagateStatusCodePlugin { diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index 8941d7c3b..aff44b0dd 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -25,8 +25,8 @@ impl RouterPluginWithConfig for ResponseCachePlugin { "response_cache_plugin" } fn new(config: ResponseCachePluginOptions) -> Self { - let redis_client = redis::Client::open(config.redis_url) - .expect("Failed to create Redis client"); + let redis_client = + redis::Client::open(config.redis_url).expect("Failed to create Redis client"); Self { redis_client, ttl_per_type: DashMap::new(), diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index a802f239b..ad1ae32e4 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -79,7 +79,8 @@ where } } -pub trait RouterPluginWithConfig where +pub trait RouterPluginWithConfig +where Self: Sized, Self: RouterPlugin, { @@ -91,9 +92,7 @@ pub trait RouterPluginWithConfig where Self: Sized, { let config: Self::Config = serde_json::from_value(value)?; - Ok( - Box::new(Self::new(config)) - ) + Ok(Box::new(Self::new(config))) } } From fb17a7b38fa1307b7dfdda0456290ada8179f6ef Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Mon, 24 Nov 2025 18:18:17 +0300 Subject: [PATCH 18/31] Better --- bin/router/src/plugins/registry.rs | 17 ++++++++++++--- .../src/plugins/examples/apollo_sandbox.rs | 10 +++++++-- lib/executor/src/plugins/examples/apq.rs | 18 ++++++++++++---- .../src/plugins/examples/async_auth.rs | 19 ++++++++++------- .../src/plugins/examples/context_data.rs | 17 ++++++++++++--- .../examples/forbid_anonymous_operations.rs | 21 +++++++++++++------ .../src/plugins/examples/multipart.rs | 17 +++++++++++---- lib/executor/src/plugins/examples/one_of.rs | 18 ++++++++++++---- .../plugins/examples/propagate_status_code.rs | 8 +++++-- .../src/plugins/examples/response_cache.rs | 10 ++++++--- .../src/plugins/examples/root_field_limit.rs | 10 ++++++--- .../examples/subgraph_response_cache.rs | 18 ++++++++++++---- lib/executor/src/plugins/plugin_trait.rs | 10 ++++++--- 13 files changed, 145 insertions(+), 48 deletions(-) diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs index 905dbb4bf..261bb0b52 100644 --- a/bin/router/src/plugins/registry.rs +++ b/bin/router/src/plugins/registry.rs @@ -8,7 +8,9 @@ use tracing::{info, warn}; pub struct PluginRegistry { map: HashMap< &'static str, - Box Result, serde_json::Error>>, + Box< + dyn Fn(Value) -> Result>, serde_json::Error>, + >, >, } @@ -27,7 +29,13 @@ impl PluginRegistry { pub fn register(&mut self) { self.map.insert( P::plugin_name(), - Box::new(|plugin_config: Value| Ok(P::from_config_value(plugin_config)?)), + Box::new(|plugin_config: Value| { + let config: P::Config = serde_json::from_value(plugin_config)?; + match P::from_config(config) { + Some(plugin) => Ok(Some(Box::new(plugin))), + None => Ok(None), + } + }), ); } pub fn initialize_plugins( @@ -41,7 +49,10 @@ impl PluginRegistry { match factory(plugin_config_value.clone()) { Ok(plugin) => { info!("Loaded plugin: {}", plugin_name); - plugins.push(plugin); + match plugin { + Some(plugin) => plugins.push(plugin), + None => info!("Plugin '{}' is disabled, skipping", plugin_name), + } } Err(err) => { warn!( diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/lib/executor/src/plugins/examples/apollo_sandbox.rs index 6febd288e..787f71cf6 100644 --- a/lib/executor/src/plugins/examples/apollo_sandbox.rs +++ b/lib/executor/src/plugins/examples/apollo_sandbox.rs @@ -11,6 +11,7 @@ use crate::{ #[derive(Default, Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct ApolloSandboxOptions { + pub enabled: bool, /** * The URL of the GraphQL endpoint that Sandbox introspects on initial load. Sandbox populates its pages using the schema obtained from this endpoint. * The default value is `http://localhost:4000`. @@ -39,6 +40,7 @@ pub struct ApolloSandboxOptions { #[derive(Serialize, Deserialize, Debug, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct ApolloSandboxInitialStateOptions { + pub enabled: bool, /** * Set this value to `true` if you want Sandbox to pass `{ credentials: 'include' }` for its requests by default.If you set `hideCookieToggle` to `false`, users can override this default setting with the **Include cookies** toggle. (By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.)If you also pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials). */ @@ -124,8 +126,12 @@ impl RouterPluginWithConfig for ApolloSandboxPlugin { fn plugin_name() -> &'static str { "apollo_sandbox" } - fn new(config: ApolloSandboxOptions) -> Self { - ApolloSandboxPlugin { options: config } + fn from_config(config: ApolloSandboxOptions) -> Option { + if config.enabled { + Some(ApolloSandboxPlugin { options: config }) + } else { + None + } } } diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index 49d9f6f04..32b5ac6a5 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -1,4 +1,5 @@ use dashmap::DashMap; +use serde::Deserialize; use sonic_rs::{JsonContainerTrait, JsonValueTrait}; use crate::{ @@ -6,18 +7,27 @@ use crate::{ plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; +#[derive(Deserialize)] +pub struct APQPluginConfig { + pub enabled: bool, +} + pub struct APQPlugin { cache: DashMap, } impl RouterPluginWithConfig for APQPlugin { - type Config = (); + type Config = APQPluginConfig; fn plugin_name() -> &'static str { "apq_plugin" } - fn new(_config: Self::Config) -> Self { - APQPlugin { - cache: DashMap::new(), + fn from_config(config: Self::Config) -> Option { + if config.enabled { + Some(APQPlugin { + cache: DashMap::new(), + }) + } else { + None } } } diff --git a/lib/executor/src/plugins/examples/async_auth.rs b/lib/executor/src/plugins/examples/async_auth.rs index 58c7e0b3c..c72fbac6e 100644 --- a/lib/executor/src/plugins/examples/async_auth.rs +++ b/lib/executor/src/plugins/examples/async_auth.rs @@ -12,30 +12,35 @@ use crate::{ #[derive(Deserialize)] pub struct AllowClientIdConfig { + pub enabled: bool, pub header: String, pub path: String, } -impl RouterPluginWithConfig for AllowClientIdFromFile { +impl RouterPluginWithConfig for AllowClientIdFromFilePlugin { type Config = AllowClientIdConfig; fn plugin_name() -> &'static str { "allow_client_id_from_file" } - fn new(config: AllowClientIdConfig) -> Self { - AllowClientIdFromFile { - header_key: config.header, - allowed_ids_path: PathBuf::from(config.path), + fn from_config(config: AllowClientIdConfig) -> Option { + if config.enabled { + Some(AllowClientIdFromFilePlugin { + header_key: config.header, + allowed_ids_path: PathBuf::from(config.path), + }) + } else { + None } } } -pub struct AllowClientIdFromFile { +pub struct AllowClientIdFromFilePlugin { header_key: String, allowed_ids_path: PathBuf, } #[async_trait::async_trait] -impl RouterPlugin for AllowClientIdFromFile { +impl RouterPlugin for AllowClientIdFromFilePlugin { // Whenever it is a GraphQL request, // We don't use on_http_request here because we want to run this only when it is a GraphQL request async fn on_graphql_params<'exec>( diff --git a/lib/executor/src/plugins/examples/context_data.rs b/lib/executor/src/plugins/examples/context_data.rs index 25b2f34c1..d7b89cdee 100644 --- a/lib/executor/src/plugins/examples/context_data.rs +++ b/lib/executor/src/plugins/examples/context_data.rs @@ -1,5 +1,7 @@ // From https://github.com/apollographql/router/blob/dev/examples/context/rust/src/context_data.rs +use serde::Deserialize; + use crate::{ hooks::{ on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, @@ -9,6 +11,11 @@ use crate::{ plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; +#[derive(Deserialize)] +pub struct ContextDataPluginConfig { + pub enabled: bool, +} + pub struct ContextDataPlugin {} pub struct ContextData { @@ -17,12 +24,16 @@ pub struct ContextData { } impl RouterPluginWithConfig for ContextDataPlugin { - type Config = (); + type Config = ContextDataPluginConfig; fn plugin_name() -> &'static str { "context_data_plugin" } - fn new(_config: ()) -> Self { - ContextDataPlugin {} + fn from_config(config: ContextDataPluginConfig) -> Option { + if config.enabled { + Some(ContextDataPlugin {}) + } else { + None + } } } diff --git a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs index 4fbe21099..9b14976d8 100644 --- a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs +++ b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs @@ -1,6 +1,7 @@ // Same with https://github.com/apollographql/router/blob/dev/examples/forbid-anonymous-operations/rust/src/forbid_anonymous_operations.rs use http::StatusCode; +use serde::Deserialize; use sonic_rs::json; use crate::{ @@ -9,20 +10,28 @@ use crate::{ plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; -pub struct ForbidAnonymousOperations {} +#[derive(Deserialize)] +pub struct ForbidAnonymousOperationsPluginConfig { + pub enabled: bool, +} +pub struct ForbidAnonymousOperationsPlugin {} -impl RouterPluginWithConfig for ForbidAnonymousOperations { - type Config = (); +impl RouterPluginWithConfig for ForbidAnonymousOperationsPlugin { + type Config = ForbidAnonymousOperationsPluginConfig; fn plugin_name() -> &'static str { "forbid_anonymous_operations" } - fn new(_config: Self::Config) -> Self { - ForbidAnonymousOperations {} + fn from_config(config: Self::Config) -> Option { + if config.enabled { + Some(ForbidAnonymousOperationsPlugin {}) + } else { + None + } } } #[async_trait::async_trait] -impl RouterPlugin for ForbidAnonymousOperations { +impl RouterPlugin for ForbidAnonymousOperationsPlugin { async fn on_graphql_params<'exec>( &'exec self, payload: OnGraphQLParamsStartPayload<'exec>, diff --git a/lib/executor/src/plugins/examples/multipart.rs b/lib/executor/src/plugins/examples/multipart.rs index 051fda9ce..4e6085851 100644 --- a/lib/executor/src/plugins/examples/multipart.rs +++ b/lib/executor/src/plugins/examples/multipart.rs @@ -13,7 +13,12 @@ use crate::{ use bytes::Bytes; use dashmap::DashMap; use multer::Multipart; -use serde::Serialize; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize)] +pub struct MultipartPluginConfig { + pub enabled: bool, +} pub struct MultipartPlugin {} pub struct MultipartFile { @@ -35,12 +40,16 @@ struct MultipartOperations<'a> { } impl RouterPluginWithConfig for MultipartPlugin { - type Config = (); + type Config = MultipartPluginConfig; fn plugin_name() -> &'static str { "multipart_plugin" } - fn new(_config: ()) -> Self { - MultipartPlugin {} + fn from_config(config: MultipartPluginConfig) -> Option { + if config.enabled { + Some(MultipartPlugin {}) + } else { + None + } } } diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs index e0a33e7dc..a85e4ed48 100644 --- a/lib/executor/src/plugins/examples/one_of.rs +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -68,16 +68,26 @@ use graphql_tools::{ utils::{ValidationError, ValidationErrorContext}, }, }; +use serde::Deserialize; use sonic_rs::{json, JsonContainerTrait}; +#[derive(Deserialize)] +pub struct OneOfPluginConfig { + pub enabled: bool, +} + impl RouterPluginWithConfig for OneOfPlugin { - type Config = (); + type Config = OneOfPluginConfig; fn plugin_name() -> &'static str { "one_of_plugin" } - fn new(_config: ()) -> Self { - OneOfPlugin { - one_of_types: RwLock::new(vec![]), + fn from_config(config: OneOfPluginConfig) -> Option { + if config.enabled { + Some(OneOfPlugin { + one_of_types: RwLock::new(vec![]), + }) + } else { + None } } } diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs index 1f5e48592..0e73fd2dd 100644 --- a/lib/executor/src/plugins/examples/propagate_status_code.rs +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -13,6 +13,7 @@ use crate::{ #[derive(Deserialize)] pub struct PropagateStatusCodePluginConfig { + pub enabled: bool, pub status_codes: Vec, } @@ -21,13 +22,16 @@ impl RouterPluginWithConfig for PropagateStatusCodePlugin { fn plugin_name() -> &'static str { "propagate_status_code_plugin" } - fn new(config: PropagateStatusCodePluginConfig) -> Self { + fn from_config(config: PropagateStatusCodePluginConfig) -> Option { + if !config.enabled { + return None; + } let status_codes = config .status_codes .into_iter() .filter_map(|code| StatusCode::from_u16(code as u16).ok()) .collect(); - PropagateStatusCodePlugin { status_codes } + Some(PropagateStatusCodePlugin { status_codes }) } } diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index aff44b0dd..25a1dbb03 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -16,6 +16,7 @@ use crate::{ #[derive(Deserialize)] pub struct ResponseCachePluginOptions { + pub enabled: bool, pub redis_url: String, } @@ -24,13 +25,16 @@ impl RouterPluginWithConfig for ResponseCachePlugin { fn plugin_name() -> &'static str { "response_cache_plugin" } - fn new(config: ResponseCachePluginOptions) -> Self { + fn from_config(config: ResponseCachePluginOptions) -> Option { + if !config.enabled { + return None; + } let redis_client = redis::Client::open(config.redis_url).expect("Failed to create Redis client"); - Self { + Some(Self { redis_client, ttl_per_type: DashMap::new(), - } + }) } } diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/lib/executor/src/plugins/examples/root_field_limit.rs index 04aa525ec..589f8e3eb 100644 --- a/lib/executor/src/plugins/examples/root_field_limit.rs +++ b/lib/executor/src/plugins/examples/root_field_limit.rs @@ -88,6 +88,7 @@ impl RouterPlugin for RootFieldLimitPlugin { #[derive(Deserialize)] pub struct RootFieldLimitPluginConfig { + enabled: bool, max_root_fields: usize, } @@ -96,10 +97,13 @@ impl RouterPluginWithConfig for RootFieldLimitPlugin { fn plugin_name() -> &'static str { "root_field_limit_plugin" } - fn new(config: Self::Config) -> Self { - RootFieldLimitPlugin { - max_root_fields: config.max_root_fields, + fn from_config(config: Self::Config) -> Option { + if !config.enabled { + return None; } + Some(RootFieldLimitPlugin { + max_root_fields: config.max_root_fields, + }) } } diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index e456cf3ae..54464606c 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,4 +1,5 @@ use dashmap::DashMap; +use serde::Deserialize; use crate::{ executors::common::HttpExecutionResponse, @@ -6,14 +7,23 @@ use crate::{ plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; +#[derive(Deserialize)] +pub struct SubgraphResponseCachePluginConfig { + enabled: bool, +} + impl RouterPluginWithConfig for SubgraphResponseCachePlugin { - type Config = (); + type Config = SubgraphResponseCachePluginConfig; fn plugin_name() -> &'static str { "subgraph_response_cache_plugin" } - fn new(_config: ()) -> Self { - SubgraphResponseCachePlugin { - cache: DashMap::new(), + fn from_config(config: SubgraphResponseCachePluginConfig) -> Option { + if config.enabled { + Some(SubgraphResponseCachePlugin { + cache: DashMap::new(), + }) + } else { + None } } } diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index ad1ae32e4..5f6b4a0c8 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -86,13 +86,17 @@ where { fn plugin_name() -> &'static str; type Config: Send + Sync + DeserializeOwned; - fn new(config: Self::Config) -> Self; - fn from_config_value(value: serde_json::Value) -> serde_json::Result> + fn from_config(config: Self::Config) -> Option; + fn from_config_value(value: serde_json::Value) -> serde_json::Result>> where Self: Sized, { let config: Self::Config = serde_json::from_value(value)?; - Ok(Box::new(Self::new(config))) + let plugin = Self::from_config(config); + match plugin { + None => Ok(None), + Some(plugin) => Ok(Some(Box::new(plugin))), + } } } From 8fa6e2cde708e3fbfb28de9de3b7373e0537912a Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 25 Nov 2025 18:36:27 +0300 Subject: [PATCH 19/31] E2E Tests --- Cargo.lock | 15 +- Cargo.toml | 1 + bin/router/src/lib.rs | 11 +- bin/router/src/main.rs | 8 +- bin/router/src/pipeline/execution.rs | 6 +- bin/router/src/pipeline/mod.rs | 133 ++++++++++-------- bin/router/src/pipeline/parser.rs | 50 ++++--- bin/router/src/pipeline/query_plan.rs | 63 +++++---- bin/router/src/pipeline/validation.rs | 66 +++++---- bin/router/src/plugins/plugins_service.rs | 5 +- bin/router/src/plugins/registry.rs | 3 +- bin/router/src/schema_state.rs | 42 +++--- bin/router/src/shared_state.rs | 6 +- e2e/Cargo.toml | 13 ++ e2e/src/file_supergraph.rs | 4 +- e2e/src/hive_cdn_supergraph.rs | 12 +- e2e/src/jwt.rs | 20 +-- e2e/src/lib.rs | 2 + e2e/src/override_subgraph_urls.rs | 3 +- .../src/plugins}/apollo_sandbox.rs | 59 ++++++-- .../examples => e2e/src/plugins}/apq.rs | 2 +- .../src/plugins}/async_auth.rs | 2 +- .../src/plugins}/context_data.rs | 2 +- .../plugins}/forbid_anonymous_operations.rs | 2 +- .../examples => e2e/src/plugins}/mod.rs | 0 .../examples => e2e/src/plugins}/multipart.rs | 2 +- .../examples => e2e/src/plugins}/one_of.rs | 2 +- .../src/plugins}/propagate_status_code.rs | 2 +- .../src/plugins}/response_cache.rs | 2 +- .../src/plugins}/root_field_limit.rs | 2 +- .../src/plugins}/subgraph_response_cache.rs | 2 +- e2e/src/probes.rs | 4 +- e2e/src/supergraph.rs | 4 +- e2e/src/testkit.rs | 13 +- lib/executor/Cargo.toml | 5 +- lib/executor/src/execution/plan.rs | 126 ++++++++++------- lib/executor/src/executors/http.rs | 52 +++---- lib/executor/src/executors/map.rs | 115 ++++++++------- .../plugins/hooks/on_graphql_validation.rs | 8 +- .../src/plugins/hooks/on_subgraph_execute.rs | 2 +- lib/executor/src/plugins/mod.rs | 1 - lib/executor/src/plugins/plugin_context.rs | 2 +- 42 files changed, 494 insertions(+), 380 deletions(-) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/apollo_sandbox.rs (80%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/apq.rs (98%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/async_auth.rs (99%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/context_data.rs (98%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/forbid_anonymous_operations.rs (98%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/mod.rs (100%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/multipart.rs (99%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/one_of.rs (99%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/propagate_status_code.rs (99%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/response_cache.rs (99%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/root_field_limit.rs (99%) rename {lib/executor/src/plugins/examples => e2e/src/plugins}/subgraph_response_cache.rs (98%) diff --git a/Cargo.lock b/Cargo.lock index 544be59be..b5d240160 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1346,14 +1346,26 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" name = "e2e" version = "0.0.1" dependencies = [ + "async-trait", + "bytes", + "dashmap", + "futures-util", + "graphql-parser", + "graphql-tools", "hive-router", "hive-router-config", + "hive-router-plan-executor", + "hive-router-query-planner", + "http", "insta", "jsonwebtoken", "lazy_static", "mockito", + "multer", "ntex", + "redis", "reqwest", + "serde", "sonic-rs", "subgraphs", "tempfile", @@ -2055,7 +2067,6 @@ dependencies = [ "criterion", "dashmap", "futures", - "futures-util", "graphql-parser", "graphql-tools", "hive-router-config", @@ -2068,11 +2079,9 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", - "multer", "ntex", "ntex-http", "ordered-float", - "redis", "regex-automata", "reqwest", "ryu", diff --git a/Cargo.toml b/Cargo.toml index 3b3d217ed..9a2d8e04b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,3 +59,4 @@ retry-policies = "0.4.0" reqwest-retry = "0.7.0" reqwest-middleware = "0.4.2" vrl = { version = "0.28.0", features = ["compiler", "parser", "value", "diagnostic", "stdlib", "core"] } +bytes = "1.10.1" \ No newline at end of file diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 8df43c132..071bcd0b7 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -4,7 +4,7 @@ mod http_utils; mod jwt; mod logger; mod pipeline; -mod plugins; +pub mod plugins; mod schema_state; mod shared_state; mod supergraph; @@ -88,7 +88,7 @@ async fn graphql_endpoint_handler( } pub async fn router_entrypoint( - plugin_factories: PluginRegistry, + plugin_registry: Option, ) -> Result<(), Box> { let config_path = std::env::var("ROUTER_CONFIG_FILE_PATH").ok(); let router_config = load_config(config_path)?; @@ -97,7 +97,7 @@ pub async fn router_entrypoint( let addr = router_config.http.address(); let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_factories).await?; + configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_registry).await?; let maybe_error = web::HttpServer::new(move || { web::App::new() @@ -121,14 +121,15 @@ pub async fn router_entrypoint( pub async fn configure_app_from_config( router_config: HiveRouterConfig, bg_tasks_manager: &mut BackgroundTasksManager, - plugin_factories: PluginRegistry, + plugin_registry: Option, ) -> Result<(Arc, Arc), Box> { let jwt_runtime = match router_config.jwt.is_jwt_auth_enabled() { true => Some(JwtAuthRuntime::init(bg_tasks_manager, &router_config.jwt).await?), false => None, }; - let plugins = plugin_factories.initialize_plugins(&router_config); + let plugins = + plugin_registry.map(|plugin_registry| plugin_registry.initialize_plugins(&router_config)); let router_config_arc = Arc::new(router_config); let shared_state = Arc::new(RouterSharedState::new( diff --git a/bin/router/src/main.rs b/bin/router/src/main.rs index 71d126fda..d46192b27 100644 --- a/bin/router/src/main.rs +++ b/bin/router/src/main.rs @@ -1,15 +1,11 @@ -use hive_router::{router_entrypoint, PluginRegistry}; -use hive_router_plan_executor::examples::apq::APQPlugin; +use hive_router::{router_entrypoint}; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[ntex::main] async fn main() -> Result<(), Box> { - let mut plugin_registry = PluginRegistry::new(); - plugin_registry.register::(); - - match router_entrypoint(plugin_registry).await { + match router_entrypoint(None).await { Ok(_) => Ok(()), Err(err) => { eprintln!("Failed to start Hive Router:\n {}", err); diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 991f503c1..75fbf499b 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -9,7 +9,7 @@ use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; -use hive_router_plan_executor::plugin_context::PluginManager; +use hive_router_plan_executor::plugin_context::PluginRequestState; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use http::HeaderName; use ntex::web::HttpRequest; @@ -33,7 +33,7 @@ pub async fn execute_plan( query_plan_payload: &QueryPlan, variable_payload: &CoerceVariablesPayload, client_request_details: &ClientRequestDetails<'_, '_>, - plugin_manager: PluginManager<'_>, + plugin_req_state: &Option>, ) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; @@ -86,7 +86,7 @@ pub async fn execute_plan( }; let ctx = QueryPlanExecutionContext { - plugin_manager: &plugin_manager, + plugin_req_state: &plugin_req_state, query_plan: query_plan_payload, operation_for_plan: &normalized_payload.operation_for_plan, projection_plan: &normalized_payload.projection_plan, diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 4d2f264ee..529154efa 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -9,7 +9,7 @@ use hive_router_plan_executor::{ on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, on_supergraph_load::SupergraphData, }, - plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, + plugin_context::{PluginContext, PluginRequestState, RouterHttpRequest}, plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ @@ -92,25 +92,29 @@ pub async fn graphql_request_handler( &APPLICATION_JSON }; - let plugin_context = req - .extensions() - .get::>() - .cloned() - .expect("Plugin manager should be loaded"); + let mut plugin_req_state = None; + if let Some(plugins) = shared_state.plugins.as_ref() { + let plugin_context = req + .extensions() + .get::>() + .cloned() + .expect("Plugin manager should be loaded"); + + plugin_req_state = Some(PluginRequestState { + plugins: plugins.clone(), + router_http_request: RouterHttpRequest { + uri: req.uri(), + method: req.method(), + version: req.version(), + headers: req.headers(), + match_info: req.match_info(), + query_string: req.query_string(), + path: req.path(), + }, + context: plugin_context, + }); + } - let plugin_manager = PluginManager { - plugins: shared_state.plugins.clone(), - router_http_request: RouterHttpRequest { - uri: req.uri(), - method: req.method(), - version: req.version(), - headers: req.headers(), - match_info: req.match_info(), - query_string: req.query_string(), - path: req.path(), - }, - context: plugin_context, - }; let response = execute_pipeline( req, @@ -119,7 +123,7 @@ pub async fn graphql_request_handler( shared_state, schema_state, jwt_context, - plugin_manager, + plugin_req_state, ) .await?; let response_bytes = Bytes::from(response.body); @@ -146,59 +150,69 @@ pub async fn execute_pipeline( shared_state: &RouterSharedState, schema_state: &SchemaState, jwt_context: Option, - plugin_manager: PluginManager<'_>, + plugin_req_state: Option>, ) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; /* Handle on_deserialize hook in the plugins - START */ let mut deserialization_end_callbacks = vec![]; - let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { - router_http_request: &plugin_manager.router_http_request, - context: &plugin_manager.context, - body, - graphql_params: None, - }; - for plugin in shared_state.plugins.as_ref() { - let result = plugin.on_graphql_params(deserialization_payload).await; - deserialization_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { - return Ok(response); - } - ControlFlowResult::OnEnd(callback) => { - deserialization_end_callbacks.push(callback); + + let mut graphql_params = None; + let mut body = body; + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + body, + graphql_params: None, + }; + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_graphql_params(deserialization_payload).await; + deserialization_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(callback) => { + deserialization_end_callbacks.push(callback); + } } } + graphql_params = deserialization_payload.graphql_params; + body = deserialization_payload.body; } - let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { - deserialize_graphql_params(req, deserialization_payload.body) + let mut graphql_params = graphql_params.unwrap_or_else(|| { + deserialize_graphql_params(req, body) .expect("Failed to parse execution request") }); - let mut payload = OnGraphQLParamsEndPayload { - graphql_params, - context: &plugin_manager.context, - }; - for deserialization_end_callback in deserialization_end_callbacks { - let result = deserialization_end_callback(payload); - payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { - return Ok(response); - } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); + if let Some(plugin_req_state) = &plugin_req_state { + let mut payload = OnGraphQLParamsEndPayload { + graphql_params, + context: &plugin_req_state.context, + }; + for deserialization_end_callback in deserialization_end_callbacks { + let result = deserialization_end_callback(payload); + payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } } } + graphql_params = payload.graphql_params; } - let mut graphql_params = payload.graphql_params; + /* Handle on_deserialize hook in the plugins - END */ let parser_result = - parse_operation_with_cache(shared_state, &graphql_params, &plugin_manager).await?; + parse_operation_with_cache(shared_state, &graphql_params, &plugin_req_state).await?; let parser_payload = match parser_result { ParseResult::Payload(payload) => payload, @@ -212,7 +226,7 @@ pub async fn execute_pipeline( schema_state, shared_state, &parser_payload, - &plugin_manager, + &plugin_req_state, ) .await?; @@ -267,8 +281,7 @@ pub async fn execute_pipeline( &normalize_payload, &progressive_override_ctx, &query_plan_cancellation_token, - shared_state, - &plugin_manager, + &plugin_req_state, ) .await?; let query_plan_payload = match query_plan_result { @@ -286,7 +299,7 @@ pub async fn execute_pipeline( &query_plan_payload, &variable_payload, &client_request_details, - plugin_manager, + &plugin_req_state, ) .await?; diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index aebbd6beb..933640abe 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -7,7 +7,7 @@ use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_graphql_parse::{ OnGraphQLParseEndPayload, OnGraphQLParseStartPayload, }; -use hive_router_plan_executor::plugin_context::PluginManager; +use hive_router_plan_executor::plugin_context::PluginRequestState; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; use xxhash_rust::xxh3::Xxh3; @@ -32,44 +32,48 @@ pub enum ParseResult { pub async fn parse_operation_with_cache( app_state: &RouterSharedState, graphql_params: &GraphQLParams, - plugin_manager: &PluginManager<'_>, + plugin_req_state: &Option>, ) -> Result { let cache_key = { let mut hasher = Xxh3::new(); graphql_params.query.hash(&mut hasher); hasher.finish() }; - /* Handle on_graphql_parse hook in the plugins - START */ - let mut start_payload = OnGraphQLParseStartPayload { - router_http_request: &plugin_manager.router_http_request, - context: &plugin_manager.context, - graphql_params, - document: None, - }; let parsed_operation = if let Some(cached) = app_state.parse_cache.get(&cache_key).await { trace!("Found cached parsed operation for query"); cached } else { + let mut document = None; let mut on_end_callbacks = vec![]; - for plugin in app_state.plugins.as_ref() { - let result = plugin.on_graphql_parse(start_payload).await; - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { - // continue to next plugin - } - ControlFlowResult::EndResponse(response) => { - return Ok(ParseResult::Response(response)); - } - ControlFlowResult::OnEnd(callback) => { - // store the callback to be called later - on_end_callbacks.push(callback); + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + /* Handle on_graphql_parse hook in the plugins - START */ + let mut start_payload = OnGraphQLParseStartPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + graphql_params, + document, + }; + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_graphql_parse(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Ok(ParseResult::Response(response)); + } + ControlFlowResult::OnEnd(callback) => { + // store the callback to be called later + on_end_callbacks.push(callback); + } } } + document = start_payload.document; } - let document = match start_payload.document { + let document = match document { Some(parsed) => parsed, None => { let query_str = graphql_params.get_query()?; diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index d1f83ca7d..766b0b28c 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -5,16 +5,15 @@ use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; use crate::schema_state::SchemaState; -use crate::RouterSharedState; use hive_router_plan_executor::execution::plan::PlanExecutionOutput; use hive_router_plan_executor::hooks::on_query_plan::{ OnQueryPlanEndPayload, OnQueryPlanStartPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; -use hive_router_plan_executor::plugin_context::PluginManager; +use hive_router_plan_executor::plugin_context::PluginRequestState; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; -use hive_router_query_planner::planner::PlannerError; +use hive_router_query_planner::planner::{PlannerError}; use hive_router_query_planner::utils::cancellation::CancellationToken; use xxhash_rust::xxh3::Xxh3; @@ -35,8 +34,7 @@ pub async fn plan_operation_with_cache<'req>( normalized_operation: &GraphQLNormalizationPayload, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, - app_state: &RouterSharedState, - plugin_manager: &PluginManager<'req>, + plugin_req_state: &Option>, ) -> Result { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -57,34 +55,41 @@ pub async fn plan_operation_with_cache<'req>( })); } - /* Handle on_query_plan hook in the plugins - START */ - let mut start_payload = OnQueryPlanStartPayload { - router_http_request: &plugin_manager.router_http_request, - context: &plugin_manager.context, - filtered_operation_for_plan, - planner_override_context: (&request_override_context.clone()).into(), - cancellation_token, - query_plan: None, - planner: &supergraph.planner, - }; - + let mut query_plan: Option = None; let mut on_end_callbacks = vec![]; - for plugin in app_state.plugins.as_ref() { - let result = plugin.on_query_plan(start_payload).await; - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { - // continue to next plugin - } - ControlFlowResult::EndResponse(response) => { - return Err(QueryPlanGetterError::Response(response)); - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + + if let Some(plugin_req_state) = plugin_req_state { + /* Handle on_query_plan hook in the plugins - START */ + let mut start_payload = OnQueryPlanStartPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + filtered_operation_for_plan, + planner_override_context: (&request_override_context.clone()).into(), + cancellation_token, + query_plan, + planner: &supergraph.planner, + }; + + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_query_plan(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } + + query_plan = start_payload.query_plan; } - let query_plan = match start_payload.query_plan { + + let query_plan = match query_plan { Some(plan) => plan, None => supergraph .planner diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 92cb1eb6f..d4635051e 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -10,7 +10,7 @@ use hive_router_plan_executor::hooks::on_graphql_validation::{ OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; -use hive_router_plan_executor::plugin_context::PluginManager; +use hive_router_plan_executor::plugin_context::PluginRequestState; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use tracing::{error, trace}; @@ -20,7 +20,7 @@ pub async fn validate_operation_with_cache( schema_state: &SchemaState, app_state: &RouterSharedState, parser_payload: &GraphQLParserPayload, - plugin_manager: &PluginManager<'_>, + plugin_req_state: &Option>, ) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; @@ -43,37 +43,45 @@ pub async fn validate_operation_with_cache( parser_payload.cache_key ); - /* Handle on_graphql_validate hook in the plugins - START */ - let mut start_payload = OnGraphQLValidationStartPayload::new( - plugin_manager, - consumer_schema_ast, - &parser_payload.parsed_operation, - &app_state.validation_plan, - ); let mut on_end_callbacks = vec![]; - for plugin in app_state.plugins.as_ref() { - let result = plugin.on_graphql_validation(start_payload).await; - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { - // continue to next plugin - } - ControlFlowResult::EndResponse(response) => { - return Ok(Some(response)); - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + let document = &parser_payload.parsed_operation; + let errors = if let Some(plugin_req_state) = plugin_req_state.as_ref() { + /* Handle on_graphql_validate hook in the plugins - START */ + let mut start_payload = OnGraphQLValidationStartPayload::new( + plugin_req_state, + consumer_schema_ast, + document, + &app_state.validation_plan, + ); + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_graphql_validation(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Ok(Some(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } - } - - let errors = match start_payload.errors { - Some(errors) => errors, - None => validate( + match start_payload.errors { + Some(errors) => errors, + None => validate( + consumer_schema_ast, + start_payload.document, + start_payload.get_validation_plan(), + ), + } + } else { + validate( consumer_schema_ast, - start_payload.document, - start_payload.get_validation_plan(), - ), + document, + &app_state.validation_plan, + ) }; let mut end_payload = OnGraphQLValidationEndPayload { errors }; diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs index 223a702ee..6bf43aac6 100644 --- a/bin/router/src/plugins/plugins_service.rs +++ b/bin/router/src/plugins/plugins_service.rs @@ -47,11 +47,12 @@ where ) -> Result { let plugins = req .app_state::>() - .map(|shared_state| shared_state.plugins.clone()); + .and_then(|shared_state| shared_state.plugins.clone()); - if let Some(plugins) = plugins { + if let Some(plugins) = plugins.as_ref() { let plugin_context = Arc::new(PluginContext::default()); req.extensions_mut().insert(plugin_context.clone()); + let mut start_payload = OnHttpRequestPayload { router_http_request: req, context: &plugin_context, diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs index 261bb0b52..651576adc 100644 --- a/bin/router/src/plugins/registry.rs +++ b/bin/router/src/plugins/registry.rs @@ -26,7 +26,7 @@ impl PluginRegistry { map: HashMap::new(), } } - pub fn register(&mut self) { + pub fn register(mut self) -> Self { self.map.insert( P::plugin_name(), Box::new(|plugin_config: Value| { @@ -37,6 +37,7 @@ impl PluginRegistry { } }), ); + return self; } pub fn initialize_plugins( &self, diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 69db5db3b..2d6cb0f66 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -85,32 +85,34 @@ impl SchemaState { while let Some(new_sdl) = rx.recv().await { debug!("Received new supergraph SDL, building new supergraph state..."); - let new_ast = parse_schema(&new_sdl); - - let mut start_payload = OnSupergraphLoadStartPayload { - current_supergraph_data: swappable_data_spawn_clone.clone(), - new_ast, - }; + let mut new_ast = parse_schema(&new_sdl); let mut on_end_callbacks = vec![]; - for plugin in app_state.plugins.as_ref() { - let result = plugin.on_supergraph_reload(start_payload); - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { - // continue to next plugin - } - ControlFlowResult::EndResponse(_) => { - unreachable!("Plugins should not end supergraph reload processing"); - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + if let Some(plugins) = app_state.plugins.as_ref() { + + let mut start_payload = OnSupergraphLoadStartPayload { + current_supergraph_data: swappable_data_spawn_clone.clone(), + new_ast, + }; + for plugin in plugins.as_ref() { + let result = plugin.on_supergraph_reload(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(_) => { + unreachable!("Plugins should not end supergraph reload processing"); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } + new_ast = start_payload.new_ast; } - let new_ast = start_payload.new_ast; match Self::build_data(router_config.clone(), &new_ast, app_state.plugins.clone()) { Ok(new_supergraph_data) => { @@ -166,7 +168,7 @@ impl SchemaState { fn build_data( router_config: Arc, parsed_supergraph_sdl: &Document, - plugins: Arc>>, + plugins: Option>>>, ) -> Result { let supergraph_state = SupergraphState::new(parsed_supergraph_sdl); let planner = Planner::new_from_supergraph(parsed_supergraph_sdl)?; diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index f876dcb49..a5ce177d7 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -19,14 +19,14 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, - pub plugins: Arc>>, + pub plugins: Option>>>, } impl RouterSharedState { pub fn new( router_config: Arc, jwt_auth_runtime: Option, - plugins: Vec>, + plugins: Option>>, ) -> Result { Ok(Self { validation_plan: graphql_tools::validation::rules::default_rules_validation_plan(), @@ -39,7 +39,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, - plugins: Arc::new(plugins), + plugins: plugins.map(|p| Arc::new(p)), }) } } diff --git a/e2e/Cargo.toml b/e2e/Cargo.toml index 5a604afc1..2ecb5434f 100644 --- a/e2e/Cargo.toml +++ b/e2e/Cargo.toml @@ -17,10 +17,23 @@ lazy_static = { workspace = true } jsonwebtoken = { workspace = true } insta = { workspace = true } reqwest = { workspace = true } +serde = { workspace = true } +dashmap = { workspace = true } +async-trait = { workspace = true } +http = { workspace = true } +bytes = { workspace = true } +graphql-tools = { workspace = true } +graphql-parser = { workspace = true } hive-router = { path = "../bin/router" } hive-router-config = { path = "../lib/router-config" } +hive-router-plan-executor = { path = "../lib/executor" } +hive-router-query-planner = { path = "../lib/query-planner" } subgraphs = { path = "../bench/subgraphs" } mockito = "1.7.0" tempfile = "3.23.0" +redis = "0.32.7" +multer = "3.1.0" +futures-util = "0.3.31" + diff --git a/e2e/src/file_supergraph.rs b/e2e/src/file_supergraph.rs index 2ceb77f14..c6782ae3e 100644 --- a/e2e/src/file_supergraph.rs +++ b/e2e/src/file_supergraph.rs @@ -26,7 +26,7 @@ mod file_supergraph_e2e_tests { source: file path: {supergraph_file_path} "#, - )) + ), None) .await .expect("failed to start router"); wait_for_readiness(&app.app).await; @@ -70,7 +70,7 @@ mod file_supergraph_e2e_tests { path: {supergraph_file_path} poll_interval: 100ms "#, - )) + ), None) .await .expect("failed to start router"); wait_for_readiness(&app.app).await; diff --git a/e2e/src/hive_cdn_supergraph.rs b/e2e/src/hive_cdn_supergraph.rs index 03c1efafb..b32137e7d 100644 --- a/e2e/src/hive_cdn_supergraph.rs +++ b/e2e/src/hive_cdn_supergraph.rs @@ -30,7 +30,7 @@ mod hive_cdn_supergraph_e2e_tests { endpoint: http://{host}/supergraph key: dummy_key "#, - )) + ), None) .await .expect("failed to start router"); @@ -84,7 +84,7 @@ mod hive_cdn_supergraph_e2e_tests { key: dummy_key poll_interval: 100ms "#, - )) + ), None) .await .expect("failed to start router"); @@ -142,7 +142,7 @@ mod hive_cdn_supergraph_e2e_tests { key: dummy_key poll_interval: 100ms "#, - )) + ), None) .await .expect("failed to start router"); @@ -206,7 +206,7 @@ mod hive_cdn_supergraph_e2e_tests { key: dummy_key poll_interval: 800ms "#, - )) + ), None) .await .expect("failed to start router"); @@ -285,7 +285,7 @@ mod hive_cdn_supergraph_e2e_tests { retry_policy: max_retries: 10 "#, - )) + ), None) .await .expect("failed to start router"); @@ -313,7 +313,7 @@ mod hive_cdn_supergraph_e2e_tests { retry_policy: max_retries: 3 "#, - )) + ), None) .await .expect("failed to start router"); diff --git a/e2e/src/jwt.rs b/e2e/src/jwt.rs index eb4f5eba4..aeadb0d59 100644 --- a/e2e/src/jwt.rs +++ b/e2e/src/jwt.rs @@ -27,7 +27,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_forward_claims_to_subgraph_via_extensions() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_forward.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_forward.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -69,7 +69,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_allow_expressions_to_access_jwt_details() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -132,7 +132,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_allow_expressions_to_access_jwt_scopes() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -243,7 +243,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_without_token_when_auth_is_required() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -270,7 +270,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_malformed_token() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -301,7 +301,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_invalid_signature() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -324,7 +324,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn accepts_request_with_valid_token() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -358,7 +358,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_expired_token() { - let app = init_router_from_config_file("configs/jwt_auth.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth.router.yaml", None) .await .unwrap(); @@ -391,7 +391,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_wrong_issuer() { - let app = init_router_from_config_file("configs/jwt_auth_issuer.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_issuer.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; @@ -413,7 +413,7 @@ mod jwt_e2e_tests { #[ntex::test] async fn rejects_request_with_wrong_audience() { - let app = init_router_from_config_file("configs/jwt_auth_audience.router.yaml") + let app = init_router_from_config_file("configs/jwt_auth_audience.router.yaml", None) .await .unwrap(); wait_for_readiness(&app.app).await; diff --git a/e2e/src/lib.rs b/e2e/src/lib.rs index 9086e01f4..936873b74 100644 --- a/e2e/src/lib.rs +++ b/e2e/src/lib.rs @@ -12,3 +12,5 @@ mod probes; mod supergraph; #[cfg(test)] mod testkit; +#[cfg(test)] +mod plugins; diff --git a/e2e/src/override_subgraph_urls.rs b/e2e/src/override_subgraph_urls.rs index cd3e0789d..80ea7d931 100644 --- a/e2e/src/override_subgraph_urls.rs +++ b/e2e/src/override_subgraph_urls.rs @@ -15,7 +15,7 @@ mod override_subgraph_urls_e2e_tests { async fn should_override_subgraph_url_based_on_static_value() { let subgraphs_server = SubgraphsServer::start_with_port(4100).await; let app = init_router_from_config_file( - "configs/override_subgraph_urls/override_static.router.yaml", + "configs/override_subgraph_urls/override_static.router.yaml", None, ) .await .unwrap(); @@ -48,6 +48,7 @@ mod override_subgraph_urls_e2e_tests { let subgraphs_server = SubgraphsServer::start_with_port(4100).await; let app = init_router_from_config_file( "configs/override_subgraph_urls/override_dynamic_header.router.yaml", + None, ) .await .unwrap(); diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/e2e/src/plugins/apollo_sandbox.rs similarity index 80% rename from lib/executor/src/plugins/examples/apollo_sandbox.rs rename to e2e/src/plugins/apollo_sandbox.rs index 787f71cf6..d3c28be68 100644 --- a/lib/executor/src/plugins/examples/apollo_sandbox.rs +++ b/e2e/src/plugins/apollo_sandbox.rs @@ -1,23 +1,18 @@ -use ::serde::{Deserialize, Serialize}; -use ahash::HashMap; -use http::{HeaderMap, StatusCode}; +use std::collections::HashMap; -use crate::{ +use hive_router_plan_executor::{ execution::plan::PlanExecutionOutput, hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; +use http::HeaderMap; +use reqwest::StatusCode; +pub(crate) use sonic_rs::{Deserialize, Serialize}; #[derive(Default, Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] +#[serde(default, rename_all = "camelCase")] pub struct ApolloSandboxOptions { pub enabled: bool, - /** - * The URL of the GraphQL endpoint that Sandbox introspects on initial load. Sandbox populates its pages using the schema obtained from this endpoint. - * The default value is `http://localhost:4000`. - * You should only pass non-production endpoints to Sandbox. Sandbox is powered by schema introspection, and we recommend [disabling introspection in production](https://www.apollographql.com/blog/graphql/security/why-you-should-disable-graphql-introspection-in-production/). - * To provide a "Sandbox-like" experience for production endpoints, we recommend using either a [public variant](https://www.apollographql.com/docs/graphos/platform/graph-management/variants#public-variants) or the [embedded Explorer](https://www.apollographql.com/docs/graphos/platform/explorer/embed). - */ pub initial_endpoint: String, /** * By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.Set `hideCookieToggle` to `false` to enable users of your embedded Sandbox instance to toggle the **Include cookies** setting. @@ -128,7 +123,10 @@ impl RouterPluginWithConfig for ApolloSandboxPlugin { } fn from_config(config: ApolloSandboxOptions) -> Option { if config.enabled { - Some(ApolloSandboxPlugin { options: config }) + Some(ApolloSandboxPlugin { + serialized_options: sonic_rs::to_string(&config) + .unwrap_or_else(|_| "{}".to_string()), + }) } else { None } @@ -136,7 +134,7 @@ impl RouterPluginWithConfig for ApolloSandboxPlugin { } pub struct ApolloSandboxPlugin { - pub options: ApolloSandboxOptions, + serialized_options: String, } impl RouterPlugin for ApolloSandboxPlugin { @@ -145,7 +143,8 @@ impl RouterPlugin for ApolloSandboxPlugin { payload: OnHttpRequestPayload<'req>, ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { if payload.router_http_request.path() == "/apollo-sandbox" { - let config = sonic_rs::to_string(&self.options).unwrap_or_else(|_| "{}".to_string()); + let config = + sonic_rs::to_string(&self.serialized_options).unwrap_or_else(|_| "{}".to_string()); let html = format!( r#"
@@ -169,3 +168,35 @@ impl RouterPlugin for ApolloSandboxPlugin { payload.cont() } } + +#[cfg(test)] +mod apollo_sandbox_tests { + use hive_router::PluginRegistry; + + #[ntex::test] + async fn renders_apollo_sandbox_page() { + use crate::testkit::init_router_from_config_inline; + use ntex::web::test; + + let app = init_router_from_config_inline( + r#" + plugins: + apollo_sandbox: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + + let req = test::TestRequest::get().uri("/apollo-sandbox").to_request(); + let response = app.call(req).await.expect("failed to call /apollo-sandbox"); + let status = response.status(); + + let body_bytes = test::read_body(response).await; + let body_str = std::str::from_utf8(&body_bytes).expect("response body is not valid UTF-8"); + + assert_eq!(status, 200); + assert!(body_str.contains("EmbeddedSandbox")); + } +} diff --git a/lib/executor/src/plugins/examples/apq.rs b/e2e/src/plugins/apq.rs similarity index 98% rename from lib/executor/src/plugins/examples/apq.rs rename to e2e/src/plugins/apq.rs index 32b5ac6a5..aeb9d107b 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/e2e/src/plugins/apq.rs @@ -2,7 +2,7 @@ use dashmap::DashMap; use serde::Deserialize; use sonic_rs::{JsonContainerTrait, JsonValueTrait}; -use crate::{ +use hive_router_plan_executor::{ hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; diff --git a/lib/executor/src/plugins/examples/async_auth.rs b/e2e/src/plugins/async_auth.rs similarity index 99% rename from lib/executor/src/plugins/examples/async_auth.rs rename to e2e/src/plugins/async_auth.rs index c72fbac6e..197e367be 100644 --- a/lib/executor/src/plugins/examples/async_auth.rs +++ b/e2e/src/plugins/async_auth.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use serde::Deserialize; use sonic_rs::json; -use crate::{ +use hive_router_plan_executor::{ execution::plan::PlanExecutionOutput, hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, diff --git a/lib/executor/src/plugins/examples/context_data.rs b/e2e/src/plugins/context_data.rs similarity index 98% rename from lib/executor/src/plugins/examples/context_data.rs rename to e2e/src/plugins/context_data.rs index d7b89cdee..305d9c7fa 100644 --- a/lib/executor/src/plugins/examples/context_data.rs +++ b/e2e/src/plugins/context_data.rs @@ -2,7 +2,7 @@ use serde::Deserialize; -use crate::{ +use hive_router_plan_executor::{ hooks::{ on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, diff --git a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs b/e2e/src/plugins/forbid_anonymous_operations.rs similarity index 98% rename from lib/executor/src/plugins/examples/forbid_anonymous_operations.rs rename to e2e/src/plugins/forbid_anonymous_operations.rs index 9b14976d8..0d82c2334 100644 --- a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs +++ b/e2e/src/plugins/forbid_anonymous_operations.rs @@ -4,7 +4,7 @@ use http::StatusCode; use serde::Deserialize; use sonic_rs::json; -use crate::{ +use hive_router_plan_executor::{ execution::plan::PlanExecutionOutput, hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, diff --git a/lib/executor/src/plugins/examples/mod.rs b/e2e/src/plugins/mod.rs similarity index 100% rename from lib/executor/src/plugins/examples/mod.rs rename to e2e/src/plugins/mod.rs diff --git a/lib/executor/src/plugins/examples/multipart.rs b/e2e/src/plugins/multipart.rs similarity index 99% rename from lib/executor/src/plugins/examples/multipart.rs rename to e2e/src/plugins/multipart.rs index 4e6085851..ab8290f42 100644 --- a/lib/executor/src/plugins/examples/multipart.rs +++ b/e2e/src/plugins/multipart.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::{ +use hive_router_plan_executor::{ executors::common::HttpExecutionResponse, hooks::{ on_graphql_params::{ diff --git a/lib/executor/src/plugins/examples/one_of.rs b/e2e/src/plugins/one_of.rs similarity index 99% rename from lib/executor/src/plugins/examples/one_of.rs rename to e2e/src/plugins/one_of.rs index a85e4ed48..e04abbb48 100644 --- a/lib/executor/src/plugins/examples/one_of.rs +++ b/e2e/src/plugins/one_of.rs @@ -47,7 +47,7 @@ use std::{collections::BTreeMap, sync::RwLock}; -use crate::{ +use hive_router_plan_executor::{ execution::plan::PlanExecutionOutput, hooks::{ on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/e2e/src/plugins/propagate_status_code.rs similarity index 99% rename from lib/executor/src/plugins/examples/propagate_status_code.rs rename to e2e/src/plugins/propagate_status_code.rs index 0e73fd2dd..57c7098d1 100644 --- a/lib/executor/src/plugins/examples/propagate_status_code.rs +++ b/e2e/src/plugins/propagate_status_code.rs @@ -3,7 +3,7 @@ use http::StatusCode; use serde::Deserialize; -use crate::{ +use hive_router_plan_executor::{ hooks::{ on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/e2e/src/plugins/response_cache.rs similarity index 99% rename from lib/executor/src/plugins/examples/response_cache.rs rename to e2e/src/plugins/response_cache.rs index 25a1dbb03..93dbdb939 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/e2e/src/plugins/response_cache.rs @@ -3,7 +3,7 @@ use http::{HeaderMap, StatusCode}; use redis::Commands; use serde::Deserialize; -use crate::{ +use hive_router_plan_executor::{ execution::plan::PlanExecutionOutput, hooks::{ on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/e2e/src/plugins/root_field_limit.rs similarity index 99% rename from lib/executor/src/plugins/examples/root_field_limit.rs rename to e2e/src/plugins/root_field_limit.rs index 589f8e3eb..12f6356f3 100644 --- a/lib/executor/src/plugins/examples/root_field_limit.rs +++ b/e2e/src/plugins/root_field_limit.rs @@ -10,7 +10,7 @@ use hive_router_query_planner::ast::selection_item::SelectionItem; use serde::Deserialize; use sonic_rs::json; -use crate::{ +use hive_router_plan_executor::{ execution::plan::PlanExecutionOutput, hooks::{ on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/e2e/src/plugins/subgraph_response_cache.rs similarity index 98% rename from lib/executor/src/plugins/examples/subgraph_response_cache.rs rename to e2e/src/plugins/subgraph_response_cache.rs index 54464606c..553840433 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/e2e/src/plugins/subgraph_response_cache.rs @@ -1,7 +1,7 @@ use dashmap::DashMap; use serde::Deserialize; -use crate::{ +use hive_router_plan_executor::{ executors::common::HttpExecutionResponse, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, diff --git a/e2e/src/probes.rs b/e2e/src/probes.rs index 86c6a8d8f..dc4fd57ab 100644 --- a/e2e/src/probes.rs +++ b/e2e/src/probes.rs @@ -30,7 +30,7 @@ mod probes_e2e_tests { key: dummy_key poll_interval: 500ms "#, - )) + ), None) .await .expect("failed to start router"); @@ -83,7 +83,7 @@ mod probes_e2e_tests { key: dummy_key poll_interval: 500ms "#, - )) + ), None) .await .expect("failed to start router"); diff --git a/e2e/src/supergraph.rs b/e2e/src/supergraph.rs index 07b345aac..0b4635843 100644 --- a/e2e/src/supergraph.rs +++ b/e2e/src/supergraph.rs @@ -36,7 +36,7 @@ mod supergraph_e2e_tests { key: dummy_key poll_interval: 500ms "#, - )) + ), None) .await .expect("failed to start router"); @@ -198,7 +198,7 @@ mod supergraph_e2e_tests { key: dummy_key poll_interval: 300ms "#, - )) + ), None) .await .expect("failed to start router"), ); diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index a44210edb..a1c266027 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -1,8 +1,7 @@ use std::{path::PathBuf, sync::Arc, time::Duration}; use hive_router::{ - background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, - PluginRegistry, RouterSharedState, SchemaState, + PluginRegistry, RouterSharedState, SchemaState, background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, plugins::plugins_service::PluginService }; use hive_router_config::{load_config, parse_yaml_config, HiveRouterConfig}; use ntex::{ @@ -124,6 +123,7 @@ impl SubgraphsServer { pub async fn init_router_from_config_file( config_path: &str, + plugin_registry: Option, ) -> Result< TestRouterApp< impl ntex::Service, @@ -133,11 +133,12 @@ pub async fn init_router_from_config_file( let supergraph_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(config_path); let router_config = load_config(Some(supergraph_path.to_str().unwrap().to_string()))?; - init_router_from_config(router_config).await + init_router_from_config(router_config, plugin_registry).await } pub async fn init_router_from_config_inline( config_yaml: &str, + plugin_registry: Option, ) -> Result< TestRouterApp< impl ntex::Service, @@ -145,7 +146,7 @@ pub async fn init_router_from_config_inline( Box, > { let router_config = parse_yaml_config(config_yaml.to_string())?; - init_router_from_config(router_config).await + init_router_from_config(router_config, plugin_registry).await } pub struct TestRouterApp { @@ -173,6 +174,7 @@ impl TestRouterApp { pub async fn init_router_from_config( router_config: HiveRouterConfig, + plugin_registry: Option, ) -> Result< TestRouterApp< impl ntex::Service, @@ -181,11 +183,12 @@ pub async fn init_router_from_config( > { let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager, PluginRegistry::new()) + configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_registry) .await?; let ntex_app = test::init_service( web::App::new() + .wrap(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) .configure(configure_ntex_app), diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 0ba354798..629d65e5d 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -49,14 +49,11 @@ hyper-util = { version = "0.1.16", features = [ "http2", "tokio", ] } -bytes = "1.10.1" +bytes = { workspace = true } itoa = "1.0.15" ryu = "1.0.20" indexmap = "2.10.0" bumpalo = "3.19.0" -redis = "0.32.7" -multer = "3.1.0" -futures-util = "0.3.31" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index ad1a46744..fc1999d46 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -35,7 +35,7 @@ use crate::{ resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, }, - plugin_context::PluginManager, + plugin_context::PluginRequestState, plugin_trait::ControlFlowResult, projection::{ plan::FieldProjectionPlan, @@ -55,7 +55,7 @@ use crate::{ }; pub struct QueryPlanExecutionContext<'exec, 'req> { - pub plugin_manager: &'exec PluginManager<'exec>, + pub plugin_req_state: &'exec Option>, pub query_plan: &'exec QueryPlan, pub operation_for_plan: &'exec OperationDefinition, pub projection_plan: &'exec Vec, @@ -78,44 +78,51 @@ pub struct PlanExecutionOutput { impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { pub async fn execute_query_plan(self) -> Result { - let init_value = if let Some(introspection_query) = self.introspection_context.query { + let mut init_value = if let Some(introspection_query) = self.introspection_context.query { resolve_introspection(introspection_query, self.introspection_context) } else { Value::Null }; + let mut query_plan = self.query_plan; + let dedupe_subgraph_requests = self.operation_type_name == "Query"; - let mut start_payload = OnExecuteStartPayload { - router_http_request: &self.plugin_manager.router_http_request, - context: &self.plugin_manager.context, - query_plan: self.query_plan, - operation_for_plan: self.operation_for_plan, - data: init_value, - errors: Vec::new(), - extensions: self.extensions.clone(), - variable_values: self.variable_values, - dedupe_subgraph_requests, - }; + let mut extensions = self.extensions; let mut on_end_callbacks = vec![]; - for plugin in self.plugin_manager.plugins.iter() { - let result = plugin.on_execute(start_payload).await; - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { - return Ok(response); - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + if let Some(plugin_req_state) = self.plugin_req_state.as_ref() { + let mut start_payload = OnExecuteStartPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + query_plan, + operation_for_plan: self.operation_for_plan, + data: init_value, + errors: Vec::new(), + extensions, + variable_values: self.variable_values, + dedupe_subgraph_requests, + }; + + for plugin in plugin_req_state.plugins.iter() { + let result = plugin.on_execute(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } - } + query_plan = start_payload.query_plan; - let query_plan = start_payload.query_plan; + init_value = start_payload.data; - let init_value = start_payload.data; + extensions = start_payload.extensions; + } let mut exec_ctx = ExecutionContext::new(query_plan, init_value); let executor = Executor::new( @@ -127,7 +134,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { self.jwt_auth_forwarding, // Deduplicate subgraph requests only if the operation type is a query self.operation_type_name == "Query", - self.plugin_manager, + self.plugin_req_state, ); if query_plan.node.is_some() { @@ -143,36 +150,47 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { affected_path: || None, })?; - let mut end_payload = OnExecuteEndPayload { - data: exec_ctx.final_response, - errors: exec_ctx.errors, - extensions: start_payload.extensions, - response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), - }; - - for callback in on_end_callbacks { - let result = callback(end_payload); - end_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next callback */ } - ControlFlowResult::EndResponse(output) => { - return Ok(output); - } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); + let mut data = exec_ctx.final_response; + let mut errors = exec_ctx.errors; + let mut response_size_estimate = exec_ctx.response_storage.estimate_final_response_size(); + + if on_end_callbacks.len() > 0 { + let mut end_payload = OnExecuteEndPayload { + data, + errors, + extensions, + response_size_estimate, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(output) => { + return Ok(output); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } } } + + data = end_payload.data; + errors = end_payload.errors; + extensions = end_payload.extensions; + response_size_estimate = end_payload.response_size_estimate; } let body = project_by_operation( - &end_payload.data, - end_payload.errors, - &self.extensions, + &data, + errors, + &extensions, self.operation_type_name, self.projection_plan, self.variable_values, - end_payload.response_size_estimate, + response_size_estimate, ) .with_plan_context(LazyPlanContext { subgraph_name: || None, @@ -195,7 +213,7 @@ pub struct Executor<'exec, 'req> { headers_plan: &'exec HeaderRulesPlan, jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, - plugin_manager: &'exec PluginManager<'exec>, + plugin_req_state: &'exec Option>, } struct ConcurrencyScope<'exec, T> { @@ -291,7 +309,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { headers_plan: &'exec HeaderRulesPlan, jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, - plugin_manager: &'exec PluginManager<'exec>, + plugin_req_state: &'exec Option>, ) -> Self { Executor { variable_values, @@ -301,7 +319,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { headers_plan, dedupe_subgraph_requests, jwt_forwarding_plan, - plugin_manager, + plugin_req_state, } } @@ -795,7 +813,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { &node.service_name, subgraph_request, self.client_request, - self.plugin_manager, + self.plugin_req_state, ) .await .into(), diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index f33cbaedc..90e3942e2 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -40,7 +40,7 @@ pub struct HTTPSubgraphExecutor { pub semaphore: Arc, pub config: Arc, pub in_flight_requests: Arc>, ABuildHasher>>, - pub plugins: Arc>>, + pub plugins: Option>>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -56,7 +56,7 @@ impl HTTPSubgraphExecutor { semaphore: Arc, config: Arc, in_flight_requests: Arc>, ABuildHasher>>, - plugins: Arc>>, + plugins: Option>>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -169,7 +169,7 @@ async fn send_request( method: http::Method, body: Vec, headers: HeaderMap, - plugins: Arc>>, + plugins: Option>>>, ) -> Result { let mut req = hyper::Request::builder() .method(method) @@ -182,37 +182,39 @@ async fn send_request( *req.headers_mut() = headers; - let mut start_payload = OnSubgraphHttpRequestPayload { - subgraph_name, - request: req, - response: None, - }; + let mut req = req; let mut on_end_callbacks = vec![]; - for plugin in plugins.as_ref() { - let result = plugin.on_subgraph_http_request(start_payload).await; - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { - // TODO: Fixx - return Ok(SharedResponse { - status: StatusCode::OK, - body: response.body.into(), - headers: response.headers, - }); - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + if let Some(plugins) = plugins.as_ref() { + let mut start_payload = OnSubgraphHttpRequestPayload { + subgraph_name, + request: req, + response: None, + }; + for plugin in plugins.as_ref() { + let result = plugin.on_subgraph_http_request(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + // TODO: Fixx + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } + req = start_payload.request; } debug!("making http request to {}", endpoint.to_string()); - let req = start_payload.request; - let res = http_client .request(req) .await diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index bf8eb7838..ef433b468 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -38,7 +38,7 @@ use crate::{ http::{HTTPSubgraphExecutor, HttpClient}, }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, - plugin_context::PluginManager, + plugin_context::PluginRequestState, plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError, }; @@ -64,13 +64,13 @@ pub struct SubgraphExecutorMap { semaphores_by_origin: DashMap>, max_connections_per_host: usize, in_flight_requests: Arc>, ABuildHasher>>, - plugins: Arc>>, + plugins: Option>>>, } impl SubgraphExecutorMap { pub fn new( config: Arc, - plugins: Arc>>, + plugins: Option>>>, ) -> Self { let https = HttpsConnector::new(); let client: HttpClient = Client::builder(TokioExecutor::new()) @@ -100,7 +100,7 @@ impl SubgraphExecutorMap { pub fn from_http_endpoint_map( subgraph_endpoint_map: HashMap, config: Arc, - plugins: Arc>>, + plugins: Option>>>, ) -> Result { let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), plugins); @@ -130,42 +130,44 @@ impl SubgraphExecutorMap { subgraph_name: &str, execution_request: SubgraphExecutionRequest<'exec>, client_request: &ClientRequestDetails<'exec, 'req>, - plugin_manager: &PluginManager<'req>, + plugin_req_state: &Option>, ) -> HttpExecutionResponse { - let mut start_payload = OnSubgraphExecuteStartPayload { - router_http_request: &plugin_manager.router_http_request, - context: &plugin_manager.context, - subgraph_name: subgraph_name.to_string(), - execution_request, - execution_result: None, - }; let mut on_end_callbacks = vec![]; - for plugin in self.plugins.as_ref() { - let result = plugin.on_subgraph_execute(start_payload).await; - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { - // continue to next plugin - } - ControlFlowResult::EndResponse(response) => { - // TODO: FFIX - return HttpExecutionResponse { - body: response.body.into(), - headers: response.headers, - status: response.status, - }; - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + let mut execution_request = execution_request; + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut start_payload = OnSubgraphExecuteStartPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + subgraph_name, + execution_request, + execution_result: None, + }; + for plugin in plugin_req_state.plugins.as_ref() { + let result = plugin.on_subgraph_execute(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + status: response.status, + }; + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } + execution_request = start_payload.execution_request; } - let execution_request = start_payload.execution_request; - - let execution_result = match self.get_or_create_executor(subgraph_name, client_request) { + let mut execution_result = match self.get_or_create_executor(subgraph_name, client_request) { Ok(Some(executor)) => executor.execute(execution_request).await, Err(err) => { error!( @@ -183,33 +185,38 @@ impl SubgraphExecutorMap { } }; - let mut end_payload = OnSubgraphExecuteEndPayload { - context: &plugin_manager.context, - execution_result, - }; + if let Some(plugin_req_state) = plugin_req_state.as_ref() { + let mut end_payload = OnSubgraphExecuteEndPayload { + context: &plugin_req_state.context, + execution_result, + }; - for callback in on_end_callbacks { - let result = callback(end_payload); - end_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { - // continue to next callback - } - ControlFlowResult::EndResponse(response) => { - // TODO: FFIX - return HttpExecutionResponse { - body: response.body.into(), - headers: response.headers, - status: response.status, - }; - } - ControlFlowResult::OnEnd(_) => { - unreachable!("End callbacks should not register further end callbacks"); + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + status: response.status, + }; + } + ControlFlowResult::OnEnd(_) => { + unreachable!("End callbacks should not register further end callbacks"); + } } } + + execution_result = end_payload.execution_result; } - end_payload.execution_result + + execution_result } fn internal_server_error_response( diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index c341a6a36..a839e06da 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -9,7 +9,7 @@ use graphql_tools::{ use hive_router_query_planner::state::supergraph_state::SchemaDocument; use crate::{ - plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, + plugin_context::{PluginContext, PluginRequestState, RouterHttpRequest}, plugin_trait::{EndPayload, StartPayload}, }; @@ -27,14 +27,14 @@ impl<'exec> StartPayload for OnGraphQLValidationS impl<'exec> OnGraphQLValidationStartPayload<'exec> { pub fn new( - plugin_manager: &'exec PluginManager<'exec>, + plugin_req_state: &'exec PluginRequestState<'exec>, schema: &'exec SchemaDocument, document: &'exec Document, default_validation_plan: &'exec ValidationPlan, ) -> Self { OnGraphQLValidationStartPayload { - router_http_request: &plugin_manager.router_http_request, - context: &plugin_manager.context, + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, schema, document, default_validation_plan, diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 18d037c10..870f28cba 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -8,7 +8,7 @@ pub struct OnSubgraphExecuteStartPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, - pub subgraph_name: String, + pub subgraph_name: &'exec str, pub execution_request: SubgraphExecutionRequest<'exec>, pub execution_result: Option, diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs index 3c24ff9f2..008dc147a 100644 --- a/lib/executor/src/plugins/mod.rs +++ b/lib/executor/src/plugins/mod.rs @@ -1,4 +1,3 @@ -pub mod examples; pub mod hooks; pub mod plugin_context; pub mod plugin_trait; diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs index d5ea9421f..a17da3f91 100644 --- a/lib/executor/src/plugins/plugin_context.rs +++ b/lib/executor/src/plugins/plugin_context.rs @@ -91,7 +91,7 @@ impl PluginContext { } } -pub struct PluginManager<'req> { +pub struct PluginRequestState<'req> { pub plugins: Arc>>, pub router_http_request: RouterHttpRequest<'req>, pub context: Arc, From 878097781385cdc7abf2193a12b2c684b2f2d84d Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Thu, 27 Nov 2025 18:45:41 +0300 Subject: [PATCH 20/31] More tests --- Cargo.lock | 110 +++++++++++ bin/router/src/lib.rs | 6 +- bin/router/src/main.rs | 2 +- bin/router/src/pipeline/execution.rs | 2 +- bin/router/src/pipeline/mod.rs | 21 ++- bin/router/src/pipeline/query_plan.rs | 2 +- bin/router/src/pipeline/validation.rs | 8 +- bin/router/src/plugins/registry.rs | 35 ++-- bin/router/src/schema_state.rs | 2 - bin/router/src/shared_state.rs | 2 +- e2e/Cargo.toml | 5 +- e2e/src/file_supergraph.rs | 18 +- e2e/src/hive_cdn_supergraph.rs | 54 ++++-- e2e/src/jwt.rs | 14 +- e2e/src/lib.rs | 4 +- e2e/src/override_subgraph_urls.rs | 3 +- e2e/src/plugins/allowed_clients.json | 1 + e2e/src/plugins/apq.rs | 173 +++++++++++++++++- e2e/src/plugins/async_auth.rs | 92 +++++++++- e2e/src/plugins/context_data.rs | 51 +++++- .../plugins/forbid_anonymous_operations.rs | 57 +++++- e2e/src/plugins/multipart.rs | 6 +- e2e/src/plugins/one_of.rs | 18 +- e2e/src/plugins/propagate_status_code.rs | 93 +++++++++- e2e/src/plugins/response_cache.rs | 156 ++++++++++++++-- e2e/src/plugins/root_field_limit.rs | 45 ++++- e2e/src/plugins/subgraph_response_cache.rs | 39 +++- e2e/src/probes.rs | 18 +- e2e/src/supergraph.rs | 18 +- e2e/src/testkit.rs | 149 ++++++++++++++- lib/executor/src/execution/plan.rs | 2 +- lib/executor/src/executors/map.rs | 44 +++-- lib/router-config/src/primitives/file_path.rs | 6 + 33 files changed, 1104 insertions(+), 152 deletions(-) create mode 100644 e2e/src/plugins/allowed_clients.json diff --git a/Cargo.lock b/Cargo.lock index b5d240160..72bf2d71e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,6 +471,51 @@ dependencies = [ "objc2", ] +[[package]] +name = "bollard" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87a52479c9237eb04047ddb94788c41ca0d26eaff8b697ecfbb4c32f7fdc3b1b" +dependencies = [ + "base64 0.22.1", + "bollard-stubs", + "bytes", + "futures-core", + "futures-util", + "hex", + "http", + "http-body-util", + "hyper", + "hyper-named-pipe", + "hyper-util", + "hyperlocal", + "log", + "pin-project-lite", + "serde", + "serde_derive", + "serde_json", + "serde_repr", + "serde_urlencoded", + "thiserror 2.0.17", + "tokio", + "tokio-util", + "tower-service", + "url", + "winapi", +] + +[[package]] +name = "bollard-stubs" +version = "1.49.1-rc.28.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5731fe885755e92beff1950774068e0cae67ea6ec7587381536fca84f1779623" +dependencies = [ + "serde", + "serde_json", + "serde_repr", + "serde_with", +] + [[package]] name = "borrow-or-share" version = "0.2.4" @@ -1347,6 +1392,7 @@ name = "e2e" version = "0.0.1" dependencies = [ "async-trait", + "bollard", "bytes", "dashmap", "futures-util", @@ -1363,9 +1409,11 @@ dependencies = [ "mockito", "multer", "ntex", + "r2d2", "redis", "reqwest", "serde", + "serde_json", "sonic-rs", "subgraphs", "tempfile", @@ -2233,6 +2281,21 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-named-pipe" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73b7d8abf35697b81a825e386fc151e0d503e8cb5fcb93cc8669c376dfd6f278" +dependencies = [ + "hex", + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", + "winapi", +] + [[package]] name = "hyper-rustls" version = "0.27.7" @@ -2292,6 +2355,21 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "hyperlocal" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "986c5ce3b994526b3cd75578e62554abd09f0899d6206de48b3e96ab34ccc8c7" +dependencies = [ + "hex", + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -4157,6 +4235,17 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r2d2" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" +dependencies = [ + "log", + "parking_lot 0.12.5", + "scheduled-thread-pool", +] + [[package]] name = "rancor" version = "0.1.1" @@ -4255,6 +4344,7 @@ dependencies = [ "itoa", "num-bigint", "percent-encoding", + "r2d2", "ryu", "sha1_smol", "socket2 0.6.1", @@ -4679,6 +4769,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot 0.12.5", +] + [[package]] name = "schemars" version = "0.9.0" @@ -4861,6 +4960,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.108", +] + [[package]] name = "serde_spanned" version = "1.0.3" diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 071bcd0b7..daa5047ab 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -128,8 +128,10 @@ pub async fn configure_app_from_config( false => None, }; - let plugins = - plugin_registry.map(|plugin_registry| plugin_registry.initialize_plugins(&router_config)); + let plugins = match plugin_registry { + Some(plugin_registry) => plugin_registry.initialize_plugins(&router_config)?, + None => None, + }; let router_config_arc = Arc::new(router_config); let shared_state = Arc::new(RouterSharedState::new( diff --git a/bin/router/src/main.rs b/bin/router/src/main.rs index d46192b27..162b3dfa4 100644 --- a/bin/router/src/main.rs +++ b/bin/router/src/main.rs @@ -1,4 +1,4 @@ -use hive_router::{router_entrypoint}; +use hive_router::router_entrypoint; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 75fbf499b..7a011a1aa 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -86,7 +86,7 @@ pub async fn execute_plan( }; let ctx = QueryPlanExecutionContext { - plugin_req_state: &plugin_req_state, + plugin_req_state: plugin_req_state, query_plan: query_plan_payload, operation_for_plan: &normalized_payload.operation_for_plan, projection_plan: &normalized_payload.projection_plan, diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 529154efa..730921683 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -115,7 +115,6 @@ pub async fn graphql_request_handler( }); } - let response = execute_pipeline( req, body_bytes, @@ -126,6 +125,7 @@ pub async fn graphql_request_handler( plugin_req_state, ) .await?; + let response_status = response.status; let response_bytes = Bytes::from(response.body); let response_headers = response.headers; @@ -138,6 +138,7 @@ pub async fn graphql_request_handler( Ok(response_builder .header(http::header::CONTENT_TYPE, response_content_type) + .status(response_status) .body(response_bytes)) } @@ -160,12 +161,13 @@ pub async fn execute_pipeline( let mut graphql_params = None; let mut body = body; if let Some(plugin_req_state) = plugin_req_state.as_ref() { - let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { - router_http_request: &plugin_req_state.router_http_request, - context: &plugin_req_state.context, - body, - graphql_params: None, - }; + let mut deserialization_payload: OnGraphQLParamsStartPayload = + OnGraphQLParamsStartPayload { + router_http_request: &plugin_req_state.router_http_request, + context: &plugin_req_state.context, + body, + graphql_params: None, + }; for plugin in plugin_req_state.plugins.as_ref() { let result = plugin.on_graphql_params(deserialization_payload).await; deserialization_payload = result.payload; @@ -183,8 +185,7 @@ pub async fn execute_pipeline( body = deserialization_payload.body; } let mut graphql_params = graphql_params.unwrap_or_else(|| { - deserialize_graphql_params(req, body) - .expect("Failed to parse execution request") + deserialize_graphql_params(req, body).expect("Failed to parse execution request") }); if let Some(plugin_req_state) = &plugin_req_state { @@ -208,7 +209,7 @@ pub async fn execute_pipeline( } graphql_params = payload.graphql_params; } - + /* Handle on_deserialize hook in the plugins - END */ let parser_result = diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 766b0b28c..343972bd0 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -13,7 +13,7 @@ use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::plugin_context::PluginRequestState; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; -use hive_router_query_planner::planner::{PlannerError}; +use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; use xxhash_rust::xxh3::Xxh3; diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index d4635051e..8dd98f922 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -46,7 +46,7 @@ pub async fn validate_operation_with_cache( let mut on_end_callbacks = vec![]; let document = &parser_payload.parsed_operation; let errors = if let Some(plugin_req_state) = plugin_req_state.as_ref() { - /* Handle on_graphql_validate hook in the plugins - START */ + /* Handle on_graphql_validate hook in the plugins - START */ let mut start_payload = OnGraphQLValidationStartPayload::new( plugin_req_state, consumer_schema_ast, @@ -77,11 +77,7 @@ pub async fn validate_operation_with_cache( ), } } else { - validate( - consumer_schema_ast, - document, - &app_state.validation_plan, - ) + validate(consumer_schema_ast, document, &app_state.validation_plan) }; let mut end_payload = OnGraphQLValidationEndPayload { errors }; diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs index 651576adc..a0b81125b 100644 --- a/bin/router/src/plugins/registry.rs +++ b/bin/router/src/plugins/registry.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use hive_router_config::HiveRouterConfig; use hive_router_plan_executor::plugin_trait::{RouterPlugin, RouterPluginWithConfig}; use serde_json::Value; -use tracing::{info, warn}; +use tracing::info; pub struct PluginRegistry { map: HashMap< @@ -20,13 +20,23 @@ impl Default for PluginRegistry { } } +#[derive(Debug, thiserror::Error)] +pub enum PluginRegistryError { + #[error("Failed to initialize plugin '{0}': {1}")] + Config(String, serde_json::Error), + #[error( + "Plugin '{0}' is not registered in the registry but is specified in the configuration" + )] + MissingInRegistry(String), +} + impl PluginRegistry { pub fn new() -> Self { Self { map: HashMap::new(), } } - pub fn register(mut self) -> Self { + pub fn register(mut self) -> Self { self.map.insert( P::plugin_name(), Box::new(|plugin_config: Value| { @@ -37,12 +47,12 @@ impl PluginRegistry { } }), ); - return self; + self } pub fn initialize_plugins( &self, router_config: &HiveRouterConfig, - ) -> Vec> { + ) -> Result>>, PluginRegistryError> { let mut plugins: Vec> = vec![]; for (plugin_name, plugin_config_value) in router_config.plugins.iter() { @@ -56,19 +66,18 @@ impl PluginRegistry { } } Err(err) => { - warn!( - "Failed to load plugin '{}': {}, skipping plugin", - plugin_name, err - ); + return Err(PluginRegistryError::Config(plugin_name.clone(), err)); } } } else { - warn!( - "No plugin found registered '{}', skipping plugin", - plugin_name - ); + return Err(PluginRegistryError::MissingInRegistry(plugin_name.clone())); } } - plugins + + if plugins.is_empty() { + Ok(None) + } else { + Ok(Some(plugins)) + } } } diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 2d6cb0f66..6e05d1430 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -90,7 +90,6 @@ impl SchemaState { let mut on_end_callbacks = vec![]; if let Some(plugins) = app_state.plugins.as_ref() { - let mut start_payload = OnSupergraphLoadStartPayload { current_supergraph_data: swappable_data_spawn_clone.clone(), new_ast, @@ -113,7 +112,6 @@ impl SchemaState { new_ast = start_payload.new_ast; } - match Self::build_data(router_config.clone(), &new_ast, app_state.plugins.clone()) { Ok(new_supergraph_data) => { let mut end_payload = OnSupergraphLoadEndPayload { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index a5ce177d7..e8541d92c 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -39,7 +39,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, - plugins: plugins.map(|p| Arc::new(p)), + plugins: plugins.map(Arc::new), }) } } diff --git a/e2e/Cargo.toml b/e2e/Cargo.toml index 2ecb5434f..f3c6e37a0 100644 --- a/e2e/Cargo.toml +++ b/e2e/Cargo.toml @@ -24,6 +24,7 @@ http = { workspace = true } bytes = { workspace = true } graphql-tools = { workspace = true } graphql-parser = { workspace = true } +serde_json = { workspace = true } hive-router = { path = "../bin/router" } hive-router-config = { path = "../lib/router-config" } @@ -33,7 +34,9 @@ subgraphs = { path = "../bench/subgraphs" } mockito = "1.7.0" tempfile = "3.23.0" -redis = "0.32.7" +redis = { version= "0.32.7", features = ["r2d2"]} +r2d2 = "0.8.10" multer = "3.1.0" futures-util = "0.3.31" +bollard = "0.19.4" diff --git a/e2e/src/file_supergraph.rs b/e2e/src/file_supergraph.rs index c6782ae3e..a9f825fb8 100644 --- a/e2e/src/file_supergraph.rs +++ b/e2e/src/file_supergraph.rs @@ -21,12 +21,15 @@ mod file_supergraph_e2e_tests { let first_supergraph = include_str!("../supergraph.graphql"); fs::write(&supergraph_file_path, first_supergraph).expect("failed to write supergraph"); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: file path: {supergraph_file_path} "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); wait_for_readiness(&app.app).await; @@ -64,13 +67,16 @@ mod file_supergraph_e2e_tests { fs::write(&supergraph_file_path, "type Query { f: String }") .expect("failed to write supergraph"); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: file path: {supergraph_file_path} poll_interval: 100ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); wait_for_readiness(&app.app).await; diff --git a/e2e/src/hive_cdn_supergraph.rs b/e2e/src/hive_cdn_supergraph.rs index b32137e7d..363492f96 100644 --- a/e2e/src/hive_cdn_supergraph.rs +++ b/e2e/src/hive_cdn_supergraph.rs @@ -24,13 +24,16 @@ mod hive_cdn_supergraph_e2e_tests { .with_body(include_str!("../supergraph.graphql")) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); @@ -77,14 +80,17 @@ mod hive_cdn_supergraph_e2e_tests { .with_status(304) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 100ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); @@ -135,14 +141,17 @@ mod hive_cdn_supergraph_e2e_tests { .with_status(304) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 100ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); @@ -199,14 +208,17 @@ mod hive_cdn_supergraph_e2e_tests { .with_body(include_str!("../supergraph.graphql")) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 800ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); @@ -276,8 +288,9 @@ mod hive_cdn_supergraph_e2e_tests { .with_body("type Query { dummy: String }") .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key @@ -285,7 +298,9 @@ mod hive_cdn_supergraph_e2e_tests { retry_policy: max_retries: 10 "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); @@ -304,8 +319,9 @@ mod hive_cdn_supergraph_e2e_tests { .with_status(500) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key @@ -313,7 +329,9 @@ mod hive_cdn_supergraph_e2e_tests { retry_policy: max_retries: 3 "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); diff --git a/e2e/src/jwt.rs b/e2e/src/jwt.rs index aeadb0d59..3c0471d72 100644 --- a/e2e/src/jwt.rs +++ b/e2e/src/jwt.rs @@ -69,9 +69,10 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_allow_expressions_to_access_jwt_details() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) - .await - .unwrap(); + let app = + init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) + .await + .unwrap(); wait_for_readiness(&app.app).await; let exp = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -132,9 +133,10 @@ mod jwt_e2e_tests { #[ntex::test] async fn should_allow_expressions_to_access_jwt_scopes() { let subgraphs_server = SubgraphsServer::start().await; - let app = init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) - .await - .unwrap(); + let app = + init_router_from_config_file("configs/jwt_auth_header_expression.router.yaml", None) + .await + .unwrap(); wait_for_readiness(&app.app).await; // First request with a token and "scope: read:accounts" diff --git a/e2e/src/lib.rs b/e2e/src/lib.rs index 936873b74..3875f71a1 100644 --- a/e2e/src/lib.rs +++ b/e2e/src/lib.rs @@ -7,10 +7,10 @@ mod jwt; #[cfg(test)] mod override_subgraph_urls; #[cfg(test)] +mod plugins; +#[cfg(test)] mod probes; #[cfg(test)] mod supergraph; #[cfg(test)] mod testkit; -#[cfg(test)] -mod plugins; diff --git a/e2e/src/override_subgraph_urls.rs b/e2e/src/override_subgraph_urls.rs index 80ea7d931..1932608f1 100644 --- a/e2e/src/override_subgraph_urls.rs +++ b/e2e/src/override_subgraph_urls.rs @@ -15,7 +15,8 @@ mod override_subgraph_urls_e2e_tests { async fn should_override_subgraph_url_based_on_static_value() { let subgraphs_server = SubgraphsServer::start_with_port(4100).await; let app = init_router_from_config_file( - "configs/override_subgraph_urls/override_static.router.yaml", None, + "configs/override_subgraph_urls/override_static.router.yaml", + None, ) .await .unwrap(); diff --git a/e2e/src/plugins/allowed_clients.json b/e2e/src/plugins/allowed_clients.json new file mode 100644 index 000000000..381a87082 --- /dev/null +++ b/e2e/src/plugins/allowed_clients.json @@ -0,0 +1 @@ +["urql", "graphql-request"] \ No newline at end of file diff --git a/e2e/src/plugins/apq.rs b/e2e/src/plugins/apq.rs index aeb9d107b..e9c9c6ff5 100644 --- a/e2e/src/plugins/apq.rs +++ b/e2e/src/plugins/apq.rs @@ -1,8 +1,11 @@ use dashmap::DashMap; +use http::StatusCode; use serde::Deserialize; +use serde_json::json; use sonic_rs::{JsonContainerTrait, JsonValueTrait}; use hive_router_plan_executor::{ + execution::plan::PlanExecutionOutput, hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; @@ -19,7 +22,7 @@ pub struct APQPlugin { impl RouterPluginWithConfig for APQPlugin { type Config = APQPluginConfig; fn plugin_name() -> &'static str { - "apq_plugin" + "apq" } fn from_config(config: Self::Config) -> Option { if config.enabled { @@ -46,11 +49,24 @@ impl RouterPlugin for APQPlugin { .and_then(|ext| ext.get("persistedQuery")) .and_then(|pq| pq.as_object()); if let Some(persisted_query_ext) = persisted_query_ext { - match persisted_query_ext.get(&"version").and_then(|v| v.as_str()) { - Some("1") => {} + match persisted_query_ext.get(&"version").and_then(|v| v.as_i64()) { + Some(1) => {} _ => { - // TODO: Error for unsupported version - return payload.cont(); + let body = json!({ + "errors": [ + { + "message": "Unsupported persisted query version", + "extensions": { + "code": "UNSUPPORTED_PERSISTED_QUERY_VERSION" + } + } + ] + }); + return payload.end_response(PlanExecutionOutput { + body: body.to_string().into_bytes(), + status: StatusCode::BAD_REQUEST, + headers: http::HeaderMap::new(), + }); } } let sha256_hash = match persisted_query_ext @@ -59,7 +75,21 @@ impl RouterPlugin for APQPlugin { { Some(h) => h, None => { - return payload.cont(); + let body = json!({ + "errors": [ + { + "message": "Missing sha256Hash in persisted query", + "extensions": { + "code": "MISSING_PERSISTED_QUERY_HASH" + } + } + ] + }); + return payload.end_response(PlanExecutionOutput { + body: body.to_string().into_bytes(), + status: StatusCode::BAD_REQUEST, + headers: http::HeaderMap::new(), + }); } }; if let Some(query_param) = &payload.graphql_params.query { @@ -72,8 +102,21 @@ impl RouterPlugin for APQPlugin { // Update the graphql_params with the cached query payload.graphql_params.query = Some(cached_query.value().to_string()); } else { - // Error - return payload.cont(); + let body = json!({ + "errors": [ + { + "message": "PersistedQueryNotFound", + "extensions": { + "code": "PERSISTED_QUERY_NOT_FOUND" + } + } + ] + }); + return payload.end_response(PlanExecutionOutput { + body: body.to_string().into_bytes(), + status: StatusCode::NOT_FOUND, + headers: http::HeaderMap::new(), + }); } } } @@ -82,3 +125,117 @@ impl RouterPlugin for APQPlugin { }) } } + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + + use hive_router::PluginRegistry; + use ntex::web::test; + use serde_json::json; + #[ntex::test] + async fn sends_not_found_error_if_query_missing() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + apq: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let body = json!( + { + "extensions": { + "persistedQuery": { + "version": 1, + "sha256Hash": "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38", + }, + }, + } + ); + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + let body = test::read_body(resp).await; + let body_json: serde_json::Value = + serde_json::from_slice(&body).expect("Response body should be valid JSON"); + assert_eq!( + body_json, + json!({ + "errors": [ + { + "message": "PersistedQueryNotFound", + "extensions": { + "code": "PERSISTED_QUERY_NOT_FOUND" + } + } + ] + }), + "Expected PersistedQueryNotFound error" + ); + } + #[ntex::test] + async fn saves_persisted_query() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + apq: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let query = "{ users { id } }"; + let sha256_hash = "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"; + let body = json!( + { + "query": query, + "extensions": { + "persistedQuery": { + "version": 1, + "sha256Hash": sha256_hash, + }, + }, + } + ); + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!( + resp.status().is_success(), + "Expected 200 OK when sending full query" + ); + + // Now send only the hash and expect it to be found + let body = json!( + { + "extensions": { + "persistedQuery": { + "version": 1, + "sha256Hash": sha256_hash, + }, + }, + } + ); + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!( + resp.status().is_success(), + "Expected 200 OK when sending persisted query hash" + ); + } +} diff --git a/e2e/src/plugins/async_auth.rs b/e2e/src/plugins/async_auth.rs index 197e367be..f9e3b82eb 100644 --- a/e2e/src/plugins/async_auth.rs +++ b/e2e/src/plugins/async_auth.rs @@ -1,8 +1,7 @@ -use std::path::PathBuf; - // From https://github.com/apollographql/router/blob/dev/examples/async-auth/rust/src/allow_client_id_from_file.rs use serde::Deserialize; use sonic_rs::json; +use std::path::PathBuf; use hive_router_plan_executor::{ execution::plan::PlanExecutionOutput, @@ -82,7 +81,7 @@ impl RouterPlugin for AllowClientIdFromFilePlugin { } } Err(_not_a_string_error) => { - let message = format!("'{}' value is not a string", self.header_key); + let message = format!("'{}' value is not a string", &self.header_key); tracing::error!(message); let body = json!( { @@ -105,7 +104,7 @@ impl RouterPlugin for AllowClientIdFromFilePlugin { } } None => { - let message = format!("Missing '{}' header", self.header_key); + let message = format!("Missing '{}' header", &self.header_key); tracing::error!(message); let body = json!( { @@ -129,3 +128,88 @@ impl RouterPlugin for AllowClientIdFromFilePlugin { payload.cont() } } + +#[cfg(test)] +mod tests { + use crate::testkit::{ + init_graphql_request, init_router_from_config_inline, wait_for_readiness, SubgraphsServer, + }; + + use hive_router::PluginRegistry; + use ntex::web::test; + use serde_json::Value; + #[ntex::test] + async fn should_allow_only_allowed_client_ids() { + SubgraphsServer::start().await; + + let app = init_router_from_config_inline( + r#" + plugins: + allow_client_id_from_file: + enabled: true + path: "./src/plugins/allowed_clients.json" + header: "x-client-id" + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + // Test with an allowed client id + let req = init_graphql_request("{ users { id } }", None).header("x-client-id", "urql"); + let resp = test::call_service(&app.app, req.to_request()).await; + let status = resp.status(); + assert!(status.is_success(), "Expected 200 OK for allowed client id"); + // Test with a disallowed client id + let req = init_graphql_request("{ users { id } }", None) + .header("x-client-id", "forbidden-client"); + let resp = test::call_service(&app.app, req.to_request()).await; + assert_eq!( + resp.status(), + http::StatusCode::FORBIDDEN, + "Expected 403 FORBIDDEN for disallowed client id" + ); + let body_bytes = test::read_body(resp).await; + let body_json: Value = + serde_json::from_slice(&body_bytes).expect("Response body should be valid JSON"); + assert_eq!( + body_json, + serde_json::json!({ + "errors": [ + { + "message": "client-id is not allowed", + "extensions": { + "code": "UNAUTHORIZED_CLIENT_ID" + } + } + ] + }), + "Expected error message for disallowed client id" + ); + // Test with missing client id + let req = init_graphql_request("{ users { id } }", None); + let resp = test::call_service(&app.app, req.to_request()).await; + assert_eq!( + resp.status(), + http::StatusCode::UNAUTHORIZED, + "Expected 401 UNAUTHORIZED for missing client id" + ); + let body_bytes = test::read_body(resp).await; + let body_json: Value = + serde_json::from_slice(&body_bytes).expect("Response body should be valid JSON"); + assert_eq!( + body_json, + serde_json::json!({ + "errors": [ + { + "message": "Missing 'x-client-id' header", + "extensions": { + "code": "AUTH_ERROR" + } + } + ] + }), + "Expected error message for missing client id" + ); + } +} diff --git a/e2e/src/plugins/context_data.rs b/e2e/src/plugins/context_data.rs index 305d9c7fa..d2ef891bd 100644 --- a/e2e/src/plugins/context_data.rs +++ b/e2e/src/plugins/context_data.rs @@ -26,7 +26,7 @@ pub struct ContextData { impl RouterPluginWithConfig for ContextDataPlugin { type Config = ContextDataPluginConfig; fn plugin_name() -> &'static str { - "context_data_plugin" + "context_data" } fn from_config(config: ContextDataPluginConfig) -> Option { if config.enabled { @@ -86,3 +86,52 @@ impl RouterPlugin for ContextDataPlugin { }) } } + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + use hive_router::PluginRegistry; + use ntex::web::test; + #[ntex::test] + async fn should_add_context_data_and_modify_subgraph_request() { + let subgraphs = SubgraphsServer::start().await; + + let app = init_router_from_config_inline( + r#" + plugins: + context_data: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + + wait_for_readiness(&app.app).await; + + let resp = test::call_service( + &app.app, + crate::testkit::init_graphql_request("{ users { id } }", None).to_request(), + ) + .await; + + assert!(resp.status().is_success(), "Expected 200 OK"); + + let request_logs = subgraphs + .get_subgraph_requests_log("accounts") + .await + .expect("expected requests sent to accounts subgraph"); + assert_eq!( + request_logs.len(), + 1, + "expected 1 request to accounts subgraph" + ); + let hello_header_value = request_logs[0] + .headers + .get("x-hello") + .expect("expected x-hello header to be present in subgraph request") + .to_str() + .expect("header value should be valid string"); + assert_eq!(hello_header_value, "Hello world"); + } +} diff --git a/e2e/src/plugins/forbid_anonymous_operations.rs b/e2e/src/plugins/forbid_anonymous_operations.rs index 0d82c2334..1539c3ec0 100644 --- a/e2e/src/plugins/forbid_anonymous_operations.rs +++ b/e2e/src/plugins/forbid_anonymous_operations.rs @@ -65,10 +65,59 @@ impl RouterPlugin for ForbidAnonymousOperationsPlugin { headers: http::HeaderMap::new(), status: StatusCode::BAD_REQUEST, }); - } else { - // we're good to go! - tracing::info!("operation is allowed!"); - return payload.cont(); } + // we're good to go! + tracing::info!("operation is allowed!"); + payload.cont() + } +} + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + use hive_router::PluginRegistry; + use http::StatusCode; + use ntex::web::test; + use serde_json::{json, Value}; + #[ntex::test] + async fn should_forbid_anonymous_operations() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + forbid_anonymous_operations: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + + let resp = test::call_service( + &app.app, + test::TestRequest::post() + .uri("/graphql") + .set_payload(r#"{"query":"{ __schema { types { name } } }"}"#) + .header("content-type", "application/json") + .to_request(), + ) + .await; + + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let json_body: Value = serde_json::from_slice(&test::read_body(resp).await).unwrap(); + assert_eq!( + json_body, + json!({ + "errors": [ + { + "message": "Anonymous operations are not allowed", + "extensions": { + "code": "ANONYMOUS_OPERATION" + } + } + ] + }) + ); } } diff --git a/e2e/src/plugins/multipart.rs b/e2e/src/plugins/multipart.rs index ab8290f42..cb0fa1646 100644 --- a/e2e/src/plugins/multipart.rs +++ b/e2e/src/plugins/multipart.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; +use bytes::Bytes; +use dashmap::DashMap; use hive_router_plan_executor::{ executors::common::HttpExecutionResponse, hooks::{ @@ -10,8 +12,6 @@ use hive_router_plan_executor::{ }, plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; -use bytes::Bytes; -use dashmap::DashMap; use multer::Multipart; use serde::{Deserialize, Serialize}; @@ -42,7 +42,7 @@ struct MultipartOperations<'a> { impl RouterPluginWithConfig for MultipartPlugin { type Config = MultipartPluginConfig; fn plugin_name() -> &'static str { - "multipart_plugin" + "multipart" } fn from_config(config: MultipartPluginConfig) -> Option { if config.enabled { diff --git a/e2e/src/plugins/one_of.rs b/e2e/src/plugins/one_of.rs index e04abbb48..91679042a 100644 --- a/e2e/src/plugins/one_of.rs +++ b/e2e/src/plugins/one_of.rs @@ -47,15 +47,6 @@ use std::{collections::BTreeMap, sync::RwLock}; -use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, - hooks::{ - on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, - on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, - on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, - }, - plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, -}; use graphql_parser::{ query::Value, schema::{Definition, TypeDefinition}, @@ -68,6 +59,15 @@ use graphql_tools::{ utils::{ValidationError, ValidationErrorContext}, }, }; +use hive_router_plan_executor::{ + execution::plan::PlanExecutionOutput, + hooks::{ + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, + on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, +}; use serde::Deserialize; use sonic_rs::{json, JsonContainerTrait}; diff --git a/e2e/src/plugins/propagate_status_code.rs b/e2e/src/plugins/propagate_status_code.rs index 57c7098d1..177dbbb8c 100644 --- a/e2e/src/plugins/propagate_status_code.rs +++ b/e2e/src/plugins/propagate_status_code.rs @@ -20,7 +20,7 @@ pub struct PropagateStatusCodePluginConfig { impl RouterPluginWithConfig for PropagateStatusCodePlugin { type Config = PropagateStatusCodePluginConfig; fn plugin_name() -> &'static str { - "propagate_status_code_plugin" + "propagate_status_code" } fn from_config(config: PropagateStatusCodePluginConfig) -> Option { if !config.enabled { @@ -87,3 +87,94 @@ impl RouterPlugin for PropagateStatusCodePlugin { }) } } + +#[cfg(test)] +mod tests { + #[ntex::test] + async fn propagates_highest_status_code() { + let mut subgraphs_server = mockito::Server::new_async().await; + let accounts_mock_207 = subgraphs_server + .mock("POST", "/accounts") + .with_status(207) + .with_body(r#"{"data": {"users": [{"id": "1"}]}}"#) + .create_async() + .await; + let products_mock_206 = subgraphs_server + .mock("POST", "/products") + .with_status(206) + .with_body(r#"{"data": {"topProducts": [{"upc": "a"}]}}"#) + .create_async() + .await; + let app = crate::testkit::init_router_from_config_inline( + &format!( + r#" + override_subgraph_urls: + accounts: + url: http://{}/accounts + products: + url: http://{}/products + plugins: + propagate_status_code: + enabled: true + status_codes: [206, 207] + "#, + subgraphs_server.host_with_port(), + subgraphs_server.host_with_port() + ), + Some(hive_router::PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + crate::testkit::wait_for_readiness(&app.app).await; + + let req = + crate::testkit::init_graphql_request("{ users { id } topProducts { upc } }", None); + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + assert_eq!(resp.status().as_u16(), 207); + accounts_mock_207.assert_async().await; + products_mock_206.assert_async().await; + } + #[ntex::test] + async fn ignores_unlisted_status_codes() { + let mut subgraphs_server = mockito::Server::new_async().await; + let accounts_mock_208 = subgraphs_server + .mock("POST", "/accounts") + .with_status(208) + .with_body(r#"{"data": {"users": [{"id": "1"}]}}"#) + .create_async() + .await; + let products_mock_209 = subgraphs_server + .mock("POST", "/products") + .with_status(209) + .with_body(r#"{"data": {"topProducts": [{"upc": "a"}]}}"#) + .create_async() + .await; + let app = crate::testkit::init_router_from_config_inline( + &format!( + r#" + override_subgraph_urls: + accounts: + url: http://{}/accounts + products: + url: http://{}/products + plugins: + propagate_status_code: + enabled: true + status_codes: [208] + "#, + subgraphs_server.host_with_port(), + subgraphs_server.host_with_port() + ), + Some(hive_router::PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + crate::testkit::wait_for_readiness(&app.app).await; + let req = + crate::testkit::init_graphql_request("{ users { id } topProducts { upc } }", None); + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + assert_eq!(resp.status().as_u16(), 208); + accounts_mock_208.assert_async().await; + products_mock_209.assert_async().await; + } +} diff --git a/e2e/src/plugins/response_cache.rs b/e2e/src/plugins/response_cache.rs index 93dbdb939..d5d02db30 100644 --- a/e2e/src/plugins/response_cache.rs +++ b/e2e/src/plugins/response_cache.rs @@ -13,11 +13,18 @@ use hive_router_plan_executor::{ plugins::plugin_trait::RouterPlugin, utils::consts::TYPENAME_FIELD_NAME, }; +use tracing::trace; #[derive(Deserialize)] pub struct ResponseCachePluginOptions { pub enabled: bool, pub redis_url: String, + #[serde(default = "default_ttl_seconds")] + pub default_ttl_seconds: u64, +} + +fn default_ttl_seconds() -> u64 { + 5 } impl RouterPluginWithConfig for ResponseCachePlugin { @@ -31,16 +38,21 @@ impl RouterPluginWithConfig for ResponseCachePlugin { } let redis_client = redis::Client::open(config.redis_url).expect("Failed to create Redis client"); + let pool = r2d2::Pool::builder() + .build(redis_client) + .unwrap_or_else(|err| panic!("Failed to create Redis connection pool: {}", err)); Some(Self { - redis_client, + redis: pool, ttl_per_type: DashMap::new(), + default_ttl_seconds: config.default_ttl_seconds, }) } } pub struct ResponseCachePlugin { - redis_client: redis::Client, + redis: r2d2::Pool, ttl_per_type: DashMap, + default_ttl_seconds: u64, } #[async_trait::async_trait] @@ -53,22 +65,39 @@ impl RouterPlugin for ResponseCachePlugin { "response_cache:{}:{:?}", payload.query_plan, payload.variable_values ); - if let Ok(mut conn) = self.redis_client.get_connection() { - let cached_response: Option> = conn.get(&key).ok(); - if let Some(cached_response) = cached_response { - return payload.end_response(PlanExecutionOutput { - body: cached_response, - headers: HeaderMap::new(), - status: StatusCode::OK, - }); + if let Ok(mut conn) = self.redis.get() { + trace!("Checking cache for key: {}", key); + let cache_result: Result, redis::RedisError> = conn.get(&key); + match cache_result { + Ok(body) => { + if body.is_empty() { + trace!("Cache miss for key: {}", key); + } else { + trace!( + "Cache hit for key: {} -> {}", + key, + String::from_utf8_lossy(&body) + ); + return payload.end_response(PlanExecutionOutput { + body: body, + headers: HeaderMap::new(), + status: StatusCode::OK, + }); + } + } + Err(err) => { + trace!("Error accessing cache for key {}: {}", key, err); + } } return payload.on_end(move |mut payload: OnExecuteEndPayload<'exec>| { // Do not cache if there are errors if !payload.errors.is_empty() { + trace!("Not caching response due to errors"); return payload.cont(); } if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { + trace!("Caching response for key: {}", key); // Decide on the ttl somehow // Get the type names let mut max_ttl = 0; @@ -86,10 +115,11 @@ impl RouterPlugin for ResponseCachePlugin { } } - // If no ttl found, default to 60 seconds + // If no ttl found, default if max_ttl == 0 { - max_ttl = 60; + max_ttl = self.default_ttl_seconds; } + trace!("Using TTL of {} seconds for key: {}", max_ttl, key); // Insert the ttl into extensions for client awareness payload @@ -98,7 +128,13 @@ impl RouterPlugin for ResponseCachePlugin { .insert("response_cache_ttl".to_string(), sonic_rs::json!(max_ttl)); // Set the cache with the decided ttl - let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); + let result = + conn.set_ex::<&str, Vec, ()>(&key, serialized.clone(), max_ttl); + if let Err(err) = result { + trace!("Failed to set cache for key {}: {}", key, err); + } else { + trace!("Cached response for key: {} with TTL: {}", key, max_ttl); + } } payload.cont() }); @@ -134,3 +170,97 @@ impl RouterPlugin for ResponseCachePlugin { payload.cont() } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use http::StatusCode; + use tracing::trace; + + use crate::testkit::{ + wait_for_readiness, SubgraphsServer, TestDockerContainer, TestDockerContainerOpts, + }; + + #[ntex::test] + async fn test_caching_with_default_ttl() { + let container = TestDockerContainer::async_new(TestDockerContainerOpts { + name: "redis_resp_caching_test".to_string(), + image: "redis/redis-stack:latest".to_string(), + ports: HashMap::from([(6379, 6379)]), + env: vec!["ALLOW_EMPTY_PASSWORD=yes".to_string()], + ..Default::default() + }) + .await + .expect("failed to start redis container"); + + // Redis flush all to ensure clean state + container + .exec(vec!["redis-cli", "FLUSHALL"]) + .await + .expect("Failed to flush redis"); + let subgraphs_server = SubgraphsServer::start().await; + + let app = crate::testkit::init_router_from_config_inline( + r#" + plugins: + response_cache_plugin: + enabled: true + redis_url: "redis://0.0.0.0:6379" + default_ttl_seconds: 2 + "#, + Some(hive_router::PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + + wait_for_readiness(&app.app).await; + + let req = crate::testkit::init_graphql_request("{ users { id } }", None); + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + trace!("First response received"); + assert_eq!(resp.status(), StatusCode::OK); + let resp_body = ntex::web::test::read_body(resp).await; + trace!( + "Response body read: {:?}", + String::from_utf8_lossy(&resp_body) + ); + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("Failed to get subgraph requests log"); + assert_eq!(subgraph_requests.len(), 1, "Expected one subgraph request"); + let req = crate::testkit::init_graphql_request("{ users { id } }", None); + let resp2 = ntex::web::test::call_service(&app.app, req.to_request()).await; + trace!("Second response received"); + assert!(resp2.status().is_success()); + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("Failed to get subgraph requests log"); + assert_eq!( + subgraph_requests.len(), + 1, + "Expected only one subgraph request due to caching" + ); + trace!("Waiting for cache to expire..."); + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + let req = crate::testkit::init_graphql_request("{ users { id } }", None); + let resp3 = ntex::web::test::call_service(&app.app, req.to_request()).await; + assert!(resp3.status().is_success()); + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("Failed to get subgraph requests log"); + assert_eq!( + subgraph_requests.len(), + 2, + "Expected a second subgraph request after cache expiry" + ); + container.stop().await; + } + #[ntex::test] + async fn respect_directives_on_supergraph_reload() { + todo!(); + } +} diff --git a/e2e/src/plugins/root_field_limit.rs b/e2e/src/plugins/root_field_limit.rs index 12f6356f3..64da0cc28 100644 --- a/e2e/src/plugins/root_field_limit.rs +++ b/e2e/src/plugins/root_field_limit.rs @@ -95,7 +95,7 @@ pub struct RootFieldLimitPluginConfig { impl RouterPluginWithConfig for RootFieldLimitPlugin { type Config = RootFieldLimitPluginConfig; fn plugin_name() -> &'static str { - "root_field_limit_plugin" + "root_field_limit" } fn from_config(config: Self::Config) -> Option { if !config.enabled { @@ -165,3 +165,46 @@ impl ValidationRule for RootFieldLimitRule { ); } } + +#[cfg(test)] +mod tests { + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + use hive_router::PluginRegistry; + use ntex::web::test; + #[ntex::test] + async fn rejects_query_with_too_many_root_fields() { + SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + root_field_limit: + enabled: true + max_root_fields: 1 + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let resp = test::call_service( + &app.app, + test::TestRequest::post() + .uri("/graphql") + .set_payload( + r#"{"query":"query TooManyRootFields { users { id } topProducts { upc } }"}"#, + ) + .header("content-type", "application/json") + .to_request(), + ) + .await; + let json_body: serde_json::Value = + serde_json::from_slice(&test::read_body(resp).await).unwrap(); + + let error_msg = json_body["errors"][0]["message"].as_str().unwrap(); + assert!( + error_msg.contains("Query has too many root fields"), + "Unexpected error message: {}", + error_msg + ); + } +} diff --git a/e2e/src/plugins/subgraph_response_cache.rs b/e2e/src/plugins/subgraph_response_cache.rs index 553840433..617427382 100644 --- a/e2e/src/plugins/subgraph_response_cache.rs +++ b/e2e/src/plugins/subgraph_response_cache.rs @@ -15,7 +15,7 @@ pub struct SubgraphResponseCachePluginConfig { impl RouterPluginWithConfig for SubgraphResponseCachePlugin { type Config = SubgraphResponseCachePluginConfig; fn plugin_name() -> &'static str { - "subgraph_response_cache_plugin" + "subgraph_response_cache" } fn from_config(config: SubgraphResponseCachePluginConfig) -> Option { if config.enabled { @@ -55,3 +55,40 @@ impl RouterPlugin for SubgraphResponseCachePlugin { }) } } + +#[cfg(test)] +mod tests { + use crate::testkit::{ + init_graphql_request, init_router_from_config_inline, wait_for_readiness, SubgraphsServer, + }; + use hive_router::PluginRegistry; + use ntex::web::test; + + // Tests on_subgraph_execute's override behavior + #[ntex::test] + async fn caches_subgraph_responses() { + let subgraphs = SubgraphsServer::start().await; + let app = init_router_from_config_inline( + r#" + plugins: + subgraph_response_cache: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("failed to start router"); + wait_for_readiness(&app.app).await; + let req = init_graphql_request("{ users { id } }", None); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!(resp.status().is_success()); + let req = init_graphql_request("{ users { id } }", None); + let resp2 = test::call_service(&app.app, req.to_request()).await; + assert!(resp2.status().is_success()); + let subgraph_requests = subgraphs + .get_subgraph_requests_log("accounts") + .await + .expect("failed to get subgraph requests log"); + assert_eq!(subgraph_requests.len(), 1); + } +} diff --git a/e2e/src/probes.rs b/e2e/src/probes.rs index dc4fd57ab..027154d35 100644 --- a/e2e/src/probes.rs +++ b/e2e/src/probes.rs @@ -23,14 +23,17 @@ mod probes_e2e_tests { }) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 500ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); @@ -76,14 +79,17 @@ mod probes_e2e_tests { .with_status(500) .create(); - let app = init_router_from_config_inline(&format!( - r#"supergraph: + let app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 500ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); diff --git a/e2e/src/supergraph.rs b/e2e/src/supergraph.rs index 0b4635843..9cabfe952 100644 --- a/e2e/src/supergraph.rs +++ b/e2e/src/supergraph.rs @@ -29,14 +29,17 @@ mod supergraph_e2e_tests { .with_body("type Query { dummyNew: NewType } type NewType { id: ID! }") .create(); - let test_app = init_router_from_config_inline(&format!( - r#"supergraph: + let test_app = init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 500ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"); @@ -191,14 +194,17 @@ mod supergraph_e2e_tests { .create(); let test_app = Arc::new( - init_router_from_config_inline(&format!( - r#"supergraph: + init_router_from_config_inline( + &format!( + r#"supergraph: source: hive endpoint: http://{host}/supergraph key: dummy_key poll_interval: 300ms "#, - ), None) + ), + None, + ) .await .expect("failed to start router"), ); diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index a1c266027..114e801e9 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -1,7 +1,17 @@ -use std::{path::PathBuf, sync::Arc, time::Duration}; +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; +use bollard::{ + container::StartContainerOptions, + exec::{CreateExecOptions, StartExecResults}, + image, + query_parameters::CreateImageOptionsBuilder, + secret::{ContainerCreateBody, ContainerCreateResponse, CreateImageInfo, HostConfig, PortMap}, + Docker, +}; +use futures_util::TryStreamExt; use hive_router::{ - PluginRegistry, RouterSharedState, SchemaState, background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, plugins::plugins_service::PluginService + background_tasks::BackgroundTasksManager, configure_app_from_config, configure_ntex_app, + plugins::plugins_service::PluginService, PluginRegistry, RouterSharedState, SchemaState, }; use hive_router_config::{load_config, parse_yaml_config, HiveRouterConfig}; use ntex::{ @@ -183,8 +193,7 @@ pub async fn init_router_from_config( > { let mut bg_tasks_manager = BackgroundTasksManager::new(); let (shared_state, schema_state) = - configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_registry) - .await?; + configure_app_from_config(router_config, &mut bg_tasks_manager, plugin_registry).await?; let ntex_app = test::init_service( web::App::new() @@ -208,3 +217,135 @@ impl Drop for TestRouterApp { self.bg_tasks_manager.shutdown(); } } + +#[derive(Default)] +pub struct TestDockerContainerOpts { + pub name: String, + pub image: String, + pub ports: HashMap, + pub env: Vec, +} + +pub struct TestDockerContainer { + docker: Docker, + image: Vec, + container: ContainerCreateResponse, +} + +impl TestDockerContainer { + pub async fn async_new(opts: TestDockerContainerOpts) -> Result { + let docker = + Docker::connect_with_local_defaults().expect("Failed to connect to Docker daemon"); + let mut port_bindings = PortMap::new(); + for (container_port, host_port) in opts.ports.iter() { + port_bindings.insert( + format!("{}/tcp", container_port), + Some(vec![bollard::models::PortBinding { + host_port: Some(host_port.to_string()), + ..Default::default() + }]), + ); + } + let image: Vec = docker + .create_image( + Some( + CreateImageOptionsBuilder::default() + .from_image(&opts.image) + .build(), + ), + None, + None, + ) + .try_collect() + .await + .expect("Failed to pull the image"); + let container_exists = docker + .list_containers(Some(bollard::container::ListContainersOptions:: { + all: true, + ..Default::default() + })) + .await? + .into_iter() + .any(|c| { + c.names + .unwrap_or_default() + .iter() + .any(|name| name.trim_start_matches('/').eq(&opts.name)) + }); + if container_exists { + docker + .remove_container( + &opts.name, + Some(bollard::query_parameters::RemoveContainerOptions { + force: true, + ..Default::default() + }), + ) + .await + .expect("Failed to remove existing container"); + } + let container = docker + .create_container( + Some( + bollard::query_parameters::CreateContainerOptionsBuilder::default() + .name(&opts.name) + .build(), + ), + ContainerCreateBody { + image: Some(opts.image.to_string()), + host_config: Some(HostConfig { + port_bindings: Some(port_bindings), + ..Default::default() + }), + env: Some(opts.env), + ..Default::default() + }, + ) + .await + .expect("Failed to create the container"); + docker + .start_container(&container.id, None::>) + .await + .expect("Failed to start the container"); + Ok(Self { + docker, + image, + container, + }) + } + pub async fn exec(&self, cmd: Vec<&str>) -> Result<(), bollard::errors::Error> { + let exec = self + .docker + .create_exec( + &self.container.id, + CreateExecOptions { + attach_stdout: Some(true), + attach_stderr: Some(true), + cmd: Some(cmd), + ..Default::default() + }, + ) + .await?; + match self.docker.start_exec(&exec.id, None).await? { + StartExecResults::Attached { mut output, .. } => { + while let Some(msg) = output.try_next().await? { + print!("{}", msg); + } + } + _ => {} + } + Ok(()) + } + pub async fn stop(&self) { + self.docker + .remove_container( + &self.container.id, + Some(bollard::query_parameters::RemoveContainerOptions { + force: true, + ..Default::default() + }), + ) + .await + .expect("Failed to remove the container"); + } +} diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index fc1999d46..1ee1d8f34 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -154,7 +154,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { let mut errors = exec_ctx.errors; let mut response_size_estimate = exec_ctx.response_storage.estimate_final_response_size(); - if on_end_callbacks.len() > 0 { + if !on_end_callbacks.is_empty() { let mut end_payload = OnExecuteEndPayload { data, errors, diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index ef433b468..fc308677d 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -132,17 +132,17 @@ impl SubgraphExecutorMap { client_request: &ClientRequestDetails<'exec, 'req>, plugin_req_state: &Option>, ) -> HttpExecutionResponse { - let mut on_end_callbacks = vec![]; let mut execution_request = execution_request; + let mut execution_result = None; if let Some(plugin_req_state) = plugin_req_state.as_ref() { let mut start_payload = OnSubgraphExecuteStartPayload { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, subgraph_name, execution_request, - execution_result: None, + execution_result, }; for plugin in plugin_req_state.plugins.as_ref() { let result = plugin.on_subgraph_execute(start_payload).await; @@ -165,24 +165,31 @@ impl SubgraphExecutorMap { } } execution_request = start_payload.execution_request; + execution_result = start_payload.execution_result; } - let mut execution_result = match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor.execute(execution_request).await, - Err(err) => { - error!( - "Subgraph executor error for subgraph '{}': {}", - subgraph_name, err, - ); - self.internal_server_error_response(err.into(), subgraph_name) - } - Ok(None) => { - error!( - "Subgraph executor not found for subgraph '{}'", - subgraph_name - ); - self.internal_server_error_response("Internal server error".into(), subgraph_name) - } + let mut execution_result = match execution_result { + Some(execution_result) => execution_result, + None => match self.get_or_create_executor(subgraph_name, client_request) { + Ok(Some(executor)) => executor.execute(execution_request).await, + Err(err) => { + error!( + "Subgraph executor error for subgraph '{}': {}", + subgraph_name, err, + ); + self.internal_server_error_response(err.into(), subgraph_name) + } + Ok(None) => { + error!( + "Subgraph executor not found for subgraph '{}'", + subgraph_name + ); + self.internal_server_error_response( + "Internal server error".into(), + subgraph_name, + ) + } + }, }; if let Some(plugin_req_state) = plugin_req_state.as_ref() { @@ -215,7 +222,6 @@ impl SubgraphExecutorMap { execution_result = end_payload.execution_result; } - execution_result } diff --git a/lib/router-config/src/primitives/file_path.rs b/lib/router-config/src/primitives/file_path.rs index f8140562a..13aac9c37 100644 --- a/lib/router-config/src/primitives/file_path.rs +++ b/lib/router-config/src/primitives/file_path.rs @@ -9,6 +9,7 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, Serialize, }; +use tracing::info; #[derive(Debug, Clone, Serialize)] pub struct FilePath { @@ -70,6 +71,11 @@ impl<'de> Visitor<'de> for FilePathVisitor { { CONTEXT_START_PATH.with(|ctx| { if let Some(start_path) = ctx.borrow().as_ref() { + info!( + "Deserializing FilePath '{}' with start path '{}'", + v, + start_path.display() + ); match FilePath::resolve_relative(start_path, v, true) { Ok(file_path) => Ok(file_path), Err(err) => Err(E::custom(format!("Failed to canonicalize path: {}", err))), From f89a9f7bf5ba3a80e9b5c7682b0b4e80afe9631c Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Fri, 28 Nov 2025 17:46:20 +0300 Subject: [PATCH 21/31] Go --- Cargo.lock | 1 + bench/subgraphs/Cargo.toml | 1 + bench/subgraphs/lib.rs | 2 +- bench/subgraphs/products.rs | 57 ++++++- bin/router/src/pipeline/execution.rs | 2 +- bin/router/src/plugins/registry.rs | 15 +- bin/router/src/schema_state.rs | 6 +- bin/router/src/shared_state.rs | 6 +- e2e/src/plugins/multipart.rs | 153 ++++++++++++++---- e2e/src/plugins/one_of.rs | 114 ++++++++++++- e2e/src/testkit.rs | 18 +-- e2e/supergraph.graphql | 29 ++++ lib/executor/src/executors/common.rs | 3 + lib/executor/src/executors/http.rs | 139 ++++++++-------- lib/executor/src/executors/map.rs | 15 +- .../plugins/hooks/on_subgraph_http_request.rs | 16 +- lib/executor/src/plugins/plugin_context.rs | 4 +- lib/executor/src/plugins/plugin_trait.rs | 2 + 18 files changed, 432 insertions(+), 151 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 72bf2d71e..855a04b78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5390,6 +5390,7 @@ dependencies = [ "async-graphql", "async-graphql-axum", "axum", + "dashmap", "lazy_static", "rand 0.9.2", "sonic-rs", diff --git a/bench/subgraphs/Cargo.toml b/bench/subgraphs/Cargo.toml index 7d1f940bf..724e74e7b 100644 --- a/bench/subgraphs/Cargo.toml +++ b/bench/subgraphs/Cargo.toml @@ -17,6 +17,7 @@ lazy_static = { workspace = true } rand = { workspace = true } tokio = { workspace = true } sonic-rs = { workspace = true } +dashmap = { workspace = true } async-graphql = "7.0.17" async-graphql-axum = "7.0.17" diff --git a/bench/subgraphs/lib.rs b/bench/subgraphs/lib.rs index 575d0f499..0d8fa8469 100644 --- a/bench/subgraphs/lib.rs +++ b/bench/subgraphs/lib.rs @@ -74,7 +74,7 @@ async fn track_requests( fn extract_record(request_parts: &Parts, request_body: Bytes) -> RequestLog { let header_map = request_parts.headers.clone(); - let body_value: Value = sonic_rs::from_slice(&request_body).unwrap(); + let body_value: Value = sonic_rs::from_slice(&request_body).unwrap_or(Value::new()); RequestLog { headers: header_map, diff --git a/bench/subgraphs/products.rs b/bench/subgraphs/products.rs index 363f5640b..3b978dff1 100644 --- a/bench/subgraphs/products.rs +++ b/bench/subgraphs/products.rs @@ -1,5 +1,10 @@ -use async_graphql::{EmptyMutation, EmptySubscription, Object, Schema, SimpleObject, ID}; +use std::io::Read; + +use async_graphql::{ + Context, EmptySubscription, InputObject, Object, Schema, SimpleObject, Upload, ID, +}; use lazy_static::lazy_static; +use tokio::io::AsyncWriteExt; lazy_static! { static ref PRODUCTS: Vec = vec![ @@ -69,6 +74,8 @@ pub struct Product { pub struct Query; +pub struct Mutation; + #[Object(extends = true)] impl Query { async fn top_products( @@ -94,8 +101,52 @@ impl Query { } } -pub fn get_subgraph() -> Schema { - Schema::build(Query, EmptyMutation, EmptySubscription) +pub fn get_subgraph() -> Schema { + Schema::build(Query, Mutation, EmptySubscription) .enable_federation() .finish() } + +#[Object(extends = true)] +impl Mutation { + async fn upload(&self, ctx: &Context<'_>, file: Option) -> String { + if file.is_none() { + return "No file uploaded".to_string(); + } + // Write to a temp location, and return the path + let uploaded_file = file.unwrap().value(ctx).unwrap(); + let path = format!("/tmp/{}", uploaded_file.filename); + let mut buf = vec![]; + let _ = uploaded_file.into_read().read_to_end(&mut buf); + let mut tmp_file_on_disk = tokio::fs::File::create(&path).await.unwrap(); + tmp_file_on_disk.write_all(&buf).await.unwrap(); + path + } + async fn oneof_test(&self, input: OneOfTestInput) -> OneOfTestResult { + OneOfTestResult { + string: input.string, + int: input.int, + float: input.float, + boolean: input.boolean, + id: input.id, + } + } +} + +#[derive(InputObject)] +struct OneOfTestInput { + pub string: Option, + pub int: Option, + pub float: Option, + pub boolean: Option, + pub id: Option, +} + +#[derive(SimpleObject)] +struct OneOfTestResult { + pub string: Option, + pub int: Option, + pub float: Option, + pub boolean: Option, + pub id: Option, +} diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 7a011a1aa..fb0e453c3 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -86,7 +86,7 @@ pub async fn execute_plan( }; let ctx = QueryPlanExecutionContext { - plugin_req_state: plugin_req_state, + plugin_req_state, query_plan: query_plan_payload, operation_for_plan: &normalized_payload.operation_for_plan, projection_plan: &normalized_payload.projection_plan, diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs index a0b81125b..b41472ecb 100644 --- a/bin/router/src/plugins/registry.rs +++ b/bin/router/src/plugins/registry.rs @@ -1,17 +1,14 @@ use std::collections::HashMap; use hive_router_config::HiveRouterConfig; -use hive_router_plan_executor::plugin_trait::{RouterPlugin, RouterPluginWithConfig}; +use hive_router_plan_executor::plugin_trait::{RouterPluginBoxed, RouterPluginWithConfig}; use serde_json::Value; use tracing::info; +type PluginFactory = Box Result, serde_json::Error>>; + pub struct PluginRegistry { - map: HashMap< - &'static str, - Box< - dyn Fn(Value) -> Result>, serde_json::Error>, - >, - >, + map: HashMap<&'static str, PluginFactory>, } impl Default for PluginRegistry { @@ -52,8 +49,8 @@ impl PluginRegistry { pub fn initialize_plugins( &self, router_config: &HiveRouterConfig, - ) -> Result>>, PluginRegistryError> { - let mut plugins: Vec> = vec![]; + ) -> Result>, PluginRegistryError> { + let mut plugins: Vec = vec![]; for (plugin_name, plugin_config_value) in router_config.plugins.iter() { if let Some(factory) = self.map.get(plugin_name.as_str()) { diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 6e05d1430..ed87b02bd 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -8,7 +8,7 @@ use hive_router_plan_executor::{ OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData, }, introspection::schema::SchemaWithMetadata, - plugin_trait::{ControlFlowResult, RouterPlugin}, + plugin_trait::{ControlFlowResult}, SubgraphExecutorMap, }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -112,7 +112,7 @@ impl SchemaState { new_ast = start_payload.new_ast; } - match Self::build_data(router_config.clone(), &new_ast, app_state.plugins.clone()) { + match Self::build_data(router_config.clone(), &new_ast) { Ok(new_supergraph_data) => { let mut end_payload = OnSupergraphLoadEndPayload { new_supergraph_data, @@ -166,7 +166,6 @@ impl SchemaState { fn build_data( router_config: Arc, parsed_supergraph_sdl: &Document, - plugins: Option>>>, ) -> Result { let supergraph_state = SupergraphState::new(parsed_supergraph_sdl); let planner = Planner::new_from_supergraph(parsed_supergraph_sdl)?; @@ -174,7 +173,6 @@ impl SchemaState { let subgraph_executor_map = SubgraphExecutorMap::from_http_endpoint_map( supergraph_state.subgraph_endpoint_map, router_config, - plugins.clone(), )?; Ok(SupergraphData { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index e8541d92c..bffcfc4b4 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -3,7 +3,7 @@ use hive_router_config::HiveRouterConfig; use hive_router_plan_executor::headers::{ compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan, }; -use hive_router_plan_executor::plugin_trait::RouterPlugin; +use hive_router_plan_executor::plugin_trait::RouterPluginBoxed; use moka::future::Cache; use std::sync::Arc; @@ -19,14 +19,14 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, - pub plugins: Option>>>, + pub plugins: Option>>, } impl RouterSharedState { pub fn new( router_config: Arc, jwt_auth_runtime: Option, - plugins: Option>>, + plugins: Option>, ) -> Result { Ok(Self { validation_plan: graphql_tools::validation::rules::default_rules_validation_plan(), diff --git a/e2e/src/plugins/multipart.rs b/e2e/src/plugins/multipart.rs index cb0fa1646..21fa90592 100644 --- a/e2e/src/plugins/multipart.rs +++ b/e2e/src/plugins/multipart.rs @@ -1,19 +1,21 @@ use std::collections::HashMap; use bytes::Bytes; -use dashmap::DashMap; use hive_router_plan_executor::{ - executors::common::HttpExecutionResponse, + execution::plan::PlanExecutionOutput, + executors::dedupe::SharedResponse, hooks::{ on_graphql_params::{ GraphQLParams, OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload, }, - on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, }, plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; use multer::Multipart; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; +use serde_json::json; +use tracing::error; #[derive(Deserialize)] pub struct MultipartPluginConfig { @@ -29,14 +31,7 @@ pub struct MultipartFile { pub struct MultipartContext { pub file_map: HashMap>, - pub files: DashMap, -} - -#[derive(Serialize)] -struct MultipartOperations<'a> { - pub query: &'a str, - pub variables: Option<&'a HashMap<&'a str, &'a sonic_rs::Value>>, - pub operation_name: Option<&'a str>, + pub files: HashMap, } impl RouterPluginWithConfig for MultipartPlugin { @@ -84,7 +79,7 @@ impl RouterPlugin for MultipartPlugin { sonic_rs::from_slice(&data).unwrap(); payload.context.insert(MultipartContext { file_map, - files: DashMap::new(), + files: HashMap::new(), }); } field_name => { @@ -110,10 +105,10 @@ impl RouterPlugin for MultipartPlugin { payload.cont() } - async fn on_subgraph_execute<'exec>( + async fn on_subgraph_http_request<'exec>( &'exec self, - mut payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + mut payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { if let Some(variables) = &payload.execution_request.variables { let ctx_ref = payload.context.get_ref_entry(); let multipart_ctx: Option<&MultipartContext> = ctx_ref.get_ref(); @@ -133,13 +128,10 @@ impl RouterPlugin for MultipartPlugin { } if !file_map.is_empty() { let mut form = reqwest::multipart::Form::new(); - let operations_struct = MultipartOperations { - query: payload.execution_request.query, - variables: payload.execution_request.variables.as_ref(), - operation_name: payload.execution_request.operation_name, - }; - let operations = sonic_rs::to_string(&operations_struct).unwrap(); - form = form.text("operations", operations); + form = form.text( + "operations", + String::from_utf8(payload.body.clone()).unwrap(), + ); let file_map_str: String = sonic_rs::to_string(&file_map).unwrap(); form = form.text("map", file_map_str); for (file_ref, _op_refs) in file_map { @@ -156,23 +148,116 @@ impl RouterPlugin for MultipartPlugin { } } let resp = reqwest::Client::new() - .post("http://example.com/graphql") + .post(payload.endpoint.to_string()) // Using query as endpoint URL .multipart(form) .send() - .await - .unwrap(); - let headers = resp.headers().clone(); - let status = resp.status(); - let body = resp.bytes().await.unwrap(); - payload.execution_result = Some(HttpExecutionResponse { - body, - headers, - status, - }); + .await; + match resp { + Ok(resp) => { + payload.response = Some(SharedResponse { + status: resp.status(), + headers: resp.headers().clone(), + body: resp.bytes().await.unwrap(), + }); + } + Err(err) => { + error!("Failed to send multipart request to subgraph: {}", err); + let body = json!({ + "errors": [{ + "message": format!("Failed to send multipart request to subgraph: {}", err) + }] + }); + return payload.end_response(PlanExecutionOutput { + status: reqwest::StatusCode::INTERNAL_SERVER_ERROR, + headers: reqwest::header::HeaderMap::new(), + body: serde_json::to_vec(&body).unwrap(), + }); + } + } } } } payload.cont() } } + +#[cfg(test)] +mod tests { + use futures_util::StreamExt; + use hive_router::PluginRegistry; + use ntex::web::test; + + use crate::testkit::{init_router_from_config_inline, wait_for_readiness, SubgraphsServer}; + + #[ntex::test] + async fn forward_files() { + let subgraphs_server = SubgraphsServer::start().await; + + let app = init_router_from_config_inline( + r#" + plugins: + multipart: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + + let form = reqwest::multipart::Form::new() + .text("operations", r#"{"query":"mutation ($file: Upload) { upload(file: $file) }","variables":{"file":null}}"#) + .text("map", r#"{"0":["variables.file"]}"#) + .part( + "0", + reqwest::multipart::Part::bytes("file content".as_bytes().to_vec()) + .file_name("test.txt") + .mime_str("text/plain") + .unwrap(), + ); + + let boundary = form.boundary().to_string(); + let form_stream = form.into_stream(); + + let mut form_bytes = vec![]; + let mut stream = form_stream; + while let Some(item) = stream.next().await { + let chunk = item.expect("Failed to read chunk"); + form_bytes.extend_from_slice(&chunk); + } + + let req = test::TestRequest::post() + .uri("/graphql") + .header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + ) + .set_payload(form_bytes); + + let resp = test::call_service(&app.app, req.to_request()).await; + + let body = test::read_body(resp).await; + let body_str = String::from_utf8_lossy(&body); + let body_json = serde_json::from_str::(&body_str).unwrap(); + let upload_file_path = &body_json["data"]["upload"].as_str().unwrap(); + assert!( + upload_file_path.contains("test.txt"), + "Response should contain the filename" + ); + let file_content = tokio::fs::read(upload_file_path).await.unwrap(); + assert_eq!( + file_content, b"file content", + "File content should match the uploaded content" + ); + assert_eq!( + subgraphs_server + .get_subgraph_requests_log("products") + .await + .unwrap() + .len(), + 1, + "Expected 1 request to products subgraph" + ); + } +} diff --git a/e2e/src/plugins/one_of.rs b/e2e/src/plugins/one_of.rs index 91679042a..e133bbb30 100644 --- a/e2e/src/plugins/one_of.rs +++ b/e2e/src/plugins/one_of.rs @@ -79,7 +79,7 @@ pub struct OneOfPluginConfig { impl RouterPluginWithConfig for OneOfPlugin { type Config = OneOfPluginConfig; fn plugin_name() -> &'static str { - "one_of_plugin" + "oneof" } fn from_config(config: OneOfPluginConfig) -> Option { if config.enabled { @@ -233,3 +233,115 @@ impl<'a> OperationVisitor<'a, ValidationErrorContext> for OneOfValidation { } } } + +#[cfg(test)] +mod tests { + use hive_router::PluginRegistry; + use serde_json::{from_slice, Value}; + + use crate::testkit::{init_router_from_config_inline, wait_for_readiness}; + + #[ntex::test] + async fn one_of_validates_in_validation_rule() { + let app = init_router_from_config_inline( + r#" + plugins: + oneof: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + + let req = crate::testkit::init_graphql_request( + r#" + mutation OneOfTest { + oneofTest(input: { + string: "test", + int: 42 + }) { + string + int + float + boolean + id + } + } + "#, + None, + ); + + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + let body = ntex::web::test::read_body(resp).await; + let body_val: Value = from_slice(&body).expect("Response body should be valid JSON"); + let errors = body_val + .get("errors") + .expect("Response should contain errors"); + let first_error = errors + .as_array() + .expect("Errors should be an array") + .first() + .expect("There should be at least one error"); + let message = first_error + .get("message") + .expect("Error should have a message") + .as_str() + .expect("Message should be a string"); + assert!(message.contains("multiple fields set")); + } + + #[ntex::test] + async fn one_of_validates_during_execution() { + let app = init_router_from_config_inline( + r#" + plugins: + oneof: + enabled: true + "#, + Some(PluginRegistry::new().register::()), + ) + .await + .expect("Router should initialize successfully"); + wait_for_readiness(&app.app).await; + + let req = crate::testkit::init_graphql_request( + r#" + mutation OneOfTest($input: OneOfTestInput!) { + oneofTest(input: $input) { + string + int + float + boolean + id + } + } + "#, + Some(sonic_rs::json!({ + "input": { + "string": "test", + "int": 42 + } + })), + ); + + let resp = ntex::web::test::call_service(&app.app, req.to_request()).await; + let body = ntex::web::test::read_body(resp).await; + let body_val: Value = from_slice(&body).expect("Response body should be valid JSON"); + let errors = body_val + .get("errors") + .expect("Response should contain errors"); + let first_error = errors + .as_array() + .expect("Errors should be an array") + .first() + .expect("There should be at least one error"); + let message = first_error + .get("message") + .expect("Error should have a message") + .as_str() + .expect("Message should be a string"); + assert!(message.contains("multiple fields set")); + } +} diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index 114e801e9..f360d66da 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -1,9 +1,7 @@ use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; use bollard::{ - container::StartContainerOptions, exec::{CreateExecOptions, StartExecResults}, - image, query_parameters::CreateImageOptionsBuilder, secret::{ContainerCreateBody, ContainerCreateResponse, CreateImageInfo, HostConfig, PortMap}, Docker, @@ -228,7 +226,6 @@ pub struct TestDockerContainerOpts { pub struct TestDockerContainer { docker: Docker, - image: Vec, container: ContainerCreateResponse, } @@ -246,7 +243,7 @@ impl TestDockerContainer { }]), ); } - let image: Vec = docker + let _: Vec = docker .create_image( Some( CreateImageOptionsBuilder::default() @@ -260,7 +257,7 @@ impl TestDockerContainer { .await .expect("Failed to pull the image"); let container_exists = docker - .list_containers(Some(bollard::container::ListContainersOptions:: { + .list_containers(Some(bollard::query_parameters::ListContainersOptions { all: true, ..Default::default() })) @@ -304,14 +301,13 @@ impl TestDockerContainer { .await .expect("Failed to create the container"); docker - .start_container(&container.id, None::>) + .start_container( + &container.id, + None::, + ) .await .expect("Failed to start the container"); - Ok(Self { - docker, - image, - container, - }) + Ok(Self { docker, container }) } pub async fn exec(&self, cmd: Vec<&str>) -> Result<(), bollard::errors::Error> { let exec = self diff --git a/e2e/supergraph.graphql b/e2e/supergraph.graphql index 1fe16b12b..655a18c1a 100644 --- a/e2e/supergraph.graphql +++ b/e2e/supergraph.graphql @@ -2,6 +2,7 @@ schema @link(url: "https://specs.apollo.dev/link/v1.0") @link(url: "https://specs.apollo.dev/join/v0.3", for: EXECUTION) { query: Query + mutation: Mutation } directive @join__enumValue(graph: join__Graph!) repeatable on ENUM_VALUE @@ -113,3 +114,31 @@ type User birthday: Int @join__field(graph: ACCOUNTS) reviews: [Review] @join__field(graph: REVIEWS) } + +scalar Upload + +type Mutation + @join__type(graph: PRODUCTS) { + upload(file: Upload): String @join__field(graph: PRODUCTS) + + oneofTest(input: OneOfTestInput!): OneOfTestResult + @join__field(graph: PRODUCTS) +} + +directive @oneOf on INPUT_OBJECT + +input OneOfTestInput @oneOf @join__type(graph: PRODUCTS) { + string: String + int: Int + float: Float + boolean: Boolean + id: ID +} + +type OneOfTestResult @join__type(graph: PRODUCTS) { + string: String @join__field(graph: PRODUCTS) + int: Int @join__field(graph: PRODUCTS) + float: Float @join__field(graph: PRODUCTS) + boolean: Boolean @join__field(graph: PRODUCTS) + id: ID @join__field(graph: PRODUCTS) +} \ No newline at end of file diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 6b3c804b5..b607a7e43 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -5,11 +5,14 @@ use bytes::Bytes; use http::HeaderMap; use sonic_rs::Value; +use crate::plugin_context::PluginRequestState; + #[async_trait] pub trait SubgraphExecutor { async fn execute<'a>( &self, execution_request: SubgraphExecutionRequest<'a>, + plugin_req_state: &'a Option>, ) -> HttpExecutionResponse; fn to_boxed_arc<'a>(self) -> Arc> diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 90e3942e2..6798bc084 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -5,7 +5,8 @@ use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse use crate::hooks::on_subgraph_http_request::{ OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, }; -use crate::plugin_trait::{ControlFlowResult, RouterPlugin}; +use crate::plugin_context::PluginRequestState; +use crate::plugin_trait::ControlFlowResult; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -40,7 +41,6 @@ pub struct HTTPSubgraphExecutor { pub semaphore: Arc, pub config: Arc, pub in_flight_requests: Arc>, ABuildHasher>>, - pub plugins: Option>>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -56,7 +56,6 @@ impl HTTPSubgraphExecutor { semaphore: Arc, config: Arc, in_flight_requests: Arc>, ABuildHasher>>, - plugins: Option>>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -76,7 +75,6 @@ impl HTTPSubgraphExecutor { semaphore, config, in_flight_requests, - plugins, } } @@ -166,33 +164,25 @@ async fn send_request( http_client: &Client, Full>, subgraph_name: &str, endpoint: &http::Uri, - method: http::Method, - body: Vec, - headers: HeaderMap, - plugins: Option>>>, + mut method: http::Method, + mut body: Vec, + mut execution_request: SubgraphExecutionRequest<'_>, + plugin_req_state: &Option>, ) -> Result { - let mut req = hyper::Request::builder() - .method(method) - .uri(endpoint) - .version(Version::HTTP_11) - .body(Full::new(Bytes::from(body))) - .map_err(|e| { - SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) - })?; - - *req.headers_mut() = headers; - - let mut req = req; - let mut on_end_callbacks = vec![]; + let mut response = None; - if let Some(plugins) = plugins.as_ref() { + if let Some(plugin_req_state) = plugin_req_state.as_ref() { let mut start_payload = OnSubgraphHttpRequestPayload { subgraph_name, - request: req, - response: None, + endpoint, + method, + body, + execution_request, + context: &plugin_req_state.context, + response, }; - for plugin in plugins.as_ref() { + for plugin in plugin_req_state.plugins.as_ref() { let result = plugin.on_subgraph_http_request(start_payload).await; start_payload = result.payload; match result.control_flow { @@ -210,40 +200,60 @@ async fn send_request( } } } - req = start_payload.request; + method = start_payload.method; + body = start_payload.body; + execution_request = start_payload.execution_request; + response = start_payload.response; } - debug!("making http request to {}", endpoint.to_string()); - - let res = http_client - .request(req) - .await - .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))?; - - debug!( - "http request to {} completed, status: {}", - endpoint.to_string(), - res.status() - ); - - let (parts, body) = res.into_parts(); - let body = body - .collect() - .await - .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))? - .to_bytes(); - - if body.is_empty() { - return Err(SubgraphExecutorError::RequestFailure( - endpoint.to_string(), - "Empty response body".to_string(), - )); - } + let response = match response { + Some(response) => response, + None => { + let mut req = hyper::Request::builder() + .method(method) + .uri(endpoint) + .version(Version::HTTP_11) + .body(Full::new(Bytes::from(body))) + .map_err(|e| { + SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) + })?; - let response = SharedResponse { - status: parts.status, - body, - headers: parts.headers, + *req.headers_mut() = execution_request.headers; + + debug!("making http request to {}", endpoint.to_string()); + + let res = http_client.request(req).await.map_err(|e| { + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) + })?; + + debug!( + "http request to {} completed, status: {}", + endpoint.to_string(), + res.status() + ); + + let (parts, body) = res.into_parts(); + let body = body + .collect() + .await + .map_err(|e| { + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) + })? + .to_bytes(); + + if body.is_empty() { + return Err(SubgraphExecutorError::RequestFailure( + endpoint.to_string(), + "Empty response body".to_string(), + )); + } + + SharedResponse { + status: parts.status, + body, + headers: parts.headers, + } + } }; let mut end_payload = OnSubgraphHttpResponsePayload { response }; @@ -275,7 +285,8 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { #[tracing::instrument(skip_all, fields(subgraph_name = self.subgraph_name))] async fn execute<'a>( &self, - execution_request: SubgraphExecutionRequest<'a>, + mut execution_request: SubgraphExecutionRequest<'a>, + plugin_req_state: &'a Option>, ) -> HttpExecutionResponse { let body = match self.build_request_body(&execution_request) { Ok(body) => body, @@ -289,9 +300,8 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { } }; - let mut headers = execution_request.headers; self.header_map.iter().for_each(|(key, value)| { - headers.insert(key, value.clone()); + execution_request.headers.insert(key, value.clone()); }); let method = http::Method::POST; @@ -306,8 +316,8 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { &self.endpoint, method, body, - headers, - self.plugins.clone(), + execution_request, + plugin_req_state, ) .await { @@ -327,7 +337,8 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { }; } - let fingerprint = request_fingerprint(&method, &self.endpoint, &headers, &body); + let fingerprint = + request_fingerprint(&method, &self.endpoint, &execution_request.headers, &body); // Clone the cell from the map, dropping the lock from the DashMap immediately. // Prevents any deadlocks. @@ -350,8 +361,8 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { &self.endpoint, method, body, - headers, - self.plugins.clone(), + execution_request, + plugin_req_state, ) .await }; diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index fc308677d..cd33b511e 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -39,7 +39,7 @@ use crate::{ }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, plugin_context::PluginRequestState, - plugin_trait::{ControlFlowResult, RouterPlugin}, + plugin_trait::ControlFlowResult, response::graphql_error::GraphQLError, }; @@ -64,14 +64,10 @@ pub struct SubgraphExecutorMap { semaphores_by_origin: DashMap>, max_connections_per_host: usize, in_flight_requests: Arc>, ABuildHasher>>, - plugins: Option>>>, } impl SubgraphExecutorMap { - pub fn new( - config: Arc, - plugins: Option>>>, - ) -> Self { + pub fn new(config: Arc) -> Self { let https = HttpsConnector::new(); let client: HttpClient = Client::builder(TokioExecutor::new()) .pool_timer(TokioTimer::new()) @@ -93,16 +89,14 @@ impl SubgraphExecutorMap { semaphores_by_origin: Default::default(), max_connections_per_host, in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())), - plugins, } } pub fn from_http_endpoint_map( subgraph_endpoint_map: HashMap, config: Arc, - plugins: Option>>>, ) -> Result { - let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), plugins); + let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone()); for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.into_iter() { let endpoint_str = config @@ -171,7 +165,7 @@ impl SubgraphExecutorMap { let mut execution_result = match execution_result { Some(execution_result) => execution_result, None => match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor.execute(execution_request).await, + Ok(Some(executor)) => executor.execute(execution_request, plugin_req_state).await, Err(err) => { error!( "Subgraph executor error for subgraph '{}': {}", @@ -410,7 +404,6 @@ impl SubgraphExecutorMap { semaphore, self.config.clone(), self.in_flight_requests.clone(), - self.plugins.clone(), ); self.executors_by_subgraph diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 1b50f001b..2ce479b3d 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,16 +1,18 @@ -use bytes::Bytes; -use http::Request; -use http_body_util::Full; - use crate::{ - executors::dedupe::SharedResponse, + executors::{common::SubgraphExecutionRequest, dedupe::SharedResponse}, + plugin_context::PluginContext, plugin_trait::{EndPayload, StartPayload}, }; pub struct OnSubgraphHttpRequestPayload<'exec> { pub subgraph_name: &'exec str, - // At this point, there is no point of mutating this - pub request: Request>, + + pub endpoint: &'exec http::Uri, + pub method: http::Method, + pub body: Vec, + pub execution_request: SubgraphExecutionRequest<'exec>, + + pub context: &'exec PluginContext, // Early response pub response: Option, diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs index a17da3f91..14a93974d 100644 --- a/lib/executor/src/plugins/plugin_context.rs +++ b/lib/executor/src/plugins/plugin_context.rs @@ -11,7 +11,7 @@ use http::Uri; use ntex::router::Path; use ntex_http::HeaderMap; -use crate::plugin_trait::RouterPlugin; +use crate::plugin_trait::RouterPluginBoxed; pub struct RouterHttpRequest<'exec> { pub uri: &'exec Uri, @@ -92,7 +92,7 @@ impl PluginContext { } pub struct PluginRequestState<'req> { - pub plugins: Arc>>, + pub plugins: Arc>, pub router_http_request: RouterHttpRequest<'req>, pub context: Arc, } diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 5f6b4a0c8..ea18c16e2 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -158,3 +158,5 @@ pub trait RouterPlugin { start_payload.cont() } } + +pub type RouterPluginBoxed = Box; From 2fbb1cd3068b5366a3f08bbbf8b176ffc8962277 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Fri, 28 Nov 2025 17:47:02 +0300 Subject: [PATCH 22/31] Go --- bin/router/src/schema_state.rs | 2 +- lib/executor/src/execution/plan.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index ed87b02bd..bffb9ce9d 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -8,7 +8,7 @@ use hive_router_plan_executor::{ OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData, }, introspection::schema::SchemaWithMetadata, - plugin_trait::{ControlFlowResult}, + plugin_trait::ControlFlowResult, SubgraphExecutorMap, }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 1ee1d8f34..f5b7ec815 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -301,6 +301,7 @@ struct PreparedFlattenData { } impl<'exec, 'req> Executor<'exec, 'req> { + #[allow(clippy::too_many_arguments)] pub fn new( variable_values: &'exec Option>, executors: &'exec SubgraphExecutorMap, From d8ece27cf2317ac18b5e0b3119ebb57e4c2ea3fb Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Fri, 28 Nov 2025 17:50:33 +0300 Subject: [PATCH 23/31] Configuration --- docs/README.md | 9 +++++++++ lib/router-config/src/lib.rs | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/README.md b/docs/README.md index fb474b9d2..0a61d1baa 100644 --- a/docs/README.md +++ b/docs/README.md @@ -13,6 +13,7 @@ |[**log**](#log)|`object`|The router logger configuration.
Default: `{"filter":null,"format":"json","level":"info"}`
|| |[**override\_labels**](#override_labels)|`object`|Configuration for overriding labels.
|| |[**override\_subgraph\_urls**](#override_subgraph_urls)|`object`|Configuration for overriding subgraph URLs.
Default: `{}`
|| +|[**plugins**](#plugins)|`object`|Configuration for custom plugins
|| |[**query\_planner**](#query_planner)|`object`|Query planning configuration.
Default: `{"allow_expose":false,"timeout":"10s"}`
|| |[**supergraph**](#supergraph)|`object`|Configuration for the Federation supergraph source. By default, the router will use a local file-based supergraph source (`./supergraph.graphql`).
|| |[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"dedupe_enabled":true,"max_connections_per_host":100,"pool_idle_timeout_seconds":50}`
|| @@ -102,6 +103,7 @@ override_subgraph_urls: .original_url } +plugins: {} query_planner: allow_expose: false timeout: 10s @@ -1641,6 +1643,13 @@ products: |----|----|-----------|--------| |**url**||Overrides for the URL of the subgraph.

For convenience, a plain string in your configuration will be treated as a static URL.

### Static URL Example
```yaml
url: "https://api.example.com/graphql"
```

### Dynamic Expression Example

The expression has access to the following variables:
- `request`: The incoming HTTP request, including headers and other metadata.
- `original_url`: The original URL of the subgraph (from supergraph sdl).

```yaml
url:
expression: \|
if .request.headers."x-region" == "us-east" {
"https://products-us-east.example.com/graphql"
} else if .request.headers."x-region" == "eu-west" {
"https://products-eu-west.example.com/graphql"
} else {
.original_url
}
|yes| +
+## plugins: object + +Configuration for custom plugins + + +**Additional Properties:** allowed ## query\_planner: object diff --git a/lib/router-config/src/lib.rs b/lib/router-config/src/lib.rs index 925246c2c..d113288f7 100644 --- a/lib/router-config/src/lib.rs +++ b/lib/router-config/src/lib.rs @@ -93,7 +93,7 @@ pub struct HiveRouterConfig { #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub override_labels: OverrideLabelsConfig, - /// Configuration for plugins. + /// Configuration for custom plugins #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub plugins: HashMap, } From 3b6d1a84374a3f067742911d68637406caf68f1a Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Fri, 28 Nov 2025 17:52:12 +0300 Subject: [PATCH 24/31] Lets go --- bin/router/src/pipeline/mod.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 730921683..4e0fe8e42 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -184,9 +184,10 @@ pub async fn execute_pipeline( graphql_params = deserialization_payload.graphql_params; body = deserialization_payload.body; } - let mut graphql_params = graphql_params.unwrap_or_else(|| { - deserialize_graphql_params(req, body).expect("Failed to parse execution request") - }); + let mut graphql_params = match graphql_params { + Some(params) => params, + None => deserialize_graphql_params(req, body)?, + }; if let Some(plugin_req_state) = &plugin_req_state { let mut payload = OnGraphQLParamsEndPayload { From 779bad96315a2177d8d928e675157aa775667225 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Fri, 28 Nov 2025 17:58:45 +0300 Subject: [PATCH 25/31] No Mutex --- bench/subgraphs/lib.rs | 24 ++++++++++-------------- e2e/src/testkit.rs | 9 +++++---- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/bench/subgraphs/lib.rs b/bench/subgraphs/lib.rs index 0d8fa8469..507de10fa 100644 --- a/bench/subgraphs/lib.rs +++ b/bench/subgraphs/lib.rs @@ -13,14 +13,12 @@ use axum::{ routing::{get, post_service}, Router, }; +use dashmap::DashMap; use sonic_rs::Value; -use std::{collections::HashMap, env::var, sync::Arc}; +use std::{env::var, sync::Arc}; use tokio::{ net::TcpListener, - sync::{ - oneshot::{self, Sender}, - Mutex, - }, + sync::oneshot::{self, Sender}, task::JoinHandle, }; @@ -55,7 +53,7 @@ async fn add_subgraph_header(req: Request, next: Next) -> Response { } async fn track_requests( - State(state): State, + State(state): State>, request: Request, next: Next, ) -> impl IntoResponse { @@ -63,9 +61,8 @@ async fn track_requests( let (parts, body) = request.into_parts(); let body_bytes = to_bytes(body, usize::MAX).await.unwrap(); let record = extract_record(&parts, body_bytes.clone()); - let mut log = state.request_log.lock().await; - log.entry(path).or_default().push(record); + state.request_log.entry(path).or_default().push(record); let new_body = axum::body::Body::from(body_bytes); let request = Request::from_parts(parts, new_body); @@ -92,25 +89,24 @@ pub struct RequestLog { pub request_body: Value, } -#[derive(Clone)] pub struct SubgraphsServiceState { - pub request_log: Arc>>>, + pub request_log: DashMap>, pub health_check_url: String, } pub fn start_subgraphs_server( port: Option, -) -> (JoinHandle<()>, Sender<()>, SubgraphsServiceState) { +) -> (JoinHandle<()>, Sender<()>, Arc) { let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); let host = var("HOST").unwrap_or("0.0.0.0".to_owned()); let port = port .map(|v| v.to_string()) .unwrap_or(var("PORT").unwrap_or("4200".to_owned())); - let shared_state = SubgraphsServiceState { - request_log: Arc::new(Mutex::new(HashMap::new())), + let shared_state = Arc::new(SubgraphsServiceState { + request_log: DashMap::new(), health_check_url: format!("http://{}:{}/health", host, port), - }; + }); let app = Router::new() .route( diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index f360d66da..1f73c6d18 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -81,7 +81,7 @@ where pub struct SubgraphsServer { shutdown_tx: Option>, - subgraph_shared_state: SubgraphsServiceState, + subgraph_shared_state: Arc, } impl Drop for SubgraphsServer { @@ -123,9 +123,10 @@ impl SubgraphsServer { } pub async fn get_subgraph_requests_log(&self, subgraph_name: &str) -> Option> { - let log = self.subgraph_shared_state.request_log.lock().await; - - log.get(&format!("/{}", subgraph_name)).cloned() + self.subgraph_shared_state + .request_log + .get(&format!("/{}", subgraph_name)) + .map(|entry| entry.value().clone()) } } From 8c23103d75f79da59c462e3feebcc83f996d94c2 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Fri, 28 Nov 2025 19:08:41 +0300 Subject: [PATCH 26/31] Easier context handling --- e2e/src/plugins/context_data.rs | 21 ++--- e2e/src/plugins/multipart.rs | 9 +- e2e/src/plugins/propagate_status_code.rs | 8 +- lib/executor/src/plugins/plugin_context.rs | 95 +++++++++++++--------- 4 files changed, 72 insertions(+), 61 deletions(-) diff --git a/e2e/src/plugins/context_data.rs b/e2e/src/plugins/context_data.rs index d2ef891bd..c0b7840ff 100644 --- a/e2e/src/plugins/context_data.rs +++ b/e2e/src/plugins/context_data.rs @@ -7,7 +7,6 @@ use hive_router_plan_executor::{ on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, }, - plugin_context::PluginContextMutEntry, plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, }; @@ -51,9 +50,8 @@ impl RouterPlugin for ContextDataPlugin { payload.context.insert(context_data); payload.on_end(|payload| { - let mut ctx_data_entry = payload.context.get_mut_entry(); - let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); - if let Some(context_data) = context_data { + let context_data = payload.context.get_mut::(); + if let Some(mut context_data) = context_data { context_data.response_count += 1; tracing::info!("subrequest count {}", context_data.response_count); } @@ -64,21 +62,18 @@ impl RouterPlugin for ContextDataPlugin { &'exec self, mut payload: OnSubgraphExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { - let ctx_data_entry = payload.context.get_ref_entry(); - let context_data: Option<&ContextData> = ctx_data_entry.get_ref(); - if let Some(context_data) = context_data { - tracing::info!("hello {}", context_data.incoming_data); // Hello world! - let new_header_value = format!("Hello {}", context_data.incoming_data); + let context_data_entry = payload.context.get_ref::(); + if let Some(ref context_data_entry) = context_data_entry { + tracing::info!("hello {}", context_data_entry.incoming_data); // Hello world! + let new_header_value = format!("Hello {}", context_data_entry.incoming_data); payload.execution_request.headers.insert( "x-hello", http::HeaderValue::from_str(&new_header_value).unwrap(), ); } payload.on_end(|payload: OnSubgraphExecuteEndPayload<'exec>| { - let mut ctx_data_entry: PluginContextMutEntry = - payload.context.get_mut_entry(); - let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); - if let Some(context_data) = context_data { + let context_data = payload.context.get_mut::(); + if let Some(mut context_data) = context_data { context_data.response_count += 1; tracing::info!("subrequest count {}", context_data.response_count); } diff --git a/e2e/src/plugins/multipart.rs b/e2e/src/plugins/multipart.rs index 21fa90592..b421cde6c 100644 --- a/e2e/src/plugins/multipart.rs +++ b/e2e/src/plugins/multipart.rs @@ -83,10 +83,8 @@ impl RouterPlugin for MultipartPlugin { }); } field_name => { - let mut ctx_entry = payload.context.get_mut_entry(); - let multipart_ctx: Option<&mut MultipartContext> = - ctx_entry.get_ref_mut(); - if let Some(multipart_ctx) = multipart_ctx { + let multipart_ctx = payload.context.get_mut::(); + if let Some(mut multipart_ctx) = multipart_ctx { let multipart_file = MultipartFile { filename, content_type, @@ -110,8 +108,7 @@ impl RouterPlugin for MultipartPlugin { mut payload: OnSubgraphHttpRequestPayload<'exec>, ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { if let Some(variables) = &payload.execution_request.variables { - let ctx_ref = payload.context.get_ref_entry(); - let multipart_ctx: Option<&MultipartContext> = ctx_ref.get_ref(); + let multipart_ctx = payload.context.get_ref::(); if let Some(multipart_ctx) = multipart_ctx { let mut file_map: HashMap> = HashMap::new(); for variable_name in variables.keys() { diff --git a/e2e/src/plugins/propagate_status_code.rs b/e2e/src/plugins/propagate_status_code.rs index 177dbbb8c..856424f62 100644 --- a/e2e/src/plugins/propagate_status_code.rs +++ b/e2e/src/plugins/propagate_status_code.rs @@ -55,9 +55,8 @@ impl RouterPlugin for PropagateStatusCodePlugin { // if a response contains a status code we're watching... if self.status_codes.contains(&status_code) { // Checking if there is already a context entry - let mut ctx_entry = payload.context.get_mut_entry(); - let ctx: Option<&mut PropagateStatusCodeCtx> = ctx_entry.get_ref_mut(); - if let Some(ctx) = ctx { + let ctx = payload.context.get_mut::(); + if let Some(mut ctx) = ctx { // Update the status code if the new one is more severe (higher) if status_code.as_u16() > ctx.status_code.as_u16() { ctx.status_code = status_code; @@ -77,8 +76,7 @@ impl RouterPlugin for PropagateStatusCodePlugin { ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponsePayload<'exec>> { payload.on_end(|mut payload| { // Checking if there is a context entry - let ctx_entry = payload.context.get_ref_entry(); - let ctx: Option<&PropagateStatusCodeCtx> = ctx_entry.get_ref(); + let ctx = payload.context.get_ref::(); if let Some(ctx) = ctx { // Update the HTTP response status code *payload.response.response_mut().status_mut() = ctx.status_code; diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs index 14a93974d..8f8960131 100644 --- a/lib/executor/src/plugins/plugin_context.rs +++ b/lib/executor/src/plugins/plugin_context.rs @@ -1,5 +1,6 @@ use std::{ any::{Any, TypeId}, + ops::{Deref, DerefMut}, sync::Arc, }; @@ -29,35 +30,59 @@ pub struct PluginContext { } pub struct PluginContextRefEntry<'a, T> { - pub entry: Option>>, + pub entry: Ref<'a, TypeId, Box>, phantom: std::marker::PhantomData, } -impl<'a, T: Any + Send + Sync> PluginContextRefEntry<'a, T> { - pub fn get_ref(&self) -> Option<&T> { - match &self.entry { - None => None, - Some(entry) => { - let boxed_any = entry.value(); - Some(boxed_any.downcast_ref::()?) - } - } +impl<'a, T: Any + Send + Sync> AsRef for PluginContextRefEntry<'a, T> { + fn as_ref(&self) -> &T { + let boxed_any = self.entry.value(); + boxed_any + .downcast_ref::() + .expect("type mismatch in PluginContextRefEntry") + } +} + +impl<'a, T: Any + Send + Sync> Deref for PluginContextRefEntry<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.as_ref() } } + pub struct PluginContextMutEntry<'a, T> { - pub entry: Option>>, + pub entry: RefMut<'a, TypeId, Box>, phantom: std::marker::PhantomData, } -impl<'a, T: Any + Send + Sync> PluginContextMutEntry<'a, T> { - pub fn get_ref_mut(&mut self) -> Option<&mut T> { - match &mut self.entry { - None => None, - Some(entry) => { - let boxed_any = entry.value_mut(); - Some(boxed_any.downcast_mut::()?) - } - } +impl<'a, T: Any + Send + Sync> AsRef for PluginContextMutEntry<'a, T> { + fn as_ref(&self) -> &T { + let boxed_any = self.entry.value(); + boxed_any + .downcast_ref::() + .expect("type mismatch in PluginContextMutEntry") + } +} + +impl<'a, T: Any + Send + Sync> Deref for PluginContextMutEntry<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl<'a, T: Any + Send + Sync> AsMut for PluginContextMutEntry<'a, T> { + fn as_mut(&mut self) -> &mut T { + let boxed_any = self.entry.value_mut(); + boxed_any + .downcast_mut::() + .expect("type mismatch in PluginContextMutEntry") + } +} + +impl<'a, T: Any + Send + Sync> DerefMut for PluginContextMutEntry<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.as_mut() } } @@ -72,22 +97,21 @@ impl PluginContext { .insert(type_id, Box::new(value)) .and_then(|boxed_any| boxed_any.downcast::().ok()) } - pub fn get_ref_entry(&self) -> PluginContextRefEntry<'_, T> { + pub fn get_ref(&self) -> Option> { let type_id = TypeId::of::(); - let entry = self.inner.get(&type_id); - PluginContextRefEntry { + self.inner.get(&type_id).map(|entry| PluginContextRefEntry { entry, phantom: std::marker::PhantomData, - } + }) } - pub fn get_mut_entry<'a, T: Any + Send + Sync>(&'a self) -> PluginContextMutEntry<'a, T> { + pub fn get_mut(&self) -> Option> { let type_id = TypeId::of::(); - let entry = self.inner.get_mut(&type_id); - - PluginContextMutEntry { - entry, - phantom: std::marker::PhantomData, - } + self.inner + .get_mut(&type_id) + .map(|entry| PluginContextMutEntry { + entry, + phantom: std::marker::PhantomData, + }) } } @@ -110,8 +134,7 @@ mod tests { let ctx = PluginContext::default(); ctx.insert(TestCtx { value: 42 }); - let entry = ctx.get_ref_entry(); - let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + let ctx_ref: &TestCtx = &ctx.get_ref().unwrap(); assert_eq!(ctx_ref.value, 42); } #[test] @@ -126,13 +149,11 @@ mod tests { ctx.insert(TestCtx { value: 42 }); { - let mut entry = ctx.get_mut_entry(); - let ctx_mut: &mut TestCtx = entry.get_ref_mut().unwrap(); + let ctx_mut: &mut TestCtx = &mut ctx.get_mut().unwrap(); ctx_mut.value = 100; } - let entry = ctx.get_ref_entry(); - let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + let ctx_ref: &TestCtx = &ctx.get_ref().unwrap(); assert_eq!(ctx_ref.value, 100); } } From 25fe7624ae6cc886b5c94100795b423ebf589739 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Fri, 28 Nov 2025 22:23:38 +0300 Subject: [PATCH 27/31] Cleanup --- bin/router/src/plugins/registry.rs | 2 +- lib/executor/src/plugins/plugin_trait.rs | 13 +------------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/bin/router/src/plugins/registry.rs b/bin/router/src/plugins/registry.rs index b41472ecb..0383c8000 100644 --- a/bin/router/src/plugins/registry.rs +++ b/bin/router/src/plugins/registry.rs @@ -5,7 +5,7 @@ use hive_router_plan_executor::plugin_trait::{RouterPluginBoxed, RouterPluginWit use serde_json::Value; use tracing::info; -type PluginFactory = Box Result, serde_json::Error>>; +type PluginFactory = Box serde_json::Result>>; pub struct PluginRegistry { map: HashMap<&'static str, PluginFactory>, diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index ea18c16e2..3bb750e59 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -85,19 +85,8 @@ where Self: RouterPlugin, { fn plugin_name() -> &'static str; - type Config: Send + Sync + DeserializeOwned; + type Config: DeserializeOwned; fn from_config(config: Self::Config) -> Option; - fn from_config_value(value: serde_json::Value) -> serde_json::Result>> - where - Self: Sized, - { - let config: Self::Config = serde_json::from_value(value)?; - let plugin = Self::from_config(config); - match plugin { - None => Ok(None), - Some(plugin) => Ok(Some(Box::new(plugin))), - } - } } #[async_trait::async_trait] From 94dcd3207323fb2cc3d8962f304eeaa1780f5394 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 2 Dec 2025 15:19:25 +0300 Subject: [PATCH 28/31] Unify HTTP Response --- bin/router/src/pipeline/execution.rs | 5 +- bin/router/src/pipeline/mod.rs | 36 +++-- bin/router/src/pipeline/parser.rs | 26 ++-- bin/router/src/pipeline/query_plan.rs | 27 ++-- bin/router/src/pipeline/validation.rs | 25 ++-- bin/router/src/plugins/plugins_service.rs | 59 ++++---- bin/router/src/schema_state.rs | 23 ++-- e2e/src/plugins/apollo_sandbox.rs | 14 +- e2e/src/plugins/apq.rs | 22 +-- e2e/src/plugins/async_auth.rs | 22 +-- e2e/src/plugins/context_data.rs | 19 +-- .../plugins/forbid_anonymous_operations.rs | 16 +-- e2e/src/plugins/multipart.rs | 25 ++-- e2e/src/plugins/one_of.rs | 31 +++-- e2e/src/plugins/propagate_status_code.rs | 17 +-- e2e/src/plugins/response_cache.rs | 24 ++-- e2e/src/plugins/root_field_limit.rs | 23 ++-- e2e/src/plugins/subgraph_response_cache.rs | 17 ++- lib/executor/src/execution/plan.rs | 74 +++------- lib/executor/src/executors/common.rs | 12 +- lib/executor/src/executors/dedupe.rs | 10 +- lib/executor/src/executors/http.rs | 68 +++++---- lib/executor/src/executors/map.rs | 48 +++---- lib/executor/src/plugins/hooks/on_execute.rs | 15 +- .../src/plugins/hooks/on_graphql_params.rs | 25 +++- .../src/plugins/hooks/on_graphql_parse.rs | 18 ++- .../plugins/hooks/on_graphql_validation.rs | 25 +++- .../src/plugins/hooks/on_http_request.rs | 15 +- .../src/plugins/hooks/on_query_plan.rs | 15 +- .../src/plugins/hooks/on_subgraph_execute.rs | 27 ++-- .../plugins/hooks/on_subgraph_http_request.rs | 28 ++-- .../src/plugins/hooks/on_supergraph_load.rs | 19 ++- lib/executor/src/plugins/plugin_trait.rs | 129 ++++++++++-------- 33 files changed, 492 insertions(+), 467 deletions(-) diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index fb0e453c3..d870a8d45 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -6,7 +6,8 @@ use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::shared_state::RouterSharedState; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; -use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; +use hive_router_plan_executor::execution::plan::QueryPlanExecutionContext; +use hive_router_plan_executor::executors::http::HttpResponse; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; use hive_router_plan_executor::plugin_context::PluginRequestState; @@ -34,7 +35,7 @@ pub async fn execute_plan( variable_payload: &CoerceVariablesPayload, client_request_details: &ClientRequestDetails<'_, '_>, plugin_req_state: &Option>, -) -> Result { +) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; if app_state.router_config.query_planner.allow_expose { diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 4e0fe8e42..104bcd19a 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,16 +1,16 @@ use std::sync::Arc; use hive_router_plan_executor::{ - execution::{ - client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::PlanExecutionOutput, + execution::client_request_details::{ + ClientRequestDetails, JwtRequestDetails, OperationDetails, }, + executors::http::HttpResponse, hooks::{ - on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + on_graphql_params::{OnGraphQLParamsEndHookPayload, OnGraphQLParamsStartHookPayload}, on_supergraph_load::SupergraphData, }, plugin_context::{PluginContext, PluginRequestState, RouterHttpRequest}, - plugin_trait::ControlFlowResult, + plugin_trait::{EndControlFlow, StartControlFlow}, }; use hive_router_query_planner::{ state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, @@ -126,7 +126,7 @@ pub async fn graphql_request_handler( ) .await?; let response_status = response.status; - let response_bytes = Bytes::from(response.body); + let response_bytes = response.body; let response_headers = response.headers; let mut response_builder = web::HttpResponse::Ok(); @@ -139,7 +139,7 @@ pub async fn graphql_request_handler( Ok(response_builder .header(http::header::CONTENT_TYPE, response_content_type) .status(response_status) - .body(response_bytes)) + .body(response_bytes.to_vec())) } #[inline] @@ -152,7 +152,7 @@ pub async fn execute_pipeline( schema_state: &SchemaState, jwt_context: Option, plugin_req_state: Option>, -) -> Result { +) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; /* Handle on_deserialize hook in the plugins - START */ @@ -161,8 +161,8 @@ pub async fn execute_pipeline( let mut graphql_params = None; let mut body = body; if let Some(plugin_req_state) = plugin_req_state.as_ref() { - let mut deserialization_payload: OnGraphQLParamsStartPayload = - OnGraphQLParamsStartPayload { + let mut deserialization_payload: OnGraphQLParamsStartHookPayload = + OnGraphQLParamsStartHookPayload { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, body, @@ -172,11 +172,11 @@ pub async fn execute_pipeline( let result = plugin.on_graphql_params(deserialization_payload).await; deserialization_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { + StartControlFlow::Continue => { /* continue to next plugin */ } + StartControlFlow::EndResponse(response) => { return Ok(response); } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { deserialization_end_callbacks.push(callback); } } @@ -190,7 +190,7 @@ pub async fn execute_pipeline( }; if let Some(plugin_req_state) = &plugin_req_state { - let mut payload = OnGraphQLParamsEndPayload { + let mut payload = OnGraphQLParamsEndHookPayload { graphql_params, context: &plugin_req_state.context, }; @@ -198,14 +198,10 @@ pub async fn execute_pipeline( let result = deserialization_end_callback(payload); payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { + EndControlFlow::Continue => { /* continue to next plugin */ } + EndControlFlow::EndResponse(response) => { return Ok(response); } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); - } } } graphql_params = payload.graphql_params; diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 933640abe..2ccef500e 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,13 +2,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::executors::http::HttpResponse; use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_graphql_parse::{ - OnGraphQLParseEndPayload, OnGraphQLParseStartPayload, + OnGraphQLParseEndHookPayload, OnGraphQLParseStartHookPayload, }; use hive_router_plan_executor::plugin_context::PluginRequestState; -use hive_router_plan_executor::plugin_trait::ControlFlowResult; +use hive_router_plan_executor::plugin_trait::{EndControlFlow, StartControlFlow}; use hive_router_query_planner::utils::parsing::safe_parse_operation; use xxhash_rust::xxh3::Xxh3; @@ -25,7 +25,7 @@ pub struct GraphQLParserPayload { pub enum ParseResult { Payload(GraphQLParserPayload), - Response(PlanExecutionOutput), + Response(HttpResponse), } #[inline] @@ -48,7 +48,7 @@ pub async fn parse_operation_with_cache( let mut on_end_callbacks = vec![]; if let Some(plugin_req_state) = plugin_req_state.as_ref() { /* Handle on_graphql_parse hook in the plugins - START */ - let mut start_payload = OnGraphQLParseStartPayload { + let mut start_payload = OnGraphQLParseStartHookPayload { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, graphql_params, @@ -58,13 +58,13 @@ pub async fn parse_operation_with_cache( let result = plugin.on_graphql_parse(start_payload).await; start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + StartControlFlow::Continue => { // continue to next plugin } - ControlFlowResult::EndResponse(response) => { + StartControlFlow::EndResponse(response) => { return Ok(ParseResult::Response(response)); } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { // store the callback to be called later on_end_callbacks.push(callback); } @@ -85,21 +85,17 @@ pub async fn parse_operation_with_cache( parsed } }; - let mut end_payload = OnGraphQLParseEndPayload { document }; + let mut end_payload = OnGraphQLParseEndHookPayload { document }; for callback in on_end_callbacks { let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + EndControlFlow::Continue => { // continue to next callback } - ControlFlowResult::EndResponse(response) => { + EndControlFlow::EndResponse(response) => { return Ok(ParseResult::Response(response)); } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!(); - } } } let document = end_payload.document; diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 343972bd0..87e25b627 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -5,13 +5,13 @@ use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; use crate::schema_state::SchemaState; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::executors::http::HttpResponse; use hive_router_plan_executor::hooks::on_query_plan::{ - OnQueryPlanEndPayload, OnQueryPlanStartPayload, + OnQueryPlanEndHookPayload, OnQueryPlanStartHookPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::plugin_context::PluginRequestState; -use hive_router_plan_executor::plugin_trait::ControlFlowResult; +use hive_router_plan_executor::plugin_trait::{EndControlFlow, StartControlFlow}; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; @@ -19,12 +19,12 @@ use xxhash_rust::xxh3::Xxh3; pub enum QueryPlanResult { QueryPlan(Arc), - Response(PlanExecutionOutput), + Response(HttpResponse), } pub enum QueryPlanGetterError { Planner(PlannerError), - Response(PlanExecutionOutput), + Response(HttpResponse), } #[inline] @@ -60,7 +60,7 @@ pub async fn plan_operation_with_cache<'req>( if let Some(plugin_req_state) = plugin_req_state { /* Handle on_query_plan hook in the plugins - START */ - let mut start_payload = OnQueryPlanStartPayload { + let mut start_payload = OnQueryPlanStartHookPayload { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, filtered_operation_for_plan, @@ -74,13 +74,13 @@ pub async fn plan_operation_with_cache<'req>( let result = plugin.on_query_plan(start_payload).await; start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + StartControlFlow::Continue => { // continue to next plugin } - ControlFlowResult::EndResponse(response) => { + StartControlFlow::EndResponse(response) => { return Err(QueryPlanGetterError::Response(response)); } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); } } @@ -101,21 +101,18 @@ pub async fn plan_operation_with_cache<'req>( .map_err(QueryPlanGetterError::Planner)?, }; - let mut end_payload = OnQueryPlanEndPayload { query_plan }; + let mut end_payload = OnQueryPlanEndHookPayload { query_plan }; for callback in on_end_callbacks { let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + EndControlFlow::Continue => { // continue to next callback } - ControlFlowResult::EndResponse(response) => { + EndControlFlow::EndResponse(response) => { return Err(QueryPlanGetterError::Response(response)); } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - } } } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 8dd98f922..da03ddada 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -5,13 +5,13 @@ use crate::pipeline::parser::GraphQLParserPayload; use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; use graphql_tools::validation::validate::validate; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::executors::http::HttpResponse; use hive_router_plan_executor::hooks::on_graphql_validation::{ - OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, + OnGraphQLValidationEndHookPayload, OnGraphQLValidationStartHookPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::plugin_context::PluginRequestState; -use hive_router_plan_executor::plugin_trait::ControlFlowResult; +use hive_router_plan_executor::plugin_trait::{EndControlFlow, StartControlFlow}; use tracing::{error, trace}; #[inline] @@ -21,7 +21,7 @@ pub async fn validate_operation_with_cache( app_state: &RouterSharedState, parser_payload: &GraphQLParserPayload, plugin_req_state: &Option>, -) -> Result, PipelineErrorVariant> { +) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; let validation_result = match schema_state @@ -47,7 +47,7 @@ pub async fn validate_operation_with_cache( let document = &parser_payload.parsed_operation; let errors = if let Some(plugin_req_state) = plugin_req_state.as_ref() { /* Handle on_graphql_validate hook in the plugins - START */ - let mut start_payload = OnGraphQLValidationStartPayload::new( + let mut start_payload = OnGraphQLValidationStartHookPayload::new( plugin_req_state, consumer_schema_ast, document, @@ -57,13 +57,13 @@ pub async fn validate_operation_with_cache( let result = plugin.on_graphql_validation(start_payload).await; start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + StartControlFlow::Continue => { // continue to next plugin } - ControlFlowResult::EndResponse(response) => { + StartControlFlow::EndResponse(response) => { return Ok(Some(response)); } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); } } @@ -80,21 +80,18 @@ pub async fn validate_operation_with_cache( validate(consumer_schema_ast, document, &app_state.validation_plan) }; - let mut end_payload = OnGraphQLValidationEndPayload { errors }; + let mut end_payload = OnGraphQLValidationEndHookPayload { errors }; for callback in on_end_callbacks { let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + EndControlFlow::Continue => { // continue to next callback } - ControlFlowResult::EndResponse(response) => { + EndControlFlow::EndResponse(response) => { return Ok(Some(response)); } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - } } } /* Handle on_graphql_validate hook in the plugins - END */ diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs index 6bf43aac6..b88728379 100644 --- a/bin/router/src/plugins/plugins_service.rs +++ b/bin/router/src/plugins/plugins_service.rs @@ -1,12 +1,10 @@ use std::sync::Arc; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, - hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + hooks::on_http_request::{OnHttpRequestHookPayload, OnHttpResponseHookPayload}, plugin_context::PluginContext, - plugin_trait::ControlFlowResult, + plugin_trait::{EndControlFlow, StartControlFlow}, }; -use http::StatusCode; use ntex::{ http::ResponseBuilder, service::{Service, ServiceCtx}, @@ -53,45 +51,43 @@ where let plugin_context = Arc::new(PluginContext::default()); req.extensions_mut().insert(plugin_context.clone()); - let mut start_payload = OnHttpRequestPayload { + let mut start_payload = OnHttpRequestHookPayload { router_http_request: req, context: &plugin_context, }; let mut on_end_callbacks = vec![]; - let mut early_response: Option = None; for plugin in plugins.iter() { let result = plugin.on_http_request(start_payload); start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + StartControlFlow::Continue => { // continue to next plugin } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); } - ControlFlowResult::EndResponse(response) => { - early_response = Some(response); - break; + StartControlFlow::EndResponse(response) => { + let mut resp_builder = ResponseBuilder::new(response.status); + for (key, value) in response.headers { + if let Some(key) = key { + resp_builder.header(key, value); + } + } + let response = start_payload + .router_http_request + .into_response(resp_builder.body(response.body.to_vec())); + return Ok(response); } } } let req = start_payload.router_http_request; - let response = if let Some(early_response) = early_response { - let mut builder = ResponseBuilder::new(StatusCode::OK); - for (key, value) in early_response.headers.iter() { - builder.header(key, value); - } - let res = builder.body(early_response.body); - req.into_response(res) - } else { - ctx.call(&self.service, req).await? - }; + let response = ctx.call(&self.service, req).await?; - let mut end_payload = OnHttpResponsePayload { + let mut end_payload = OnHttpResponseHookPayload { response, context: &plugin_context, }; @@ -100,16 +96,19 @@ where let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + EndControlFlow::Continue => { // continue to next callback } - ControlFlowResult::EndResponse(_response) => { - // Short-circuit the request with the provided response - unimplemented!() - } - ControlFlowResult::OnEnd(_) => { - // This should not happen - unreachable!(); + EndControlFlow::EndResponse(response) => { + let mut resp_builder = ResponseBuilder::new(response.status); + for (key, value) in response.headers { + if let Some(key) = key { + resp_builder.header(key, value); + } + } + let response = resp_builder.body(response.body.to_vec()); + end_payload.response = end_payload.response.into_response(response); + return Ok(end_payload.response); } } } diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index bffb9ce9d..ed88b29b7 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -5,10 +5,10 @@ use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; use hive_router_plan_executor::{ executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{ - OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData, + OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload, SupergraphData, }, introspection::schema::SchemaWithMetadata, - plugin_trait::ControlFlowResult, + plugin_trait::{EndControlFlow, StartControlFlow}, SubgraphExecutorMap, }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -90,7 +90,7 @@ impl SchemaState { let mut on_end_callbacks = vec![]; if let Some(plugins) = app_state.plugins.as_ref() { - let mut start_payload = OnSupergraphLoadStartPayload { + let mut start_payload = OnSupergraphLoadStartHookPayload { current_supergraph_data: swappable_data_spawn_clone.clone(), new_ast, }; @@ -98,13 +98,13 @@ impl SchemaState { let result = plugin.on_supergraph_reload(start_payload); start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + StartControlFlow::Continue => { // continue to next plugin } - ControlFlowResult::EndResponse(_) => { + StartControlFlow::EndResponse(_) => { unreachable!("Plugins should not end supergraph reload processing"); } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); } } @@ -114,7 +114,7 @@ impl SchemaState { match Self::build_data(router_config.clone(), &new_ast) { Ok(new_supergraph_data) => { - let mut end_payload = OnSupergraphLoadEndPayload { + let mut end_payload = OnSupergraphLoadEndHookPayload { new_supergraph_data, }; @@ -122,19 +122,14 @@ impl SchemaState { let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + EndControlFlow::Continue => { // continue to next callback } - ControlFlowResult::EndResponse(_) => { + EndControlFlow::EndResponse(_) => { unreachable!( "Plugins should not end supergraph reload processing" ); } - ControlFlowResult::OnEnd(_) => { - unreachable!( - "End callbacks should not register further end callbacks" - ); - } } } diff --git a/e2e/src/plugins/apollo_sandbox.rs b/e2e/src/plugins/apollo_sandbox.rs index d3c28be68..327cb1c44 100644 --- a/e2e/src/plugins/apollo_sandbox.rs +++ b/e2e/src/plugins/apollo_sandbox.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, - hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, - plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + executors::http::HttpResponse, + hooks::on_http_request::{OnHttpRequestHookPayload, OnHttpRequestHookResult}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; use http::HeaderMap; use reqwest::StatusCode; @@ -140,8 +140,8 @@ pub struct ApolloSandboxPlugin { impl RouterPlugin for ApolloSandboxPlugin { fn on_http_request<'req>( &'req self, - payload: OnHttpRequestPayload<'req>, - ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { + payload: OnHttpRequestHookPayload<'req>, + ) -> OnHttpRequestHookResult<'req> { if payload.router_http_request.path() == "/apollo-sandbox" { let config = sonic_rs::to_string(&self.serialized_options).unwrap_or_else(|_| "{}".to_string()); @@ -159,8 +159,8 @@ impl RouterPlugin for ApolloSandboxPlugin { ); let mut headers = HeaderMap::new(); headers.insert("Content-Type", "text/html".parse().unwrap()); - return payload.end_response(PlanExecutionOutput { - body: html.into_bytes(), + return payload.end_response(HttpResponse { + body: html.into_bytes().into(), headers, status: StatusCode::OK, }); diff --git a/e2e/src/plugins/apq.rs b/e2e/src/plugins/apq.rs index e9c9c6ff5..683606f30 100644 --- a/e2e/src/plugins/apq.rs +++ b/e2e/src/plugins/apq.rs @@ -5,9 +5,9 @@ use serde_json::json; use sonic_rs::{JsonContainerTrait, JsonValueTrait}; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, - hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + executors::http::HttpResponse, + hooks::on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; #[derive(Deserialize)] @@ -39,8 +39,8 @@ impl RouterPluginWithConfig for APQPlugin { impl RouterPlugin for APQPlugin { async fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { payload.on_end(|mut payload| { let persisted_query_ext = payload .graphql_params @@ -62,8 +62,8 @@ impl RouterPlugin for APQPlugin { } ] }); - return payload.end_response(PlanExecutionOutput { - body: body.to_string().into_bytes(), + return payload.end_response(HttpResponse { + body: body.to_string().into_bytes().into(), status: StatusCode::BAD_REQUEST, headers: http::HeaderMap::new(), }); @@ -85,8 +85,8 @@ impl RouterPlugin for APQPlugin { } ] }); - return payload.end_response(PlanExecutionOutput { - body: body.to_string().into_bytes(), + return payload.end_response(HttpResponse { + body: body.to_string().into_bytes().into(), status: StatusCode::BAD_REQUEST, headers: http::HeaderMap::new(), }); @@ -112,8 +112,8 @@ impl RouterPlugin for APQPlugin { } ] }); - return payload.end_response(PlanExecutionOutput { - body: body.to_string().into_bytes(), + return payload.end_response(HttpResponse { + body: body.to_string().into_bytes().into(), status: StatusCode::NOT_FOUND, headers: http::HeaderMap::new(), }); diff --git a/e2e/src/plugins/async_auth.rs b/e2e/src/plugins/async_auth.rs index f9e3b82eb..07871d6a2 100644 --- a/e2e/src/plugins/async_auth.rs +++ b/e2e/src/plugins/async_auth.rs @@ -4,9 +4,9 @@ use sonic_rs::json; use std::path::PathBuf; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, - hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + executors::http::HttpResponse, + hooks::on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; #[derive(Deserialize)] @@ -44,8 +44,8 @@ impl RouterPlugin for AllowClientIdFromFilePlugin { // We don't use on_http_request here because we want to run this only when it is a GraphQL request async fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { let header = payload.router_http_request.headers.get(&self.header_key); match header { Some(client_id) => { @@ -73,8 +73,8 @@ impl RouterPlugin for AllowClientIdFromFilePlugin { ] } ); - return payload.end_response(PlanExecutionOutput { - body: sonic_rs::to_vec(&body).unwrap_or_default(), + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), headers: http::HeaderMap::new(), status: http::StatusCode::FORBIDDEN, }); @@ -95,8 +95,8 @@ impl RouterPlugin for AllowClientIdFromFilePlugin { ] } ); - return payload.end_response(PlanExecutionOutput { - body: sonic_rs::to_vec(&body).unwrap_or_default(), + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), headers: http::HeaderMap::new(), status: http::StatusCode::BAD_REQUEST, }); @@ -118,8 +118,8 @@ impl RouterPlugin for AllowClientIdFromFilePlugin { ] } ); - return payload.end_response(PlanExecutionOutput { - body: sonic_rs::to_vec(&body).unwrap_or_default(), + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), headers: http::HeaderMap::new(), status: http::StatusCode::UNAUTHORIZED, }); diff --git a/e2e/src/plugins/context_data.rs b/e2e/src/plugins/context_data.rs index c0b7840ff..ab120122d 100644 --- a/e2e/src/plugins/context_data.rs +++ b/e2e/src/plugins/context_data.rs @@ -4,10 +4,13 @@ use serde::Deserialize; use hive_router_plan_executor::{ hooks::{ - on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + on_subgraph_execute::{ + OnSubgraphExecuteEndHookPayload, OnSubgraphExecuteStartHookPayload, + OnSubgraphExecuteStartHookResult, + }, }, - plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; #[derive(Deserialize)] @@ -40,8 +43,8 @@ impl RouterPluginWithConfig for ContextDataPlugin { impl RouterPlugin for ContextDataPlugin { async fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { let context_data = ContextData { incoming_data: "world".to_string(), response_count: 0, @@ -60,8 +63,8 @@ impl RouterPlugin for ContextDataPlugin { } async fn on_subgraph_execute<'exec>( &'exec self, - mut payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + mut payload: OnSubgraphExecuteStartHookPayload<'exec>, + ) -> OnSubgraphExecuteStartHookResult<'exec> { let context_data_entry = payload.context.get_ref::(); if let Some(ref context_data_entry) = context_data_entry { tracing::info!("hello {}", context_data_entry.incoming_data); // Hello world! @@ -71,7 +74,7 @@ impl RouterPlugin for ContextDataPlugin { http::HeaderValue::from_str(&new_header_value).unwrap(), ); } - payload.on_end(|payload: OnSubgraphExecuteEndPayload<'exec>| { + payload.on_end(|payload: OnSubgraphExecuteEndHookPayload<'exec>| { let context_data = payload.context.get_mut::(); if let Some(mut context_data) = context_data { context_data.response_count += 1; diff --git a/e2e/src/plugins/forbid_anonymous_operations.rs b/e2e/src/plugins/forbid_anonymous_operations.rs index 1539c3ec0..e4eaace23 100644 --- a/e2e/src/plugins/forbid_anonymous_operations.rs +++ b/e2e/src/plugins/forbid_anonymous_operations.rs @@ -5,9 +5,9 @@ use serde::Deserialize; use sonic_rs::json; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, - hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, - plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + executors::http::HttpResponse, + hooks::on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; #[derive(Deserialize)] @@ -34,8 +34,8 @@ impl RouterPluginWithConfig for ForbidAnonymousOperationsPlugin { impl RouterPlugin for ForbidAnonymousOperationsPlugin { async fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { let maybe_operation_name = &payload .graphql_params .as_ref() @@ -50,7 +50,7 @@ impl RouterPlugin for ForbidAnonymousOperationsPlugin { tracing::error!("Operation is not allowed!"); // Prepare an HTTP 400 response with a GraphQL error message - let response_body = json!({ + let body = json!({ "errors": [ { "message": "Anonymous operations are not allowed", @@ -60,8 +60,8 @@ impl RouterPlugin for ForbidAnonymousOperationsPlugin { } ] }); - return payload.end_response(PlanExecutionOutput { - body: sonic_rs::to_vec(&response_body).unwrap_or_default(), + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), headers: http::HeaderMap::new(), status: StatusCode::BAD_REQUEST, }); diff --git a/e2e/src/plugins/multipart.rs b/e2e/src/plugins/multipart.rs index b421cde6c..2cf929a95 100644 --- a/e2e/src/plugins/multipart.rs +++ b/e2e/src/plugins/multipart.rs @@ -2,15 +2,16 @@ use std::collections::HashMap; use bytes::Bytes; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, - executors::dedupe::SharedResponse, + executors::http::HttpResponse, hooks::{ on_graphql_params::{ - GraphQLParams, OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload, + GraphQLParams, OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult, + }, + on_subgraph_http_request::{ + OnSubgraphHttpRequestHookPayload, OnSubgraphHttpRequestHookResult, }, - on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, }, - plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; use multer::Multipart; use serde::Deserialize; @@ -52,8 +53,8 @@ impl RouterPluginWithConfig for MultipartPlugin { impl RouterPlugin for MultipartPlugin { async fn on_graphql_params<'exec>( &'exec self, - mut payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + mut payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { if let Some(content_type) = payload.router_http_request.headers.get("content-type") { if let Ok(content_type_str) = content_type.to_str() { if content_type_str.starts_with("multipart/form-data") { @@ -105,8 +106,8 @@ impl RouterPlugin for MultipartPlugin { async fn on_subgraph_http_request<'exec>( &'exec self, - mut payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { + mut payload: OnSubgraphHttpRequestHookPayload<'exec>, + ) -> OnSubgraphHttpRequestHookResult<'exec> { if let Some(variables) = &payload.execution_request.variables { let multipart_ctx = payload.context.get_ref::(); if let Some(multipart_ctx) = multipart_ctx { @@ -152,7 +153,7 @@ impl RouterPlugin for MultipartPlugin { .await; match resp { Ok(resp) => { - payload.response = Some(SharedResponse { + payload.response = Some(HttpResponse { status: resp.status(), headers: resp.headers().clone(), body: resp.bytes().await.unwrap(), @@ -165,10 +166,10 @@ impl RouterPlugin for MultipartPlugin { "message": format!("Failed to send multipart request to subgraph: {}", err) }] }); - return payload.end_response(PlanExecutionOutput { + return payload.end_response(HttpResponse { status: reqwest::StatusCode::INTERNAL_SERVER_ERROR, headers: reqwest::header::HeaderMap::new(), - body: serde_json::to_vec(&body).unwrap(), + body: serde_json::to_vec(&body).unwrap().into(), }); } } diff --git a/e2e/src/plugins/one_of.rs b/e2e/src/plugins/one_of.rs index e133bbb30..08f6fd5f8 100644 --- a/e2e/src/plugins/one_of.rs +++ b/e2e/src/plugins/one_of.rs @@ -60,13 +60,15 @@ use graphql_tools::{ }, }; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, + executors::http::HttpResponse, hooks::{ - on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, - on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, - on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + on_execute::{OnExecuteStartHookPayload, OnExecuteStartHookResult}, + on_graphql_validation::{ + OnGraphQLValidationStartHookPayload, OnGraphQLValidationStartHookResult, + }, + on_supergraph_load::{OnSupergraphLoadEndHookPayload, OnSupergraphLoadStartHookPayload}, }, - plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload, StartHookResult}, }; use serde::Deserialize; use sonic_rs::{json, JsonContainerTrait}; @@ -101,9 +103,8 @@ impl RouterPlugin for OneOfPlugin { // 1. During validation step async fn on_graphql_validation<'exec>( &'exec self, - mut payload: OnGraphQLValidationStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> - { + mut payload: OnGraphQLValidationStartHookPayload<'exec>, + ) -> OnGraphQLValidationStartHookResult<'exec> { let rule = OneOfValidationRule { one_of_types: self.one_of_types.read().unwrap().clone(), }; @@ -113,8 +114,8 @@ impl RouterPlugin for OneOfPlugin { // 2. During execution step async fn on_execute<'exec>( &'exec self, - payload: OnExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload> { + payload: OnExecuteStartHookPayload<'exec>, + ) -> OnExecuteStartHookResult<'exec> { if let (Some(variable_values), Some(variable_defs)) = ( &payload.variable_values, &payload.operation_for_plan.variable_definitions, @@ -133,7 +134,7 @@ impl RouterPlugin for OneOfPlugin { variable_named_type, keys_num ); - return payload.end_response(PlanExecutionOutput { + return payload.end_response(HttpResponse { body: sonic_rs::to_vec(&json!({ "errors": [{ "message": err_msg, @@ -142,7 +143,8 @@ impl RouterPlugin for OneOfPlugin { } }] })) - .unwrap(), + .unwrap() + .into(), headers: Default::default(), status: http::StatusCode::BAD_REQUEST, }); @@ -155,8 +157,9 @@ impl RouterPlugin for OneOfPlugin { } fn on_supergraph_reload<'exec>( &'exec self, - start_payload: OnSupergraphLoadStartPayload, - ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + start_payload: OnSupergraphLoadStartHookPayload, + ) -> StartHookResult<'exec, OnSupergraphLoadStartHookPayload, OnSupergraphLoadEndHookPayload> + { for def in start_payload.new_ast.definitions.iter() { if let Definition::TypeDefinition(TypeDefinition::InputObject(input_obj)) = def { for directive in input_obj.directives.iter() { diff --git a/e2e/src/plugins/propagate_status_code.rs b/e2e/src/plugins/propagate_status_code.rs index 856424f62..8a7b0d4cf 100644 --- a/e2e/src/plugins/propagate_status_code.rs +++ b/e2e/src/plugins/propagate_status_code.rs @@ -5,10 +5,12 @@ use serde::Deserialize; use hive_router_plan_executor::{ hooks::{ - on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, - on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + on_http_request::{OnHttpRequestHookPayload, OnHttpRequestHookResult}, + on_subgraph_execute::{ + OnSubgraphExecuteStartHookPayload, OnSubgraphExecuteStartHookResult, + }, }, - plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; #[derive(Deserialize)] @@ -47,9 +49,8 @@ pub struct PropagateStatusCodeCtx { impl RouterPlugin for PropagateStatusCodePlugin { async fn on_subgraph_execute<'exec>( &'exec self, - payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload<'exec>> - { + payload: OnSubgraphExecuteStartHookPayload<'exec>, + ) -> OnSubgraphExecuteStartHookResult<'exec> { payload.on_end(|payload| { let status_code = payload.execution_result.status; // if a response contains a status code we're watching... @@ -72,8 +73,8 @@ impl RouterPlugin for PropagateStatusCodePlugin { } fn on_http_request<'exec>( &'exec self, - payload: OnHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponsePayload<'exec>> { + payload: OnHttpRequestHookPayload<'exec>, + ) -> OnHttpRequestHookResult<'exec> { payload.on_end(|mut payload| { // Checking if there is a context entry let ctx = payload.context.get_ref::(); diff --git a/e2e/src/plugins/response_cache.rs b/e2e/src/plugins/response_cache.rs index d5d02db30..4a09b9e0f 100644 --- a/e2e/src/plugins/response_cache.rs +++ b/e2e/src/plugins/response_cache.rs @@ -4,12 +4,14 @@ use redis::Commands; use serde::Deserialize; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, + executors::http::HttpResponse, hooks::{ - on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, - on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + on_execute::{ + OnExecuteEndHookPayload, OnExecuteStartHookPayload, OnExecuteStartHookResult, + }, + on_supergraph_load::{OnSupergraphLoadStartHookPayload, OnSupergraphLoadStartHookResult}, }, - plugin_trait::{EndPayload, HookResult, RouterPluginWithConfig, StartPayload}, + plugin_trait::{EndHookPayload, RouterPluginWithConfig, StartHookPayload}, plugins::plugin_trait::RouterPlugin, utils::consts::TYPENAME_FIELD_NAME, }; @@ -59,8 +61,8 @@ pub struct ResponseCachePlugin { impl RouterPlugin for ResponseCachePlugin { async fn on_execute<'exec>( &'exec self, - payload: OnExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + payload: OnExecuteStartHookPayload<'exec>, + ) -> OnExecuteStartHookResult<'exec> { let key = format!( "response_cache:{}:{:?}", payload.query_plan, payload.variable_values @@ -78,8 +80,8 @@ impl RouterPlugin for ResponseCachePlugin { key, String::from_utf8_lossy(&body) ); - return payload.end_response(PlanExecutionOutput { - body: body, + return payload.end_response(HttpResponse { + body: body.into(), headers: HeaderMap::new(), status: StatusCode::OK, }); @@ -89,7 +91,7 @@ impl RouterPlugin for ResponseCachePlugin { trace!("Error accessing cache for key {}: {}", key, err); } } - return payload.on_end(move |mut payload: OnExecuteEndPayload<'exec>| { + return payload.on_end(move |mut payload: OnExecuteEndHookPayload<'exec>| { // Do not cache if there are errors if !payload.errors.is_empty() { trace!("Not caching response due to errors"); @@ -143,8 +145,8 @@ impl RouterPlugin for ResponseCachePlugin { } fn on_supergraph_reload<'a>( &'a self, - payload: OnSupergraphLoadStartPayload, - ) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + payload: OnSupergraphLoadStartHookPayload, + ) -> OnSupergraphLoadStartHookResult<'a> { // Visit the schema and update ttl_per_type based on some directive payload.new_ast.definitions.iter().for_each(|def| { if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { diff --git a/e2e/src/plugins/root_field_limit.rs b/e2e/src/plugins/root_field_limit.rs index 64da0cc28..6baf07ecb 100644 --- a/e2e/src/plugins/root_field_limit.rs +++ b/e2e/src/plugins/root_field_limit.rs @@ -11,12 +11,14 @@ use serde::Deserialize; use sonic_rs::json; use hive_router_plan_executor::{ - execution::plan::PlanExecutionOutput, + executors::http::HttpResponse, hooks::{ - on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, - on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + on_graphql_validation::{ + OnGraphQLValidationStartHookPayload, OnGraphQLValidationStartHookResult, + }, + on_query_plan::{OnQueryPlanStartHookPayload, OnQueryPlanStartHookResult}, }, - plugin_trait::{HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + plugin_trait::{RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; // This example shows two ways of limiting the number of root fields in a query: @@ -28,9 +30,8 @@ impl RouterPlugin for RootFieldLimitPlugin { // Using validation step async fn on_graphql_validation<'exec>( &'exec self, - mut payload: OnGraphQLValidationStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> - { + mut payload: OnGraphQLValidationStartHookPayload<'exec>, + ) -> OnGraphQLValidationStartHookResult<'exec> { let rule = RootFieldLimitRule { max_root_fields: self.max_root_fields, }; @@ -40,8 +41,8 @@ impl RouterPlugin for RootFieldLimitPlugin { // Or during query planning async fn on_query_plan<'exec>( &'exec self, - payload: OnQueryPlanStartPayload<'exec>, - ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { + payload: OnQueryPlanStartHookPayload<'exec>, + ) -> OnQueryPlanStartHookResult<'exec> { let mut cnt = 0; for selection in payload .filtered_operation_for_plan @@ -67,8 +68,8 @@ impl RouterPlugin for RootFieldLimitPlugin { }] }); // Return error - return payload.end_response(PlanExecutionOutput { - body: sonic_rs::to_vec(&body).unwrap_or_default(), + return payload.end_response(HttpResponse { + body: sonic_rs::to_vec(&body).unwrap_or_default().into(), headers: http::HeaderMap::new(), status: http::StatusCode::PAYLOAD_TOO_LARGE, }); diff --git a/e2e/src/plugins/subgraph_response_cache.rs b/e2e/src/plugins/subgraph_response_cache.rs index 617427382..b7b8a0289 100644 --- a/e2e/src/plugins/subgraph_response_cache.rs +++ b/e2e/src/plugins/subgraph_response_cache.rs @@ -2,9 +2,12 @@ use dashmap::DashMap; use serde::Deserialize; use hive_router_plan_executor::{ - executors::common::HttpExecutionResponse, - hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, - plugin_trait::{EndPayload, HookResult, RouterPlugin, RouterPluginWithConfig, StartPayload}, + executors::http::HttpResponse, + hooks::on_subgraph_execute::{ + OnSubgraphExecuteEndHookPayload, OnSubgraphExecuteStartHookPayload, + OnSubgraphExecuteStartHookResult, + }, + plugin_trait::{EndHookPayload, RouterPlugin, RouterPluginWithConfig, StartHookPayload}, }; #[derive(Deserialize)] @@ -29,15 +32,15 @@ impl RouterPluginWithConfig for SubgraphResponseCachePlugin { } pub struct SubgraphResponseCachePlugin { - cache: DashMap, + cache: DashMap, } #[async_trait::async_trait] impl RouterPlugin for SubgraphResponseCachePlugin { async fn on_subgraph_execute<'exec>( &'exec self, - mut payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + mut payload: OnSubgraphExecuteStartHookPayload<'exec>, + ) -> OnSubgraphExecuteStartHookResult<'exec> { let key = format!( "subgraph_response_cache:{}:{:?}", payload.execution_request.query, payload.execution_request.variables @@ -48,7 +51,7 @@ impl RouterPlugin for SubgraphResponseCachePlugin { payload.execution_result = Some(cached_response.clone()); return payload.cont(); } - payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { + payload.on_end(move |payload: OnSubgraphExecuteEndHookPayload| { // Here payload.response is not Option self.cache.insert(key, payload.execution_result.clone()); payload.cont() diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index f5b7ec815..188240ce9 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -21,22 +21,19 @@ use crate::{ jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt, }, - executors::{ - common::{HttpExecutionResponse, SubgraphExecutionRequest}, - map::SubgraphExecutorMap, - }, + executors::{common::SubgraphExecutionRequest, http::HttpResponse, map::SubgraphExecutorMap}, headers::{ plan::HeaderRulesPlan, request::modify_subgraph_request_headers, response::{apply_subgraph_response_headers, modify_client_response_headers}, }, - hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + hooks::on_execute::{OnExecuteEndHookPayload, OnExecuteStartHookPayload}, introspection::{ resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, }, plugin_context::PluginRequestState, - plugin_trait::ControlFlowResult, + plugin_trait::{EndControlFlow, StartControlFlow}, projection::{ plan::FieldProjectionPlan, request::{project_requires, RequestProjectionContext}, @@ -69,15 +66,8 @@ pub struct QueryPlanExecutionContext<'exec, 'req> { pub jwt_auth_forwarding: Option, } -#[derive(Clone)] -pub struct PlanExecutionOutput { - pub body: Vec, - pub headers: HeaderMap, - pub status: http::StatusCode, -} - impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { - pub async fn execute_query_plan(self) -> Result { + pub async fn execute_query_plan(self) -> Result { let mut init_value = if let Some(introspection_query) = self.introspection_context.query { resolve_introspection(introspection_query, self.introspection_context) } else { @@ -92,7 +82,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { let mut on_end_callbacks = vec![]; if let Some(plugin_req_state) = self.plugin_req_state.as_ref() { - let mut start_payload = OnExecuteStartPayload { + let mut start_payload = OnExecuteStartHookPayload { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, query_plan, @@ -108,11 +98,11 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { let result = plugin.on_execute(start_payload).await; start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { + StartControlFlow::Continue => { /* continue to next plugin */ } + StartControlFlow::EndResponse(response) => { return Ok(response); } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); } } @@ -155,7 +145,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { let mut response_size_estimate = exec_ctx.response_storage.estimate_final_response_size(); if !on_end_callbacks.is_empty() { - let mut end_payload = OnExecuteEndPayload { + let mut end_payload = OnExecuteEndHookPayload { data, errors, extensions, @@ -166,14 +156,10 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { /* continue to next callback */ } - ControlFlowResult::EndResponse(output) => { + EndControlFlow::Continue => { /* continue to next callback */ } + EndControlFlow::EndResponse(output) => { return Ok(output); } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); - } } } @@ -197,8 +183,8 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { affected_path: || None, })?; - Ok(PlanExecutionOutput { - body, + Ok(HttpResponse { + body: body.into(), headers: response_headers, status: http::StatusCode::OK, }) @@ -240,20 +226,15 @@ impl<'exec, T> ConcurrencyScope<'exec, T> { } } -struct SubgraphOutput { - body: Bytes, - headers: HeaderMap, -} - struct FetchJob { fetch_node_id: i64, subgraph_name: String, - response: SubgraphOutput, + response: HttpResponse, } struct FlattenFetchJob { flatten_node_path: FlattenNodePath, - response: SubgraphOutput, + response: HttpResponse, fetch_node_id: i64, subgraph_name: String, representation_hashes: Vec, @@ -266,18 +247,13 @@ enum ExecutionJob { None, } -impl From for SubgraphOutput { +impl From for HttpResponse { fn from(value: ExecutionJob) -> Self { match value { - ExecutionJob::Fetch(j) => Self { - body: j.response.body, - headers: j.response.headers, - }, - ExecutionJob::FlattenFetch(j) => Self { - body: j.response.body, - headers: j.response.headers, - }, + ExecutionJob::Fetch(j) => j.response, + ExecutionJob::FlattenFetch(j) => j.response, ExecutionJob::None => Self { + status: http::StatusCode::OK, body: Bytes::new(), headers: HeaderMap::new(), }, @@ -285,15 +261,6 @@ impl From for SubgraphOutput { } } -impl From for SubgraphOutput { - fn from(res: HttpExecutionResponse) -> Self { - Self { - body: res.body, - headers: res.headers, - } - } -} - struct PreparedFlattenData { representations: Vec, representation_hashes: Vec, @@ -816,8 +783,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { self.client_request, self.plugin_req_state, ) - .await - .into(), + .await, })) } diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index b607a7e43..0cfd4f354 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -1,11 +1,10 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; -use bytes::Bytes; use http::HeaderMap; use sonic_rs::Value; -use crate::plugin_context::PluginRequestState; +use crate::{executors::http::HttpResponse, plugin_context::PluginRequestState}; #[async_trait] pub trait SubgraphExecutor { @@ -13,7 +12,7 @@ pub trait SubgraphExecutor { &self, execution_request: SubgraphExecutionRequest<'a>, plugin_req_state: &'a Option>, - ) -> HttpExecutionResponse; + ) -> HttpResponse; fn to_boxed_arc<'a>(self) -> Arc> where @@ -47,10 +46,3 @@ impl SubgraphExecutionRequest<'_> { .insert(key, value); } } - -#[derive(Clone)] -pub struct HttpExecutionResponse { - pub body: Bytes, - pub headers: HeaderMap, - pub status: http::StatusCode, -} diff --git a/lib/executor/src/executors/dedupe.rs b/lib/executor/src/executors/dedupe.rs index a60599f19..420ede67d 100644 --- a/lib/executor/src/executors/dedupe.rs +++ b/lib/executor/src/executors/dedupe.rs @@ -1,16 +1,8 @@ use ahash::AHasher; -use bytes::Bytes; -use http::{HeaderMap, Method, StatusCode, Uri}; +use http::{HeaderMap, Method, Uri}; use std::collections::BTreeMap; use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; -#[derive(Debug, Clone)] -pub struct SharedResponse { - pub status: StatusCode, - pub headers: HeaderMap, - pub body: Bytes, -} - pub fn request_fingerprint( method: &Method, url: &Uri, diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 6798bc084..b8a573b8f 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -1,12 +1,11 @@ use std::sync::Arc; -use crate::executors::common::HttpExecutionResponse; -use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; +use crate::executors::dedupe::{request_fingerprint, ABuildHasher}; use crate::hooks::on_subgraph_http_request::{ - OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, + OnSubgraphHttpRequestHookPayload, OnSubgraphHttpResponseHookPayload, }; use crate::plugin_context::PluginRequestState; -use crate::plugin_trait::ControlFlowResult; +use crate::plugin_trait::{EndControlFlow, StartControlFlow}; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -40,7 +39,7 @@ pub struct HTTPSubgraphExecutor { pub header_map: HeaderMap, pub semaphore: Arc, pub config: Arc, - pub in_flight_requests: Arc>, ABuildHasher>>, + pub in_flight_requests: Arc>, ABuildHasher>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -55,7 +54,7 @@ impl HTTPSubgraphExecutor { http_client: Arc, semaphore: Arc, config: Arc, - in_flight_requests: Arc>, ABuildHasher>>, + in_flight_requests: Arc>, ABuildHasher>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -168,12 +167,12 @@ async fn send_request( mut body: Vec, mut execution_request: SubgraphExecutionRequest<'_>, plugin_req_state: &Option>, -) -> Result { +) -> Result { let mut on_end_callbacks = vec![]; let mut response = None; if let Some(plugin_req_state) = plugin_req_state.as_ref() { - let mut start_payload = OnSubgraphHttpRequestPayload { + let mut start_payload = OnSubgraphHttpRequestHookPayload { subgraph_name, endpoint, method, @@ -186,16 +185,16 @@ async fn send_request( let result = plugin.on_subgraph_http_request(start_payload).await; start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { + StartControlFlow::Continue => { /* continue to next plugin */ } + StartControlFlow::EndResponse(response) => { // TODO: Fixx - return Ok(SharedResponse { + return Ok(HttpResponse { status: StatusCode::OK, - body: response.body.into(), + body: response.body, headers: response.headers, }); } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); } } @@ -248,7 +247,7 @@ async fn send_request( )); } - SharedResponse { + HttpResponse { status: parts.status, body, headers: parts.headers, @@ -256,24 +255,20 @@ async fn send_request( } }; - let mut end_payload = OnSubgraphHttpResponsePayload { response }; + let mut end_payload = OnSubgraphHttpResponseHookPayload { response }; for callback in on_end_callbacks { let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { /* continue to next callback */ } - ControlFlowResult::EndResponse(response) => { - return Ok(SharedResponse { + EndControlFlow::Continue => { /* continue to next callback */ } + EndControlFlow::EndResponse(response) => { + return Ok(HttpResponse { status: StatusCode::OK, - body: response.body.into(), + body: response.body, headers: response.headers, }); } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); - } } } @@ -287,12 +282,12 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { &self, mut execution_request: SubgraphExecutionRequest<'a>, plugin_req_state: &'a Option>, - ) -> HttpExecutionResponse { + ) -> HttpResponse { let body = match self.build_request_body(&execution_request) { Ok(body) => body, Err(e) => { self.log_error(&e); - return HttpExecutionResponse { + return HttpResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), status: StatusCode::OK, @@ -321,14 +316,10 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { ) .await { - Ok(shared_response) => HttpExecutionResponse { - body: shared_response.body, - headers: shared_response.headers, - status: shared_response.status, - }, + Ok(shared_response) => shared_response, Err(e) => { self.log_error(&e); - HttpExecutionResponse { + HttpResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), status: StatusCode::OK, @@ -375,14 +366,10 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { .await; match response_result { - Ok(shared_response) => HttpExecutionResponse { - body: shared_response.body.clone(), - headers: shared_response.headers.clone(), - status: shared_response.status, - }, + Ok(shared_response) => shared_response.clone(), Err(e) => { self.log_error(&e); - HttpExecutionResponse { + HttpResponse { body: self.error_to_graphql_bytes(e.clone()), headers: Default::default(), status: StatusCode::OK, @@ -391,3 +378,10 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { } } } + +#[derive(Clone)] +pub struct HttpResponse { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: Bytes, +} diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index cd33b511e..59586a01b 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -29,17 +29,16 @@ use vrl::{ use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ - common::{ - HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, - SubgraphExecutorBoxedArc, - }, - dedupe::{ABuildHasher, SharedResponse}, + common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc}, + dedupe::ABuildHasher, error::SubgraphExecutorError, - http::{HTTPSubgraphExecutor, HttpClient}, + http::{HTTPSubgraphExecutor, HttpClient, HttpResponse}, + }, + hooks::on_subgraph_execute::{ + OnSubgraphExecuteEndHookPayload, OnSubgraphExecuteStartHookPayload, }, - hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, plugin_context::PluginRequestState, - plugin_trait::ControlFlowResult, + plugin_trait::{EndControlFlow, StartControlFlow}, response::graphql_error::GraphQLError, }; @@ -63,7 +62,7 @@ pub struct SubgraphExecutorMap { client: Arc, semaphores_by_origin: DashMap>, max_connections_per_host: usize, - in_flight_requests: Arc>, ABuildHasher>>, + in_flight_requests: Arc>, ABuildHasher>>, } impl SubgraphExecutorMap { @@ -125,13 +124,13 @@ impl SubgraphExecutorMap { execution_request: SubgraphExecutionRequest<'exec>, client_request: &ClientRequestDetails<'exec, 'req>, plugin_req_state: &Option>, - ) -> HttpExecutionResponse { + ) -> HttpResponse { let mut on_end_callbacks = vec![]; let mut execution_request = execution_request; let mut execution_result = None; if let Some(plugin_req_state) = plugin_req_state.as_ref() { - let mut start_payload = OnSubgraphExecuteStartPayload { + let mut start_payload = OnSubgraphExecuteStartHookPayload { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, subgraph_name, @@ -142,18 +141,18 @@ impl SubgraphExecutorMap { let result = plugin.on_subgraph_execute(start_payload).await; start_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + StartControlFlow::Continue => { // continue to next plugin } - ControlFlowResult::EndResponse(response) => { + StartControlFlow::EndResponse(response) => { // TODO: FFIX - return HttpExecutionResponse { - body: response.body.into(), + return HttpResponse { + body: response.body, headers: response.headers, status: response.status, }; } - ControlFlowResult::OnEnd(callback) => { + StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); } } @@ -187,7 +186,7 @@ impl SubgraphExecutorMap { }; if let Some(plugin_req_state) = plugin_req_state.as_ref() { - let mut end_payload = OnSubgraphExecuteEndPayload { + let mut end_payload = OnSubgraphExecuteEndHookPayload { context: &plugin_req_state.context, execution_result, }; @@ -196,20 +195,17 @@ impl SubgraphExecutorMap { let result = callback(end_payload); end_payload = result.payload; match result.control_flow { - ControlFlowResult::Continue => { + EndControlFlow::Continue => { // continue to next callback } - ControlFlowResult::EndResponse(response) => { + EndControlFlow::EndResponse(response) => { // TODO: FFIX - return HttpExecutionResponse { - body: response.body.into(), + return HttpResponse { + body: response.body, headers: response.headers, status: response.status, }; } - ControlFlowResult::OnEnd(_) => { - unreachable!("End callbacks should not register further end callbacks"); - } } } @@ -223,7 +219,7 @@ impl SubgraphExecutorMap { &self, graphql_error: GraphQLError, subgraph_name: &str, - ) -> HttpExecutionResponse { + ) -> HttpResponse { let errors = vec![graphql_error.add_subgraph_name(subgraph_name)]; let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); let mut buffer = BytesMut::new(); @@ -231,7 +227,7 @@ impl SubgraphExecutorMap { buffer.put_slice(&errors_bytes); buffer.put_slice(b"}"); - HttpExecutionResponse { + HttpResponse { body: buffer.freeze(), headers: Default::default(), status: http::StatusCode::INTERNAL_SERVER_ERROR, diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index b69ba3297..e1392af14 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -4,11 +4,11 @@ use hive_router_query_planner::ast::operation::OperationDefinition; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use crate::plugin_context::{PluginContext, RouterHttpRequest}; -use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}; use crate::response::graphql_error::GraphQLError; use crate::response::value::Value; -pub struct OnExecuteStartPayload<'exec> { +pub struct OnExecuteStartHookPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub query_plan: &'exec QueryPlan, @@ -23,9 +23,12 @@ pub struct OnExecuteStartPayload<'exec> { pub dedupe_subgraph_requests: bool, } -impl<'exec> StartPayload> for OnExecuteStartPayload<'exec> {} +impl<'exec> StartHookPayload> for OnExecuteStartHookPayload<'exec> {} -pub struct OnExecuteEndPayload<'exec> { +pub type OnExecuteStartHookResult<'exec> = + StartHookResult<'exec, OnExecuteStartHookPayload<'exec>, OnExecuteEndHookPayload<'exec>>; + +pub struct OnExecuteEndHookPayload<'exec> { pub data: Value<'exec>, pub errors: Vec, pub extensions: Option>, @@ -33,4 +36,6 @@ pub struct OnExecuteEndPayload<'exec> { pub response_size_estimate: usize, } -impl<'exec> EndPayload for OnExecuteEndPayload<'exec> {} +impl<'exec> EndHookPayload for OnExecuteEndHookPayload<'exec> {} + +pub type OnExecuteEndHookResult<'exec> = EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs index c69d094e4..ea44426f2 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_params.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -8,8 +8,10 @@ use sonic_rs::Value; use crate::plugin_context::PluginContext; use crate::plugin_context::RouterHttpRequest; -use crate::plugin_trait::EndPayload; -use crate::plugin_trait::StartPayload; +use crate::plugin_trait::EndHookPayload; +use crate::plugin_trait::EndHookResult; +use crate::plugin_trait::StartHookPayload; +use crate::plugin_trait::StartHookResult; #[derive(Debug, Clone, Default)] pub struct GraphQLParams { @@ -93,18 +95,29 @@ impl<'de> Deserialize<'de> for GraphQLParams { } } -pub struct OnGraphQLParamsStartPayload<'exec> { +pub struct OnGraphQLParamsStartHookPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub body: Bytes, pub graphql_params: Option, } -impl<'exec> StartPayload> for OnGraphQLParamsStartPayload<'exec> {} +impl<'exec> StartHookPayload> + for OnGraphQLParamsStartHookPayload<'exec> +{ +} + +pub type OnGraphQLParamsStartHookResult<'exec> = StartHookResult< + 'exec, + OnGraphQLParamsStartHookPayload<'exec>, + OnGraphQLParamsEndHookPayload<'exec>, +>; -pub struct OnGraphQLParamsEndPayload<'exec> { +pub struct OnGraphQLParamsEndHookPayload<'exec> { pub graphql_params: GraphQLParams, pub context: &'exec PluginContext, } -impl<'exec> EndPayload for OnGraphQLParamsEndPayload<'exec> {} +impl<'exec> EndHookPayload for OnGraphQLParamsEndHookPayload<'exec> {} + +pub type OnGraphQLParamsEndHookResult<'exec> = EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs index fa29e3b9d..9ee55780a 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_parse.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -3,20 +3,28 @@ use graphql_tools::static_graphql::query::Document; use crate::{ hooks::on_graphql_params::GraphQLParams, plugin_context::{PluginContext, RouterHttpRequest}, - plugin_trait::{EndPayload, StartPayload}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, }; -pub struct OnGraphQLParseStartPayload<'exec> { +pub struct OnGraphQLParseStartHookPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub graphql_params: &'exec GraphQLParams, pub document: Option, } -impl<'exec> StartPayload for OnGraphQLParseStartPayload<'exec> {} +impl<'exec> StartHookPayload + for OnGraphQLParseStartHookPayload<'exec> +{ +} + +pub type OnGraphQLParseHookResult<'exec> = + StartHookResult<'exec, OnGraphQLParseStartHookPayload<'exec>, OnGraphQLParseEndHookPayload>; -pub struct OnGraphQLParseEndPayload { +pub struct OnGraphQLParseEndHookPayload { pub document: Document, } -impl EndPayload for OnGraphQLParseEndPayload {} +impl EndHookPayload for OnGraphQLParseEndHookPayload {} + +pub type OnGraphQLParseEndHookResult = EndHookResult; diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index a839e06da..9431dd29f 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -10,10 +10,10 @@ use hive_router_query_planner::state::supergraph_state::SchemaDocument; use crate::{ plugin_context::{PluginContext, PluginRequestState, RouterHttpRequest}, - plugin_trait::{EndPayload, StartPayload}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, }; -pub struct OnGraphQLValidationStartPayload<'exec> { +pub struct OnGraphQLValidationStartHookPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub schema: &'exec SchemaDocument, @@ -23,16 +23,25 @@ pub struct OnGraphQLValidationStartPayload<'exec> { pub errors: Option>, } -impl<'exec> StartPayload for OnGraphQLValidationStartPayload<'exec> {} +impl<'exec> StartHookPayload + for OnGraphQLValidationStartHookPayload<'exec> +{ +} + +pub type OnGraphQLValidationStartHookResult<'exec> = StartHookResult< + 'exec, + OnGraphQLValidationStartHookPayload<'exec>, + OnGraphQLValidationEndHookPayload, +>; -impl<'exec> OnGraphQLValidationStartPayload<'exec> { +impl<'exec> OnGraphQLValidationStartHookPayload<'exec> { pub fn new( plugin_req_state: &'exec PluginRequestState<'exec>, schema: &'exec SchemaDocument, document: &'exec Document, default_validation_plan: &'exec ValidationPlan, ) -> Self { - OnGraphQLValidationStartPayload { + OnGraphQLValidationStartHookPayload { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, schema, @@ -67,8 +76,10 @@ impl<'exec> OnGraphQLValidationStartPayload<'exec> { } } -pub struct OnGraphQLValidationEndPayload { +pub struct OnGraphQLValidationEndHookPayload { pub errors: Vec, } -impl EndPayload for OnGraphQLValidationEndPayload {} +impl EndHookPayload for OnGraphQLValidationEndHookPayload {} + +pub type OnGraphQLValidationHookEndResult = EndHookResult; diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index e9e857a80..8473a69fc 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -2,19 +2,24 @@ use ntex::web::{self, DefaultError, WebRequest}; use crate::{ plugin_context::PluginContext, - plugin_trait::{EndPayload, StartPayload}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, }; -pub struct OnHttpRequestPayload<'req> { +pub struct OnHttpRequestHookPayload<'req> { pub router_http_request: WebRequest, pub context: &'req PluginContext, } -impl<'req> StartPayload> for OnHttpRequestPayload<'req> {} +impl<'req> StartHookPayload> for OnHttpRequestHookPayload<'req> {} -pub struct OnHttpResponsePayload<'req> { +pub type OnHttpRequestHookResult<'req> = + StartHookResult<'req, OnHttpRequestHookPayload<'req>, OnHttpResponseHookPayload<'req>>; + +pub struct OnHttpResponseHookPayload<'req> { pub response: web::WebResponse, pub context: &'req PluginContext, } -impl<'req> EndPayload for OnHttpResponsePayload<'req> {} +impl<'req> EndHookPayload for OnHttpResponseHookPayload<'req> {} + +pub type OnHttpResponseHookResult<'req> = EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs index 9b2110fd7..103eafdf2 100644 --- a/lib/executor/src/plugins/hooks/on_query_plan.rs +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -7,10 +7,10 @@ use hive_router_query_planner::{ use crate::{ plugin_context::{PluginContext, RouterHttpRequest}, - plugin_trait::{EndPayload, StartPayload}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, }; -pub struct OnQueryPlanStartPayload<'exec> { +pub struct OnQueryPlanStartHookPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub filtered_operation_for_plan: &'exec OperationDefinition, @@ -20,10 +20,15 @@ pub struct OnQueryPlanStartPayload<'exec> { pub planner: &'exec Planner, } -impl<'exec> StartPayload for OnQueryPlanStartPayload<'exec> {} +impl<'exec> StartHookPayload for OnQueryPlanStartHookPayload<'exec> {} -pub struct OnQueryPlanEndPayload { +pub type OnQueryPlanStartHookResult<'exec> = + StartHookResult<'exec, OnQueryPlanStartHookPayload<'exec>, OnQueryPlanEndHookPayload>; + +pub struct OnQueryPlanEndHookPayload { pub query_plan: QueryPlan, } -impl EndPayload for OnQueryPlanEndPayload {} +impl EndHookPayload for OnQueryPlanEndHookPayload {} + +pub type OnQueryPlanEndHookResult = EndHookResult; diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 870f28cba..519a22d54 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,27 +1,36 @@ use crate::{ - executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, + executors::{common::SubgraphExecutionRequest, http::HttpResponse}, plugin_context::{PluginContext, RouterHttpRequest}, - plugin_trait::{EndPayload, StartPayload}, + plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, }; -pub struct OnSubgraphExecuteStartPayload<'exec> { +pub struct OnSubgraphExecuteStartHookPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub subgraph_name: &'exec str, pub execution_request: SubgraphExecutionRequest<'exec>, - pub execution_result: Option, + pub execution_result: Option, } -impl<'exec> StartPayload> - for OnSubgraphExecuteStartPayload<'exec> +impl<'exec> StartHookPayload> + for OnSubgraphExecuteStartHookPayload<'exec> { } -pub struct OnSubgraphExecuteEndPayload<'exec> { - pub execution_result: HttpExecutionResponse, +pub type OnSubgraphExecuteStartHookResult<'exec> = StartHookResult< + 'exec, + OnSubgraphExecuteStartHookPayload<'exec>, + OnSubgraphExecuteEndHookPayload<'exec>, +>; + +pub struct OnSubgraphExecuteEndHookPayload<'exec> { + pub execution_result: HttpResponse, pub context: &'exec PluginContext, } -impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} +impl<'exec> EndHookPayload for OnSubgraphExecuteEndHookPayload<'exec> {} + +pub type OnSubgraphExecuteEndHookResult<'exec> = + EndHookResult>; diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 2ce479b3d..8835c08ef 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,10 +1,10 @@ use crate::{ - executors::{common::SubgraphExecutionRequest, dedupe::SharedResponse}, + executors::{common::SubgraphExecutionRequest, http::HttpResponse}, plugin_context::PluginContext, - plugin_trait::{EndPayload, StartPayload}, + plugin_trait::{EndHookPayload, StartHookPayload}, }; -pub struct OnSubgraphHttpRequestPayload<'exec> { +pub struct OnSubgraphHttpRequestHookPayload<'exec> { pub subgraph_name: &'exec str, pub endpoint: &'exec http::Uri, @@ -15,13 +15,25 @@ pub struct OnSubgraphHttpRequestPayload<'exec> { pub context: &'exec PluginContext, // Early response - pub response: Option, + pub response: Option, } -impl<'exec> StartPayload for OnSubgraphHttpRequestPayload<'exec> {} +impl<'exec> StartHookPayload + for OnSubgraphHttpRequestHookPayload<'exec> +{ +} + +pub type OnSubgraphHttpRequestHookResult<'exec> = crate::plugin_trait::StartHookResult< + 'exec, + OnSubgraphHttpRequestHookPayload<'exec>, + OnSubgraphHttpResponseHookPayload, +>; -pub struct OnSubgraphHttpResponsePayload { - pub response: SharedResponse, +pub struct OnSubgraphHttpResponseHookPayload { + pub response: HttpResponse, } -impl EndPayload for OnSubgraphHttpResponsePayload {} +impl EndHookPayload for OnSubgraphHttpResponseHookPayload {} + +pub type OnSubgraphHttpResponseHookResult = + crate::plugin_trait::EndHookResult; diff --git a/lib/executor/src/plugins/hooks/on_supergraph_load.rs b/lib/executor/src/plugins/hooks/on_supergraph_load.rs index 21dfbf5a5..2425c9c3f 100644 --- a/lib/executor/src/plugins/hooks/on_supergraph_load.rs +++ b/lib/executor/src/plugins/hooks/on_supergraph_load.rs @@ -6,7 +6,7 @@ use hive_router_query_planner::planner::Planner; use crate::{ introspection::schema::SchemaMetadata, - plugin_trait::{EndPayload, StartPayload}, + plugin_trait::{EndHookPayload, StartHookPayload}, SubgraphExecutorMap, }; @@ -16,15 +16,24 @@ pub struct SupergraphData { pub subgraph_executor_map: SubgraphExecutorMap, } -pub struct OnSupergraphLoadStartPayload { +pub struct OnSupergraphLoadStartHookPayload { pub current_supergraph_data: Arc>>, pub new_ast: Document, } -impl StartPayload for OnSupergraphLoadStartPayload {} +impl StartHookPayload for OnSupergraphLoadStartHookPayload {} -pub struct OnSupergraphLoadEndPayload { +pub type OnSupergraphLoadStartHookResult<'exec> = crate::plugin_trait::StartHookResult< + 'exec, + OnSupergraphLoadStartHookPayload, + OnSupergraphLoadEndHookPayload, +>; + +pub struct OnSupergraphLoadEndHookPayload { pub new_supergraph_data: SupergraphData, } -impl EndPayload for OnSupergraphLoadEndPayload {} +impl EndHookPayload for OnSupergraphLoadEndHookPayload {} + +pub type OnSupergraphLoadEndHookResult = + crate::plugin_trait::EndHookResult; diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 3bb750e59..72517b776 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -1,80 +1,94 @@ use serde::de::DeserializeOwned; -use crate::execution::plan::PlanExecutionOutput; -use crate::hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}; -use crate::hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}; -use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; -use crate::hooks::on_graphql_validation::{ - OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, +use crate::{ + executors::http::HttpResponse, + hooks::{ + on_execute::{OnExecuteStartHookPayload, OnExecuteStartHookResult}, + on_graphql_params::{OnGraphQLParamsStartHookPayload, OnGraphQLParamsStartHookResult}, + on_graphql_parse::{OnGraphQLParseHookResult, OnGraphQLParseStartHookPayload}, + on_graphql_validation::{ + OnGraphQLValidationStartHookPayload, OnGraphQLValidationStartHookResult, + }, + on_http_request::{OnHttpRequestHookPayload, OnHttpRequestHookResult}, + on_query_plan::{OnQueryPlanStartHookPayload, OnQueryPlanStartHookResult}, + on_subgraph_execute::{ + OnSubgraphExecuteStartHookPayload, OnSubgraphExecuteStartHookResult, + }, + on_subgraph_http_request::{ + OnSubgraphHttpRequestHookPayload, OnSubgraphHttpRequestHookResult, + }, + on_supergraph_load::{OnSupergraphLoadStartHookPayload, OnSupergraphLoadStartHookResult}, + }, }; -use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}; -use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; -use crate::hooks::on_subgraph_execute::{ - OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, -}; -use crate::hooks::on_subgraph_http_request::{ - OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, -}; -use crate::hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}; -pub struct HookResult<'exec, TStartPayload, TEndPayload> { +pub struct StartHookResult<'exec, TStartPayload, TEndPayload> { pub payload: TStartPayload, - pub control_flow: ControlFlowResult<'exec, TEndPayload>, + pub control_flow: StartControlFlow<'exec, TEndPayload>, } -pub enum ControlFlowResult<'exec, TEndPayload> { +pub enum StartControlFlow<'exec, TEndPayload> { Continue, - EndResponse(PlanExecutionOutput), - OnEnd(Box HookResult<'exec, TEndPayload, ()> + Send + 'exec>), + EndResponse(HttpResponse), + OnEnd(Box EndHookResult + Send + 'exec>), } -pub trait StartPayload +pub trait StartHookPayload where Self: Sized, { - fn cont<'exec>(self) -> HookResult<'exec, Self, TEndPayload> { - HookResult { + fn cont<'exec>(self) -> StartHookResult<'exec, Self, TEndPayload> { + StartHookResult { payload: self, - control_flow: ControlFlowResult::Continue, + control_flow: StartControlFlow::Continue, } } fn end_response<'exec>( self, - output: PlanExecutionOutput, - ) -> HookResult<'exec, Self, TEndPayload> { - HookResult { + output: HttpResponse, + ) -> StartHookResult<'exec, Self, TEndPayload> { + StartHookResult { payload: self, - control_flow: ControlFlowResult::EndResponse(output), + control_flow: StartControlFlow::EndResponse(output), } } - fn on_end<'exec, F>(self, f: F) -> HookResult<'exec, Self, TEndPayload> + fn on_end<'exec, F>(self, f: F) -> StartHookResult<'exec, Self, TEndPayload> where - F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + Send + 'exec, + F: FnOnce(TEndPayload) -> EndHookResult + Send + 'exec, { - HookResult { + StartHookResult { payload: self, - control_flow: ControlFlowResult::OnEnd(Box::new(f)), + control_flow: StartControlFlow::OnEnd(Box::new(f)), } } } -pub trait EndPayload +pub struct EndHookResult { + pub payload: TEndPayload, + pub control_flow: EndControlFlow, +} + +pub enum EndControlFlow { + Continue, + EndResponse(HttpResponse), +} + +pub trait EndHookPayload where Self: Sized, { - fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { - HookResult { + fn cont(self) -> EndHookResult { + EndHookResult { payload: self, - control_flow: ControlFlowResult::Continue, + control_flow: EndControlFlow::Continue, } } - fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { - HookResult { + fn end_response(self, output: HttpResponse) -> EndHookResult { + EndHookResult { payload: self, - control_flow: ControlFlowResult::EndResponse(output), + control_flow: EndControlFlow::EndResponse(output), } } } @@ -93,57 +107,56 @@ where pub trait RouterPlugin { fn on_http_request<'req>( &'req self, - start_payload: OnHttpRequestPayload<'req>, - ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { + start_payload: OnHttpRequestHookPayload<'req>, + ) -> OnHttpRequestHookResult<'req> { start_payload.cont() } async fn on_graphql_params<'exec>( &'exec self, - start_payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + start_payload: OnGraphQLParamsStartHookPayload<'exec>, + ) -> OnGraphQLParamsStartHookResult<'exec> { start_payload.cont() } async fn on_graphql_parse<'exec>( &'exec self, - start_payload: OnGraphQLParseStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { + start_payload: OnGraphQLParseStartHookPayload<'exec>, + ) -> OnGraphQLParseHookResult<'exec> { start_payload.cont() } async fn on_graphql_validation<'exec>( &'exec self, - start_payload: OnGraphQLValidationStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> - { + start_payload: OnGraphQLValidationStartHookPayload<'exec>, + ) -> OnGraphQLValidationStartHookResult<'exec> { start_payload.cont() } async fn on_query_plan<'exec>( &'exec self, - start_payload: OnQueryPlanStartPayload<'exec>, - ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { + start_payload: OnQueryPlanStartHookPayload<'exec>, + ) -> OnQueryPlanStartHookResult<'exec> { start_payload.cont() } async fn on_execute<'exec>( &'exec self, - start_payload: OnExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + start_payload: OnExecuteStartHookPayload<'exec>, + ) -> OnExecuteStartHookResult<'exec> { start_payload.cont() } async fn on_subgraph_execute<'exec>( &'exec self, - start_payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + start_payload: OnSubgraphExecuteStartHookPayload<'exec>, + ) -> OnSubgraphExecuteStartHookResult<'exec> { start_payload.cont() } async fn on_subgraph_http_request<'exec>( &'exec self, - start_payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { + start_payload: OnSubgraphHttpRequestHookPayload<'exec>, + ) -> OnSubgraphHttpRequestHookResult<'exec> { start_payload.cont() } fn on_supergraph_reload<'exec>( &'exec self, - start_payload: OnSupergraphLoadStartPayload, - ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + start_payload: OnSupergraphLoadStartHookPayload, + ) -> OnSupergraphLoadStartHookResult<'exec> { start_payload.cont() } } From 5675c8a4da145fb95383ebf633341cdc4667816a Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 2 Dec 2025 15:38:22 +0300 Subject: [PATCH 29/31] Continue --- lib/executor/src/executors/common.rs | 3 +- lib/executor/src/executors/http.rs | 16 ++----- lib/executor/src/executors/map.rs | 48 +++++++++---------- .../src/plugins/hooks/on_subgraph_execute.rs | 6 ++- 4 files changed, 35 insertions(+), 38 deletions(-) diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 0cfd4f354..0015d8139 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -1,13 +1,14 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; -use http::HeaderMap; +use http::{HeaderMap, Uri}; use sonic_rs::Value; use crate::{executors::http::HttpResponse, plugin_context::PluginRequestState}; #[async_trait] pub trait SubgraphExecutor { + fn endpoint(&self) -> &Uri; async fn execute<'a>( &self, execution_request: SubgraphExecutionRequest<'a>, diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index b8a573b8f..ac3bdd47f 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -187,12 +187,7 @@ async fn send_request( match result.control_flow { StartControlFlow::Continue => { /* continue to next plugin */ } StartControlFlow::EndResponse(response) => { - // TODO: Fixx - return Ok(HttpResponse { - status: StatusCode::OK, - body: response.body, - headers: response.headers, - }); + return Ok(response); } StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); @@ -263,11 +258,7 @@ async fn send_request( match result.control_flow { EndControlFlow::Continue => { /* continue to next callback */ } EndControlFlow::EndResponse(response) => { - return Ok(HttpResponse { - status: StatusCode::OK, - body: response.body, - headers: response.headers, - }); + return Ok(response); } } } @@ -277,6 +268,9 @@ async fn send_request( #[async_trait] impl SubgraphExecutor for HTTPSubgraphExecutor { + fn endpoint(&self) -> &http::Uri { + &self.endpoint + } #[tracing::instrument(skip_all, fields(subgraph_name = self.subgraph_name))] async fn execute<'a>( &self, diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 59586a01b..7ecad86ba 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -125,6 +125,25 @@ impl SubgraphExecutorMap { client_request: &ClientRequestDetails<'exec, 'req>, plugin_req_state: &Option>, ) -> HttpResponse { + let mut executor = match self.get_or_create_executor(subgraph_name, client_request) { + Ok(Some(executor)) => executor, + Err(err) => { + error!( + "Subgraph executor error for subgraph '{}': {}", + subgraph_name, err, + ); + return self.internal_server_error_response(err.into(), subgraph_name); + } + Ok(None) => { + error!( + "Subgraph executor not found for subgraph '{}'", + subgraph_name + ); + return self + .internal_server_error_response("Internal server error".into(), subgraph_name); + } + }; + let mut on_end_callbacks = vec![]; let mut execution_request = execution_request; @@ -134,6 +153,7 @@ impl SubgraphExecutorMap { router_http_request: &plugin_req_state.router_http_request, context: &plugin_req_state.context, subgraph_name, + executor, execution_request, execution_result, }; @@ -159,30 +179,12 @@ impl SubgraphExecutorMap { } execution_request = start_payload.execution_request; execution_result = start_payload.execution_result; + executor = start_payload.executor; } let mut execution_result = match execution_result { Some(execution_result) => execution_result, - None => match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor.execute(execution_request, plugin_req_state).await, - Err(err) => { - error!( - "Subgraph executor error for subgraph '{}': {}", - subgraph_name, err, - ); - self.internal_server_error_response(err.into(), subgraph_name) - } - Ok(None) => { - error!( - "Subgraph executor not found for subgraph '{}'", - subgraph_name - ); - self.internal_server_error_response( - "Internal server error".into(), - subgraph_name, - ) - } - }, + None => executor.execute(execution_request, plugin_req_state).await, }; if let Some(plugin_req_state) = plugin_req_state.as_ref() { @@ -200,11 +202,7 @@ impl SubgraphExecutorMap { } EndControlFlow::EndResponse(response) => { // TODO: FFIX - return HttpResponse { - body: response.body, - headers: response.headers, - status: response.status, - }; + return response; } } } diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 519a22d54..b4e08f320 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,5 +1,8 @@ use crate::{ - executors::{common::SubgraphExecutionRequest, http::HttpResponse}, + executors::{ + common::{SubgraphExecutionRequest, SubgraphExecutorBoxedArc}, + http::HttpResponse, + }, plugin_context::{PluginContext, RouterHttpRequest}, plugin_trait::{EndHookPayload, EndHookResult, StartHookPayload, StartHookResult}, }; @@ -9,6 +12,7 @@ pub struct OnSubgraphExecuteStartHookPayload<'exec> { pub context: &'exec PluginContext, pub subgraph_name: &'exec str, + pub executor: SubgraphExecutorBoxedArc, pub execution_request: SubgraphExecutionRequest<'exec>, pub execution_result: Option, From 87eb0daf2660c39895235ffb0c8faf798fdb111f Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 2 Dec 2025 15:55:46 +0300 Subject: [PATCH 30/31] More --- lib/executor/src/executors/map.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 7ecad86ba..b666189c7 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -166,11 +166,7 @@ impl SubgraphExecutorMap { } StartControlFlow::EndResponse(response) => { // TODO: FFIX - return HttpResponse { - body: response.body, - headers: response.headers, - status: response.status, - }; + return response; } StartControlFlow::OnEnd(callback) => { on_end_callbacks.push(callback); From 86e9f805c6ea736730617796142d0f50f76e7403 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 2 Dec 2025 16:24:10 +0300 Subject: [PATCH 31/31] More --- lib/executor/src/plugins/plugin_context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs index 8f8960131..b5babeca9 100644 --- a/lib/executor/src/plugins/plugin_context.rs +++ b/lib/executor/src/plugins/plugin_context.rs @@ -97,14 +97,14 @@ impl PluginContext { .insert(type_id, Box::new(value)) .and_then(|boxed_any| boxed_any.downcast::().ok()) } - pub fn get_ref(&self) -> Option> { + pub fn get_ref<'a, T: Any + Send + Sync>(&'a self) -> Option> { let type_id = TypeId::of::(); self.inner.get(&type_id).map(|entry| PluginContextRefEntry { entry, phantom: std::marker::PhantomData, }) } - pub fn get_mut(&self) -> Option> { + pub fn get_mut<'a, T: Any + Send + Sync>(&'a self) -> Option> { let type_id = TypeId::of::(); self.inner .get_mut(&type_id)