Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move sql-only functions to plpgsql instead of using SPI #361

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lantern_extras/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_extras"
version = "0.5.0"
version = "0.6.0"
edition = "2021"

[lib]
Expand Down
4 changes: 2 additions & 2 deletions lantern_extras/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ To add a new embedding job, use the `add_embedding_job` function:

```sql
SELECT add_embedding_job(
table => 'articles', -- Name of the table
table_name => 'articles', -- Name of the table
src_column => 'content', -- Source column for embeddings
dst_column => 'content_embedding', -- Destination column for embeddings (will be created automatically)
model => 'text-embedding-3-small', -- Model for runtime to use (default: 'text-embedding-3-small')
Expand Down Expand Up @@ -224,7 +224,7 @@ To add a new completion job, use the `add_completion_job` function:

```sql
SELECT add_completion_job(
table => 'articles', -- Name of the table
table_name => 'articles', -- Name of the table
src_column => 'content', -- Source column for embeddings
dst_column => 'content_summary', -- Destination column for llm response (will be created automatically)
system_prompt => 'Provide short summary for the given text', -- System prompt for LLM (default: '')
Expand Down
296 changes: 176 additions & 120 deletions lantern_extras/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,126 +226,182 @@ fn add_completion_job<'a>(
Ok(id.unwrap())
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_embedding_job_status<'a>(
job_id: i32,
) -> Result<
TableIterator<
'static,
(
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
let tuple = Spi::get_three_with_args(
r#"
SELECT
CASE
WHEN init_failed_at IS NOT NULL THEN 'failed'
WHEN canceled_at IS NOT NULL THEN 'canceled'
WHEN init_finished_at IS NOT NULL THEN 'enabled'
WHEN init_started_at IS NOT NULL THEN 'in_progress'
ELSE 'queued'
END AS status,
init_progress as progress,
init_failure_reason as error
FROM _lantern_extras_internal.embedding_generation_jobs
WHERE id=$1;
"#,
vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())],
);

if tuple.is_err() {
return Ok(TableIterator::once((None, None, None)));
}

Ok(TableIterator::once(tuple.unwrap()))
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_completion_job_failures<'a>(
job_id: i32,
) -> Result<
TableIterator<'static, (name!(row_id, Option<i32>), name!(value, Option<String>))>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT row_id, value FROM _lantern_extras_internal.embedding_failure_info WHERE job_id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]))?
.map(|row| Ok((row["row_id"].value()?, row["value"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_embedding_jobs<'a>() -> Result<
TableIterator<
'static,
(
name!(id, Option<i32>),
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'embedding_generation'", None, None)?
.map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_completion_jobs<'a>() -> Result<
TableIterator<
'static,
(
name!(id, Option<i32>),
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'completion'", None, None)?
.map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn cancel_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult {
Spi::run_with_args(
r#"
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=$1;
"#,
Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]),
)?;

Ok(())
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn resume_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult {
Spi::run_with_args(
r#"
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=$1;
"#,
Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]),
)?;

Ok(())
}
extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_embedding_job_status(job_id INT)
RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do these need to be security definer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the embedding jobs table is being created mostly by superusers and then when regular user tries to call the get_embedding_job_status or related function it does not have permission to read from the table.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we grant SELECT on the jobs table to PUBLIC instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are correct, I did grant select to public already. Added tests to make sure that functions work without security definer.

LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT
CASE
WHEN init_failed_at IS NOT NULL THEN 'failed'
WHEN canceled_at IS NOT NULL THEN 'canceled'
WHEN init_finished_at IS NOT NULL THEN 'enabled'
WHEN init_started_at IS NOT NULL THEN 'in_progress'
ELSE 'queued'
END AS status,
init_progress as progress,
init_failure_reason as error
FROM _lantern_extras_internal.embedding_generation_jobs
WHERE id=job_id;
END
$$;
"#,
name = "get_embedding_job_status"
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_job_status(job_id INT)
RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT * FROM get_embedding_job_status(job_id);
END
$$;
"#,
name = "get_completion_job_status",
requires = ["get_embedding_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_job_failures(job_id INT)
RETURNS TABLE (row_id INT, value TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT info.row_id, info.value
FROM _lantern_extras_internal.embedding_failure_info info
WHERE info.job_id=get_completion_job_failures.job_id;
END
$$;
"#,
name = "get_completion_job_failures",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_embedding_jobs()
RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT jobs.id, (get_embedding_job_status(jobs.id)).*
FROM _lantern_extras_internal.embedding_generation_jobs jobs
WHERE jobs.job_type = 'embedding_generation';
END
$$;
"#,
name = "get_embedding_jobs",
requires = ["get_embedding_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_jobs()
RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT jobs.id, (get_completion_job_status(jobs.id)).*
FROM _lantern_extras_internal.embedding_generation_jobs jobs
WHERE jobs.job_type = 'completion';
END
$$;
"#,
name = "get_completion_jobs",
requires = ["get_completion_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION cancel_embedding_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=job_id;
END
$$;
"#,
name = "cancel_embedding_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION cancel_completion_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=job_id;
END
$$;
"#,
name = "cancel_completion_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION resume_embedding_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=job_id;
END
$$;
"#,
name = "resume_embedding_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION resume_completion_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=job_id;
END
$$;
"#,
name = "resume_completion_job",
);

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
Expand Down
Loading