Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rewrite statement splitter #138

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: minor stuff
psteinroe committed Aug 23, 2024
commit 69bf227368e3a55dd860abe13f45541b882dbf26
164 changes: 126 additions & 38 deletions crates/pg_statement_splitter/src/data.rs
Original file line number Diff line number Diff line change
@@ -78,6 +78,7 @@ pub struct StatementDefinition {
pub stmt: SyntaxKind,
pub tokens: Vec<SyntaxDefinition>,
pub prohibited_following_statements: Vec<SyntaxKind>,
pub prohibited_tokens: Vec<SyntaxKind>,
}

impl StatementDefinition {
@@ -86,9 +87,15 @@ impl StatementDefinition {
stmt,
tokens: b.build(),
prohibited_following_statements: Vec::new(),
prohibited_tokens: Vec::new(),
}
}

fn with_prohibited_tokens(mut self, prohibited: Vec<SyntaxKind>) -> Self {
self.prohibited_tokens = prohibited;
self
}

fn with_prohibited_following_statements(mut self, prohibited: Vec<SyntaxKind>) -> Self {
self.prohibited_following_statements = prohibited;
self
@@ -223,7 +230,11 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.optional_if_exists_group()
.optional_token(SyntaxKind::Only)
.optional_schema_name_group()
.required_token(SyntaxKind::Ident)
.one_of(vec![
SyntaxKind::Ident,
SyntaxKind::VersionP,
SyntaxKind::Simple,
])
.any_token(),
));

@@ -273,13 +284,16 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.required_token(SyntaxKind::Ascii41),
));

m.push(StatementDefinition::new(
SyntaxKind::AlterDefaultPrivilegesStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Alter)
.required_token(SyntaxKind::Default)
.required_token(SyntaxKind::Privileges),
));
m.push(
StatementDefinition::new(
SyntaxKind::AlterDefaultPrivilegesStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Alter)
.required_token(SyntaxKind::Default)
.required_token(SyntaxKind::Privileges),
)
.with_prohibited_following_statements(vec![SyntaxKind::GrantStmt]),
);

m.push(StatementDefinition::new(
SyntaxKind::ClusterStmt,
@@ -387,6 +401,17 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.required_token(SyntaxKind::Ident),
));

m.push(StatementDefinition::new(
SyntaxKind::DropStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Drop)
.required_token(SyntaxKind::Materialized)
.required_token(SyntaxKind::View)
.optional_if_exists_group()
.optional_schema_name_group()
.required_token(SyntaxKind::Ident),
));

m.push(StatementDefinition::new(
SyntaxKind::DropStmt,
SyntaxBuilder::new()
@@ -822,6 +847,11 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
SyntaxBuilder::new().required_token(SyntaxKind::BeginP),
));

m.push(StatementDefinition::new(
SyntaxKind::TransactionStmt,
SyntaxBuilder::new().required_token(SyntaxKind::EndP),
));

m.push(StatementDefinition::new(
SyntaxKind::TransactionStmt,
SyntaxBuilder::new()
@@ -942,7 +972,11 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.required_token(SyntaxKind::Table)
.optional_if_not_exists_group()
.optional_schema_name_group()
.required_token(SyntaxKind::Ident)
.one_of(vec![
SyntaxKind::Ident,
SyntaxKind::VersionP,
SyntaxKind::Simple,
])
.any_tokens(None)
.required_token(SyntaxKind::As)
.any_token(),
@@ -973,7 +1007,19 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
m.push(
StatementDefinition::new(
SyntaxKind::ExplainStmt,
SyntaxBuilder::new().required_token(SyntaxKind::Explain),
SyntaxBuilder::new()
.required_token(SyntaxKind::Explain)
.one_of(vec![
SyntaxKind::Analyze,
SyntaxKind::Ascii40,
SyntaxKind::Select,
SyntaxKind::Insert,
SyntaxKind::Update,
SyntaxKind::DeleteP,
SyntaxKind::Merge,
SyntaxKind::Execute,
SyntaxKind::Create,
]),
)
.with_prohibited_following_statements(vec![
SyntaxKind::VacuumStmt,
@@ -983,6 +1029,7 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
SyntaxKind::UpdateStmt,
SyntaxKind::MergeStmt,
SyntaxKind::ExecuteStmt,
SyntaxKind::CreateTableAsStmt,
]),
);

@@ -1105,6 +1152,18 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.required_token(SyntaxKind::Ident),
));

m.push(
StatementDefinition::new(
SyntaxKind::AlterRoleSetStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Alter)
.required_token(SyntaxKind::Role)
.required_token(SyntaxKind::Ident)
.required_token(SyntaxKind::Set),
)
.with_prohibited_following_statements(vec![SyntaxKind::VariableSetStmt]),
);

