Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "db-migrator"
version = "0.2.5"
edition = "2021"
version = "0.3.0"
edition = "2024"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -15,6 +15,7 @@ clap = { version = "4", features = ["derive"] }
chrono = { version = "0.4" }
toml = "0.8"
async-trait = "0.1"
async-stream = "0.3"
hex = "0.4"
futures = "0.3"
tiberius = { version = "0.12.3" }
Expand Down
2 changes: 1 addition & 1 deletion src/common/schema.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::{anyhow, Context, Result};
use anyhow::{Context, Result, anyhow};
use tiberius::Row;

use crate::common::constraints::Constraint;
Expand Down
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow};
use toml::Value;

#[derive(Debug)]
Expand Down
2 changes: 1 addition & 1 deletion src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use anyhow::{Context, Result};
use async_trait::async_trait;
use bb8::Pool;
use bb8_tiberius::ConnectionManager;
use sqlx::mysql::{MySqlConnectOptions, MySqlPool, MySqlPoolOptions};
use sqlx::ConnectOptions;
use sqlx::mysql::{MySqlConnectOptions, MySqlPool, MySqlPoolOptions};
use tiberius::{AuthMethod, Config, EncryptionLevel};

use crate::config::DatabaseConfig;
Expand Down
69 changes: 45 additions & 24 deletions src/extract/extractor.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
use crate::common::schema::ColumnSchema;
use crate::common::sql::{escape_mssql_identifier, escape_sql_string};
use crate::extract::format::format_row_values;
use anyhow::{anyhow, Context, Result};
use bb8::{Pool, PooledConnection};
use crate::extract::traits::Extractor;
use anyhow::{Context, Result, anyhow};
use async_stream::stream;
use async_trait::async_trait;
use bb8::Pool;
use bb8_tiberius::ConnectionManager;
use futures::stream::{BoxStream, StreamExt};

#[derive(Clone)]
pub struct DatabaseExtractor {
pub pool: Pool<ConnectionManager>,
pool: Pool<ConnectionManager>,
}

impl DatabaseExtractor {
pub fn new(pool: Pool<ConnectionManager>) -> Self {
DatabaseExtractor { pool }
}
}

