diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/auth/mod.rs | 6 | ||||
| -rw-r--r-- | src/auth/password.rs | 56 | ||||
| -rw-r--r-- | src/auth/pin.rs | 66 | ||||
| -rw-r--r-- | src/auth/session.rs | 129 | ||||
| -rw-r--r-- | src/auth/token.rs | 99 | ||||
| -rw-r--r-- | src/config.rs | 845 | ||||
| -rw-r--r-- | src/db/mod.rs | 3 | ||||
| -rw-r--r-- | src/db/pool.rs | 110 | ||||
| -rw-r--r-- | src/logging/logger.rs | 373 | ||||
| -rw-r--r-- | src/logging/mod.rs | 3 | ||||
| -rw-r--r-- | src/main.rs | 359 | ||||
| -rw-r--r-- | src/models/mod.rs | 294 | ||||
| -rw-r--r-- | src/permissions/mod.rs | 3 | ||||
| -rw-r--r-- | src/permissions/rbac.rs | 119 | ||||
| -rw-r--r-- | src/routes/auth.rs | 513 | ||||
| -rw-r--r-- | src/routes/mod.rs | 111 | ||||
| -rw-r--r-- | src/routes/preferences.rs | 423 | ||||
| -rw-r--r-- | src/routes/query.rs | 3255 | ||||
| -rw-r--r-- | src/scheduler.rs | 185 | ||||
| -rw-r--r-- | src/sql/builder.rs | 493 | ||||
| -rw-r--r-- | src/sql/mod.rs | 8 |
21 files changed, 7453 insertions, 0 deletions
diff --git a/src/auth/mod.rs b/src/auth/mod.rs new file mode 100644 index 0000000..65cba10 --- /dev/null +++ b/src/auth/mod.rs @@ -0,0 +1,6 @@ +pub mod password; +pub mod pin; +pub mod session; +pub mod token; + +pub use session::SessionManager; diff --git a/src/auth/password.rs b/src/auth/password.rs new file mode 100644 index 0000000..580ad1f --- /dev/null +++ b/src/auth/password.rs @@ -0,0 +1,56 @@ +// Password authentication module +use crate::models::{Role, User}; +use anyhow::{Context, Result}; +use bcrypt::verify; +use sqlx::MySqlPool; + +pub async fn authenticate_password( + pool: &MySqlPool, + username: &str, + password: &str, +) -> Result<Option<(User, Role)>> { + // Fetch user from database + let user: Option<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE username = ? AND active = TRUE + "#, + ) + .bind(username) + .fetch_optional(pool) + .await + .context("Failed to fetch user from database")?; + + if let Some(user) = user { + // Verify password + let password_valid = + verify(password, &user.password).context("Failed to verify password")?; + + if password_valid { + // Fetch user's role + let role: Role = sqlx::query_as::<_, Role>( + "SELECT id, name, power, created_at FROM roles WHERE id = ?", + ) + .bind(user.role_id) + .fetch_one(pool) + .await + .context("Failed to fetch user role")?; + + // Update last login date + sqlx::query("UPDATE users SET last_login_date = NOW() WHERE id = ?") + .bind(user.id) + .execute(pool) + .await + .context("Failed to update last login date")?; + + Ok(Some((user, role))) + } else { + Ok(None) + } + } else { + Ok(None) + } +} diff --git a/src/auth/pin.rs b/src/auth/pin.rs new file mode 100644 index 0000000..4d1993c --- /dev/null +++ b/src/auth/pin.rs @@ -0,0 +1,66 @@ +// PIN authentication module +use crate::config::SecurityConfig; +use crate::models::{Role, User}; +use anyhow::{Context, Result}; +use sqlx::MySqlPool; + +pub async fn authenticate_pin( + pool: &MySqlPool, + username: &str, + pin: &str, + security_config: &SecurityConfig, +) -> Result<Option<(User, Role)>> { + // Fetch user from database + let user: Option<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE username = ? AND active = TRUE AND pin_code IS NOT NULL + "#, + ) + .bind(username) + .fetch_optional(pool) + .await + .context("Failed to fetch user from database")?; + + if let Some(user) = user { + // Check if user has a PIN set + if let Some(user_pin) = &user.pin_code { + // Verify PIN - either bcrypt hash or plaintext depending on config + let pin_valid = if security_config.hash_pins { + bcrypt::verify(pin, user_pin).unwrap_or(false) + } else { + user_pin == pin + }; + + if pin_valid { + // Fetch user's role + let role: Role = sqlx::query_as::<_, Role>( + "SELECT id, name, power, created_at FROM roles WHERE id = ?", + ) + .bind(user.role_id) + .fetch_one(pool) + .await + .context("Failed to fetch user role")?; + + // Update last login date + sqlx::query("UPDATE users SET last_login_date = NOW() WHERE id = ?") + .bind(user.id) + .execute(pool) + .await + .context("Failed to update last login date")?; + + Ok(Some((user, role))) + } else { + Ok(None) + } + } else { + // User doesn't have a PIN set + Ok(None) + } + } else { + Ok(None) + } +} diff --git a/src/auth/session.rs b/src/auth/session.rs new file mode 100644 index 0000000..277cef2 --- /dev/null +++ b/src/auth/session.rs @@ -0,0 +1,129 @@ +// Session management for SeckelAPI +use crate::config::Config; +use crate::models::Session; +use chrono::{Duration, Utc}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use uuid::Uuid; + +#[derive(Clone)] +pub struct SessionManager { + sessions: Arc<RwLock<HashMap<String, Session>>>, + config: Arc<Config>, +} + +impl SessionManager { + pub fn new(config: Arc<Config>) -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + config, + } + } + + fn get_timeout_for_power(&self, power: i32) -> u64 { + self.config.get_session_timeout(power) + } + + pub fn create_session( + &self, + user_id: i32, + username: String, + role_id: i32, + role_name: String, + power: i32, + ) -> String { + let token = Uuid::new_v4().to_string(); + let now = Utc::now(); + + let session = Session { + user_id, + username, + role_id, + role_name, + power, + created_at: now, + last_accessed: now, + }; + + if let Ok(mut sessions) = self.sessions.write() { + // Check concurrent session limit for this user + let max_sessions = self.config.get_max_concurrent_sessions(power); + let user_sessions: Vec<(String, chrono::DateTime<Utc>)> = sessions + .iter() + .filter(|(_, s)| s.user_id == user_id) + .map(|(token, s)| (token.clone(), s.created_at)) + .collect(); + + // If at limit, remove the oldest session + if user_sessions.len() >= max_sessions as usize { + if let Some(oldest_token) = user_sessions + .iter() + .min_by_key(|(_, created)| created) + .map(|(token, _)| token.clone()) + { + sessions.remove(&oldest_token); + } + } + + sessions.insert(token.clone(), session); + } + + token + } + + pub fn get_session(&self, token: &str) -> Option<Session> { + let mut session_to_update = None; + + { + if let Ok(sessions) = self.sessions.read() { + if let Some(session) = sessions.get(token) { + let now = Utc::now(); + let timeout_for_user = self.get_timeout_for_power(session.power); + let timeout_duration = Duration::minutes(timeout_for_user as i64); + + // Check if session has expired + if now - session.last_accessed > timeout_duration { + return None; // Session expired, will be cleaned up later + } else { + session_to_update = Some(session.clone()); + } + } + } + } + + if let Some(mut session) = session_to_update { + // Update last accessed time only if refresh_on_activity is enabled + if self.config.security.refresh_session_on_activity { + session.last_accessed = Utc::now(); + + if let Ok(mut sessions) = self.sessions.write() { + sessions.insert(token.to_string(), session.clone()); + } + } + + Some(session) + } else { + None + } + } + + pub fn remove_session(&self, token: &str) -> bool { + if let Ok(mut sessions) = self.sessions.write() { + sessions.remove(token).is_some() + } else { + false + } + } + + pub fn cleanup_expired_sessions(&self) { + let now = Utc::now(); + + if let Ok(mut sessions) = self.sessions.write() { + sessions.retain(|_, session| { + let timeout_for_user = self.get_timeout_for_power(session.power); + let timeout_duration = Duration::minutes(timeout_for_user as i64); + now - session.last_accessed <= timeout_duration + }); + } + } +} diff --git a/src/auth/token.rs b/src/auth/token.rs new file mode 100644 index 0000000..17f75e4 --- /dev/null +++ b/src/auth/token.rs @@ -0,0 +1,99 @@ +// Token/RFID authentication module +use crate::config::SecurityConfig; +use crate::models::{Role, User}; +use anyhow::{Context, Result}; +use sqlx::MySqlPool; + +pub async fn authenticate_token( + pool: &MySqlPool, + login_string: &str, + security_config: &SecurityConfig, +) -> Result<Option<(User, Role)>> { + // If hashing is enabled, we can't use WHERE login_string = ? directly + // Need to fetch all users and verify hashes + if security_config.hash_tokens { + // Fetch all active users with login_string set + let users: Vec<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE login_string IS NOT NULL AND active = TRUE + "#, + ) + .fetch_all(pool) + .await + .context("Failed to fetch users from database")?; + + // Find matching user by verifying bcrypt hash + for user in users { + if let Some(ref stored_hash) = user.login_string { + if bcrypt::verify(login_string, stored_hash).unwrap_or(false) { + // Found matching user + return authenticate_user_by_id(pool, user.id).await; + } + } + } + Ok(None) + } else { + // Plaintext comparison - direct database query + let user: Option<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE login_string = ? AND active = TRUE + "#, + ) + .bind(login_string) + .fetch_optional(pool) + .await + .context("Failed to fetch user from database")?; + + if let Some(user) = user { + authenticate_user_by_id(pool, user.id).await + } else { + Ok(None) + } + } +} + +async fn authenticate_user_by_id(pool: &MySqlPool, user_id: i32) -> Result<Option<(User, Role)>> { + // Fetch user + let user: User = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE id = ? + "#, + ) + .bind(user_id) + .fetch_one(pool) + .await + .context("Failed to fetch user")?; + + if user.active { + // Fetch user's role + let role: Role = + sqlx::query_as::<_, Role>("SELECT id, name, power, created_at FROM roles WHERE id = ?") + .bind(user.role_id) + .fetch_one(pool) + .await + .context("Failed to fetch user role")?; + + // Update last login date + sqlx::query("UPDATE users SET last_login_date = NOW() WHERE id = ?") + .bind(user.id) + .execute(pool) + .await + .context("Failed to update last login date")?; + + Ok(Some((user, role))) + } else { + Ok(None) + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..602c0d8 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,845 @@ +// Configuration module for SeckelAPI +use anyhow::{Context, Result}; +use ipnet::IpNet; +use serde::Deserialize; +use std::collections::HashMap; +use std::fs; +use std::str::FromStr; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub database: DatabaseConfig, + pub server: ServerConfig, + pub security: SecurityConfig, + pub permissions: PermissionsConfig, + pub logging: LoggingConfig, + pub auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + pub scheduled_queries: Option<ScheduledQueriesConfig>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PermissionsConfig { + #[serde(flatten)] + pub power_levels: HashMap<String, PowerLevelPermissions>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PowerLevelPermissions { + pub basic_rules: Vec<String>, + pub advanced_rules: Option<Vec<String>>, + pub max_limit: Option<u32>, + pub max_where_conditions: Option<u32>, + pub session_timeout_minutes: Option<u64>, + pub max_concurrent_sessions: Option<u32>, + #[serde(default = "default_true")] + pub rollback_on_error: bool, + #[serde(default = "default_false")] + pub allow_batch_operations: bool, + #[serde(default = "default_user_settings_access")] + pub user_settings_access: UserSettingsAccess, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum UserSettingsAccess { + ReadOwnOnly, + ReadWriteOwn, + ReadWriteAll, +} + +fn default_user_settings_access() -> UserSettingsAccess { + UserSettingsAccess::ReadWriteOwn +} + +#[derive(Debug, Clone, Deserialize)] +pub struct DatabaseConfig { + pub host: String, + pub port: u16, + pub database: String, + pub username: String, + pub password: String, + #[serde(default = "default_min_connections")] + pub min_connections: u32, + #[serde(default = "default_max_connections")] + pub max_connections: u32, + #[serde(default = "default_connection_timeout_seconds")] + pub connection_timeout_seconds: u64, + #[serde(default = "default_connection_timeout_wait")] + pub connection_timeout_wait: u64, + #[serde(default = "default_connection_check_interval")] + pub connection_check: u64, +} + +fn default_min_connections() -> u32 { + 1 +} + +fn default_max_connections() -> u32 { + 10 +} + +fn default_connection_timeout_seconds() -> u64 { + 30 +} + +fn default_connection_timeout_wait() -> u64 { + 5 +} + +fn default_connection_check_interval() -> u64 { + 30 +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ServerConfig { + pub host: String, + pub port: u16, + #[serde(default = "default_request_body_limit_mb")] + pub request_body_limit_mb: usize, +} + +fn default_request_body_limit_mb() -> usize { + 10 // 10 MB default +} + +#[derive(Debug, Clone, Deserialize)] +pub struct SecurityConfig { + pub whitelisted_pin_ips: Vec<String>, + pub whitelisted_string_ips: Vec<String>, + pub session_timeout_minutes: u64, + pub refresh_session_on_activity: bool, + pub max_concurrent_sessions: u32, + pub session_cleanup_interval_minutes: u64, + pub default_max_limit: u32, + pub default_max_where_conditions: u32, + #[serde(default = "default_hash_pins")] + pub hash_pins: bool, // Whether to use bcrypt for PINs (false = plaintext) + #[serde(default = "default_hash_tokens")] + pub hash_tokens: bool, // Whether to use bcrypt for login_strings (false = plaintext) + #[serde(default = "default_pin_column")] + pub pin_column: String, // Database column name for PINs + #[serde(default = "default_token_column")] + pub token_column: String, // Database column name for login strings + // Rate limiting + #[serde(default = "default_enable_rate_limiting")] + pub enable_rate_limiting: bool, // Master switch for rate limiting (disable for debugging) + + // Auth rate limiting + #[serde(default = "default_auth_rate_limit_per_minute")] + pub auth_rate_limit_per_minute: u32, // Max auth requests per IP per minute + #[serde(default = "default_auth_rate_limit_per_second")] + pub auth_rate_limit_per_second: u32, // Max auth requests per IP per second (burst protection) + + // API rate limiting + #[serde(default = "default_api_rate_limit_per_minute")] + pub api_rate_limit_per_minute: u32, // Max API calls per user per minute + #[serde(default = "default_api_rate_limit_per_second")] + pub api_rate_limit_per_second: u32, // Max API calls per user per second (burst protection) + + // Table configuration (moved from basics.toml) + #[serde(default = "default_known_tables")] + pub known_tables: Vec<String>, + #[serde(default = "default_read_only_tables")] + pub read_only_tables: Vec<String>, + #[serde(default = "default_global_write_protected_columns")] + pub global_write_protected_columns: Vec<String>, + + // User preferences access control + #[serde(default = "default_user_settings_access")] + pub default_user_settings_access: UserSettingsAccess, +} + +fn default_known_tables() -> Vec<String> { + vec![] // Empty by default, must be configured +} + +fn default_read_only_tables() -> Vec<String> { + vec![] // Empty by default +} + +fn default_global_write_protected_columns() -> Vec<String> { + vec!["id".to_string()] // Protect 'id' by default at minimum +} + +fn default_hash_pins() -> bool { + false // Default to plaintext (must be explicitly enabled for bcrypt hashing) +} + +fn default_hash_tokens() -> bool { + false // Default to plaintext (must be explicitly enabled for bcrypt hashing) +} + +fn default_pin_column() -> String { + "pin_code".to_string() +} + +fn default_token_column() -> String { + "login_string".to_string() +} + +fn default_enable_rate_limiting() -> bool { + true // Enable by default for security +} + +fn default_auth_rate_limit_per_minute() -> u32 { + 10 // 10 login attempts per IP per minute (prevents brute force) +} + +fn default_auth_rate_limit_per_second() -> u32 { + 5 // Max 5 login attempts per second (burst protection) +} + +fn default_api_rate_limit_per_minute() -> u32 { + 60 // 60 API calls per user per minute +} + +fn default_api_rate_limit_per_second() -> u32 { + 10 // Max 10 API calls per second (burst protection) +} + +#[derive(Debug, Clone, Deserialize)] +pub struct LoggingConfig { + pub request_log: Option<String>, + pub query_log: Option<String>, + pub error_log: Option<String>, + pub warning_log: Option<String>, // Warning messages + pub info_log: Option<String>, // Info messages + pub combined_log: Option<String>, // Unified log with request IDs + pub level: String, + pub mask_passwords: bool, + #[serde(default)] + pub sensitive_fields: Vec<String>, // Fields to mask beyond password/pin + #[serde(default)] + pub custom_filters: Vec<CustomLogFilter>, // Regex-based log routing +} + +#[derive(Debug, Clone, Deserialize)] +pub struct CustomLogFilter { + pub name: String, + pub output_file: String, + pub pattern: String, // Regex pattern to match + #[serde(default = "default_filter_enabled")] + pub enabled: bool, +} + +fn default_filter_enabled() -> bool { + true +} + +fn default_true() -> bool { + true +} + +fn default_false() -> bool { + false +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AutoGenerationConfig { + pub field: String, + #[serde(rename = "type")] + pub gen_type: String, + pub length: Option<u32>, + pub range_min: Option<u64>, + pub range_max: Option<u64>, + pub max_attempts: Option<u32>, + #[serde(default = "default_on_action")] + pub on_action: String, // "insert", "update", or "both" +} + +fn default_on_action() -> String { + "insert".to_string() +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ScheduledQueriesConfig { + #[serde(default)] + pub tasks: Vec<ScheduledQueryTask>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ScheduledQueryTask { + pub name: String, + pub description: String, + pub query: String, + pub interval_minutes: u64, + #[serde(default = "default_enabled")] + pub enabled: bool, + #[serde(default = "default_run_on_startup")] + pub run_on_startup: bool, +} + +fn default_enabled() -> bool { + true +} + +fn default_run_on_startup() -> bool { + true // Run immediately on startup by default +} + +impl Config { + pub fn load() -> Result<Self> { + Self::load_from_folder() + } + + fn load_from_folder() -> Result<Self> { + // Load consolidated config files + let basics_content = + fs::read_to_string("config/basics.toml").context("Can't read config/basics.toml")?; + let security_content = fs::read_to_string("config/security.toml") + .context("Can't read config/security.toml")?; + let logging_content = + fs::read_to_string("config/logging.toml").context("Can't read config/logging.toml")?; + + // Parse individual sections + #[derive(Deserialize)] + struct BasicsWrapper { + server: ServerConfig, + database: DatabaseConfig, + #[serde(default)] + auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + } + #[derive(Deserialize)] + struct SecurityWrapper { + security: SecurityConfig, + permissions: PermissionsConfig, + } + #[derive(Deserialize)] + struct LoggingWrapper { + logging: LoggingConfig, + } + + #[derive(Deserialize)] + struct FunctionsWrapper { + #[serde(default)] + auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + #[serde(default)] + scheduled_queries: Option<ScheduledQueriesConfig>, + } + + let basics: BasicsWrapper = + toml::from_str(&basics_content).context("Failed to parse basics.toml")?; + let security: SecurityWrapper = + toml::from_str(&security_content).context("Failed to parse security.toml")?; + let logging: LoggingWrapper = + toml::from_str(&logging_content).context("Failed to parse logging.toml")?; + + // Load functions.toml if it exists, otherwise use basics fallback + let functions: FunctionsWrapper = if fs::metadata("config/functions.toml").is_ok() { + let functions_content = fs::read_to_string("config/functions.toml") + .context("Can't read config/functions.toml")?; + toml::from_str(&functions_content).context("Failed to parse functions.toml")? + } else { + // Fallback to basics.toml for auto_generation + FunctionsWrapper { + auto_generation: basics.auto_generation.clone(), + scheduled_queries: None, + } + }; + + let config = Config { + database: basics.database, + server: basics.server, + security: security.security, + permissions: security.permissions, + logging: logging.logging, + auto_generation: functions.auto_generation, + scheduled_queries: functions.scheduled_queries, + }; + + // Validate configuration + config.validate()?; + + Ok(config) + } + + /// Validate configuration values + fn validate(&self) -> Result<()> { + // Validate server port (u16 is already limited to 0-65535, just check for 0) + if self.server.port == 0 { + anyhow::bail!( + "Invalid server.port: {} (must be 1-65535)", + self.server.port + ); + } + + // Validate database port (u16 is already limited to 0-65535, just check for 0) + if self.database.port == 0 { + anyhow::bail!( + "Invalid database.port: {} (must be 1-65535)", + self.database.port + ); + } + + // Validate database connection details + if self.database.host.trim().is_empty() { + anyhow::bail!("database.host cannot be empty"); + } + if self.database.database.trim().is_empty() { + anyhow::bail!("database.database cannot be empty"); + } + if self.database.username.trim().is_empty() { + anyhow::bail!("database.username cannot be empty"); + } + + // Validate PIN whitelist IPs (can be CIDR or single IPs) + for ip in &self.security.whitelisted_pin_ips { + if IpNet::from_str(ip).is_err() && ip.parse::<std::net::IpAddr>().is_err() { + anyhow::bail!("Invalid IP/CIDR in whitelisted_pin_ips: {}", ip); + } + } + + // Validate string auth whitelist IPs (can be CIDR or single IPs) + for ip in &self.security.whitelisted_string_ips { + if IpNet::from_str(ip).is_err() && ip.parse::<std::net::IpAddr>().is_err() { + anyhow::bail!("Invalid IP/CIDR in whitelisted_string_ips: {}", ip); + } + } + + // Validate session timeout + if self.security.session_timeout_minutes == 0 { + anyhow::bail!("security.session_timeout_minutes must be greater than 0"); + } + + // Validate permission syntax + for (power_level, perms) in &self.permissions.power_levels { + // Validate power level is numeric + if power_level.parse::<i32>().is_err() { + anyhow::bail!("Invalid power level '{}': must be numeric", power_level); + } + + // Validate basic rules format + for rule in &perms.basic_rules { + Self::validate_permission_rule(rule).with_context(|| { + format!("Invalid basic rule for power level {}", power_level) + })?; + } + + // Validate advanced rules format + if let Some(advanced_rules) = &perms.advanced_rules { + for rule in advanced_rules { + Self::validate_advanced_rule(rule).with_context(|| { + format!("Invalid advanced rule for power level {}", power_level) + })?; + } + } + } + + // Validate logging configuration (paths are now optional) + // No validation needed for optional log paths + + // Validate log level + let valid_levels = ["trace", "debug", "info", "warn", "error"]; + if !valid_levels.contains(&self.logging.level.to_lowercase().as_str()) { + anyhow::bail!( + "Invalid logging.level '{}': must be one of: {}", + self.logging.level, + valid_levels.join(", ") + ); + } + + Ok(()) + } + + /// Validate basic permission rule format (table:permission) + fn validate_permission_rule(rule: &str) -> Result<()> { + let parts: Vec<&str> = rule.split(':').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Permission rule '{}' must be in format 'table:permission'", + rule + ); + } + + let table = parts[0]; + let permission = parts[1]; + + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty in rule '{}'", rule); + } + + // Validate permission type + let valid_permissions = ["r", "rw", "rwd"]; + if !valid_permissions.contains(&permission) { + anyhow::bail!( + "Invalid permission '{}' in rule '{}': must be one of: {}", + permission, + rule, + valid_permissions.join(", ") + ); + } + + Ok(()) + } + + /// Validate advanced permission rule format (table.column:permission) + fn validate_advanced_rule(rule: &str) -> Result<()> { + let parts: Vec<&str> = rule.split(':').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Advanced rule '{}' must be in format 'table.column:permission'", + rule + ); + } + + let table_col = parts[0]; + let permission = parts[1]; + + // Validate table.column format + let col_parts: Vec<&str> = table_col.split('.').collect(); + if col_parts.len() != 2 { + anyhow::bail!("Advanced rule '{}' must have 'table.column' format", rule); + } + + let table = col_parts[0]; + let column = col_parts[1]; + + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty in rule '{}'", rule); + } + if column.trim().is_empty() { + anyhow::bail!("Column name cannot be empty in rule '{}'", rule); + } + + // Validate permission type + let valid_permissions = ["r", "w", "rw", "block"]; + if !valid_permissions.contains(&permission) { + anyhow::bail!( + "Invalid permission '{}' in rule '{}': must be one of: {}", + permission, + rule, + valid_permissions.join(", ") + ); + } + + Ok(()) + } + + pub fn get_database_url(&self) -> String { + format!( + "mysql://{}:{}@{}:{}/{}", + self.database.username, + self.database.password, + self.database.host, + self.database.port, + self.database.database + ) + } + + pub fn is_pin_ip_whitelisted(&self, ip: &str) -> bool { + self.is_ip_in_whitelist(ip, &self.security.whitelisted_pin_ips) + } + + pub fn is_string_ip_whitelisted(&self, ip: &str) -> bool { + self.is_ip_in_whitelist(ip, &self.security.whitelisted_string_ips) + } + + fn is_ip_in_whitelist(&self, ip: &str, whitelist: &[String]) -> bool { + let client_ip = match ip.parse::<std::net::IpAddr>() { + Ok(addr) => addr, + Err(_) => return false, + }; + + for allowed in whitelist { + // Try to parse as a network (CIDR notation) + if let Ok(network) = IpNet::from_str(allowed) { + if network.contains(&client_ip) { + return true; + } + } + // Try to parse as a single IP address + else if let Ok(allowed_ip) = allowed.parse::<std::net::IpAddr>() { + if client_ip == allowed_ip { + return true; + } + } + } + false + } + + pub fn get_role_permissions(&self, power: i32) -> Vec<(String, String)> { + let power_str = power.to_string(); + + if let Some(power_perms) = self.permissions.power_levels.get(&power_str) { + power_perms + .basic_rules + .iter() + .filter_map(|perm| { + let parts: Vec<&str> = perm.split(':').collect(); + if parts.len() == 2 { + Some((parts[0].to_string(), parts[1].to_string())) + } else { + None + } + }) + .collect() + } else { + Vec::new() + } + } + + // Helper methods for new configuration sections + pub fn get_known_tables(&self) -> Vec<String> { + if self.security.known_tables.is_empty() { + tracing::warn!("No known_tables configured in security.toml - returning empty list. Wildcard permissions (*) will not work."); + } + self.security.known_tables.clone() + } + + pub fn is_read_only_table(&self, table: &str) -> bool { + self.security.read_only_tables.contains(&table.to_string()) + } + + pub fn get_auto_generation_config(&self, table: &str) -> Option<&AutoGenerationConfig> { + self.auto_generation + .as_ref() + .and_then(|configs| configs.get(table)) + } + + pub fn get_basic_permissions(&self, power: i32) -> Option<&Vec<String>> { + self.permissions + .power_levels + .get(&power.to_string()) + .map(|p| &p.basic_rules) + } + + pub fn get_advanced_permissions(&self, power: i32) -> Option<&Vec<String>> { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.advanced_rules.as_ref()) + } + + pub fn filter_readable_columns( + &self, + power: i32, + table: &str, + requested_columns: &[String], + ) -> Vec<String> { + if let Some(advanced_rules) = self.get_advanced_permissions(power) { + let mut allowed_columns = Vec::new(); + let mut blocked_columns = Vec::new(); + let mut has_wildcard_block = false; + let mut has_wildcard_allow = false; + + // Parse advanced rules for this table + for rule in advanced_rules { + if let Some((table_col, permission)) = rule.split_once(':') { + if let Some((rule_table, column)) = table_col.split_once('.') { + if rule_table == table { + match permission { + "block" => { + if column == "*" { + has_wildcard_block = true; + } else { + blocked_columns.push(column.to_string()); + } + } + "r" | "rw" => { + if column == "*" { + has_wildcard_allow = true; + } else { + allowed_columns.push(column.to_string()); + } + } + _ => {} + } + } + } + } + } + + // Filter requested columns based on rules + let mut result = Vec::new(); + for column in requested_columns { + let allow = if has_wildcard_block { + // If wildcard block, only allow specifically allowed columns + allowed_columns.contains(column) + } else if has_wildcard_allow { + // If wildcard allow, block only specifically blocked columns + !blocked_columns.contains(column) + } else { + // No wildcard rules, block specifically blocked columns + !blocked_columns.contains(column) + }; + + if allow { + result.push(column.clone()); + } + } + + result + } else { + // No advanced rules, return all requested columns + requested_columns.to_vec() + } + } + + pub fn filter_writable_columns( + &self, + power: i32, + table: &str, + requested_columns: &[String], + ) -> Vec<String> { + // First, apply global write-protected columns (these override everything) + let mut globally_blocked: Vec<String> = + self.security.global_write_protected_columns.clone(); + + if let Some(advanced_rules) = self.get_advanced_permissions(power) { + let mut allowed_columns = Vec::new(); + let mut blocked_columns = Vec::new(); + let mut has_wildcard_block = false; + let mut has_wildcard_allow = false; + + // Parse advanced rules for this table + for rule in advanced_rules { + if let Some((table_col, permission)) = rule.split_once(':') { + if let Some((rule_table, column)) = table_col.split_once('.') { + if rule_table == table { + match permission { + "block" => { + if column == "*" { + has_wildcard_block = true; + } else { + blocked_columns.push(column.to_string()); + } + } + "w" | "rw" => { + if column == "*" { + has_wildcard_allow = true; + } else { + allowed_columns.push(column.to_string()); + } + } + "r" => { + // Read-only: block from writing but not from reading + blocked_columns.push(column.to_string()); + } + _ => {} + } + } + } + } + } + + // Merge advanced_rules blocked columns with globally protected + blocked_columns.append(&mut globally_blocked); + + // Filter requested columns based on rules + let mut result = Vec::new(); + for column in requested_columns { + let allow = if has_wildcard_block { + // If wildcard block, only allow specifically allowed columns + allowed_columns.contains(column) + } else if has_wildcard_allow { + // If wildcard allow, block only specifically blocked columns + !blocked_columns.contains(column) + } else { + // No wildcard rules, block specifically blocked columns + !blocked_columns.contains(column) + }; + + if allow { + result.push(column.clone()); + } + } + + result + } else { + // No advanced rules, just filter out globally protected columns + requested_columns + .iter() + .filter(|col| !globally_blocked.contains(col)) + .cloned() + .collect() + } + } + + /// Get the max_limit for a specific power level (with fallback to next lower power level) + pub fn get_max_limit(&self, power: i32) -> u32 { + // Try exact match first + if let Some(perms) = self.permissions.power_levels.get(&power.to_string()) { + if let Some(limit) = perms.max_limit { + return limit; + } + } + + // Find next lower power level + let fallback_power = self.find_fallback_power_level(power); + if let Some(fb_power) = fallback_power { + tracing::warn!( + "Power level {} not found in config, falling back to power level {}", + power, + fb_power + ); + if let Some(perms) = self.permissions.power_levels.get(&fb_power.to_string()) { + if let Some(limit) = perms.max_limit { + return limit; + } + } + } + + // Ultimate fallback to default + self.security.default_max_limit + } + + /// Get the max_where_conditions for a specific power level (with fallback to next lower power level) + pub fn get_max_where_conditions(&self, power: i32) -> u32 { + // Try exact match first + if let Some(perms) = self.permissions.power_levels.get(&power.to_string()) { + if let Some(max_where) = perms.max_where_conditions { + return max_where; + } + } + + // Find next lower power level + let fallback_power = self.find_fallback_power_level(power); + if let Some(fb_power) = fallback_power { + tracing::warn!( + "Power level {} not found in config, falling back to power level {}", + power, + fb_power + ); + if let Some(perms) = self.permissions.power_levels.get(&fb_power.to_string()) { + if let Some(max_where) = perms.max_where_conditions { + return max_where; + } + } + } + + // Ultimate fallback to default + self.security.default_max_where_conditions + } + + /// Find the next lower configured power level (e.g., power=60 → fallback to 50) + fn find_fallback_power_level(&self, power: i32) -> Option<i32> { + let mut available_powers: Vec<i32> = self + .permissions + .power_levels + .keys() + .filter_map(|k| k.parse::<i32>().ok()) + .filter(|&p| p < power) // Only consider lower power levels + .collect(); + + available_powers.sort_by(|a, b| b.cmp(a)); // Sort descending + available_powers.first().copied() + } + + /// Get the session_timeout_minutes for a specific power level (with fallback to default) + pub fn get_session_timeout(&self, power: i32) -> u64 { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.session_timeout_minutes) + .unwrap_or(self.security.session_timeout_minutes) + } + + /// Get the max_concurrent_sessions for a specific power level (with fallback to default) + pub fn get_max_concurrent_sessions(&self, power: i32) -> u32 { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.max_concurrent_sessions) + .unwrap_or(self.security.max_concurrent_sessions) + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..cbee57f --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,3 @@ +pub mod pool; + +pub use pool::{Database, DatabaseInitError}; diff --git a/src/db/pool.rs b/src/db/pool.rs new file mode 100644 index 0000000..4390739 --- /dev/null +++ b/src/db/pool.rs @@ -0,0 +1,110 @@ +// Database connection pool module +use crate::config::DatabaseConfig; +use anyhow::{Context, Result as AnyResult}; +use sqlx::{Error as SqlxError, MySqlPool}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +#[derive(Debug)] +pub enum DatabaseInitError { + Fatal(anyhow::Error), + Retryable(anyhow::Error), +} + +impl DatabaseInitError { + fn fatal(err: anyhow::Error) -> Self { + Self::Fatal(err) + } + + fn retryable(err: anyhow::Error) -> Self { + Self::Retryable(err) + } +} + +#[derive(Clone)] +pub struct Database { + pool: MySqlPool, + availability: Arc<AtomicBool>, +} + +impl Database { + pub async fn new(config: &DatabaseConfig) -> Result<Self, DatabaseInitError> { + let database_url = format!( + "mysql://{}:{}@{}:{}/{}", + config.username, config.password, config.host, config.port, config.database + ); + + let pool = sqlx::mysql::MySqlPoolOptions::new() + .min_connections(config.min_connections) + .max_connections(config.max_connections) + .acquire_timeout(std::time::Duration::from_secs( + config.connection_timeout_seconds, + )) + .connect(&database_url) + .await + .map_err(|err| map_sqlx_error(err, "Failed to connect to database"))?; + + // Test the connection + sqlx::query("SELECT 1") + .execute(&pool) + .await + .map_err(|err| map_sqlx_error(err, "Failed to test database connection"))?; + + Ok(Database { + pool, + availability: Arc::new(AtomicBool::new(true)), + }) + } + + pub fn pool(&self) -> &MySqlPool { + &self.pool + } + + pub fn is_available(&self) -> bool { + self.availability.load(Ordering::Relaxed) + } + + pub fn mark_available(&self) { + self.availability.store(true, Ordering::Relaxed); + } + + pub fn mark_unavailable(&self) { + self.availability.store(false, Ordering::Relaxed); + } + + pub async fn set_current_user(&self, user_id: i32) -> AnyResult<()> { + sqlx::query("SET @current_user_id = ?") + .bind(user_id) + .execute(&self.pool) + .await + .context("Failed to set current user ID")?; + + Ok(()) + } + + pub async fn close(&self) { + self.pool.close().await; + } +} + +fn map_sqlx_error(err: SqlxError, context: &str) -> DatabaseInitError { + let retryable = is_retryable_sqlx_error(&err); + let wrapped = anyhow::Error::new(err).context(context.to_string()); + if retryable { + DatabaseInitError::retryable(wrapped) + } else { + DatabaseInitError::fatal(wrapped) + } +} + +fn is_retryable_sqlx_error(err: &SqlxError) -> bool { + matches!( + err, + SqlxError::Io(_) + | SqlxError::PoolTimedOut + | SqlxError::PoolClosed + | SqlxError::WorkerCrashed + ) +} diff --git a/src/logging/logger.rs b/src/logging/logger.rs new file mode 100644 index 0000000..e063f8f --- /dev/null +++ b/src/logging/logger.rs @@ -0,0 +1,373 @@ +// Audit logging module with request ID tracing and custom filters +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use regex::Regex; +use serde_json::Value; +use std::fs::OpenOptions; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Clone)] +struct CustomFilter { + name: String, + pattern: Regex, + file: Arc<Mutex<std::fs::File>>, +} + +#[derive(Clone)] +pub struct AuditLogger { + mask_passwords: bool, + sensitive_fields: Vec<String>, + request_file: Option<Arc<Mutex<std::fs::File>>>, + query_file: Option<Arc<Mutex<std::fs::File>>>, + error_file: Option<Arc<Mutex<std::fs::File>>>, + warning_file: Option<Arc<Mutex<std::fs::File>>>, + info_file: Option<Arc<Mutex<std::fs::File>>>, + combined_file: Option<Arc<Mutex<std::fs::File>>>, + custom_filters: Vec<CustomFilter>, +} + +impl AuditLogger { + pub fn new( + request_log_path: Option<String>, + query_log_path: Option<String>, + error_log_path: Option<String>, + warning_log_path: Option<String>, + info_log_path: Option<String>, + combined_log_path: Option<String>, + mask_passwords: bool, + sensitive_fields: Vec<String>, + custom_filter_configs: Vec<crate::config::CustomLogFilter>, + ) -> Result<Self> { + // Helper function to open a log file if path is provided + let open_log_file = |path: &Option<String>| -> Result<Option<Arc<Mutex<std::fs::File>>>> { + if let Some(path_str) = path { + // Ensure log directories exist + if let Some(parent) = Path::new(path_str).parent() { + std::fs::create_dir_all(parent).context("Failed to create log directory")?; + } + + let file = OpenOptions::new() + .create(true) + .append(true) + .open(path_str) + .context(format!("Failed to open log file: {}", path_str))?; + + Ok(Some(Arc::new(Mutex::new(file)))) + } else { + Ok(None) + } + }; + + // Initialize custom filters + let mut custom_filters = Vec::new(); + for filter_config in custom_filter_configs { + if filter_config.enabled { + // Compile regex pattern + let pattern = Regex::new(&filter_config.pattern).context(format!( + "Invalid regex pattern in filter '{}': {}", + filter_config.name, filter_config.pattern + ))?; + + // Open filter output file + if let Some(parent) = Path::new(&filter_config.output_file).parent() { + std::fs::create_dir_all(parent) + .context("Failed to create filter log directory")?; + } + + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&filter_config.output_file) + .context(format!( + "Failed to open filter log file: {}", + filter_config.output_file + ))?; + + custom_filters.push(CustomFilter { + name: filter_config.name.clone(), + pattern, + file: Arc::new(Mutex::new(file)), + }); + } + } + + Ok(Self { + mask_passwords, + sensitive_fields, + request_file: open_log_file(&request_log_path)?, + query_file: open_log_file(&query_log_path)?, + error_file: open_log_file(&error_log_path)?, + warning_file: open_log_file(&warning_log_path)?, + info_file: open_log_file(&info_log_path)?, + combined_file: open_log_file(&combined_log_path)?, + custom_filters, + }) + } + + /// Generate a unique request ID for transaction tracing + pub fn generate_request_id() -> String { + format!("{}", uuid::Uuid::new_v4().as_u128() & 0xFFFFFFFF_FFFFFFFF) // 16 hex chars + } + + /// Write to combined log and apply custom filters + async fn write_combined_and_filter(&self, entry: &str) -> Result<()> { + // Write to combined log if configured + if let Some(ref file_mutex) = self.combined_file { + let mut file = file_mutex.lock().await; + file.write_all(entry.as_bytes()) + .context("Failed to write to combined log")?; + file.flush().context("Failed to flush combined log")?; + } + + // Apply custom filters + for filter in &self.custom_filters { + if filter.pattern.is_match(entry) { + let mut file = filter.file.lock().await; + file.write_all(entry.as_bytes()) + .context(format!("Failed to write to filter log: {}", filter.name))?; + file.flush() + .context(format!("Failed to flush filter log: {}", filter.name))?; + } + } + + Ok(()) + } + + pub async fn log_request( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + _ip: &str, + user: Option<&str>, + power: Option<i32>, + endpoint: &str, + payload: &Value, + ) -> Result<()> { + let mut masked_payload = payload.clone(); + + if self.mask_passwords { + self.mask_sensitive_data(&mut masked_payload); + } + + let user_str = user.unwrap_or("anonymous"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | REQUEST | user={} | {} | endpoint={} | payload={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + endpoint, + serde_json::to_string(&masked_payload).unwrap_or_else(|_| "invalid_json".to_string()) + ); + + // Write to legacy request log if configured + if let Some(ref file_mutex) = self.request_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to request log")?; + file.flush().context("Failed to flush request log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_query( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + user: &str, + power: Option<i32>, + query: &str, + parameters: Option<&Value>, + rows_affected: Option<u64>, + ) -> Result<()> { + let params_str = if let Some(params) = parameters { + serde_json::to_string(params).unwrap_or_else(|_| "invalid_json".to_string()) + } else { + "null".to_string() + }; + + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + let rows_str = rows_affected + .map(|r| format!("rows={}", r)) + .unwrap_or_else(|| "rows=0".to_string()); + + let log_entry = format!( + "{} [{}] | QUERY | user={} | {} | {} | query={} | params={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user, + power_str, + rows_str, + query, + params_str + ); + + // Write to legacy query log if configured + if let Some(ref file_mutex) = self.query_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to query log")?; + file.flush().context("Failed to flush query log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_error( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + error: &str, + context: Option<&str>, + user: Option<&str>, + power: Option<i32>, + ) -> Result<()> { + let user_str = user.unwrap_or("unknown"); + let context_str = context.unwrap_or("general"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | ERROR | user={} | {} | context={} | error={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + context_str, + error + ); + + // Write to legacy error log if configured + if let Some(ref file_mutex) = self.error_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to error log")?; + file.flush().context("Failed to flush error log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_warning( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + message: &str, + context: Option<&str>, + user: Option<&str>, + power: Option<i32>, + ) -> Result<()> { + let user_str = user.unwrap_or("unknown"); + let context_str = context.unwrap_or("general"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | WARNING | user={} | {} | context={} | message={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + context_str, + message + ); + + // Write to warning log if configured + if let Some(ref file_mutex) = self.warning_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to warning log")?; + file.flush().context("Failed to flush warning log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_info( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + message: &str, + context: Option<&str>, + user: Option<&str>, + power: Option<i32>, + ) -> Result<()> { + let user_str = user.unwrap_or("system"); + let context_str = context.unwrap_or("general"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | INFO | user={} | {} | context={} | message={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + context_str, + message + ); + + // Write to info log if configured + if let Some(ref file_mutex) = self.info_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to info log")?; + file.flush().context("Failed to flush info log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + fn mask_sensitive_data(&self, value: &mut Value) { + match value { + Value::Object(map) => { + for (key, val) in map.iter_mut() { + // Always mask password and pin + if key == "password" || key == "pin" { + *val = Value::String("***MASKED***".to_string()); + } + // Also mask any configured sensitive fields + else if self.sensitive_fields.contains(key) { + *val = Value::String("***MASKED***".to_string()); + } else { + self.mask_sensitive_data(val); + } + } + } + Value::Array(arr) => { + for item in arr.iter_mut() { + self.mask_sensitive_data(item); + } + } + _ => {} + } + } +} diff --git a/src/logging/mod.rs b/src/logging/mod.rs new file mode 100644 index 0000000..a6752e2 --- /dev/null +++ b/src/logging/mod.rs @@ -0,0 +1,3 @@ +pub mod logger; + +pub use logger::AuditLogger; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..17702f4 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,359 @@ +// Main entry point for the SeckelAPI server +use anyhow::Result; +use axum::{ + http::Method, + response::Json, + routing::{get, post}, + Router, +}; +use chrono::Utc; +use serde_json::json; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::time::{sleep, Duration}; +use tower_governor::{ + governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer, +}; +use tower_http::cors::{Any, CorsLayer}; +use tracing::{error, info, Level}; +use tracing_subscriber; + +mod auth; +mod config; +mod db; +mod logging; +mod models; +mod permissions; +mod routes; +mod scheduler; +mod sql; + +use auth::SessionManager; +use config::{Config, DatabaseConfig}; +use logging::AuditLogger; +use scheduler::QueryScheduler; + +use axum::extract::State; +use db::{Database, DatabaseInitError}; +use permissions::RBACManager; + +// Version pulled from Cargo.toml +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +// Root handler - simple version info +async fn root() -> Json<serde_json::Value> { + Json(json!({ + "name": "SeckelAPI", + "version": VERSION + })) +} + +// Health check handler with database connectivity test +async fn health_check(State(state): State<AppState>) -> Json<serde_json::Value> { + // Test database connectivity + let db_status = match sqlx::query("SELECT 1") + .fetch_one(state.database.pool()) + .await + { + Ok(_) => "connected", + Err(_) => "disconnected", + }; + + Json(json!({ + "status": "running", + "message": format!("SeckelAPI v{}: The schizophrenic database application JSON API", VERSION), + "database": db_status + })) +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize tracing + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); + + info!("Starting SeckelAPI Server..."); + + // Load configuration + let config = Config::load()?; + let config_arc = Arc::new(config.clone()); + info!("- Configuration might have loaded successfully"); + + // Initialize database connection + let database = wait_for_database(&config.database).await?; + info!("- Database said 'Hello there!'"); + + // Initialize session manager + let session_manager = SessionManager::new(config_arc.clone()); + info!("- Session manager didn't crash on startup (yay!)"); + + // Spawn background task for session cleanup + let cleanup_manager = session_manager.clone(); + let cleanup_interval = config.security.session_cleanup_interval_minutes; + tokio::spawn(async move { + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval * 60)); + loop { + interval.tick().await; + cleanup_manager.cleanup_expired_sessions(); + info!("- Ran session cleanup task"); + } + }); + info!( + "- Background session cleanup task spawned (interval: {}min)", + cleanup_interval + ); + + // Initialize RBAC manager + let rbac = RBACManager::new(&config); + info!("- Overcomplicated RBAC manager has initialized"); + + // Initialize audit logger + let logging = AuditLogger::new( + config.logging.request_log.clone(), + config.logging.query_log.clone(), + config.logging.error_log.clone(), + config.logging.warning_log.clone(), + config.logging.info_log.clone(), + config.logging.combined_log.clone(), + config.logging.mask_passwords, + config.logging.sensitive_fields.clone(), + config.logging.custom_filters.clone(), + )?; + info!("- CIA Surveillance Service Agents have been initialized on your system (just kidding, just the logging stack)"); + if !config.logging.custom_filters.is_empty() { + let enabled_count = config + .logging + .custom_filters + .iter() + .filter(|f| f.enabled) + .count(); + info!("- {} custom log filter(s) active", enabled_count); + } + + spawn_database_monitor(database.clone(), config.database.clone(), logging.clone()); + + // Initialize and spawn scheduled query tasks + let scheduler = QueryScheduler::new(config_arc.clone(), database.clone(), logging.clone()); + scheduler.spawn_tasks(); + + // Create CORS layer + let cors = CorsLayer::new() + .allow_methods([Method::GET, Method::POST]) + .allow_headers(Any) + .allow_origin(Any); + + // Build auth routes with rate limiting + let mut auth_routes = Router::new() + .route("/auth/login", post(routes::auth::login)) + .route("/auth/logout", post(routes::auth::logout)) + .route("/auth/status", get(routes::auth::status)); + + if config.security.enable_rate_limiting { + info!( + "- Auth rate limiting set to: {}/min, {}/sec per IP", + config.security.auth_rate_limit_per_minute, config.security.auth_rate_limit_per_second + ); + let auth_governor_conf = Arc::new( + GovernorConfigBuilder::default() + .per_second(config.security.auth_rate_limit_per_second.max(1) as u64) + .burst_size(config.security.auth_rate_limit_per_minute.max(1)) + .key_extractor(SmartIpKeyExtractor) + .finish() + .expect("Failed to build auth rate limiter config ... you sure its configured?"), + ); + auth_routes = auth_routes.layer(GovernorLayer::new(auth_governor_conf)); + } else { + info!("- Auth rate limiting DISABLED (dont run in production like this plz) "); + } + + // Build API routes with rate limiting + let mut api_routes = Router::new() + .route("/query", post(routes::query::execute_query)) + .route("/permissions", get(routes::query::get_permissions)) + .route( + "/preferences", + post(routes::preferences::handle_preferences), + ); + + if config.security.enable_rate_limiting { + info!( + "- API rate limiting set to: {}/min, {}/sec per IP", + config.security.api_rate_limit_per_minute, config.security.api_rate_limit_per_second + ); + let api_governor_conf = Arc::new( + GovernorConfigBuilder::default() + .per_second(config.security.api_rate_limit_per_second.max(1) as u64) + .burst_size(config.security.api_rate_limit_per_minute.max(1)) + .key_extractor(SmartIpKeyExtractor) + .finish() + .expect( + "Failed to build API rate limiter config, ugh ... you sure its configured?", + ), + ); + api_routes = api_routes.layer(GovernorLayer::new(api_governor_conf)); + } + + // das so routes was es chan + let app = Router::new() + .route("/", get(root)) + .route("/health", get(health_check)) + .merge(auth_routes) + .merge(api_routes) + .layer(tower_http::limit::RequestBodyLimitLayer::new( + config.server.request_body_limit_mb * 1024 * 1024, + )) + .layer(cors) + .with_state(AppState { + config: config.clone(), + database, + session_manager, + rbac, + logging, + }); + let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port)); + info!( + "- SeckelAPI somehow started and should now be listening on {} :)", + addr + ); + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve( + listener, + app.into_make_service_with_connect_info::<SocketAddr>(), + ) + .await?; + Ok(()) +} + +async fn wait_for_database(config: &DatabaseConfig) -> Result<Database> { + let retry_delay = config.connection_timeout_wait.max(1); + + loop { + match Database::new(config).await { + Ok(db) => return Ok(db), + Err(DatabaseInitError::Retryable(err)) => { + error!( + "Database unavailable (retrying in {}s): {}", + retry_delay, err + ); + sleep(Duration::from_secs(retry_delay)).await; + } + Err(DatabaseInitError::Fatal(err)) => { + error!("Fatal database configuration error: {}", err); + return Err(err); + } + } + } +} + +fn spawn_database_monitor(database: Database, db_config: DatabaseConfig, logging: AuditLogger) { + let heartbeat_interval = db_config.connection_check.max(5); + let retry_delay = db_config.connection_timeout_wait.max(1); + + tokio::spawn(async move { + let mut ticker = tokio::time::interval(Duration::from_secs(heartbeat_interval)); + loop { + ticker.tick().await; + match sqlx::query("SELECT 1").fetch_one(database.pool()).await { + Ok(_) => { + if !database.is_available() { + info!("Database connectivity restored"); + log_database_event( + &logging, + "database_reconnect", + "Database connectivity restored", + false, + ) + .await; + } + database.mark_available(); + } + Err(err) => { + database.mark_unavailable(); + error!( + "Database heartbeat failed (retrying every {}s): {}", + retry_delay, err + ); + + log_database_event( + &logging, + "database_heartbeat_failure", + &format!("Heartbeat failed: {}", err), + true, + ) + .await; + + loop { + sleep(Duration::from_secs(retry_delay)).await; + match sqlx::query("SELECT 1").fetch_one(database.pool()).await { + Ok(_) => { + database.mark_available(); + info!("Database reconnected successfully"); + log_database_event( + &logging, + "database_reconnect", + "Database reconnected successfully", + false, + ) + .await; + break; + } + Err(retry_err) => { + error!( + "Database still unavailable (retrying in {}s): {}", + retry_delay, retry_err + ); + log_database_event( + &logging, + "database_retry_failure", + &format!("Database retry failed: {}", retry_err), + true, + ) + .await; + } + } + } + } + } + } + }); +} + +async fn log_database_event(logging: &AuditLogger, context: &str, message: &str, as_error: bool) { + let request_id = AuditLogger::generate_request_id(); + let timestamp = Utc::now(); + let result = if as_error { + logging + .log_error( + &request_id, + timestamp, + message, + Some(context), + Some("system"), + None, + ) + .await + } else { + logging + .log_info( + &request_id, + timestamp, + message, + Some(context), + Some("system"), + None, + ) + .await + }; + + if let Err(err) = result { + error!("Failed to record database event ({}): {}", context, err); + } +} +#[derive(Clone)] +pub struct AppState { + pub config: Config, + pub database: Database, + pub session_manager: SessionManager, + pub rbac: RBACManager, + pub logging: AuditLogger, +} diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..4d9f734 --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,294 @@ +// Data models for SeckelAPI +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// Authentication models +#[derive(Debug, Deserialize, Serialize)] +pub struct LoginRequest { + pub method: AuthMethod, + pub username: Option<String>, + pub password: Option<String>, + pub pin: Option<String>, + pub login_string: Option<String>, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum AuthMethod { + Password, + Pin, + Token, +} + +#[derive(Debug, Serialize)] +pub struct LoginResponse { + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub token: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option<UserInfo>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<String>, +} + +#[derive(Debug, Serialize)] +pub struct UserInfo { + pub id: i32, + pub username: String, + pub name: String, + pub role: String, + pub power: i32, +} + +// Database query models +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct QueryRequest { + // Single query mode (when queries is None) + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option<QueryAction>, + #[serde(skip_serializing_if = "Option::is_none")] + pub table: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub columns: Option<Vec<String>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, + // Enhanced WHERE clause - supports both simple and complex conditions + #[serde(rename = "where", skip_serializing_if = "Option::is_none")] + pub where_clause: Option<serde_json::Value>, + // New structured filter for complex queries + #[serde(skip_serializing_if = "Option::is_none")] + pub filter: Option<FilterCondition>, + // JOIN support - allows multi-table queries with permission validation + #[serde(skip_serializing_if = "Option::is_none")] + pub joins: Option<Vec<Join>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub offset: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub order_by: Option<Vec<OrderBy>>, + + // Batch mode (when queries is Some) - action/table apply to ALL queries in batch + #[serde(skip_serializing_if = "Option::is_none")] + pub queries: Option<Vec<BatchQuery>>, + /// Whether to rollback on error in batch mode (defaults to config setting) + #[serde(skip_serializing_if = "Option::is_none")] + pub rollback_on_error: Option<bool>, +} + +/// Individual query in a batch - inherits action/table from parent QueryRequest +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct BatchQuery { + // Only the variable parts per query - no action/table duplication + #[serde(skip_serializing_if = "Option::is_none")] + pub columns: Option<Vec<String>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, + #[serde(rename = "where", skip_serializing_if = "Option::is_none")] + pub where_clause: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub filter: Option<FilterCondition>, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub offset: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub order_by: Option<Vec<OrderBy>>, +} + +/// JOIN specification for multi-table queries +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Join { + /// Type of join (INNER, LEFT, RIGHT) + #[serde(rename = "type")] + pub join_type: JoinType, + /// Table to join with + pub table: String, + /// Join condition (e.g., "assets.category_id = categories.id") + pub on: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(rename_all = "UPPERCASE")] +pub enum JoinType { + Inner, + Left, + Right, +} + +/// Enhanced filter condition supporting operators and nested logic +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(untagged)] +pub enum FilterCondition { + /// Simple condition: {"column": "name", "op": "=", "value": "John"} + Simple { + column: String, + #[serde(rename = "op")] + operator: FilterOperator, + value: serde_json::Value, + }, + /// Logical AND/OR: {"and": [condition1, condition2]} + Logical { + #[serde(rename = "and")] + and_conditions: Option<Vec<FilterCondition>>, + #[serde(rename = "or")] + or_conditions: Option<Vec<FilterCondition>>, + }, + /// NOT condition: {"not": condition} + Not { not: Box<FilterCondition> }, +} + +/// Supported WHERE clause operators +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum FilterOperator { + #[serde(rename = "=")] + Eq, // Equal + #[serde(rename = "!=")] + Ne, // Not equal + #[serde(rename = ">")] + Gt, // Greater than + #[serde(rename = ">=")] + Gte, // Greater than or equal + #[serde(rename = "<")] + Lt, // Less than + #[serde(rename = "<=")] + Lte, // Less than or equal + Like, // LIKE pattern matching + #[serde(rename = "not_like")] + NotLike, // NOT LIKE + In, // IN (value1, value2, ...) + #[serde(rename = "not_in")] + NotIn, // NOT IN (...) + #[serde(rename = "is_null")] + IsNull, // IS NULL + #[serde(rename = "is_not_null")] + IsNotNull, // IS NOT NULL + Between, // BETWEEN value1 AND value2 +} + +/// ORDER BY clause +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct OrderBy { + pub column: String, + #[serde(default = "default_order_direction")] + pub direction: OrderDirection, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "UPPERCASE")] +pub enum OrderDirection { + ASC, + DESC, +} + +fn default_order_direction() -> OrderDirection { + OrderDirection::ASC +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum QueryAction { + Select, + Insert, + Update, + Delete, + Count, +} + +#[derive(Debug, Serialize, Clone)] +pub struct QueryResponse { + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub rows_affected: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub warning: Option<String>, + // Batch results (when queries field was used) + #[serde(skip_serializing_if = "Option::is_none")] + pub results: Option<Vec<QueryResponse>>, +} + +// Database entities +#[derive(Debug, sqlx::FromRow)] +#[allow(dead_code)] // Fields are used for database serialization +pub struct User { + pub id: i32, + pub name: String, + pub username: String, + pub password: String, + pub pin_code: Option<String>, + pub login_string: Option<String>, + pub role_id: i32, + pub email: Option<String>, + pub phone: Option<String>, + pub notes: Option<String>, + pub active: bool, + pub last_login_date: Option<DateTime<Utc>>, + pub created_date: DateTime<Utc>, + pub password_reset_token: Option<String>, + pub password_reset_expiry: Option<DateTime<Utc>>, +} + +#[derive(Debug, sqlx::FromRow)] +#[allow(dead_code)] // Fields are used for database serialization +pub struct Role { + pub id: i32, + pub name: String, + pub power: i32, + pub created_at: DateTime<Utc>, +} + +// Session management +#[derive(Debug, Clone)] +pub struct Session { + pub user_id: i32, + pub username: String, + pub role_id: i32, + pub role_name: String, + pub power: i32, + pub created_at: DateTime<Utc>, + pub last_accessed: DateTime<Utc>, +} + +// Permission types +#[derive(Debug, Clone, PartialEq)] +pub enum Permission { + Read, + Write, + ReadWrite, + None, +} + +impl Permission { + pub fn from_str(s: &str) -> Self { + match s { + "r" => Permission::Read, + "w" => Permission::Write, + "rw" => Permission::ReadWrite, + _ => Permission::None, + } + } + + pub fn can_read(&self) -> bool { + matches!(self, Permission::Read | Permission::ReadWrite) + } + + pub fn can_write(&self) -> bool { + matches!(self, Permission::Write | Permission::ReadWrite) + } +} + +// Permissions response +#[derive(Debug, Serialize)] +pub struct PermissionsResponse { + pub success: bool, + pub user: UserInfo, + pub permissions: HashMap<String, String>, + pub security_clearance: Option<String>, + pub user_settings_access: String, +} diff --git a/src/permissions/mod.rs b/src/permissions/mod.rs new file mode 100644 index 0000000..4c43932 --- /dev/null +++ b/src/permissions/mod.rs @@ -0,0 +1,3 @@ +pub mod rbac; + +pub use rbac::RBACManager; diff --git a/src/permissions/rbac.rs b/src/permissions/rbac.rs new file mode 100644 index 0000000..c1835e5 --- /dev/null +++ b/src/permissions/rbac.rs @@ -0,0 +1,119 @@ +// Role-based access control module +use crate::config::Config; +use crate::models::{Permission, QueryAction}; +use std::collections::HashMap; + +#[derive(Clone)] +pub struct RBACManager { + permissions: HashMap<i32, HashMap<String, Permission>>, +} + +impl RBACManager { + pub fn new(config: &Config) -> Self { + let mut permissions = HashMap::new(); + + // Parse permissions from new config format + for (power_str, power_perms) in &config.permissions.power_levels { + if let Ok(power) = power_str.parse::<i32>() { + let mut table_permissions = HashMap::new(); + + for perm in &power_perms.basic_rules { + let parts: Vec<&str> = perm.split(':').collect(); + if parts.len() == 2 { + let table = parts[0]; + let permission = Permission::from_str(parts[1]); + + if table == "*" { + // Grant permission to all known tables from config + let all_tables = config.get_known_tables(); + + for table_name in all_tables { + // Don't overwrite existing specific permissions + if !table_permissions.contains_key(&table_name) { + table_permissions.insert(table_name, permission.clone()); + } + } + } else { + table_permissions.insert(table.to_string(), permission); + } + } + } + + permissions.insert(power, table_permissions); + } + } + + Self { permissions } + } + + fn base_table_name(table: &str) -> String { + // Accept formats: "table", "table alias", "table AS alias" (AS case-insensitive) + let parts: Vec<&str> = table.trim().split_whitespace().collect(); + if parts.len() >= 3 && parts[1].eq_ignore_ascii_case("AS") { + parts[0].to_string() + } else { + // For simple "table alias" or just "table", take the first token + parts.get(0).cloned().unwrap_or("").to_string() + } + } + + pub fn check_permission( + &self, + config: &Config, + power: i32, + table: &str, + action: &QueryAction, + ) -> bool { + // Normalize potential alias usage to the base table name for permission lookup + let base_table = Self::base_table_name(table); + + if let Some(table_permissions) = self.permissions.get(&power) { + if let Some(permission) = table_permissions.get(&base_table) { + // Check if table is read-only through config and enforce read-only constraint + if config.is_read_only_table(&base_table) && !matches!(action, QueryAction::Select) + { + return false; // Write operations not allowed on read-only tables + } + + match action { + QueryAction::Select | QueryAction::Count => permission.can_read(), + QueryAction::Insert | QueryAction::Update | QueryAction::Delete => { + permission.can_write() + } + } + } else { + false // No permission for this table + } + } else { + false // No permissions for this power level + } + } + + pub fn get_table_permissions(&self, config: &Config, power: i32) -> HashMap<String, String> { + let mut result = HashMap::new(); + + if let Some(table_permissions) = self.permissions.get(&power) { + for (table, permission) in table_permissions { + let perm_str = match permission { + Permission::Read => "r", + Permission::Write => "w", + Permission::ReadWrite => "rw", + Permission::None => "", + }; + + if !perm_str.is_empty() { + result.insert(table.clone(), perm_str.to_string()); + } + } + } + + // Ensure read-only tables from config are marked as read-only + for table in &config.get_known_tables() { + if config.is_read_only_table(table) && result.contains_key(table) { + result.insert(table.clone(), "r".to_string()); + } + } + + result + } +} diff --git a/src/routes/auth.rs b/src/routes/auth.rs new file mode 100644 index 0000000..c124e03 --- /dev/null +++ b/src/routes/auth.rs @@ -0,0 +1,513 @@ +// Authentication routes +use axum::{ + extract::{ConnectInfo, State}, + http::StatusCode, + Json, +}; +use chrono::Utc; +use std::net::SocketAddr; +use tracing::{error, warn}; + +use crate::logging::AuditLogger; +use crate::models::{AuthMethod, LoginRequest, LoginResponse, UserInfo}; +use crate::{auth, AppState}; + +pub async fn login( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + Json(payload): Json<LoginRequest>, +) -> Result<Json<LoginResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + None, + None, + "/auth/login", + &serde_json::to_value(&payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Validate request based on auth method + let (user, role) = match payload.method { + AuthMethod::Password => { + // Password auth - allowed from any IP + if let (Some(username), Some(password)) = (&payload.username, &payload.password) { + match auth::password::authenticate_password( + state.database.pool(), + username, + password, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed password authentication for user: {} - Invalid credentials", + username + ), + Some("password_auth"), + Some(username), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during password authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Password auth error: {}", e), + Some("authentication"), + payload.username.as_deref(), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("Password authentication attempted without username or password"); + super::log_warning_async( + &state.logging, + &request_id, + "Password authentication attempted without username or password", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + AuthMethod::Pin => { + // PIN auth - only from whitelisted IPs + if !state.config.is_pin_ip_whitelisted(&client_ip) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "PIN authentication attempted from non-whitelisted IP: {}", + client_ip + ), + Some("security_violation"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + + if let (Some(username), Some(pin)) = (&payload.username, &payload.pin) { + match auth::pin::authenticate_pin( + state.database.pool(), + username, + pin, + &state.config.security, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!("Failed PIN authentication for user: {} - Invalid PIN or PIN not configured", username), + Some("pin_auth"), + Some(username), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during PIN authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("PIN auth error: {}", e), + Some("authentication"), + payload.username.as_deref(), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("PIN authentication attempted without username or PIN"); + super::log_warning_async( + &state.logging, + &request_id, + "PIN authentication attempted without username or PIN", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + AuthMethod::Token => { + // Token/RFID auth - only from whitelisted IPs + if !state.config.is_string_ip_whitelisted(&client_ip) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Token authentication attempted from non-whitelisted IP: {}", + client_ip + ), + Some("security_violation"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + + if let Some(login_string) = &payload.login_string { + match auth::token::authenticate_token( + state.database.pool(), + login_string, + &state.config.security, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed token authentication for login_string: {} - Invalid token", + login_string + ), + Some("token_auth"), + Some(login_string), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during token authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Token auth error: {}", e), + Some("authentication"), + Some(login_string), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("Token authentication attempted without login_string"); + super::log_warning_async( + &state.logging, + &request_id, + "Token authentication attempted without login_string", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + }; + + // Create session token + let token = state.session_manager.create_session( + user.id, + user.username.clone(), + role.id, + role.name.clone(), + role.power, + ); + + // Log successful login + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Successful login for user: {} ({})", + user.username, user.name + ), + Some("authentication"), + Some(&user.username), + Some(role.power), + ); + + Ok(Json(LoginResponse { + success: true, + token: Some(token), + user: Some(UserInfo { + id: user.id, + username: user.username, + name: user.name, + role: role.name, + power: role.power, + }), + error: None, + })) +} + +pub async fn logout( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: axum::http::HeaderMap, +) -> Result<Json<serde_json::Value>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract token from Authorization header + let token = match headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) { + Some(t) => t, + None => { + return Ok(Json(serde_json::json!({ + "success": false, + "error": "No authorization token provided" + }))); + } + }; + + // Get username for logging before removing session + let username = state + .session_manager + .get_session(&token) + .map(|s| s.username.clone()); + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + username.as_deref(), + None, + "/auth/logout", + &serde_json::json!({"action": "logout"}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + let removed = state.session_manager.remove_session(&token); + + if removed { + // Log successful logout + super::log_info_async( + &state.logging, + &request_id, + &format!( + "User {} logged out successfully", + username.as_deref().unwrap_or("unknown") + ), + Some("authentication"), + username.as_deref(), + None, + ); + Ok(Json(serde_json::json!({ + "success": true, + "message": "Logged out successfully" + }))) + } else { + Ok(Json(serde_json::json!({ + "success": false, + "error": "Invalid or expired token" + }))) + } +} + +pub async fn status( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: axum::http::HeaderMap, +) -> Result<Json<serde_json::Value>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract token from Authorization header + let token_opt = headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }); + + let token = match token_opt { + Some(t) => t, + None => { + return Ok(Json(serde_json::json!({ + "success": false, + "valid": false, + "error": "No authorization token provided" + }))); + } + }; + + // Check session validity + match state.session_manager.get_session(&token) { + Some(session) => { + let now = Utc::now(); + let timeout_minutes = state.config.get_session_timeout(session.power); + let elapsed = (now - session.last_accessed).num_seconds(); + let timeout_seconds = (timeout_minutes * 60) as i64; + let remaining_seconds = timeout_seconds - elapsed; + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/auth/status", + &serde_json::json!({"token_provided": true}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + Ok(Json(serde_json::json!({ + "success": true, + "valid": true, + "user": { + "id": session.user_id, + "username": session.username, + "name": session.username, + "role": session.role_name, + "power": session.power + }, + "session": { + "created_at": session.created_at.to_rfc3339(), + "last_accessed": session.last_accessed.to_rfc3339(), + "timeout_minutes": timeout_minutes, + "remaining_seconds": remaining_seconds.max(0), + "expires_at": (session.last_accessed + chrono::Duration::minutes(timeout_minutes as i64)).to_rfc3339() + } + }))) + } + None => { + // Log the request for invalid token + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + None, + None, + "/auth/status", + &serde_json::json!({"token_provided": true, "valid": false}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + Ok(Json(serde_json::json!({ + "success": true, + "valid": false, + "message": "Session expired or invalid" + }))) + } + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs new file mode 100644 index 0000000..b81a0a1 --- /dev/null +++ b/src/routes/mod.rs @@ -0,0 +1,111 @@ +pub mod auth; +pub mod preferences; +pub mod query; + +use crate::logging::logger::AuditLogger; +use tracing::{error, info, warn}; + +/// Helper function to log errors to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_error_async( + logging: &AuditLogger, + request_id: &str, + error_msg: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + error!("[{}] {}", request_id, error_msg); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let error_msg = error_msg.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_error( + &req_id, + chrono::Utc::now(), + &error_msg, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} + +/// Helper function to log warnings to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_warning_async( + logging: &AuditLogger, + request_id: &str, + message: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + warn!("[{}] {}", request_id, message); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let message = message.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_warning( + &req_id, + chrono::Utc::now(), + &message, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} + +/// Helper function to log info messages to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_info_async( + logging: &AuditLogger, + request_id: &str, + message: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + info!("[{}] {}", request_id, message); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let message = message.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_info( + &req_id, + chrono::Utc::now(), + &message, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} diff --git a/src/routes/preferences.rs b/src/routes/preferences.rs new file mode 100644 index 0000000..e58d823 --- /dev/null +++ b/src/routes/preferences.rs @@ -0,0 +1,423 @@ +// User preferences routes +use axum::{ + extract::{ConnectInfo, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use chrono::Utc; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use tracing::{error, warn}; + +use crate::config::UserSettingsAccess; +use crate::logging::AuditLogger; +use crate::AppState; + +// Request/Response structures matching the query route pattern +#[derive(Debug, Deserialize)] +pub struct PreferencesRequest { + pub action: String, // "get", "set", "reset" + pub user_id: Option<i32>, // For admin access to other users + pub preferences: Option<serde_json::Value>, +} + +#[derive(Debug, Serialize)] +pub struct PreferencesResponse { + pub success: bool, + pub preferences: Option<serde_json::Value>, + pub error: Option<String>, +} + +/// Extract token from Authorization header +fn extract_token(headers: &HeaderMap) -> Option<String> { + headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) +} + +/// POST /preferences - Handle all preference operations (get, set, reset) +pub async fn handle_preferences( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, + Json(payload): Json<PreferencesRequest>, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some( + "Please stop trying to access this resource without signing in".to_string(), + ), + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Session not found".to_string()), + })); + } + }; + + // Determine target user ID + let target_user_id = payload.user_id.unwrap_or(session.user_id); + + // Get user's permission level for preferences + let user_settings_permission = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .map(|p| &p.user_settings_access) + .unwrap_or(&state.config.security.default_user_settings_access); + + // Check permissions for cross-user access + if target_user_id != session.user_id { + if *user_settings_permission != UserSettingsAccess::ReadWriteAll { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!("User {} (power {}) attempted to access preferences of user {} without permission", + session.username, session.power, target_user_id), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } + + // Check write permissions for set/reset actions + if payload.action == "set" || payload.action == "reset" { + if target_user_id == session.user_id { + // Writing own preferences - need at least ReadWriteOwn + if *user_settings_permission == UserSettingsAccess::ReadOwnOnly { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } else { + // Writing others' preferences - need ReadWriteAll + if *user_settings_permission != UserSettingsAccess::ReadWriteAll { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } + } + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/preferences", + &serde_json::json!({"action": payload.action, "target_user_id": target_user_id}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Handle the action + match payload.action.as_str() { + "get" => { + handle_get_preferences( + state, + request_id, + target_user_id, + session.username.clone(), + session.power, + ) + .await + } + "set" => { + handle_set_preferences( + state, + request_id, + target_user_id, + payload.preferences, + session.username.clone(), + session.power, + ) + .await + } + "reset" => { + handle_reset_preferences( + state, + request_id, + target_user_id, + session.username.clone(), + session.power, + ) + .await + } + _ => Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some(format!("Invalid action: {}", payload.action)), + })), + } +} + +async fn handle_get_preferences( + state: AppState, + request_id: String, + user_id: i32, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + // Cast JSON column to CHAR to get string representation + let query = "SELECT CAST(preferences AS CHAR) FROM users WHERE id = ? AND active = TRUE"; + let row: Option<(Option<String>,)> = sqlx::query_as(query) + .bind(user_id) + .fetch_optional(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error fetching preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let preferences = match row { + Some((Some(prefs_str),)) => serde_json::from_str(&prefs_str).unwrap_or_else(|e| { + warn!( + "[{}] Failed to parse preferences JSON for user {}: {}", + request_id, user_id, e + ); + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed to parse preferences JSON for user {}: {}", + user_id, e + ), + Some("data_integrity"), + Some(&username), + Some(power), + ); + serde_json::json!({}) + }), + _ => serde_json::json!({}), + }; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!( + "User {} retrieved preferences for user {}", + username, user_id + ), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(preferences), + error: None, + })) +} + +async fn handle_set_preferences( + state: AppState, + request_id: String, + user_id: i32, + new_preferences: Option<serde_json::Value>, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let new_prefs = match new_preferences { + Some(prefs) => prefs, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Missing preferences field".to_string()), + })); + } + }; + + // Get current preferences for merging + let query = "SELECT CAST(preferences AS CHAR) FROM users WHERE id = ? AND active = TRUE"; + let row: Option<(Option<String>,)> = sqlx::query_as(query) + .bind(user_id) + .fetch_optional(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error fetching preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Deep merge the preferences + let mut merged_prefs = match row { + Some((Some(prefs_str),)) => { + serde_json::from_str(&prefs_str).unwrap_or_else(|_| serde_json::json!({})) + } + _ => serde_json::json!({}), + }; + + // Merge function + fn merge_json(base: &mut serde_json::Value, update: &serde_json::Value) { + if let (Some(base_obj), Some(update_obj)) = (base.as_object_mut(), update.as_object()) { + for (key, value) in update_obj { + if let Some(base_value) = base_obj.get_mut(key) { + if base_value.is_object() && value.is_object() { + merge_json(base_value, value); + } else { + *base_value = value.clone(); + } + } else { + base_obj.insert(key.clone(), value.clone()); + } + } + } + } + + merge_json(&mut merged_prefs, &new_prefs); + + // Save to database + let prefs_str = serde_json::to_string(&merged_prefs).map_err(|e| { + error!("[{}] Failed to serialize preferences: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let update_query = "UPDATE users SET preferences = ? WHERE id = ? AND active = TRUE"; + sqlx::query(update_query) + .bind(&prefs_str) + .bind(user_id) + .execute(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error updating preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!("User {} updated preferences for user {}", username, user_id), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(merged_prefs), + error: None, + })) +} + +async fn handle_reset_preferences( + state: AppState, + request_id: String, + user_id: i32, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let query = "UPDATE users SET preferences = NULL WHERE id = ? AND active = TRUE"; + sqlx::query(query) + .bind(user_id) + .execute(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error resetting preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!("User {} reset preferences for user {}", username, user_id), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(serde_json::json!({})), + error: None, + })) +} diff --git a/src/routes/query.rs b/src/routes/query.rs new file mode 100644 index 0000000..9814215 --- /dev/null +++ b/src/routes/query.rs @@ -0,0 +1,3255 @@ +// Query routes and execution +use anyhow::{Context, Result}; +use axum::{ + extract::{ConnectInfo, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use chrono::Utc; +use rand::Rng; +use serde_json::Value; +use sqlx::{Column, Row}; +use std::collections::HashMap; +use std::net::SocketAddr; +use tracing::{error, info, warn}; + +use crate::logging::AuditLogger; +use crate::models::{PermissionsResponse, QueryAction, QueryRequest, QueryResponse, UserInfo}; +use crate::sql::{ + build_filter_clause, build_legacy_where_clause, build_order_by_clause, validate_column_name, + validate_column_names, validate_table_name, +}; +use crate::AppState; + +// Helper function to extract token from Authorization header +fn extract_token(headers: &HeaderMap) -> Option<String> { + headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) +} + +fn database_unavailable_response(request_id: &str) -> QueryResponse { + QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Database temporarily unavailable, try again in a moment [request_id: {}]", + request_id + )), + warning: None, + results: None, + } +} + +fn database_unavailable_batch_response(request_id: &str) -> QueryResponse { + let mut base = database_unavailable_response(request_id); + base.results = Some(vec![]); + base +} + +async fn log_database_unavailable_event( + logger: &AuditLogger, + request_id: &str, + username: Option<&str>, + power: Option<i32>, + detail: &str, +) { + if let Err(err) = logger + .log_error( + request_id, + Utc::now(), + detail, + Some("database_unavailable"), + username, + power, + ) + .await + { + error!("[{}] Failed to record database outage: {}", request_id, err); + } +} + +pub async fn execute_query( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, + Json(payload): Json<QueryRequest>, +) -> Result<Json<QueryResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some( + "Please stop trying to access this resource without signing in".to_string(), + ), + warning: None, + results: None, + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Session not found".to_string()), + warning: None, + results: None, + })); + } + }; + + // Detect batch mode - if queries field is present, handle as batch operation + if payload.queries.is_some() { + // SECURITY: Check if user has permission to use batch operations + let power_perms = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .ok_or(StatusCode::FORBIDDEN)?; + + if !power_perms.allow_batch_operations { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "User {} (power {}) attempted batch operation without permission", + session.username, session.power + ), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Batch operations not permitted for your role".to_string()), + warning: None, + results: None, + })); + } + + return execute_batch_mode(state, session, request_id, timestamp, client_ip, &payload) + .await; + } + + // Validate input for very basic security vulnerabilities (null bytes, etc.) + if let Err(security_error) = validate_input_security(&payload) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Security validation failed for user {}: {}", + session.username, security_error + ), + Some("security_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Invalid input detected, how did you even manage to do that? [request_id: {}]", + request_id + )), + warning: None, + results: None, + })); + } + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/query", + &serde_json::to_value(&payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Clone payload before extracting fields (to avoid partial move issues) + let payload_clone = payload.clone(); + + // Single query mode - validate required fields + let action = payload.action.ok_or_else(|| { + let error_msg = "Missing action field in single query mode"; + super::log_error_async( + &state.logging, + &request_id, + error_msg, + Some("request_validation"), + Some(&session.username), + Some(session.power), + ); + StatusCode::BAD_REQUEST + })?; + + let table = payload.table.ok_or_else(|| { + let error_msg = "Missing table field in single query mode"; + super::log_error_async( + &state.logging, + &request_id, + error_msg, + Some("request_validation"), + Some(&session.username), + Some(session.power), + ); + StatusCode::BAD_REQUEST + })?; + + // SECURITY: Validate table name before any operations + if let Err(e) = validate_table_name(&table, &state.config) { + let error_msg = format!("Invalid table name '{}': {}", table, e); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("table_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Invalid table name [request_id: {}]", request_id)), + warning: None, + results: None, + })); + } + + // SECURITY: Validate column names if specified + if let Some(ref columns) = payload.columns { + if let Err(e) = validate_column_names(columns) { + let error_msg = format!("Invalid column names on table '{}': {}", table, e); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("column_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Invalid column name: {} [request_id: {}]", + e, request_id + )), + warning: None, + results: None, + })); + } + } + + // Check permissions (after validation to avoid leaking table existence) + if !state + .rbac + .check_permission(&state.config, session.power, &table, &action) + { + let action_str = match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }; + + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "User {} attempted unauthorized {} on table {}", + session.username, action_str, table + ), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + + // Log security violation + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Permission denied: {} on table {}", action_str, table), + Some("authorization"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log permission denial: {}", + request_id, log_err + ); + } + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Insufficient permissions for this operation, Might as well give up [request_id: {}]", request_id)), + warning: None, + results: None, + })); + } + + if !state.database.is_available() { + warn!( + "[{}] Database marked unavailable, returning graceful error", + request_id + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + "Database flagged unavailable before transaction", + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + + // ALL database operations now use transactions with user context set + let mut tx = match state.database.pool().begin().await { + Ok(tx) => { + state.database.mark_available(); + tx + } + Err(e) => { + state.database.mark_unavailable(); + error!("[{}] Failed to begin transaction: {}", request_id, e); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to begin transaction: {}", e), + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + }; + + // Set user context and request ID in transaction - ALL queries have user context now + if let Err(e) = sqlx::query("SET @current_user_id = ?, @request_id = ?") + .bind(session.user_id) + .bind(&request_id) + .execute(&mut *tx) + .await + { + state.database.mark_unavailable(); + error!( + "[{}] Failed to set current user context and request ID: {}", + request_id, e + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to set user context: {}", e), + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + + // Execute the query within the transaction + let result = match action { + QueryAction::Select => { + execute_select_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &session, + &state, + ) + .await + } + QueryAction::Insert => { + execute_insert_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Update => { + execute_update_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Delete => { + execute_delete_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Count => { + execute_count_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &session, + &state, + ) + .await + } + }; + + match result { + Ok(response) => { + let action_str = match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }; + + super::log_info_async( + &state.logging, + &request_id, + &format!("Query executed successfully: {} on {}", action_str, table), + Some("query_execution"), + Some(&session.username), + Some(session.power), + ); + Ok(Json(response)) + } + Err(e) => { + error!("[{}] Query execution failed: {}", request_id, e); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Query execution error: {}", e), + Some("query_execution"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + + Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Database query failed [request_id: {}]", + request_id + )), + warning: None, + results: None, + })) + } + } +} + +pub async fn get_permissions( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, +) -> Result<Json<PermissionsResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(PermissionsResponse { + success: false, + permissions: HashMap::new(), + user: UserInfo { + id: 0, + username: "".to_string(), + name: "".to_string(), + role: "".to_string(), + power: 0, + }, + security_clearance: None, + user_settings_access: "".to_string(), + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(PermissionsResponse { + success: false, + permissions: HashMap::new(), + user: UserInfo { + id: 0, + username: "".to_string(), + name: "".to_string(), + role: "".to_string(), + power: 0, + }, + security_clearance: None, + user_settings_access: "".to_string(), + })); + } + }; + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/permissions", + &serde_json::json!({}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + let permissions = state + .rbac + .get_table_permissions(&state.config, session.power); + + // Get user settings access permission + let user_settings_permission = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .map(|p| &p.user_settings_access) + .unwrap_or(&state.config.security.default_user_settings_access); + + let user_settings_access_str = match user_settings_permission { + crate::config::UserSettingsAccess::ReadOwnOnly => "read-own-only", + crate::config::UserSettingsAccess::ReadWriteOwn => "read-write-own", + crate::config::UserSettingsAccess::ReadWriteAll => "read-write-all", + }; + + Ok(Json(PermissionsResponse { + success: true, + permissions, + user: UserInfo { + id: session.user_id, + username: session.username, + name: "".to_string(), // We don't store name in session, would need to fetch from DB + role: session.role_name, + power: session.power, + }, + security_clearance: None, + user_settings_access: user_settings_access_str.to_string(), + })) +} + +// ===== SAFE SQL QUERY BUILDERS WITH VALIDATION ===== +// All functions validate table/column names to prevent SQL injection + +/// Build WHERE clause with column name validation (legacy simple format) +fn build_where_clause(where_clause: &Value) -> anyhow::Result<(String, Vec<String>)> { + // Use the new validated builder from sql module + build_legacy_where_clause(where_clause) +} + +/// Build INSERT data with column name validation +fn build_insert_data( + data: &Value, +) -> anyhow::Result<(Vec<String>, Vec<String>, Vec<Option<String>>)> { + let mut columns = Vec::new(); + let mut placeholders = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = data { + for (key, value) in map { + // SECURITY: Validate column name before using it + validate_column_name(key) + .with_context(|| format!("Invalid column name in INSERT: {}", key))?; + + // Handle special JSON fields like additional_fields + if key == "additional_fields" && value.is_object() { + columns.push(key.clone()); + placeholders.push("?".to_string()); + values.push(Some( + serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()), + )); + } else { + columns.push(key.clone()); + placeholders.push("?".to_string()); + values.push(json_value_to_option_string(value)); + } + } + } else { + anyhow::bail!("INSERT data must be a JSON object"); + } + + if columns.is_empty() { + anyhow::bail!("INSERT data cannot be empty"); + } + + Ok((columns, placeholders, values)) +} + +/// Build UPDATE SET clause with column name validation +fn build_update_set_clause(data: &Value) -> anyhow::Result<(String, Vec<Option<String>>)> { + let mut set_clauses = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = data { + for (key, value) in map { + // SECURITY: Validate column name before using it + validate_column_name(key) + .with_context(|| format!("Invalid column name in UPDATE: {}", key))?; + + set_clauses.push(format!("{} = ?", key)); + // Handle special JSON fields like additional_fields + if key == "additional_fields" && value.is_object() { + values.push(Some( + serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()), + )); + } else { + values.push(json_value_to_option_string(value)); + } + } + } else { + anyhow::bail!("UPDATE data must be a JSON object"); + } + + if set_clauses.is_empty() { + anyhow::bail!("UPDATE data cannot be empty"); + } + + Ok((set_clauses.join(", "), values)) +} + +/// Convert JSON value to Option<String> for SQL binding +/// Properly handles booleans (true/false -> "1"/"0" for MySQL TINYINT/BOOLEAN) +/// NULL values return None for proper SQL NULL handling +fn json_value_to_option_string(value: &Value) -> Option<String> { + match value { + Value::String(s) => Some(s.clone()), + Value::Number(n) => Some(n.to_string()), + // MySQL uses TINYINT(1) for booleans: true -> 1, false -> 0 + Value::Bool(b) => Some(if *b { "1".to_string() } else { "0".to_string() }), + Value::Null => None, // Return None for proper SQL NULL + // Complex types (objects, arrays) get JSON serialized + _ => Some(serde_json::to_string(value).unwrap_or_else(|_| "null".to_string())), + } +} +/// Convert SQL values from MySQL to JSON +/// Handles ALL MySQL types automatically with proper NULL handling +/// Returns booleans as true/false (MySQL TINYINT(1) -> JSON bool) +fn convert_sql_value_to_json( + row: &sqlx::mysql::MySqlRow, + index: usize, + request_id: Option<&str>, + username: Option<&str>, + power: Option<i32>, + state: Option<&AppState>, +) -> anyhow::Result<Value> { + let column = &row.columns()[index]; + + use sqlx::TypeInfo; + let type_name = column.type_info().name(); + + // Comprehensive MySQL type handling - no need to manually add types anymore! + let result = match type_name { + // ===== String types ===== + "VARCHAR" | "TEXT" | "CHAR" | "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "SET" | "ENUM" => { + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + + // ===== Integer types ===== + "INT" | "BIGINT" | "MEDIUMINT" | "SMALLINT" | "INTEGER" => row + .try_get::<Option<i64>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // Unsigned integers + "INT UNSIGNED" | "BIGINT UNSIGNED" | "MEDIUMINT UNSIGNED" | "SMALLINT UNSIGNED" => row + .try_get::<Option<u64>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // ===== Boolean type (MySQL TINYINT(1)) ===== + // Returns proper JSON true/false instead of 1/0 + "TINYINT" | "BOOLEAN" | "BOOL" => { + // Try as bool first (for TINYINT(1)) + if let Ok(opt_bool) = row.try_get::<Option<bool>, _>(index) { + return Ok(opt_bool.map(Value::Bool).unwrap_or(Value::Null)); + } + // Fallback to i8 for regular TINYINT + row.try_get::<Option<i8>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)) + } + + // ===== Decimal/Numeric types ===== + "DECIMAL" | "NUMERIC" => { + row.try_get::<Option<rust_decimal::Decimal>, _>(index) + .map(|opt| { + opt.map(|decimal| { + // Keep precision by converting to string then parsing + let decimal_str = decimal.to_string(); + if let Ok(f) = decimal_str.parse::<f64>() { + serde_json::json!(f) + } else { + Value::String(decimal_str) + } + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Floating point types ===== + "FLOAT" | "DOUBLE" | "REAL" => row + .try_get::<Option<f64>, _>(index) + .map(|opt| opt.map(|v| serde_json::json!(v)).unwrap_or(Value::Null)), + + // ===== Date/Time types ===== + "DATE" => { + use chrono::NaiveDate; + row.try_get::<Option<NaiveDate>, _>(index).map(|opt| { + opt.map(|date| Value::String(date.format("%Y-%m-%d").to_string())) + .unwrap_or(Value::Null) + }) + } + "DATETIME" => { + use chrono::NaiveDateTime; + row.try_get::<Option<NaiveDateTime>, _>(index).map(|opt| { + opt.map(|datetime| Value::String(datetime.format("%Y-%m-%d %H:%M:%S").to_string())) + .unwrap_or(Value::Null) + }) + } + "TIMESTAMP" => { + use chrono::{DateTime, Utc}; + row.try_get::<Option<DateTime<Utc>>, _>(index).map(|opt| { + opt.map(|timestamp| Value::String(timestamp.to_rfc3339())) + .unwrap_or(Value::Null) + }) + } + "TIME" => { + // TIME values come as strings in HH:MM:SS format + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + "YEAR" => row + .try_get::<Option<i32>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // ===== JSON type ===== + "JSON" => row.try_get::<Option<String>, _>(index).map(|opt| { + opt.and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or(Value::Null) + }), + + // ===== Binary types ===== + // Return as base64-encoded strings for safe JSON transmission + "BLOB" | "MEDIUMBLOB" | "LONGBLOB" | "TINYBLOB" | "BINARY" | "VARBINARY" => { + row.try_get::<Option<Vec<u8>>, _>(index).map(|opt| { + opt.map(|bytes| { + use base64::{engine::general_purpose, Engine as _}; + Value::String(general_purpose::STANDARD.encode(&bytes)) + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Bit type ===== + "BIT" => { + row.try_get::<Option<Vec<u8>>, _>(index).map(|opt| { + opt.map(|bytes| { + // Convert bit value to number + let mut val: i64 = 0; + for &byte in &bytes { + val = (val << 8) | byte as i64; + } + Value::Number(val.into()) + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Spatial/Geometry types ===== + "GEOMETRY" | "POINT" | "LINESTRING" | "POLYGON" | "MULTIPOINT" | "MULTILINESTRING" + | "MULTIPOLYGON" | "GEOMETRYCOLLECTION" => { + // Return as WKT (Well-Known Text) string + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + + // ===== Catch-all for unknown/new types ===== + // This ensures forward compatibility if MySQL adds new types + _ => { + warn!( + "Unknown MySQL type '{}' for column '{}', attempting string fallback", + type_name, + column.name() + ); + if let (Some(rid), Some(st)) = (request_id, state) { + super::log_warning_async( + &st.logging, + rid, + &format!( + "Unknown MySQL type '{}' for column '{}', attempting string fallback", + type_name, + column.name() + ), + Some("data_conversion"), + username, + power, + ); + } + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + }; + + // Robust error handling with fallback + match result { + Ok(value) => Ok(value), + Err(e) => { + // Final fallback: try as string + match row.try_get::<Option<String>, _>(index) { + Ok(opt) => { + warn!("Primary conversion failed for column '{}' (type: {}), used string fallback", + column.name(), type_name); + if let (Some(rid), Some(st)) = (request_id, state) { + super::log_warning_async( + &st.logging, + rid, + &format!("Primary conversion failed for column '{}' (type: {}), used string fallback", column.name(), type_name), + Some("data_conversion"), + username, + power, + ); + } + Ok(opt.map(Value::String).unwrap_or(Value::Null)) + } + Err(_) => { + error!( + "Complete failure to decode column '{}' (index: {}, type: {}): {}", + column.name(), + index, + type_name, + e + ); + // Return NULL instead of failing the entire query + Ok(Value::Null) + } + } + } + } +} + +// Generate auto values based on configuration +async fn generate_auto_value( + state: &AppState, + table: &str, + config: &crate::config::AutoGenerationConfig, +) -> Result<String, anyhow::Error> { + match config.gen_type.as_str() { + "numeric" => generate_unique_numeric_id(state, table, config).await, + _ => Err(anyhow::anyhow!( + "Unsupported auto-generation type: {}", + config.gen_type + )), + } +} + +// Generate a unique numeric ID based on configuration +async fn generate_unique_numeric_id( + state: &AppState, + table: &str, + config: &crate::config::AutoGenerationConfig, +) -> Result<String, anyhow::Error> { + let range_min = config.range_min.unwrap_or(10000000); + let range_max = config.range_max.unwrap_or(99999999); + let max_attempts = config.max_attempts.unwrap_or(10) as usize; + let field_name = &config.field; + + for _attempt in 0..max_attempts { + // Generate random number in specified range + let id = { + let mut rng = rand::rng(); + rng.random_range(range_min..=range_max) + }; + let id_str = id.to_string(); + + // Check if this ID already exists + let query_str = format!( + "SELECT COUNT(*) as count FROM {} WHERE {} = ?", + table, field_name + ); + let exists = sqlx::query(&query_str) + .bind(&id_str) + .fetch_one(state.database.pool()) + .await?; + + let count: i64 = exists.try_get("count")?; + + if count == 0 { + return Ok(id_str); + } + } + + Err(anyhow::anyhow!( + "Failed to generate unique {} for table {} after {} attempts", + field_name, + table, + max_attempts + )) +} + +// Security validation functions +fn validate_input_security(payload: &QueryRequest) -> Result<(), String> { + // Check for null bytes in table name + if let Some(ref table) = payload.table { + if table.contains('\0') { + return Err("Null byte detected in table name".to_string()); + } + } + + // Check for null bytes in column names + if let Some(columns) = &payload.columns { + for column in columns { + if column.contains('\0') { + return Err("Null byte detected in column name".to_string()); + } + } + } + + // Check for null bytes in data values + if let Some(data) = &payload.data { + if contains_null_bytes_in_value(data) { + return Err("Null byte detected in data values".to_string()); + } + } + + // Check for null bytes in WHERE clause + if let Some(where_clause) = &payload.where_clause { + if contains_null_bytes_in_value(where_clause) { + return Err("Null byte detected in WHERE clause".to_string()); + } + } + + Ok(()) +} + +fn contains_null_bytes_in_value(value: &Value) -> bool { + match value { + Value::String(s) => s.contains('\0'), + Value::Array(arr) => arr.iter().any(contains_null_bytes_in_value), + Value::Object(map) => { + map.keys().any(|k| k.contains('\0')) || map.values().any(contains_null_bytes_in_value) + } + _ => false, + } +} + +// Core execution functions that work with mutable transaction references +// These are used by batch operations to execute multiple queries in a single atomic transaction + +async fn execute_select_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + let mut count = 0; + if let Some(f) = filter { + count += count_filter_conditions(f); + } + if let Some(w) = where_clause { + count += count_where_conditions(w); + } + count + } + + fn count_filter_conditions(filter: &crate::models::FilterCondition) -> usize { + use crate::models::FilterCondition; + match filter { + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + let mut count = 0; + if let Some(conditions) = and_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + if let Some(conditions) = or_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + count + } + _ => 1, + } + } + + fn count_where_conditions(where_clause: &serde_json::Value) -> usize { + match where_clause { + serde_json::Value::Object(map) => { + if map.contains_key("AND") || map.contains_key("OR") { + if let Some(arr) = map.get("AND").or_else(|| map.get("OR")) { + if let serde_json::Value::Array(conditions) = arr { + return conditions.iter().map(|c| count_where_conditions(c)).sum(); + } + } + } + 1 + } + _ => 1, + } + } + + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + let requested_columns = if let Some(ref cols) = payload.columns { + cols.clone() + } else { + vec!["*".to_string()] + }; + + let filtered_columns = if requested_columns.len() == 1 && requested_columns[0] == "*" { + "*".to_string() + } else { + let allowed_columns = + state + .config + .filter_readable_columns(session.power, &table, &requested_columns); + if allowed_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No readable columns available for this request".to_string()), + warning: None, + results: None, + }); + } + allowed_columns.join(", ") + }; + + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT {} FROM {}", filtered_columns, table); + let mut values = Vec::new(); + + if let Some(joins) = &payload.joins { + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + if let Some(order_by) = &payload.order_by { + let order_clause = build_order_by_clause(order_by)?; + query.push_str(&order_clause); + } + + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {} for your power level, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, defaulted to {} (max for power level {})", + max_limit, session.power + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let rows = sqlx_query.fetch_all(&mut **tx).await?; + + let mut results = Vec::new(); + for row in rows { + let mut result_row = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value = convert_sql_value_to_json( + &row, + i, + Some(request_id), + Some(username), + Some(session.power), + Some(state), + )?; + result_row.insert(column.name().to_string(), value); + } + results.push(Value::Object(result_row)); + } + + Ok(QueryResponse { + success: true, + data: Some(Value::Array(results)), + rows_affected: None, + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_count_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where (same as other functions) + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter_conditions(filter: &FilterCondition) -> usize { + match filter { + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + let mut count = 0; + if let Some(conditions) = and_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + if let Some(conditions) = or_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + count + } + _ => 1, + } + } + + fn count_where_conditions(where_clause: &serde_json::Value) -> usize { + match where_clause { + serde_json::Value::Object(map) => { + if map.contains_key("AND") || map.contains_key("OR") { + if let Some(arr) = map.get("AND").or_else(|| map.get("OR")) { + if let serde_json::Value::Array(conditions) = arr { + return conditions.iter().map(|c| count_where_conditions(c)).sum(); + } + } + } + 1 + } + _ => 1, + } + } + + let mut count = 0; + if let Some(f) = filter { + count += count_filter_conditions(f); + } + if let Some(w) = where_clause { + count += count_where_conditions(w); + } + count + } + + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT COUNT(*) as count FROM {}", table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE conditions (filter takes precedence over where_clause if both are provided) + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = crate::sql::build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Log the query + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.fetch_one(&mut **tx).await?; + let count: i64 = result.try_get("count")?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(count)), + rows_affected: None, + error: None, + warning: None, + results: None, + }) +} + +async fn execute_update_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for UPDATE"))? + .clone(); + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for UPDATE"))?; + + let max_where = state.config.get_max_where_conditions(session.power); + let condition_count = where_clause.as_object().map(|obj| obj.len()).unwrap_or(0); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation (system-generated fields bypass write protection) + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + if auto_config.on_action == "update" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + let writable_columns = data + .as_object() + .unwrap() + .keys() + .cloned() + .collect::<Vec<_>>(); + if writable_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in UPDATE data".to_string()), + warning: None, + results: None, + }); + } + + let set_clause = writable_columns + .iter() + .map(|col| format!("{} = ?", col)) + .collect::<Vec<_>>() + .join(", "); + let (where_sql, where_values) = build_where_clause(where_clause)?; + let query = format!("UPDATE {} SET {} WHERE {}", table, set_clause, where_sql); + + let requested_limit = payload.limit; + let max_limit = state.config.get_max_limit(session.power); + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + let query_with_limit = format!("{} LIMIT {}", query, limit); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {}, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query_with_limit, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query_with_limit); + for col in &writable_columns { + if let Some(val) = data.get(col) { + sqlx_query = sqlx_query.bind(val.clone()); + } + } + for val in where_values { + sqlx_query = sqlx_query.bind(val); + } + + let result = sqlx_query.execute(&mut **tx).await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_delete_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for DELETE"))?; + + let max_where = state.config.get_max_where_conditions(session.power); + let condition_count = where_clause.as_object().map(|obj| obj.len()).unwrap_or(0); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let (where_sql, where_values) = build_where_clause(where_clause)?; + let query = format!("DELETE FROM {} WHERE {}", table, where_sql); + + let requested_limit = payload.limit; + let max_limit = state.config.get_max_limit(session.power); + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + let query_with_limit = format!("{} LIMIT {}", query, limit); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {}, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query_with_limit, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query_with_limit); + for val in where_values { + sqlx_query = sqlx_query.bind(val); + } + + let result = sqlx_query.execute(&mut **tx).await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +// Transaction-based execution functions for user context operations +// These create their own transactions and commit them - used for single query operations + +async fn execute_select_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter(cond: &FilterCondition) -> usize { + match cond { + FilterCondition::Simple { .. } => 1, + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + and_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + + or_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + } + FilterCondition::Not { not } => count_filter(not), + } + } + let mut count = 0; + if let Some(f) = filter { + count += count_filter(f); + } + if let Some(w) = where_clause { + if let serde_json::Value::Object(map) = w { + count += map.len(); + } + } + count + } + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Apply granular column filtering based on user's power level + let requested_columns = if let Some(cols) = &payload.columns { + cols.clone() + } else { + vec!["*".to_string()] + }; + + let filtered_columns = if requested_columns.len() == 1 && requested_columns[0] == "*" { + "*".to_string() + } else { + let allowed_columns = + state + .config + .filter_readable_columns(session.power, &table, &requested_columns); + if allowed_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No readable columns available for this request".to_string()), + warning: None, + results: None, + }); + } + allowed_columns.join(", ") + }; + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT {} FROM {}", filtered_columns, table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + // Validate user has read permission for all joined tables + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + + // Build and append JOIN SQL + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE clause - support both legacy and new filter format + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Add ORDER BY if provided + if let Some(order_by) = &payload.order_by { + let order_clause = build_order_by_clause(order_by)?; + query.push_str(&order_clause); + } + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Add OFFSET if provided + if let Some(offset) = payload.offset { + query.push_str(&format!(" OFFSET {}", offset)); + } + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, // Row count will be known after execution + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let rows = sqlx_query.fetch_all(&mut *tx).await?; + tx.commit().await?; + + let mut results = Vec::new(); + for row in rows { + let mut result_row = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value = convert_sql_value_to_json( + &row, + i, + Some(request_id), + Some(username), + Some(session.power), + Some(state), + )?; + result_row.insert(column.name().to_string(), value); + } + results.push(Value::Object(result_row)); + } + + Ok(QueryResponse { + success: true, + data: Some(Value::Array(results)), + rows_affected: None, + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_insert_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for INSERT"))? + .clone(); + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + // This validates what the user is trying to write + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation requirements based on config + // Auto-generated fields bypass write protection since they're system-generated + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + // Check if auto-generation is enabled for INSERT action + if auto_config.on_action == "insert" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + // Generate auto value based on config + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + // Final validation: ensure we have columns to insert + if let Value::Object(ref map) = data { + if map.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in INSERT data".to_string()), + warning: None, + results: None, + }); + } + } + + let (columns_vec, placeholders_vec, values) = build_insert_data(&data)?; + let columns = columns_vec.join(", "); + let placeholders = placeholders_vec.join(", "); + + let query = format!( + "INSERT INTO {} ({}) VALUES ({})", + table, columns, placeholders + ); + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.execute(&mut *tx).await?; + let insert_id = result.last_insert_id(); + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(insert_id)), + rows_affected: Some(result.rows_affected()), + error: None, + warning: None, // INSERT queries don't have LIMIT, + results: None, + }) +} + +async fn execute_update_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for UPDATE"))? + .clone(); + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for UPDATE"))?; + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = if let Some(w) = &payload.where_clause { + if let serde_json::Value::Object(map) = w { + map.len() + } else { + 0 + } + } else { + 0 + }; + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation (system-generated fields bypass write protection) + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + if auto_config.on_action == "update" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + // Final validation: ensure we have columns to update + if let Value::Object(ref map) = data { + if map.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in UPDATE data".to_string()), + warning: None, + results: None, + }); + } + } + + let (set_clause, mut values) = build_update_set_clause(&data)?; + let (where_sql, where_values) = build_where_clause(where_clause)?; + // Convert where_values to Option<String> to match set values + values.extend(where_values.into_iter().map(Some)); + + let mut query = format!("UPDATE {} SET {} WHERE {}", table, set_clause, where_sql); + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.execute(&mut *tx).await?; + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_delete_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for DELETE"))?; + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = if let Some(w) = &payload.where_clause { + if let serde_json::Value::Object(map) = w { + map.len() + } else { + 0 + } + } else { + 0 + }; + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let (where_sql, values) = build_where_clause(where_clause)?; + + let mut query = format!("DELETE FROM {} WHERE {}", table, where_sql); + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + sqlx_query = sqlx_query.bind(value); + } + + let result = sqlx_query.execute(&mut *tx).await?; + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_count_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Check read permissions for the table + if !state.rbac.check_permission( + &state.config, + session.power, + table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to COUNT from table '{}' [request_id: {}]", + table, request_id + )), + warning: None, + results: None, + }); + } + + // Helper to count conditions in filter/where (same as select function) + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter(cond: &FilterCondition) -> usize { + match cond { + FilterCondition::Simple { .. } => 1, + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + and_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + + or_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + } + FilterCondition::Not { not } => count_filter(not), + } + } + let mut count = 0; + if let Some(f) = filter { + count += count_filter(f); + } + if let Some(w) = where_clause { + if let serde_json::Value::Object(map) = w { + count += map.len(); + } + } + count + } + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT COUNT(*) as count FROM {}", table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + // Validate user has read permission for all joined tables + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE conditions (filter takes precedence over where_clause if both are provided) + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = crate::sql::build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.fetch_one(&mut *tx).await?; + tx.commit().await?; + + let count: i64 = result.try_get("count")?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(count)), + rows_affected: None, + error: None, + warning: None, + results: None, + }) +} + +/// Execute multiple queries in a single transaction +async fn execute_batch_mode( + state: AppState, + session: crate::models::Session, + request_id: String, + timestamp: chrono::DateTime<chrono::Utc>, + client_ip: String, + payload: &QueryRequest, +) -> Result<Json<QueryResponse>, StatusCode> { + let queries = payload.queries.as_ref().unwrap(); + + // Check if batch is empty + if queries.is_empty() { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Batch request cannot be empty".to_string()), + warning: None, + results: Some(vec![]), + })); + } + + info!( + "[{}] Batch query request from user {} (power {}): {} queries", + request_id, + session.username, + session.power, + queries.len() + ); + + // Log the batch request (log full payload including action/table) + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/query", + &serde_json::to_value(payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log batch request: {}", request_id, e); + } + + // Get action and table from parent request (all batch queries share these) + let action = payload + .action + .as_ref() + .ok_or_else(|| StatusCode::BAD_REQUEST)?; + + let table = payload + .table + .as_ref() + .ok_or_else(|| StatusCode::BAD_REQUEST)?; + + // Validate table name once for the entire batch + if let Err(e) = validate_table_name(table, &state.config) { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Invalid table name: {}", e)), + warning: None, + results: Some(vec![]), + })); + } + + // Check RBAC permission once for the entire batch + if !state + .rbac + .check_permission(&state.config, session.power, table, action) + { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions for {} on table '{}'", + match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }, + table + )), + warning: None, + results: Some(vec![]), + })); + } + + info!( + "[{}] Validated batch: {} x {:?} on table '{}'", + request_id, + queries.len(), + action, + table + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Validated batch: {} x {:?} on table '{}'", + queries.len(), + action, + table + ), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + if !state.database.is_available() { + warn!( + "[{}] Database marked unavailable before batch execution", + request_id + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + "Database flagged unavailable before batch", + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + + // Start a SINGLE transaction for the batch - proper atomic operation + let mut tx = match state.database.pool().begin().await { + Ok(tx) => { + state.database.mark_available(); + tx + } + Err(e) => { + state.database.mark_unavailable(); + error!("[{}] Failed to begin batch transaction: {}", request_id, e); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to begin batch transaction: {}", e), + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + }; + + // Set user context and request ID in transaction + if let Err(e) = sqlx::query("SET @current_user_id = ?, @request_id = ?") + .bind(session.user_id) + .bind(&request_id) + .execute(&mut *tx) + .await + { + state.database.mark_unavailable(); + error!( + "[{}] Failed to set current user context and request ID: {}", + request_id, e + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to set batch user context: {}", e), + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + + let rollback_on_error = payload.rollback_on_error.unwrap_or(false); + + info!( + "[{}] Executing NATIVE batch: {} x {:?} on '{}' (rollback_on_error={})", + request_id, + queries.len(), + action, + table, + rollback_on_error + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Executing NATIVE batch: {} x {:?} on '{}' (rollback_on_error={})", + queries.len(), + action, + table, + rollback_on_error + ), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + // Execute as SINGLE native batch query based on action type + let result = match action { + QueryAction::Insert => { + execute_batch_insert( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Update => { + execute_batch_update( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Delete => { + execute_batch_delete( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Select => { + // SELECT batches are less common but we'll execute them individually + // (combining SELECTs into one query doesn't make sense as they return different results) + execute_batch_selects( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &session, + &state, + ) + .await + } + QueryAction::Count => { + // COUNT batches execute individually (each returns different results) + execute_batch_counts( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &session, + &state, + ) + .await + } + }; + + match result { + Ok(response) => { + if response.success || !rollback_on_error { + // Commit the transaction + tx.commit().await.map_err(|e| { + error!("[{}] Failed to commit batch transaction: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + info!( + "[{}] Native batch committed: {} operations", + request_id, + queries.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Native batch committed: {} operations", queries.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + Ok(Json(response)) + } else { + // Rollback on error + error!( + "[{}] Rolling back batch transaction due to error", + request_id + ); + tx.rollback().await.map_err(|e| { + error!("[{}] Failed to rollback transaction: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(Json(response)) + } + } + Err(e) => { + error!("[{}] Batch execution failed: {}", request_id, e); + tx.rollback().await.map_err(|e2| { + error!("[{}] Failed to rollback after error: {}", request_id, e2); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Batch execution failed: {}", e)), + warning: None, + results: Some(vec![]), + })) + } + } +} + +// ===== NATIVE BATCH EXECUTION FUNCTIONS ===== + +/// Execute batch INSERT using MySQL multi-value INSERT +async fn execute_batch_insert( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + _action: &QueryAction, + _username: &str, + state: &AppState, + _session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + use serde_json::Value; + + // Extract all data objects, apply auto-generation, and validate they have the same columns + let mut all_data = Vec::new(); + let mut column_set: Option<std::collections::HashSet<String>> = None; + + // Check for auto-generation config + let auto_config = state.config.get_auto_generation_config(&table); + + for (idx, query) in queries.iter().enumerate() { + let mut data = query + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Query {} missing data field for INSERT", idx + 1))? + .clone(); + + // Apply auto-generation if configured for INSERT + if let Some(ref auto_cfg) = auto_config { + if auto_cfg.on_action == "insert" || auto_cfg.on_action == "both" { + if let Value::Object(ref mut map) = data { + let field_name = &auto_cfg.field; + + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + // Generate auto value based on config + let generated_value = generate_auto_value(&state, &table, auto_cfg).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + if let Value::Object(map) = data { + let cols: std::collections::HashSet<String> = map.keys().cloned().collect(); + + if let Some(ref expected_cols) = column_set { + if *expected_cols != cols { + anyhow::bail!("All INSERT queries must have the same columns. Query {} has different columns", idx + 1); + } + } else { + column_set = Some(cols); + } + + all_data.push(map); + } else { + anyhow::bail!("Query {} data must be an object", idx + 1); + } + } + + let columns: Vec<String> = column_set.unwrap().into_iter().collect(); + + // Validate column names + for col in &columns { + validate_column_name(col)?; + } + + // Build multi-value INSERT: INSERT INTO table (col1, col2) VALUES (?, ?), (?, ?), ... + let column_list = columns.join(", "); + let value_placeholder = format!("({})", vec!["?"; columns.len()].join(", ")); + let values_clause = vec![value_placeholder; all_data.len()].join(", "); + + let sql = format!( + "INSERT INTO {} ({}) VALUES {}", + table, column_list, values_clause + ); + + info!( + "[{}] Native batch INSERT: {} rows into {}", + request_id, + all_data.len(), + table + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Native batch INSERT: {} rows into {}", + all_data.len(), + table + ), + Some("query"), + Some(&_session.username), + Some(_session.power), + ); + + // Bind all values + let mut query = sqlx::query(&sql); + for data_map in &all_data { + for col in &columns { + let value = data_map.get(col).and_then(|v| match v { + Value::String(s) => Some(s.clone()), + Value::Number(n) => Some(n.to_string()), + Value::Bool(b) => Some(b.to_string()), + Value::Null => None, + _ => Some(v.to_string()), + }); + query = query.bind(value); + } + } + + // Execute the batch INSERT + let result = query.execute(&mut **tx).await?; + let rows_affected = result.rows_affected(); + + info!( + "[{}] Batch INSERT affected {} rows", + request_id, rows_affected + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch INSERT affected {} rows", rows_affected), + Some("query"), + Some(&_session.username), + Some(_session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(rows_affected), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch UPDATE (executes individually for now, could be optimized with CASE statements) +async fn execute_batch_update( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let mut total_rows = 0u64; + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_update_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + state, + session, + ) + .await?; + + if let Some(rows) = result.rows_affected { + total_rows += rows; + } + } + + info!( + "[{}] Batch UPDATE affected {} total rows", + request_id, total_rows + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch UPDATE affected {} total rows", total_rows), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(total_rows), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch DELETE using IN clause when possible +async fn execute_batch_delete( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + // For now, execute individually + // TODO: Optimize by detecting simple ID-based deletes and combining with IN clause + let mut total_rows = 0u64; + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_delete_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + state, + session, + ) + .await?; + + if let Some(rows) = result.rows_affected { + total_rows += rows; + } + } + + info!( + "[{}] Batch DELETE affected {} total rows", + request_id, total_rows + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch DELETE affected {} total rows", total_rows), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(total_rows), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch SELECT (executes individually since they return different results) +async fn execute_batch_selects( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let mut results = Vec::new(); + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_select_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + session, + state, + ) + .await?; + + results.push(result); + } + + info!( + "[{}] Batch SELECT executed {} queries", + request_id, + results.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch SELECT executed {} queries", results.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: None, + error: None, + warning: None, + results: Some(results), + }) +} + +/// Execute batch COUNT (executes individually since they return different results) +async fn execute_batch_counts( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let mut results = Vec::new(); + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_count_core( + &format!("{}:{}", request_id, idx), + tx, + &query_req, + username, + session, + state, + ) + .await?; + + results.push(result); + } + + info!( + "[{}] Batch COUNT executed {} queries", + request_id, + results.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch COUNT executed {} queries", results.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: None, + error: None, + warning: None, + results: Some(results), + }) +} diff --git a/src/scheduler.rs b/src/scheduler.rs new file mode 100644 index 0000000..ec2710b --- /dev/null +++ b/src/scheduler.rs @@ -0,0 +1,185 @@ +// Scheduled query execution module +use crate::config::{Config, ScheduledQueryTask}; +use crate::db::Database; +use crate::logging::AuditLogger; +use chrono::Utc; +use std::sync::Arc; +use tokio::time::{interval, Duration}; +use tracing::{error, info, warn}; + +pub struct QueryScheduler { + config: Arc<Config>, + database: Database, + logging: AuditLogger, +} + +impl QueryScheduler { + pub fn new(config: Arc<Config>, database: Database, logging: AuditLogger) -> Self { + Self { + config, + database, + logging, + } + } + + /// Spawn background tasks for all enabled scheduled queries + pub fn spawn_tasks(&self) { + let Some(scheduled_config) = &self.config.scheduled_queries else { + info!("No scheduled queries configured"); + return; + }; + + if scheduled_config.tasks.is_empty() { + info!("No scheduled query tasks defined"); + return; + } + + let enabled_tasks: Vec<&ScheduledQueryTask> = scheduled_config + .tasks + .iter() + .filter(|task| task.enabled) + .collect(); + + if enabled_tasks.is_empty() { + info!("No enabled scheduled query tasks"); + return; + } + + info!( + "Spawning {} enabled scheduled query task(s)", + enabled_tasks.len() + ); + + for task in enabled_tasks { + let task_clone = task.clone(); + let database = self.database.clone(); + let logging = self.logging.clone(); + + tokio::spawn(async move { + Self::run_scheduled_task(task_clone, database, logging).await; + }); + } + } + + async fn run_scheduled_task( + task: ScheduledQueryTask, + database: Database, + logging: AuditLogger, + ) { + let mut interval_timer = interval(Duration::from_secs(task.interval_minutes * 60)); + let mut first_run = true; + + info!( + "Scheduled task '{}' started (interval: {}min, run_on_startup: {}): {}", + task.name, task.interval_minutes, task.run_on_startup, task.description + ); + + loop { + interval_timer.tick().await; + + if !database.is_available() { + warn!( + "Skipping scheduled task '{}' because database is unavailable", + task.name + ); + log_task_event( + &logging, + "scheduler_skip", + &task.name, + "Database unavailable, task skipped", + false, + ) + .await; + continue; + } + + // Skip first execution if run_on_startup is false + if first_run && !task.run_on_startup { + first_run = false; + info!( + "Scheduled task '{}' skipping initial run (run_on_startup=false)", + task.name + ); + continue; + } + first_run = false; + + match sqlx::query(&task.query).execute(database.pool()).await { + Ok(result) => { + database.mark_available(); + info!( + "Scheduled task '{}' executed successfully (rows affected: {})", + task.name, + result.rows_affected() + ); + log_task_event( + &logging, + "scheduler_success", + &task.name, + &format!( + "Task executed successfully (rows affected: {})", + result.rows_affected() + ), + false, + ) + .await; + } + Err(e) => { + database.mark_unavailable(); + error!( + "Scheduled task '{}' failed: {} (query: {})", + task.name, e, task.query + ); + log_task_event( + &logging, + "scheduler_failure", + &task.name, + &format!("Task failed: {}", e), + true, + ) + .await; + } + } + } + } +} + +async fn log_task_event( + logging: &AuditLogger, + context: &str, + task_name: &str, + message: &str, + as_error: bool, +) { + let request_id = AuditLogger::generate_request_id(); + let timestamp = Utc::now(); + let full_message = format!("{}: {}", task_name, message); + + let result = if as_error { + logging + .log_error( + &request_id, + timestamp, + &full_message, + Some(context), + Some("system"), + None, + ) + .await + } else { + logging + .log_info( + &request_id, + timestamp, + &full_message, + Some(context), + Some("system"), + None, + ) + .await + }; + + if let Err(err) = result { + error!("Failed to record scheduler event ({}): {}", context, err); + } +} diff --git a/src/sql/builder.rs b/src/sql/builder.rs new file mode 100644 index 0000000..7ba84f2 --- /dev/null +++ b/src/sql/builder.rs @@ -0,0 +1,493 @@ +// Enhanced SQL query builder with proper validation and complex WHERE support +use anyhow::{Context, Result}; +use regex::Regex; +use serde_json::Value; + +use crate::config::Config; +use crate::models::{FilterCondition, FilterOperator, OrderBy, OrderDirection}; + +/// Parse a table reference possibly containing an alias into (base_table, alias) +/// Accepts formats like: "table", "table alias", "table AS alias" (AS case-insensitive) +pub fn parse_table_and_alias(input: &str) -> (String, Option<String>) { + let s = input.trim(); + // Normalize multiple spaces + let parts: Vec<&str> = s.split_whitespace().collect(); + if parts.is_empty() { + return (String::new(), None); + } + + if parts.len() == 1 { + return (parts[0].to_string(), None); + } + + if parts.len() >= 3 && parts[1].eq_ignore_ascii_case("AS") { + // table AS alias [ignore extra] + return (parts[0].to_string(), Some(parts[2].to_string())); + } + + // table alias + (parts[0].to_string(), Some(parts[1].to_string())) +} + +/// Validates a table name against known tables and SQL injection patterns +pub fn validate_table_name(table: &str, config: &Config) -> Result<()> { + // Check if empty + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty"); + } + + // Check against known tables list + let known_tables = config.get_known_tables(); + if !known_tables.contains(&table.to_string()) { + anyhow::bail!("Table '{}' is not in the known tables list", table); + } + + // Validate format: only alphanumeric and underscores, must start with letter + let table_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table name regex")?; + + if !table_regex.is_match(table) { + anyhow::bail!("Invalid table name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", table); + } + + // Additional security: check for SQL keywords and dangerous patterns + let dangerous_patterns = [ + "--", "/*", "*/", ";", "DROP", "ALTER", "CREATE", "EXEC", "EXECUTE", + ]; + let table_upper = table.to_uppercase(); + for pattern in &dangerous_patterns { + if table_upper.contains(pattern) { + anyhow::bail!("Table name contains forbidden pattern: '{}'", pattern); + } + } + + Ok(()) +} + +/// Validates a column name against SQL injection patterns +/// Supports: column, table.column, table.column as alias, table.* +pub fn validate_column_name(column: &str) -> Result<()> { + // Check if empty + if column.trim().is_empty() { + anyhow::bail!("Column name cannot be empty"); + } + + // Allow * for SELECT all + if column == "*" { + return Ok(()); + } + + // Check for SQL injection patterns (comments and statement terminators) + let dangerous_chars = ["--", "/*", "*/", ";"]; + for pattern in &dangerous_chars { + if column.contains(pattern) { + anyhow::bail!("Column name contains forbidden pattern: '{}'", pattern); + } + } + + // Note: We don't block SQL keywords like DROP, CREATE, ALTER etc. in column names + // because legitimate columns like "created_date", "created_by" contain these as substrings. + // The regex validation below ensures only alphanumeric + underscore characters are allowed, + // which prevents actual SQL injection while allowing valid column names. + + // Support multiple formats: + // 1. Simple column: column_name + // 2. Qualified column: table.column_name + // 3. Wildcard: table.* + // 4. Alias: column_name as alias or table.column_name as alias + + // Remove AS alias if present (case insensitive) + let column_upper = column.to_uppercase(); + let column_without_alias = if column_upper.contains(" AS ") { + column.split_whitespace().next().unwrap_or("") + } else { + column + }; + + // Check for qualified column (table.column or table.*) + if column_without_alias.contains('.') { + let parts: Vec<&str> = column_without_alias.split('.').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Invalid qualified column format: '{}'. Must be table.column", + column + ); + } + + // Validate table part + let table_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table regex")?; + if !table_regex.is_match(parts[0]) { + anyhow::bail!("Invalid table name in qualified column: '{}'", parts[0]); + } + + // Validate column part (allow * for table.*) + if parts[1] != "*" { + let column_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile column regex")?; + if !column_regex.is_match(parts[1]) { + anyhow::bail!("Invalid column name in qualified column: '{}'", parts[1]); + } + } + } else { + // Simple column name validation + let column_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$") + .context("Failed to compile column name regex")?; + + if !column_regex.is_match(column_without_alias) { + anyhow::bail!("Invalid column name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", column); + } + } + + Ok(()) +} + +/// Validates multiple column names +pub fn validate_column_names(columns: &[String]) -> Result<()> { + for column in columns { + validate_column_name(column).with_context(|| format!("Invalid column name: {}", column))?; + } + Ok(()) +} + +/// Build WHERE clause from enhanced FilterCondition +/// Supports complex operators (=, !=, >, <, LIKE, IN, IS NULL, etc.) and nested logic (AND, OR, NOT) +pub fn build_filter_clause(filter: &FilterCondition) -> Result<(String, Vec<String>)> { + match filter { + FilterCondition::Simple { + column, + operator, + value, + } => { + validate_column_name(column)?; + build_simple_condition(column, operator, value) + } + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + if let Some(and_conds) = and_conditions { + if and_conds.is_empty() { + anyhow::bail!("AND conditions array cannot be empty"); + } + let mut conditions = Vec::new(); + let mut all_values = Vec::new(); + + for cond in and_conds { + let (sql, values) = build_filter_clause(cond)?; + conditions.push(format!("({})", sql)); + all_values.extend(values); + } + + Ok((conditions.join(" AND "), all_values)) + } else if let Some(or_conds) = or_conditions { + if or_conds.is_empty() { + anyhow::bail!("OR conditions array cannot be empty"); + } + let mut conditions = Vec::new(); + let mut all_values = Vec::new(); + + for cond in or_conds { + let (sql, values) = build_filter_clause(cond)?; + conditions.push(format!("({})", sql)); + all_values.extend(values); + } + + Ok((conditions.join(" OR "), all_values)) + } else { + anyhow::bail!("Logical condition must have either 'and' or 'or' field"); + } + } + FilterCondition::Not { not } => { + let (sql, values) = build_filter_clause(not)?; + Ok((format!("NOT ({})", sql), values)) + } + } +} + +/// Build a simple condition (column operator value) +fn build_simple_condition( + column: &str, + operator: &FilterOperator, + value: &Value, +) -> Result<(String, Vec<String>)> { + match operator { + FilterOperator::Eq => { + if value.is_null() { + Ok((format!("{} IS NULL", column), vec![])) + } else { + Ok((format!("{} = ?", column), vec![json_to_sql_string(value)])) + } + } + FilterOperator::Ne => { + if value.is_null() { + Ok((format!("{} IS NOT NULL", column), vec![])) + } else { + Ok((format!("{} != ?", column), vec![json_to_sql_string(value)])) + } + } + FilterOperator::Gt => Ok((format!("{} > ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Gte => Ok((format!("{} >= ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Lt => Ok((format!("{} < ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Lte => Ok((format!("{} <= ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Like => Ok(( + format!("{} LIKE ?", column), + vec![json_to_sql_string(value)], + )), + FilterOperator::NotLike => Ok(( + format!("{} NOT LIKE ?", column), + vec![json_to_sql_string(value)], + )), + FilterOperator::In => { + if let Value::Array(arr) = value { + if arr.is_empty() { + anyhow::bail!("IN operator requires non-empty array"); + } + let placeholders = vec!["?"; arr.len()].join(", "); + let values: Vec<String> = arr.iter().map(json_to_sql_string).collect(); + Ok((format!("{} IN ({})", column, placeholders), values)) + } else { + anyhow::bail!("IN operator requires array value"); + } + } + FilterOperator::NotIn => { + if let Value::Array(arr) = value { + if arr.is_empty() { + anyhow::bail!("NOT IN operator requires non-empty array"); + } + let placeholders = vec!["?"; arr.len()].join(", "); + let values: Vec<String> = arr.iter().map(json_to_sql_string).collect(); + Ok((format!("{} NOT IN ({})", column, placeholders), values)) + } else { + anyhow::bail!("NOT IN operator requires array value"); + } + } + FilterOperator::IsNull => { + // Value is ignored for IS NULL + Ok((format!("{} IS NULL", column), vec![])) + } + FilterOperator::IsNotNull => { + // Value is ignored for IS NOT NULL + Ok((format!("{} IS NOT NULL", column), vec![])) + } + FilterOperator::Between => { + if let Value::Array(arr) = value { + if arr.len() != 2 { + anyhow::bail!("BETWEEN operator requires array with exactly 2 values"); + } + let val1 = json_to_sql_string(&arr[0]); + let val2 = json_to_sql_string(&arr[1]); + Ok((format!("{} BETWEEN ? AND ?", column), vec![val1, val2])) + } else { + anyhow::bail!("BETWEEN operator requires array with [min, max] values"); + } + } + } +} + +/// Convert JSON value to SQL string representation +/// Properly handles all JSON types including booleans (true/false -> 1/0) +fn json_to_sql_string(value: &Value) -> String { + match value { + Value::String(s) => s.clone(), + Value::Number(n) => n.to_string(), + Value::Bool(b) => { + if *b { + "1".to_string() + } else { + "0".to_string() + } + } + Value::Null => "NULL".to_string(), + _ => serde_json::to_string(value).unwrap_or_else(|_| "NULL".to_string()), + } +} + +/// Build legacy WHERE clause from simple key-value JSON (for backward compatibility) +pub fn build_legacy_where_clause(where_clause: &Value) -> Result<(String, Vec<String>)> { + let mut conditions = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = where_clause { + for (key, value) in map { + validate_column_name(key)?; + + if value.is_null() { + // Handle NULL values with IS NULL + conditions.push(format!("{} IS NULL", key)); + // Don't add to values since IS NULL doesn't need a parameter + } else { + conditions.push(format!("{} = ?", key)); + values.push(json_to_sql_string(value)); + } + } + } else { + anyhow::bail!("WHERE clause must be an object"); + } + + if conditions.is_empty() { + anyhow::bail!("WHERE clause cannot be empty"); + } + + Ok((conditions.join(" AND "), values)) +} + +/// Build ORDER BY clause with column validation +pub fn build_order_by_clause(order_by: &[OrderBy]) -> Result<String> { + if order_by.is_empty() { + return Ok(String::new()); + } + + let mut clauses = Vec::new(); + for order in order_by { + validate_column_name(&order.column)?; + let direction = match order.direction { + OrderDirection::ASC => "ASC", + OrderDirection::DESC => "DESC", + }; + clauses.push(format!("{} {}", order.column, direction)); + } + + Ok(format!(" ORDER BY {}", clauses.join(", "))) +} + +/// Build JOIN clause from Join specifications +/// Validates table names and join conditions for security +pub fn build_join_clause(joins: &[crate::models::Join], config: &Config) -> Result<String> { + if joins.is_empty() { + return Ok(String::new()); + } + + let mut join_sql = String::new(); + + for join in joins { + // Extract base table and optional alias + let (base_table, alias) = parse_table_and_alias(&join.table); + + // Validate joined base table name + validate_table_name(&base_table, config)?; + + // Optionally validate alias format (same rules as table/column names) + if let Some(alias_name) = &alias { + let alias_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile alias regex")?; + if !alias_regex.is_match(alias_name) { + anyhow::bail!("Invalid table alias format: '{}'", alias_name); + } + } + + // Validate join condition (must be in format "table1.column1 = table2.column2") + validate_join_condition(&join.on)?; + + // Build JOIN clause based on type + let join_type_str = match join.join_type { + crate::models::JoinType::Inner => "INNER JOIN", + crate::models::JoinType::Left => "LEFT JOIN", + crate::models::JoinType::Right => "RIGHT JOIN", + }; + + // Reconstruct safe table reference + let table_ref = match alias { + Some(a) => format!("{} AS {}", base_table, a), + None => base_table, + }; + join_sql.push_str(&format!(" {} {} ON {}", join_type_str, table_ref, join.on)); + } + + Ok(join_sql) +} + +/// Validate JOIN ON condition +/// Must be in format: "table1.column1 = table2.column2" or similar simple conditions +fn validate_join_condition(condition: &str) -> Result<()> { + // Basic validation: must contain = and table.column references + if !condition.contains('=') { + anyhow::bail!("JOIN condition must contain = operator"); + } + + // Split by = and validate both sides + let parts: Vec<&str> = condition.split('=').map(|s| s.trim()).collect(); + if parts.len() != 2 { + anyhow::bail!("JOIN condition must have exactly one = operator"); + } + + // Validate both sides are table.column format + for part in parts { + validate_table_column_reference(part)?; + } + + Ok(()) +} + +/// Validate table.column reference (e.g., "assets.category_id") +fn validate_table_column_reference(reference: &str) -> Result<()> { + let parts: Vec<&str> = reference.split('.').collect(); + + if parts.len() != 2 { + anyhow::bail!( + "Table column reference must be in format 'table.column', got: {}", + reference + ); + } + + let table = parts[0].trim(); + let column = parts[1].trim(); + + // Validate table and column names follow safe patterns + let name_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile name regex")?; + + if !name_regex.is_match(table) { + anyhow::bail!("Invalid table name in JOIN condition: {}", table); + } + + if !name_regex.is_match(column) { + anyhow::bail!("Invalid column name in JOIN condition: {}", column); + } + + // Additional guard against comment/terminator tokens (redundant given regex, but safe) + // Intentionally do NOT block SQL keywords like CREATE/ALTER since identifiers like + // 'created_at' are legitimate and already validated by the regex above. + let dangerous_tokens = ["--", "/*", "*/", ";"]; + for token in &dangerous_tokens { + if reference.contains(token) { + anyhow::bail!("JOIN condition contains forbidden token: '{}'", token); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_table_name() { + // Valid names + assert!(validate_column_name("users").is_ok()); + assert!(validate_column_name("asset_id").is_ok()); + assert!(validate_column_name("table123").is_ok()); + + // Invalid names + assert!(validate_column_name("123table").is_err()); // starts with number + assert!(validate_column_name("table-name").is_err()); // contains hyphen + assert!(validate_column_name("table name").is_err()); // contains space + assert!(validate_column_name("table;DROP").is_err()); // SQL injection + assert!(validate_column_name("").is_err()); // empty + } + + #[test] + fn test_validate_column_name() { + // Valid names + assert!(validate_column_name("user_id").is_ok()); + assert!(validate_column_name("firstName").is_ok()); + assert!(validate_column_name("*").is_ok()); // wildcard allowed + + // Invalid names + assert!(validate_column_name("123column").is_err()); + assert!(validate_column_name("col-umn").is_err()); + assert!(validate_column_name("col umn").is_err()); + assert!(validate_column_name("").is_err()); + } +} diff --git a/src/sql/mod.rs b/src/sql/mod.rs new file mode 100644 index 0000000..fd5b959 --- /dev/null +++ b/src/sql/mod.rs @@ -0,0 +1,8 @@ +// SQL query building and validation module +pub mod builder; + +// Re-export commonly used functions +pub use builder::{ + build_filter_clause, build_join_clause, build_legacy_where_clause, build_order_by_clause, + validate_column_name, validate_column_names, validate_table_name, +}; |