m.push(StatementDefinition::new(
SyntaxKind::DropRoleStmt,
SyntaxBuilder::new()
@@ -1160,12 +1219,23 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
SyntaxBuilder::new().required_token(SyntaxKind::Checkpoint),
));

m.push(StatementDefinition::new(
SyntaxKind::CreateSchemaStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Create)
.required_token(SyntaxKind::Schema),
));
// CREATE TABLE, CREATE VIEW, CREATE INDEX, CREATE SEQUENCE, CREATE TRIGGER and GRANT
m.push(
StatementDefinition::new(
SyntaxKind::CreateSchemaStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Create)
.required_token(SyntaxKind::Schema),
)
.with_prohibited_following_statements(vec![
SyntaxKind::CreateTableAsStmt,
SyntaxKind::CreateStmt,
SyntaxKind::IndexStmt,
SyntaxKind::CreateSeqStmt,
SyntaxKind::CreateTrigStmt,
SyntaxKind::GrantStmt,
]),
);

m.push(StatementDefinition::new(
SyntaxKind::AlterDatabaseStmt,
@@ -1233,18 +1303,21 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.required_token(SyntaxKind::Ident),
));

m.push(StatementDefinition::new(
SyntaxKind::AlterOpFamilyStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Alter)
.required_token(SyntaxKind::Operator)
.required_token(SyntaxKind::Family)
.optional_schema_name_group()
.required_token(SyntaxKind::Ident)
.required_token(SyntaxKind::Using)
.required_token(SyntaxKind::Ident)
.one_of(vec![SyntaxKind::Drop, SyntaxKind::AddP, SyntaxKind::Rename]),
));
m.push(
StatementDefinition::new(
SyntaxKind::AlterOpFamilyStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Alter)
.required_token(SyntaxKind::Operator)
.required_token(SyntaxKind::Family)
.optional_schema_name_group()
.required_token(SyntaxKind::Ident)
.required_token(SyntaxKind::Using)
.required_token(SyntaxKind::Ident)
.one_of(vec![SyntaxKind::Drop, SyntaxKind::AddP, SyntaxKind::Rename]),
)
.with_prohibited_tokens(vec![SyntaxKind::Rename]),
);

m.push(
StatementDefinition::new(
@@ -1256,9 +1329,21 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.required_token(SyntaxKind::As)
.any_token(),
)
.with_prohibited_following_statements(vec![SyntaxKind::SelectStmt]),
.with_prohibited_following_statements(vec![
SyntaxKind::SelectStmt,
SyntaxKind::InsertStmt,
SyntaxKind::UpdateStmt,
SyntaxKind::DeleteStmt,
]),
);

m.push(StatementDefinition::new(
SyntaxKind::ClosePortalStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Close)
.one_of(vec![SyntaxKind::Ident, SyntaxKind::All]),
));

m.push(StatementDefinition::new(
SyntaxKind::DeallocateStmt,
SyntaxBuilder::new()
@@ -1331,15 +1416,18 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
.required_token(SyntaxKind::Ident),
));

m.push(StatementDefinition::new(
SyntaxKind::AlterFdwStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Alter)
.required_token(SyntaxKind::Foreign)
.required_token(SyntaxKind::DataP)
.required_token(SyntaxKind::Wrapper)
.required_token(SyntaxKind::Ident),
));
m.push(
StatementDefinition::new(
SyntaxKind::AlterFdwStmt,
SyntaxBuilder::new()
.required_token(SyntaxKind::Alter)
.required_token(SyntaxKind::Foreign)
.required_token(SyntaxKind::DataP)
.required_token(SyntaxKind::Wrapper)
.required_token(SyntaxKind::Ident),
)
.with_prohibited_tokens(vec![SyntaxKind::Rename]),
);

