aboutsummaryrefslogtreecommitdiff
path: root/src/sql
diff options
context:
space:
mode:
Diffstat (limited to 'src/sql')
-rw-r--r--src/sql/builder.rs493
-rw-r--r--src/sql/mod.rs8
2 files changed, 501 insertions, 0 deletions
diff --git a/src/sql/builder.rs b/src/sql/builder.rs
new file mode 100644
index 0000000..7ba84f2
--- /dev/null
+++ b/src/sql/builder.rs
@@ -0,0 +1,493 @@
+// Enhanced SQL query builder with proper validation and complex WHERE support
+use anyhow::{Context, Result};
+use regex::Regex;
+use serde_json::Value;
+
+use crate::config::Config;
+use crate::models::{FilterCondition, FilterOperator, OrderBy, OrderDirection};
+
+/// Parse a table reference possibly containing an alias into (base_table, alias)
+/// Accepts formats like: "table", "table alias", "table AS alias" (AS case-insensitive)
+pub fn parse_table_and_alias(input: &str) -> (String, Option<String>) {
+ let s = input.trim();
+ // Normalize multiple spaces
+ let parts: Vec<&str> = s.split_whitespace().collect();
+ if parts.is_empty() {
+ return (String::new(), None);
+ }
+
+ if parts.len() == 1 {
+ return (parts[0].to_string(), None);
+ }
+
+ if parts.len() >= 3 && parts[1].eq_ignore_ascii_case("AS") {
+ // table AS alias [ignore extra]
+ return (parts[0].to_string(), Some(parts[2].to_string()));
+ }
+
+ // table alias
+ (parts[0].to_string(), Some(parts[1].to_string()))
+}
+
+/// Validates a table name against known tables and SQL injection patterns
+pub fn validate_table_name(table: &str, config: &Config) -> Result<()> {
+ // Check if empty
+ if table.trim().is_empty() {
+ anyhow::bail!("Table name cannot be empty");
+ }
+
+ // Check against known tables list
+ let known_tables = config.get_known_tables();
+ if !known_tables.contains(&table.to_string()) {
+ anyhow::bail!("Table '{}' is not in the known tables list", table);
+ }
+
+ // Validate format: only alphanumeric and underscores, must start with letter
+ let table_regex =
+ Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table name regex")?;
+
+ if !table_regex.is_match(table) {
+ anyhow::bail!("Invalid table name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", table);
+ }
+
+ // Additional security: check for SQL keywords and dangerous patterns
+ let dangerous_patterns = [
+ "--", "/*", "*/", ";", "DROP", "ALTER", "CREATE", "EXEC", "EXECUTE",
+ ];
+ let table_upper = table.to_uppercase();
+ for pattern in &dangerous_patterns {
+ if table_upper.contains(pattern) {
+ anyhow::bail!("Table name contains forbidden pattern: '{}'", pattern);
+ }
+ }
+
+ Ok(())
+}
+
+/// Validates a column name against SQL injection patterns
+/// Supports: column, table.column, table.column as alias, table.*
+pub fn validate_column_name(column: &str) -> Result<()> {
+ // Check if empty
+ if column.trim().is_empty() {
+ anyhow::bail!("Column name cannot be empty");
+ }
+
+ // Allow * for SELECT all
+ if column == "*" {
+ return Ok(());
+ }
+
+ // Check for SQL injection patterns (comments and statement terminators)
+ let dangerous_chars = ["--", "/*", "*/", ";"];
+ for pattern in &dangerous_chars {
+ if column.contains(pattern) {
+ anyhow::bail!("Column name contains forbidden pattern: '{}'", pattern);
+ }
+ }
+
+ // Note: We don't block SQL keywords like DROP, CREATE, ALTER etc. in column names
+ // because legitimate columns like "created_date", "created_by" contain these as substrings.
+ // The regex validation below ensures only alphanumeric + underscore characters are allowed,
+ // which prevents actual SQL injection while allowing valid column names.
+
+ // Support multiple formats:
+ // 1. Simple column: column_name
+ // 2. Qualified column: table.column_name
+ // 3. Wildcard: table.*
+ // 4. Alias: column_name as alias or table.column_name as alias
+
+ // Remove AS alias if present (case insensitive)
+ let column_upper = column.to_uppercase();
+ let column_without_alias = if column_upper.contains(" AS ") {
+ column.split_whitespace().next().unwrap_or("")
+ } else {
+ column
+ };
+
+ // Check for qualified column (table.column or table.*)
+ if column_without_alias.contains('.') {
+ let parts: Vec<&str> = column_without_alias.split('.').collect();
+ if parts.len() != 2 {
+ anyhow::bail!(
+ "Invalid qualified column format: '{}'. Must be table.column",
+ column
+ );
+ }
+
+ // Validate table part
+ let table_regex =
+ Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table regex")?;
+ if !table_regex.is_match(parts[0]) {
+ anyhow::bail!("Invalid table name in qualified column: '{}'", parts[0]);
+ }
+
+ // Validate column part (allow * for table.*)
+ if parts[1] != "*" {
+ let column_regex =
+ Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile column regex")?;
+ if !column_regex.is_match(parts[1]) {
+ anyhow::bail!("Invalid column name in qualified column: '{}'", parts[1]);
+ }
+ }
+ } else {
+ // Simple column name validation
+ let column_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$")
+ .context("Failed to compile column name regex")?;
+
+ if !column_regex.is_match(column_without_alias) {
+ anyhow::bail!("Invalid column name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", column);
+ }
+ }
+
+ Ok(())
+}
+
+/// Validates multiple column names
+pub fn validate_column_names(columns: &[String]) -> Result<()> {
+ for column in columns {
+ validate_column_name(column).with_context(|| format!("Invalid column name: {}", column))?;
+ }
+ Ok(())
+}
+
+/// Build WHERE clause from enhanced FilterCondition
+/// Supports complex operators (=, !=, >, <, LIKE, IN, IS NULL, etc.) and nested logic (AND, OR, NOT)
+pub fn build_filter_clause(filter: &FilterCondition) -> Result<(String, Vec<String>)> {
+ match filter {
+ FilterCondition::Simple {
+ column,
+ operator,
+ value,
+ } => {
+ validate_column_name(column)?;
+ build_simple_condition(column, operator, value)
+ }
+ FilterCondition::Logical {
+ and_conditions,
+ or_conditions,
+ } => {
+ if let Some(and_conds) = and_conditions {
+ if and_conds.is_empty() {
+ anyhow::bail!("AND conditions array cannot be empty");
+ }
+ let mut conditions = Vec::new();
+ let mut all_values = Vec::new();
+
+ for cond in and_conds {
+ let (sql, values) = build_filter_clause(cond)?;
+ conditions.push(format!("({})", sql));
+ all_values.extend(values);
+ }
+
+ Ok((conditions.join(" AND "), all_values))
+ } else if let Some(or_conds) = or_conditions {
+ if or_conds.is_empty() {
+ anyhow::bail!("OR conditions array cannot be empty");
+ }
+ let mut conditions = Vec::new();
+ let mut all_values = Vec::new();
+
+ for cond in or_conds {
+ let (sql, values) = build_filter_clause(cond)?;
+ conditions.push(format!("({})", sql));
+ all_values.extend(values);
+ }
+
+ Ok((conditions.join(" OR "), all_values))
+ } else {
+ anyhow::bail!("Logical condition must have either 'and' or 'or' field");
+ }
+ }
+ FilterCondition::Not { not } => {
+ let (sql, values) = build_filter_clause(not)?;
+ Ok((format!("NOT ({})", sql), values))
+ }
+ }
+}
+
+/// Build a simple condition (column operator value)
+fn build_simple_condition(
+ column: &str,
+ operator: &FilterOperator,
+ value: &Value,
+) -> Result<(String, Vec<String>)> {
+ match operator {
+ FilterOperator::Eq => {
+ if value.is_null() {
+ Ok((format!("{} IS NULL", column), vec![]))
+ } else {
+ Ok((format!("{} = ?", column), vec![json_to_sql_string(value)]))
+ }
+ }
+ FilterOperator::Ne => {
+ if value.is_null() {
+ Ok((format!("{} IS NOT NULL", column), vec![]))
+ } else {
+ Ok((format!("{} != ?", column), vec![json_to_sql_string(value)]))
+ }
+ }
+ FilterOperator::Gt => Ok((format!("{} > ?", column), vec![json_to_sql_string(value)])),
+ FilterOperator::Gte => Ok((format!("{} >= ?", column), vec![json_to_sql_string(value)])),
+ FilterOperator::Lt => Ok((format!("{} < ?", column), vec![json_to_sql_string(value)])),
+ FilterOperator::Lte => Ok((format!("{} <= ?", column), vec![json_to_sql_string(value)])),
+ FilterOperator::Like => Ok((
+ format!("{} LIKE ?", column),
+ vec![json_to_sql_string(value)],
+ )),
+ FilterOperator::NotLike => Ok((
+ format!("{} NOT LIKE ?", column),
+ vec![json_to_sql_string(value)],
+ )),
+ FilterOperator::In => {
+ if let Value::Array(arr) = value {
+ if arr.is_empty() {
+ anyhow::bail!("IN operator requires non-empty array");
+ }
+ let placeholders = vec!["?"; arr.len()].join(", ");
+ let values: Vec<String> = arr.iter().map(json_to_sql_string).collect();
+ Ok((format!("{} IN ({})", column, placeholders), values))
+ } else {
+ anyhow::bail!("IN operator requires array value");
+ }
+ }
+ FilterOperator::NotIn => {
+ if let Value::Array(arr) = value {
+ if arr.is_empty() {
+ anyhow::bail!("NOT IN operator requires non-empty array");
+ }
+ let placeholders = vec!["?"; arr.len()].join(", ");
+ let values: Vec<String> = arr.iter().map(json_to_sql_string).collect();
+ Ok((format!("{} NOT IN ({})", column, placeholders), values))
+ } else {
+ anyhow::bail!("NOT IN operator requires array value");
+ }
+ }
+ FilterOperator::IsNull => {
+ // Value is ignored for IS NULL
+ Ok((format!("{} IS NULL", column), vec![]))
+ }
+ FilterOperator::IsNotNull => {
+ // Value is ignored for IS NOT NULL
+ Ok((format!("{} IS NOT NULL", column), vec![]))
+ }
+ FilterOperator::Between => {
+ if let Value::Array(arr) = value {
+ if arr.len() != 2 {
+ anyhow::bail!("BETWEEN operator requires array with exactly 2 values");
+ }
+ let val1 = json_to_sql_string(&arr[0]);
+ let val2 = json_to_sql_string(&arr[1]);
+ Ok((format!("{} BETWEEN ? AND ?", column), vec![val1, val2]))
+ } else {
+ anyhow::bail!("BETWEEN operator requires array with [min, max] values");
+ }
+ }
+ }
+}
+
+/// Convert JSON value to SQL string representation
+/// Properly handles all JSON types including booleans (true/false -> 1/0)
+fn json_to_sql_string(value: &Value) -> String {
+ match value {
+ Value::String(s) => s.clone(),
+ Value::Number(n) => n.to_string(),
+ Value::Bool(b) => {
+ if *b {
+ "1".to_string()
+ } else {
+ "0".to_string()
+ }
+ }
+ Value::Null => "NULL".to_string(),
+ _ => serde_json::to_string(value).unwrap_or_else(|_| "NULL".to_string()),
+ }
+}
+
+/// Build legacy WHERE clause from simple key-value JSON (for backward compatibility)
+pub fn build_legacy_where_clause(where_clause: &Value) -> Result<(String, Vec<String>)> {
+ let mut conditions = Vec::new();
+ let mut values = Vec::new();
+
+ if let Value::Object(map) = where_clause {
+ for (key, value) in map {
+ validate_column_name(key)?;
+
+ if value.is_null() {
+ // Handle NULL values with IS NULL
+ conditions.push(format!("{} IS NULL", key));
+ // Don't add to values since IS NULL doesn't need a parameter
+ } else {
+ conditions.push(format!("{} = ?", key));
+ values.push(json_to_sql_string(value));
+ }
+ }
+ } else {
+ anyhow::bail!("WHERE clause must be an object");
+ }
+
+ if conditions.is_empty() {
+ anyhow::bail!("WHERE clause cannot be empty");
+ }
+
+ Ok((conditions.join(" AND "), values))
+}
+
+/// Build ORDER BY clause with column validation
+pub fn build_order_by_clause(order_by: &[OrderBy]) -> Result<String> {
+ if order_by.is_empty() {
+ return Ok(String::new());
+ }
+
+ let mut clauses = Vec::new();
+ for order in order_by {
+ validate_column_name(&order.column)?;
+ let direction = match order.direction {
+ OrderDirection::ASC => "ASC",
+ OrderDirection::DESC => "DESC",
+ };
+ clauses.push(format!("{} {}", order.column, direction));
+ }
+
+ Ok(format!(" ORDER BY {}", clauses.join(", ")))
+}
+
+/// Build JOIN clause from Join specifications
+/// Validates table names and join conditions for security
+pub fn build_join_clause(joins: &[crate::models::Join], config: &Config) -> Result<String> {
+ if joins.is_empty() {
+ return Ok(String::new());
+ }
+
+ let mut join_sql = String::new();
+
+ for join in joins {
+ // Extract base table and optional alias
+ let (base_table, alias) = parse_table_and_alias(&join.table);
+
+ // Validate joined base table name
+ validate_table_name(&base_table, config)?;
+
+ // Optionally validate alias format (same rules as table/column names)
+ if let Some(alias_name) = &alias {
+ let alias_regex =
+ Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile alias regex")?;
+ if !alias_regex.is_match(alias_name) {
+ anyhow::bail!("Invalid table alias format: '{}'", alias_name);
+ }
+ }
+
+ // Validate join condition (must be in format "table1.column1 = table2.column2")
+ validate_join_condition(&join.on)?;
+
+ // Build JOIN clause based on type
+ let join_type_str = match join.join_type {
+ crate::models::JoinType::Inner => "INNER JOIN",
+ crate::models::JoinType::Left => "LEFT JOIN",
+ crate::models::JoinType::Right => "RIGHT JOIN",
+ };
+
+ // Reconstruct safe table reference
+ let table_ref = match alias {
+ Some(a) => format!("{} AS {}", base_table, a),
+ None => base_table,
+ };
+ join_sql.push_str(&format!(" {} {} ON {}", join_type_str, table_ref, join.on));
+ }
+
+ Ok(join_sql)
+}
+
+/// Validate JOIN ON condition
+/// Must be in format: "table1.column1 = table2.column2" or similar simple conditions
+fn validate_join_condition(condition: &str) -> Result<()> {
+ // Basic validation: must contain = and table.column references
+ if !condition.contains('=') {
+ anyhow::bail!("JOIN condition must contain = operator");
+ }
+
+ // Split by = and validate both sides
+ let parts: Vec<&str> = condition.split('=').map(|s| s.trim()).collect();
+ if parts.len() != 2 {
+ anyhow::bail!("JOIN condition must have exactly one = operator");
+ }
+
+ // Validate both sides are table.column format
+ for part in parts {
+ validate_table_column_reference(part)?;
+ }
+
+ Ok(())
+}
+
+/// Validate table.column reference (e.g., "assets.category_id")
+fn validate_table_column_reference(reference: &str) -> Result<()> {
+ let parts: Vec<&str> = reference.split('.').collect();
+
+ if parts.len() != 2 {
+ anyhow::bail!(
+ "Table column reference must be in format 'table.column', got: {}",
+ reference
+ );
+ }
+
+ let table = parts[0].trim();
+ let column = parts[1].trim();
+
+ // Validate table and column names follow safe patterns
+ let name_regex =
+ Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile name regex")?;
+
+ if !name_regex.is_match(table) {
+ anyhow::bail!("Invalid table name in JOIN condition: {}", table);
+ }
+
+ if !name_regex.is_match(column) {
+ anyhow::bail!("Invalid column name in JOIN condition: {}", column);
+ }
+
+ // Additional guard against comment/terminator tokens (redundant given regex, but safe)
+ // Intentionally do NOT block SQL keywords like CREATE/ALTER since identifiers like
+ // 'created_at' are legitimate and already validated by the regex above.
+ let dangerous_tokens = ["--", "/*", "*/", ";"];
+ for token in &dangerous_tokens {
+ if reference.contains(token) {
+ anyhow::bail!("JOIN condition contains forbidden token: '{}'", token);
+ }
+ }
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_validate_table_name() {
+ // Valid names
+ assert!(validate_column_name("users").is_ok());
+ assert!(validate_column_name("asset_id").is_ok());
+ assert!(validate_column_name("table123").is_ok());
+
+ // Invalid names
+ assert!(validate_column_name("123table").is_err()); // starts with number
+ assert!(validate_column_name("table-name").is_err()); // contains hyphen
+ assert!(validate_column_name("table name").is_err()); // contains space
+ assert!(validate_column_name("table;DROP").is_err()); // SQL injection
+ assert!(validate_column_name("").is_err()); // empty
+ }
+
+ #[test]
+ fn test_validate_column_name() {
+ // Valid names
+ assert!(validate_column_name("user_id").is_ok());
+ assert!(validate_column_name("firstName").is_ok());
+ assert!(validate_column_name("*").is_ok()); // wildcard allowed
+
+ // Invalid names
+ assert!(validate_column_name("123column").is_err());
+ assert!(validate_column_name("col-umn").is_err());
+ assert!(validate_column_name("col umn").is_err());
+ assert!(validate_column_name("").is_err());
+ }
+}
diff --git a/src/sql/mod.rs b/src/sql/mod.rs
new file mode 100644
index 0000000..fd5b959
--- /dev/null
+++ b/src/sql/mod.rs
@@ -0,0 +1,8 @@
+// SQL query building and validation module
+pub mod builder;
+
+// Re-export commonly used functions
+pub use builder::{
+ build_filter_clause, build_join_clause, build_legacy_where_clause, build_order_by_clause,
+ validate_column_name, validate_column_names, validate_table_name,
+};