pub async fn fetch_tables(&mut self) -> Result<Vec<String>> {
#[async_trait]
impl Extractor for DatabaseExtractor {
async fn fetch_tables(&self) -> Result<Vec<String>> {
let mut conn = self.pool.get().await?;

let rows = conn
Expand All @@ -42,10 +48,10 @@ impl DatabaseExtractor {
Ok(tables)
}

pub async fn get_table_schema(&mut self, table: &str) -> Result<Vec<ColumnSchema>> {
async fn get_table_schema(&self, table: &str) -> Result<Vec<ColumnSchema>> {
let mut conn = self.pool.get().await?;

let query = format !(
let query = format!(
"SELECT
c.COLUMN_NAME,
c.DATA_TYPE,
Expand Down Expand Up @@ -86,23 +92,38 @@ impl DatabaseExtractor {

Ok(schema)
}
}

pub async fn open_row_stream<'a>(
conn: &'a mut PooledConnection<'_, ConnectionManager>,
table: &'a str,
) -> Result<BoxStream<'a, Result<Vec<String>, anyhow::Error>>> {
let query = format!("SELECT * FROM {}", escape_mssql_identifier(table));
let stream = conn
.simple_query(query)
.await?
.into_row_stream()
.map(|row_result| {
row_result
.map_err(anyhow::Error::from)
.and_then(format_row_values)
})
.boxed();

Ok(stream)
async fn stream_rows(&self, table: &str) -> Result<BoxStream<'static, Result<Vec<String>>>> {
let pool = self.pool.clone();
let table = table.to_owned();

let stream = stream! {
let conn_result = pool.get().await;
let mut conn = match conn_result {
Ok(conn) => conn,
Err(e) => {
yield Err(anyhow::Error::from(e));
return;
}
};

let query = format!("SELECT * FROM {}", escape_mssql_identifier(&table));
let query_stream = match conn.simple_query(query).await {
Ok(qs) => qs.into_row_stream(),
Err(e) => {
yield Err(anyhow::Error::from(e));
return;
}
};

futures::pin_mut!(query_stream);
while let Some(row_result) = query_stream.next().await {
yield row_result
.map_err(anyhow::Error::from)
.and_then(format_row_values);
}
};

Ok(Box::pin(stream))
}
}
2 changes: 1 addition & 1 deletion src/extract/format.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow};
use chrono::DateTime as ChronosDateTime;
use chrono::{Duration, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use hex::encode;
Expand Down
1 change: 1 addition & 0 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod extractor;
mod format;
pub mod traits;
12 changes: 12 additions & 0 deletions src/extract/traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;

use crate::common::schema::ColumnSchema;

#[async_trait]
pub trait Extractor: Clone + Send + Sync + 'static {
async fn fetch_tables(&self) -> Result<Vec<String>>;
async fn get_table_schema(&self, table: &str) -> Result<Vec<ColumnSchema>>;
async fn stream_rows(&self, table: &str) -> Result<BoxStream<'static, Result<Vec<String>>>>;
}
45 changes: 25 additions & 20 deletions src/insert/inserter.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use anyhow::{anyhow, Context, Result};
use anyhow::{Context, Result, anyhow};
use async_trait::async_trait;
use sqlx::{Acquire, Executor, MySqlPool, Row};

use crate::common::sql::{escape_mysql_identifier, escape_sql_string};

use crate::common::schema::ColumnSchema;
use crate::insert::query::{build_create_constraints, build_create_table_query, build_reset_query};
use crate::insert::table_action::TableAction;
use crate::insert::traits::Inserter;

#[derive(Clone)]
pub struct DatabaseInserter {
Expand All @@ -17,7 +19,18 @@ impl DatabaseInserter {
DatabaseInserter { pool }
}

pub async fn create_table(&mut self, table_name: &str, schema: &[ColumnSchema]) -> Result<()> {
async fn get_all_tables(&self) -> Result<Vec<String>> {
let rows = sqlx::query("SHOW TABLES").fetch_all(&self.pool).await?;

let table_names: Vec<String> = rows.iter().map(|row| row.get::<String, _>(0)).collect();

Ok(table_names)
}
}

#[async_trait]
impl Inserter for DatabaseInserter {
async fn create_table(&self, table_name: &str, schema: &[ColumnSchema]) -> Result<()> {
let create_table_query = build_create_table_query(table_name, schema);

debug!("Creating table {}", table_name);
Expand All @@ -31,8 +44,8 @@ impl DatabaseInserter {
Ok(())
}

pub async fn create_constraints(
&mut self,
async fn create_constraints(
&self,
table_name: &str,
schema: &[ColumnSchema],
formatted_tables: &[String],
Expand Down Expand Up @@ -72,7 +85,7 @@ impl DatabaseInserter {
Ok(())
}

pub async fn execute_transactional_query(&mut self, query: &str) -> Result<()> {
async fn execute_transactional_query(&self, query: &str) -> Result<()> {
let mut connection = self.pool.acquire().await?;
let mut transaction = connection.begin().await?;

Expand All @@ -97,18 +110,18 @@ impl DatabaseInserter {
Ok(())
}

pub async fn get_max_allowed_packet(&mut self) -> Result<usize> {
async fn get_max_allowed_packet(&self) -> Result<usize> {
let query = "SELECT @@max_allowed_packet";

let max_allowed_packet: u32 = sqlx::query_scalar(query).fetch_one(&self.pool).await?;

Ok(max_allowed_packet as usize)
}

pub async fn reset_tables(&mut self, tables: &[String], action: TableAction) -> Result<()> {
let mut all_tables = self.get_all_tables().await.with_context(|| {
"Resetting tables encountered an error, cannot obtain existing tables"
})?;
async fn reset_tables(&self, tables: &[String], action: TableAction) -> Result<()> {
let mut all_tables = self.get_all_tables().await.with_context(
|| "Resetting tables encountered an error, cannot obtain existing tables",
)?;

// Filter and keep only the tables that exist in the database and are also present in the `tables` slice
all_tables.retain(|table| {
Expand Down Expand Up @@ -136,15 +149,7 @@ impl DatabaseInserter {
Ok(())
}

async fn get_all_tables(&mut self) -> Result<Vec<String>> {
let rows = sqlx::query("SHOW TABLES").fetch_all(&self.pool).await?;

let table_names: Vec<String> = rows.iter().map(|row| row.get::<String, _>(0)).collect();

Ok(table_names)
}

pub async fn table_exists(&mut self, table_name: &str) -> Result<bool> {
async fn table_exists(&self, table_name: &str) -> Result<bool> {
let query = format!(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
escape_sql_string(table_name)
Expand All @@ -155,7 +160,7 @@ impl DatabaseInserter {
Ok(count > 0)
}

pub async fn table_rows_count(&mut self, table_name: &str) -> Result<i64> {
async fn table_rows_count(&self, table_name: &str) -> Result<i64> {
let query = format!(
"SELECT COUNT(*) FROM {}",
escape_mysql_identifier(table_name)
Expand Down
1 change: 1 addition & 0 deletions src/insert/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod inserter;
pub mod query;
pub mod table_action;
pub mod traits;
8 changes: 4 additions & 4 deletions src/insert/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ pub fn build_create_table_query(table_name: &str, schema: &[ColumnSchema]) -> St
}
}

if let Some(constraint) = &column.constraints {
if *constraint == Constraint::PrimaryKey {
result_str.push_str(" PRIMARY KEY");
}
if let Some(constraint) = &column.constraints
&& *constraint == Constraint::PrimaryKey
{
result_str.push_str(" PRIMARY KEY");
}

result_str.push(' ');
Expand Down
21 changes: 21 additions & 0 deletions src/insert/traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use anyhow::Result;
use async_trait::async_trait;

use crate::common::schema::ColumnSchema;
use crate::insert::table_action::TableAction;

#[async_trait]
pub trait Inserter: Clone + Send + Sync + 'static {
async fn create_table(&self, name: &str, schema: &[ColumnSchema]) -> Result<()>;
async fn create_constraints(
&self,
name: &str,
schema: &[ColumnSchema],
tables: &[String],
) -> Result<()>;
async fn execute_transactional_query(&self, query: &str) -> Result<()>;
async fn get_max_allowed_packet(&self) -> Result<usize>;
async fn reset_tables(&self, tables: &[String], action: TableAction) -> Result<()>;
async fn table_exists(&self, name: &str) -> Result<bool>;
async fn table_rows_count(&self, name: &str) -> Result<i64>;
}
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async fn run(options: Args) -> Result<()> {
whitelisted_tables: config.settings().whitelisted_tables.clone(),
};

let mut migrator = DatabaseMigrator::new(extractor, inserter, mappings, migration_options);
let migrator = DatabaseMigrator::new(extractor, inserter, mappings, migration_options);

migrator.run().await.with_context(|| "Migration failed")?;

Expand Down
2 changes: 1 addition & 1 deletion src/mappings.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow};
use std::collections::HashMap;

#[derive(Clone, Debug)]
Expand Down
Loading
Loading