Skip to content

Commit

Permalink
Box Pgconnection fields (launchbadge#3529)
Browse files Browse the repository at this point in the history
* Update PgConnection code

* rustfmt
  • Loading branch information
joeydewaal authored and jrasanen committed Oct 14, 2024
1 parent 718c70c commit e710cda
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 104 deletions.
35 changes: 25 additions & 10 deletions sqlx-postgres/src/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl PgConnection {
}

// next we check a local cache for user-defined type names <-> object id
if let Some(info) = self.cache_type_info.get(&oid) {
if let Some(info) = self.inner.cache_type_info.get(&oid) {
return Ok(info.clone());
}

Expand All @@ -173,8 +173,9 @@ impl PgConnection {

// cache the type name <-> oid relationship in a paired hashmap
// so we don't come down this road again
self.cache_type_info.insert(oid, info.clone());
self.cache_type_oid
self.inner.cache_type_info.insert(oid, info.clone());
self.inner
.cache_type_oid
.insert(info.0.name().to_string().into(), oid);

Ok(info)
Expand Down Expand Up @@ -374,7 +375,7 @@ WHERE rngtypid = $1
}

pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<Oid, Error> {
if let Some(oid) = self.cache_type_oid.get(name) {
if let Some(oid) = self.inner.cache_type_oid.get(name) {
return Ok(*oid);
}

Expand All @@ -387,15 +388,18 @@ WHERE rngtypid = $1
type_name: name.into(),
})?;

self.cache_type_oid.insert(name.to_string().into(), oid);
self.inner
.cache_type_oid
.insert(name.to_string().into(), oid);
Ok(oid)
}

pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result<Oid, Error> {
if let Some(oid) = self
.inner
.cache_type_oid
.get(&array.elem_name)
.and_then(|elem_oid| self.cache_elem_type_to_array.get(elem_oid))
.and_then(|elem_oid| self.inner.cache_elem_type_to_array.get(elem_oid))
{
return Ok(*oid);
}
Expand All @@ -411,10 +415,13 @@ WHERE rngtypid = $1
})?;

// Avoids copying `elem_name` until necessary
self.cache_type_oid
self.inner
.cache_type_oid
.entry_ref(&array.elem_name)
.insert(elem_oid);
self.cache_elem_type_to_array.insert(elem_oid, array_oid);
self.inner
.cache_elem_type_to_array
.insert(elem_oid, array_oid);

Ok(array_oid)
}
Expand Down Expand Up @@ -475,8 +482,16 @@ WHERE rngtypid = $1
})?;