m.push(StatementDefinition::new(
SyntaxKind::CreateForeignServerStmt,
77 changes: 63 additions & 14 deletions crates/pg_statement_splitter/src/statement_splitter.rs
Original file line number Diff line number Diff line change
@@ -37,25 +37,31 @@ impl<'a> StatementSplitter<'a> {
}
}

fn track_nesting(&mut self) {
fn end_nesting(&mut self) {
match self.parser.nth(0, false).kind {
SyntaxKind::Ascii40 => {
// "("
self.sub_stmt_depth += 1;
}
SyntaxKind::Ascii41 => {
// ")"
self.sub_stmt_depth -= 1;
}
SyntaxKind::EndP => {
self.is_within_atomic_block = false;
}
_ => {}
};
}

fn start_nesting(&mut self) {
match self.parser.nth(0, false).kind {
SyntaxKind::Ascii40 => {
// "("
self.sub_stmt_depth += 1;
}
SyntaxKind::Atomic => {
if self.parser.lookbehind(2, true, None).map(|t| t.kind) == Some(SyntaxKind::BeginP)
{
self.is_within_atomic_block = true;
}
}
SyntaxKind::EndP => {
self.is_within_atomic_block = false;
}
_ => {}
};
}
@@ -177,19 +183,19 @@ impl<'a> StatementSplitter<'a> {
.min_by_key(|stmt| stmt.started_at)
.map(|stmt| stmt.started_at)
{
println!(
"earliest complete stmt started at: {:?}",
earliest_complete_stmt_started_at
);
let earliest_complete_stmt = self
.tracked_statements
.iter()
.filter(|s| {
s.started_at == earliest_complete_stmt_started_at && s.could_be_complete()
})
.max_by_key(|stmt| stmt.max_pos())
.max_by_key(|stmt| {
println!("stmt: {:?} max pos: {:?}", stmt.def.stmt, stmt.max_pos());
stmt.max_pos()
})
.unwrap();

println!("earliest complete stmt: {:?}", earliest_complete_stmt);
assert_eq!(
1,
self.tracked_statements
@@ -304,7 +310,7 @@ impl<'a> StatementSplitter<'a> {
.collect::<Vec<_>>()
);

self.track_nesting();
self.start_nesting();

let removed_items_min_started_at = self.advance_tracker();

@@ -328,6 +334,8 @@ impl<'a> StatementSplitter<'a> {
self.close_stmt_with_semicolon();
}

self.end_nesting();

// # This is where the actual parsing happens

// 1. Find the latest complete statement
@@ -1360,6 +1368,47 @@ DROP LANGUAGE IF EXISTS test_language_exists;
assert_eq!(SyntaxKind::DropStmt, result[2].kind);
}

#[test]
fn alter_mat_view() {
let input = "
ALTER MATERIALIZED VIEW mvtest_tvm SET SCHEMA mvtest_mvschema;
";
let result = StatementSplitter::new(input).run();

assert_eq!(result.len(), 1);
assert_eq!(SyntaxKind::AlterObjectSchemaStmt, result[0].kind);
}

#[test]
fn create_tbl_as_2() {
let input = "
create table simple as
select generate_series(1, 20000) AS id, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa';
";
let result = StatementSplitter::new(input).run();

assert_eq!(result.len(), 1);
assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind);
}

#[test]
fn create_tbl_as() {
let input = "
CREATE TABLE tab_settings_flags AS SELECT name, category,
'EXPLAIN' = ANY(flags) AS explain,
'NO_RESET_ALL' = ANY(flags) AS no_reset_all,
'NO_SHOW_ALL' = ANY(flags) AS no_show_all,
'NOT_IN_SAMPLE' = ANY(flags) AS not_in_sample,
'RUNTIME_COMPUTED' = ANY(flags) AS runtime_computed
FROM pg_show_all_settings() AS psas,
pg_settings_get_flags(psas.name) AS flags;
";
let result = StatementSplitter::new(input).run();

assert_eq!(result.len(), 1);
assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind);
}

#[allow(clippy::must_use)]
fn debug(input: &str) {
for s in input.split(';').filter_map(|s| {
22 changes: 21 additions & 1 deletion crates/pg_statement_splitter/src/tracker.rs
Original file line number Diff line number Diff line change
@@ -75,8 +75,24 @@ impl<'a> Tracker<'a> {
true
}

/// Returns the max idx of all tracked positions while ignoring non-required tokens
pub fn max_pos(&self) -> usize {
self.positions.iter().max_by_key(|p| p.idx).unwrap().idx
self.positions
.iter()
.map(|p| {
// substract non-required tokens from the position count
(0..p.idx).fold(0, |acc, idx| {
let token = self.def.tokens.get(idx);
match token {
Some(SyntaxDefinition::RequiredToken(_)) => acc + 1,
Some(SyntaxDefinition::OneOf(_)) => acc + 1,
Some(SyntaxDefinition::AnyToken) => acc + 1,
_ => acc,
}
})
})
.max()
.unwrap()
}

pub fn current_positions(&self) -> Vec<usize> {
@@ -132,6 +148,10 @@ impl<'a> Tracker<'a> {
return true;
}

if self.def.prohibited_tokens.contains(kind) {
return false;
}

let mut new_positions = Vec::with_capacity(self.positions.len());

println!(
7 changes: 7 additions & 0 deletions crates/pg_statement_splitter/tests/skipped.txt
Original file line number Diff line number Diff line change
@@ -15,3 +15,10 @@ comments
dependency
drop_if_exists
groupingsets
index_including_gist
inherit
insert
insert_conflict
numeric_big
opr_sanity
case