Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: New Functionalities #13

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ description = "A CLI tool for generating models based on a SQL Database using SQ
license = "MIT"

[dependencies]
sqlx = { version = "0.7", features = ["postgres","runtime-tokio"] }
sqlx = { version = "0.7", features = ["postgres", "runtime-tokio"] }
sqlx-cli = "0.7"
clap = "3.0"
regex = "1.5"
chrono = "0.4"
# regex = "1.5"
tokio = { version = "1", features = ["full"] }
dotenv = "0.15.0"
testcontainers = { version ="0.15.0" }
testcontainers = { version = "0.15.0" }
testcontainers-modules = { version = "0.3.5", features = ["postgres"] }
27 changes: 26 additions & 1 deletion src/db_queries.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use sqlx::PgPool;

use crate::models::TableColumn;
use crate::models::{TableColumn, UserDefinedEnums};

pub async fn get_table_columns(
pool: &PgPool,
Expand Down Expand Up @@ -98,3 +98,28 @@ ORDER BY
.await?;
Ok(rows)
}

pub async fn get_user_defined_enums(
udt_names: &Vec<String>,
pool: &PgPool,
) -> sqlx::Result<Vec<UserDefinedEnums>> {
let query = "
SELECT
t.typname AS enum_name,
e.enumlabel AS enum_value
FROM
pg_type t
JOIN pg_enum e ON t.oid = e.enumtypid
WHERE
t.typname = ANY($1)
ORDER BY
t.typname,
e.enumsortorder;
";

let rows = sqlx::query_as::<_, UserDefinedEnums>(query)
.bind(udt_names)
.fetch_all(pool)
.await?;
Ok(rows)
}
68 changes: 62 additions & 6 deletions src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@ use sqlx::PgPool;
use std::fs;
use std::path::Path;

use crate::db_queries::get_table_columns;
use crate::db_queries::{get_table_columns, get_user_defined_enums};
use crate::models::TableColumn;
use crate::utils::{generate_struct_code, to_pascal_case, to_snake_case};
use crate::utils::{generate_enum_code, generate_struct_code, to_pascal_case, to_snake_case};

use crate::query_generate::generate_query_code;
use crate::utils::{DateTimeLib, SqlGenState};
use crate::STATE;

