From 7bd5b713881a857d1df0e2029455bec1066f410c Mon Sep 17 00:00:00 2001 From: James Holman Date: Wed, 14 Aug 2024 19:58:47 +1000 Subject: [PATCH 1/4] feat: cube support --- Cargo.toml | 4 ++-- src/utils.rs | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fbf0cad..220e12c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,8 @@ description = "A CLI tool for generating models based on a SQL Database using SQ license = "MIT" [dependencies] -sqlx = { version = "0.7", features = ["postgres","runtime-tokio"] } -sqlx-cli = "0.7" +sqlx = { version = "0.8", features = ["postgres","runtime-tokio"] } +sqlx-cli = "0.8" clap = "3.0" regex = "1.5" chrono = "0.4" diff --git a/src/utils.rs b/src/utils.rs index cc051cd..314a863 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -73,6 +73,7 @@ pub fn convert_data_type(data_type: &str) -> String { "timestamp" => "chrono::NaiveDateTime", "timestamptz" => "chrono::DateTime", "uuid" => "uuid::Uuid", + "cube" => "sqlx::postgres::types::PgCube", _ => panic!("Unknown type: {}", data_type), } .to_string() From cbd4eff2bd5f1dd43d929d27378a15be3272ab40 Mon Sep 17 00:00:00 2001 From: James Holman Date: Tue, 25 Mar 2025 19:20:14 +1100 Subject: [PATCH 2/4] fix: support more types --- src/mysql/queries/convert_type.rs | 13 ++- src/utils.rs | 174 ------------------------------ 2 files changed, 12 insertions(+), 175 deletions(-) delete mode 100644 src/utils.rs diff --git a/src/mysql/queries/convert_type.rs b/src/mysql/queries/convert_type.rs index c1738dd..3cac28b 100644 --- a/src/mysql/queries/convert_type.rs +++ b/src/mysql/queries/convert_type.rs @@ -20,11 +20,22 @@ pub fn convert_data_type(udt_type: &str) -> Option { "int8" | "bigint" | "bigserial" => Some("i64".to_string()), "void" => Some("()".to_string()), "jsonb" | "json" => Some("serde_json::Value".to_string()), - "text" | "varchar" | "name" | "citext" => Some("String".to_string()), + "text" | "varchar" | "name" => Some("String".to_string()), "time" => Some("chrono::NaiveTime".to_string()), "timestamp" => Some("chrono::NaiveDateTime".to_string()), "timestamptz" => Some("chrono::DateTime".to_string()), "uuid" => Some("uuid::Uuid".to_string()), + "cube" => Some("sqlx::postgres::types::PgCube".to_string()), + "point" => Some("sqlx::postgres::types::PgPoint".to_string()), + "line" => Some("sqlx::postgres::types::PgLine".to_string()), + "money" => Some("sqlx::postgres::types::PgMoney".to_string()), + "interval" => Some("sqlx::postgres::types::PgInterval".to_string()), + "ltree" => Some("sqlx::postgres::types::PgLTree".to_string()), + "lquery" => Some("sqlx::postgres::types::PgLQuery".to_string()), + "citext" => Some("sqlx::postgres::types::PgCiText".to_string()), + "hstore" => Some("sqlx::postgres::types::PgHstore".to_string()), + "bit" | "varbit" => Some("bit_vec::BitVec".to_string()), + "macaddr" => Some("mac_address::MacAddress".to_string()), _ => None, } } diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index 314a863..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,174 +0,0 @@ -use crate::models::TableColumn; - -pub(crate) fn to_snake_case(input: &str) -> String { - let mut output = String::new(); - let mut prev_is_uppercase = false; - - for c in input.chars() { - if c.is_ascii_uppercase() { - if !output.is_empty() && !prev_is_uppercase { - output.push('_'); - } - output.extend(c.to_lowercase()); - prev_is_uppercase = true; - } else { - output.push(c); - prev_is_uppercase = false; - } - } - - output -} - -pub fn generate_struct_code(table_name: &str, rows: &Vec) -> String { - let struct_name = to_pascal_case(table_name); - let mut struct_code = String::new(); - - struct_code.push_str("#![allow(dead_code)]\n"); - struct_code.push_str("// Generated with sql-gen\n// https://github.com/jayy-lmao/sql-gen\n\n"); - struct_code.push_str("#[derive(sqlx::FromRow, Debug)]\n"); - struct_code.push_str(&format!("pub struct {} {{\n", struct_name)); - - for row in rows { - if row.table_name == table_name { - let column_name = to_snake_case(&row.column_name); - let mut data_type = convert_data_type(&row.udt_name); - let optional_type = format!("Option<{}>", data_type); - if row.is_nullable { - data_type = optional_type; - } - - struct_code.push_str(&format!(" pub {}: {},\n", column_name, data_type)); - } - } - struct_code.push_str("}\n"); - - struct_code -} - -pub fn convert_data_type(data_type: &str) -> String { - if data_type.to_lowercase().contains("char(") { - return "String".to_string(); - } - if data_type.starts_with("_") { - let array_of_type = convert_data_type(&data_type[1..]); - let vec_type = format!("Vec<{}>", array_of_type); - return vec_type; - } - - match data_type { - "bool" | "boolean" => "bool", - "bytea" => "Vec", // is this right? - "char" | "bpchar" | "character" => "String", - "date" => "chrono::NaiveDate", - "float4" | "real" => "f32", - "float8" | "double precision" => "f64", - "int2" | "smallint" | "smallserial" => "i16", - "int4" | "int" | "serial" => "i32", - "int8" | "bigint" | "bigserial" => "i64", - "void" => "()", - "jsonb" | "json" => "serde_json::Value", - "text" | "varchar" | "name" | "citext" => "String", - "time" => "chrono::NaiveTime", - "timestamp" => "chrono::NaiveDateTime", - "timestamptz" => "chrono::DateTime", - "uuid" => "uuid::Uuid", - "cube" => "sqlx::postgres::types::PgCube", - _ => panic!("Unknown type: {}", data_type), - } - .to_string() -} - -pub fn convert_data_type_from_pg(data_type: &str) -> String { - if data_type.contains("Json<") { - return "jsonb".to_string(); - } - if data_type.contains("Vec<") { - let array_type = convert_data_type_from_pg(&data_type[4..data_type.len() - 1]); - return format!("{}[]", array_type); - } - match data_type { - "i64" => "int8", - "i32" => "int4", - "i16" => "int2", - "String" => "text", - "serde_json::Value" => "jsonb", - "chrono::DateTime" => "timestamptz", - "chrono::NaiveDateTime" => "timestamp", - "DateTime" => "timestamptz", - "chrono::NaiveDate" => "date", - "f32" => "float4", - "f64" => "float8", - "uuid::Uuid" => "uuid", - "bool" => "boolean", - "Vec" => "bytea", // is this right ? - _ => panic!("Unknown type: {}", data_type), - } - .to_string() -} - -fn generate_query_code(_row: &TableColumn) -> String { - // ... (implementation of generate_query_code) - // query_code - todo!() -} - -pub fn parse_struct_fields(struct_code: &str) -> Vec<(String, String, bool)> { - let lines = struct_code.lines(); - let mut fields = Vec::new(); - - for line in lines { - let trimmed_line = line.trim(); - if !trimmed_line.starts_with("pub") { - continue; - } - - let parts: Vec<&str> = trimmed_line.split(": ").collect(); - if parts.len() != 2 { - continue; - } - - let field = parts[0].trim().trim_start_matches("pub").trim(); - //let data_type_optional = parts[1].trim().trim_end_matches(",").trim(); - let mut is_nullable = false; - - let data_type = if parts[1].trim().starts_with("Option") { - is_nullable = true; - parts[1] - .trim() - .trim_start_matches("Option<") - .trim_end_matches(">,") - } else { - parts[1].trim().trim_end_matches(',') - }; - - fields.push((field.to_owned(), data_type.to_owned(), is_nullable)); - } - - fields -} - -#[cfg(test)] -mod tests { - // ... (unit tests can be defined here) -} - -pub fn to_pascal_case(input: &str) -> String { - let mut output = String::new(); - let mut capitalize_next = true; - - for c in input.chars() { - if c.is_ascii_alphanumeric() { - if capitalize_next { - output.extend(c.to_uppercase()); - capitalize_next = false; - } else { - output.push(c); - } - } else { - capitalize_next = true; - } - } - - output -} From e904ebda496817d141282a24e06d605451725747 Mon Sep 17 00:00:00 2001 From: James Holman Date: Tue, 25 Mar 2025 19:24:58 +1100 Subject: [PATCH 3/4] fix: mysql types --- src/mysql/queries/convert_type.rs | 57 ++++++++++++++++------------ src/postgres/queries/convert_type.rs | 13 ++++++- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/mysql/queries/convert_type.rs b/src/mysql/queries/convert_type.rs index 3cac28b..b26ab1b 100644 --- a/src/mysql/queries/convert_type.rs +++ b/src/mysql/queries/convert_type.rs @@ -9,33 +9,42 @@ pub fn convert_data_type(udt_type: &str) -> Option { } match udt_type { - "bool" | "boolean" => Some("bool".to_string()), - "bytea" => Some("u8".to_string()), // is this right? - "char" | "bpchar" | "character" => Some("String".to_string()), + // Boolean: MySQL treats TINYINT(1), BOOLEAN and BOOL as booleans. + "bool" | "boolean" | "tinyint(1)" => Some("bool".to_string()), + + // Numeric types + "tinyint unsigned" => Some("u8".to_string()), + "tinyint" => Some("i8".to_string()), + "smallint unsigned" => Some("u16".to_string()), + "smallint" => Some("i16".to_string()), + "int unsigned" => Some("u32".to_string()), + "int" => Some("i32".to_string()), + "bigint unsigned" => Some("u64".to_string()), + "bigint" => Some("i64".to_string()), + "float" => Some("f32".to_string()), + "double" => Some("f64".to_string()), + + // String types: VARCHAR, CHAR and TEXT map to Rust’s String. + "varchar" | "char" | "text" => Some("String".to_string()), + + // Binary types: VARBINARY, BINARY and BLOB map to Vec + "varbinary" | "binary" | "blob" => Some("Vec".to_string()), + + // Date and time types "date" => Some("chrono::NaiveDate".to_string()), - "float4" | "real" => Some("f32".to_string()), - "float8" | "double precision" => Some("f64".to_string()), - "int2" | "smallint" | "smallserial" => Some("i16".to_string()), - "int4" | "int" | "serial" => Some("i32".to_string()), - "int8" | "bigint" | "bigserial" => Some("i64".to_string()), - "void" => Some("()".to_string()), - "jsonb" | "json" => Some("serde_json::Value".to_string()), - "text" | "varchar" | "name" => Some("String".to_string()), + "datetime" => Some("chrono::NaiveDateTime".to_string()), + "timestamp" => Some("chrono::DateTime".to_string()), "time" => Some("chrono::NaiveTime".to_string()), - "timestamp" => Some("chrono::NaiveDateTime".to_string()), - "timestamptz" => Some("chrono::DateTime".to_string()), + + // Decimal type + "decimal" => Some("rust_decimal::Decimal".to_string()), + + // UUID type: MySQL often stores UUIDs as BINARY(16) "uuid" => Some("uuid::Uuid".to_string()), - "cube" => Some("sqlx::postgres::types::PgCube".to_string()), - "point" => Some("sqlx::postgres::types::PgPoint".to_string()), - "line" => Some("sqlx::postgres::types::PgLine".to_string()), - "money" => Some("sqlx::postgres::types::PgMoney".to_string()), - "interval" => Some("sqlx::postgres::types::PgInterval".to_string()), - "ltree" => Some("sqlx::postgres::types::PgLTree".to_string()), - "lquery" => Some("sqlx::postgres::types::PgLQuery".to_string()), - "citext" => Some("sqlx::postgres::types::PgCiText".to_string()), - "hstore" => Some("sqlx::postgres::types::PgHstore".to_string()), - "bit" | "varbit" => Some("bit_vec::BitVec".to_string()), - "macaddr" => Some("mac_address::MacAddress".to_string()), + + // JSON type: maps to a serde_json value. + "json" => Some("serde_json::JsonValue".to_string()), + _ => None, } } diff --git a/src/postgres/queries/convert_type.rs b/src/postgres/queries/convert_type.rs index c1738dd..3cac28b 100644 --- a/src/postgres/queries/convert_type.rs +++ b/src/postgres/queries/convert_type.rs @@ -20,11 +20,22 @@ pub fn convert_data_type(udt_type: &str) -> Option { "int8" | "bigint" | "bigserial" => Some("i64".to_string()), "void" => Some("()".to_string()), "jsonb" | "json" => Some("serde_json::Value".to_string()), - "text" | "varchar" | "name" | "citext" => Some("String".to_string()), + "text" | "varchar" | "name" => Some("String".to_string()), "time" => Some("chrono::NaiveTime".to_string()), "timestamp" => Some("chrono::NaiveDateTime".to_string()), "timestamptz" => Some("chrono::DateTime".to_string()), "uuid" => Some("uuid::Uuid".to_string()), + "cube" => Some("sqlx::postgres::types::PgCube".to_string()), + "point" => Some("sqlx::postgres::types::PgPoint".to_string()), + "line" => Some("sqlx::postgres::types::PgLine".to_string()), + "money" => Some("sqlx::postgres::types::PgMoney".to_string()), + "interval" => Some("sqlx::postgres::types::PgInterval".to_string()), + "ltree" => Some("sqlx::postgres::types::PgLTree".to_string()), + "lquery" => Some("sqlx::postgres::types::PgLQuery".to_string()), + "citext" => Some("sqlx::postgres::types::PgCiText".to_string()), + "hstore" => Some("sqlx::postgres::types::PgHstore".to_string()), + "bit" | "varbit" => Some("bit_vec::BitVec".to_string()), + "macaddr" => Some("mac_address::MacAddress".to_string()), _ => None, } } From 2781d19d5e49de0198e983129e1671cb6827cbfc Mon Sep 17 00:00:00 2001 From: James Holman Date: Wed, 26 Mar 2025 22:38:55 +1100 Subject: [PATCH 4/4] fix: mysql --- src/core/models/db.rs | 5 +- .../convert_table_to_struct_test.rs | 52 +- src/mysql/mod.rs | 2 +- src/mysql/queries/get_enums.rs | 6 +- src/mysql/queries/get_enums_test.rs | 4 +- src/mysql/queries/get_tables.rs | 12 +- src/mysql/queries/get_tables_test.rs | 61 +- src/mysql/test_helper.rs | 7 +- src/postgres/queries/get_tables_test.rs | 33 +- src/tests.rs | 561 ++++++++++++------ 10 files changed, 494 insertions(+), 249 deletions(-) diff --git a/src/core/models/db.rs b/src/core/models/db.rs index cac9726..51e4ab0 100644 --- a/src/core/models/db.rs +++ b/src/core/models/db.rs @@ -43,6 +43,7 @@ pub struct CustomEnum { pub struct TableColumnBuilder { column_name: String, column_comment: Option, + recommended_rust_type: Option, udt_name: String, data_type: String, is_nullable: bool, @@ -59,6 +60,7 @@ impl TableColumnBuilder { column_name: impl ToString, udt_name: impl ToString, data_type: impl ToString, + recommended_rust_type: Option, ) -> Self { Self { column_name: column_name.to_string(), @@ -72,6 +74,7 @@ impl TableColumnBuilder { foreign_key_id: None, is_auto_populated: false, array_depth: 0, + recommended_rust_type, } } @@ -117,7 +120,7 @@ impl TableColumnBuilder { pub fn build(self) -> TableColumn { TableColumn { column_name: self.column_name, - recommended_rust_type: convert_data_type(&self.udt_name), + recommended_rust_type: self.recommended_rust_type, udt_name: self.udt_name, data_type: self.data_type, is_nullable: self.is_nullable, diff --git a/src/core/translators/convert_table_to_struct_test.rs b/src/core/translators/convert_table_to_struct_test.rs index 4c02a9a..7427037 100644 --- a/src/core/translators/convert_table_to_struct_test.rs +++ b/src/core/translators/convert_table_to_struct_test.rs @@ -88,7 +88,9 @@ fn should_convert_table_with_basic_column() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("title", "text", "text").build()], + columns: vec![ + TableColumnBuilder::new("title", "text", "text", Some("String".to_string())).build(), + ], ..Default::default() }; @@ -117,14 +119,14 @@ fn should_convert_table_with_each_column_attribute_type() { table_name: "products".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "uuid", "uuid") + TableColumnBuilder::new("id", "uuid", "uuid", Some("uuid::Uuid".to_string())) .is_auto_populated() .is_primary_key() .build(), - TableColumnBuilder::new("title", "text", "text") + TableColumnBuilder::new("title", "text", "text", Some("String".to_string())) .is_unique() .build(), - TableColumnBuilder::new("description", "text", "text") + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), ], @@ -170,9 +172,14 @@ fn should_convert_table_with_optional_column() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("description", "text", "text") - .is_nullable() - .build()], + columns: vec![TableColumnBuilder::new( + "description", + "text", + "text", + Some("String".to_string()), + ) + .is_nullable() + .build()], ..Default::default() }; let mut options = CodegenOptions::default(); @@ -200,10 +207,15 @@ fn should_convert_table_with_array_column() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("tags", "_text", "ARRAY") - .is_nullable() - .array_depth(1) - .build()], + columns: vec![TableColumnBuilder::new( + "tags", + "_text", + "ARRAY", + Some("String".to_string()), + ) + .is_nullable() + .array_depth(1) + .build()], ..Default::default() }; let mut options = CodegenOptions::default(); @@ -232,7 +244,9 @@ fn should_convert_table_with_enum_column() { let table = Table { table_name: "orders".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("order_status", "status", "USER-DEFINED").build()], + columns: vec![ + TableColumnBuilder::new("order_status", "status", "USER-DEFINED", None).build(), + ], ..Default::default() }; let enums: Vec = vec![CustomEnum { @@ -277,7 +291,7 @@ fn should_ignore_columns_with_invalid_types() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("title", "badtype", "badtype").build()], + columns: vec![TableColumnBuilder::new("title", "badtype", "badtype", None).build()], ..Default::default() }; let mut options = CodegenOptions::default(); @@ -300,7 +314,9 @@ fn should_convert_table_with_column_type_override() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("id", "i32", "i32").build()], + columns: vec![ + TableColumnBuilder::new("id", "int4", "int4", Some("i32".to_string())).build(), + ], ..Default::default() }; @@ -335,7 +351,9 @@ fn should_convert_table_with_global_type_override() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("id", "int4", "int4").build()], + columns: vec![ + TableColumnBuilder::new("id", "int4", "int4", Some("i32".to_string())).build(), + ], ..Default::default() }; @@ -370,7 +388,9 @@ fn column_override_takes_preference_over_global_type_override() { let table = Table { table_name: "products".to_string(), table_schema: Some("public".to_string()), - columns: vec![TableColumnBuilder::new("price", "int4", "int4").build()], + columns: vec![ + TableColumnBuilder::new("price", "int4", "int4", Some("i32".to_string())).build(), + ], ..Default::default() }; diff --git a/src/mysql/mod.rs b/src/mysql/mod.rs index 1d0e9a7..1224c5a 100644 --- a/src/mysql/mod.rs +++ b/src/mysql/mod.rs @@ -1,4 +1,4 @@ pub mod models; pub mod queries; #[cfg(test)] -mod test_helper; +pub mod test_helper; diff --git a/src/mysql/queries/get_enums.rs b/src/mysql/queries/get_enums.rs index dfe5927..0e58511 100644 --- a/src/mysql/queries/get_enums.rs +++ b/src/mysql/queries/get_enums.rs @@ -12,12 +12,12 @@ pub async fn get_mysql_enums(pool: &MySqlPool) -> Result, sqlx:: WITH RECURSIVE enum_split AS ( -- Initial row: extract the full list of enum values from COLUMN_TYPE SELECT - c.TABLE_SCHEMA AS `schema`, + CAST(c.TABLE_SCHEMA AS CHAR) AS `schema`, CONCAT(c.TABLE_NAME, '.', c.COLUMN_NAME) AS enum_type, c.COLUMN_TYPE, - c.COLUMN_COMMENT AS enum_type_comment, + CAST(c.COLUMN_COMMENT AS CHAR) AS enum_type_comment, -- Remove the leading "enum(" and trailing ")" then extract the first value. - TRIM(BOTH '\'' FROM SUBSTRING_INDEX(SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1), ',', 1)) AS enum_value, + CAST(TRIM(BOTH '\'' FROM SUBSTRING_INDEX(SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1), ',', 1)) AS CHAR) AS enum_value, CASE WHEN LOCATE(',', SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1)) > 0 THEN TRIM(LEADING ' ' FROM SUBSTRING( diff --git a/src/mysql/queries/get_enums_test.rs b/src/mysql/queries/get_enums_test.rs index aeae23c..70095a5 100644 --- a/src/mysql/queries/get_enums_test.rs +++ b/src/mysql/queries/get_enums_test.rs @@ -7,7 +7,7 @@ use std::error::Error; #[tokio::test] async fn test_get_mysql_enums() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; sqlx::query("DROP TABLE IF EXISTS test_table;") .execute(&pool) @@ -51,7 +51,7 @@ async fn test_get_mysql_enums() -> Result<(), Box> { #[tokio::test] async fn test_get_mysql_enums_with_comments() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; // Clean up any existing type // Drop any existing table that uses the enum. sqlx::query("DROP TABLE IF EXISTS test_weather;") diff --git a/src/mysql/queries/get_tables.rs b/src/mysql/queries/get_tables.rs index e8d6080..15b7a22 100644 --- a/src/mysql/queries/get_tables.rs +++ b/src/mysql/queries/get_tables.rs @@ -17,17 +17,17 @@ pub async fn get_tables( // get all tables from the database let query = " SELECT - c.TABLE_NAME AS table_name, - c.COLUMN_NAME AS column_name, - c.COLUMN_TYPE AS udt_name, - c.DATA_TYPE AS data_type, + CAST(c.TABLE_NAME AS CHAR) AS table_name, + CAST(c.COLUMN_NAME AS CHAR) AS column_name, + CAST(c.COLUMN_TYPE AS CHAR) AS udt_name, + CAST(c.DATA_TYPE AS CHAR) AS data_type, '' AS table_schema, (c.IS_NULLABLE = 'YES') AS is_nullable, (c.COLUMN_KEY = 'PRI') AS is_primary_key, (c.COLUMN_KEY = 'UNI') AS is_unique, - kcu.REFERENCED_TABLE_NAME AS foreign_key_table, + CAST(kcu.REFERENCED_TABLE_NAME AS CHAR) AS foreign_key_table, kcu.REFERENCED_COLUMN_NAME AS foreign_key_id, - NULLIF(c.COLUMN_COMMENT, '') AS column_comment, + NULLIF(CAST(c.COLUMN_COMMENT as CHAR), '') AS column_comment, NULLIF(t.TABLE_COMMENT, '') AS table_comment, CASE WHEN c.COLUMN_DEFAULT IS NOT NULL diff --git a/src/mysql/queries/get_tables_test.rs b/src/mysql/queries/get_tables_test.rs index 445967f..f9dad45 100644 --- a/src/mysql/queries/get_tables_test.rs +++ b/src/mysql/queries/get_tables_test.rs @@ -26,7 +26,7 @@ async fn test_table( #[tokio::test] async fn test_basic_mysql_tables() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; test_table( &pool, &["CREATE TABLE test_table_0 ( @@ -40,18 +40,23 @@ async fn test_basic_mysql_tables() -> Result<(), Box> { table_name: "test_table_0".to_string(), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), - TableColumnBuilder::new("name", "varchar(255)", "varchar") - .is_unique() + TableColumnBuilder::new( + "name", + "varchar(255)", + "varchar", + Some("String".to_string()), + ) + .is_unique() + .is_nullable() + .build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), - TableColumnBuilder::new("description", "text", "text") - .is_nullable() - .build(), - TableColumnBuilder::new("parent_id", "int", "int") + TableColumnBuilder::new("parent_id", "int", "int", Some("i32".to_string())) .is_nullable() .foreign_key_table("test_table_0") .foreign_key_id("id") @@ -67,7 +72,7 @@ async fn test_basic_mysql_tables() -> Result<(), Box> { #[tokio::test] async fn test_basic_mysql_tables_with_comments() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; test_table( &pool, &["CREATE TABLE test_table_with_comments ( @@ -80,16 +85,21 @@ async fn test_basic_mysql_tables_with_comments() -> Result<(), Box> { table_comment: Some("Some test table comment".to_string()), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .add_column_comment("Some test table column comment") .build(), - TableColumnBuilder::new("name", "varchar(255)", "varchar") - .is_unique() - .is_nullable() - .build(), - TableColumnBuilder::new("description", "text", "text") + TableColumnBuilder::new( + "name", + "varchar(255)", + "varchar", + Some("String".to_string()), + ) + .is_unique() + .is_nullable() + .build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), ], @@ -102,7 +112,7 @@ async fn test_basic_mysql_tables_with_comments() -> Result<(), Box> { #[tokio::test] async fn test_basic_mysql_table_with_array() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; // MySQL does not support array types. We use a JSON column instead. test_table( &pool, @@ -114,14 +124,19 @@ async fn test_basic_mysql_table_with_array() -> Result<(), Box> { table_name: "test_table_1".to_string(), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), // Note: instead of an array, we expect a JSON type without array depth. - TableColumnBuilder::new("names", "json", "json") - .is_nullable() - .build(), + TableColumnBuilder::new( + "names", + "json", + "json", + Some("serde_json::JsonValue".to_string()), + ) + .is_nullable() + .build(), ], ..Default::default() }], @@ -133,7 +148,7 @@ async fn test_basic_mysql_table_with_array() -> Result<(), Box> { #[tokio::test] async fn test_mysql_table_with_custom_type() -> Result<(), Box> { - let pool = setup_mysql_db().await; + let (pool, _) = setup_mysql_db().await; // MySQL does not support custom types (CREATE TYPE). Instead, define ENUM directly. test_table( @@ -146,12 +161,12 @@ async fn test_mysql_table_with_custom_type() -> Result<(), Box> { table_name: "test_orders_status_0".to_string(), table_schema: None, columns: vec![ - TableColumnBuilder::new("id", "int", "int") + TableColumnBuilder::new("id", "int", "int", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), // The expected type is now 'enum' instead of a custom type. - TableColumnBuilder::new("order_status", "order_status", "enum").build(), + TableColumnBuilder::new("order_status", "order_status", "enum", None).build(), ], ..Default::default() }], diff --git a/src/mysql/test_helper.rs b/src/mysql/test_helper.rs index 86a9070..892cf5c 100644 --- a/src/mysql/test_helper.rs +++ b/src/mysql/test_helper.rs @@ -45,7 +45,7 @@ impl ContainerGuard { } } -pub async fn setup_mysql_db() -> MySqlPool { +pub async fn setup_mysql_db() -> (MySqlPool, String) { // Ensure exactly one container guard is created (across threads within the same process) let _guard = CONTAINER_GUARD.get_or_init(|| Arc::new(ContainerGuard::new())); @@ -67,11 +67,12 @@ pub async fn setup_mysql_db() -> MySqlPool { let test_db_url = format!("mysql://root:root@localhost:3307/{}", db_name); - MySqlPoolOptions::new() + let pool = MySqlPoolOptions::new() .max_connections(5) .connect(&test_db_url) .await - .expect("Failed to connect to test database") + .expect("Failed to connect to test database"); + (pool, test_db_url) } async fn wait_for_mysql_ready(db_url: &str) -> MySqlPool { diff --git a/src/postgres/queries/get_tables_test.rs b/src/postgres/queries/get_tables_test.rs index 953bca6..2a5e730 100644 --- a/src/postgres/queries/get_tables_test.rs +++ b/src/postgres/queries/get_tables_test.rs @@ -33,10 +33,10 @@ async fn test_basic_postgres_tables() -> Result<(), Box> { table_name: "test_table_0".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer").is_primary_key().is_auto_populated().build(), - TableColumnBuilder::new("name", "varchar", "character varying").is_unique().is_nullable().build(), - TableColumnBuilder::new("description", "text", "text").is_nullable().build(), - TableColumnBuilder::new("parent_id", "int4", "integer").is_nullable().foreign_key_table("test_table_0").foreign_key_id("id").build(), + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())).is_primary_key().is_auto_populated().build(), + TableColumnBuilder::new("name", "varchar", "character varying", Some("String".to_string())).is_unique().is_nullable().build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())).is_nullable().build(), + TableColumnBuilder::new("parent_id", "int4", "integer", Some("i32".to_string())).is_nullable().foreign_key_table("test_table_0").foreign_key_id("id").build(), ], ..Default::default() }], @@ -65,16 +65,21 @@ async fn test_basic_postgres_tables_with_comments() -> Result<(), Box table_comment: Some("Some test table comment".to_string()), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer") + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .add_column_comment("Some test table column comment") .build(), - TableColumnBuilder::new("name", "varchar", "character varying") - .is_unique() - .is_nullable() - .build(), - TableColumnBuilder::new("description", "text", "text") + TableColumnBuilder::new( + "name", + "varchar", + "character varying", + Some("String".to_string()), + ) + .is_unique() + .is_nullable() + .build(), + TableColumnBuilder::new("description", "text", "text", Some("String".to_string())) .is_nullable() .build(), ], @@ -95,11 +100,11 @@ async fn test_basic_postgres_table_with_array() -> Result<(), Box> { table_name: "test_table_1".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer") + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), - TableColumnBuilder::new("names", "_text", "ARRAY") + TableColumnBuilder::new("names", "_text", "ARRAY", Some("String".to_string())) .is_nullable() .array_depth(1) .build(), @@ -130,11 +135,11 @@ async fn test_postgres_table_with_custom_type() -> Result<(), Box> { table_name: "test_orders_status_0".to_string(), table_schema: Some("public".to_string()), columns: vec![ - TableColumnBuilder::new("id", "int4", "integer") + TableColumnBuilder::new("id", "int4", "integer", Some("i32".to_string())) .is_primary_key() .is_auto_populated() .build(), - TableColumnBuilder::new("order_status", "status", "USER-DEFINED").build(), + TableColumnBuilder::new("order_status", "status", "USER-DEFINED", None).build(), ], ..Default::default() }], diff --git a/src/tests.rs b/src/tests.rs index 50dd074..57a96e7 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,43 +1,43 @@ -use clap::Parser as _; -use pretty_assertions::assert_eq; -use sqlx::query; - -use crate::{generate_rust_from_database, postgres::test_helper::setup_pg_db, Cli}; -use std::error::Error; - -#[tokio::test] -async fn test_basic_postgres_tables() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement = "CREATE TABLE test_table_0 (id SERIAL PRIMARY KEY, name VARCHAR(255) UNIQUE, description TEXT, parent_id INTEGER REFERENCES test_table_0 (id));"; - query(statement).execute(&pool).await?; - - let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); - - let writer = generate_rust_from_database(&args).await; - - assert_eq!( - writer.write_to_string().trim(), - r#"#[derive(Debug, Clone, sqlx::FromRow)] +mod postgres { + use clap::Parser as _; + use pretty_assertions::assert_eq; + use sqlx::query; + + use crate::{generate_rust_from_database, postgres::test_helper::setup_pg_db, Cli}; + use std::error::Error; + #[tokio::test] + async fn test_basic_postgres_tables() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement = "CREATE TABLE test_table_0 (id SERIAL PRIMARY KEY, name VARCHAR(255) UNIQUE, description TEXT, parent_id INTEGER REFERENCES test_table_0 (id));"; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#"#[derive(Debug, Clone, sqlx::FromRow)] pub struct TestTable0 { id: i32, name: Option, description: Option, parent_id: Option, }"# - .to_string() - ); + .to_string() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_basic_postgres_table_with_enum() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_basic_postgres_table_with_enum() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -48,22 +48,22 @@ CREATE TABLE todos ( ); "; - let statement_3 = " + let statement_3 = " -- Add a comment to the table. COMMENT ON TABLE todos IS 'Table to store todo items with tags and status information.'; "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; - query(statement_3).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; + query(statement_3).execute(&pool).await?; - let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, PartialEq, sqlx::Type)] #[sqlx(type_name = "todo_status")] pub enum TodoStatus { @@ -85,21 +85,21 @@ pub struct Todo { status: TodoStatus, } "# - .to_string() - .trim() - ); + .to_string() + .trim() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_type_override_for_enum_type() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_type_override_for_enum_type() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -110,22 +110,22 @@ CREATE TABLE todos ( ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--type-overrides", - "todo_status=String", - ]); + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--type-overrides", + "todo_status=String", + ]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, sqlx::FromRow)] pub struct Todo { id: i32, @@ -135,21 +135,21 @@ pub struct Todo { status: String, } "# - .to_string() - .trim() - ); + .to_string() + .trim() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_field_override() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_field_override() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -160,22 +160,22 @@ CREATE TABLE todos ( ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--table-overrides", - "status=String", - ]); + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "status=String", + ]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, sqlx::FromRow)] pub struct Todo { id: i32, @@ -185,21 +185,21 @@ pub struct Todo { status: String, } "# - .to_string() - .trim() - ); + .to_string() + .trim() + ); - Ok(()) -} + Ok(()) + } -#[tokio::test] -async fn test_table_specific_field_override() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = " + #[tokio::test] + async fn test_table_specific_field_override() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = " -- Create an enum type for todo statuses. CREATE TYPE todo_status AS ENUM ('pending', 'in_progress', 'completed'); "; - let statement_2 = " + let statement_2 = " -- Create the todos table. CREATE TABLE todos ( id SERIAL PRIMARY KEY, -- Primary key @@ -210,30 +210,30 @@ CREATE TABLE todos ( ); "; - let statement_3 = " + let statement_3 = " CREATE TABLE other_todos_table ( id SERIAL PRIMARY KEY, -- Primary key status todo_status NOT NULL DEFAULT 'pending' -- Enum field with a default value ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; - query(statement_3).execute(&pool).await?; + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; + query(statement_3).execute(&pool).await?; - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--table-overrides", - "todos.status=String,toto.status=i32", - ]); + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "todos.status=String,toto.status=i32", + ]); - let writer = generate_rust_from_database(&args).await; + let writer = generate_rust_from_database(&args).await; - assert_eq!( - writer.write_to_string().trim(), - r#" + assert_eq!( + writer.write_to_string().trim(), + r#" #[derive(Debug, Clone, PartialEq, sqlx::Type)] #[sqlx(type_name = "todo_status")] pub enum TodoStatus { @@ -262,66 +262,66 @@ pub struct Todo { "# - .to_string() - .trim() - ); - - Ok(()) -} - -/// Test using the include_tables flag: if multiple tables exist, only the specified table is generated. -#[tokio::test] -async fn test_include_tables_filter() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - // Create two tables in the database. - let statement1 = "CREATE TABLE table_one (id SERIAL PRIMARY KEY, data TEXT);"; - let statement2 = "CREATE TABLE table_two (id SERIAL PRIMARY KEY, info TEXT);"; - query(statement1).execute(&pool).await?; - query(statement2).execute(&pool).await?; - - // Only include table_one in the generated output. - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--include-tables", - "table_one", - ]); - let writer = generate_rust_from_database(&args).await; - let expected = r#" + .to_string() + .trim() + ); + + Ok(()) + } + + /// Test using the include_tables flag: if multiple tables exist, only the specified table is generated. + #[tokio::test] + async fn test_include_tables_filter() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + // Create two tables in the database. + let statement1 = "CREATE TABLE table_one (id SERIAL PRIMARY KEY, data TEXT);"; + let statement2 = "CREATE TABLE table_two (id SERIAL PRIMARY KEY, info TEXT);"; + query(statement1).execute(&pool).await?; + query(statement2).execute(&pool).await?; + + // Only include table_one in the generated output. + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--include-tables", + "table_one", + ]); + let writer = generate_rust_from_database(&args).await; + let expected = r#" #[derive(Debug, Clone, sqlx::FromRow)] pub struct TableOne { id: i32, data: Option, } "#; - assert_eq!(writer.write_to_string().trim(), expected.trim()); - Ok(()) -} - -/// Test passing extra derives for an enum type using the enum-derive flag. -#[tokio::test] -async fn test_enum_derives_flag() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement_1 = "CREATE TYPE color AS ENUM ('red', 'green', 'blue');"; - let statement_2 = " + assert_eq!(writer.write_to_string().trim(), expected.trim()); + Ok(()) + } + + /// Test passing extra derives for an enum type using the enum-derive flag. + #[tokio::test] + async fn test_enum_derives_flag() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement_1 = "CREATE TYPE color AS ENUM ('red', 'green', 'blue');"; + let statement_2 = " CREATE TABLE palette ( id SERIAL PRIMARY KEY, favorite color NOT NULL ); "; - query(statement_1).execute(&pool).await?; - query(statement_2).execute(&pool).await?; - - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--enum-derive", - "Debug,PartialEq", - ]); - let writer = generate_rust_from_database(&args).await; - let expected = r#" + query(statement_1).execute(&pool).await?; + query(statement_2).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--enum-derive", + "Debug,PartialEq", + ]); + let writer = generate_rust_from_database(&args).await; + let expected = r#" #[derive(Debug, PartialEq)] #[sqlx(type_name = "color")] pub enum Color { @@ -339,32 +339,233 @@ pub struct Palette { favorite: Color, } "#; - assert_eq!(writer.write_to_string().trim(), expected.trim()); - Ok(()) -} - -/// Test passing extra derives for the model using the model-derive flag. -#[tokio::test] -async fn test_model_derives_flag() -> Result<(), Box> { - let (pool, uri) = setup_pg_db().await; - let statement = "CREATE TABLE simple (id SERIAL PRIMARY KEY, value TEXT);"; - query(statement).execute(&pool).await?; - - let args = Cli::parse_from([ - "sql-gen", - "--db-url", - uri.as_str(), - "--model-derive", - "Debug,Clone,PartialEq", - ]); - let writer = generate_rust_from_database(&args).await; - let expected = r#" + assert_eq!(writer.write_to_string().trim(), expected.trim()); + Ok(()) + } + + /// Test passing extra derives for the model using the model-derive flag. + #[tokio::test] + async fn test_model_derives_flag() -> Result<(), Box> { + let (pool, uri) = setup_pg_db().await; + let statement = "CREATE TABLE simple (id SERIAL PRIMARY KEY, value TEXT);"; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--model-derive", + "Debug,Clone,PartialEq", + ]); + let writer = generate_rust_from_database(&args).await; + let expected = r#" #[derive(Debug, Clone, PartialEq)] pub struct Simple { id: i32, value: Option, } "#; - assert_eq!(writer.write_to_string().trim(), expected.trim()); - Ok(()) + assert_eq!(writer.write_to_string().trim(), expected.trim()); + Ok(()) + } +} + +mod mysql { + use clap::Parser as _; + use pretty_assertions::assert_eq; + use sqlx::query; + + use crate::{generate_rust_from_database, mysql::test_helper::setup_mysql_db, Cli}; + use std::error::Error; + + #[tokio::test] + async fn test_basic_mysql_tables() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + let statement = r#" + CREATE TABLE test_table_0 ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) UNIQUE, + description TEXT, + parent_id INT, + FOREIGN KEY (parent_id) REFERENCES test_table_0(id) + ); + "#; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#"#[derive(Debug, Clone, sqlx::FromRow)] +pub struct TestTable0 { + id: i32, + name: Option, + description: Option, + parent_id: Option, +}"# + .to_string() + ); + + Ok(()) + } + + #[tokio::test] + async fn test_basic_mysql_table_with_enum() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + // No separate CREATE TYPE needed for MySQL. + let statement = r#" + CREATE TABLE todos ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + tags JSON NOT NULL, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ) COMMENT='Table to store todo items with tags and status information.'; + "#; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from(["sql-gen", "--db-url", uri.as_str()]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#" +#[derive(Debug, Clone, PartialEq, sqlx::Type)] +pub enum TodoStatus { + #[sqlx(rename = "completed")] + Completed, + #[sqlx(rename = "in_progress")] + InProgress, + #[sqlx(rename = "pending")] + Pending, +} + +/// Table to store todo items with tags and status information. +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Todo { + id: i32, + title: String, + description: Option, + tags: serde_json::JsonValue, + status: TodoStatus, +} + "# + .to_string() + .trim() + ); + + Ok(()) + } + + #[tokio::test] + async fn test_field_override() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + let statement = r#" + CREATE TABLE todos ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + tags JSON NOT NULL, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ); + "#; + query(statement).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "status=String", + ]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#" +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Todo { + id: i32, + title: String, + description: Option, + tags: serde_json::JsonValue, + status: String, +} + "# + .to_string() + .trim() + ); + + Ok(()) + } + + #[tokio::test] + async fn test_table_specific_field_override() -> Result<(), Box> { + let (pool, uri) = setup_mysql_db().await; + let statement_todos = r#" + CREATE TABLE todos ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255) NOT NULL UNIQUE, + description TEXT, + tags JSON NOT NULL, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ); + "#; + let statement_other = r#" + CREATE TABLE other_todos_table ( + id INT AUTO_INCREMENT PRIMARY KEY, + status ENUM('pending','in_progress','completed') NOT NULL DEFAULT 'pending' + ); + "#; + query(statement_todos).execute(&pool).await?; + query(statement_other).execute(&pool).await?; + + let args = Cli::parse_from([ + "sql-gen", + "--db-url", + uri.as_str(), + "--table-overrides", + "todos.status=String,toto.status=i32", + ]); + + let writer = generate_rust_from_database(&args).await; + + assert_eq!( + writer.write_to_string().trim(), + r#" +#[derive(Debug, Clone, PartialEq, sqlx::Type)] +pub enum OtherTodosTableStatus { + #[sqlx(rename = "completed")] + Completed, + #[sqlx(rename = "in_progress")] + InProgress, + #[sqlx(rename = "pending")] + Pending, +} + +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct OtherTodosTable { + id: i32, + status: OtherTodosTableStatus, +} + +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Todo { + id: i32, + title: String, + description: Option, + tags: serde_json::JsonValue, + status: String, +} + "# + .to_string() + .trim() + ); + + Ok(()) + } }