// If the server is CockroachDB or Materialize, skip this step (#1248).
if !self.stream.parameter_statuses.contains_key("crdb_version")
&& !self.stream.parameter_statuses.contains_key("mz_version")
if !self
.inner
.stream
.parameter_statuses
.contains_key("crdb_version")
&& !self
.inner
.stream
.parameter_statuses
.contains_key("mz_version")
{
// patch up our null inference with data from EXPLAIN
let nullable_patch = self
Expand Down
28 changes: 16 additions & 12 deletions sqlx-postgres/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::message::{
};
use crate::{PgConnectOptions, PgConnection};

use super::PgConnectionInner;

// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11

Expand Down Expand Up @@ -134,18 +136,20 @@ impl PgConnection {
}

Ok(PgConnection {
stream,
process_id,
secret_key,
transaction_status,
transaction_depth: 0,
pending_ready_for_query_count: 0,
next_statement_id: StatementId::NAMED_START,
cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),
cache_elem_type_to_array: HashMap::new(),
log_settings: options.log_settings.clone(),
inner: Box::new(PgConnectionInner {
stream,
process_id,
secret_key,
transaction_status,
transaction_depth: 0,
pending_ready_for_query_count: 0,
next_statement_id: StatementId::NAMED_START,
cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),
cache_elem_type_to_array: HashMap::new(),
log_settings: options.log_settings.clone(),
}),
})
}
}
53 changes: 29 additions & 24 deletions sqlx-postgres/src/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ async fn prepare(
parameters: &[PgTypeInfo],
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
let id = conn.next_statement_id;
conn.next_statement_id = id.next();
let id = conn.inner.next_statement_id;
conn.inner.next_statement_id = id.next();

// build a list of type OIDs to send to the database in the PARSE command
// we have not yet started the query sequence, so we are *safe* to cleanly make
Expand All @@ -43,23 +43,25 @@ async fn prepare(
conn.wait_until_ready().await?;

// next we send the PARSE command to the server
conn.stream.write_msg(Parse {
conn.inner.stream.write_msg(Parse {
param_types: &param_types,
query: sql,
statement: id,
})?;

if metadata.is_none() {
// get the statement columns and parameters
conn.stream.write_msg(message::Describe::Statement(id))?;
conn.inner
.stream
.write_msg(message::Describe::Statement(id))?;
}

// we ask for the server to immediately send us the result of the PARSE command
conn.write_sync();
conn.stream.flush().await?;
conn.inner.stream.flush().await?;

// indicates that the SQL query string is now successfully parsed and has semantic validity
conn.stream.recv_expect::<ParseComplete>().await?;
conn.inner.stream.recv_expect::<ParseComplete>().await?;

let metadata = if let Some(metadata) = metadata {
// each SYNC produces one READY FOR QUERY
Expand Down Expand Up @@ -94,11 +96,11 @@ async fn prepare(
}

async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
conn.stream.recv_expect().await
conn.inner.stream.recv_expect().await
}

async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
let rows: Option<RowDescription> = match conn.stream.recv().await? {
let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
// describes the rows that will be returned when the statement is eventually executed
message if message.format == BackendMessageFormat::RowDescription => {
Some(message.decode()?)
Expand All @@ -123,7 +125,7 @@ impl PgConnection {
pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
// we need to wait for the [CloseComplete] to be returned from the server
while count > 0 {
match self.stream.recv().await? {
match self.inner.stream.recv().await? {
message if message.format == BackendMessageFormat::PortalSuspended => {
// there was an open portal
// this can happen if the last time a statement was used it was not fully executed
Expand All @@ -148,12 +150,13 @@ impl PgConnection {

#[inline(always)]
pub(crate) fn write_sync(&mut self) {
self.stream
self.inner
.stream
.write_msg(message::Sync)
.expect("BUG: Sync should not be too big for protocol");

// all SYNC messages will return a ReadyForQuery
self.pending_ready_for_query_count += 1;
self.inner.pending_ready_for_query_count += 1;
}

async fn get_or_prepare<'a>(
Expand All @@ -166,18 +169,18 @@ impl PgConnection {
// a statement object
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
return Ok((*statement).clone());
}

let statement = prepare(self, sql, parameters, metadata).await?;

if store_to_cache && self.cache_statement.is_enabled() {
if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
self.stream.write_msg(Close::Statement(id))?;
if store_to_cache && self.inner.cache_statement.is_enabled() {
if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
self.inner.stream.write_msg(Close::Statement(id))?;
self.write_sync();

self.stream.flush().await?;
self.inner.stream.flush().await?;

self.wait_for_close_complete(1).await?;
self.recv_ready_for_query().await?;
Expand All @@ -195,7 +198,7 @@ impl PgConnection {
persistent: bool,
metadata_opt: Option<Arc<PgStatementMetadata>>,
) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
let mut logger = QueryLogger::new(query, self.log_settings.clone());
let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());

// before we continue, wait until we are "ready" to accept more queries
self.wait_until_ready().await?;
Expand Down Expand Up @@ -231,7 +234,7 @@ impl PgConnection {
self.wait_until_ready().await?;

// bind to attach the arguments to the statement and create a portal
self.stream.write_msg(Bind {
self.inner.stream.write_msg(Bind {
portal: PortalId::UNNAMED,
statement,
formats: &[PgValueFormat::Binary],
Expand All @@ -242,7 +245,7 @@ impl PgConnection {

// executes the portal up to the passed limit
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
self.stream.write_msg(message::Execute {
self.inner.stream.write_msg(message::Execute {
portal: PortalId::UNNAMED,
limit: limit.into(),
})?;
Expand All @@ -255,7 +258,9 @@ impl PgConnection {

// we ask the database server to close the unnamed portal and free the associated resources
// earlier - after the execution of the current query.
self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?;
self.inner
.stream
.write_msg(Close::Portal(PortalId::UNNAMED))?;

// finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
Expand All @@ -268,8 +273,8 @@ impl PgConnection {
PgValueFormat::Binary
} else {
// Query will trigger a ReadyForQuery
self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1;
self.inner.stream.write_msg(Query(query))?;
self.inner.pending_ready_for_query_count += 1;

// metadata starts out as "nothing"
metadata = Arc::new(PgStatementMetadata::default());
Expand All @@ -278,11 +283,11 @@ impl PgConnection {
PgValueFormat::Text
};

self.stream.flush().await?;
self.inner.stream.flush().await?;

Ok(try_stream! {
loop {
let message = self.stream.recv().await?;
let message = self.inner.stream.recv().await?;

match message.format {
BackendMessageFormat::BindComplete
Expand Down
Loading

0 comments on commit e710cda

Please sign in to comment.