Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ license = "MIT"
repository = "https://github.com/jayy-lmao/sql-gen"

[dependencies]
sqlx = { version = "0.7", features = ["postgres","runtime-tokio", "mysql"] }
sqlx-cli = "0.7"
sqlx = { version = "0.8.3", features = ["postgres","runtime-tokio", "mysql"] }
sqlx-cli = "0.8.3"
clap = { version = "4.5.31", features = ["derive"] }
regex = "1.5"
chrono = "0.4"
Expand Down
5 changes: 4 additions & 1 deletion src/core/models/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub struct CustomEnum {
pub struct TableColumnBuilder {
column_name: String,
column_comment: Option<String>,
recommended_rust_type: Option<String>,
udt_name: String,
data_type: String,
is_nullable: bool,
Expand All @@ -59,6 +60,7 @@ impl TableColumnBuilder {
column_name: impl ToString,
udt_name: impl ToString,
data_type: impl ToString,
recommended_rust_type: Option<String>,
) -> Self {
Self {
column_name: column_name.to_string(),
Expand All @@ -72,6 +74,7 @@ impl TableColumnBuilder {
foreign_key_id: None,
is_auto_populated: false,
array_depth: 0,
recommended_rust_type,
}
}

Expand Down Expand Up @@ -117,7 +120,7 @@ impl TableColumnBuilder {
pub fn build(self) -> TableColumn {
TableColumn {
column_name: self.column_name,
recommended_rust_type: convert_data_type(&self.udt_name),
recommended_rust_type: self.recommended_rust_type,
udt_name: self.udt_name,
data_type: self.data_type,
is_nullable: self.is_nullable,
Expand Down
52 changes: 36 additions & 16 deletions src/core/translators/convert_table_to_struct_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ fn should_convert_table_with_basic_column() {
let table = Table {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("title", "text", "text").build()],
columns: vec![
TableColumnBuilder::new("title", "text", "text", Some("String".to_string())).build(),
],
..Default::default()
};

Expand Down Expand Up @@ -117,14 +119,14 @@ fn should_convert_table_with_each_column_attribute_type() {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![
TableColumnBuilder::new("id", "uuid", "uuid")
TableColumnBuilder::new("id", "uuid", "uuid", Some("uuid::Uuid".to_string()))
.is_auto_populated()
.is_primary_key()
.build(),
TableColumnBuilder::new("title", "text", "text")
TableColumnBuilder::new("title", "text", "text", Some("String".to_string()))
.is_unique()
.build(),
TableColumnBuilder::new("description", "text", "text")
TableColumnBuilder::new("description", "text", "text", Some("String".to_string()))
.is_nullable()
.build(),
],
Expand Down Expand Up @@ -170,9 +172,14 @@ fn should_convert_table_with_optional_column() {
let table = Table {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("description", "text", "text")
.is_nullable()
.build()],
columns: vec![TableColumnBuilder::new(
"description",
"text",
"text",
Some("String".to_string()),
)
.is_nullable()
.build()],
..Default::default()
};
let mut options = CodegenOptions::default();
Expand Down Expand Up @@ -200,10 +207,15 @@ fn should_convert_table_with_array_column() {
let table = Table {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("tags", "_text", "ARRAY")
.is_nullable()
.array_depth(1)
.build()],
columns: vec![TableColumnBuilder::new(
"tags",
"_text",
"ARRAY",
Some("String".to_string()),
)
.is_nullable()
.array_depth(1)
.build()],
..Default::default()
};
let mut options = CodegenOptions::default();
Expand Down Expand Up @@ -232,7 +244,9 @@ fn should_convert_table_with_enum_column() {
let table = Table {
table_name: "orders".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("order_status", "status", "USER-DEFINED").build()],
columns: vec![
TableColumnBuilder::new("order_status", "status", "USER-DEFINED", None).build(),
],
..Default::default()
};
let enums: Vec<CustomEnum> = vec![CustomEnum {
Expand Down Expand Up @@ -277,7 +291,7 @@ fn should_ignore_columns_with_invalid_types() {
let table = Table {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("title", "badtype", "badtype").build()],
columns: vec![TableColumnBuilder::new("title", "badtype", "badtype", None).build()],
..Default::default()
};
let mut options = CodegenOptions::default();
Expand All @@ -300,7 +314,9 @@ fn should_convert_table_with_column_type_override() {
let table = Table {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("id", "i32", "i32").build()],
columns: vec![
TableColumnBuilder::new("id", "int4", "int4", Some("i32".to_string())).build(),
],
..Default::default()
};

Expand Down Expand Up @@ -335,7 +351,9 @@ fn should_convert_table_with_global_type_override() {
let table = Table {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("id", "int4", "int4").build()],
columns: vec![
TableColumnBuilder::new("id", "int4", "int4", Some("i32".to_string())).build(),
],
..Default::default()
};

Expand Down Expand Up @@ -370,7 +388,9 @@ fn column_override_takes_preference_over_global_type_override() {
let table = Table {
table_name: "products".to_string(),
table_schema: Some("public".to_string()),
columns: vec![TableColumnBuilder::new("price", "int4", "int4").build()],
columns: vec![
TableColumnBuilder::new("price", "int4", "int4", Some("i32".to_string())).build(),
],
..Default::default()
};

Expand Down
2 changes: 1 addition & 1 deletion src/mysql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod models;
pub mod queries;
#[cfg(test)]
mod test_helper;
pub mod test_helper;
46 changes: 33 additions & 13 deletions src/mysql/queries/convert_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,42 @@ pub fn convert_data_type(udt_type: &str) -> Option<String> {
}

match udt_type {
"bool" | "boolean" => Some("bool".to_string()),
"bytea" => Some("u8".to_string()), // is this right?
"char" | "bpchar" | "character" => Some("String".to_string()),
// Boolean: MySQL treats TINYINT(1), BOOLEAN and BOOL as booleans.
"bool" | "boolean" | "tinyint(1)" => Some("bool".to_string()),

// Numeric types
"tinyint unsigned" => Some("u8".to_string()),
"tinyint" => Some("i8".to_string()),
"smallint unsigned" => Some("u16".to_string()),
"smallint" => Some("i16".to_string()),
"int unsigned" => Some("u32".to_string()),
"int" => Some("i32".to_string()),
"bigint unsigned" => Some("u64".to_string()),
"bigint" => Some("i64".to_string()),
"float" => Some("f32".to_string()),
"double" => Some("f64".to_string()),

// String types: VARCHAR, CHAR and TEXT map to Rust’s String.
"varchar" | "char" | "text" => Some("String".to_string()),

// Binary types: VARBINARY, BINARY and BLOB map to Vec<u8>
"varbinary" | "binary" | "blob" => Some("Vec<u8>".to_string()),

// Date and time types
"date" => Some("chrono::NaiveDate".to_string()),
"float4" | "real" => Some("f32".to_string()),
"float8" | "double precision" => Some("f64".to_string()),
"int2" | "smallint" | "smallserial" => Some("i16".to_string()),
"int4" | "int" | "serial" => Some("i32".to_string()),
"int8" | "bigint" | "bigserial" => Some("i64".to_string()),
"void" => Some("()".to_string()),
"jsonb" | "json" => Some("serde_json::Value".to_string()),
"text" | "varchar" | "name" | "citext" => Some("String".to_string()),
"datetime" => Some("chrono::NaiveDateTime".to_string()),
"timestamp" => Some("chrono::DateTime<chrono::Utc>".to_string()),
"time" => Some("chrono::NaiveTime".to_string()),
"timestamp" => Some("chrono::NaiveDateTime".to_string()),
"timestamptz" => Some("chrono::DateTime<chrono::Utc>".to_string()),

// Decimal type
"decimal" => Some("rust_decimal::Decimal".to_string()),

// UUID type: MySQL often stores UUIDs as BINARY(16)
"uuid" => Some("uuid::Uuid".to_string()),

// JSON type: maps to a serde_json value.
"json" => Some("serde_json::JsonValue".to_string()),

_ => None,
}
}
6 changes: 3 additions & 3 deletions src/mysql/queries/get_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ pub async fn get_mysql_enums(pool: &MySqlPool) -> Result<Vec<CustomEnum>, sqlx::
WITH RECURSIVE enum_split AS (
-- Initial row: extract the full list of enum values from COLUMN_TYPE
SELECT
c.TABLE_SCHEMA AS `schema`,
CAST(c.TABLE_SCHEMA AS CHAR) AS `schema`,
CONCAT(c.TABLE_NAME, '.', c.COLUMN_NAME) AS enum_type,
c.COLUMN_TYPE,
c.COLUMN_COMMENT AS enum_type_comment,
CAST(c.COLUMN_COMMENT AS CHAR) AS enum_type_comment,
-- Remove the leading "enum(" and trailing ")" then extract the first value.
TRIM(BOTH '\'' FROM SUBSTRING_INDEX(SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1), ',', 1)) AS enum_value,
CAST(TRIM(BOTH '\'' FROM SUBSTRING_INDEX(SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1), ',', 1)) AS CHAR) AS enum_value,
CASE
WHEN LOCATE(',', SUBSTRING(c.COLUMN_TYPE, 6, CHAR_LENGTH(c.COLUMN_TYPE) - 6 - 1)) > 0
THEN TRIM(LEADING ' ' FROM SUBSTRING(
Expand Down
4 changes: 2 additions & 2 deletions src/mysql/queries/get_enums_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::error::Error;

#[tokio::test]
async fn test_get_mysql_enums() -> Result<(), Box<dyn Error>> {
let pool = setup_mysql_db().await;
let (pool, _) = setup_mysql_db().await;

sqlx::query("DROP TABLE IF EXISTS test_table;")
.execute(&pool)
Expand Down Expand Up @@ -51,7 +51,7 @@ async fn test_get_mysql_enums() -> Result<(), Box<dyn Error>> {

#[tokio::test]
async fn test_get_mysql_enums_with_comments() -> Result<(), Box<dyn Error>> {
let pool = setup_mysql_db().await;
let (pool, _) = setup_mysql_db().await;
// Clean up any existing type
// Drop any existing table that uses the enum.
sqlx::query("DROP TABLE IF EXISTS test_weather;")
Expand Down
12 changes: 6 additions & 6 deletions src/mysql/queries/get_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ pub async fn get_tables(
// get all tables from the database
let query = "
SELECT
c.TABLE_NAME AS table_name,
c.COLUMN_NAME AS column_name,
c.COLUMN_TYPE AS udt_name,
c.DATA_TYPE AS data_type,
CAST(c.TABLE_NAME AS CHAR) AS table_name,
CAST(c.COLUMN_NAME AS CHAR) AS column_name,
CAST(c.COLUMN_TYPE AS CHAR) AS udt_name,
CAST(c.DATA_TYPE AS CHAR) AS data_type,
'' AS table_schema,
(c.IS_NULLABLE = 'YES') AS is_nullable,
(c.COLUMN_KEY = 'PRI') AS is_primary_key,
(c.COLUMN_KEY = 'UNI') AS is_unique,
kcu.REFERENCED_TABLE_NAME AS foreign_key_table,
CAST(kcu.REFERENCED_TABLE_NAME AS CHAR) AS foreign_key_table,
kcu.REFERENCED_COLUMN_NAME AS foreign_key_id,
NULLIF(c.COLUMN_COMMENT, '') AS column_comment,
NULLIF(CAST(c.COLUMN_COMMENT as CHAR), '') AS column_comment,
NULLIF(t.TABLE_COMMENT, '') AS table_comment,
CASE
WHEN c.COLUMN_DEFAULT IS NOT NULL
Expand Down
Loading
Loading