diff --git a/Cargo.toml b/Cargo.toml index 1d86bae..0ec6417 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ license = "MIT" repository = "https://github.com/jayy-lmao/sql-gen" [dependencies] -sqlx = { version = "0.7", features = ["postgres","runtime-tokio", "mysql"] } -sqlx-cli = "0.7" +sqlx = { version = "0.8.3", features = ["postgres","runtime-tokio", "mysql"] } +sqlx-cli = "0.8.3" clap = { version = "4.5.31", features = ["derive"] } regex = "1.5" chrono = "0.4" diff --git a/src/core/models/db.rs b/src/core/models/db.rs index cac9726..51e4ab0 100644 --- a/src/core/models/db.rs +++ b/src/core/models/db.rs @@ -43,6 +43,7 @@ pub struct CustomEnum { pub struct TableColumnBuilder { column_name: String, column_comment: Option, + recommended_rust_type: Option, udt_name: String, data_type: String, is_nullable: bool, @@ -59,6 +60,7 @@ impl TableColumnBuilder { column_name: impl ToString, udt_name: impl ToString, data_type: impl ToString, + recommended_rust_type: Option, ) -> Self { Self { column_name: column_name.to_string(), @@ -72,6 +74,7 @@ impl TableColumnBuilder { foreign_key_id: None, is_auto_populated: false, array_depth: 0, + recommended_rust_type, } } @@ -117,7 +120,7 @@ impl TableColumnBuilder { pub fn build(self) -> TableColumn { TableColumn { column_name: self.column_name, - recommended_rust_type: convert_data_type(&self.udt_name), + recommended_rust_type: self.recommended_rust_type, udt_name: self.udt_name, data_type: self.data_type, is_nullable: self.is_nullable, diff --git a/src/core/translators/convert_table_to_struct_test.rs b/src/core/translators/convert_table_to_struct_test.rs index 4c02a9a..7427037 100644 --- a/src/core/translators/convert_table_to_struct_test.rs +++ b/src/core/translators/convert_table_to_struct_test.rs @@ -88,7 +88,9 @@ fn should_convert_table_with_basic_column() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("title", "text", "text").build()], + columns: vec![ + TableColumnBuilder::new("title", "text", "text", Some("String".to_string())).build(), + ], ..Default::default() }; @@ -117,14 +119,14 @@ fn should_convert_table_with_each_column_attribute_type() { table_name: "products".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "uuid", "uuid") + TableColumnBuilder::new("id", "uuid", "uuid", Some("uuid::Uuid".to_string())) .is_auto_populated() .is_primary_key() .build(), - TableColumnBuilder::new("title", "text", "text") + TableColumnBuilder::new("title", "text", "text", Some("String".to_string())) .is_unique() .build(), - TableColumnBuilder::new("description", "text", "text") + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), ], @@ -170,9 +172,14 @@ fn should_convert_table_with_optional_column() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("description", "text", "text") - .is_nullable() - .build()], + columns: vec![TableColumnBuilder::new( + "description", + "text", + "text", + Some("String".to_string()), + ) + .is_nullable() + .build()], ..Default::default() }; let mut options = CodegenOptions::default(); @@ -200,10 +207,15 @@ fn should_convert_table_with_array_column() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("tags", "_text", "ARRAY") - .is_nullable() - .array_depth(1) - .build()], + columns: vec![TableColumnBuilder::new( + "tags", + "_text", + "ARRAY", + Some("String".to_string()), + ) + .is_nullable() + .array_depth(1) + .build()], ..Default::default() }; let mut options = CodegenOptions::default(); @@ -232,7 +244,9 @@ fn should_convert_table_with_enum_column() { let table = Table { table_name: "orders".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("order_status", "status", "USER-DEFINED").build()], + columns: vec![ + TableColumnBuilder::new("order_status", "status", "USER-DEFINED", None).build(), + ], ..Default::default() }; let enums: Vec = vec![CustomEnum { @@ -277,7 +291,7 @@ fn should_ignore_columns_with_invalid_types() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("title", "badtype", "badtype").build()], + columns: vec![TableColumnBuilder::new("title", "badtype", "badtype", None).build()], ..Default::default() }; let mut options = CodegenOptions::default(); @@ -300,7 +314,9 @@ fn should_convert_table_with_column_type_override() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("id", "i32", "i32").build()], + columns: vec![ + TableColumnBuilder::new("id", "int4", "int4", Some("i32".to_string())).build(), + ], ..Default::default() }; @@ -335,7 +351,9 @@ fn should_convert_table_with_global_type_override() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("id", "int4", "int4").build()], + columns: vec![ + TableColumnBuilder::new("id", "int4", "int4", Some("i32".to_string())).build(), + ], ..Default::default() }; @@ -370,7 +388,9 @@ fn column_override_takes_preference_over_global_type_override() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("price", "int4", "int4").build()], + columns: vec![ + TableColumnBuilder::new("price", "int4", "int4", Some("i32".to_string())).build(), + ], ..Default::default() }; diff --git a/src/mysql/mod.rs b/src/mysql/mod.rs index 1d0e9a7..1224c5a 100644 --- a/src/mysql/mod.rs +++ b/src/mysql/mod.rs @@ -1,4 +1,4 @@ pub mod models; pub mod queries; #[cfg(test)] -mod test_helper; +pub mod test_helper; diff --git a/src/mysql/queries/convert_type.rs b/src/mysql/queries/convert_type.rs index c1738dd..b26ab1b 100644 --- a/src/mysql/queries/convert_type.rs +++ b/src/mysql/queries/convert_type.rs @@ -9,22 +9,42 @@ pub fn convert_data_type(udt_type: &str) -> Option { } match udt_type { - "bool" | "boolean" => Some("bool".to_string()), - "bytea" => Some("u8".to_string()), // is this right? - "char" | "bpchar" | "character" => Some("String".to_string()), + // Boolean: MySQL treats TINYINT(1), BOOLEAN and BOOL as booleans. + "bool" | "boolean" | "tinyint(1)" => Some("bool".to_string()), + + // Numeric types + "tinyint unsigned" => Some("u8".to_string()), + "tinyint" => Some("i8".to_string()), + "smallint unsigned" => Some("u16".to_string()), + "smallint" => Some("i16".to_string()), + "int unsigned" => Some("u32".to_string()), + "int" => Some("i32".to_string()), + "bigint unsigned" => Some("u64".to_string()), + "bigint" => Some("i64".to_string()), + "float" => Some("f32".to_string()), + "double" => Some("f64".to_string()), + + // String types: VARCHAR, CHAR and TEXT map to Rust’s String. + "varchar" | "char" | "text" => Some("String".to_string()), + + // Binary types: VARBINARY, BINARY and BLOB map to Vec + "varbinary" | "binary" | "blob" => Some("Vec".to_string()), + + // Date and time types "date" => Some("chrono::NaiveDate".to_string()), - "float4" | "real" => Some("f32".to_string()), - "float8" | "double precision" => Some("f64".to_string()), - "int2" | "smallint" | "smallserial" => Some("i16".to_string()), - "int4" | "int" | "serial" => Some("i32".to_string()), - "int8" | "bigint" | "bigserial" => Some("i64".to_string()), - "void" => Some("()".to_string()), - "jsonb" | "json" => Some("serde_json::Value".to_string()), - "text" | "varchar" | "name" | "citext" => Some("String".to_string()), + "datetime" => Some("chrono::NaiveDateTime".to_string()), + "timestamp" => Some("chrono::DateTime".to_string()), "time" => Some("chrono::NaiveTime".to_string()), - "timestamp" => Some("chrono::NaiveDateTime".to_string()), - "timestamptz" => Some("chrono::DateTime".to_string()), + + // Decimal type + "decimal" => Some("rust_decimal::Decimal".to_string()), + + // UUID type: MySQL often stores UUIDs as BINARY(16) "uuid" => Some("uuid::Uuid".to_string()), + + // JSON type: maps to a serde_json value. + "json" => Some("serde_json::JsonValue".to_string()), + _ => None, } } diff --git a/src/mysql/queries/get_enums.rs b/src/mysql/queries/get_enums.rs index dfe5927..0e58511 100644 --- a/src/mysql/queries/get_enums.rs +++ b/src/mysql/queries/get_enums.rs @@ -12,12 +12,12 @@ pub async fn get_mysql_enums(pool: &MySqlPool) -> Result, sqlx:: WITH RECURSIVE enum_split AS ( -- Initial row: extract the full list of enum values from COLUMN_TYPE SELECT - c.TABLE_SCHEMA AS `schema`, + CAST(c.TABLE_SCHEMA AS CHAR) AS `schema`, CONCAT(c.TABLE_NAME, '.', c.COLUMN_NAME) AS enum_type, c.COLUMN_TYPE, - c.COLUMN_COMMENT AS enum_type_comment, + CAST(c.COLUMN_COMMENT AS CHAR) AS enum_type_comment, -- Remove the leading "enum(" and trailing ")" then extract the first value. - TRIM(BOTH '\'' FROM SUBSTRING_INDEX(SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1), ',', 1)) AS enum_value, + CAST(TRIM(BOTH '\'' FROM SUBSTRING_INDEX(SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1), ',', 1)) AS CHAR) AS enum_value, CASE WHEN LOCATE(',', SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1)) > 0 THEN TRIM(LEADING ' ' FROM SUBSTRING( diff --git a/src/mysql/queries/get_enums_test.rs b/src/mysql/queries/get_enums_test.rs index aeae23c..70095a5 100644 --- a/src/mysql/queries/get_enums_test.rs +++ b/src/mysql/queries/get_enums_test.rs @@ -7,7 +7,7 @@ use std::error::Error; #[tokio::test] async fn test_get_mysql_enums() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; sqlx::query("DROP TABLE IF EXISTS test_table;") .execute(&pool) @@ -51,7 +51,7 @@ async fn test_get_mysql_enums() -> Result<(), Box> { #[tokio::test] async fn test_get_mysql_enums_with_comments() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; // Clean up any existing type // Drop any existing table that uses the enum. sqlx::query("DROP TABLE IF EXISTS test_weather;") diff --git a/src/mysql/queries/get_tables.rs b/src/mysql/queries/get_tables.rs index e8d6080..15b7a22 100644 --- a/src/mysql/queries/get_tables.rs +++ b/src/mysql/queries/get_tables.rs @@ -17,17 +17,17 @@ pub async fn get_tables( // get all tables from the database let query = " SELECT - c.TABLE_NAME AS table_name, - c.COLUMN_NAME AS column_name, - c.COLUMN_TYPE AS udt_name, - c.DATA_TYPE AS data_type, + CAST(c.TABLE_NAME AS CHAR) AS table_name, + CAST(c.COLUMN_NAME AS CHAR) AS column_name, + CAST(c.COLUMN_TYPE AS CHAR) AS udt_name, + CAST(c.DATA_TYPE AS CHAR) AS data_type, '' AS table_schema, (c.IS_NULLABLE = 'YES') AS is_nullable, (c.COLUMN_KEY = 'PRI') AS is_primary_key, (c.COLUMN_KEY = 'UNI') AS is_unique, - kcu.REFERENCED_TABLE_NAME AS foreign_key_table, + CAST(kcu.REFERENCED_TABLE_NAME AS CHAR) AS foreign_key_table, kcu.REFERENCED_COLUMN_NAME AS foreign_key_id, - NULLIF(c.COLUMN_COMMENT, '') AS column_comment, + NULLIF(CAST(c.COLUMN_COMMENT as CHAR), '') AS column_comment, NULLIF(t.TABLE_COMMENT, '') AS table_comment, CASE WHEN c.COLUMN_DEFAULT IS NOT NULL diff --git a/src/mysql/queries/get_tables_test.rs b/src/mysql/queries/get_tables_test.rs index 445967f..f9dad45 100644 --- a/src/mysql/queries/get_tables_test.rs +++ b/src/mysql/queries/get_tables_test.rs @@ -26,7 +26,7 @@ async fn test_table( #[tokio::test] async fn test_basic_mysql_tables() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; test_table( &pool, &["CREATE TABLE test_table_0 ( @@ -40,18 +40,23 @@ async fn test_basic_mysql_tables() -> Result<(), Box> { table_name: "test_table_0".to_string(), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), - TableColumnBuilder::new("name", "varchar(255)", "varchar") - .is_unique() + TableColumnBuilder::new( + "name", + "varchar(255)", + "varchar", + Some("String".to_string()), + ) + .is_unique() + .is_nullable() + .build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), - TableColumnBuilder::new("description", "text", "text") - .is_nullable() - .build(), - TableColumnBuilder::new("parent_id", "int", "int") + TableColumnBuilder::new("parent_id", "int", "int", Some("i32".to_string())) .is_nullable() .foreign_key_table("test_table_0") .foreign_key_id("id") @@ -67,7 +72,7 @@ async fn test_basic_mysql_tables() -> Result<(), Box> { #[tokio::test] async fn test_basic_mysql_tables_with_comments() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; test_table( &pool, &["CREATE TABLE test_table_with_comments ( @@ -80,16 +85,21 @@ async fn test_basic_mysql_tables_with_comments() -> Result<(), Box> { table_comment: Some("Some test table comment".to_string()), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .add_column_comment("Some test table column comment") .build(), - TableColumnBuilder::new("name", "varchar(255)", "varchar") - .is_unique() - .is_nullable() - .build(), - TableColumnBuilder::new("description", "text", "text") + TableColumnBuilder::new( + "name", + "varchar(255)", + "varchar", + Some("String".to_string()), + ) + .is_unique() + .is_nullable() + .build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), ], @@ -102,7 +112,7 @@ async fn test_basic_mysql_tables_with_comments() -> Result<(), Box> { #[tokio::test] async fn test_basic_mysql_table_with_array() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; // MySQL does not support array types. We use a JSON column instead. test_table( &pool, @@ -114,14 +124,19 @@ async fn test_basic_mysql_table_with_array() -> Result<(), Box> { table_name: "test_table_1".to_string(), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), // Note: instead of an array, we expect a JSON type without array depth. - TableColumnBuilder::new("names", "json", "json") - .is_nullable() - .build(), + TableColumnBuilder::new( + "names", + "json", + "json", + Some("serde_json::JsonValue".to_string()), + ) + .is_nullable() + .build(), ], ..Default::default() }], @@ -133,7 +148,7 @@ async fn test_basic_mysql_table_with_array() -> Result<(), Box> { #[tokio::test] async fn test_mysql_table_with_custom_type() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; // MySQL does not support custom types (CREATE TYPE). Instead, define ENUM directly. test_table( @@ -146,12 +161,12 @@ async fn test_mysql_table_with_custom_type() -> Result<(), Box> { table_name: "test_orders_status_0".to_string(), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), // The expected type is now 'enum' instead of a custom type. - TableColumnBuilder::new("order_status", "order_status", "enum").build(), + TableColumnBuilder::new("order_status", "order_status", "enum", None).build(), ], ..Default::default() }], diff --git a/src/mysql/test_helper.rs b/src/mysql/test_helper.rs index 86a9070..892cf5c 100644 --- a/src/mysql/test_helper.rs +++ b/src/mysql/test_helper.rs @@ -45,7 +45,7 @@ impl ContainerGuard { } } -pub async fn setup_mysql_db() -> MySqlPool { +pub async fn setup_mysql_db() -> (MySqlPool, String) { // Ensure exactly one container guard is created (across threads within the same process) let _guard = CONTAINER_GUARD.get_or_init(|| Arc::new(ContainerGuard::new())); @@ -67,11 +67,12 @@ pub async fn setup_mysql_db() -> MySqlPool { let test_db_url = format!("mysql://root:root@localhost:3307/{}", db_name); - MySqlPoolOptions::new() + let pool = MySqlPoolOptions::new() .max_connections(5) .connect(&test_db_url) .await - .expect("Failed to connect to test database") + .expect("Failed to connect to test database"); + (pool, test_db_url) } async fn wait_for_mysql_ready(db_url: &str) -> MySqlPool { diff --git a/src/postgres/queries/convert_type.rs b/src/postgres/queries/convert_type.rs index c1738dd..3cac28b 100644 --- a/src/postgres/queries/convert_type.rs +++ b/src/postgres/queries/convert_type.rs @@ -20,11 +20,22 @@ pub fn convert_data_type(udt_type: &str) -> Option { "int8" | "bigint" | "bigserial" => Some("i64".to_string()), "void" => Some("()".to_string()), "jsonb" | "json" => Some("serde_json::Value".to_string()), - "text" | "varchar" | "name" | "citext" => Some("String".to_string()), + "text" | "varchar" | "name" => Some("String".to_string()), "time" => Some("chrono::NaiveTime".to_string()), "timestamp" => Some("chrono::NaiveDateTime".to_string()), "timestamptz" => Some("chrono::DateTime".to_string()), "uuid" => Some("uuid::Uuid".to_string()), + "cube" => Some("sqlx::postgres::types::PgCube".to_string()), + "point" => Some("sqlx::postgres::types::PgPoint".to_string()), + "line" => Some("sqlx::postgres::types::PgLine".to_string()), + "money" => Some("sqlx::postgres::types::PgMoney".to_string()), + "interval" => Some("sqlx::postgres::types::PgInterval".to_string()), + "ltree" => Some("sqlx::postgres::types::PgLTree".to_string()), + "lquery" => Some("sqlx::postgres::types::PgLQuery".to_string()), + "citext" => Some("sqlx::postgres::types::PgCiText".to_string()), + "hstore" => Some("sqlx::postgres::types::PgHstore".to_string()), + "bit" | "varbit" => Some("bit_vec::BitVec".to_string()), + "macaddr" => Some("mac_address::MacAddress".to_string()), _ => None, } } diff --git a/src/postgres/queries/get_tables_test.rs b/src/postgres/queries/get_tables_test.rs index 953bca6..2a5e730 100644 --- a/src/postgres/queries/get_tables_test.rs +++ b/src/postgres/queries/get_tables_test.rs @@ -33,10 +33,10 @@ async fn test_basic_postgres_tables() -> Result<(), Box> { table_name: "test_table_0".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer").is_primary_key().is_auto_populated().build(), - TableColumnBuilder::new("name", "varchar", "character varying").is_unique().is_nullable().build(), - TableColumnBuilder::new("description", "text", "text").is_nullable().build(), - TableColumnBuilder::new("parent_id", "int4", "integer").is_nullable().foreign_key_table("test_table_0").foreign_key_id("id").build(), + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())).is_primary_key().is_auto_populated().build(), + TableColumnBuilder::new("name", "varchar", "character varying", Some("String".to_string())).is_unique().is_nullable().build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())).is_nullable().build(), + TableColumnBuilder::new("parent_id", "int4", "integer", Some("i32".to_string())).is_nullable().foreign_key_table("test_table_0").foreign_key_id("id").build(), ], ..Default::default() }], @@ -65,16 +65,21 @@ async fn test_basic_postgres_tables_with_comments() -> Result<(), Box table_comment: Some("Some test table comment".to_string()), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer") + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .add_column_comment("Some test table column comment") .build(), - TableColumnBuilder::new("name", "varchar", "character varying") - .is_unique() - .is_nullable() - .build(), - TableColumnBuilder::new("description", "text", "text") + TableColumnBuilder::new( + "name", + "varchar", + "character varying", + Some("String".to_string()), + ) + .is_unique() + .is_nullable() + .build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), ], @@ -95,11 +100,11 @@ async fn test_basic_postgres_table_with_array() -> Result<(), Box> { table_name: "test_table_1".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer") + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), - TableColumnBuilder::new("names", "_text", "ARRAY") + TableColumnBuilder::new("names", "_text", "ARRAY", Some("String".to_string())) .is_nullable() .array_depth(1) .build(), @@ -130,11 +135,11 @@ async fn test_postgres_table_with_custom_type() -> Result<(), Box> { table_name: "test_orders_status_0".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer") + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), - TableColumnBuilder::new("order_status", "status", "USER-DEFINED").build(), + TableColumnBuilder::new("order_status", "status", "USER-DEFINED", None).build(), ], ..Default::default() }], diff --git a/src/tests.rs b/src/tests.rs index 50dd074..57a96e7 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,43 +1,43 @@ -use clap::Parser as _; -use pretty_assertions::assert_eq; -use sqlx::query; - -use crate::{generate_rust_from_database, postgres::test_helper::setup_pg_db, Cli}; -use std::error::Error; - -#[tokio::test] -async fn test_basic_postgres_tables() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement = "CREATE TABLE test_table_0 (id SERIAL PRIMARY KEY, name VARCHAR(255) UNIQUE, description TEXT, parent_id INTEGER REFERENCES test_table_0 (id));"; - query(statement).execute(&pool).await?; - - let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); - - let writer = generate_rust_from_database(&args).await; - - assert_eq!( - writer.write_to_string().trim(), - r#"#[derive(Debug, Clone, sqlx::FromRow)] +mod postgres { + use clap::Parser as _; + use pretty_assertions::assert_eq; + use sqlx::query; + + use crate::{generate_rust_from_database, postgres::test_helper::setup_pg_db, Cli}; + use std::error::Error; + #[tokio::test] + async fn test_basic_postgres_tables() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement = "CREATE TABLE test_table_0 (id SERIAL PRIMARY KEY, name VARCHAR(255) UNIQUE, description TEXT, parent_id INTEGER REFERENCES test_table_0 (id));"; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#"#[derive(Debug, Clone, sqlx::FromRow)] pub struct TestTable0 { id: i32, name: Option, description: Option, parent_id: Option, }"# - .to_string() - ); + .to_string() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_basic_postgres_table_with_enum() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_basic_postgres_table_with_enum() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -48,22 +48,22 @@ CREATE TABLE todos ( ); "; - let statement_3 = " + let statement_3 = " -- Add a comment to the table. COMMENT ON TABLE todos IS 'Table to store todo items with tags and status information.'; "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; - query(statement_3).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; + query(statement_3).execute(&pool).await?; - let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, PartialEq, sqlx::Type)] #[sqlx(type_name = "todo_status")] pub enum TodoStatus { @@ -85,21 +85,21 @@ pub struct Todo { status: TodoStatus, } "# - .to_string() - .trim() - ); + .to_string() + .trim() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_type_override_for_enum_type() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_type_override_for_enum_type() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -110,22 +110,22 @@ CREATE TABLE todos ( ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--type-overrides", - "todo_status=String", - ]); + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--type-overrides", + "todo_status=String", + ]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, sqlx::FromRow)] pub struct Todo { id: i32, @@ -135,21 +135,21 @@ pub struct Todo { status: String, } "# - .to_string() - .trim() - ); + .to_string() + .trim() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_field_override() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_field_override() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -160,22 +160,22 @@ CREATE TABLE todos ( ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--table-overrides", - "status=String", - ]); + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "status=String", + ]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, sqlx::FromRow)] pub struct Todo { id: i32, @@ -185,21 +185,21 @@ pub struct Todo { status: String, } "# - .to_string() - .trim() - ); + .to_string() + .trim() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_table_specific_field_override() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_table_specific_field_override() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -210,30 +210,30 @@ CREATE TABLE todos ( ); "; - let statement_3 = " + let statement_3 = " CREATE TABLE other_todos_table ( id SERIAL PRIMARY KEY, -- Primary key status todo_status NOT NULL DEFAULT 'pending' -- Enum field with a default value ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; - query(statement_3).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; + query(statement_3).execute(&pool).await?; - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--table-overrides", - "todos.status=String,toto.status=i32", - ]); + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "todos.status=String,toto.status=i32", + ]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, PartialEq, sqlx::Type)] #[sqlx(type_name = "todo_status")] pub enum TodoStatus { @@ -262,66 +262,66 @@ pub struct Todo { "# - .to_string() - .trim() - ); - - Ok(()) -} - -/// Test using the include_tables flag: if multiple tables exist, only the specified table is generated. -#[tokio::test] -async fn test_include_tables_filter() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - // Create two tables in the database. - let statement1 = "CREATE TABLE table_one (id SERIAL PRIMARY KEY, data TEXT);"; - let statement2 = "CREATE TABLE table_two (id SERIAL PRIMARY KEY, info TEXT);"; - query(statement1).execute(&pool).await?; - query(statement2).execute(&pool).await?; - - // Only include table_one in the generated output. - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--include-tables", - "table_one", - ]); - let writer = generate_rust_from_database(&args).await; - let expected = r#" + .to_string() + .trim() + ); + + Ok(()) + } + + /// Test using the include_tables flag: if multiple tables exist, only the specified table is generated. + #[tokio::test] + async fn test_include_tables_filter() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + // Create two tables in the database. + let statement1 = "CREATE TABLE table_one (id SERIAL PRIMARY KEY, data TEXT);"; + let statement2 = "CREATE TABLE table_two (id SERIAL PRIMARY KEY, info TEXT);"; + query(statement1).execute(&pool).await?; + query(statement2).execute(&pool).await?; + + // Only include table_one in the generated output. + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--include-tables", + "table_one", + ]); + let writer = generate_rust_from_database(&args).await; + let expected = r#" #[derive(Debug, Clone, sqlx::FromRow)] pub struct TableOne { id: i32, data: Option, } "#; - assert_eq!(writer.write_to_string().trim(), expected.trim()); - Ok(()) -} - -/// Test passing extra derives for an enum type using the enum-derive flag. -#[tokio::test] -async fn test_enum_derives_flag() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = "CREATE TYPE color AS ENUM ('red', 'green', 'blue');"; - let statement_2 = " + assert_eq!(writer.write_to_string().trim(), expected.trim()); + Ok(()) + } + + /// Test passing extra derives for an enum type using the enum-derive flag. + #[tokio::test] + async fn test_enum_derives_flag() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = "CREATE TYPE color AS ENUM ('red', 'green', 'blue');"; + let statement_2 = " CREATE TABLE palette ( id SERIAL PRIMARY KEY, favorite color NOT NULL ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; - - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--enum-derive", - "Debug,PartialEq", - ]); - let writer = generate_rust_from_database(&args).await; - let expected = r#" + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--enum-derive", + "Debug,PartialEq", + ]); + let writer = generate_rust_from_database(&args).await; + let expected = r#" #[derive(Debug, PartialEq)] #[sqlx(type_name = "color")] pub enum Color { @@ -339,32 +339,233 @@ pub struct Palette { favorite: Color, } "#; - assert_eq!(writer.write_to_string().trim(), expected.trim()); - Ok(()) -} - -/// Test passing extra derives for the model using the model-derive flag. -#[tokio::test] -async fn test_model_derives_flag() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement = "CREATE TABLE simple (id SERIAL PRIMARY KEY, value TEXT);"; - query(statement).execute(&pool).await?; - - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--model-derive", - "Debug,Clone,PartialEq", - ]); - let writer = generate_rust_from_database(&args).await; - let expected = r#" + assert_eq!(writer.write_to_string().trim(), expected.trim()); + Ok(()) + } + + /// Test passing extra derives for the model using the model-derive flag. + #[tokio::test] + async fn test_model_derives_flag() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement = "CREATE TABLE simple (id SERIAL PRIMARY KEY, value TEXT);"; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--model-derive", + "Debug,Clone,PartialEq", + ]); + let writer = generate_rust_from_database(&args).await; + let expected = r#" #[derive(Debug, Clone, PartialEq)] pub struct Simple { id: i32, value: Option, } "#; - assert_eq!(writer.write_to_string().trim(), expected.trim()); - Ok(()) + assert_eq!(writer.write_to_string().trim(), expected.trim()); + Ok(()) + } +} + +mod mysql { + use clap::Parser as _; + use pretty_assertions::assert_eq; + use sqlx::query; + + use crate::{generate_rust_from_database, mysql::test_helper::setup_mysql_db, Cli}; + use std::error::Error; + + #[tokio::test] + async fn test_basic_mysql_tables() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + let statement = r#" + CREATE TABLE test_table_0 ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) UNIQUE, + description TEXT, + parent_id INT, + FOREIGN KEY (parent_id) REFERENCES test_table_0(id) + ); + "#; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#"#[derive(Debug, Clone, sqlx::FromRow)] +pub struct TestTable0 { + id: i32, + name: Option, + description: Option, + parent_id: Option, +}"# + .to_string() + ); + + Ok(()) + } + + #[tokio::test] + async fn test_basic_mysql_table_with_enum() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + // No separate CREATE TYPE needed for MySQL. + let statement = r#" + CREATE TABLE todos ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + tags JSON NOT NULL, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ) COMMENT='Table to store todo items with tags and status information.'; + "#; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#" +#[derive(Debug, Clone, PartialEq, sqlx::Type)] +pub enum TodoStatus { + #[sqlx(rename = "completed")] + Completed, + #[sqlx(rename = "in_progress")] + InProgress, + #[sqlx(rename = "pending")] + Pending, +} + +/// Table to store todo items with tags and status information. +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Todo { + id: i32, + title: String, + description: Option, + tags: serde_json::JsonValue, + status: TodoStatus, +} + "# + .to_string() + .trim() + ); + + Ok(()) + } + + #[tokio::test] + async fn test_field_override() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + let statement = r#" + CREATE TABLE todos ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + tags JSON NOT NULL, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ); + "#; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "status=String", + ]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#" +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Todo { + id: i32, + title: String, + description: Option, + tags: serde_json::JsonValue, + status: String, +} + "# + .to_string() + .trim() + ); + + Ok(()) + } + + #[tokio::test] + async fn test_table_specific_field_override() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + let statement_todos = r#" + CREATE TABLE todos ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + tags JSON NOT NULL, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ); + "#; + let statement_other = r#" + CREATE TABLE other_todos_table ( + id INT AUTO_INCREMENT PRIMARY KEY, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ); + "#; + query(statement_todos).execute(&pool).await?; + query(statement_other).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "todos.status=String,toto.status=i32", + ]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#" +#[derive(Debug, Clone, PartialEq, sqlx::Type)] +pub enum OtherTodosTableStatus { + #[sqlx(rename = "completed")] + Completed, + #[sqlx(rename = "in_progress")] + InProgress, + #[sqlx(rename = "pending")] + Pending, +} + +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct OtherTodosTable { + id: i32, + status: OtherTodosTableStatus, +} + +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Todo { + id: i32, + title: String, + description: Option, + tags: serde_json::JsonValue, + status: String, +} + "# + .to_string() + .trim() + ); + + Ok(()) + } }