diff --git a/src/args.rs b/src/args.rs index 166bebd..6709606 100644 --- a/src/args.rs +++ b/src/args.rs @@ -19,9 +19,19 @@ pub(crate) struct Args { pub(crate) verbose: u8, #[arg(short = 't', long = "table")] - #[arg(help = "Table or views to query. Can be used multiple times.")] - #[arg(action=ArgAction::Set)] - pub(crate) table: Option>, + #[arg(help = "Table or view to query. Can be used multiple times")] + #[arg(action=ArgAction::Append)] + pub(crate) table: Vec, + + #[arg(short = 's', long = "sql")] + #[arg(help = "SQL query to run. Can be used multiple times")] + #[arg(action=ArgAction::Append)] + pub(crate) query: Vec, + + #[arg(short = 'i', long = "ignore")] + #[arg(help = "Ignore non-readonly queries")] + #[arg(action=ArgAction::SetTrue)] + pub(crate) ignore_non_readonly: bool, #[arg(help = "Pattern to match every cell with")] pub(crate) pattern: String, diff --git a/src/cell_to_string.rs b/src/cell_to_string.rs index 111f196..5ecafdb 100644 --- a/src/cell_to_string.rs +++ b/src/cell_to_string.rs @@ -1,21 +1,10 @@ -use std::fmt::Display; - -use log::warn; use sqlx::sqlite::SqliteValueRef; use sqlx::Decode; use sqlx::Sqlite; use sqlx::Type; -use sqlx::TypeInfo; use sqlx::ValueRef; -// REVIEW: better way convert an error to a string? -fn errr_format(value: impl Display) -> String { - format!("{value}") -} - pub(crate) fn sqlite_cell_to_string(value_ref: SqliteValueRef) -> Result, String> { - // TODO: add an option to override types in some extent - if value_ref.is_null() { return Ok(None); } @@ -24,41 +13,45 @@ pub(crate) fn sqlite_cell_to_string(value_ref: SqliteValueRef) -> Result>::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(value)); } // // INTEGER, INT4 if >::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(format!("{value}"))); } // REAL if >::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(format!("{value}"))); } // BOOL? if >::compatible(&type_info) { - let value = >::decode(value_ref).map_err(errr_format)?; + let value = + >::decode(value_ref).map_err(|value| value.to_string())?; return Ok(Some(format!("{value}"))); } // DateTime if as Type>::compatible(&type_info) { let value = as Decode>::decode(value_ref) - .map_err(errr_format)?; + .map_err(|value| value.to_string())?; return Ok(Some(value.to_rfc3339())); } // Date if >::compatible(&type_info) { - let value = - >::decode(value_ref).map_err(errr_format)?; + let value = >::decode(value_ref) + .map_err(|value| value.to_string())?; return Ok(Some(value.format("%Y-%m-%d").to_string())); } // Time if >::compatible(&type_info) { - let value = - >::decode(value_ref).map_err(errr_format)?; + let value = >::decode(value_ref) + .map_err(|value| value.to_string())?; return Ok(Some(value.format("%H:%M:%S").to_string())); } @@ -73,6 +66,5 @@ pub(crate) fn sqlite_cell_to_string(value_ref: SqliteValueRef) -> Result i32 { + match self { + SQLError::Regex(error) => { + log::log!(level, "Regex error: {error}"); + + 64 + } + SQLError::QueryError(query_error) => { + match query_error { + QueryError::ReadOnlyQueryAllowed => { + log::log!(level, "Only readonly query is allowed"); + } + } + + 65 + } + SQLError::ParseError(error) => { + log::log!(level, "Unable to parse SQL: {error}"); + + 66 + } + SQLError::Io((context, error)) => { + let context = format_context(context); + log::log!(level, "IO error{context}: {error}"); + + 70 + } + SQLError::SqlX((context, error)) => { + let context = format_context(context); + log::log!(level, "SQL error{context}: {error}"); + + 74 + } + SQLError::ConvertCell((context, error)) => { + let context = format_context(context); + + log::log!(level, "Cell conversion error{context}: {error}"); + + 73 + } + } + } +} + +#[inline] +fn format_context(context: &String) -> String { + if context.is_empty() { + String::new() + } else { + format!(" ({context})") + } +} diff --git a/src/main.rs b/src/main.rs index d3657f9..a15b1a0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,28 @@ mod args; mod cell_to_string; +mod error; +mod matching; +mod pattern; +mod query; mod select; +use std::fs::OpenOptions; +use std::io::stdin; + +use error::Level; +use error::SQLError; +use matching::sqlite_check_rows; +use pattern::{Pattern, PatternKind}; +use query::{prepare_queries, SelectVariant}; use sqlparser::dialect::SQLiteDialect; -use sqlx::sqlite::SqliteConnectOptions; -use sqlx::Error; -use sqlx::{Column, Executor, Pool, Row, Sqlite, SqlitePool}; +use sqlx::{sqlite::SqliteConnectOptions, Executor as _, Pool, Row as _, Sqlite, SqlitePool}; #[tokio::main()] async fn main() { let args = args::parse_args(); // Set default log level to 2. - let quiet_level: i16 = 2 + args.verbose as i16 - args.quiet as i16; + let quiet_level: i16 = 2 + i16::from(args.verbose) - i16::from(args.quiet); stderrlog::new() .module(module_path!()) @@ -23,175 +33,149 @@ async fn main() { .init() .unwrap(); - let pattern = match regex::Regex::new(&args.pattern) { + let pattern = match Pattern::new(args.pattern.as_str(), &PatternKind::Regex) { Ok(pattern) => pattern, - Err(err) => { - log::error!("Unable to compile pattern: {}", err); - std::process::exit(64); - } + Err(err) => std::process::exit(err.report(Level::Error)), }; - let select_variant = match args.table { - None => SelectVariant::WholeDB, - Some(table_names) => SelectVariant::SpecificTables(table_names), + let queries = match read_queries(args.query) { + Ok(queries) => queries, + Err(err) => std::process::exit(err.report(Level::Error)), }; - match process_sqlite_database(args.database_uri, &pattern, select_variant).await { - Ok(_) => {} - Err(err) => { - log::error!("Unable read tables from database: {}", err); - std::process::exit(74); - } + match process_sqlite_database( + args.database_uri, + pattern, + args.table, + queries, + args.ignore_non_readonly, + ) + .await + { + Ok(()) => {} + Err(err) => std::process::exit(err.report(Level::Error)), } } async fn process_sqlite_database( database_uri: String, - pattern: ®ex::Regex, - select_variant: SelectVariant, -) -> Result<(), Error> { + pattern: Pattern, + tables: Vec, + queries: Vec, + ignore_non_read: bool, +) -> Result<(), SQLError> { let dialect = SQLiteDialect {}; - let options: SqliteConnectOptions = match database_uri.parse::() { - Ok(options) => options.read_only(true).immutable(true), - Err(err) => { - log::error!("Database URI error: {}", err); - std::process::exit(64); - } - }; - - let db = match SqlitePool::connect_with(options).await { - Ok(db) => db, - Err(err) => { - log::error!("Database connection error: {}", err); - std::process::exit(74); - } - }; - - match select_variant { - SelectVariant::WholeDB => match sqlite_select_tables(&db).await { - Ok(table_names) => { - let mut table_names = table_names; - sqlite_select_from_tables(&db, &mut table_names, pattern, &dialect).await + let options: SqliteConnectOptions = database_uri + .parse::() + .map(|options| options.read_only(true).immutable(true)) + .map_err(|error| SQLError::SqlX(("Database URI".into(), error)))?; + + let db = SqlitePool::connect_with(options) + .await + .map_err(|error| SQLError::SqlX(("Database connection".into(), error)))?; + + let select_variant = prepare_queries( + tables.into_iter(), + queries.into_iter(), + &dialect, + ignore_non_read, + )?; + + let queries = match select_variant { + SelectVariant::Queries(queries) => queries, + SelectVariant::WholeDB => { + let tables = sqlite_select_tables(&db).await?; + let select_variant = prepare_queries( + tables.into_iter(), + vec![].into_iter(), + &dialect, + ignore_non_read, + )?; + match select_variant { + SelectVariant::WholeDB => vec![], + SelectVariant::Queries(queries) => queries, } - Err(err) => Err(err), - }, - SelectVariant::SpecificTables(table_names) => { - let mut table_names = table_names.into_iter(); - sqlite_select_from_tables(&db, &mut table_names, pattern, &dialect).await } - } -} + }; -async fn sqlite_select_from_tables( - db: &Pool, - table_names: &mut Iter, - pattern: ®ex::Regex, - dialect: &SQLiteDialect, -) -> Result<(), Error> -where - Iter: Iterator, -{ - for table_name in table_names { - let select_query = select::generate_select(table_name.as_str(), dialect); - sqlite_check_rows(&table_name, db, select_query.as_str(), pattern).await; + for (query_id, query) in queries { + sqlite_check_rows(&db, query_id.as_str(), query.as_str(), &pattern).await; } Ok(()) } -async fn sqlite_select_tables(db: &Pool) -> Result, Error> { - let select_query = "SELECT name - FROM sqlite_schema - WHERE type ='table'"; +async fn sqlite_select_tables(db: &Pool) -> Result, SQLError> { + let select_query = "SELECT name FROM sqlite_schema WHERE type = 'table'"; log::debug!("Execute query: {select_query}"); - let result = db.fetch_all(select_query).await?; + + let result = db + .fetch_all(select_query) + .await + .map_err(|err| SQLError::SqlX(("fetch tables".into(), err)))?; Ok(result .into_iter() .filter_map(|row| match row.try_get::("name") { Ok(value) => Some(value), Err(err) => { - log::warn!("Error while reading from table `sqlite_schema`: {}", err); + SQLError::SqlX(("fetch tables".into(), err)).report(Level::Warn); None } - })) + }) + .collect()) } -async fn sqlite_check_rows( - table_name: &String, - db: &Pool, - select_query: &str, - pattern: ®ex::Regex, -) { - use futures::TryStreamExt; - use std::sync::atomic::AtomicI64; - use std::sync::atomic::Ordering; +fn read_queries(queries: Vec) -> Result, SQLError> { + let mut acc = vec![]; - log::debug!("Execute query: {select_query}"); - let mut rows = db.fetch(select_query); + queries.into_iter().try_fold((), |(), query| { + if query.is_empty() { + return Ok(()); + } - log::debug!("==> {table_name}"); - let row_idx: AtomicI64 = AtomicI64::new(-1); - loop { - row_idx.fetch_add(1, Ordering::SeqCst); - let idx = row_idx.load(Ordering::SeqCst); + if query == "-" { + return read_query(&mut stdin(), "").map(|query| { + acc.push(query); + }); + } - let row = match rows.try_next().await { - Ok(None) => break, - Ok(Some(row)) => row, - Err(err) => { - log::warn!( - "Error while reading row {idx} from table `{table_name}`: {}", - err - ); - continue; + match query.strip_prefix('@') { + None => { + acc.push(query); + Ok(()) } - }; + Some(filename) => read_from_file(filename).map(|query| { + acc.push(query); + }), + } + })?; - sqlite_process_row(idx, row, table_name, pattern); - } + Ok(acc) } -fn sqlite_process_row( - row_idx: i64, - row: sqlx::sqlite::SqliteRow, - table_name: &String, - pattern: ®ex::Regex, -) { - use sqlx::TypeInfo; - let columns = row.columns(); - for column in columns { - let index = column.ordinal(); - let column_name = column.name().to_owned(); - - let value_ref = match row.try_get_raw(index) { - Ok(value_ref) => value_ref, - Err(err) => { - log::warn!("Error while reading row {row_idx} from table {table_name} column {column_name}: {}", err); - continue; - } - }; +#[inline] +fn read_from_file(filename: &str) -> Result { + let mut file = OpenOptions::new() + .read(true) + .write(false) + .create(false) + .open(filename) + .expect("Unable to open"); - let value_str = match cell_to_string::sqlite_cell_to_string(value_ref) { - Ok(Some(value_str)) => value_str, - Ok(None) => continue, - Err(err) => { - let column_type = column.type_info().name(); - log::warn!("Error while converting data from row {row_idx} from table {table_name} column {column_name} of type {column_type}: {}", err); - continue; - } - }; - - if pattern.is_match(&value_str) { - println!("{table_name}::{row_idx}::{column_name} => {value_str:?}"); - } - } + read_query(&mut file, filename) } -#[non_exhaustive] -enum SelectVariant { - WholeDB, - SpecificTables(Vec), +#[inline] +fn read_query(file: &mut File, filename: &str) -> Result +where + File: std::io::Read, +{ + let mut query = String::new(); + + file.read_to_string(&mut query) + .map_err(|error| SQLError::Io((format!("read {filename}"), error))) + .map(|_| query) } diff --git a/src/matching.rs b/src/matching.rs new file mode 100644 index 0000000..889f062 --- /dev/null +++ b/src/matching.rs @@ -0,0 +1,75 @@ +use crate::cell_to_string::sqlite_cell_to_string; +use crate::error::Level; +use crate::{Pattern, SQLError}; + +use sqlx::{Column, Executor, Pool, Row, Sqlite}; + +pub async fn sqlite_check_rows( + db: &Pool, + query_id: &str, + select_query: &str, + pattern: &Pattern, +) { + use futures::TryStreamExt; + use std::sync::atomic::AtomicU64; + use std::sync::atomic::Ordering; + + log::debug!("{query_id}: {select_query}"); + + let mut rows = db.fetch(select_query); + + let row_counter: AtomicU64 = AtomicU64::new(0); + loop { + let row_idx = row_counter.load(Ordering::SeqCst); + + let row = match rows.try_next().await { + Ok(None) => break, + Ok(Some(row)) => row, + Err(error) => { + SQLError::SqlX((format!("{query_id}::{row_idx}"), error)).report(Level::Warn); + continue; + } + }; + + sqlite_process_row(row_idx, &row, query_id, pattern); + row_counter.fetch_add(1, Ordering::SeqCst); + } +} + +fn sqlite_process_row( + row_idx: u64, + row: &sqlx::sqlite::SqliteRow, + query_id: &str, + pattern: &Pattern, +) { + use sqlx::TypeInfo; + let columns = row.columns(); + for column in columns { + let index = column.ordinal(); + let column_name = column.name().to_owned(); + let column_type = column.type_info().name(); + let row_id = format!("{query_id}::{row_idx}::{column_name}"); + + let value_ref = match row.try_get_raw(index) { + Ok(value_ref) => value_ref, + Err(error) => { + SQLError::SqlX((row_id, error)).report(Level::Warn); + continue; + } + }; + + let value_str = match sqlite_cell_to_string(value_ref) { + Ok(Some(value_str)) => value_str, + Ok(None) => continue, + Err(error) => { + let error_context = format!("{row_id} cell type {column_type}"); + SQLError::ConvertCell((error_context, error)).report(Level::Warn); + continue; + } + }; + + if pattern.is_match(&value_str) { + println!("{row_id} => {value_str}"); + } + } +} diff --git a/src/pattern.rs b/src/pattern.rs new file mode 100644 index 0000000..60d1832 --- /dev/null +++ b/src/pattern.rs @@ -0,0 +1,25 @@ +use crate::error::SQLError; + +pub(crate) enum PatternKind { + Regex, +} + +pub(crate) enum Pattern { + Regex(regex::Regex), +} + +impl Pattern { + pub fn new(pattern: &str, kind: &PatternKind) -> Result { + match kind { + PatternKind::Regex => regex::Regex::new(pattern) + .map(Self::Regex) + .map_err(SQLError::Regex), + } + } + + pub fn is_match(&self, value: &str) -> bool { + match self { + Pattern::Regex(regex) => regex.is_match(value), + } + } +} diff --git a/src/query.rs b/src/query.rs new file mode 100644 index 0000000..36918f5 --- /dev/null +++ b/src/query.rs @@ -0,0 +1,43 @@ +use sqlparser::dialect::Dialect; + +use crate::error::SQLError; +use crate::select::{escape_table_name, generate_select, read_verify_query}; + +#[non_exhaustive] +pub(crate) enum SelectVariant { + WholeDB, + Queries(Vec<(String, String)>), +} + +pub(crate) fn prepare_queries( + table: T, + queries: T, + dialect: &impl Dialect, + ignore_non_read: bool, +) -> Result +where + T: Iterator, +{ + let mut queries_result: Vec<(String, String)> = table + .map(|table_name| { + ( + format!("Table {}", escape_table_name(table_name.as_str(), dialect)), + generate_select(&table_name, dialect), + ) + }) + .collect(); + + let mut idx = 0usize; + queries.into_iter().try_fold((), |(), sql| { + read_verify_query(&sql, dialect, ignore_non_read, &mut idx)? + .into_iter() + .for_each(|query| queries_result.push((format!("Query #{idx}"), query))); + Ok(()) + })?; + + Ok(if queries_result.is_empty() { + SelectVariant::WholeDB + } else { + SelectVariant::Queries(queries_result) + }) +} diff --git a/src/select.rs b/src/select.rs index 09f80ea..a639086 100644 --- a/src/select.rs +++ b/src/select.rs @@ -1,8 +1,14 @@ use sqlparser::ast::helpers::attached_token::AttachedToken; -use sqlparser::ast::*; +use sqlparser::ast::{ + GroupByExpr, Ident, Select, SelectFlavor, SelectItem, SetExpr, Statement, TableFactor, + TableWithJoins, WildcardAdditionalOptions, +}; use sqlparser::dialect::Dialect; +use sqlparser::parser::Parser; use sqlparser::tokenizer::Span; +use crate::error::{QueryError, SQLError}; + /// /// Generates wildcard select for given dialect: /// @@ -70,3 +76,36 @@ pub(crate) fn generate_select(table_name: &str, dialect: &impl Dialect) -> Strin ast.to_string() } + +pub(crate) fn escape_table_name(table_name: &str, dialect: &impl Dialect) -> String { + Ident { + value: table_name.to_owned(), + quote_style: dialect.identifier_quote_style(table_name), + span: Span::empty(), + } + .to_string() +} + +/// Checks and reformat select +pub(crate) fn read_verify_query( + sql: &str, + dialect: &impl Dialect, + ignore_non_read: bool, + idx: &mut usize, +) -> Result, SQLError> { + let ast = Parser::parse_sql(dialect, sql).map_err(SQLError::ParseError)?; + let mut acc: Vec = vec![]; + ast.iter().try_fold((), |(), statement| { + if matches!(statement, Statement::Query(_)) { + acc.push(statement.to_string()); + *idx += 1; + Ok(()) + } else if ignore_non_read { + Ok(()) + } else { + Err(SQLError::QueryError(QueryError::ReadOnlyQueryAllowed)) + } + })?; + + Ok(acc) +}