Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod constraints;
pub mod errors;
pub mod helpers;
pub mod schema;
pub mod sql;
98 changes: 98 additions & 0 deletions src/common/sql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/// Escape a MSSQL identifier by wrapping in `[]` and doubling any `]` inside.
/// Example: `my]table` -> `[my]]table]`
pub fn escape_mssql_identifier(name: &str) -> String {
format!("[{}]", name.replace(']', "]]"))
}

/// Escape a MySQL identifier by wrapping in backticks and doubling any backtick inside.
/// Example: `my`table` -> `` `my``table` ``
pub fn escape_mysql_identifier(name: &str) -> String {
format!("`{}`", name.replace('`', "``"))
}

/// Escape a string value for use in a SQL string literal (single-quoted).
/// Doubles single quotes and escapes backslashes.
/// Example: `O'Brien` -> `O''Brien`
pub fn escape_sql_string(value: &str) -> String {
value.replace('\\', "\\\\").replace('\'', "''")
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_escape_mssql_identifier_simple() {
assert_eq!(escape_mssql_identifier("users"), "[users]");
}

#[test]
fn test_escape_mssql_identifier_with_bracket() {
assert_eq!(escape_mssql_identifier("my]table"), "[my]]table]");
}

#[test]
fn test_escape_mssql_identifier_with_spaces() {
assert_eq!(escape_mssql_identifier("my table"), "[my table]");
}

#[test]
fn test_escape_mssql_identifier_reserved_word() {
assert_eq!(escape_mssql_identifier("select"), "[select]");
}

#[test]
fn test_escape_mysql_identifier_simple() {
assert_eq!(escape_mysql_identifier("users"), "`users`");
}

#[test]
fn test_escape_mysql_identifier_with_backtick() {
assert_eq!(escape_mysql_identifier("my`table"), "`my``table`");
}

#[test]
fn test_escape_mysql_identifier_with_spaces() {
assert_eq!(escape_mysql_identifier("my table"), "`my table`");
}

#[test]
fn test_escape_mysql_identifier_reserved_word() {
assert_eq!(escape_mysql_identifier("select"), "`select`");
}

#[test]
fn test_escape_sql_string_simple() {
assert_eq!(escape_sql_string("hello"), "hello");
}

#[test]
fn test_escape_sql_string_with_quote() {
assert_eq!(escape_sql_string("O'Brien"), "O''Brien");
}

#[test]
fn test_escape_sql_string_with_backslash() {
assert_eq!(escape_sql_string("path\\to"), "path\\\\to");
}

#[test]
fn test_escape_sql_string_with_both() {
assert_eq!(escape_sql_string("it's a\\path"), "it''s a\\\\path");
}

#[test]
fn test_escape_sql_string_empty() {
assert_eq!(escape_sql_string(""), "");
}

#[test]
fn test_escape_mysql_identifier_empty() {
assert_eq!(escape_mysql_identifier(""), "``");
}

#[test]
fn test_escape_mssql_identifier_multiple_brackets() {
assert_eq!(escape_mssql_identifier("a]b]c"), "[a]]b]]c]");
}
}
5 changes: 3 additions & 2 deletions src/extract/extractor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::common::schema::ColumnSchema;
use crate::common::sql::{escape_mssql_identifier, escape_sql_string};
use crate::extract::format::format_row_values;
use anyhow::{anyhow, Context, Result};
use bb8::{Pool, PooledConnection};
Expand Down Expand Up @@ -72,7 +73,7 @@ impl DatabaseExtractor {
FROM
INFORMATION_SCHEMA.COLUMNS c
WHERE c.TABLE_NAME = '{}';",
table
escape_sql_string(table)
);

