Skip to content

Commit

Permalink
fix(Postgres) chunk pg_copy data (#3703)
Browse files Browse the repository at this point in the history
* fix(postgres) chunk pg_copy data

* fix: cleanup after review
  • Loading branch information
joeydewaal authored Jan 25, 2025
1 parent 74da542 commit 6fa0458
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
22 changes: 15 additions & 7 deletions sqlx-postgres/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ impl PgPoolCopyExt for Pool<Postgres> {
}
}

// (1 GiB - 1) - 1 - length prefix (4 bytes)
pub const PG_COPY_MAX_DATA_LEN: usize = 0x3fffffff - 1 - 4;

/// A connection in streaming `COPY FROM STDIN` mode.
///
/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
Expand Down Expand Up @@ -186,15 +189,20 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {

/// Send a chunk of `COPY` data.
///
/// The data is sent in chunks if it exceeds the maximum length of a `CopyData` message (1 GiB - 6
/// bytes) and may be partially sent if this call is cancelled.
///
/// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
self.conn
.as_deref_mut()
.expect("send_data: conn taken")
.inner
.stream
.send(CopyData(data))
.await?;
for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) {
self.conn
.as_deref_mut()
.expect("send_data: conn taken")
.inner
.stream
.send(CopyData(chunk))
.await?;
}

Ok(self)
}
Expand Down
3 changes: 3 additions & 0 deletions sqlx-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ mod value;
#[doc(hidden)]
pub mod any;

#[doc(hidden)]
pub use copy::PG_COPY_MAX_DATA_LEN;

#[cfg(feature = "migrate")]
mod migrate;

Expand Down
22 changes: 21 additions & 1 deletion tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use futures::{Stream, StreamExt, TryStreamExt};
use sqlx::postgres::types::Oid;
use sqlx::postgres::{
PgAdvisoryLock, PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgListener,
PgPoolOptions, PgRow, PgSeverity, Postgres,
PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN,
};
use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo};
use sqlx_core::{bytes::Bytes, error::BoxDynError};
Expand Down Expand Up @@ -2042,3 +2042,23 @@ async fn test_issue_3052() {
"expected encode error, got {too_large_error:?}",
);
}

#[sqlx_macros::test]
async fn test_pg_copy_chunked() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;

let mut row = "1".repeat(PG_COPY_MAX_DATA_LEN / 10 - 1);
row.push_str("\n");

// creates a payload with COPY_MAX_DATA_LEN + 1 as size
let mut payload = row.repeat(10);
payload.push_str("12345678\n");

assert_eq!(payload.len(), PG_COPY_MAX_DATA_LEN + 1);

let mut copy = conn.copy_in_raw("COPY products(name) FROM STDIN").await?;

assert!(copy.send(payload.as_bytes()).await.is_ok());
assert!(copy.finish().await.is_ok());
Ok(())
}

0 comments on commit 6fa0458

Please sign in to comment.