diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index c0db982c7d..98f060e927 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -111,6 +111,7 @@ impl PgArguments { &mut self, conn: &mut PgConnection, parameters: &[PgTypeInfo], + persistent: bool, ) -> Result<(), Error> { let PgArgumentBuffer { ref patches, @@ -128,8 +129,8 @@ impl PgArguments { for (offset, kind) in type_holes { let oid = match kind { - HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?, - HoleKind::Array(array) => conn.fetch_array_type_id(array).await?, + HoleKind::Type { name } => conn.fetch_type_id_by_name(persistent, name).await?, + HoleKind::Array(array) => conn.fetch_array_type_id(persistent, array).await?, }; buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes()); } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index dfe5286458..8ae35a4592 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -103,6 +103,7 @@ impl PgConnection { pub(super) async fn handle_row_description( &mut self, desc: Option, + persistent: bool, fetch_type_info: bool, fetch_column_description: bool, ) -> Result<(Vec, HashMap), Error> { @@ -123,14 +124,19 @@ impl PgConnection { let name = UStr::from(field.name); let type_info = self - .maybe_fetch_type_info_by_oid(field.data_type_id, fetch_type_info) + .maybe_fetch_type_info_by_oid(field.data_type_id, persistent, fetch_type_info) .await?; let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) { - self.maybe_fetch_column_origin(relation_oid, attribute_no, fetch_column_description) - .await? + self.maybe_fetch_column_origin( + relation_oid, + attribute_no, + persistent, + fetch_column_description, + ) + .await? } else { ColumnOrigin::Expression }; @@ -153,12 +159,16 @@ impl PgConnection { pub(super) async fn handle_parameter_description( &mut self, + persistent: bool, desc: ParameterDescription, ) -> Result, Error> { let mut params = Vec::with_capacity(desc.types.len()); for ty in desc.types { - params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?); + params.push( + self.maybe_fetch_type_info_by_oid(ty, persistent, true) + .await?, + ); } Ok(params) @@ -167,6 +177,7 @@ impl PgConnection { async fn maybe_fetch_type_info_by_oid( &mut self, oid: Oid, + persistent: bool, should_fetch: bool, ) -> Result { // first we check if this is a built-in type @@ -183,7 +194,7 @@ impl PgConnection { // fallback to asking the database directly for a type name if should_fetch { // we're boxing this future here so we can use async recursion - let info = Box::pin(async { self.fetch_type_by_oid(oid).await }).await?; + let info = Box::pin(async { self.fetch_type_by_oid(persistent, oid).await }).await?; // cache the type name <-> oid relationship in a paired hashmap // so we don't come down this road again @@ -208,6 +219,7 @@ impl PgConnection { &mut self, relation_id: Oid, attribute_no: i16, + persistent: bool, should_fetch: bool, ) -> Result { if let Some(origin) = self @@ -238,6 +250,7 @@ impl PgConnection { FROM pg_catalog.pg_attribute \ WHERE attrelid = $1 AND attnum = $2", ) + .persistent(persistent) .bind(relation_id) .bind(attribute_no) .fetch_optional(&mut *self) @@ -267,7 +280,7 @@ impl PgConnection { })) } - async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result { + async fn fetch_type_by_oid(&mut self, persistent: bool, oid: Oid) -> Result { let (name, typ_type, category, relation_id, element, base_type): ( String, i8, @@ -287,6 +300,7 @@ impl PgConnection { FROM pg_catalog.pg_type \ WHERE oid = $1", ) + .persistent(persistent) .bind(oid) .fetch_one(&mut *self) .await?; @@ -295,12 +309,16 @@ impl PgConnection { let category = TypCategory::try_from(category); match (typ_type, category) { - (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, + (Ok(TypType::Domain), _) => { + self.fetch_domain_by_oid(oid, base_type, persistent, name) + .await + } (Ok(TypType::Base), Ok(TypCategory::Array)) => { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Array( - self.maybe_fetch_type_info_by_oid(element, true).await?, + self.maybe_fetch_type_info_by_oid(element, persistent, true) + .await?, ), name: name.into(), oid, @@ -316,13 +334,16 @@ impl PgConnection { } (Ok(TypType::Range), Ok(TypCategory::Range)) => { - self.fetch_range_by_oid(oid, name).await + self.fetch_range_by_oid(oid, persistent, name).await } - (Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await, + (Ok(TypType::Enum), Ok(TypCategory::Enum)) => { + self.fetch_enum_by_oid(oid, persistent, name).await + } (Ok(TypType::Composite), Ok(TypCategory::Composite)) => { - self.fetch_composite_by_oid(oid, relation_id, name).await + self.fetch_composite_by_oid(oid, relation_id, persistent, name) + .await } _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { @@ -333,7 +354,12 @@ impl PgConnection { } } - async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result { + async fn fetch_enum_by_oid( + &mut self, + oid: Oid, + persistent: bool, + name: String, + ) -> Result { let variants: Vec = query_scalar( r#" SELECT enumlabel @@ -342,6 +368,7 @@ WHERE enumtypid = $1 ORDER BY enumsortorder "#, ) + .persistent(persistent) .bind(oid) .fetch_all(self) .await?; @@ -357,6 +384,7 @@ ORDER BY enumsortorder &mut self, oid: Oid, relation_id: Oid, + persistent: bool, name: String, ) -> Result { let raw_fields: Vec<(String, Oid)> = query_as( @@ -369,6 +397,7 @@ AND attnum > 0 ORDER BY attnum "#, ) + .persistent(persistent) .bind(relation_id) .fetch_all(&mut *self) .await?; @@ -376,7 +405,9 @@ ORDER BY attnum let mut fields = Vec::new(); for (field_name, field_oid) in raw_fields.into_iter() { - let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?; + let field_type = self + .maybe_fetch_type_info_by_oid(field_oid, persistent, true) + .await?; fields.push((field_name, field_type)); } @@ -392,9 +423,12 @@ ORDER BY attnum &mut self, oid: Oid, base_type: Oid, + persistent: bool, name: String, ) -> Result { - let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?; + let base_type = self + .maybe_fetch_type_info_by_oid(base_type, persistent, true) + .await?; Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { oid, @@ -403,7 +437,12 @@ ORDER BY attnum })))) } - async fn fetch_range_by_oid(&mut self, oid: Oid, name: String) -> Result { + async fn fetch_range_by_oid( + &mut self, + oid: Oid, + persistent: bool, + name: String, + ) -> Result { let element_oid: Oid = query_scalar( r#" SELECT rngsubtype @@ -415,7 +454,9 @@ WHERE rngtypid = $1 .fetch_one(&mut *self) .await?; - let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?; + let element = self + .maybe_fetch_type_info_by_oid(element_oid, persistent, true) + .await?; Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Range(element), @@ -424,26 +465,35 @@ WHERE rngtypid = $1 })))) } - pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result { + pub(crate) async fn resolve_type_id( + &mut self, + persistent: bool, + ty: &PgType, + ) -> Result { if let Some(oid) = ty.try_oid() { return Ok(oid); } match ty { - PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await, - PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await, + PgType::DeclareWithName(name) => self.fetch_type_id_by_name(persistent, name).await, + PgType::DeclareArrayOf(array) => self.fetch_array_type_id(persistent, array).await, // `.try_oid()` should return `Some()` or it should be covered here _ => unreachable!("(bug) OID should be resolvable for type {ty:?}"), } } - pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result { + pub(crate) async fn fetch_type_id_by_name( + &mut self, + persistent: bool, + name: &str, + ) -> Result { if let Some(oid) = self.inner.cache_type_oid.get(name) { return Ok(*oid); } // language=SQL let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid") + .persistent(persistent) .bind(name) .fetch_optional(&mut *self) .await? @@ -457,7 +507,11 @@ WHERE rngtypid = $1 Ok(oid) } - pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result { + pub(crate) async fn fetch_array_type_id( + &mut self, + persistent: bool, + array: &PgArrayOf, + ) -> Result { if let Some(oid) = self .inner .cache_type_oid @@ -470,6 +524,7 @@ WHERE rngtypid = $1 // language=SQL let (elem_oid, array_oid): (Oid, Oid) = query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid") + .persistent(persistent) .bind(&*array.elem_name) .fetch_optional(&mut *self) .await? @@ -719,19 +774,19 @@ fn explain_parsing() { // https://github.com/launchbadge/sqlx/issues/2622 let extra_field = r#"[ - { - "Plan": { - "Node Type": "Result", - "Parallel Aware": false, - "Async Capable": false, - "Startup Cost": 0.00, - "Total Cost": 0.01, - "Plan Rows": 1, - "Plan Width": 4, - "Output": ["1"] - }, + { + "Plan": { + "Node Type": "Result", + "Parallel Aware": false, + "Async Capable": false, + "Startup Cost": 0.00, + "Total Cost": 0.01, + "Plan Rows": 1, + "Plan Width": 4, + "Output": ["1"] + }, "Query Identifier": 1147616880456321454 - } + } ]"#; // https://github.com/launchbadge/sqlx/issues/1449 diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index ba4cffa647..f6527dd72b 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -44,7 +44,7 @@ async fn prepare( let mut param_types = Vec::with_capacity(parameters.len()); for ty in parameters { - param_types.push(conn.resolve_type_id(&ty.0).await?); + param_types.push(conn.resolve_type_id(persistent, &ty.0).await?); } // flush and wait until we are re-ready @@ -85,16 +85,34 @@ async fn prepare( // each SYNC produces one READY FOR QUERY conn.recv_ready_for_query().await?; - let parameters = conn.handle_parameter_description(parameters).await?; + let parameters = conn + .handle_parameter_description(persistent, parameters) + .await?; let (columns, column_names) = conn - .handle_row_description(rows, true, fetch_column_origin) + .handle_row_description(rows, persistent, true, fetch_column_origin) .await?; // ensure that if we did fetch custom data, we wait until we are fully ready before // continuing conn.wait_until_ready().await?; + // if not persistent, we must Parse/Describe again, as handle_parameter_description/handle_row_description + // will overwrite the current unnamed prepared statement. + // + // we don't need to send sync/wait for response + if !persistent { + conn.inner.stream.write_msg(Parse { + param_types: ¶m_types, + query: sql, + statement: id, + })?; + + conn.inner + .stream + .write_msg(message::Describe::Statement(id))?; + } + Arc::new(PgStatementMetadata { parameters, columns, @@ -246,7 +264,9 @@ impl PgConnection { metadata = metadata_; // patch holes created during encoding - arguments.apply_patches(self, &metadata.parameters).await?; + arguments + .apply_patches(self, &metadata.parameters, persistent) + .await?; // consume messages till `ReadyForQuery` before bind and execute self.wait_until_ready().await?; @@ -347,7 +367,7 @@ impl PgConnection { BackendMessageFormat::RowDescription => { // indicates that a *new* set of rows are about to be returned let (columns, column_names) = self - .handle_row_description(Some(message.decode()?), false, false) + .handle_row_description(Some(message.decode()?), persistent, false, false) .await?; metadata = Arc::new(PgStatementMetadata {