let rows = conn.simple_query(query).await?.into_first_result().await?;
Expand All @@ -91,7 +92,7 @@ pub async fn open_row_stream<'a>(
conn: &'a mut PooledConnection<'_, ConnectionManager>,
table: &'a str,
) -> Result<BoxStream<'a, Result<Vec<String>, anyhow::Error>>> {
let query = format!("SELECT * FROM [{}]", table);
let query = format!("SELECT * FROM {}", escape_mssql_identifier(table));
let stream = conn
.simple_query(query)
.await?
Expand Down
9 changes: 7 additions & 2 deletions src/insert/inserter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use anyhow::{anyhow, Context, Result};
use sqlx::{Acquire, Executor, MySqlPool, Row};

use crate::common::sql::{escape_mysql_identifier, escape_sql_string};

use crate::common::schema::ColumnSchema;
use crate::insert::query::{build_create_constraints, build_create_table_query, build_reset_query};
use crate::insert::table_action::TableAction;
Expand Down Expand Up @@ -145,7 +147,7 @@ impl DatabaseInserter {
pub async fn table_exists(&mut self, table_name: &str) -> Result<bool> {
let query = format!(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
table_name
escape_sql_string(table_name)
);

let count: i64 = sqlx::query_scalar(&query).fetch_one(&self.pool).await?;
Expand All @@ -154,7 +156,10 @@ impl DatabaseInserter {
}

pub async fn table_rows_count(&mut self, table_name: &str) -> Result<i64> {
let query = format!("SELECT COUNT(*) FROM `{}`", table_name);
let query = format!(
"SELECT COUNT(*) FROM {}",
escape_mysql_identifier(table_name)
);

let count: i64 = sqlx::query_scalar(&query).fetch_one(&self.pool).await?;

Expand Down
85 changes: 67 additions & 18 deletions src/insert/query.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use crate::common::constraints::Constraint;
use crate::common::schema::ColumnSchema;
use crate::common::sql::escape_mysql_identifier;
use crate::insert::table_action::TableAction;

pub fn build_insert_statement(table_name: &str, schema: &[ColumnSchema]) -> String {
let column_names_string = schema
.iter()
.map(|column| column.column_name.as_str())
.map(|column| escape_mysql_identifier(&column.column_name))
.collect::<Vec<_>>()
.join(", ");

format!(
"INSERT INTO `{}` ({}) VALUES",
table_name, column_names_string
"INSERT INTO {} ({}) VALUES",
escape_mysql_identifier(table_name),
column_names_string
)
}

Expand All @@ -20,9 +22,9 @@ pub fn build_reset_query(tables: &[String], action: &TableAction) -> String {
.iter()
.map(|table_name| {
format!(
"{} TABLE `{}`;",
"{} TABLE {};",
action.to_string().to_uppercase(),
table_name
escape_mysql_identifier(table_name)
)
})
.collect::<Vec<_>>()
Expand Down Expand Up @@ -62,15 +64,18 @@ pub fn build_create_constraints(
}
})
.map(|constraints| match constraints {
//Constraint::PrimaryKey => format!("ADD PRIMARY KEY(`{}`)", column.column_name),
Constraint::ForeignKey {
referenced_table,
referenced_column,
} => format!(
"ADD FOREIGN KEY(`{}`) REFERENCES `{}`(`{}`) ON DELETE CASCADE",
column.column_name, referenced_table, referenced_column
"ADD FOREIGN KEY({}) REFERENCES {}({}) ON DELETE CASCADE",
escape_mysql_identifier(&column.column_name),
escape_mysql_identifier(referenced_table),
escape_mysql_identifier(referenced_column)
),
Constraint::Unique => format!("ADD UNIQUE(`{}`)", column.column_name),
Constraint::Unique => {
format!("ADD UNIQUE({})", escape_mysql_identifier(&column.column_name))
}
Constraint::Check(check_clause) => format!("ADD CHECK ({})", check_clause),
Constraint::Default(default_value) => format!("ADD DEFAULT {}", default_value),
_ => String::new(),
Expand All @@ -84,8 +89,8 @@ pub fn build_create_constraints(
}

let alter_table_query = format!(
"SET FOREIGN_KEY_CHECKS=0; ALTER TABLE `{}` {}",
table_name,
"SET FOREIGN_KEY_CHECKS=0; ALTER TABLE {} {}",
escape_mysql_identifier(table_name),
constraints.join(", ")
);

Expand All @@ -98,7 +103,7 @@ pub fn build_create_table_query(table_name: &str, schema: &[ColumnSchema]) -> St
.map(|column| {
let mut result_str = String::new();

result_str.push_str(&column.column_name);
result_str.push_str(&escape_mysql_identifier(&column.column_name));
result_str.push(' ');

result_str.push_str(&column.data_type);
Expand Down Expand Up @@ -131,7 +136,11 @@ pub fn build_create_table_query(table_name: &str, schema: &[ColumnSchema]) -> St
.collect();

let columns = columns.join(", ");
format!("CREATE TABLE `{}` ({})", table_name, columns)
format!(
"CREATE TABLE {} ({})",
escape_mysql_identifier(table_name),
columns
)
}

#[cfg(test)]
Expand All @@ -157,14 +166,14 @@ mod tests {
make_column("name", "varchar", true),
];
let result = build_insert_statement("users", &schema);
assert_eq!(result, "INSERT INTO `users` (id, name) VALUES");
assert_eq!(result, "INSERT INTO `users` (`id`, `name`) VALUES");
}

#[test]
fn test_build_insert_statement_single_column() {
let schema = vec![make_column("id", "int", false)];
let result = build_insert_statement("test", &schema);
assert_eq!(result, "INSERT INTO `test` (id) VALUES");
assert_eq!(result, "INSERT INTO `test` (`id`) VALUES");
}

#[test]
Expand All @@ -175,8 +184,8 @@ mod tests {
];
let result = build_create_table_query("users", &schema);
assert!(result.starts_with("CREATE TABLE `users`"));
assert!(result.contains("id int NOT NULL"));
assert!(result.contains("name varchar NULL"));
assert!(result.contains("`id` int NOT NULL"));
assert!(result.contains("`name` varchar NULL"));
}

#[test]
Expand Down Expand Up @@ -243,7 +252,9 @@ mod tests {
let result = build_create_constraints("orders", &schema, &formatted_tables);
assert!(result.is_some());
let query = result.unwrap();
assert!(query.contains("ADD FOREIGN KEY(`user_id`) REFERENCES `users`(`id`)"));
assert!(
query.contains("ADD FOREIGN KEY(`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE")
);
}

#[test]
Expand Down Expand Up @@ -288,4 +299,42 @@ mod tests {
assert!(result.is_some());
assert!(result.unwrap().contains("ADD CHECK (age > 0)"));
}

#[test]
fn test_build_insert_reserved_word_column() {
let schema = vec![make_column("select", "int", false)];
let result = build_insert_statement("order", &schema);
assert_eq!(result, "INSERT INTO `order` (`select`) VALUES");
}

#[test]
fn test_build_create_table_backtick_in_name() {
let schema = vec![make_column("col`name", "int", false)];
let result = build_create_table_query("my`table", &schema);
assert!(result.contains("CREATE TABLE `my``table`"));
assert!(result.contains("`col``name`"));
}

#[test]
fn test_build_reset_query_reserved_word() {
let tables = vec!["order".to_string(), "select".to_string()];
let result = build_reset_query(&tables, &TableAction::Drop);
assert!(result.contains("DROP TABLE `order`;"));
assert!(result.contains("DROP TABLE `select`;"));
}

#[test]
fn test_build_create_constraints_escaped_fk() {
let mut col = make_column("group", "int", false);
col.constraints = Some(Constraint::ForeignKey {
referenced_table: "order".to_string(),
referenced_column: "select".to_string(),
});
let schema = vec![col];
let formatted_tables = vec!["order".to_string()];
let result = build_create_constraints("test", &schema, &formatted_tables);
assert!(result.is_some());
let query = result.unwrap();
assert!(query.contains("ADD FOREIGN KEY(`group`) REFERENCES `order`(`select`)"));
}
}
Loading