diff options
| author | UMTS at Teleco <crt@teleco.ch> | 2025-12-13 02:48:13 +0100 |
|---|---|---|
| committer | UMTS at Teleco <crt@teleco.ch> | 2025-12-13 02:48:13 +0100 |
| commit | e52b8e1c2e110d0feb74feb7905c2ff064b51d55 (patch) | |
| tree | 3090814e422250e07e72cf1c83241ffd95cf20f7 /src/routes | |
Diffstat (limited to 'src/routes')
| -rw-r--r-- | src/routes/auth.rs | 513 | ||||
| -rw-r--r-- | src/routes/mod.rs | 111 | ||||
| -rw-r--r-- | src/routes/preferences.rs | 423 | ||||
| -rw-r--r-- | src/routes/query.rs | 3255 |
4 files changed, 4302 insertions, 0 deletions
diff --git a/src/routes/auth.rs b/src/routes/auth.rs new file mode 100644 index 0000000..c124e03 --- /dev/null +++ b/src/routes/auth.rs @@ -0,0 +1,513 @@ +// Authentication routes +use axum::{ + extract::{ConnectInfo, State}, + http::StatusCode, + Json, +}; +use chrono::Utc; +use std::net::SocketAddr; +use tracing::{error, warn}; + +use crate::logging::AuditLogger; +use crate::models::{AuthMethod, LoginRequest, LoginResponse, UserInfo}; +use crate::{auth, AppState}; + +pub async fn login( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + Json(payload): Json<LoginRequest>, +) -> Result<Json<LoginResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + None, + None, + "/auth/login", + &serde_json::to_value(&payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Validate request based on auth method + let (user, role) = match payload.method { + AuthMethod::Password => { + // Password auth - allowed from any IP + if let (Some(username), Some(password)) = (&payload.username, &payload.password) { + match auth::password::authenticate_password( + state.database.pool(), + username, + password, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed password authentication for user: {} - Invalid credentials", + username + ), + Some("password_auth"), + Some(username), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during password authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Password auth error: {}", e), + Some("authentication"), + payload.username.as_deref(), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("Password authentication attempted without username or password"); + super::log_warning_async( + &state.logging, + &request_id, + "Password authentication attempted without username or password", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + AuthMethod::Pin => { + // PIN auth - only from whitelisted IPs + if !state.config.is_pin_ip_whitelisted(&client_ip) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "PIN authentication attempted from non-whitelisted IP: {}", + client_ip + ), + Some("security_violation"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + + if let (Some(username), Some(pin)) = (&payload.username, &payload.pin) { + match auth::pin::authenticate_pin( + state.database.pool(), + username, + pin, + &state.config.security, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!("Failed PIN authentication for user: {} - Invalid PIN or PIN not configured", username), + Some("pin_auth"), + Some(username), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during PIN authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("PIN auth error: {}", e), + Some("authentication"), + payload.username.as_deref(), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("PIN authentication attempted without username or PIN"); + super::log_warning_async( + &state.logging, + &request_id, + "PIN authentication attempted without username or PIN", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + AuthMethod::Token => { + // Token/RFID auth - only from whitelisted IPs + if !state.config.is_string_ip_whitelisted(&client_ip) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Token authentication attempted from non-whitelisted IP: {}", + client_ip + ), + Some("security_violation"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + + if let Some(login_string) = &payload.login_string { + match auth::token::authenticate_token( + state.database.pool(), + login_string, + &state.config.security, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed token authentication for login_string: {} - Invalid token", + login_string + ), + Some("token_auth"), + Some(login_string), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during token authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Token auth error: {}", e), + Some("authentication"), + Some(login_string), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("Token authentication attempted without login_string"); + super::log_warning_async( + &state.logging, + &request_id, + "Token authentication attempted without login_string", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + }; + + // Create session token + let token = state.session_manager.create_session( + user.id, + user.username.clone(), + role.id, + role.name.clone(), + role.power, + ); + + // Log successful login + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Successful login for user: {} ({})", + user.username, user.name + ), + Some("authentication"), + Some(&user.username), + Some(role.power), + ); + + Ok(Json(LoginResponse { + success: true, + token: Some(token), + user: Some(UserInfo { + id: user.id, + username: user.username, + name: user.name, + role: role.name, + power: role.power, + }), + error: None, + })) +} + +pub async fn logout( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: axum::http::HeaderMap, +) -> Result<Json<serde_json::Value>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract token from Authorization header + let token = match headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) { + Some(t) => t, + None => { + return Ok(Json(serde_json::json!({ + "success": false, + "error": "No authorization token provided" + }))); + } + }; + + // Get username for logging before removing session + let username = state + .session_manager + .get_session(&token) + .map(|s| s.username.clone()); + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + username.as_deref(), + None, + "/auth/logout", + &serde_json::json!({"action": "logout"}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + let removed = state.session_manager.remove_session(&token); + + if removed { + // Log successful logout + super::log_info_async( + &state.logging, + &request_id, + &format!( + "User {} logged out successfully", + username.as_deref().unwrap_or("unknown") + ), + Some("authentication"), + username.as_deref(), + None, + ); + Ok(Json(serde_json::json!({ + "success": true, + "message": "Logged out successfully" + }))) + } else { + Ok(Json(serde_json::json!({ + "success": false, + "error": "Invalid or expired token" + }))) + } +} + +pub async fn status( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: axum::http::HeaderMap, +) -> Result<Json<serde_json::Value>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract token from Authorization header + let token_opt = headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }); + + let token = match token_opt { + Some(t) => t, + None => { + return Ok(Json(serde_json::json!({ + "success": false, + "valid": false, + "error": "No authorization token provided" + }))); + } + }; + + // Check session validity + match state.session_manager.get_session(&token) { + Some(session) => { + let now = Utc::now(); + let timeout_minutes = state.config.get_session_timeout(session.power); + let elapsed = (now - session.last_accessed).num_seconds(); + let timeout_seconds = (timeout_minutes * 60) as i64; + let remaining_seconds = timeout_seconds - elapsed; + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/auth/status", + &serde_json::json!({"token_provided": true}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + Ok(Json(serde_json::json!({ + "success": true, + "valid": true, + "user": { + "id": session.user_id, + "username": session.username, + "name": session.username, + "role": session.role_name, + "power": session.power + }, + "session": { + "created_at": session.created_at.to_rfc3339(), + "last_accessed": session.last_accessed.to_rfc3339(), + "timeout_minutes": timeout_minutes, + "remaining_seconds": remaining_seconds.max(0), + "expires_at": (session.last_accessed + chrono::Duration::minutes(timeout_minutes as i64)).to_rfc3339() + } + }))) + } + None => { + // Log the request for invalid token + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + None, + None, + "/auth/status", + &serde_json::json!({"token_provided": true, "valid": false}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + Ok(Json(serde_json::json!({ + "success": true, + "valid": false, + "message": "Session expired or invalid" + }))) + } + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs new file mode 100644 index 0000000..b81a0a1 --- /dev/null +++ b/src/routes/mod.rs @@ -0,0 +1,111 @@ +pub mod auth; +pub mod preferences; +pub mod query; + +use crate::logging::logger::AuditLogger; +use tracing::{error, info, warn}; + +/// Helper function to log errors to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_error_async( + logging: &AuditLogger, + request_id: &str, + error_msg: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + error!("[{}] {}", request_id, error_msg); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let error_msg = error_msg.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_error( + &req_id, + chrono::Utc::now(), + &error_msg, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} + +/// Helper function to log warnings to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_warning_async( + logging: &AuditLogger, + request_id: &str, + message: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + warn!("[{}] {}", request_id, message); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let message = message.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_warning( + &req_id, + chrono::Utc::now(), + &message, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} + +/// Helper function to log info messages to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_info_async( + logging: &AuditLogger, + request_id: &str, + message: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + info!("[{}] {}", request_id, message); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let message = message.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_info( + &req_id, + chrono::Utc::now(), + &message, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} diff --git a/src/routes/preferences.rs b/src/routes/preferences.rs new file mode 100644 index 0000000..e58d823 --- /dev/null +++ b/src/routes/preferences.rs @@ -0,0 +1,423 @@ +// User preferences routes +use axum::{ + extract::{ConnectInfo, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use chrono::Utc; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use tracing::{error, warn}; + +use crate::config::UserSettingsAccess; +use crate::logging::AuditLogger; +use crate::AppState; + +// Request/Response structures matching the query route pattern +#[derive(Debug, Deserialize)] +pub struct PreferencesRequest { + pub action: String, // "get", "set", "reset" + pub user_id: Option<i32>, // For admin access to other users + pub preferences: Option<serde_json::Value>, +} + +#[derive(Debug, Serialize)] +pub struct PreferencesResponse { + pub success: bool, + pub preferences: Option<serde_json::Value>, + pub error: Option<String>, +} + +/// Extract token from Authorization header +fn extract_token(headers: &HeaderMap) -> Option<String> { + headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) +} + +/// POST /preferences - Handle all preference operations (get, set, reset) +pub async fn handle_preferences( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, + Json(payload): Json<PreferencesRequest>, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some( + "Please stop trying to access this resource without signing in".to_string(), + ), + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Session not found".to_string()), + })); + } + }; + + // Determine target user ID + let target_user_id = payload.user_id.unwrap_or(session.user_id); + + // Get user's permission level for preferences + let user_settings_permission = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .map(|p| &p.user_settings_access) + .unwrap_or(&state.config.security.default_user_settings_access); + + // Check permissions for cross-user access + if target_user_id != session.user_id { + if *user_settings_permission != UserSettingsAccess::ReadWriteAll { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!("User {} (power {}) attempted to access preferences of user {} without permission", + session.username, session.power, target_user_id), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } + + // Check write permissions for set/reset actions + if payload.action == "set" || payload.action == "reset" { + if target_user_id == session.user_id { + // Writing own preferences - need at least ReadWriteOwn + if *user_settings_permission == UserSettingsAccess::ReadOwnOnly { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } else { + // Writing others' preferences - need ReadWriteAll + if *user_settings_permission != UserSettingsAccess::ReadWriteAll { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } + } + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/preferences", + &serde_json::json!({"action": payload.action, "target_user_id": target_user_id}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Handle the action + match payload.action.as_str() { + "get" => { + handle_get_preferences( + state, + request_id, + target_user_id, + session.username.clone(), + session.power, + ) + .await + } + "set" => { + handle_set_preferences( + state, + request_id, + target_user_id, + payload.preferences, + session.username.clone(), + session.power, + ) + .await + } + "reset" => { + handle_reset_preferences( + state, + request_id, + target_user_id, + session.username.clone(), + session.power, + ) + .await + } + _ => Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some(format!("Invalid action: {}", payload.action)), + })), + } +} + +async fn handle_get_preferences( + state: AppState, + request_id: String, + user_id: i32, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + // Cast JSON column to CHAR to get string representation + let query = "SELECT CAST(preferences AS CHAR) FROM users WHERE id = ? AND active = TRUE"; + let row: Option<(Option<String>,)> = sqlx::query_as(query) + .bind(user_id) + .fetch_optional(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error fetching preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let preferences = match row { + Some((Some(prefs_str),)) => serde_json::from_str(&prefs_str).unwrap_or_else(|e| { + warn!( + "[{}] Failed to parse preferences JSON for user {}: {}", + request_id, user_id, e + ); + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed to parse preferences JSON for user {}: {}", + user_id, e + ), + Some("data_integrity"), + Some(&username), + Some(power), + ); + serde_json::json!({}) + }), + _ => serde_json::json!({}), + }; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!( + "User {} retrieved preferences for user {}", + username, user_id + ), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(preferences), + error: None, + })) +} + +async fn handle_set_preferences( + state: AppState, + request_id: String, + user_id: i32, + new_preferences: Option<serde_json::Value>, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let new_prefs = match new_preferences { + Some(prefs) => prefs, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Missing preferences field".to_string()), + })); + } + }; + + // Get current preferences for merging + let query = "SELECT CAST(preferences AS CHAR) FROM users WHERE id = ? AND active = TRUE"; + let row: Option<(Option<String>,)> = sqlx::query_as(query) + .bind(user_id) + .fetch_optional(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error fetching preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Deep merge the preferences + let mut merged_prefs = match row { + Some((Some(prefs_str),)) => { + serde_json::from_str(&prefs_str).unwrap_or_else(|_| serde_json::json!({})) + } + _ => serde_json::json!({}), + }; + + // Merge function + fn merge_json(base: &mut serde_json::Value, update: &serde_json::Value) { + if let (Some(base_obj), Some(update_obj)) = (base.as_object_mut(), update.as_object()) { + for (key, value) in update_obj { + if let Some(base_value) = base_obj.get_mut(key) { + if base_value.is_object() && value.is_object() { + merge_json(base_value, value); + } else { + *base_value = value.clone(); + } + } else { + base_obj.insert(key.clone(), value.clone()); + } + } + } + } + + merge_json(&mut merged_prefs, &new_prefs); + + // Save to database + let prefs_str = serde_json::to_string(&merged_prefs).map_err(|e| { + error!("[{}] Failed to serialize preferences: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let update_query = "UPDATE users SET preferences = ? WHERE id = ? AND active = TRUE"; + sqlx::query(update_query) + .bind(&prefs_str) + .bind(user_id) + .execute(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error updating preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!("User {} updated preferences for user {}", username, user_id), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(merged_prefs), + error: None, + })) +} + +async fn handle_reset_preferences( + state: AppState, + request_id: String, + user_id: i32, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let query = "UPDATE users SET preferences = NULL WHERE id = ? AND active = TRUE"; + sqlx::query(query) + .bind(user_id) + .execute(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error resetting preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!("User {} reset preferences for user {}", username, user_id), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(serde_json::json!({})), + error: None, + })) +} diff --git a/src/routes/query.rs b/src/routes/query.rs new file mode 100644 index 0000000..9814215 --- /dev/null +++ b/src/routes/query.rs @@ -0,0 +1,3255 @@ +// Query routes and execution +use anyhow::{Context, Result}; +use axum::{ + extract::{ConnectInfo, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use chrono::Utc; +use rand::Rng; +use serde_json::Value; +use sqlx::{Column, Row}; +use std::collections::HashMap; +use std::net::SocketAddr; +use tracing::{error, info, warn}; + +use crate::logging::AuditLogger; +use crate::models::{PermissionsResponse, QueryAction, QueryRequest, QueryResponse, UserInfo}; +use crate::sql::{ + build_filter_clause, build_legacy_where_clause, build_order_by_clause, validate_column_name, + validate_column_names, validate_table_name, +}; +use crate::AppState; + +// Helper function to extract token from Authorization header +fn extract_token(headers: &HeaderMap) -> Option<String> { + headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) +} + +fn database_unavailable_response(request_id: &str) -> QueryResponse { + QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Database temporarily unavailable, try again in a moment [request_id: {}]", + request_id + )), + warning: None, + results: None, + } +} + +fn database_unavailable_batch_response(request_id: &str) -> QueryResponse { + let mut base = database_unavailable_response(request_id); + base.results = Some(vec![]); + base +} + +async fn log_database_unavailable_event( + logger: &AuditLogger, + request_id: &str, + username: Option<&str>, + power: Option<i32>, + detail: &str, +) { + if let Err(err) = logger + .log_error( + request_id, + Utc::now(), + detail, + Some("database_unavailable"), + username, + power, + ) + .await + { + error!("[{}] Failed to record database outage: {}", request_id, err); + } +} + +pub async fn execute_query( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, + Json(payload): Json<QueryRequest>, +) -> Result<Json<QueryResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some( + "Please stop trying to access this resource without signing in".to_string(), + ), + warning: None, + results: None, + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Session not found".to_string()), + warning: None, + results: None, + })); + } + }; + + // Detect batch mode - if queries field is present, handle as batch operation + if payload.queries.is_some() { + // SECURITY: Check if user has permission to use batch operations + let power_perms = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .ok_or(StatusCode::FORBIDDEN)?; + + if !power_perms.allow_batch_operations { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "User {} (power {}) attempted batch operation without permission", + session.username, session.power + ), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Batch operations not permitted for your role".to_string()), + warning: None, + results: None, + })); + } + + return execute_batch_mode(state, session, request_id, timestamp, client_ip, &payload) + .await; + } + + // Validate input for very basic security vulnerabilities (null bytes, etc.) + if let Err(security_error) = validate_input_security(&payload) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Security validation failed for user {}: {}", + session.username, security_error + ), + Some("security_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Invalid input detected, how did you even manage to do that? [request_id: {}]", + request_id + )), + warning: None, + results: None, + })); + } + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/query", + &serde_json::to_value(&payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Clone payload before extracting fields (to avoid partial move issues) + let payload_clone = payload.clone(); + + // Single query mode - validate required fields + let action = payload.action.ok_or_else(|| { + let error_msg = "Missing action field in single query mode"; + super::log_error_async( + &state.logging, + &request_id, + error_msg, + Some("request_validation"), + Some(&session.username), + Some(session.power), + ); + StatusCode::BAD_REQUEST + })?; + + let table = payload.table.ok_or_else(|| { + let error_msg = "Missing table field in single query mode"; + super::log_error_async( + &state.logging, + &request_id, + error_msg, + Some("request_validation"), + Some(&session.username), + Some(session.power), + ); + StatusCode::BAD_REQUEST + })?; + + // SECURITY: Validate table name before any operations + if let Err(e) = validate_table_name(&table, &state.config) { + let error_msg = format!("Invalid table name '{}': {}", table, e); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("table_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Invalid table name [request_id: {}]", request_id)), + warning: None, + results: None, + })); + } + + // SECURITY: Validate column names if specified + if let Some(ref columns) = payload.columns { + if let Err(e) = validate_column_names(columns) { + let error_msg = format!("Invalid column names on table '{}': {}", table, e); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("column_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Invalid column name: {} [request_id: {}]", + e, request_id + )), + warning: None, + results: None, + })); + } + } + + // Check permissions (after validation to avoid leaking table existence) + if !state + .rbac + .check_permission(&state.config, session.power, &table, &action) + { + let action_str = match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }; + + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "User {} attempted unauthorized {} on table {}", + session.username, action_str, table + ), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + + // Log security violation + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Permission denied: {} on table {}", action_str, table), + Some("authorization"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log permission denial: {}", + request_id, log_err + ); + } + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Insufficient permissions for this operation, Might as well give up [request_id: {}]", request_id)), + warning: None, + results: None, + })); + } + + if !state.database.is_available() { + warn!( + "[{}] Database marked unavailable, returning graceful error", + request_id + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + "Database flagged unavailable before transaction", + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + + // ALL database operations now use transactions with user context set + let mut tx = match state.database.pool().begin().await { + Ok(tx) => { + state.database.mark_available(); + tx + } + Err(e) => { + state.database.mark_unavailable(); + error!("[{}] Failed to begin transaction: {}", request_id, e); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to begin transaction: {}", e), + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + }; + + // Set user context and request ID in transaction - ALL queries have user context now + if let Err(e) = sqlx::query("SET @current_user_id = ?, @request_id = ?") + .bind(session.user_id) + .bind(&request_id) + .execute(&mut *tx) + .await + { + state.database.mark_unavailable(); + error!( + "[{}] Failed to set current user context and request ID: {}", + request_id, e + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to set user context: {}", e), + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + + // Execute the query within the transaction + let result = match action { + QueryAction::Select => { + execute_select_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &session, + &state, + ) + .await + } + QueryAction::Insert => { + execute_insert_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Update => { + execute_update_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Delete => { + execute_delete_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Count => { + execute_count_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &session, + &state, + ) + .await + } + }; + + match result { + Ok(response) => { + let action_str = match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }; + + super::log_info_async( + &state.logging, + &request_id, + &format!("Query executed successfully: {} on {}", action_str, table), + Some("query_execution"), + Some(&session.username), + Some(session.power), + ); + Ok(Json(response)) + } + Err(e) => { + error!("[{}] Query execution failed: {}", request_id, e); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Query execution error: {}", e), + Some("query_execution"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + + Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Database query failed [request_id: {}]", + request_id + )), + warning: None, + results: None, + })) + } + } +} + +pub async fn get_permissions( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, +) -> Result<Json<PermissionsResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(PermissionsResponse { + success: false, + permissions: HashMap::new(), + user: UserInfo { + id: 0, + username: "".to_string(), + name: "".to_string(), + role: "".to_string(), + power: 0, + }, + security_clearance: None, + user_settings_access: "".to_string(), + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(PermissionsResponse { + success: false, + permissions: HashMap::new(), + user: UserInfo { + id: 0, + username: "".to_string(), + name: "".to_string(), + role: "".to_string(), + power: 0, + }, + security_clearance: None, + user_settings_access: "".to_string(), + })); + } + }; + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/permissions", + &serde_json::json!({}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + let permissions = state + .rbac + .get_table_permissions(&state.config, session.power); + + // Get user settings access permission + let user_settings_permission = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .map(|p| &p.user_settings_access) + .unwrap_or(&state.config.security.default_user_settings_access); + + let user_settings_access_str = match user_settings_permission { + crate::config::UserSettingsAccess::ReadOwnOnly => "read-own-only", + crate::config::UserSettingsAccess::ReadWriteOwn => "read-write-own", + crate::config::UserSettingsAccess::ReadWriteAll => "read-write-all", + }; + + Ok(Json(PermissionsResponse { + success: true, + permissions, + user: UserInfo { + id: session.user_id, + username: session.username, + name: "".to_string(), // We don't store name in session, would need to fetch from DB + role: session.role_name, + power: session.power, + }, + security_clearance: None, + user_settings_access: user_settings_access_str.to_string(), + })) +} + +// ===== SAFE SQL QUERY BUILDERS WITH VALIDATION ===== +// All functions validate table/column names to prevent SQL injection + +/// Build WHERE clause with column name validation (legacy simple format) +fn build_where_clause(where_clause: &Value) -> anyhow::Result<(String, Vec<String>)> { + // Use the new validated builder from sql module + build_legacy_where_clause(where_clause) +} + +/// Build INSERT data with column name validation +fn build_insert_data( + data: &Value, +) -> anyhow::Result<(Vec<String>, Vec<String>, Vec<Option<String>>)> { + let mut columns = Vec::new(); + let mut placeholders = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = data { + for (key, value) in map { + // SECURITY: Validate column name before using it + validate_column_name(key) + .with_context(|| format!("Invalid column name in INSERT: {}", key))?; + + // Handle special JSON fields like additional_fields + if key == "additional_fields" && value.is_object() { + columns.push(key.clone()); + placeholders.push("?".to_string()); + values.push(Some( + serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()), + )); + } else { + columns.push(key.clone()); + placeholders.push("?".to_string()); + values.push(json_value_to_option_string(value)); + } + } + } else { + anyhow::bail!("INSERT data must be a JSON object"); + } + + if columns.is_empty() { + anyhow::bail!("INSERT data cannot be empty"); + } + + Ok((columns, placeholders, values)) +} + +/// Build UPDATE SET clause with column name validation +fn build_update_set_clause(data: &Value) -> anyhow::Result<(String, Vec<Option<String>>)> { + let mut set_clauses = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = data { + for (key, value) in map { + // SECURITY: Validate column name before using it + validate_column_name(key) + .with_context(|| format!("Invalid column name in UPDATE: {}", key))?; + + set_clauses.push(format!("{} = ?", key)); + // Handle special JSON fields like additional_fields + if key == "additional_fields" && value.is_object() { + values.push(Some( + serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()), + )); + } else { + values.push(json_value_to_option_string(value)); + } + } + } else { + anyhow::bail!("UPDATE data must be a JSON object"); + } + + if set_clauses.is_empty() { + anyhow::bail!("UPDATE data cannot be empty"); + } + + Ok((set_clauses.join(", "), values)) +} + +/// Convert JSON value to Option<String> for SQL binding +/// Properly handles booleans (true/false -> "1"/"0" for MySQL TINYINT/BOOLEAN) +/// NULL values return None for proper SQL NULL handling +fn json_value_to_option_string(value: &Value) -> Option<String> { + match value { + Value::String(s) => Some(s.clone()), + Value::Number(n) => Some(n.to_string()), + // MySQL uses TINYINT(1) for booleans: true -> 1, false -> 0 + Value::Bool(b) => Some(if *b { "1".to_string() } else { "0".to_string() }), + Value::Null => None, // Return None for proper SQL NULL + // Complex types (objects, arrays) get JSON serialized + _ => Some(serde_json::to_string(value).unwrap_or_else(|_| "null".to_string())), + } +} +/// Convert SQL values from MySQL to JSON +/// Handles ALL MySQL types automatically with proper NULL handling +/// Returns booleans as true/false (MySQL TINYINT(1) -> JSON bool) +fn convert_sql_value_to_json( + row: &sqlx::mysql::MySqlRow, + index: usize, + request_id: Option<&str>, + username: Option<&str>, + power: Option<i32>, + state: Option<&AppState>, +) -> anyhow::Result<Value> { + let column = &row.columns()[index]; + + use sqlx::TypeInfo; + let type_name = column.type_info().name(); + + // Comprehensive MySQL type handling - no need to manually add types anymore! + let result = match type_name { + // ===== String types ===== + "VARCHAR" | "TEXT" | "CHAR" | "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "SET" | "ENUM" => { + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + + // ===== Integer types ===== + "INT" | "BIGINT" | "MEDIUMINT" | "SMALLINT" | "INTEGER" => row + .try_get::<Option<i64>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // Unsigned integers + "INT UNSIGNED" | "BIGINT UNSIGNED" | "MEDIUMINT UNSIGNED" | "SMALLINT UNSIGNED" => row + .try_get::<Option<u64>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // ===== Boolean type (MySQL TINYINT(1)) ===== + // Returns proper JSON true/false instead of 1/0 + "TINYINT" | "BOOLEAN" | "BOOL" => { + // Try as bool first (for TINYINT(1)) + if let Ok(opt_bool) = row.try_get::<Option<bool>, _>(index) { + return Ok(opt_bool.map(Value::Bool).unwrap_or(Value::Null)); + } + // Fallback to i8 for regular TINYINT + row.try_get::<Option<i8>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)) + } + + // ===== Decimal/Numeric types ===== + "DECIMAL" | "NUMERIC" => { + row.try_get::<Option<rust_decimal::Decimal>, _>(index) + .map(|opt| { + opt.map(|decimal| { + // Keep precision by converting to string then parsing + let decimal_str = decimal.to_string(); + if let Ok(f) = decimal_str.parse::<f64>() { + serde_json::json!(f) + } else { + Value::String(decimal_str) + } + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Floating point types ===== + "FLOAT" | "DOUBLE" | "REAL" => row + .try_get::<Option<f64>, _>(index) + .map(|opt| opt.map(|v| serde_json::json!(v)).unwrap_or(Value::Null)), + + // ===== Date/Time types ===== + "DATE" => { + use chrono::NaiveDate; + row.try_get::<Option<NaiveDate>, _>(index).map(|opt| { + opt.map(|date| Value::String(date.format("%Y-%m-%d").to_string())) + .unwrap_or(Value::Null) + }) + } + "DATETIME" => { + use chrono::NaiveDateTime; + row.try_get::<Option<NaiveDateTime>, _>(index).map(|opt| { + opt.map(|datetime| Value::String(datetime.format("%Y-%m-%d %H:%M:%S").to_string())) + .unwrap_or(Value::Null) + }) + } + "TIMESTAMP" => { + use chrono::{DateTime, Utc}; + row.try_get::<Option<DateTime<Utc>>, _>(index).map(|opt| { + opt.map(|timestamp| Value::String(timestamp.to_rfc3339())) + .unwrap_or(Value::Null) + }) + } + "TIME" => { + // TIME values come as strings in HH:MM:SS format + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + "YEAR" => row + .try_get::<Option<i32>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // ===== JSON type ===== + "JSON" => row.try_get::<Option<String>, _>(index).map(|opt| { + opt.and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or(Value::Null) + }), + + // ===== Binary types ===== + // Return as base64-encoded strings for safe JSON transmission + "BLOB" | "MEDIUMBLOB" | "LONGBLOB" | "TINYBLOB" | "BINARY" | "VARBINARY" => { + row.try_get::<Option<Vec<u8>>, _>(index).map(|opt| { + opt.map(|bytes| { + use base64::{engine::general_purpose, Engine as _}; + Value::String(general_purpose::STANDARD.encode(&bytes)) + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Bit type ===== + "BIT" => { + row.try_get::<Option<Vec<u8>>, _>(index).map(|opt| { + opt.map(|bytes| { + // Convert bit value to number + let mut val: i64 = 0; + for &byte in &bytes { + val = (val << 8) | byte as i64; + } + Value::Number(val.into()) + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Spatial/Geometry types ===== + "GEOMETRY" | "POINT" | "LINESTRING" | "POLYGON" | "MULTIPOINT" | "MULTILINESTRING" + | "MULTIPOLYGON" | "GEOMETRYCOLLECTION" => { + // Return as WKT (Well-Known Text) string + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + + // ===== Catch-all for unknown/new types ===== + // This ensures forward compatibility if MySQL adds new types + _ => { + warn!( + "Unknown MySQL type '{}' for column '{}', attempting string fallback", + type_name, + column.name() + ); + if let (Some(rid), Some(st)) = (request_id, state) { + super::log_warning_async( + &st.logging, + rid, + &format!( + "Unknown MySQL type '{}' for column '{}', attempting string fallback", + type_name, + column.name() + ), + Some("data_conversion"), + username, + power, + ); + } + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + }; + + // Robust error handling with fallback + match result { + Ok(value) => Ok(value), + Err(e) => { + // Final fallback: try as string + match row.try_get::<Option<String>, _>(index) { + Ok(opt) => { + warn!("Primary conversion failed for column '{}' (type: {}), used string fallback", + column.name(), type_name); + if let (Some(rid), Some(st)) = (request_id, state) { + super::log_warning_async( + &st.logging, + rid, + &format!("Primary conversion failed for column '{}' (type: {}), used string fallback", column.name(), type_name), + Some("data_conversion"), + username, + power, + ); + } + Ok(opt.map(Value::String).unwrap_or(Value::Null)) + } + Err(_) => { + error!( + "Complete failure to decode column '{}' (index: {}, type: {}): {}", + column.name(), + index, + type_name, + e + ); + // Return NULL instead of failing the entire query + Ok(Value::Null) + } + } + } + } +} + +// Generate auto values based on configuration +async fn generate_auto_value( + state: &AppState, + table: &str, + config: &crate::config::AutoGenerationConfig, +) -> Result<String, anyhow::Error> { + match config.gen_type.as_str() { + "numeric" => generate_unique_numeric_id(state, table, config).await, + _ => Err(anyhow::anyhow!( + "Unsupported auto-generation type: {}", + config.gen_type + )), + } +} + +// Generate a unique numeric ID based on configuration +async fn generate_unique_numeric_id( + state: &AppState, + table: &str, + config: &crate::config::AutoGenerationConfig, +) -> Result<String, anyhow::Error> { + let range_min = config.range_min.unwrap_or(10000000); + let range_max = config.range_max.unwrap_or(99999999); + let max_attempts = config.max_attempts.unwrap_or(10) as usize; + let field_name = &config.field; + + for _attempt in 0..max_attempts { + // Generate random number in specified range + let id = { + let mut rng = rand::rng(); + rng.random_range(range_min..=range_max) + }; + let id_str = id.to_string(); + + // Check if this ID already exists + let query_str = format!( + "SELECT COUNT(*) as count FROM {} WHERE {} = ?", + table, field_name + ); + let exists = sqlx::query(&query_str) + .bind(&id_str) + .fetch_one(state.database.pool()) + .await?; + + let count: i64 = exists.try_get("count")?; + + if count == 0 { + return Ok(id_str); + } + } + + Err(anyhow::anyhow!( + "Failed to generate unique {} for table {} after {} attempts", + field_name, + table, + max_attempts + )) +} + +// Security validation functions +fn validate_input_security(payload: &QueryRequest) -> Result<(), String> { + // Check for null bytes in table name + if let Some(ref table) = payload.table { + if table.contains('\0') { + return Err("Null byte detected in table name".to_string()); + } + } + + // Check for null bytes in column names + if let Some(columns) = &payload.columns { + for column in columns { + if column.contains('\0') { + return Err("Null byte detected in column name".to_string()); + } + } + } + + // Check for null bytes in data values + if let Some(data) = &payload.data { + if contains_null_bytes_in_value(data) { + return Err("Null byte detected in data values".to_string()); + } + } + + // Check for null bytes in WHERE clause + if let Some(where_clause) = &payload.where_clause { + if contains_null_bytes_in_value(where_clause) { + return Err("Null byte detected in WHERE clause".to_string()); + } + } + + Ok(()) +} + +fn contains_null_bytes_in_value(value: &Value) -> bool { + match value { + Value::String(s) => s.contains('\0'), + Value::Array(arr) => arr.iter().any(contains_null_bytes_in_value), + Value::Object(map) => { + map.keys().any(|k| k.contains('\0')) || map.values().any(contains_null_bytes_in_value) + } + _ => false, + } +} + +// Core execution functions that work with mutable transaction references +// These are used by batch operations to execute multiple queries in a single atomic transaction + +async fn execute_select_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + let mut count = 0; + if let Some(f) = filter { + count += count_filter_conditions(f); + } + if let Some(w) = where_clause { + count += count_where_conditions(w); + } + count + } + + fn count_filter_conditions(filter: &crate::models::FilterCondition) -> usize { + use crate::models::FilterCondition; + match filter { + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + let mut count = 0; + if let Some(conditions) = and_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + if let Some(conditions) = or_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + count + } + _ => 1, + } + } + + fn count_where_conditions(where_clause: &serde_json::Value) -> usize { + match where_clause { + serde_json::Value::Object(map) => { + if map.contains_key("AND") || map.contains_key("OR") { + if let Some(arr) = map.get("AND").or_else(|| map.get("OR")) { + if let serde_json::Value::Array(conditions) = arr { + return conditions.iter().map(|c| count_where_conditions(c)).sum(); + } + } + } + 1 + } + _ => 1, + } + } + + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + let requested_columns = if let Some(ref cols) = payload.columns { + cols.clone() + } else { + vec!["*".to_string()] + }; + + let filtered_columns = if requested_columns.len() == 1 && requested_columns[0] == "*" { + "*".to_string() + } else { + let allowed_columns = + state + .config + .filter_readable_columns(session.power, &table, &requested_columns); + if allowed_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No readable columns available for this request".to_string()), + warning: None, + results: None, + }); + } + allowed_columns.join(", ") + }; + + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT {} FROM {}", filtered_columns, table); + let mut values = Vec::new(); + + if let Some(joins) = &payload.joins { + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + if let Some(order_by) = &payload.order_by { + let order_clause = build_order_by_clause(order_by)?; + query.push_str(&order_clause); + } + + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {} for your power level, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, defaulted to {} (max for power level {})", + max_limit, session.power + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let rows = sqlx_query.fetch_all(&mut **tx).await?; + + let mut results = Vec::new(); + for row in rows { + let mut result_row = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value = convert_sql_value_to_json( + &row, + i, + Some(request_id), + Some(username), + Some(session.power), + Some(state), + )?; + result_row.insert(column.name().to_string(), value); + } + results.push(Value::Object(result_row)); + } + + Ok(QueryResponse { + success: true, + data: Some(Value::Array(results)), + rows_affected: None, + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_count_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where (same as other functions) + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter_conditions(filter: &FilterCondition) -> usize { + match filter { + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + let mut count = 0; + if let Some(conditions) = and_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + if let Some(conditions) = or_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + count + } + _ => 1, + } + } + + fn count_where_conditions(where_clause: &serde_json::Value) -> usize { + match where_clause { + serde_json::Value::Object(map) => { + if map.contains_key("AND") || map.contains_key("OR") { + if let Some(arr) = map.get("AND").or_else(|| map.get("OR")) { + if let serde_json::Value::Array(conditions) = arr { + return conditions.iter().map(|c| count_where_conditions(c)).sum(); + } + } + } + 1 + } + _ => 1, + } + } + + let mut count = 0; + if let Some(f) = filter { + count += count_filter_conditions(f); + } + if let Some(w) = where_clause { + count += count_where_conditions(w); + } + count + } + + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT COUNT(*) as count FROM {}", table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE conditions (filter takes precedence over where_clause if both are provided) + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = crate::sql::build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Log the query + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.fetch_one(&mut **tx).await?; + let count: i64 = result.try_get("count")?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(count)), + rows_affected: None, + error: None, + warning: None, + results: None, + }) +} + +async fn execute_update_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for UPDATE"))? + .clone(); + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for UPDATE"))?; + + let max_where = state.config.get_max_where_conditions(session.power); + let condition_count = where_clause.as_object().map(|obj| obj.len()).unwrap_or(0); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation (system-generated fields bypass write protection) + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + if auto_config.on_action == "update" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + let writable_columns = data + .as_object() + .unwrap() + .keys() + .cloned() + .collect::<Vec<_>>(); + if writable_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in UPDATE data".to_string()), + warning: None, + results: None, + }); + } + + let set_clause = writable_columns + .iter() + .map(|col| format!("{} = ?", col)) + .collect::<Vec<_>>() + .join(", "); + let (where_sql, where_values) = build_where_clause(where_clause)?; + let query = format!("UPDATE {} SET {} WHERE {}", table, set_clause, where_sql); + + let requested_limit = payload.limit; + let max_limit = state.config.get_max_limit(session.power); + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + let query_with_limit = format!("{} LIMIT {}", query, limit); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {}, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query_with_limit, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query_with_limit); + for col in &writable_columns { + if let Some(val) = data.get(col) { + sqlx_query = sqlx_query.bind(val.clone()); + } + } + for val in where_values { + sqlx_query = sqlx_query.bind(val); + } + + let result = sqlx_query.execute(&mut **tx).await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_delete_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for DELETE"))?; + + let max_where = state.config.get_max_where_conditions(session.power); + let condition_count = where_clause.as_object().map(|obj| obj.len()).unwrap_or(0); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let (where_sql, where_values) = build_where_clause(where_clause)?; + let query = format!("DELETE FROM {} WHERE {}", table, where_sql); + + let requested_limit = payload.limit; + let max_limit = state.config.get_max_limit(session.power); + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + let query_with_limit = format!("{} LIMIT {}", query, limit); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {}, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query_with_limit, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query_with_limit); + for val in where_values { + sqlx_query = sqlx_query.bind(val); + } + + let result = sqlx_query.execute(&mut **tx).await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +// Transaction-based execution functions for user context operations +// These create their own transactions and commit them - used for single query operations + +async fn execute_select_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter(cond: &FilterCondition) -> usize { + match cond { + FilterCondition::Simple { .. } => 1, + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + and_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + + or_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + } + FilterCondition::Not { not } => count_filter(not), + } + } + let mut count = 0; + if let Some(f) = filter { + count += count_filter(f); + } + if let Some(w) = where_clause { + if let serde_json::Value::Object(map) = w { + count += map.len(); + } + } + count + } + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Apply granular column filtering based on user's power level + let requested_columns = if let Some(cols) = &payload.columns { + cols.clone() + } else { + vec!["*".to_string()] + }; + + let filtered_columns = if requested_columns.len() == 1 && requested_columns[0] == "*" { + "*".to_string() + } else { + let allowed_columns = + state + .config + .filter_readable_columns(session.power, &table, &requested_columns); + if allowed_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No readable columns available for this request".to_string()), + warning: None, + results: None, + }); + } + allowed_columns.join(", ") + }; + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT {} FROM {}", filtered_columns, table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + // Validate user has read permission for all joined tables + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + + // Build and append JOIN SQL + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE clause - support both legacy and new filter format + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Add ORDER BY if provided + if let Some(order_by) = &payload.order_by { + let order_clause = build_order_by_clause(order_by)?; + query.push_str(&order_clause); + } + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Add OFFSET if provided + if let Some(offset) = payload.offset { + query.push_str(&format!(" OFFSET {}", offset)); + } + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, // Row count will be known after execution + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let rows = sqlx_query.fetch_all(&mut *tx).await?; + tx.commit().await?; + + let mut results = Vec::new(); + for row in rows { + let mut result_row = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value = convert_sql_value_to_json( + &row, + i, + Some(request_id), + Some(username), + Some(session.power), + Some(state), + )?; + result_row.insert(column.name().to_string(), value); + } + results.push(Value::Object(result_row)); + } + + Ok(QueryResponse { + success: true, + data: Some(Value::Array(results)), + rows_affected: None, + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_insert_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for INSERT"))? + .clone(); + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + // This validates what the user is trying to write + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation requirements based on config + // Auto-generated fields bypass write protection since they're system-generated + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + // Check if auto-generation is enabled for INSERT action + if auto_config.on_action == "insert" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + // Generate auto value based on config + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + // Final validation: ensure we have columns to insert + if let Value::Object(ref map) = data { + if map.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in INSERT data".to_string()), + warning: None, + results: None, + }); + } + } + + let (columns_vec, placeholders_vec, values) = build_insert_data(&data)?; + let columns = columns_vec.join(", "); + let placeholders = placeholders_vec.join(", "); + + let query = format!( + "INSERT INTO {} ({}) VALUES ({})", + table, columns, placeholders + ); + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.execute(&mut *tx).await?; + let insert_id = result.last_insert_id(); + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(insert_id)), + rows_affected: Some(result.rows_affected()), + error: None, + warning: None, // INSERT queries don't have LIMIT, + results: None, + }) +} + +async fn execute_update_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for UPDATE"))? + .clone(); + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for UPDATE"))?; + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = if let Some(w) = &payload.where_clause { + if let serde_json::Value::Object(map) = w { + map.len() + } else { + 0 + } + } else { + 0 + }; + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation (system-generated fields bypass write protection) + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + if auto_config.on_action == "update" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + // Final validation: ensure we have columns to update + if let Value::Object(ref map) = data { + if map.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in UPDATE data".to_string()), + warning: None, + results: None, + }); + } + } + + let (set_clause, mut values) = build_update_set_clause(&data)?; + let (where_sql, where_values) = build_where_clause(where_clause)?; + // Convert where_values to Option<String> to match set values + values.extend(where_values.into_iter().map(Some)); + + let mut query = format!("UPDATE {} SET {} WHERE {}", table, set_clause, where_sql); + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.execute(&mut *tx).await?; + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_delete_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for DELETE"))?; + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = if let Some(w) = &payload.where_clause { + if let serde_json::Value::Object(map) = w { + map.len() + } else { + 0 + } + } else { + 0 + }; + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let (where_sql, values) = build_where_clause(where_clause)?; + + let mut query = format!("DELETE FROM {} WHERE {}", table, where_sql); + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + sqlx_query = sqlx_query.bind(value); + } + + let result = sqlx_query.execute(&mut *tx).await?; + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_count_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Check read permissions for the table + if !state.rbac.check_permission( + &state.config, + session.power, + table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to COUNT from table '{}' [request_id: {}]", + table, request_id + )), + warning: None, + results: None, + }); + } + + // Helper to count conditions in filter/where (same as select function) + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter(cond: &FilterCondition) -> usize { + match cond { + FilterCondition::Simple { .. } => 1, + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + and_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + + or_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + } + FilterCondition::Not { not } => count_filter(not), + } + } + let mut count = 0; + if let Some(f) = filter { + count += count_filter(f); + } + if let Some(w) = where_clause { + if let serde_json::Value::Object(map) = w { + count += map.len(); + } + } + count + } + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT COUNT(*) as count FROM {}", table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + // Validate user has read permission for all joined tables + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE conditions (filter takes precedence over where_clause if both are provided) + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = crate::sql::build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.fetch_one(&mut *tx).await?; + tx.commit().await?; + + let count: i64 = result.try_get("count")?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(count)), + rows_affected: None, + error: None, + warning: None, + results: None, + }) +} + +/// Execute multiple queries in a single transaction +async fn execute_batch_mode( + state: AppState, + session: crate::models::Session, + request_id: String, + timestamp: chrono::DateTime<chrono::Utc>, + client_ip: String, + payload: &QueryRequest, +) -> Result<Json<QueryResponse>, StatusCode> { + let queries = payload.queries.as_ref().unwrap(); + + // Check if batch is empty + if queries.is_empty() { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Batch request cannot be empty".to_string()), + warning: None, + results: Some(vec![]), + })); + } + + info!( + "[{}] Batch query request from user {} (power {}): {} queries", + request_id, + session.username, + session.power, + queries.len() + ); + + // Log the batch request (log full payload including action/table) + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/query", + &serde_json::to_value(payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log batch request: {}", request_id, e); + } + + // Get action and table from parent request (all batch queries share these) + let action = payload + .action + .as_ref() + .ok_or_else(|| StatusCode::BAD_REQUEST)?; + + let table = payload + .table + .as_ref() + .ok_or_else(|| StatusCode::BAD_REQUEST)?; + + // Validate table name once for the entire batch + if let Err(e) = validate_table_name(table, &state.config) { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Invalid table name: {}", e)), + warning: None, + results: Some(vec![]), + })); + } + + // Check RBAC permission once for the entire batch + if !state + .rbac + .check_permission(&state.config, session.power, table, action) + { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions for {} on table '{}'", + match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }, + table + )), + warning: None, + results: Some(vec![]), + })); + } + + info!( + "[{}] Validated batch: {} x {:?} on table '{}'", + request_id, + queries.len(), + action, + table + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Validated batch: {} x {:?} on table '{}'", + queries.len(), + action, + table + ), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + if !state.database.is_available() { + warn!( + "[{}] Database marked unavailable before batch execution", + request_id + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + "Database flagged unavailable before batch", + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + + // Start a SINGLE transaction for the batch - proper atomic operation + let mut tx = match state.database.pool().begin().await { + Ok(tx) => { + state.database.mark_available(); + tx + } + Err(e) => { + state.database.mark_unavailable(); + error!("[{}] Failed to begin batch transaction: {}", request_id, e); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to begin batch transaction: {}", e), + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + }; + + // Set user context and request ID in transaction + if let Err(e) = sqlx::query("SET @current_user_id = ?, @request_id = ?") + .bind(session.user_id) + .bind(&request_id) + .execute(&mut *tx) + .await + { + state.database.mark_unavailable(); + error!( + "[{}] Failed to set current user context and request ID: {}", + request_id, e + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to set batch user context: {}", e), + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + + let rollback_on_error = payload.rollback_on_error.unwrap_or(false); + + info!( + "[{}] Executing NATIVE batch: {} x {:?} on '{}' (rollback_on_error={})", + request_id, + queries.len(), + action, + table, + rollback_on_error + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Executing NATIVE batch: {} x {:?} on '{}' (rollback_on_error={})", + queries.len(), + action, + table, + rollback_on_error + ), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + // Execute as SINGLE native batch query based on action type + let result = match action { + QueryAction::Insert => { + execute_batch_insert( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Update => { + execute_batch_update( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Delete => { + execute_batch_delete( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Select => { + // SELECT batches are less common but we'll execute them individually + // (combining SELECTs into one query doesn't make sense as they return different results) + execute_batch_selects( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &session, + &state, + ) + .await + } + QueryAction::Count => { + // COUNT batches execute individually (each returns different results) + execute_batch_counts( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &session, + &state, + ) + .await + } + }; + + match result { + Ok(response) => { + if response.success || !rollback_on_error { + // Commit the transaction + tx.commit().await.map_err(|e| { + error!("[{}] Failed to commit batch transaction: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + info!( + "[{}] Native batch committed: {} operations", + request_id, + queries.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Native batch committed: {} operations", queries.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + Ok(Json(response)) + } else { + // Rollback on error + error!( + "[{}] Rolling back batch transaction due to error", + request_id + ); + tx.rollback().await.map_err(|e| { + error!("[{}] Failed to rollback transaction: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(Json(response)) + } + } + Err(e) => { + error!("[{}] Batch execution failed: {}", request_id, e); + tx.rollback().await.map_err(|e2| { + error!("[{}] Failed to rollback after error: {}", request_id, e2); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Batch execution failed: {}", e)), + warning: None, + results: Some(vec![]), + })) + } + } +} + +// ===== NATIVE BATCH EXECUTION FUNCTIONS ===== + +/// Execute batch INSERT using MySQL multi-value INSERT +async fn execute_batch_insert( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + _action: &QueryAction, + _username: &str, + state: &AppState, + _session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + use serde_json::Value; + + // Extract all data objects, apply auto-generation, and validate they have the same columns + let mut all_data = Vec::new(); + let mut column_set: Option<std::collections::HashSet<String>> = None; + + // Check for auto-generation config + let auto_config = state.config.get_auto_generation_config(&table); + + for (idx, query) in queries.iter().enumerate() { + let mut data = query + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Query {} missing data field for INSERT", idx + 1))? + .clone(); + + // Apply auto-generation if configured for INSERT + if let Some(ref auto_cfg) = auto_config { + if auto_cfg.on_action == "insert" || auto_cfg.on_action == "both" { + if let Value::Object(ref mut map) = data { + let field_name = &auto_cfg.field; + + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + // Generate auto value based on config + let generated_value = generate_auto_value(&state, &table, auto_cfg).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + if let Value::Object(map) = data { + let cols: std::collections::HashSet<String> = map.keys().cloned().collect(); + + if let Some(ref expected_cols) = column_set { + if *expected_cols != cols { + anyhow::bail!("All INSERT queries must have the same columns. Query {} has different columns", idx + 1); + } + } else { + column_set = Some(cols); + } + + all_data.push(map); + } else { + anyhow::bail!("Query {} data must be an object", idx + 1); + } + } + + let columns: Vec<String> = column_set.unwrap().into_iter().collect(); + + // Validate column names + for col in &columns { + validate_column_name(col)?; + } + + // Build multi-value INSERT: INSERT INTO table (col1, col2) VALUES (?, ?), (?, ?), ... + let column_list = columns.join(", "); + let value_placeholder = format!("({})", vec!["?"; columns.len()].join(", ")); + let values_clause = vec![value_placeholder; all_data.len()].join(", "); + + let sql = format!( + "INSERT INTO {} ({}) VALUES {}", + table, column_list, values_clause + ); + + info!( + "[{}] Native batch INSERT: {} rows into {}", + request_id, + all_data.len(), + table + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Native batch INSERT: {} rows into {}", + all_data.len(), + table + ), + Some("query"), + Some(&_session.username), + Some(_session.power), + ); + + // Bind all values + let mut query = sqlx::query(&sql); + for data_map in &all_data { + for col in &columns { + let value = data_map.get(col).and_then(|v| match v { + Value::String(s) => Some(s.clone()), + Value::Number(n) => Some(n.to_string()), + Value::Bool(b) => Some(b.to_string()), + Value::Null => None, + _ => Some(v.to_string()), + }); + query = query.bind(value); + } + } + + // Execute the batch INSERT + let result = query.execute(&mut **tx).await?; + let rows_affected = result.rows_affected(); + + info!( + "[{}] Batch INSERT affected {} rows", + request_id, rows_affected + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch INSERT affected {} rows", rows_affected), + Some("query"), + Some(&_session.username), + Some(_session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(rows_affected), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch UPDATE (executes individually for now, could be optimized with CASE statements) +async fn execute_batch_update( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let mut total_rows = 0u64; + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_update_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + state, + session, + ) + .await?; + + if let Some(rows) = result.rows_affected { + total_rows += rows; + } + } + + info!( + "[{}] Batch UPDATE affected {} total rows", + request_id, total_rows + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch UPDATE affected {} total rows", total_rows), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(total_rows), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch DELETE using IN clause when possible +async fn execute_batch_delete( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + // For now, execute individually + // TODO: Optimize by detecting simple ID-based deletes and combining with IN clause + let mut total_rows = 0u64; + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_delete_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + state, + session, + ) + .await?; + + if let Some(rows) = result.rows_affected { + total_rows += rows; + } + } + + info!( + "[{}] Batch DELETE affected {} total rows", + request_id, total_rows + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch DELETE affected {} total rows", total_rows), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(total_rows), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch SELECT (executes individually since they return different results) +async fn execute_batch_selects( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let mut results = Vec::new(); + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_select_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + session, + state, + ) + .await?; + + results.push(result); + } + + info!( + "[{}] Batch SELECT executed {} queries", + request_id, + results.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch SELECT executed {} queries", results.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: None, + error: None, + warning: None, + results: Some(results), + }) +} + +/// Execute batch COUNT (executes individually since they return different results) +async fn execute_batch_counts( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let mut results = Vec::new(); + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_count_core( + &format!("{}:{}", request_id, idx), + tx, + &query_req, + username, + session, + state, + ) + .await?; + + results.push(result); + } + + info!( + "[{}] Batch COUNT executed {} queries", + request_id, + results.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch COUNT executed {} queries", results.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: None, + error: None, + warning: None, + results: Some(results), + }) +} |