pub async fn generate(
output_folder: &str,
database_url: &str,
context: Option<&str>,
force: bool,
include_tables: Option<Vec<&str>>,
exclude_tables: Vec<String>,
schemas: Option<Vec<&str>>,
date_time_lib: DateTimeLib,
struct_derives: Vec<String>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty for doing this 🙏

enum_derives: Vec<String>,
) -> Result<(), Box<dyn std::error::Error>> {
// Connect to the Postgres database
let pool = PgPoolOptions::new()
Expand All @@ -28,18 +34,58 @@ pub async fn generate(

let default_schema = "public";
let rows = get_table_columns(&pool, schemas.unwrap_or(vec![default_schema]), None).await?;
let user_defined = rows
.iter()
.filter_map(|e| {
if e.data_type.as_str() == "USER-DEFINED" && e.udt_name.as_str() != "geometry" {
Some(e.udt_name.clone())
} else {
None
}
})
.collect::<Vec<String>>();

let enum_rows = get_user_defined_enums(&user_defined, &pool).await?;
let mut unique_enums = std::collections::BTreeSet::new();
for row in &enum_rows {
unique_enums.insert(row.enum_name.clone());
}
let enums = unique_enums.into_iter().collect::<Vec<String>>();
// Create the output folder if it doesn't exist
fs::create_dir_all(output_folder)?;

let mut unique = std::collections::BTreeSet::new();
for row in &rows {
unique.insert(row.table_name.clone());
}
let tables: Vec<String> = unique.into_iter().collect::<Vec<String>>();

let tables: Vec<String> = unique
.into_iter()
.collect::<Vec<String>>()
.into_iter()
.filter(|e| !exclude_tables.contains(e))
.collect();

if !enums.is_empty() {
println!("Outputting user defined enums: {:?}", enums);
}
println!("Outputting tables: {:?}", tables);

STATE
.set(SqlGenState {
user_defined: enums.clone(),
date_time_lib,
struct_derives,
enum_derives,
})
.expect("Unable to set state");

let mut rs_enums = Vec::new();

for user_enum in enums {
let enum_code = generate_enum_code(&user_enum, &enum_rows);
rs_enums.push(enum_code);
}

// Generate structs and queries for each table
for table in &tables {
if let Some(ts) = include_tables.clone() {
Expand Down Expand Up @@ -75,18 +121,28 @@ pub async fn generate(
}
}

let context_code = generate_db_context(context.unwrap_or(&database_name), &tables, &rows);
let context_code =
generate_db_context(context.unwrap_or(&database_name), &rs_enums, &tables, &rows);
let context_file_path = format!("{}/mod.rs", output_folder);
fs::write(context_file_path, context_code)?;
Ok(())
}

fn generate_db_context(database_name: &str, tables: &[String], _rows: &[TableColumn]) -> String {
fn generate_db_context(
database_name: &str,
enums: &[String],
tables: &[String],
_rows: &[TableColumn],
) -> String {
let mut db_context_code = String::new();

db_context_code.push_str("#![allow(dead_code)]\n");
db_context_code
.push_str("// Generated with sql-gen\n//https://github.com/jayy-lmao/sql-gen\n\n");
for enum_item in enums {
db_context_code.push_str(enum_item);
db_context_code.push_str("\n\n");
}
for table in tables {
db_context_code.push_str(&format!("pub mod {};\n", to_snake_case(table)));
db_context_code.push_str(&format!(
Expand Down
117 changes: 116 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{collections::BTreeSet, sync::OnceLock};

use clap::{App, Arg, SubCommand};
use utils::{DateTimeLib, SqlGenState};

mod db_queries;
mod generate;
Expand All @@ -7,6 +10,8 @@ mod models;
mod query_generate;
mod utils;

pub(crate) static STATE: OnceLock<SqlGenState> = OnceLock::new();

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
dotenv::dotenv().ok();
Expand All @@ -22,6 +27,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.help("Sets the output folder for generated structs")
.takes_value(true),
)
.arg(
Arg::with_name("serde")
.long("serde")
.default_value("true")
.value_name("SQLGEN_ENABLE_SERDE")
.help("Adds Serde derices to created structs")
.takes_value(false),
)
.arg(
Arg::with_name("migrations")
.short('m')
Expand Down Expand Up @@ -68,6 +81,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.use_delimiter(true)
.help("Specify the table name(s)"),
)
.arg(
Arg::with_name("exclude")
.short('e')
.long("exclude")
.takes_value(true)
.value_name("SQLGEN_EXCLUDE")
.multiple(true)
.use_delimiter(true)
.help("Specify the excluded table name(s)"),
)
Comment on lines +84 to +93
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea

.arg(
Arg::new("force")
.short('f')
Expand All @@ -76,6 +99,31 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.takes_value(false)
.required(false)
.help("Overwrites existing files sharing names in that folder"),
)
.arg(
Arg::with_name("datetime-lib")
.long("datetime-lib")
.default_value("chrono")
.possible_values(&["chrono", "time"])
.value_name("SQLGEN_DATETIME_LIB")
.help("Specifies the library to use for date and time handling")
.takes_value(true),
)
.arg(
Arg::with_name("struct-derive")
.long("struct-derive")
.value_name("SQLGEN_STRUCT_DERIVE")
.help("Derive created structs with given values")
.multiple(true)
.takes_value(true),
)
.arg(
Arg::with_name("enum-derive")
.long("enum-derive")
.value_name("SQLGEN_ENUM_DERIVE")
.help("Derive created enums with given values")
.multiple(true)
.takes_value(true),
);

let migrate_subcommand = SubCommand::with_name("migrate")
Expand Down Expand Up @@ -195,7 +243,74 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let schemas: Option<Vec<&str>> =
matches.values_of("schema").map(|schemas| schemas.collect());
let force = matches.is_present("force");
generate::generate(output_folder, database_url, context, force, None, schemas).await?;
let include_tables = matches.values_of("table").map(|v| v.collect::<Vec<&str>>());
let exclude_tables = matches
.values_of("exclude")
.map(|v| {
v.into_iter()
.map(|e| e.to_string())
.collect::<Vec<String>>()
})
.unwrap_or(vec![]);

if !exclude_tables.is_empty() {
println!("Excluding tables: {:?}", exclude_tables);
}

let enable_serde = matches.is_present("serde");
let mut struct_derives = matches
.values_of("struct-derive")
.map(|v| {
v.into_iter()
.map(|e| e.to_string())
.collect::<Vec<String>>()
})
.unwrap_or_default();
let mut enum_derives = matches
.values_of("enum-derive")
.map(|v| {
v.into_iter()
.map(|e| e.to_string())
.collect::<Vec<String>>()
})
.unwrap_or_default();

if enable_serde {
let mut unique_struct_derivies = struct_derives
.clone()
.into_iter()
.collect::<BTreeSet<String>>();
let mut unique_enum_derivies = enum_derives
.clone()
.into_iter()
.collect::<BTreeSet<String>>();
for serde_derive in ["serde::Serialize", "serde::Deserialize"] {
unique_struct_derivies.insert(serde_derive.to_string());
unique_enum_derivies.insert(serde_derive.to_string());
}

struct_derives = unique_struct_derivies.into_iter().collect();
enum_derives = unique_enum_derivies.into_iter().collect();
}
let date_time_lib = matches
.value_of("datetime-lib")
.map(|e| e.to_string())
.unwrap();
let date_time_lib = DateTimeLib::from(date_time_lib);

generate::generate(
output_folder,
database_url,
context,
force,
include_tables,
exclude_tables,
schemas,
date_time_lib,
struct_derives,
enum_derives,
)
.await?;
} else if let Some(matches) = matches.subcommand_matches("migrate") {
let input_migrations_folder = matches.value_of("migrations").unwrap_or("./migrations");
println!(
Expand Down
8 changes: 7 additions & 1 deletion src/models.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[derive(sqlx::FromRow)]
#[derive(sqlx::FromRow, Clone)]
pub struct TableColumn {
pub(crate) table_name: String,
pub(crate) column_name: String,
Expand All @@ -12,3 +12,9 @@ pub struct TableColumn {
// #todo
pub(crate) table_schema: String,
}

#[derive(sqlx::FromRow, Clone)]
pub struct UserDefinedEnums {
pub(crate) enum_name: String,
pub(crate) enum_value: String,
}
Loading