diff options
Diffstat (limited to 'src/sql')
| -rw-r--r-- | src/sql/builder.rs | 493 | ||||
| -rw-r--r-- | src/sql/mod.rs | 8 |
2 files changed, 501 insertions, 0 deletions
diff --git a/src/sql/builder.rs b/src/sql/builder.rs new file mode 100644 index 0000000..7ba84f2 --- /dev/null +++ b/src/sql/builder.rs @@ -0,0 +1,493 @@ +// Enhanced SQL query builder with proper validation and complex WHERE support +use anyhow::{Context, Result}; +use regex::Regex; +use serde_json::Value; + +use crate::config::Config; +use crate::models::{FilterCondition, FilterOperator, OrderBy, OrderDirection}; + +/// Parse a table reference possibly containing an alias into (base_table, alias) +/// Accepts formats like: "table", "table alias", "table AS alias" (AS case-insensitive) +pub fn parse_table_and_alias(input: &str) -> (String, Option<String>) { + let s = input.trim(); + // Normalize multiple spaces + let parts: Vec<&str> = s.split_whitespace().collect(); + if parts.is_empty() { + return (String::new(), None); + } + + if parts.len() == 1 { + return (parts[0].to_string(), None); + } + + if parts.len() >= 3 && parts[1].eq_ignore_ascii_case("AS") { + // table AS alias [ignore extra] + return (parts[0].to_string(), Some(parts[2].to_string())); + } + + // table alias + (parts[0].to_string(), Some(parts[1].to_string())) +} + +/// Validates a table name against known tables and SQL injection patterns +pub fn validate_table_name(table: &str, config: &Config) -> Result<()> { + // Check if empty + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty"); + } + + // Check against known tables list + let known_tables = config.get_known_tables(); + if !known_tables.contains(&table.to_string()) { + anyhow::bail!("Table '{}' is not in the known tables list", table); + } + + // Validate format: only alphanumeric and underscores, must start with letter + let table_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table name regex")?; + + if !table_regex.is_match(table) { + anyhow::bail!("Invalid table name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", table); + } + + // Additional security: check for SQL keywords and dangerous patterns + let dangerous_patterns = [ + "--", "/*", "*/", ";", "DROP", "ALTER", "CREATE", "EXEC", "EXECUTE", + ]; + let table_upper = table.to_uppercase(); + for pattern in &dangerous_patterns { + if table_upper.contains(pattern) { + anyhow::bail!("Table name contains forbidden pattern: '{}'", pattern); + } + } + + Ok(()) +} + +/// Validates a column name against SQL injection patterns +/// Supports: column, table.column, table.column as alias, table.* +pub fn validate_column_name(column: &str) -> Result<()> { + // Check if empty + if column.trim().is_empty() { + anyhow::bail!("Column name cannot be empty"); + } + + // Allow * for SELECT all + if column == "*" { + return Ok(()); + } + + // Check for SQL injection patterns (comments and statement terminators) + let dangerous_chars = ["--", "/*", "*/", ";"]; + for pattern in &dangerous_chars { + if column.contains(pattern) { + anyhow::bail!("Column name contains forbidden pattern: '{}'", pattern); + } + } + + // Note: We don't block SQL keywords like DROP, CREATE, ALTER etc. in column names + // because legitimate columns like "created_date", "created_by" contain these as substrings. + // The regex validation below ensures only alphanumeric + underscore characters are allowed, + // which prevents actual SQL injection while allowing valid column names. + + // Support multiple formats: + // 1. Simple column: column_name + // 2. Qualified column: table.column_name + // 3. Wildcard: table.* + // 4. Alias: column_name as alias or table.column_name as alias + + // Remove AS alias if present (case insensitive) + let column_upper = column.to_uppercase(); + let column_without_alias = if column_upper.contains(" AS ") { + column.split_whitespace().next().unwrap_or("") + } else { + column + }; + + // Check for qualified column (table.column or table.*) + if column_without_alias.contains('.') { + let parts: Vec<&str> = column_without_alias.split('.').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Invalid qualified column format: '{}'. Must be table.column", + column + ); + } + + // Validate table part + let table_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table regex")?; + if !table_regex.is_match(parts[0]) { + anyhow::bail!("Invalid table name in qualified column: '{}'", parts[0]); + } + + // Validate column part (allow * for table.*) + if parts[1] != "*" { + let column_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile column regex")?; + if !column_regex.is_match(parts[1]) { + anyhow::bail!("Invalid column name in qualified column: '{}'", parts[1]); + } + } + } else { + // Simple column name validation + let column_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$") + .context("Failed to compile column name regex")?; + + if !column_regex.is_match(column_without_alias) { + anyhow::bail!("Invalid column name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", column); + } + } + + Ok(()) +} + +/// Validates multiple column names +pub fn validate_column_names(columns: &[String]) -> Result<()> { + for column in columns { + validate_column_name(column).with_context(|| format!("Invalid column name: {}", column))?; + } + Ok(()) +} + +/// Build WHERE clause from enhanced FilterCondition +/// Supports complex operators (=, !=, >, <, LIKE, IN, IS NULL, etc.) and nested logic (AND, OR, NOT) +pub fn build_filter_clause(filter: &FilterCondition) -> Result<(String, Vec<String>)> { + match filter { + FilterCondition::Simple { + column, + operator, + value, + } => { + validate_column_name(column)?; + build_simple_condition(column, operator, value) + } + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + if let Some(and_conds) = and_conditions { + if and_conds.is_empty() { + anyhow::bail!("AND conditions array cannot be empty"); + } + let mut conditions = Vec::new(); + let mut all_values = Vec::new(); + + for cond in and_conds { + let (sql, values) = build_filter_clause(cond)?; + conditions.push(format!("({})", sql)); + all_values.extend(values); + } + + Ok((conditions.join(" AND "), all_values)) + } else if let Some(or_conds) = or_conditions { + if or_conds.is_empty() { + anyhow::bail!("OR conditions array cannot be empty"); + } + let mut conditions = Vec::new(); + let mut all_values = Vec::new(); + + for cond in or_conds { + let (sql, values) = build_filter_clause(cond)?; + conditions.push(format!("({})", sql)); + all_values.extend(values); + } + + Ok((conditions.join(" OR "), all_values)) + } else { + anyhow::bail!("Logical condition must have either 'and' or 'or' field"); + } + } + FilterCondition::Not { not } => { + let (sql, values) = build_filter_clause(not)?; + Ok((format!("NOT ({})", sql), values)) + } + } +} + +/// Build a simple condition (column operator value) +fn build_simple_condition( + column: &str, + operator: &FilterOperator, + value: &Value, +) -> Result<(String, Vec<String>)> { + match operator { + FilterOperator::Eq => { + if value.is_null() { + Ok((format!("{} IS NULL", column), vec![])) + } else { + Ok((format!("{} = ?", column), vec![json_to_sql_string(value)])) + } + } + FilterOperator::Ne => { + if value.is_null() { + Ok((format!("{} IS NOT NULL", column), vec![])) + } else { + Ok((format!("{} != ?", column), vec![json_to_sql_string(value)])) + } + } + FilterOperator::Gt => Ok((format!("{} > ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Gte => Ok((format!("{} >= ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Lt => Ok((format!("{} < ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Lte => Ok((format!("{} <= ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Like => Ok(( + format!("{} LIKE ?", column), + vec![json_to_sql_string(value)], + )), + FilterOperator::NotLike => Ok(( + format!("{} NOT LIKE ?", column), + vec![json_to_sql_string(value)], + )), + FilterOperator::In => { + if let Value::Array(arr) = value { + if arr.is_empty() { + anyhow::bail!("IN operator requires non-empty array"); + } + let placeholders = vec!["?"; arr.len()].join(", "); + let values: Vec<String> = arr.iter().map(json_to_sql_string).collect(); + Ok((format!("{} IN ({})", column, placeholders), values)) + } else { + anyhow::bail!("IN operator requires array value"); + } + } + FilterOperator::NotIn => { + if let Value::Array(arr) = value { + if arr.is_empty() { + anyhow::bail!("NOT IN operator requires non-empty array"); + } + let placeholders = vec!["?"; arr.len()].join(", "); + let values: Vec<String> = arr.iter().map(json_to_sql_string).collect(); + Ok((format!("{} NOT IN ({})", column, placeholders), values)) + } else { + anyhow::bail!("NOT IN operator requires array value"); + } + } + FilterOperator::IsNull => { + // Value is ignored for IS NULL + Ok((format!("{} IS NULL", column), vec![])) + } + FilterOperator::IsNotNull => { + // Value is ignored for IS NOT NULL + Ok((format!("{} IS NOT NULL", column), vec![])) + } + FilterOperator::Between => { + if let Value::Array(arr) = value { + if arr.len() != 2 { + anyhow::bail!("BETWEEN operator requires array with exactly 2 values"); + } + let val1 = json_to_sql_string(&arr[0]); + let val2 = json_to_sql_string(&arr[1]); + Ok((format!("{} BETWEEN ? AND ?", column), vec![val1, val2])) + } else { + anyhow::bail!("BETWEEN operator requires array with [min, max] values"); + } + } + } +} + +/// Convert JSON value to SQL string representation +/// Properly handles all JSON types including booleans (true/false -> 1/0) +fn json_to_sql_string(value: &Value) -> String { + match value { + Value::String(s) => s.clone(), + Value::Number(n) => n.to_string(), + Value::Bool(b) => { + if *b { + "1".to_string() + } else { + "0".to_string() + } + } + Value::Null => "NULL".to_string(), + _ => serde_json::to_string(value).unwrap_or_else(|_| "NULL".to_string()), + } +} + +/// Build legacy WHERE clause from simple key-value JSON (for backward compatibility) +pub fn build_legacy_where_clause(where_clause: &Value) -> Result<(String, Vec<String>)> { + let mut conditions = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = where_clause { + for (key, value) in map { + validate_column_name(key)?; + + if value.is_null() { + // Handle NULL values with IS NULL + conditions.push(format!("{} IS NULL", key)); + // Don't add to values since IS NULL doesn't need a parameter + } else { + conditions.push(format!("{} = ?", key)); + values.push(json_to_sql_string(value)); + } + } + } else { + anyhow::bail!("WHERE clause must be an object"); + } + + if conditions.is_empty() { + anyhow::bail!("WHERE clause cannot be empty"); + } + + Ok((conditions.join(" AND "), values)) +} + +/// Build ORDER BY clause with column validation +pub fn build_order_by_clause(order_by: &[OrderBy]) -> Result<String> { + if order_by.is_empty() { + return Ok(String::new()); + } + + let mut clauses = Vec::new(); + for order in order_by { + validate_column_name(&order.column)?; + let direction = match order.direction { + OrderDirection::ASC => "ASC", + OrderDirection::DESC => "DESC", + }; + clauses.push(format!("{} {}", order.column, direction)); + } + + Ok(format!(" ORDER BY {}", clauses.join(", "))) +} + +/// Build JOIN clause from Join specifications +/// Validates table names and join conditions for security +pub fn build_join_clause(joins: &[crate::models::Join], config: &Config) -> Result<String> { + if joins.is_empty() { + return Ok(String::new()); + } + + let mut join_sql = String::new(); + + for join in joins { + // Extract base table and optional alias + let (base_table, alias) = parse_table_and_alias(&join.table); + + // Validate joined base table name + validate_table_name(&base_table, config)?; + + // Optionally validate alias format (same rules as table/column names) + if let Some(alias_name) = &alias { + let alias_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile alias regex")?; + if !alias_regex.is_match(alias_name) { + anyhow::bail!("Invalid table alias format: '{}'", alias_name); + } + } + + // Validate join condition (must be in format "table1.column1 = table2.column2") + validate_join_condition(&join.on)?; + + // Build JOIN clause based on type + let join_type_str = match join.join_type { + crate::models::JoinType::Inner => "INNER JOIN", + crate::models::JoinType::Left => "LEFT JOIN", + crate::models::JoinType::Right => "RIGHT JOIN", + }; + + // Reconstruct safe table reference + let table_ref = match alias { + Some(a) => format!("{} AS {}", base_table, a), + None => base_table, + }; + join_sql.push_str(&format!(" {} {} ON {}", join_type_str, table_ref, join.on)); + } + + Ok(join_sql) +} + +/// Validate JOIN ON condition +/// Must be in format: "table1.column1 = table2.column2" or similar simple conditions +fn validate_join_condition(condition: &str) -> Result<()> { + // Basic validation: must contain = and table.column references + if !condition.contains('=') { + anyhow::bail!("JOIN condition must contain = operator"); + } + + // Split by = and validate both sides + let parts: Vec<&str> = condition.split('=').map(|s| s.trim()).collect(); + if parts.len() != 2 { + anyhow::bail!("JOIN condition must have exactly one = operator"); + } + + // Validate both sides are table.column format + for part in parts { + validate_table_column_reference(part)?; + } + + Ok(()) +} + +/// Validate table.column reference (e.g., "assets.category_id") +fn validate_table_column_reference(reference: &str) -> Result<()> { + let parts: Vec<&str> = reference.split('.').collect(); + + if parts.len() != 2 { + anyhow::bail!( + "Table column reference must be in format 'table.column', got: {}", + reference + ); + } + + let table = parts[0].trim(); + let column = parts[1].trim(); + + // Validate table and column names follow safe patterns + let name_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile name regex")?; + + if !name_regex.is_match(table) { + anyhow::bail!("Invalid table name in JOIN condition: {}", table); + } + + if !name_regex.is_match(column) { + anyhow::bail!("Invalid column name in JOIN condition: {}", column); + } + + // Additional guard against comment/terminator tokens (redundant given regex, but safe) + // Intentionally do NOT block SQL keywords like CREATE/ALTER since identifiers like + // 'created_at' are legitimate and already validated by the regex above. + let dangerous_tokens = ["--", "/*", "*/", ";"]; + for token in &dangerous_tokens { + if reference.contains(token) { + anyhow::bail!("JOIN condition contains forbidden token: '{}'", token); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_table_name() { + // Valid names + assert!(validate_column_name("users").is_ok()); + assert!(validate_column_name("asset_id").is_ok()); + assert!(validate_column_name("table123").is_ok()); + + // Invalid names + assert!(validate_column_name("123table").is_err()); // starts with number + assert!(validate_column_name("table-name").is_err()); // contains hyphen + assert!(validate_column_name("table name").is_err()); // contains space + assert!(validate_column_name("table;DROP").is_err()); // SQL injection + assert!(validate_column_name("").is_err()); // empty + } + + #[test] + fn test_validate_column_name() { + // Valid names + assert!(validate_column_name("user_id").is_ok()); + assert!(validate_column_name("firstName").is_ok()); + assert!(validate_column_name("*").is_ok()); // wildcard allowed + + // Invalid names + assert!(validate_column_name("123column").is_err()); + assert!(validate_column_name("col-umn").is_err()); + assert!(validate_column_name("col umn").is_err()); + assert!(validate_column_name("").is_err()); + } +} diff --git a/src/sql/mod.rs b/src/sql/mod.rs new file mode 100644 index 0000000..fd5b959 --- /dev/null +++ b/src/sql/mod.rs @@ -0,0 +1,8 @@ +// SQL query building and validation module +pub mod builder; + +// Re-export commonly used functions +pub use builder::{ + build_filter_clause, build_join_clause, build_legacy_where_clause, build_order_by_clause, + validate_column_name, validate_column_names, validate_table_name, +}; |
