diff options
| -rw-r--r-- | .gitignore | 28 | ||||
| -rw-r--r-- | Cargo.toml | 47 | ||||
| -rw-r--r-- | README.md | 281 | ||||
| -rw-r--r-- | config/basics.toml | 18 | ||||
| -rw-r--r-- | config/functions.toml | 45 | ||||
| -rw-r--r-- | config/logging.toml | 31 | ||||
| -rw-r--r-- | config/security.toml | 210 | ||||
| -rw-r--r-- | src/auth/mod.rs | 6 | ||||
| -rw-r--r-- | src/auth/password.rs | 56 | ||||
| -rw-r--r-- | src/auth/pin.rs | 66 | ||||
| -rw-r--r-- | src/auth/session.rs | 129 | ||||
| -rw-r--r-- | src/auth/token.rs | 99 | ||||
| -rw-r--r-- | src/config.rs | 845 | ||||
| -rw-r--r-- | src/db/mod.rs | 3 | ||||
| -rw-r--r-- | src/db/pool.rs | 110 | ||||
| -rw-r--r-- | src/logging/logger.rs | 373 | ||||
| -rw-r--r-- | src/logging/mod.rs | 3 | ||||
| -rw-r--r-- | src/main.rs | 359 | ||||
| -rw-r--r-- | src/models/mod.rs | 294 | ||||
| -rw-r--r-- | src/permissions/mod.rs | 3 | ||||
| -rw-r--r-- | src/permissions/rbac.rs | 119 | ||||
| -rw-r--r-- | src/routes/auth.rs | 513 | ||||
| -rw-r--r-- | src/routes/mod.rs | 111 | ||||
| -rw-r--r-- | src/routes/preferences.rs | 423 | ||||
| -rw-r--r-- | src/routes/query.rs | 3255 | ||||
| -rw-r--r-- | src/scheduler.rs | 185 | ||||
| -rw-r--r-- | src/sql/builder.rs | 493 | ||||
| -rw-r--r-- | src/sql/mod.rs | 8 | ||||
| -rw-r--r-- | todo.md | 2 |
29 files changed, 8115 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..00e3cc6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +# Ignore logs +/logs +*.log + +# Ignore Rust build artifacts +/target/ +Cargo.lock + +# Ignore vscodes bs files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Ignore macOS piss files +.DS_Store + +# Ignore session data +session.json + +/seckelapi + +# Ignore database backups +/database + +# Ignore temporary setup folder +/beepzone-inventory-setup
\ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..a5527e6 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "seckelapi" +version = "0.0.11" +edition = "2021" + +[dependencies] +# Web framework +axum = { version = "0.8", features = ["macros"] } +tokio = { version = "1.0", features = ["full"] } +tower = "0.5" +tower-http = { version = "0.6", features = ["cors", "trace", "limit"] } +tower_governor = "0.8" + +# Database +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "mysql", "json", "chrono", "uuid", "rust_decimal"] } +rust_decimal = "1.0" + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Configuration +toml = "0.9" + +# Authentication and security +bcrypt = "0.17" +uuid = { version = "1.0", features = ["v4", "serde"] } + +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Date/time +chrono = { version = "0.4", features = ["serde"] } + +# Error handling +anyhow = "1.0" +thiserror = "2.0" + +# Network utilities for IP handling +ipnet = "2.9" + +# Additional utilities +async-trait = "0.1" +rand = "0.9" +base64 = "0.22" +regex = "1.10"
\ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..3aa184b --- /dev/null +++ b/README.md @@ -0,0 +1,281 @@ +# SeckelAPI + +## What is this even +A hopefully somewhat secure, role based SQL API server built in Rust. Provides a API interface to MariaDB with goofy authentication methods, basic and kinda advanced table authorization, and logging capabilities. + +kinda recycled from an older project and not cleaned up from swearwords or drunk coding sessions reminants at all! +please do not take any insult from the code or its "quality" incase the automatic sanetization by an llm model failed mlol. + +currently contains example config for beepzone inventory system but should be usable for more than just that + +## Purpose (for now): +To server as a basic API plus "firewall" between any BeepZone client and its actual database. + +## Goofy ah Features worth mentioning: + +### **Three auth methods** +- **Password Authentication**: Normal username/password +- **PIN Authentication**: Username+PIN based login with IP whitelisting +- **Token Authentication**: Reusable token strings (for like RFID cards etc. etc.) with IP restrictions + +### **Attempts at Security** +- **Basic and advanced table permissions (kinda RBAC style)**: Control read and write access per table or if you're schizophrenic even as granular as table column specific. +- **IP Whitelisting**: Restrict PIN and token authentication by IPee ranges as these auth meths are by their nature insecure but needed for my application. (can be saved as bcrypt hash or of you really want to in plaintext it technically can do both) +- **Input Validation**: Protection against some of the most basic ass common SQL injections (atleast according to chat gpt i dont do opsec myself im probably worse than an llm model in that aspect) +- **Audit Logging**: Comprehensive request, query, and error logging. You see what comes in from a client, what happens within the API, and what goes out from the API to the database. + +### **Goofy Data-BASED helping Features** +- **Generated Fields**: Automatic generation of defined fields if they are sent empty in request (for like asset tags and stuff with dynamic generation template strings that are too complex fo me to let just the DB triggers handle it) +- **User and Transaction Context**: User and Transaction ID context automatically on ALL operations with optional exclusions of fields your next upcoming dataleak <3 +- **Read Only Tables**: Enforce read only for tables where all stuff is handled by database triggers +- **Connection Pooling**: Somewhat efficient database connection management + +## Das Archtiktur und so + +``` +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ Client App │───▶│ SeckelAPI │───▶│ MySQL/MariaDB │ +│ (BeepZone UI) │ │ (Port 8800) │ │ (Port 3306) │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ + │ + ▼ + ┌─────────────────┐ + │ Log Files │ + │ ./logs/*.log │ + └─────────────────┘ +``` + +### Internal Request Flow + +``` +Client Request + │ + ▼ +┌─────────────────────────────────────────┐ +│ 1. Rate Limiting (per IP) │ ──▶ {"success": false, "error": "Too Many Requests"} +│ - Auth: 60/min, 10/sec (configurable)│ +│ - API: 120/min, 20/sec (configurable)│ +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ 2. Authentication │ +│ - Extract Bearer token │ +│ - Validate session │ ──▶ {"success": false, "error": "Invalid session"} +│ - Set user context (@current_user_id)│ +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ 3. RBAC Permission Check │ +│ - Check basic_rules (table access) │ ──▶ {"success": false, "error": "Insufficient permissions [request_id: xxx]"} +│ - Apply advanced_rules (column-level)│ +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ 4. Query Building & Validation │ +│ - Validate table/column names │ +│ - Filter writable columns │ ──▶ {"success": false, "error": "Invalid table/column [request_id: xxx]"} +│ - Auto-generate fields (if needed) │ +│ - Apply LIMIT caps │ +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ 5. Database Execution │ +│ - Execute via connection pool │ +│ - Triggers run (change log, etc.) │ ──▶ {"success": false, "error": "Database query failed [request_id: xxx]"} +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ 6. Audit Logging │ +│ - Log request, query, result │ +│ - Mask sensitive fields │ +└─────────────────────────────────────────┘ + │ + ▼ + JSON Response +``` + +## Database Requirements + +Your database needs these tables for the API to work: + +### Required Tables +- **`users`** - User accounts with authentication credentials, if you dont use pin or string logins (for kiosk accounts or rfid) they can be left out + - Fields: `id`, `username`, `password` (bcrypt hash), `pin_code`, `login_string` (RFID), `role_id`, `active` + +- **`roles`** - Role definitions with power levels + - Fields: `id`, `name`, `power` (1-100, where 100 = admin) + +### Your Application Tables +The API works with ANY tables you define. Common examples (for BeepZone as an Example): +- `assets`, `categories`, `zones`, `suppliers` (asset management) +- `lending_history`, `borrowers` (lending system) +- `physical_audit_logs`, `physical_audits` (audit system) +- Literally anything else - it's your database + +**Note**: Use database triggers to populate audit fields (`created_by`, `last_modified_by`) using `@current_user_id` session (or even last change transactionid for tracing) variable that the API automatically sets. + +## API Features + +### Authentication Endpoints + +**POST /auth/login** +```json +// Password auth +{"method": "password", "username": "admin", "password": "pass123"} + +// PIN auth (IP restricted) +{"method": "pin", "username": "user1", "pin": "1234"} + +// Token auth (IP restricted) +{"method": "token", "login_string": "RFID_TOKEN_12345"} +``` +Returns: `{"success": true, "token": "session-token-here"}` + +Use token in all subsequent requests: `Authorization: Bearer <token>` + +### Query Endpoint + +**POST /query** - Main data operations + +**SELECT** - Read data +```json +{ + "action": "select", + "table": "assets", + "columns": ["id", "name", "status"], + "where": {"status": "Good"}, + "order_by": [{"column": "name", "direction": "ASC"}], + "limit": 50 +} +``` + +**INSERT** - Create records +```json +{ + "action": "insert", + "table": "assets", + "data": { + "name": "Laptop", + "status": "Good", + "category_id": 5 + // "asset_numeric_id" auto-generated if configured + } +} +``` +Returns: `{"success": true, "data": 123}` (new ID) + +**UPDATE** - Modify records +```json +{ + "action": "update", + "table": "assets", + "data": {"status": "In Repair"}, + "where": {"id": 123} +} +``` + +**DELETE** - Remove records +```json +{ + "action": "delete", + "table": "assets", + "where": {"id": 123} +} +``` + +**BATCH** - Multiple operations in one transaction +```json +{ + "action": "batch", + "queries": [ + {"action": "insert", "table": "assets", "data": {...}}, + {"action": "update", "table": "assets", "data": {...}, "where": {...}} + ], + "rollback_on_error": true // All or nothing +} +``` + +### Advanced Query Features + +**JOINs** - Query across tables +```json +{ + "action": "select", + "table": "assets", + "columns": ["assets.*", "categories.name as category_name"], + "joins": [ + { + "type": "INNER", + "table": "categories", + "on": "assets.category_id = categories.id" + } + ] +} +``` + +**Complex WHERE** - Multiple conditions +```json +{ + "where": { + "status": {"operator": "IN", "value": ["Good", "Attention"]}, + "price": {"operator": ">=", "value": 100}, + "name": {"operator": "LIKE", "value": "%Laptop%"} + } +} +``` + +**Aggregations** - GROUP BY and aggregate functions +```json +{ + "action": "select", + "table": "assets", + "columns": ["category_id", "COUNT(*) as total"], + "group_by": ["category_id"] +} +``` + +### Security Features + +- **Rate Limiting**: 429 if you spam too hard +- **Column Filtering**: Auto removes columns you can't write based on permissions +- **Query Limits**: Max LIMIT capped per power level (prevents SELECT * disasters for big boy databases) +- **WHERE Limits**: Max conditions per query (prevents complex attack queries) +- **Read-Only Tables**: Some tables blocked from writes entirely +- **Session Timeouts**: Auto-expire sessions based on power level +- **Audit Everything**: All operations logged with user context + +### Health Check + +**GET /health** - Check if the API and database are alive +```json +{ + "status": "hurensohn modus aktiviert", + "database": "connected" +} +``` + +## Running It + +1. Set up your MariaDB/MySQL database +2. Configure `config/*.toml` files +3. Build: `cargo build --release` +4. Run: `./target/release/SeckelAPI` + +Server starts on configured port (default 8800). + +Check logs in `logs/` folder to see what's happening. + +## Testing + +Run the workflow test to verify everything works: +```bash +cd testing +./1-workflow.sh +``` + +This creates sample data and tests all features. diff --git a/config/basics.toml b/config/basics.toml new file mode 100644 index 0000000..eb728e4 --- /dev/null +++ b/config/basics.toml @@ -0,0 +1,18 @@ +# Basic ahh configs + +[server] +host = "0.0.0.0" +port = 5777 +request_body_limit_mb = 10 + +[database] +host = "host.containers.internal" +port = 3306 +database = "beepzone" +username = "beepzone" +password = "beepzone" +min_connections = 1 +max_connections = 10 +connection_timeout_seconds = 2 +connection_timeout_wait = 2 +connection_check = 1
\ No newline at end of file diff --git a/config/functions.toml b/config/functions.toml new file mode 100644 index 0000000..deb487b --- /dev/null +++ b/config/functions.toml @@ -0,0 +1,45 @@ +# auto generation of things +[auto_generation] + +[auto_generation.assets] +field = "asset_numeric_id" +type = "numeric" +length = 8 +range_min = 10000000 +range_max = 99999999 +max_attempts = 10 +# on what event seckel api schould try to generate auto gen value incaase client send empty value +on_action = "insert" + +[scheduled_queries] + +# Single idempotent task that sets the correct state atomically to avoid double-trigger inserts +[[scheduled_queries.tasks]] +name = "sync_overdue_and_stolen" +description = "Atomically set lending_status to Overdue (1-13 days late) or Stolen (>=14 days late) only if it changed" +query = """ + -- Use max lateness per asset to avoid flip-flopping due to multiple open lending rows + -- Removed issue_tracker check from WHERE clause to avoid MySQL trigger conflict + UPDATE assets a + INNER JOIN ( + SELECT lh.asset_id, MAX(DATEDIFF(CURDATE(), lh.due_date)) AS days_late + FROM lending_history lh + WHERE lh.return_date IS NULL + AND lh.due_date IS NOT NULL + GROUP BY lh.asset_id + ) late ON a.id = late.asset_id + SET a.lending_status = CASE + WHEN a.asset_type IN ('N','B') AND late.days_late >= 14 THEN 'Stolen' + WHEN a.asset_type IN ('N','B') AND late.days_late BETWEEN 1 AND 13 THEN 'Overdue' + ELSE a.lending_status + END + WHERE a.asset_type IN ('N','B') + AND ( + (late.days_late >= 14 AND a.lending_status <> 'Stolen') + OR + (late.days_late BETWEEN 1 AND 13 AND a.lending_status <> 'Overdue') + ) +""" +interval_minutes = 2 +run_on_startup = true +enabled = true
\ No newline at end of file diff --git a/config/logging.toml b/config/logging.toml new file mode 100644 index 0000000..fdfc08a --- /dev/null +++ b/config/logging.toml @@ -0,0 +1,31 @@ +# Logging Configuration +[logging] +# all logs can be commented out to disable them if you want yk, because you probably dont need more than the combined log +request_log = "./logs/request.log" +query_log = "./logs/queries.log" +error_log = "./logs/error.log" +warning_log = "./logs/warning.log" +info_log = "./logs/info.log" +combined_log = "./logs/sequel.log" + +# Log levels: debug, info, warn, error +level = "info" + +# mask fields that are sensitive in logs (they are hashed anyways but why log bcrypt hashes in ur logs thats dumb) +mask_passwords = true + +# other values that we might not want in query logs (also applies to request logs) +sensitive_fields = ["login_string", "password_reset_token", "pin_code"] + +# Custom log filters, route specific log entries to separate files using regex ... yes I have autism why are you asking? +[[logging.custom_filters]] +name = "security_violations" +output_file = "./logs/security_violations.log" +pattern = "(Permission denied|Too many WHERE|Authentication failed|invalid credentials|invalid PIN|invalid token)" +enabled = true + +[[logging.custom_filters]] +name = "admin_transactions" +output_file = "./logs/admin_activity.log" +pattern = "user=admin|power=100" +enabled = true diff --git a/config/security.toml b/config/security.toml new file mode 100644 index 0000000..1b87c5b --- /dev/null +++ b/config/security.toml @@ -0,0 +1,210 @@ +# prepare for evil ass autism configs! +[security] +# Yk what this is, if not read the fkn readme +whitelisted_pin_ips = ["192.168.1.0/24", "127.0.0.1"] +whitelisted_string_ips = ["192.168.5.0/24", "127.0.0.1"] + +# session stuffs +session_timeout_minutes = 60 # def session timeout (makes session key go bye bye) +refresh_session_on_activity = true # most useless thing ever most likely as nobody will ever disable this but sure you can just kill a users session during active use right? +max_concurrent_sessions = 3 # how many gooning session to allow per user (you can set custom ones per powerlevel btw) +session_cleanup_interval_minutes = 5 # how often to actually check on the session timeout, we aint gotta spam it none stop tbh + +# PIN and Token Auth +hash_pins = false # weather or not to use bcrypt for pin field (left off for dev work) +hash_tokens = false # Same as above +pin_column = "pin_code" +token_column = "login_string" + +# Rate Limiting, need i say more? +enable_rate_limiting = true # Do yuo wahnt raten limitierung or not? + +# If i have to explain these to you just dont use this software +auth_rate_limit_per_minute = 10000 +auth_rate_limit_per_second = 50000 + +# api rape limitz +api_rate_limit_per_minute = 100000 +api_rate_limit_per_second = 100000 + +# default query limits to avoid someone spamming quieries on a table with 271k rows +default_max_limit = 10000 +default_max_where_conditions = 1000 + +# own user preferences level +# Determines what an user can do with their own little preference store +# - "read-own-only": kiosk ah ruling +# - "read-write-own": what you probably want for most users +# - "read-write-all": adminier maybe ? +default_user_settings_access = "read-write-own" + +# define what tables exist +# known tables for wildcard permissions (*:rw) and to prevent SQL injection via table names cuz thats a thing +known_tables = [ + "users", "roles", "assets", "categories", "zones", + "suppliers", "templates", "audit_tasks", "borrowers", + "lending_history", "audit_history", "maintenance_log", + "asset_change_log", "issue_tracker", "issue_tracker_change_log", + "physical_audits", "physical_audit_logs", + "label_templates", "printer_settings", "print_history" +] + +# tables you cant write or change using proxi in any way not even user overrides below +read_only_tables = ["asset_change_log", "issue_tracker_change_log", "print_history"] + +# column names banned from being written to by default (this is however overwritable on a per table per column per user type schizo settings below) +global_write_protected_columns = [ + "id", + "created_date", + "created_at", + "last_modified_date", + "updated_at", + "last_modified_at", +] + +# note to myself how the rbac system kinda works +# Format: role_power contains both basic table rules and advanced column rules +# Basic rules: "table:permission" (r = read, w = write, rw = read+write, * = all tables (for like admins or smth)) +# Advanced rules: "table.column:permission" for more granular column level control +# Column permissions: r = read, w = write, rw = read+write, block = blocked (obviously) +# Use "table.*:block" to block all columns, then "table.specific_column:r" to allow specific ones +# Use "table.*:r" to allow all columns, then "table.sensitive_column:block" to block specific ones + +# In the future even more advaned rules called schizo_rules will be implemented where you can define sql logic based rules +# like "only allow access to rows where user_id = current_user_id" or "only allow access to assets where status != 'Stolen'" + +# i let an llm comment on the crap below so i can understand what ive done in like 3 months when i forget everything + +[permissions] + +[permissions."100"] +# Admin - full access to everything +basic_rules = [ + "*:rw", # Example of wildcard full access to all known tables + "asset_change_log:r", # More or less redundant but whatever + "issue_tracker_change_log:r" # Same as above +] +advanced_rules = [ + # Further granularity wow! + "assets.asset_numeric_id:r", + "assets.created_by:r", + "assets.last_modified_by:r", + "users.password_hash:block", +] +max_limit = 500 +max_where_conditions = 50 +session_timeout_minutes = 120 # Admins get longer sessions (2 hours) +max_concurrent_sessions = 5 # Admins can have more concurrent sessions +rollback_on_error = true # Rollback batch operations on any error +allow_batch_operations = true # Admins can use batch operations +user_settings_access = "read-write-all" # Admins can modify any user's preferences + +[permissions."75"] +# Manager - full asset management, limited user access +rollback_on_error = true # Rollback batch operations on any error +allow_batch_operations = true # Managers can use batch operations +basic_rules = [ + "assets:rw", + "lending_history:rw", + "audit_history:rw", + "maintenance_log:rw", + "borrowers:rw", + "categories:rw", + "zones:rw", + "suppliers:rw", + "templates:rw", + "audit_tasks:rw", + "issue_tracker:rw", + "physical_audits:rw", + "physical_audit_logs:rw", + "label_templates:rw", + "printer_settings:rw", + "print_history:r", + "users:r", # Basic read access, then restricted by advanced rules below + "roles:r", + "asset_change_log:r", + "issue_tracker_change_log:r" +] +advanced_rules = [ + # Table-specific protected (same as admin) + "assets.asset_numeric_id:r", + "assets.created_by:r", + "assets.last_modified_by:r", + # Users table - can read most info but not sensitive auth data + "users.password:block", + "users.password_hash:block", + "users.pin_code:block", + "users.login_string:block", + "users.password_reset_token:block", + "users.password_reset_expiry:block", +] +# Query limits (moderate for managers) +max_limit = 200 +max_where_conditions = 20 +user_settings_access = "read-write-own" # Managers can only modify their own preferences + +[permissions."50"] +# Staff - asset and lending management, NO user access +rollback_on_error = false # Don't rollback batch operations on error (continue processing) +allow_batch_operations = true # Staff can use batch operations +basic_rules = [ + "assets:rw", + "lending_history:rw", + "audit_history:rw", + "maintenance_log:rw", + "borrowers:rw", + "categories:r", + "zones:r", + "suppliers:r", + "templates:r", + "audit_tasks:r", + "issue_tracker:r", + "physical_audits:r", + "physical_audit_logs:r", + "label_templates:r", + "printer_settings:r", + "print_history:r", + "asset_change_log:r", + "issue_tracker_change_log:r" +] +advanced_rules = [ + # Table-specific protected (same as admin/manager) + "assets.asset_numeric_id:r", + "assets.created_by:r", + "assets.last_modified_by:r", +] +# No users table access for staff - security requirement +# Query limits (standard for staff) +max_limit = 100 +max_where_conditions = 10 +user_settings_access = "read-write-own" # Staff can only modify their own preferences + +[permissions."25"] +# Student - read-only access, no financial data, no user access, no change logs +rollback_on_error = true # Rollback batch operations on any error +allow_batch_operations = false # Students cannot use batch operations +basic_rules = [ + "assets:r", + "lending_history:r", + "borrowers:r", + "categories:r", + "zones:r" +] +advanced_rules = [ + # Assets table - hide financial and sensitive info + "assets.price:block", + "assets.purchase_date:block", + "assets.supplier_id:block", + "assets.warranty_expiry:block", + # Borrowers table - hide personal contact info + "borrowers.email:block", + "borrowers.phone_number:block", + "borrowers.notes:block" +] + +# Query limits +max_limit = 50 +max_where_conditions = 5 +user_settings_access = "read-own-only" # Students can only read their own preferences, not modify + + diff --git a/src/auth/mod.rs b/src/auth/mod.rs new file mode 100644 index 0000000..65cba10 --- /dev/null +++ b/src/auth/mod.rs @@ -0,0 +1,6 @@ +pub mod password; +pub mod pin; +pub mod session; +pub mod token; + +pub use session::SessionManager; diff --git a/src/auth/password.rs b/src/auth/password.rs new file mode 100644 index 0000000..580ad1f --- /dev/null +++ b/src/auth/password.rs @@ -0,0 +1,56 @@ +// Password authentication module +use crate::models::{Role, User}; +use anyhow::{Context, Result}; +use bcrypt::verify; +use sqlx::MySqlPool; + +pub async fn authenticate_password( + pool: &MySqlPool, + username: &str, + password: &str, +) -> Result<Option<(User, Role)>> { + // Fetch user from database + let user: Option<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE username = ? AND active = TRUE + "#, + ) + .bind(username) + .fetch_optional(pool) + .await + .context("Failed to fetch user from database")?; + + if let Some(user) = user { + // Verify password + let password_valid = + verify(password, &user.password).context("Failed to verify password")?; + + if password_valid { + // Fetch user's role + let role: Role = sqlx::query_as::<_, Role>( + "SELECT id, name, power, created_at FROM roles WHERE id = ?", + ) + .bind(user.role_id) + .fetch_one(pool) + .await + .context("Failed to fetch user role")?; + + // Update last login date + sqlx::query("UPDATE users SET last_login_date = NOW() WHERE id = ?") + .bind(user.id) + .execute(pool) + .await + .context("Failed to update last login date")?; + + Ok(Some((user, role))) + } else { + Ok(None) + } + } else { + Ok(None) + } +} diff --git a/src/auth/pin.rs b/src/auth/pin.rs new file mode 100644 index 0000000..4d1993c --- /dev/null +++ b/src/auth/pin.rs @@ -0,0 +1,66 @@ +// PIN authentication module +use crate::config::SecurityConfig; +use crate::models::{Role, User}; +use anyhow::{Context, Result}; +use sqlx::MySqlPool; + +pub async fn authenticate_pin( + pool: &MySqlPool, + username: &str, + pin: &str, + security_config: &SecurityConfig, +) -> Result<Option<(User, Role)>> { + // Fetch user from database + let user: Option<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE username = ? AND active = TRUE AND pin_code IS NOT NULL + "#, + ) + .bind(username) + .fetch_optional(pool) + .await + .context("Failed to fetch user from database")?; + + if let Some(user) = user { + // Check if user has a PIN set + if let Some(user_pin) = &user.pin_code { + // Verify PIN - either bcrypt hash or plaintext depending on config + let pin_valid = if security_config.hash_pins { + bcrypt::verify(pin, user_pin).unwrap_or(false) + } else { + user_pin == pin + }; + + if pin_valid { + // Fetch user's role + let role: Role = sqlx::query_as::<_, Role>( + "SELECT id, name, power, created_at FROM roles WHERE id = ?", + ) + .bind(user.role_id) + .fetch_one(pool) + .await + .context("Failed to fetch user role")?; + + // Update last login date + sqlx::query("UPDATE users SET last_login_date = NOW() WHERE id = ?") + .bind(user.id) + .execute(pool) + .await + .context("Failed to update last login date")?; + + Ok(Some((user, role))) + } else { + Ok(None) + } + } else { + // User doesn't have a PIN set + Ok(None) + } + } else { + Ok(None) + } +} diff --git a/src/auth/session.rs b/src/auth/session.rs new file mode 100644 index 0000000..277cef2 --- /dev/null +++ b/src/auth/session.rs @@ -0,0 +1,129 @@ +// Session management for SeckelAPI +use crate::config::Config; +use crate::models::Session; +use chrono::{Duration, Utc}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use uuid::Uuid; + +#[derive(Clone)] +pub struct SessionManager { + sessions: Arc<RwLock<HashMap<String, Session>>>, + config: Arc<Config>, +} + +impl SessionManager { + pub fn new(config: Arc<Config>) -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + config, + } + } + + fn get_timeout_for_power(&self, power: i32) -> u64 { + self.config.get_session_timeout(power) + } + + pub fn create_session( + &self, + user_id: i32, + username: String, + role_id: i32, + role_name: String, + power: i32, + ) -> String { + let token = Uuid::new_v4().to_string(); + let now = Utc::now(); + + let session = Session { + user_id, + username, + role_id, + role_name, + power, + created_at: now, + last_accessed: now, + }; + + if let Ok(mut sessions) = self.sessions.write() { + // Check concurrent session limit for this user + let max_sessions = self.config.get_max_concurrent_sessions(power); + let user_sessions: Vec<(String, chrono::DateTime<Utc>)> = sessions + .iter() + .filter(|(_, s)| s.user_id == user_id) + .map(|(token, s)| (token.clone(), s.created_at)) + .collect(); + + // If at limit, remove the oldest session + if user_sessions.len() >= max_sessions as usize { + if let Some(oldest_token) = user_sessions + .iter() + .min_by_key(|(_, created)| created) + .map(|(token, _)| token.clone()) + { + sessions.remove(&oldest_token); + } + } + + sessions.insert(token.clone(), session); + } + + token + } + + pub fn get_session(&self, token: &str) -> Option<Session> { + let mut session_to_update = None; + + { + if let Ok(sessions) = self.sessions.read() { + if let Some(session) = sessions.get(token) { + let now = Utc::now(); + let timeout_for_user = self.get_timeout_for_power(session.power); + let timeout_duration = Duration::minutes(timeout_for_user as i64); + + // Check if session has expired + if now - session.last_accessed > timeout_duration { + return None; // Session expired, will be cleaned up later + } else { + session_to_update = Some(session.clone()); + } + } + } + } + + if let Some(mut session) = session_to_update { + // Update last accessed time only if refresh_on_activity is enabled + if self.config.security.refresh_session_on_activity { + session.last_accessed = Utc::now(); + + if let Ok(mut sessions) = self.sessions.write() { + sessions.insert(token.to_string(), session.clone()); + } + } + + Some(session) + } else { + None + } + } + + pub fn remove_session(&self, token: &str) -> bool { + if let Ok(mut sessions) = self.sessions.write() { + sessions.remove(token).is_some() + } else { + false + } + } + + pub fn cleanup_expired_sessions(&self) { + let now = Utc::now(); + + if let Ok(mut sessions) = self.sessions.write() { + sessions.retain(|_, session| { + let timeout_for_user = self.get_timeout_for_power(session.power); + let timeout_duration = Duration::minutes(timeout_for_user as i64); + now - session.last_accessed <= timeout_duration + }); + } + } +} diff --git a/src/auth/token.rs b/src/auth/token.rs new file mode 100644 index 0000000..17f75e4 --- /dev/null +++ b/src/auth/token.rs @@ -0,0 +1,99 @@ +// Token/RFID authentication module +use crate::config::SecurityConfig; +use crate::models::{Role, User}; +use anyhow::{Context, Result}; +use sqlx::MySqlPool; + +pub async fn authenticate_token( + pool: &MySqlPool, + login_string: &str, + security_config: &SecurityConfig, +) -> Result<Option<(User, Role)>> { + // If hashing is enabled, we can't use WHERE login_string = ? directly + // Need to fetch all users and verify hashes + if security_config.hash_tokens { + // Fetch all active users with login_string set + let users: Vec<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE login_string IS NOT NULL AND active = TRUE + "#, + ) + .fetch_all(pool) + .await + .context("Failed to fetch users from database")?; + + // Find matching user by verifying bcrypt hash + for user in users { + if let Some(ref stored_hash) = user.login_string { + if bcrypt::verify(login_string, stored_hash).unwrap_or(false) { + // Found matching user + return authenticate_user_by_id(pool, user.id).await; + } + } + } + Ok(None) + } else { + // Plaintext comparison - direct database query + let user: Option<User> = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE login_string = ? AND active = TRUE + "#, + ) + .bind(login_string) + .fetch_optional(pool) + .await + .context("Failed to fetch user from database")?; + + if let Some(user) = user { + authenticate_user_by_id(pool, user.id).await + } else { + Ok(None) + } + } +} + +async fn authenticate_user_by_id(pool: &MySqlPool, user_id: i32) -> Result<Option<(User, Role)>> { + // Fetch user + let user: User = sqlx::query_as::<_, User>( + r#" + SELECT id, name, username, password, pin_code, login_string, role_id, + email, phone, notes, active, last_login_date, created_date, + password_reset_token, password_reset_expiry + FROM users + WHERE id = ? + "#, + ) + .bind(user_id) + .fetch_one(pool) + .await + .context("Failed to fetch user")?; + + if user.active { + // Fetch user's role + let role: Role = + sqlx::query_as::<_, Role>("SELECT id, name, power, created_at FROM roles WHERE id = ?") + .bind(user.role_id) + .fetch_one(pool) + .await + .context("Failed to fetch user role")?; + + // Update last login date + sqlx::query("UPDATE users SET last_login_date = NOW() WHERE id = ?") + .bind(user.id) + .execute(pool) + .await + .context("Failed to update last login date")?; + + Ok(Some((user, role))) + } else { + Ok(None) + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..602c0d8 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,845 @@ +// Configuration module for SeckelAPI +use anyhow::{Context, Result}; +use ipnet::IpNet; +use serde::Deserialize; +use std::collections::HashMap; +use std::fs; +use std::str::FromStr; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub database: DatabaseConfig, + pub server: ServerConfig, + pub security: SecurityConfig, + pub permissions: PermissionsConfig, + pub logging: LoggingConfig, + pub auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + pub scheduled_queries: Option<ScheduledQueriesConfig>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PermissionsConfig { + #[serde(flatten)] + pub power_levels: HashMap<String, PowerLevelPermissions>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PowerLevelPermissions { + pub basic_rules: Vec<String>, + pub advanced_rules: Option<Vec<String>>, + pub max_limit: Option<u32>, + pub max_where_conditions: Option<u32>, + pub session_timeout_minutes: Option<u64>, + pub max_concurrent_sessions: Option<u32>, + #[serde(default = "default_true")] + pub rollback_on_error: bool, + #[serde(default = "default_false")] + pub allow_batch_operations: bool, + #[serde(default = "default_user_settings_access")] + pub user_settings_access: UserSettingsAccess, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum UserSettingsAccess { + ReadOwnOnly, + ReadWriteOwn, + ReadWriteAll, +} + +fn default_user_settings_access() -> UserSettingsAccess { + UserSettingsAccess::ReadWriteOwn +} + +#[derive(Debug, Clone, Deserialize)] +pub struct DatabaseConfig { + pub host: String, + pub port: u16, + pub database: String, + pub username: String, + pub password: String, + #[serde(default = "default_min_connections")] + pub min_connections: u32, + #[serde(default = "default_max_connections")] + pub max_connections: u32, + #[serde(default = "default_connection_timeout_seconds")] + pub connection_timeout_seconds: u64, + #[serde(default = "default_connection_timeout_wait")] + pub connection_timeout_wait: u64, + #[serde(default = "default_connection_check_interval")] + pub connection_check: u64, +} + +fn default_min_connections() -> u32 { + 1 +} + +fn default_max_connections() -> u32 { + 10 +} + +fn default_connection_timeout_seconds() -> u64 { + 30 +} + +fn default_connection_timeout_wait() -> u64 { + 5 +} + +fn default_connection_check_interval() -> u64 { + 30 +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ServerConfig { + pub host: String, + pub port: u16, + #[serde(default = "default_request_body_limit_mb")] + pub request_body_limit_mb: usize, +} + +fn default_request_body_limit_mb() -> usize { + 10 // 10 MB default +} + +#[derive(Debug, Clone, Deserialize)] +pub struct SecurityConfig { + pub whitelisted_pin_ips: Vec<String>, + pub whitelisted_string_ips: Vec<String>, + pub session_timeout_minutes: u64, + pub refresh_session_on_activity: bool, + pub max_concurrent_sessions: u32, + pub session_cleanup_interval_minutes: u64, + pub default_max_limit: u32, + pub default_max_where_conditions: u32, + #[serde(default = "default_hash_pins")] + pub hash_pins: bool, // Whether to use bcrypt for PINs (false = plaintext) + #[serde(default = "default_hash_tokens")] + pub hash_tokens: bool, // Whether to use bcrypt for login_strings (false = plaintext) + #[serde(default = "default_pin_column")] + pub pin_column: String, // Database column name for PINs + #[serde(default = "default_token_column")] + pub token_column: String, // Database column name for login strings + // Rate limiting + #[serde(default = "default_enable_rate_limiting")] + pub enable_rate_limiting: bool, // Master switch for rate limiting (disable for debugging) + + // Auth rate limiting + #[serde(default = "default_auth_rate_limit_per_minute")] + pub auth_rate_limit_per_minute: u32, // Max auth requests per IP per minute + #[serde(default = "default_auth_rate_limit_per_second")] + pub auth_rate_limit_per_second: u32, // Max auth requests per IP per second (burst protection) + + // API rate limiting + #[serde(default = "default_api_rate_limit_per_minute")] + pub api_rate_limit_per_minute: u32, // Max API calls per user per minute + #[serde(default = "default_api_rate_limit_per_second")] + pub api_rate_limit_per_second: u32, // Max API calls per user per second (burst protection) + + // Table configuration (moved from basics.toml) + #[serde(default = "default_known_tables")] + pub known_tables: Vec<String>, + #[serde(default = "default_read_only_tables")] + pub read_only_tables: Vec<String>, + #[serde(default = "default_global_write_protected_columns")] + pub global_write_protected_columns: Vec<String>, + + // User preferences access control + #[serde(default = "default_user_settings_access")] + pub default_user_settings_access: UserSettingsAccess, +} + +fn default_known_tables() -> Vec<String> { + vec![] // Empty by default, must be configured +} + +fn default_read_only_tables() -> Vec<String> { + vec![] // Empty by default +} + +fn default_global_write_protected_columns() -> Vec<String> { + vec!["id".to_string()] // Protect 'id' by default at minimum +} + +fn default_hash_pins() -> bool { + false // Default to plaintext (must be explicitly enabled for bcrypt hashing) +} + +fn default_hash_tokens() -> bool { + false // Default to plaintext (must be explicitly enabled for bcrypt hashing) +} + +fn default_pin_column() -> String { + "pin_code".to_string() +} + +fn default_token_column() -> String { + "login_string".to_string() +} + +fn default_enable_rate_limiting() -> bool { + true // Enable by default for security +} + +fn default_auth_rate_limit_per_minute() -> u32 { + 10 // 10 login attempts per IP per minute (prevents brute force) +} + +fn default_auth_rate_limit_per_second() -> u32 { + 5 // Max 5 login attempts per second (burst protection) +} + +fn default_api_rate_limit_per_minute() -> u32 { + 60 // 60 API calls per user per minute +} + +fn default_api_rate_limit_per_second() -> u32 { + 10 // Max 10 API calls per second (burst protection) +} + +#[derive(Debug, Clone, Deserialize)] +pub struct LoggingConfig { + pub request_log: Option<String>, + pub query_log: Option<String>, + pub error_log: Option<String>, + pub warning_log: Option<String>, // Warning messages + pub info_log: Option<String>, // Info messages + pub combined_log: Option<String>, // Unified log with request IDs + pub level: String, + pub mask_passwords: bool, + #[serde(default)] + pub sensitive_fields: Vec<String>, // Fields to mask beyond password/pin + #[serde(default)] + pub custom_filters: Vec<CustomLogFilter>, // Regex-based log routing +} + +#[derive(Debug, Clone, Deserialize)] +pub struct CustomLogFilter { + pub name: String, + pub output_file: String, + pub pattern: String, // Regex pattern to match + #[serde(default = "default_filter_enabled")] + pub enabled: bool, +} + +fn default_filter_enabled() -> bool { + true +} + +fn default_true() -> bool { + true +} + +fn default_false() -> bool { + false +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AutoGenerationConfig { + pub field: String, + #[serde(rename = "type")] + pub gen_type: String, + pub length: Option<u32>, + pub range_min: Option<u64>, + pub range_max: Option<u64>, + pub max_attempts: Option<u32>, + #[serde(default = "default_on_action")] + pub on_action: String, // "insert", "update", or "both" +} + +fn default_on_action() -> String { + "insert".to_string() +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ScheduledQueriesConfig { + #[serde(default)] + pub tasks: Vec<ScheduledQueryTask>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ScheduledQueryTask { + pub name: String, + pub description: String, + pub query: String, + pub interval_minutes: u64, + #[serde(default = "default_enabled")] + pub enabled: bool, + #[serde(default = "default_run_on_startup")] + pub run_on_startup: bool, +} + +fn default_enabled() -> bool { + true +} + +fn default_run_on_startup() -> bool { + true // Run immediately on startup by default +} + +impl Config { + pub fn load() -> Result<Self> { + Self::load_from_folder() + } + + fn load_from_folder() -> Result<Self> { + // Load consolidated config files + let basics_content = + fs::read_to_string("config/basics.toml").context("Can't read config/basics.toml")?; + let security_content = fs::read_to_string("config/security.toml") + .context("Can't read config/security.toml")?; + let logging_content = + fs::read_to_string("config/logging.toml").context("Can't read config/logging.toml")?; + + // Parse individual sections + #[derive(Deserialize)] + struct BasicsWrapper { + server: ServerConfig, + database: DatabaseConfig, + #[serde(default)] + auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + } + #[derive(Deserialize)] + struct SecurityWrapper { + security: SecurityConfig, + permissions: PermissionsConfig, + } + #[derive(Deserialize)] + struct LoggingWrapper { + logging: LoggingConfig, + } + + #[derive(Deserialize)] + struct FunctionsWrapper { + #[serde(default)] + auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + #[serde(default)] + scheduled_queries: Option<ScheduledQueriesConfig>, + } + + let basics: BasicsWrapper = + toml::from_str(&basics_content).context("Failed to parse basics.toml")?; + let security: SecurityWrapper = + toml::from_str(&security_content).context("Failed to parse security.toml")?; + let logging: LoggingWrapper = + toml::from_str(&logging_content).context("Failed to parse logging.toml")?; + + // Load functions.toml if it exists, otherwise use basics fallback + let functions: FunctionsWrapper = if fs::metadata("config/functions.toml").is_ok() { + let functions_content = fs::read_to_string("config/functions.toml") + .context("Can't read config/functions.toml")?; + toml::from_str(&functions_content).context("Failed to parse functions.toml")? + } else { + // Fallback to basics.toml for auto_generation + FunctionsWrapper { + auto_generation: basics.auto_generation.clone(), + scheduled_queries: None, + } + }; + + let config = Config { + database: basics.database, + server: basics.server, + security: security.security, + permissions: security.permissions, + logging: logging.logging, + auto_generation: functions.auto_generation, + scheduled_queries: functions.scheduled_queries, + }; + + // Validate configuration + config.validate()?; + + Ok(config) + } + + /// Validate configuration values + fn validate(&self) -> Result<()> { + // Validate server port (u16 is already limited to 0-65535, just check for 0) + if self.server.port == 0 { + anyhow::bail!( + "Invalid server.port: {} (must be 1-65535)", + self.server.port + ); + } + + // Validate database port (u16 is already limited to 0-65535, just check for 0) + if self.database.port == 0 { + anyhow::bail!( + "Invalid database.port: {} (must be 1-65535)", + self.database.port + ); + } + + // Validate database connection details + if self.database.host.trim().is_empty() { + anyhow::bail!("database.host cannot be empty"); + } + if self.database.database.trim().is_empty() { + anyhow::bail!("database.database cannot be empty"); + } + if self.database.username.trim().is_empty() { + anyhow::bail!("database.username cannot be empty"); + } + + // Validate PIN whitelist IPs (can be CIDR or single IPs) + for ip in &self.security.whitelisted_pin_ips { + if IpNet::from_str(ip).is_err() && ip.parse::<std::net::IpAddr>().is_err() { + anyhow::bail!("Invalid IP/CIDR in whitelisted_pin_ips: {}", ip); + } + } + + // Validate string auth whitelist IPs (can be CIDR or single IPs) + for ip in &self.security.whitelisted_string_ips { + if IpNet::from_str(ip).is_err() && ip.parse::<std::net::IpAddr>().is_err() { + anyhow::bail!("Invalid IP/CIDR in whitelisted_string_ips: {}", ip); + } + } + + // Validate session timeout + if self.security.session_timeout_minutes == 0 { + anyhow::bail!("security.session_timeout_minutes must be greater than 0"); + } + + // Validate permission syntax + for (power_level, perms) in &self.permissions.power_levels { + // Validate power level is numeric + if power_level.parse::<i32>().is_err() { + anyhow::bail!("Invalid power level '{}': must be numeric", power_level); + } + + // Validate basic rules format + for rule in &perms.basic_rules { + Self::validate_permission_rule(rule).with_context(|| { + format!("Invalid basic rule for power level {}", power_level) + })?; + } + + // Validate advanced rules format + if let Some(advanced_rules) = &perms.advanced_rules { + for rule in advanced_rules { + Self::validate_advanced_rule(rule).with_context(|| { + format!("Invalid advanced rule for power level {}", power_level) + })?; + } + } + } + + // Validate logging configuration (paths are now optional) + // No validation needed for optional log paths + + // Validate log level + let valid_levels = ["trace", "debug", "info", "warn", "error"]; + if !valid_levels.contains(&self.logging.level.to_lowercase().as_str()) { + anyhow::bail!( + "Invalid logging.level '{}': must be one of: {}", + self.logging.level, + valid_levels.join(", ") + ); + } + + Ok(()) + } + + /// Validate basic permission rule format (table:permission) + fn validate_permission_rule(rule: &str) -> Result<()> { + let parts: Vec<&str> = rule.split(':').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Permission rule '{}' must be in format 'table:permission'", + rule + ); + } + + let table = parts[0]; + let permission = parts[1]; + + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty in rule '{}'", rule); + } + + // Validate permission type + let valid_permissions = ["r", "rw", "rwd"]; + if !valid_permissions.contains(&permission) { + anyhow::bail!( + "Invalid permission '{}' in rule '{}': must be one of: {}", + permission, + rule, + valid_permissions.join(", ") + ); + } + + Ok(()) + } + + /// Validate advanced permission rule format (table.column:permission) + fn validate_advanced_rule(rule: &str) -> Result<()> { + let parts: Vec<&str> = rule.split(':').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Advanced rule '{}' must be in format 'table.column:permission'", + rule + ); + } + + let table_col = parts[0]; + let permission = parts[1]; + + // Validate table.column format + let col_parts: Vec<&str> = table_col.split('.').collect(); + if col_parts.len() != 2 { + anyhow::bail!("Advanced rule '{}' must have 'table.column' format", rule); + } + + let table = col_parts[0]; + let column = col_parts[1]; + + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty in rule '{}'", rule); + } + if column.trim().is_empty() { + anyhow::bail!("Column name cannot be empty in rule '{}'", rule); + } + + // Validate permission type + let valid_permissions = ["r", "w", "rw", "block"]; + if !valid_permissions.contains(&permission) { + anyhow::bail!( + "Invalid permission '{}' in rule '{}': must be one of: {}", + permission, + rule, + valid_permissions.join(", ") + ); + } + + Ok(()) + } + + pub fn get_database_url(&self) -> String { + format!( + "mysql://{}:{}@{}:{}/{}", + self.database.username, + self.database.password, + self.database.host, + self.database.port, + self.database.database + ) + } + + pub fn is_pin_ip_whitelisted(&self, ip: &str) -> bool { + self.is_ip_in_whitelist(ip, &self.security.whitelisted_pin_ips) + } + + pub fn is_string_ip_whitelisted(&self, ip: &str) -> bool { + self.is_ip_in_whitelist(ip, &self.security.whitelisted_string_ips) + } + + fn is_ip_in_whitelist(&self, ip: &str, whitelist: &[String]) -> bool { + let client_ip = match ip.parse::<std::net::IpAddr>() { + Ok(addr) => addr, + Err(_) => return false, + }; + + for allowed in whitelist { + // Try to parse as a network (CIDR notation) + if let Ok(network) = IpNet::from_str(allowed) { + if network.contains(&client_ip) { + return true; + } + } + // Try to parse as a single IP address + else if let Ok(allowed_ip) = allowed.parse::<std::net::IpAddr>() { + if client_ip == allowed_ip { + return true; + } + } + } + false + } + + pub fn get_role_permissions(&self, power: i32) -> Vec<(String, String)> { + let power_str = power.to_string(); + + if let Some(power_perms) = self.permissions.power_levels.get(&power_str) { + power_perms + .basic_rules + .iter() + .filter_map(|perm| { + let parts: Vec<&str> = perm.split(':').collect(); + if parts.len() == 2 { + Some((parts[0].to_string(), parts[1].to_string())) + } else { + None + } + }) + .collect() + } else { + Vec::new() + } + } + + // Helper methods for new configuration sections + pub fn get_known_tables(&self) -> Vec<String> { + if self.security.known_tables.is_empty() { + tracing::warn!("No known_tables configured in security.toml - returning empty list. Wildcard permissions (*) will not work."); + } + self.security.known_tables.clone() + } + + pub fn is_read_only_table(&self, table: &str) -> bool { + self.security.read_only_tables.contains(&table.to_string()) + } + + pub fn get_auto_generation_config(&self, table: &str) -> Option<&AutoGenerationConfig> { + self.auto_generation + .as_ref() + .and_then(|configs| configs.get(table)) + } + + pub fn get_basic_permissions(&self, power: i32) -> Option<&Vec<String>> { + self.permissions + .power_levels + .get(&power.to_string()) + .map(|p| &p.basic_rules) + } + + pub fn get_advanced_permissions(&self, power: i32) -> Option<&Vec<String>> { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.advanced_rules.as_ref()) + } + + pub fn filter_readable_columns( + &self, + power: i32, + table: &str, + requested_columns: &[String], + ) -> Vec<String> { + if let Some(advanced_rules) = self.get_advanced_permissions(power) { + let mut allowed_columns = Vec::new(); + let mut blocked_columns = Vec::new(); + let mut has_wildcard_block = false; + let mut has_wildcard_allow = false; + + // Parse advanced rules for this table + for rule in advanced_rules { + if let Some((table_col, permission)) = rule.split_once(':') { + if let Some((rule_table, column)) = table_col.split_once('.') { + if rule_table == table { + match permission { + "block" => { + if column == "*" { + has_wildcard_block = true; + } else { + blocked_columns.push(column.to_string()); + } + } + "r" | "rw" => { + if column == "*" { + has_wildcard_allow = true; + } else { + allowed_columns.push(column.to_string()); + } + } + _ => {} + } + } + } + } + } + + // Filter requested columns based on rules + let mut result = Vec::new(); + for column in requested_columns { + let allow = if has_wildcard_block { + // If wildcard block, only allow specifically allowed columns + allowed_columns.contains(column) + } else if has_wildcard_allow { + // If wildcard allow, block only specifically blocked columns + !blocked_columns.contains(column) + } else { + // No wildcard rules, block specifically blocked columns + !blocked_columns.contains(column) + }; + + if allow { + result.push(column.clone()); + } + } + + result + } else { + // No advanced rules, return all requested columns + requested_columns.to_vec() + } + } + + pub fn filter_writable_columns( + &self, + power: i32, + table: &str, + requested_columns: &[String], + ) -> Vec<String> { + // First, apply global write-protected columns (these override everything) + let mut globally_blocked: Vec<String> = + self.security.global_write_protected_columns.clone(); + + if let Some(advanced_rules) = self.get_advanced_permissions(power) { + let mut allowed_columns = Vec::new(); + let mut blocked_columns = Vec::new(); + let mut has_wildcard_block = false; + let mut has_wildcard_allow = false; + + // Parse advanced rules for this table + for rule in advanced_rules { + if let Some((table_col, permission)) = rule.split_once(':') { + if let Some((rule_table, column)) = table_col.split_once('.') { + if rule_table == table { + match permission { + "block" => { + if column == "*" { + has_wildcard_block = true; + } else { + blocked_columns.push(column.to_string()); + } + } + "w" | "rw" => { + if column == "*" { + has_wildcard_allow = true; + } else { + allowed_columns.push(column.to_string()); + } + } + "r" => { + // Read-only: block from writing but not from reading + blocked_columns.push(column.to_string()); + } + _ => {} + } + } + } + } + } + + // Merge advanced_rules blocked columns with globally protected + blocked_columns.append(&mut globally_blocked); + + // Filter requested columns based on rules + let mut result = Vec::new(); + for column in requested_columns { + let allow = if has_wildcard_block { + // If wildcard block, only allow specifically allowed columns + allowed_columns.contains(column) + } else if has_wildcard_allow { + // If wildcard allow, block only specifically blocked columns + !blocked_columns.contains(column) + } else { + // No wildcard rules, block specifically blocked columns + !blocked_columns.contains(column) + }; + + if allow { + result.push(column.clone()); + } + } + + result + } else { + // No advanced rules, just filter out globally protected columns + requested_columns + .iter() + .filter(|col| !globally_blocked.contains(col)) + .cloned() + .collect() + } + } + + /// Get the max_limit for a specific power level (with fallback to next lower power level) + pub fn get_max_limit(&self, power: i32) -> u32 { + // Try exact match first + if let Some(perms) = self.permissions.power_levels.get(&power.to_string()) { + if let Some(limit) = perms.max_limit { + return limit; + } + } + + // Find next lower power level + let fallback_power = self.find_fallback_power_level(power); + if let Some(fb_power) = fallback_power { + tracing::warn!( + "Power level {} not found in config, falling back to power level {}", + power, + fb_power + ); + if let Some(perms) = self.permissions.power_levels.get(&fb_power.to_string()) { + if let Some(limit) = perms.max_limit { + return limit; + } + } + } + + // Ultimate fallback to default + self.security.default_max_limit + } + + /// Get the max_where_conditions for a specific power level (with fallback to next lower power level) + pub fn get_max_where_conditions(&self, power: i32) -> u32 { + // Try exact match first + if let Some(perms) = self.permissions.power_levels.get(&power.to_string()) { + if let Some(max_where) = perms.max_where_conditions { + return max_where; + } + } + + // Find next lower power level + let fallback_power = self.find_fallback_power_level(power); + if let Some(fb_power) = fallback_power { + tracing::warn!( + "Power level {} not found in config, falling back to power level {}", + power, + fb_power + ); + if let Some(perms) = self.permissions.power_levels.get(&fb_power.to_string()) { + if let Some(max_where) = perms.max_where_conditions { + return max_where; + } + } + } + + // Ultimate fallback to default + self.security.default_max_where_conditions + } + + /// Find the next lower configured power level (e.g., power=60 → fallback to 50) + fn find_fallback_power_level(&self, power: i32) -> Option<i32> { + let mut available_powers: Vec<i32> = self + .permissions + .power_levels + .keys() + .filter_map(|k| k.parse::<i32>().ok()) + .filter(|&p| p < power) // Only consider lower power levels + .collect(); + + available_powers.sort_by(|a, b| b.cmp(a)); // Sort descending + available_powers.first().copied() + } + + /// Get the session_timeout_minutes for a specific power level (with fallback to default) + pub fn get_session_timeout(&self, power: i32) -> u64 { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.session_timeout_minutes) + .unwrap_or(self.security.session_timeout_minutes) + } + + /// Get the max_concurrent_sessions for a specific power level (with fallback to default) + pub fn get_max_concurrent_sessions(&self, power: i32) -> u32 { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.max_concurrent_sessions) + .unwrap_or(self.security.max_concurrent_sessions) + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..cbee57f --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,3 @@ +pub mod pool; + +pub use pool::{Database, DatabaseInitError}; diff --git a/src/db/pool.rs b/src/db/pool.rs new file mode 100644 index 0000000..4390739 --- /dev/null +++ b/src/db/pool.rs @@ -0,0 +1,110 @@ +// Database connection pool module +use crate::config::DatabaseConfig; +use anyhow::{Context, Result as AnyResult}; +use sqlx::{Error as SqlxError, MySqlPool}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +#[derive(Debug)] +pub enum DatabaseInitError { + Fatal(anyhow::Error), + Retryable(anyhow::Error), +} + +impl DatabaseInitError { + fn fatal(err: anyhow::Error) -> Self { + Self::Fatal(err) + } + + fn retryable(err: anyhow::Error) -> Self { + Self::Retryable(err) + } +} + +#[derive(Clone)] +pub struct Database { + pool: MySqlPool, + availability: Arc<AtomicBool>, +} + +impl Database { + pub async fn new(config: &DatabaseConfig) -> Result<Self, DatabaseInitError> { + let database_url = format!( + "mysql://{}:{}@{}:{}/{}", + config.username, config.password, config.host, config.port, config.database + ); + + let pool = sqlx::mysql::MySqlPoolOptions::new() + .min_connections(config.min_connections) + .max_connections(config.max_connections) + .acquire_timeout(std::time::Duration::from_secs( + config.connection_timeout_seconds, + )) + .connect(&database_url) + .await + .map_err(|err| map_sqlx_error(err, "Failed to connect to database"))?; + + // Test the connection + sqlx::query("SELECT 1") + .execute(&pool) + .await + .map_err(|err| map_sqlx_error(err, "Failed to test database connection"))?; + + Ok(Database { + pool, + availability: Arc::new(AtomicBool::new(true)), + }) + } + + pub fn pool(&self) -> &MySqlPool { + &self.pool + } + + pub fn is_available(&self) -> bool { + self.availability.load(Ordering::Relaxed) + } + + pub fn mark_available(&self) { + self.availability.store(true, Ordering::Relaxed); + } + + pub fn mark_unavailable(&self) { + self.availability.store(false, Ordering::Relaxed); + } + + pub async fn set_current_user(&self, user_id: i32) -> AnyResult<()> { + sqlx::query("SET @current_user_id = ?") + .bind(user_id) + .execute(&self.pool) + .await + .context("Failed to set current user ID")?; + + Ok(()) + } + + pub async fn close(&self) { + self.pool.close().await; + } +} + +fn map_sqlx_error(err: SqlxError, context: &str) -> DatabaseInitError { + let retryable = is_retryable_sqlx_error(&err); + let wrapped = anyhow::Error::new(err).context(context.to_string()); + if retryable { + DatabaseInitError::retryable(wrapped) + } else { + DatabaseInitError::fatal(wrapped) + } +} + +fn is_retryable_sqlx_error(err: &SqlxError) -> bool { + matches!( + err, + SqlxError::Io(_) + | SqlxError::PoolTimedOut + | SqlxError::PoolClosed + | SqlxError::WorkerCrashed + ) +} diff --git a/src/logging/logger.rs b/src/logging/logger.rs new file mode 100644 index 0000000..e063f8f --- /dev/null +++ b/src/logging/logger.rs @@ -0,0 +1,373 @@ +// Audit logging module with request ID tracing and custom filters +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use regex::Regex; +use serde_json::Value; +use std::fs::OpenOptions; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Clone)] +struct CustomFilter { + name: String, + pattern: Regex, + file: Arc<Mutex<std::fs::File>>, +} + +#[derive(Clone)] +pub struct AuditLogger { + mask_passwords: bool, + sensitive_fields: Vec<String>, + request_file: Option<Arc<Mutex<std::fs::File>>>, + query_file: Option<Arc<Mutex<std::fs::File>>>, + error_file: Option<Arc<Mutex<std::fs::File>>>, + warning_file: Option<Arc<Mutex<std::fs::File>>>, + info_file: Option<Arc<Mutex<std::fs::File>>>, + combined_file: Option<Arc<Mutex<std::fs::File>>>, + custom_filters: Vec<CustomFilter>, +} + +impl AuditLogger { + pub fn new( + request_log_path: Option<String>, + query_log_path: Option<String>, + error_log_path: Option<String>, + warning_log_path: Option<String>, + info_log_path: Option<String>, + combined_log_path: Option<String>, + mask_passwords: bool, + sensitive_fields: Vec<String>, + custom_filter_configs: Vec<crate::config::CustomLogFilter>, + ) -> Result<Self> { + // Helper function to open a log file if path is provided + let open_log_file = |path: &Option<String>| -> Result<Option<Arc<Mutex<std::fs::File>>>> { + if let Some(path_str) = path { + // Ensure log directories exist + if let Some(parent) = Path::new(path_str).parent() { + std::fs::create_dir_all(parent).context("Failed to create log directory")?; + } + + let file = OpenOptions::new() + .create(true) + .append(true) + .open(path_str) + .context(format!("Failed to open log file: {}", path_str))?; + + Ok(Some(Arc::new(Mutex::new(file)))) + } else { + Ok(None) + } + }; + + // Initialize custom filters + let mut custom_filters = Vec::new(); + for filter_config in custom_filter_configs { + if filter_config.enabled { + // Compile regex pattern + let pattern = Regex::new(&filter_config.pattern).context(format!( + "Invalid regex pattern in filter '{}': {}", + filter_config.name, filter_config.pattern + ))?; + + // Open filter output file + if let Some(parent) = Path::new(&filter_config.output_file).parent() { + std::fs::create_dir_all(parent) + .context("Failed to create filter log directory")?; + } + + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&filter_config.output_file) + .context(format!( + "Failed to open filter log file: {}", + filter_config.output_file + ))?; + + custom_filters.push(CustomFilter { + name: filter_config.name.clone(), + pattern, + file: Arc::new(Mutex::new(file)), + }); + } + } + + Ok(Self { + mask_passwords, + sensitive_fields, + request_file: open_log_file(&request_log_path)?, + query_file: open_log_file(&query_log_path)?, + error_file: open_log_file(&error_log_path)?, + warning_file: open_log_file(&warning_log_path)?, + info_file: open_log_file(&info_log_path)?, + combined_file: open_log_file(&combined_log_path)?, + custom_filters, + }) + } + + /// Generate a unique request ID for transaction tracing + pub fn generate_request_id() -> String { + format!("{}", uuid::Uuid::new_v4().as_u128() & 0xFFFFFFFF_FFFFFFFF) // 16 hex chars + } + + /// Write to combined log and apply custom filters + async fn write_combined_and_filter(&self, entry: &str) -> Result<()> { + // Write to combined log if configured + if let Some(ref file_mutex) = self.combined_file { + let mut file = file_mutex.lock().await; + file.write_all(entry.as_bytes()) + .context("Failed to write to combined log")?; + file.flush().context("Failed to flush combined log")?; + } + + // Apply custom filters + for filter in &self.custom_filters { + if filter.pattern.is_match(entry) { + let mut file = filter.file.lock().await; + file.write_all(entry.as_bytes()) + .context(format!("Failed to write to filter log: {}", filter.name))?; + file.flush() + .context(format!("Failed to flush filter log: {}", filter.name))?; + } + } + + Ok(()) + } + + pub async fn log_request( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + _ip: &str, + user: Option<&str>, + power: Option<i32>, + endpoint: &str, + payload: &Value, + ) -> Result<()> { + let mut masked_payload = payload.clone(); + + if self.mask_passwords { + self.mask_sensitive_data(&mut masked_payload); + } + + let user_str = user.unwrap_or("anonymous"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | REQUEST | user={} | {} | endpoint={} | payload={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + endpoint, + serde_json::to_string(&masked_payload).unwrap_or_else(|_| "invalid_json".to_string()) + ); + + // Write to legacy request log if configured + if let Some(ref file_mutex) = self.request_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to request log")?; + file.flush().context("Failed to flush request log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_query( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + user: &str, + power: Option<i32>, + query: &str, + parameters: Option<&Value>, + rows_affected: Option<u64>, + ) -> Result<()> { + let params_str = if let Some(params) = parameters { + serde_json::to_string(params).unwrap_or_else(|_| "invalid_json".to_string()) + } else { + "null".to_string() + }; + + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + let rows_str = rows_affected + .map(|r| format!("rows={}", r)) + .unwrap_or_else(|| "rows=0".to_string()); + + let log_entry = format!( + "{} [{}] | QUERY | user={} | {} | {} | query={} | params={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user, + power_str, + rows_str, + query, + params_str + ); + + // Write to legacy query log if configured + if let Some(ref file_mutex) = self.query_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to query log")?; + file.flush().context("Failed to flush query log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_error( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + error: &str, + context: Option<&str>, + user: Option<&str>, + power: Option<i32>, + ) -> Result<()> { + let user_str = user.unwrap_or("unknown"); + let context_str = context.unwrap_or("general"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | ERROR | user={} | {} | context={} | error={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + context_str, + error + ); + + // Write to legacy error log if configured + if let Some(ref file_mutex) = self.error_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to error log")?; + file.flush().context("Failed to flush error log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_warning( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + message: &str, + context: Option<&str>, + user: Option<&str>, + power: Option<i32>, + ) -> Result<()> { + let user_str = user.unwrap_or("unknown"); + let context_str = context.unwrap_or("general"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | WARNING | user={} | {} | context={} | message={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + context_str, + message + ); + + // Write to warning log if configured + if let Some(ref file_mutex) = self.warning_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to warning log")?; + file.flush().context("Failed to flush warning log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + pub async fn log_info( + &self, + request_id: &str, + timestamp: DateTime<Utc>, + message: &str, + context: Option<&str>, + user: Option<&str>, + power: Option<i32>, + ) -> Result<()> { + let user_str = user.unwrap_or("system"); + let context_str = context.unwrap_or("general"); + let power_str = power + .map(|p| format!("power={}", p)) + .unwrap_or_else(|| "power=0".to_string()); + + let log_entry = format!( + "{} [{}] | INFO | user={} | {} | context={} | message={}\n", + timestamp.format("%Y-%m-%d %H:%M:%S"), + request_id, + user_str, + power_str, + context_str, + message + ); + + // Write to info log if configured + if let Some(ref file_mutex) = self.info_file { + let mut file = file_mutex.lock().await; + file.write_all(log_entry.as_bytes()) + .context("Failed to write to info log")?; + file.flush().context("Failed to flush info log")?; + } + + // Write to combined log and apply filters + self.write_combined_and_filter(&log_entry).await?; + + Ok(()) + } + + fn mask_sensitive_data(&self, value: &mut Value) { + match value { + Value::Object(map) => { + for (key, val) in map.iter_mut() { + // Always mask password and pin + if key == "password" || key == "pin" { + *val = Value::String("***MASKED***".to_string()); + } + // Also mask any configured sensitive fields + else if self.sensitive_fields.contains(key) { + *val = Value::String("***MASKED***".to_string()); + } else { + self.mask_sensitive_data(val); + } + } + } + Value::Array(arr) => { + for item in arr.iter_mut() { + self.mask_sensitive_data(item); + } + } + _ => {} + } + } +} diff --git a/src/logging/mod.rs b/src/logging/mod.rs new file mode 100644 index 0000000..a6752e2 --- /dev/null +++ b/src/logging/mod.rs @@ -0,0 +1,3 @@ +pub mod logger; + +pub use logger::AuditLogger; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..17702f4 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,359 @@ +// Main entry point for the SeckelAPI server +use anyhow::Result; +use axum::{ + http::Method, + response::Json, + routing::{get, post}, + Router, +}; +use chrono::Utc; +use serde_json::json; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::time::{sleep, Duration}; +use tower_governor::{ + governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer, +}; +use tower_http::cors::{Any, CorsLayer}; +use tracing::{error, info, Level}; +use tracing_subscriber; + +mod auth; +mod config; +mod db; +mod logging; +mod models; +mod permissions; +mod routes; +mod scheduler; +mod sql; + +use auth::SessionManager; +use config::{Config, DatabaseConfig}; +use logging::AuditLogger; +use scheduler::QueryScheduler; + +use axum::extract::State; +use db::{Database, DatabaseInitError}; +use permissions::RBACManager; + +// Version pulled from Cargo.toml +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +// Root handler - simple version info +async fn root() -> Json<serde_json::Value> { + Json(json!({ + "name": "SeckelAPI", + "version": VERSION + })) +} + +// Health check handler with database connectivity test +async fn health_check(State(state): State<AppState>) -> Json<serde_json::Value> { + // Test database connectivity + let db_status = match sqlx::query("SELECT 1") + .fetch_one(state.database.pool()) + .await + { + Ok(_) => "connected", + Err(_) => "disconnected", + }; + + Json(json!({ + "status": "running", + "message": format!("SeckelAPI v{}: The schizophrenic database application JSON API", VERSION), + "database": db_status + })) +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize tracing + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); + + info!("Starting SeckelAPI Server..."); + + // Load configuration + let config = Config::load()?; + let config_arc = Arc::new(config.clone()); + info!("- Configuration might have loaded successfully"); + + // Initialize database connection + let database = wait_for_database(&config.database).await?; + info!("- Database said 'Hello there!'"); + + // Initialize session manager + let session_manager = SessionManager::new(config_arc.clone()); + info!("- Session manager didn't crash on startup (yay!)"); + + // Spawn background task for session cleanup + let cleanup_manager = session_manager.clone(); + let cleanup_interval = config.security.session_cleanup_interval_minutes; + tokio::spawn(async move { + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval * 60)); + loop { + interval.tick().await; + cleanup_manager.cleanup_expired_sessions(); + info!("- Ran session cleanup task"); + } + }); + info!( + "- Background session cleanup task spawned (interval: {}min)", + cleanup_interval + ); + + // Initialize RBAC manager + let rbac = RBACManager::new(&config); + info!("- Overcomplicated RBAC manager has initialized"); + + // Initialize audit logger + let logging = AuditLogger::new( + config.logging.request_log.clone(), + config.logging.query_log.clone(), + config.logging.error_log.clone(), + config.logging.warning_log.clone(), + config.logging.info_log.clone(), + config.logging.combined_log.clone(), + config.logging.mask_passwords, + config.logging.sensitive_fields.clone(), + config.logging.custom_filters.clone(), + )?; + info!("- CIA Surveillance Service Agents have been initialized on your system (just kidding, just the logging stack)"); + if !config.logging.custom_filters.is_empty() { + let enabled_count = config + .logging + .custom_filters + .iter() + .filter(|f| f.enabled) + .count(); + info!("- {} custom log filter(s) active", enabled_count); + } + + spawn_database_monitor(database.clone(), config.database.clone(), logging.clone()); + + // Initialize and spawn scheduled query tasks + let scheduler = QueryScheduler::new(config_arc.clone(), database.clone(), logging.clone()); + scheduler.spawn_tasks(); + + // Create CORS layer + let cors = CorsLayer::new() + .allow_methods([Method::GET, Method::POST]) + .allow_headers(Any) + .allow_origin(Any); + + // Build auth routes with rate limiting + let mut auth_routes = Router::new() + .route("/auth/login", post(routes::auth::login)) + .route("/auth/logout", post(routes::auth::logout)) + .route("/auth/status", get(routes::auth::status)); + + if config.security.enable_rate_limiting { + info!( + "- Auth rate limiting set to: {}/min, {}/sec per IP", + config.security.auth_rate_limit_per_minute, config.security.auth_rate_limit_per_second + ); + let auth_governor_conf = Arc::new( + GovernorConfigBuilder::default() + .per_second(config.security.auth_rate_limit_per_second.max(1) as u64) + .burst_size(config.security.auth_rate_limit_per_minute.max(1)) + .key_extractor(SmartIpKeyExtractor) + .finish() + .expect("Failed to build auth rate limiter config ... you sure its configured?"), + ); + auth_routes = auth_routes.layer(GovernorLayer::new(auth_governor_conf)); + } else { + info!("- Auth rate limiting DISABLED (dont run in production like this plz) "); + } + + // Build API routes with rate limiting + let mut api_routes = Router::new() + .route("/query", post(routes::query::execute_query)) + .route("/permissions", get(routes::query::get_permissions)) + .route( + "/preferences", + post(routes::preferences::handle_preferences), + ); + + if config.security.enable_rate_limiting { + info!( + "- API rate limiting set to: {}/min, {}/sec per IP", + config.security.api_rate_limit_per_minute, config.security.api_rate_limit_per_second + ); + let api_governor_conf = Arc::new( + GovernorConfigBuilder::default() + .per_second(config.security.api_rate_limit_per_second.max(1) as u64) + .burst_size(config.security.api_rate_limit_per_minute.max(1)) + .key_extractor(SmartIpKeyExtractor) + .finish() + .expect( + "Failed to build API rate limiter config, ugh ... you sure its configured?", + ), + ); + api_routes = api_routes.layer(GovernorLayer::new(api_governor_conf)); + } + + // das so routes was es chan + let app = Router::new() + .route("/", get(root)) + .route("/health", get(health_check)) + .merge(auth_routes) + .merge(api_routes) + .layer(tower_http::limit::RequestBodyLimitLayer::new( + config.server.request_body_limit_mb * 1024 * 1024, + )) + .layer(cors) + .with_state(AppState { + config: config.clone(), + database, + session_manager, + rbac, + logging, + }); + let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port)); + info!( + "- SeckelAPI somehow started and should now be listening on {} :)", + addr + ); + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve( + listener, + app.into_make_service_with_connect_info::<SocketAddr>(), + ) + .await?; + Ok(()) +} + +async fn wait_for_database(config: &DatabaseConfig) -> Result<Database> { + let retry_delay = config.connection_timeout_wait.max(1); + + loop { + match Database::new(config).await { + Ok(db) => return Ok(db), + Err(DatabaseInitError::Retryable(err)) => { + error!( + "Database unavailable (retrying in {}s): {}", + retry_delay, err + ); + sleep(Duration::from_secs(retry_delay)).await; + } + Err(DatabaseInitError::Fatal(err)) => { + error!("Fatal database configuration error: {}", err); + return Err(err); + } + } + } +} + +fn spawn_database_monitor(database: Database, db_config: DatabaseConfig, logging: AuditLogger) { + let heartbeat_interval = db_config.connection_check.max(5); + let retry_delay = db_config.connection_timeout_wait.max(1); + + tokio::spawn(async move { + let mut ticker = tokio::time::interval(Duration::from_secs(heartbeat_interval)); + loop { + ticker.tick().await; + match sqlx::query("SELECT 1").fetch_one(database.pool()).await { + Ok(_) => { + if !database.is_available() { + info!("Database connectivity restored"); + log_database_event( + &logging, + "database_reconnect", + "Database connectivity restored", + false, + ) + .await; + } + database.mark_available(); + } + Err(err) => { + database.mark_unavailable(); + error!( + "Database heartbeat failed (retrying every {}s): {}", + retry_delay, err + ); + + log_database_event( + &logging, + "database_heartbeat_failure", + &format!("Heartbeat failed: {}", err), + true, + ) + .await; + + loop { + sleep(Duration::from_secs(retry_delay)).await; + match sqlx::query("SELECT 1").fetch_one(database.pool()).await { + Ok(_) => { + database.mark_available(); + info!("Database reconnected successfully"); + log_database_event( + &logging, + "database_reconnect", + "Database reconnected successfully", + false, + ) + .await; + break; + } + Err(retry_err) => { + error!( + "Database still unavailable (retrying in {}s): {}", + retry_delay, retry_err + ); + log_database_event( + &logging, + "database_retry_failure", + &format!("Database retry failed: {}", retry_err), + true, + ) + .await; + } + } + } + } + } + } + }); +} + +async fn log_database_event(logging: &AuditLogger, context: &str, message: &str, as_error: bool) { + let request_id = AuditLogger::generate_request_id(); + let timestamp = Utc::now(); + let result = if as_error { + logging + .log_error( + &request_id, + timestamp, + message, + Some(context), + Some("system"), + None, + ) + .await + } else { + logging + .log_info( + &request_id, + timestamp, + message, + Some(context), + Some("system"), + None, + ) + .await + }; + + if let Err(err) = result { + error!("Failed to record database event ({}): {}", context, err); + } +} +#[derive(Clone)] +pub struct AppState { + pub config: Config, + pub database: Database, + pub session_manager: SessionManager, + pub rbac: RBACManager, + pub logging: AuditLogger, +} diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..4d9f734 --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,294 @@ +// Data models for SeckelAPI +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// Authentication models +#[derive(Debug, Deserialize, Serialize)] +pub struct LoginRequest { + pub method: AuthMethod, + pub username: Option<String>, + pub password: Option<String>, + pub pin: Option<String>, + pub login_string: Option<String>, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum AuthMethod { + Password, + Pin, + Token, +} + +#[derive(Debug, Serialize)] +pub struct LoginResponse { + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub token: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option<UserInfo>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<String>, +} + +#[derive(Debug, Serialize)] +pub struct UserInfo { + pub id: i32, + pub username: String, + pub name: String, + pub role: String, + pub power: i32, +} + +// Database query models +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct QueryRequest { + // Single query mode (when queries is None) + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option<QueryAction>, + #[serde(skip_serializing_if = "Option::is_none")] + pub table: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub columns: Option<Vec<String>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, + // Enhanced WHERE clause - supports both simple and complex conditions + #[serde(rename = "where", skip_serializing_if = "Option::is_none")] + pub where_clause: Option<serde_json::Value>, + // New structured filter for complex queries + #[serde(skip_serializing_if = "Option::is_none")] + pub filter: Option<FilterCondition>, + // JOIN support - allows multi-table queries with permission validation + #[serde(skip_serializing_if = "Option::is_none")] + pub joins: Option<Vec<Join>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub offset: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub order_by: Option<Vec<OrderBy>>, + + // Batch mode (when queries is Some) - action/table apply to ALL queries in batch + #[serde(skip_serializing_if = "Option::is_none")] + pub queries: Option<Vec<BatchQuery>>, + /// Whether to rollback on error in batch mode (defaults to config setting) + #[serde(skip_serializing_if = "Option::is_none")] + pub rollback_on_error: Option<bool>, +} + +/// Individual query in a batch - inherits action/table from parent QueryRequest +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct BatchQuery { + // Only the variable parts per query - no action/table duplication + #[serde(skip_serializing_if = "Option::is_none")] + pub columns: Option<Vec<String>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, + #[serde(rename = "where", skip_serializing_if = "Option::is_none")] + pub where_clause: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub filter: Option<FilterCondition>, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub offset: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub order_by: Option<Vec<OrderBy>>, +} + +/// JOIN specification for multi-table queries +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Join { + /// Type of join (INNER, LEFT, RIGHT) + #[serde(rename = "type")] + pub join_type: JoinType, + /// Table to join with + pub table: String, + /// Join condition (e.g., "assets.category_id = categories.id") + pub on: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(rename_all = "UPPERCASE")] +pub enum JoinType { + Inner, + Left, + Right, +} + +/// Enhanced filter condition supporting operators and nested logic +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(untagged)] +pub enum FilterCondition { + /// Simple condition: {"column": "name", "op": "=", "value": "John"} + Simple { + column: String, + #[serde(rename = "op")] + operator: FilterOperator, + value: serde_json::Value, + }, + /// Logical AND/OR: {"and": [condition1, condition2]} + Logical { + #[serde(rename = "and")] + and_conditions: Option<Vec<FilterCondition>>, + #[serde(rename = "or")] + or_conditions: Option<Vec<FilterCondition>>, + }, + /// NOT condition: {"not": condition} + Not { not: Box<FilterCondition> }, +} + +/// Supported WHERE clause operators +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum FilterOperator { + #[serde(rename = "=")] + Eq, // Equal + #[serde(rename = "!=")] + Ne, // Not equal + #[serde(rename = ">")] + Gt, // Greater than + #[serde(rename = ">=")] + Gte, // Greater than or equal + #[serde(rename = "<")] + Lt, // Less than + #[serde(rename = "<=")] + Lte, // Less than or equal + Like, // LIKE pattern matching + #[serde(rename = "not_like")] + NotLike, // NOT LIKE + In, // IN (value1, value2, ...) + #[serde(rename = "not_in")] + NotIn, // NOT IN (...) + #[serde(rename = "is_null")] + IsNull, // IS NULL + #[serde(rename = "is_not_null")] + IsNotNull, // IS NOT NULL + Between, // BETWEEN value1 AND value2 +} + +/// ORDER BY clause +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct OrderBy { + pub column: String, + #[serde(default = "default_order_direction")] + pub direction: OrderDirection, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "UPPERCASE")] +pub enum OrderDirection { + ASC, + DESC, +} + +fn default_order_direction() -> OrderDirection { + OrderDirection::ASC +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum QueryAction { + Select, + Insert, + Update, + Delete, + Count, +} + +#[derive(Debug, Serialize, Clone)] +pub struct QueryResponse { + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub rows_affected: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub warning: Option<String>, + // Batch results (when queries field was used) + #[serde(skip_serializing_if = "Option::is_none")] + pub results: Option<Vec<QueryResponse>>, +} + +// Database entities +#[derive(Debug, sqlx::FromRow)] +#[allow(dead_code)] // Fields are used for database serialization +pub struct User { + pub id: i32, + pub name: String, + pub username: String, + pub password: String, + pub pin_code: Option<String>, + pub login_string: Option<String>, + pub role_id: i32, + pub email: Option<String>, + pub phone: Option<String>, + pub notes: Option<String>, + pub active: bool, + pub last_login_date: Option<DateTime<Utc>>, + pub created_date: DateTime<Utc>, + pub password_reset_token: Option<String>, + pub password_reset_expiry: Option<DateTime<Utc>>, +} + +#[derive(Debug, sqlx::FromRow)] +#[allow(dead_code)] // Fields are used for database serialization +pub struct Role { + pub id: i32, + pub name: String, + pub power: i32, + pub created_at: DateTime<Utc>, +} + +// Session management +#[derive(Debug, Clone)] +pub struct Session { + pub user_id: i32, + pub username: String, + pub role_id: i32, + pub role_name: String, + pub power: i32, + pub created_at: DateTime<Utc>, + pub last_accessed: DateTime<Utc>, +} + +// Permission types +#[derive(Debug, Clone, PartialEq)] +pub enum Permission { + Read, + Write, + ReadWrite, + None, +} + +impl Permission { + pub fn from_str(s: &str) -> Self { + match s { + "r" => Permission::Read, + "w" => Permission::Write, + "rw" => Permission::ReadWrite, + _ => Permission::None, + } + } + + pub fn can_read(&self) -> bool { + matches!(self, Permission::Read | Permission::ReadWrite) + } + + pub fn can_write(&self) -> bool { + matches!(self, Permission::Write | Permission::ReadWrite) + } +} + +// Permissions response +#[derive(Debug, Serialize)] +pub struct PermissionsResponse { + pub success: bool, + pub user: UserInfo, + pub permissions: HashMap<String, String>, + pub security_clearance: Option<String>, + pub user_settings_access: String, +} diff --git a/src/permissions/mod.rs b/src/permissions/mod.rs new file mode 100644 index 0000000..4c43932 --- /dev/null +++ b/src/permissions/mod.rs @@ -0,0 +1,3 @@ +pub mod rbac; + +pub use rbac::RBACManager; diff --git a/src/permissions/rbac.rs b/src/permissions/rbac.rs new file mode 100644 index 0000000..c1835e5 --- /dev/null +++ b/src/permissions/rbac.rs @@ -0,0 +1,119 @@ +// Role-based access control module +use crate::config::Config; +use crate::models::{Permission, QueryAction}; +use std::collections::HashMap; + +#[derive(Clone)] +pub struct RBACManager { + permissions: HashMap<i32, HashMap<String, Permission>>, +} + +impl RBACManager { + pub fn new(config: &Config) -> Self { + let mut permissions = HashMap::new(); + + // Parse permissions from new config format + for (power_str, power_perms) in &config.permissions.power_levels { + if let Ok(power) = power_str.parse::<i32>() { + let mut table_permissions = HashMap::new(); + + for perm in &power_perms.basic_rules { + let parts: Vec<&str> = perm.split(':').collect(); + if parts.len() == 2 { + let table = parts[0]; + let permission = Permission::from_str(parts[1]); + + if table == "*" { + // Grant permission to all known tables from config + let all_tables = config.get_known_tables(); + + for table_name in all_tables { + // Don't overwrite existing specific permissions + if !table_permissions.contains_key(&table_name) { + table_permissions.insert(table_name, permission.clone()); + } + } + } else { + table_permissions.insert(table.to_string(), permission); + } + } + } + + permissions.insert(power, table_permissions); + } + } + + Self { permissions } + } + + fn base_table_name(table: &str) -> String { + // Accept formats: "table", "table alias", "table AS alias" (AS case-insensitive) + let parts: Vec<&str> = table.trim().split_whitespace().collect(); + if parts.len() >= 3 && parts[1].eq_ignore_ascii_case("AS") { + parts[0].to_string() + } else { + // For simple "table alias" or just "table", take the first token + parts.get(0).cloned().unwrap_or("").to_string() + } + } + + pub fn check_permission( + &self, + config: &Config, + power: i32, + table: &str, + action: &QueryAction, + ) -> bool { + // Normalize potential alias usage to the base table name for permission lookup + let base_table = Self::base_table_name(table); + + if let Some(table_permissions) = self.permissions.get(&power) { + if let Some(permission) = table_permissions.get(&base_table) { + // Check if table is read-only through config and enforce read-only constraint + if config.is_read_only_table(&base_table) && !matches!(action, QueryAction::Select) + { + return false; // Write operations not allowed on read-only tables + } + + match action { + QueryAction::Select | QueryAction::Count => permission.can_read(), + QueryAction::Insert | QueryAction::Update | QueryAction::Delete => { + permission.can_write() + } + } + } else { + false // No permission for this table + } + } else { + false // No permissions for this power level + } + } + + pub fn get_table_permissions(&self, config: &Config, power: i32) -> HashMap<String, String> { + let mut result = HashMap::new(); + + if let Some(table_permissions) = self.permissions.get(&power) { + for (table, permission) in table_permissions { + let perm_str = match permission { + Permission::Read => "r", + Permission::Write => "w", + Permission::ReadWrite => "rw", + Permission::None => "", + }; + + if !perm_str.is_empty() { + result.insert(table.clone(), perm_str.to_string()); + } + } + } + + // Ensure read-only tables from config are marked as read-only + for table in &config.get_known_tables() { + if config.is_read_only_table(table) && result.contains_key(table) { + result.insert(table.clone(), "r".to_string()); + } + } + + result + } +} diff --git a/src/routes/auth.rs b/src/routes/auth.rs new file mode 100644 index 0000000..c124e03 --- /dev/null +++ b/src/routes/auth.rs @@ -0,0 +1,513 @@ +// Authentication routes +use axum::{ + extract::{ConnectInfo, State}, + http::StatusCode, + Json, +}; +use chrono::Utc; +use std::net::SocketAddr; +use tracing::{error, warn}; + +use crate::logging::AuditLogger; +use crate::models::{AuthMethod, LoginRequest, LoginResponse, UserInfo}; +use crate::{auth, AppState}; + +pub async fn login( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + Json(payload): Json<LoginRequest>, +) -> Result<Json<LoginResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + None, + None, + "/auth/login", + &serde_json::to_value(&payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Validate request based on auth method + let (user, role) = match payload.method { + AuthMethod::Password => { + // Password auth - allowed from any IP + if let (Some(username), Some(password)) = (&payload.username, &payload.password) { + match auth::password::authenticate_password( + state.database.pool(), + username, + password, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed password authentication for user: {} - Invalid credentials", + username + ), + Some("password_auth"), + Some(username), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during password authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Password auth error: {}", e), + Some("authentication"), + payload.username.as_deref(), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("Password authentication attempted without username or password"); + super::log_warning_async( + &state.logging, + &request_id, + "Password authentication attempted without username or password", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + AuthMethod::Pin => { + // PIN auth - only from whitelisted IPs + if !state.config.is_pin_ip_whitelisted(&client_ip) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "PIN authentication attempted from non-whitelisted IP: {}", + client_ip + ), + Some("security_violation"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + + if let (Some(username), Some(pin)) = (&payload.username, &payload.pin) { + match auth::pin::authenticate_pin( + state.database.pool(), + username, + pin, + &state.config.security, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!("Failed PIN authentication for user: {} - Invalid PIN or PIN not configured", username), + Some("pin_auth"), + Some(username), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during PIN authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("PIN auth error: {}", e), + Some("authentication"), + payload.username.as_deref(), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("PIN authentication attempted without username or PIN"); + super::log_warning_async( + &state.logging, + &request_id, + "PIN authentication attempted without username or PIN", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + AuthMethod::Token => { + // Token/RFID auth - only from whitelisted IPs + if !state.config.is_string_ip_whitelisted(&client_ip) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Token authentication attempted from non-whitelisted IP: {}", + client_ip + ), + Some("security_violation"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + + if let Some(login_string) = &payload.login_string { + match auth::token::authenticate_token( + state.database.pool(), + login_string, + &state.config.security, + ) + .await + { + Ok(Some((user, role))) => (user, role), + Ok(None) => { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed token authentication for login_string: {} - Invalid token", + login_string + ), + Some("token_auth"), + Some(login_string), + Some(0), + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + Err(e) => { + error!( + "[{}] Database error during token authentication: {}", + request_id, e + ); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Token auth error: {}", e), + Some("authentication"), + Some(login_string), + None, + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } else { + warn!("Token authentication attempted without login_string"); + super::log_warning_async( + &state.logging, + &request_id, + "Token authentication attempted without login_string", + Some("invalid_request"), + None, + None, + ); + return Ok(Json(LoginResponse { + success: false, + token: None, + user: None, + error: Some("Authentication failed".to_string()), + })); + } + } + }; + + // Create session token + let token = state.session_manager.create_session( + user.id, + user.username.clone(), + role.id, + role.name.clone(), + role.power, + ); + + // Log successful login + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Successful login for user: {} ({})", + user.username, user.name + ), + Some("authentication"), + Some(&user.username), + Some(role.power), + ); + + Ok(Json(LoginResponse { + success: true, + token: Some(token), + user: Some(UserInfo { + id: user.id, + username: user.username, + name: user.name, + role: role.name, + power: role.power, + }), + error: None, + })) +} + +pub async fn logout( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: axum::http::HeaderMap, +) -> Result<Json<serde_json::Value>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract token from Authorization header + let token = match headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) { + Some(t) => t, + None => { + return Ok(Json(serde_json::json!({ + "success": false, + "error": "No authorization token provided" + }))); + } + }; + + // Get username for logging before removing session + let username = state + .session_manager + .get_session(&token) + .map(|s| s.username.clone()); + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + username.as_deref(), + None, + "/auth/logout", + &serde_json::json!({"action": "logout"}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + let removed = state.session_manager.remove_session(&token); + + if removed { + // Log successful logout + super::log_info_async( + &state.logging, + &request_id, + &format!( + "User {} logged out successfully", + username.as_deref().unwrap_or("unknown") + ), + Some("authentication"), + username.as_deref(), + None, + ); + Ok(Json(serde_json::json!({ + "success": true, + "message": "Logged out successfully" + }))) + } else { + Ok(Json(serde_json::json!({ + "success": false, + "error": "Invalid or expired token" + }))) + } +} + +pub async fn status( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: axum::http::HeaderMap, +) -> Result<Json<serde_json::Value>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract token from Authorization header + let token_opt = headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }); + + let token = match token_opt { + Some(t) => t, + None => { + return Ok(Json(serde_json::json!({ + "success": false, + "valid": false, + "error": "No authorization token provided" + }))); + } + }; + + // Check session validity + match state.session_manager.get_session(&token) { + Some(session) => { + let now = Utc::now(); + let timeout_minutes = state.config.get_session_timeout(session.power); + let elapsed = (now - session.last_accessed).num_seconds(); + let timeout_seconds = (timeout_minutes * 60) as i64; + let remaining_seconds = timeout_seconds - elapsed; + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/auth/status", + &serde_json::json!({"token_provided": true}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + Ok(Json(serde_json::json!({ + "success": true, + "valid": true, + "user": { + "id": session.user_id, + "username": session.username, + "name": session.username, + "role": session.role_name, + "power": session.power + }, + "session": { + "created_at": session.created_at.to_rfc3339(), + "last_accessed": session.last_accessed.to_rfc3339(), + "timeout_minutes": timeout_minutes, + "remaining_seconds": remaining_seconds.max(0), + "expires_at": (session.last_accessed + chrono::Duration::minutes(timeout_minutes as i64)).to_rfc3339() + } + }))) + } + None => { + // Log the request for invalid token + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + None, + None, + "/auth/status", + &serde_json::json!({"token_provided": true, "valid": false}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + Ok(Json(serde_json::json!({ + "success": true, + "valid": false, + "message": "Session expired or invalid" + }))) + } + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs new file mode 100644 index 0000000..b81a0a1 --- /dev/null +++ b/src/routes/mod.rs @@ -0,0 +1,111 @@ +pub mod auth; +pub mod preferences; +pub mod query; + +use crate::logging::logger::AuditLogger; +use tracing::{error, info, warn}; + +/// Helper function to log errors to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_error_async( + logging: &AuditLogger, + request_id: &str, + error_msg: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + error!("[{}] {}", request_id, error_msg); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let error_msg = error_msg.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_error( + &req_id, + chrono::Utc::now(), + &error_msg, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} + +/// Helper function to log warnings to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_warning_async( + logging: &AuditLogger, + request_id: &str, + message: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + warn!("[{}] {}", request_id, message); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let message = message.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_warning( + &req_id, + chrono::Utc::now(), + &message, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} + +/// Helper function to log info messages to both console (tracing) and file (AuditLogger) +/// This eliminates code duplication across route handlers +pub fn log_info_async( + logging: &AuditLogger, + request_id: &str, + message: &str, + context: Option<&str>, + username: Option<&str>, + power: Option<i32>, +) { + // Log to console immediately + info!("[{}] {}", request_id, message); + + // Clone everything needed for the async task + let logging = logging.clone(); + let req_id = request_id.to_string(); + let message = message.to_string(); + let context = context.map(|s| s.to_string()); + let username = username.map(|s| s.to_string()); + + // Spawn async task to log to file + tokio::spawn(async move { + let _ = logging + .log_info( + &req_id, + chrono::Utc::now(), + &message, + context.as_deref(), + username.as_deref(), + power, + ) + .await; + }); +} diff --git a/src/routes/preferences.rs b/src/routes/preferences.rs new file mode 100644 index 0000000..e58d823 --- /dev/null +++ b/src/routes/preferences.rs @@ -0,0 +1,423 @@ +// User preferences routes +use axum::{ + extract::{ConnectInfo, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use chrono::Utc; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use tracing::{error, warn}; + +use crate::config::UserSettingsAccess; +use crate::logging::AuditLogger; +use crate::AppState; + +// Request/Response structures matching the query route pattern +#[derive(Debug, Deserialize)] +pub struct PreferencesRequest { + pub action: String, // "get", "set", "reset" + pub user_id: Option<i32>, // For admin access to other users + pub preferences: Option<serde_json::Value>, +} + +#[derive(Debug, Serialize)] +pub struct PreferencesResponse { + pub success: bool, + pub preferences: Option<serde_json::Value>, + pub error: Option<String>, +} + +/// Extract token from Authorization header +fn extract_token(headers: &HeaderMap) -> Option<String> { + headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) +} + +/// POST /preferences - Handle all preference operations (get, set, reset) +pub async fn handle_preferences( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, + Json(payload): Json<PreferencesRequest>, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some( + "Please stop trying to access this resource without signing in".to_string(), + ), + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Session not found".to_string()), + })); + } + }; + + // Determine target user ID + let target_user_id = payload.user_id.unwrap_or(session.user_id); + + // Get user's permission level for preferences + let user_settings_permission = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .map(|p| &p.user_settings_access) + .unwrap_or(&state.config.security.default_user_settings_access); + + // Check permissions for cross-user access + if target_user_id != session.user_id { + if *user_settings_permission != UserSettingsAccess::ReadWriteAll { + // Log security warning + super::log_warning_async( + &state.logging, + &request_id, + &format!("User {} (power {}) attempted to access preferences of user {} without permission", + session.username, session.power, target_user_id), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } + + // Check write permissions for set/reset actions + if payload.action == "set" || payload.action == "reset" { + if target_user_id == session.user_id { + // Writing own preferences - need at least ReadWriteOwn + if *user_settings_permission == UserSettingsAccess::ReadOwnOnly { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } else { + // Writing others' preferences - need ReadWriteAll + if *user_settings_permission != UserSettingsAccess::ReadWriteAll { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Insufficient permissions".to_string()), + })); + } + } + } + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/preferences", + &serde_json::json!({"action": payload.action, "target_user_id": target_user_id}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Handle the action + match payload.action.as_str() { + "get" => { + handle_get_preferences( + state, + request_id, + target_user_id, + session.username.clone(), + session.power, + ) + .await + } + "set" => { + handle_set_preferences( + state, + request_id, + target_user_id, + payload.preferences, + session.username.clone(), + session.power, + ) + .await + } + "reset" => { + handle_reset_preferences( + state, + request_id, + target_user_id, + session.username.clone(), + session.power, + ) + .await + } + _ => Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some(format!("Invalid action: {}", payload.action)), + })), + } +} + +async fn handle_get_preferences( + state: AppState, + request_id: String, + user_id: i32, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + // Cast JSON column to CHAR to get string representation + let query = "SELECT CAST(preferences AS CHAR) FROM users WHERE id = ? AND active = TRUE"; + let row: Option<(Option<String>,)> = sqlx::query_as(query) + .bind(user_id) + .fetch_optional(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error fetching preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let preferences = match row { + Some((Some(prefs_str),)) => serde_json::from_str(&prefs_str).unwrap_or_else(|e| { + warn!( + "[{}] Failed to parse preferences JSON for user {}: {}", + request_id, user_id, e + ); + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Failed to parse preferences JSON for user {}: {}", + user_id, e + ), + Some("data_integrity"), + Some(&username), + Some(power), + ); + serde_json::json!({}) + }), + _ => serde_json::json!({}), + }; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!( + "User {} retrieved preferences for user {}", + username, user_id + ), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(preferences), + error: None, + })) +} + +async fn handle_set_preferences( + state: AppState, + request_id: String, + user_id: i32, + new_preferences: Option<serde_json::Value>, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let new_prefs = match new_preferences { + Some(prefs) => prefs, + None => { + return Ok(Json(PreferencesResponse { + success: false, + preferences: None, + error: Some("Missing preferences field".to_string()), + })); + } + }; + + // Get current preferences for merging + let query = "SELECT CAST(preferences AS CHAR) FROM users WHERE id = ? AND active = TRUE"; + let row: Option<(Option<String>,)> = sqlx::query_as(query) + .bind(user_id) + .fetch_optional(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error fetching preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Deep merge the preferences + let mut merged_prefs = match row { + Some((Some(prefs_str),)) => { + serde_json::from_str(&prefs_str).unwrap_or_else(|_| serde_json::json!({})) + } + _ => serde_json::json!({}), + }; + + // Merge function + fn merge_json(base: &mut serde_json::Value, update: &serde_json::Value) { + if let (Some(base_obj), Some(update_obj)) = (base.as_object_mut(), update.as_object()) { + for (key, value) in update_obj { + if let Some(base_value) = base_obj.get_mut(key) { + if base_value.is_object() && value.is_object() { + merge_json(base_value, value); + } else { + *base_value = value.clone(); + } + } else { + base_obj.insert(key.clone(), value.clone()); + } + } + } + } + + merge_json(&mut merged_prefs, &new_prefs); + + // Save to database + let prefs_str = serde_json::to_string(&merged_prefs).map_err(|e| { + error!("[{}] Failed to serialize preferences: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let update_query = "UPDATE users SET preferences = ? WHERE id = ? AND active = TRUE"; + sqlx::query(update_query) + .bind(&prefs_str) + .bind(user_id) + .execute(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error updating preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!("User {} updated preferences for user {}", username, user_id), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(merged_prefs), + error: None, + })) +} + +async fn handle_reset_preferences( + state: AppState, + request_id: String, + user_id: i32, + username: String, + power: i32, +) -> Result<Json<PreferencesResponse>, StatusCode> { + let query = "UPDATE users SET preferences = NULL WHERE id = ? AND active = TRUE"; + sqlx::query(query) + .bind(user_id) + .execute(state.database.pool()) + .await + .map_err(|e| { + let error_msg = format!( + "Database error resetting preferences for user {}: {}", + user_id, e + ); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("database"), + Some(&username), + Some(power), + ); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Log user action + super::log_info_async( + &state.logging, + &request_id, + &format!("User {} reset preferences for user {}", username, user_id), + Some("user_action"), + Some(&username), + Some(power), + ); + + Ok(Json(PreferencesResponse { + success: true, + preferences: Some(serde_json::json!({})), + error: None, + })) +} diff --git a/src/routes/query.rs b/src/routes/query.rs new file mode 100644 index 0000000..9814215 --- /dev/null +++ b/src/routes/query.rs @@ -0,0 +1,3255 @@ +// Query routes and execution +use anyhow::{Context, Result}; +use axum::{ + extract::{ConnectInfo, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use chrono::Utc; +use rand::Rng; +use serde_json::Value; +use sqlx::{Column, Row}; +use std::collections::HashMap; +use std::net::SocketAddr; +use tracing::{error, info, warn}; + +use crate::logging::AuditLogger; +use crate::models::{PermissionsResponse, QueryAction, QueryRequest, QueryResponse, UserInfo}; +use crate::sql::{ + build_filter_clause, build_legacy_where_clause, build_order_by_clause, validate_column_name, + validate_column_names, validate_table_name, +}; +use crate::AppState; + +// Helper function to extract token from Authorization header +fn extract_token(headers: &HeaderMap) -> Option<String> { + headers + .get("Authorization") + .and_then(|header| header.to_str().ok()) + .and_then(|auth_str| { + if auth_str.starts_with("Bearer ") { + Some(auth_str[7..].to_string()) + } else { + None + } + }) +} + +fn database_unavailable_response(request_id: &str) -> QueryResponse { + QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Database temporarily unavailable, try again in a moment [request_id: {}]", + request_id + )), + warning: None, + results: None, + } +} + +fn database_unavailable_batch_response(request_id: &str) -> QueryResponse { + let mut base = database_unavailable_response(request_id); + base.results = Some(vec![]); + base +} + +async fn log_database_unavailable_event( + logger: &AuditLogger, + request_id: &str, + username: Option<&str>, + power: Option<i32>, + detail: &str, +) { + if let Err(err) = logger + .log_error( + request_id, + Utc::now(), + detail, + Some("database_unavailable"), + username, + power, + ) + .await + { + error!("[{}] Failed to record database outage: {}", request_id, err); + } +} + +pub async fn execute_query( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, + Json(payload): Json<QueryRequest>, +) -> Result<Json<QueryResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some( + "Please stop trying to access this resource without signing in".to_string(), + ), + warning: None, + results: None, + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Session not found".to_string()), + warning: None, + results: None, + })); + } + }; + + // Detect batch mode - if queries field is present, handle as batch operation + if payload.queries.is_some() { + // SECURITY: Check if user has permission to use batch operations + let power_perms = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .ok_or(StatusCode::FORBIDDEN)?; + + if !power_perms.allow_batch_operations { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "User {} (power {}) attempted batch operation without permission", + session.username, session.power + ), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Batch operations not permitted for your role".to_string()), + warning: None, + results: None, + })); + } + + return execute_batch_mode(state, session, request_id, timestamp, client_ip, &payload) + .await; + } + + // Validate input for very basic security vulnerabilities (null bytes, etc.) + if let Err(security_error) = validate_input_security(&payload) { + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "Security validation failed for user {}: {}", + session.username, security_error + ), + Some("security_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Invalid input detected, how did you even manage to do that? [request_id: {}]", + request_id + )), + warning: None, + results: None, + })); + } + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/query", + &serde_json::to_value(&payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + // Clone payload before extracting fields (to avoid partial move issues) + let payload_clone = payload.clone(); + + // Single query mode - validate required fields + let action = payload.action.ok_or_else(|| { + let error_msg = "Missing action field in single query mode"; + super::log_error_async( + &state.logging, + &request_id, + error_msg, + Some("request_validation"), + Some(&session.username), + Some(session.power), + ); + StatusCode::BAD_REQUEST + })?; + + let table = payload.table.ok_or_else(|| { + let error_msg = "Missing table field in single query mode"; + super::log_error_async( + &state.logging, + &request_id, + error_msg, + Some("request_validation"), + Some(&session.username), + Some(session.power), + ); + StatusCode::BAD_REQUEST + })?; + + // SECURITY: Validate table name before any operations + if let Err(e) = validate_table_name(&table, &state.config) { + let error_msg = format!("Invalid table name '{}': {}", table, e); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("table_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Invalid table name [request_id: {}]", request_id)), + warning: None, + results: None, + })); + } + + // SECURITY: Validate column names if specified + if let Some(ref columns) = payload.columns { + if let Err(e) = validate_column_names(columns) { + let error_msg = format!("Invalid column names on table '{}': {}", table, e); + super::log_error_async( + &state.logging, + &request_id, + &error_msg, + Some("column_validation"), + Some(&session.username), + Some(session.power), + ); + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Invalid column name: {} [request_id: {}]", + e, request_id + )), + warning: None, + results: None, + })); + } + } + + // Check permissions (after validation to avoid leaking table existence) + if !state + .rbac + .check_permission(&state.config, session.power, &table, &action) + { + let action_str = match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }; + + super::log_warning_async( + &state.logging, + &request_id, + &format!( + "User {} attempted unauthorized {} on table {}", + session.username, action_str, table + ), + Some("authorization"), + Some(&session.username), + Some(session.power), + ); + + // Log security violation + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Permission denied: {} on table {}", action_str, table), + Some("authorization"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log permission denial: {}", + request_id, log_err + ); + } + + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Insufficient permissions for this operation, Might as well give up [request_id: {}]", request_id)), + warning: None, + results: None, + })); + } + + if !state.database.is_available() { + warn!( + "[{}] Database marked unavailable, returning graceful error", + request_id + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + "Database flagged unavailable before transaction", + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + + // ALL database operations now use transactions with user context set + let mut tx = match state.database.pool().begin().await { + Ok(tx) => { + state.database.mark_available(); + tx + } + Err(e) => { + state.database.mark_unavailable(); + error!("[{}] Failed to begin transaction: {}", request_id, e); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to begin transaction: {}", e), + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + }; + + // Set user context and request ID in transaction - ALL queries have user context now + if let Err(e) = sqlx::query("SET @current_user_id = ?, @request_id = ?") + .bind(session.user_id) + .bind(&request_id) + .execute(&mut *tx) + .await + { + state.database.mark_unavailable(); + error!( + "[{}] Failed to set current user context and request ID: {}", + request_id, e + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to set user context: {}", e), + ) + .await; + return Ok(Json(database_unavailable_response(&request_id))); + } + + // Execute the query within the transaction + let result = match action { + QueryAction::Select => { + execute_select_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &session, + &state, + ) + .await + } + QueryAction::Insert => { + execute_insert_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Update => { + execute_update_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Delete => { + execute_delete_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Count => { + execute_count_with_tx( + &request_id, + tx, + &payload_clone, + &session.username, + &session, + &state, + ) + .await + } + }; + + match result { + Ok(response) => { + let action_str = match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }; + + super::log_info_async( + &state.logging, + &request_id, + &format!("Query executed successfully: {} on {}", action_str, table), + Some("query_execution"), + Some(&session.username), + Some(session.power), + ); + Ok(Json(response)) + } + Err(e) => { + error!("[{}] Query execution failed: {}", request_id, e); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!("Query execution error: {}", e), + Some("query_execution"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!("[{}] Failed to log error: {}", request_id, log_err); + } + + Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Database query failed [request_id: {}]", + request_id + )), + warning: None, + results: None, + })) + } + } +} + +pub async fn get_permissions( + State(state): State<AppState>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, + headers: HeaderMap, +) -> Result<Json<PermissionsResponse>, StatusCode> { + let timestamp = Utc::now(); + let client_ip = addr.ip().to_string(); + let request_id = AuditLogger::generate_request_id(); + + // Extract and validate session token + let token = match extract_token(&headers) { + Some(token) => token, + None => { + return Ok(Json(PermissionsResponse { + success: false, + permissions: HashMap::new(), + user: UserInfo { + id: 0, + username: "".to_string(), + name: "".to_string(), + role: "".to_string(), + power: 0, + }, + security_clearance: None, + user_settings_access: "".to_string(), + })); + } + }; + + let session = match state.session_manager.get_session(&token) { + Some(session) => session, + None => { + return Ok(Json(PermissionsResponse { + success: false, + permissions: HashMap::new(), + user: UserInfo { + id: 0, + username: "".to_string(), + name: "".to_string(), + role: "".to_string(), + power: 0, + }, + security_clearance: None, + user_settings_access: "".to_string(), + })); + } + }; + + // Log the request + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/permissions", + &serde_json::json!({}), + ) + .await + { + error!("[{}] Failed to log request: {}", request_id, e); + } + + let permissions = state + .rbac + .get_table_permissions(&state.config, session.power); + + // Get user settings access permission + let user_settings_permission = state + .config + .permissions + .power_levels + .get(&session.power.to_string()) + .map(|p| &p.user_settings_access) + .unwrap_or(&state.config.security.default_user_settings_access); + + let user_settings_access_str = match user_settings_permission { + crate::config::UserSettingsAccess::ReadOwnOnly => "read-own-only", + crate::config::UserSettingsAccess::ReadWriteOwn => "read-write-own", + crate::config::UserSettingsAccess::ReadWriteAll => "read-write-all", + }; + + Ok(Json(PermissionsResponse { + success: true, + permissions, + user: UserInfo { + id: session.user_id, + username: session.username, + name: "".to_string(), // We don't store name in session, would need to fetch from DB + role: session.role_name, + power: session.power, + }, + security_clearance: None, + user_settings_access: user_settings_access_str.to_string(), + })) +} + +// ===== SAFE SQL QUERY BUILDERS WITH VALIDATION ===== +// All functions validate table/column names to prevent SQL injection + +/// Build WHERE clause with column name validation (legacy simple format) +fn build_where_clause(where_clause: &Value) -> anyhow::Result<(String, Vec<String>)> { + // Use the new validated builder from sql module + build_legacy_where_clause(where_clause) +} + +/// Build INSERT data with column name validation +fn build_insert_data( + data: &Value, +) -> anyhow::Result<(Vec<String>, Vec<String>, Vec<Option<String>>)> { + let mut columns = Vec::new(); + let mut placeholders = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = data { + for (key, value) in map { + // SECURITY: Validate column name before using it + validate_column_name(key) + .with_context(|| format!("Invalid column name in INSERT: {}", key))?; + + // Handle special JSON fields like additional_fields + if key == "additional_fields" && value.is_object() { + columns.push(key.clone()); + placeholders.push("?".to_string()); + values.push(Some( + serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()), + )); + } else { + columns.push(key.clone()); + placeholders.push("?".to_string()); + values.push(json_value_to_option_string(value)); + } + } + } else { + anyhow::bail!("INSERT data must be a JSON object"); + } + + if columns.is_empty() { + anyhow::bail!("INSERT data cannot be empty"); + } + + Ok((columns, placeholders, values)) +} + +/// Build UPDATE SET clause with column name validation +fn build_update_set_clause(data: &Value) -> anyhow::Result<(String, Vec<Option<String>>)> { + let mut set_clauses = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = data { + for (key, value) in map { + // SECURITY: Validate column name before using it + validate_column_name(key) + .with_context(|| format!("Invalid column name in UPDATE: {}", key))?; + + set_clauses.push(format!("{} = ?", key)); + // Handle special JSON fields like additional_fields + if key == "additional_fields" && value.is_object() { + values.push(Some( + serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()), + )); + } else { + values.push(json_value_to_option_string(value)); + } + } + } else { + anyhow::bail!("UPDATE data must be a JSON object"); + } + + if set_clauses.is_empty() { + anyhow::bail!("UPDATE data cannot be empty"); + } + + Ok((set_clauses.join(", "), values)) +} + +/// Convert JSON value to Option<String> for SQL binding +/// Properly handles booleans (true/false -> "1"/"0" for MySQL TINYINT/BOOLEAN) +/// NULL values return None for proper SQL NULL handling +fn json_value_to_option_string(value: &Value) -> Option<String> { + match value { + Value::String(s) => Some(s.clone()), + Value::Number(n) => Some(n.to_string()), + // MySQL uses TINYINT(1) for booleans: true -> 1, false -> 0 + Value::Bool(b) => Some(if *b { "1".to_string() } else { "0".to_string() }), + Value::Null => None, // Return None for proper SQL NULL + // Complex types (objects, arrays) get JSON serialized + _ => Some(serde_json::to_string(value).unwrap_or_else(|_| "null".to_string())), + } +} +/// Convert SQL values from MySQL to JSON +/// Handles ALL MySQL types automatically with proper NULL handling +/// Returns booleans as true/false (MySQL TINYINT(1) -> JSON bool) +fn convert_sql_value_to_json( + row: &sqlx::mysql::MySqlRow, + index: usize, + request_id: Option<&str>, + username: Option<&str>, + power: Option<i32>, + state: Option<&AppState>, +) -> anyhow::Result<Value> { + let column = &row.columns()[index]; + + use sqlx::TypeInfo; + let type_name = column.type_info().name(); + + // Comprehensive MySQL type handling - no need to manually add types anymore! + let result = match type_name { + // ===== String types ===== + "VARCHAR" | "TEXT" | "CHAR" | "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "SET" | "ENUM" => { + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + + // ===== Integer types ===== + "INT" | "BIGINT" | "MEDIUMINT" | "SMALLINT" | "INTEGER" => row + .try_get::<Option<i64>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // Unsigned integers + "INT UNSIGNED" | "BIGINT UNSIGNED" | "MEDIUMINT UNSIGNED" | "SMALLINT UNSIGNED" => row + .try_get::<Option<u64>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // ===== Boolean type (MySQL TINYINT(1)) ===== + // Returns proper JSON true/false instead of 1/0 + "TINYINT" | "BOOLEAN" | "BOOL" => { + // Try as bool first (for TINYINT(1)) + if let Ok(opt_bool) = row.try_get::<Option<bool>, _>(index) { + return Ok(opt_bool.map(Value::Bool).unwrap_or(Value::Null)); + } + // Fallback to i8 for regular TINYINT + row.try_get::<Option<i8>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)) + } + + // ===== Decimal/Numeric types ===== + "DECIMAL" | "NUMERIC" => { + row.try_get::<Option<rust_decimal::Decimal>, _>(index) + .map(|opt| { + opt.map(|decimal| { + // Keep precision by converting to string then parsing + let decimal_str = decimal.to_string(); + if let Ok(f) = decimal_str.parse::<f64>() { + serde_json::json!(f) + } else { + Value::String(decimal_str) + } + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Floating point types ===== + "FLOAT" | "DOUBLE" | "REAL" => row + .try_get::<Option<f64>, _>(index) + .map(|opt| opt.map(|v| serde_json::json!(v)).unwrap_or(Value::Null)), + + // ===== Date/Time types ===== + "DATE" => { + use chrono::NaiveDate; + row.try_get::<Option<NaiveDate>, _>(index).map(|opt| { + opt.map(|date| Value::String(date.format("%Y-%m-%d").to_string())) + .unwrap_or(Value::Null) + }) + } + "DATETIME" => { + use chrono::NaiveDateTime; + row.try_get::<Option<NaiveDateTime>, _>(index).map(|opt| { + opt.map(|datetime| Value::String(datetime.format("%Y-%m-%d %H:%M:%S").to_string())) + .unwrap_or(Value::Null) + }) + } + "TIMESTAMP" => { + use chrono::{DateTime, Utc}; + row.try_get::<Option<DateTime<Utc>>, _>(index).map(|opt| { + opt.map(|timestamp| Value::String(timestamp.to_rfc3339())) + .unwrap_or(Value::Null) + }) + } + "TIME" => { + // TIME values come as strings in HH:MM:SS format + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + "YEAR" => row + .try_get::<Option<i32>, _>(index) + .map(|opt| opt.map(|v| Value::Number(v.into())).unwrap_or(Value::Null)), + + // ===== JSON type ===== + "JSON" => row.try_get::<Option<String>, _>(index).map(|opt| { + opt.and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or(Value::Null) + }), + + // ===== Binary types ===== + // Return as base64-encoded strings for safe JSON transmission + "BLOB" | "MEDIUMBLOB" | "LONGBLOB" | "TINYBLOB" | "BINARY" | "VARBINARY" => { + row.try_get::<Option<Vec<u8>>, _>(index).map(|opt| { + opt.map(|bytes| { + use base64::{engine::general_purpose, Engine as _}; + Value::String(general_purpose::STANDARD.encode(&bytes)) + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Bit type ===== + "BIT" => { + row.try_get::<Option<Vec<u8>>, _>(index).map(|opt| { + opt.map(|bytes| { + // Convert bit value to number + let mut val: i64 = 0; + for &byte in &bytes { + val = (val << 8) | byte as i64; + } + Value::Number(val.into()) + }) + .unwrap_or(Value::Null) + }) + } + + // ===== Spatial/Geometry types ===== + "GEOMETRY" | "POINT" | "LINESTRING" | "POLYGON" | "MULTIPOINT" | "MULTILINESTRING" + | "MULTIPOLYGON" | "GEOMETRYCOLLECTION" => { + // Return as WKT (Well-Known Text) string + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + + // ===== Catch-all for unknown/new types ===== + // This ensures forward compatibility if MySQL adds new types + _ => { + warn!( + "Unknown MySQL type '{}' for column '{}', attempting string fallback", + type_name, + column.name() + ); + if let (Some(rid), Some(st)) = (request_id, state) { + super::log_warning_async( + &st.logging, + rid, + &format!( + "Unknown MySQL type '{}' for column '{}', attempting string fallback", + type_name, + column.name() + ), + Some("data_conversion"), + username, + power, + ); + } + row.try_get::<Option<String>, _>(index) + .map(|opt| opt.map(Value::String).unwrap_or(Value::Null)) + } + }; + + // Robust error handling with fallback + match result { + Ok(value) => Ok(value), + Err(e) => { + // Final fallback: try as string + match row.try_get::<Option<String>, _>(index) { + Ok(opt) => { + warn!("Primary conversion failed for column '{}' (type: {}), used string fallback", + column.name(), type_name); + if let (Some(rid), Some(st)) = (request_id, state) { + super::log_warning_async( + &st.logging, + rid, + &format!("Primary conversion failed for column '{}' (type: {}), used string fallback", column.name(), type_name), + Some("data_conversion"), + username, + power, + ); + } + Ok(opt.map(Value::String).unwrap_or(Value::Null)) + } + Err(_) => { + error!( + "Complete failure to decode column '{}' (index: {}, type: {}): {}", + column.name(), + index, + type_name, + e + ); + // Return NULL instead of failing the entire query + Ok(Value::Null) + } + } + } + } +} + +// Generate auto values based on configuration +async fn generate_auto_value( + state: &AppState, + table: &str, + config: &crate::config::AutoGenerationConfig, +) -> Result<String, anyhow::Error> { + match config.gen_type.as_str() { + "numeric" => generate_unique_numeric_id(state, table, config).await, + _ => Err(anyhow::anyhow!( + "Unsupported auto-generation type: {}", + config.gen_type + )), + } +} + +// Generate a unique numeric ID based on configuration +async fn generate_unique_numeric_id( + state: &AppState, + table: &str, + config: &crate::config::AutoGenerationConfig, +) -> Result<String, anyhow::Error> { + let range_min = config.range_min.unwrap_or(10000000); + let range_max = config.range_max.unwrap_or(99999999); + let max_attempts = config.max_attempts.unwrap_or(10) as usize; + let field_name = &config.field; + + for _attempt in 0..max_attempts { + // Generate random number in specified range + let id = { + let mut rng = rand::rng(); + rng.random_range(range_min..=range_max) + }; + let id_str = id.to_string(); + + // Check if this ID already exists + let query_str = format!( + "SELECT COUNT(*) as count FROM {} WHERE {} = ?", + table, field_name + ); + let exists = sqlx::query(&query_str) + .bind(&id_str) + .fetch_one(state.database.pool()) + .await?; + + let count: i64 = exists.try_get("count")?; + + if count == 0 { + return Ok(id_str); + } + } + + Err(anyhow::anyhow!( + "Failed to generate unique {} for table {} after {} attempts", + field_name, + table, + max_attempts + )) +} + +// Security validation functions +fn validate_input_security(payload: &QueryRequest) -> Result<(), String> { + // Check for null bytes in table name + if let Some(ref table) = payload.table { + if table.contains('\0') { + return Err("Null byte detected in table name".to_string()); + } + } + + // Check for null bytes in column names + if let Some(columns) = &payload.columns { + for column in columns { + if column.contains('\0') { + return Err("Null byte detected in column name".to_string()); + } + } + } + + // Check for null bytes in data values + if let Some(data) = &payload.data { + if contains_null_bytes_in_value(data) { + return Err("Null byte detected in data values".to_string()); + } + } + + // Check for null bytes in WHERE clause + if let Some(where_clause) = &payload.where_clause { + if contains_null_bytes_in_value(where_clause) { + return Err("Null byte detected in WHERE clause".to_string()); + } + } + + Ok(()) +} + +fn contains_null_bytes_in_value(value: &Value) -> bool { + match value { + Value::String(s) => s.contains('\0'), + Value::Array(arr) => arr.iter().any(contains_null_bytes_in_value), + Value::Object(map) => { + map.keys().any(|k| k.contains('\0')) || map.values().any(contains_null_bytes_in_value) + } + _ => false, + } +} + +// Core execution functions that work with mutable transaction references +// These are used by batch operations to execute multiple queries in a single atomic transaction + +async fn execute_select_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + let mut count = 0; + if let Some(f) = filter { + count += count_filter_conditions(f); + } + if let Some(w) = where_clause { + count += count_where_conditions(w); + } + count + } + + fn count_filter_conditions(filter: &crate::models::FilterCondition) -> usize { + use crate::models::FilterCondition; + match filter { + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + let mut count = 0; + if let Some(conditions) = and_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + if let Some(conditions) = or_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + count + } + _ => 1, + } + } + + fn count_where_conditions(where_clause: &serde_json::Value) -> usize { + match where_clause { + serde_json::Value::Object(map) => { + if map.contains_key("AND") || map.contains_key("OR") { + if let Some(arr) = map.get("AND").or_else(|| map.get("OR")) { + if let serde_json::Value::Array(conditions) = arr { + return conditions.iter().map(|c| count_where_conditions(c)).sum(); + } + } + } + 1 + } + _ => 1, + } + } + + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + let requested_columns = if let Some(ref cols) = payload.columns { + cols.clone() + } else { + vec!["*".to_string()] + }; + + let filtered_columns = if requested_columns.len() == 1 && requested_columns[0] == "*" { + "*".to_string() + } else { + let allowed_columns = + state + .config + .filter_readable_columns(session.power, &table, &requested_columns); + if allowed_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No readable columns available for this request".to_string()), + warning: None, + results: None, + }); + } + allowed_columns.join(", ") + }; + + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT {} FROM {}", filtered_columns, table); + let mut values = Vec::new(); + + if let Some(joins) = &payload.joins { + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + if let Some(order_by) = &payload.order_by { + let order_clause = build_order_by_clause(order_by)?; + query.push_str(&order_clause); + } + + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {} for your power level, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, defaulted to {} (max for power level {})", + max_limit, session.power + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let rows = sqlx_query.fetch_all(&mut **tx).await?; + + let mut results = Vec::new(); + for row in rows { + let mut result_row = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value = convert_sql_value_to_json( + &row, + i, + Some(request_id), + Some(username), + Some(session.power), + Some(state), + )?; + result_row.insert(column.name().to_string(), value); + } + results.push(Value::Object(result_row)); + } + + Ok(QueryResponse { + success: true, + data: Some(Value::Array(results)), + rows_affected: None, + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_count_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where (same as other functions) + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter_conditions(filter: &FilterCondition) -> usize { + match filter { + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + let mut count = 0; + if let Some(conditions) = and_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + if let Some(conditions) = or_conditions { + count += conditions + .iter() + .map(|c| count_filter_conditions(c)) + .sum::<usize>(); + } + count + } + _ => 1, + } + } + + fn count_where_conditions(where_clause: &serde_json::Value) -> usize { + match where_clause { + serde_json::Value::Object(map) => { + if map.contains_key("AND") || map.contains_key("OR") { + if let Some(arr) = map.get("AND").or_else(|| map.get("OR")) { + if let serde_json::Value::Array(conditions) = arr { + return conditions.iter().map(|c| count_where_conditions(c)).sum(); + } + } + } + 1 + } + _ => 1, + } + } + + let mut count = 0; + if let Some(f) = filter { + count += count_filter_conditions(f); + } + if let Some(w) = where_clause { + count += count_where_conditions(w); + } + count + } + + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT COUNT(*) as count FROM {}", table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE conditions (filter takes precedence over where_clause if both are provided) + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = crate::sql::build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Log the query + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.fetch_one(&mut **tx).await?; + let count: i64 = result.try_get("count")?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(count)), + rows_affected: None, + error: None, + warning: None, + results: None, + }) +} + +async fn execute_update_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for UPDATE"))? + .clone(); + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for UPDATE"))?; + + let max_where = state.config.get_max_where_conditions(session.power); + let condition_count = where_clause.as_object().map(|obj| obj.len()).unwrap_or(0); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation (system-generated fields bypass write protection) + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + if auto_config.on_action == "update" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + let writable_columns = data + .as_object() + .unwrap() + .keys() + .cloned() + .collect::<Vec<_>>(); + if writable_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in UPDATE data".to_string()), + warning: None, + results: None, + }); + } + + let set_clause = writable_columns + .iter() + .map(|col| format!("{} = ?", col)) + .collect::<Vec<_>>() + .join(", "); + let (where_sql, where_values) = build_where_clause(where_clause)?; + let query = format!("UPDATE {} SET {} WHERE {}", table, set_clause, where_sql); + + let requested_limit = payload.limit; + let max_limit = state.config.get_max_limit(session.power); + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + let query_with_limit = format!("{} LIMIT {}", query, limit); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {}, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query_with_limit, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query_with_limit); + for col in &writable_columns { + if let Some(val) = data.get(col) { + sqlx_query = sqlx_query.bind(val.clone()); + } + } + for val in where_values { + sqlx_query = sqlx_query.bind(val); + } + + let result = sqlx_query.execute(&mut **tx).await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_delete_core( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for DELETE"))?; + + let max_where = state.config.get_max_where_conditions(session.power); + let condition_count = where_clause.as_object().map(|obj| obj.len()).unwrap_or(0); + if condition_count > max_where as usize { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {}", + condition_count, max_where + )), + warning: None, + results: None, + }); + } + + let (where_sql, where_values) = build_where_clause(where_clause)?; + let query = format!("DELETE FROM {} WHERE {}", table, where_sql); + + let requested_limit = payload.limit; + let max_limit = state.config.get_max_limit(session.power); + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + let query_with_limit = format!("{} LIMIT {}", query, limit); + + let limit_warning = if was_capped { + Some(format!( + "Requested LIMIT {} exceeded maximum {}, capped to {}", + requested_limit.unwrap(), + max_limit, + max_limit + )) + } else { + None + }; + + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query_with_limit, + None, + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + let mut sqlx_query = sqlx::query(&query_with_limit); + for val in where_values { + sqlx_query = sqlx_query.bind(val); + } + + let result = sqlx_query.execute(&mut **tx).await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +// Transaction-based execution functions for user context operations +// These create their own transactions and commit them - used for single query operations + +async fn execute_select_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Helper to count conditions in filter/where + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter(cond: &FilterCondition) -> usize { + match cond { + FilterCondition::Simple { .. } => 1, + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + and_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + + or_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + } + FilterCondition::Not { not } => count_filter(not), + } + } + let mut count = 0; + if let Some(f) = filter { + count += count_filter(f); + } + if let Some(w) = where_clause { + if let serde_json::Value::Object(map) = w { + count += map.len(); + } + } + count + } + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Apply granular column filtering based on user's power level + let requested_columns = if let Some(cols) = &payload.columns { + cols.clone() + } else { + vec!["*".to_string()] + }; + + let filtered_columns = if requested_columns.len() == 1 && requested_columns[0] == "*" { + "*".to_string() + } else { + let allowed_columns = + state + .config + .filter_readable_columns(session.power, &table, &requested_columns); + if allowed_columns.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No readable columns available for this request".to_string()), + warning: None, + results: None, + }); + } + allowed_columns.join(", ") + }; + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT {} FROM {}", filtered_columns, table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + // Validate user has read permission for all joined tables + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + + // Build and append JOIN SQL + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE clause - support both legacy and new filter format + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Add ORDER BY if provided + if let Some(order_by) = &payload.order_by { + let order_clause = build_order_by_clause(order_by)?; + query.push_str(&order_clause); + } + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Add OFFSET if provided + if let Some(offset) = payload.offset { + query.push_str(&format!(" OFFSET {}", offset)); + } + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, // Row count will be known after execution + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let rows = sqlx_query.fetch_all(&mut *tx).await?; + tx.commit().await?; + + let mut results = Vec::new(); + for row in rows { + let mut result_row = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value = convert_sql_value_to_json( + &row, + i, + Some(request_id), + Some(username), + Some(session.power), + Some(state), + )?; + result_row.insert(column.name().to_string(), value); + } + results.push(Value::Object(result_row)); + } + + Ok(QueryResponse { + success: true, + data: Some(Value::Array(results)), + rows_affected: None, + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_insert_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for INSERT"))? + .clone(); + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + // This validates what the user is trying to write + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation requirements based on config + // Auto-generated fields bypass write protection since they're system-generated + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + // Check if auto-generation is enabled for INSERT action + if auto_config.on_action == "insert" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + // Generate auto value based on config + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + // Final validation: ensure we have columns to insert + if let Value::Object(ref map) = data { + if map.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in INSERT data".to_string()), + warning: None, + results: None, + }); + } + } + + let (columns_vec, placeholders_vec, values) = build_insert_data(&data)?; + let columns = columns_vec.join(", "); + let placeholders = placeholders_vec.join(", "); + + let query = format!( + "INSERT INTO {} ({}) VALUES ({})", + table, columns, placeholders + ); + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.execute(&mut *tx).await?; + let insert_id = result.last_insert_id(); + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(insert_id)), + rows_affected: Some(result.rows_affected()), + error: None, + warning: None, // INSERT queries don't have LIMIT, + results: None, + }) +} + +async fn execute_update_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let mut data = payload + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Data is required for UPDATE"))? + .clone(); + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for UPDATE"))?; + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = if let Some(w) = &payload.where_clause { + if let serde_json::Value::Object(map) = w { + map.len() + } else { + 0 + } + } else { + 0 + }; + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + // SECURITY: Apply column-level write filtering FIRST (before auto-generation) + if let Value::Object(ref mut map) = data { + let all_columns: Vec<String> = map.keys().cloned().collect(); + let writable_columns = + state + .config + .filter_writable_columns(session.power, &table, &all_columns); + + // Remove columns that user cannot write to + map.retain(|key, _| writable_columns.contains(key)); + + // Check for auto-generation (system-generated fields bypass write protection) + if let Some(auto_config) = state.config.get_auto_generation_config(&table) { + if auto_config.on_action == "update" || auto_config.on_action == "both" { + let field_name = &auto_config.field; + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + let generated_value = generate_auto_value(&state, &table, auto_config).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + // Final validation: ensure we have columns to update + if let Value::Object(ref map) = data { + if map.is_empty() { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("No writable columns in UPDATE data".to_string()), + warning: None, + results: None, + }); + } + } + + let (set_clause, mut values) = build_update_set_clause(&data)?; + let (where_sql, where_values) = build_where_clause(where_clause)?; + // Convert where_values to Option<String> to match set values + values.extend(where_values.into_iter().map(Some)); + + let mut query = format!("UPDATE {} SET {} WHERE {}", table, set_clause, where_sql); + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.execute(&mut *tx).await?; + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_delete_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + let where_clause = payload + .where_clause + .as_ref() + .ok_or_else(|| anyhow::anyhow!("WHERE clause is required for DELETE"))?; + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_limit = state.config.get_max_limit(session.power); + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = if let Some(w) = &payload.where_clause { + if let serde_json::Value::Object(map) = w { + map.len() + } else { + 0 + } + } else { + 0 + }; + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let (where_sql, values) = build_where_clause(where_clause)?; + + let mut query = format!("DELETE FROM {} WHERE {}", table, where_sql); + + // Enforce LIMIT and track if it was capped + let requested_limit = payload.limit; + let limit = requested_limit.unwrap_or(max_limit); + let was_capped = limit > max_limit; + let limit = if limit > max_limit { max_limit } else { limit }; + query.push_str(&format!(" LIMIT {}", limit)); + + let limit_warning = if was_capped { + Some(format!("Requested LIMIT {} exceeded maximum {} for your power level, capped to {} [request_id: {}]", + requested_limit.unwrap(), max_limit, max_limit, request_id)) + } else if requested_limit.is_none() { + Some(format!( + "No LIMIT specified, using default {} based on power level [request_id: {}]", + max_limit, request_id + )) + } else { + None + }; + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + sqlx_query = sqlx_query.bind(value); + } + + let result = sqlx_query.execute(&mut *tx).await?; + tx.commit().await?; + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(result.rows_affected()), + error: None, + warning: limit_warning, + results: None, + }) +} + +async fn execute_count_with_tx( + request_id: &str, + mut tx: sqlx::Transaction<'_, sqlx::MySql>, + payload: &QueryRequest, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let table = payload + .table + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Table is required"))?; + + // Check read permissions for the table + if !state.rbac.check_permission( + &state.config, + session.power, + table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to COUNT from table '{}' [request_id: {}]", + table, request_id + )), + warning: None, + results: None, + }); + } + + // Helper to count conditions in filter/where (same as select function) + fn count_conditions( + filter: &Option<crate::models::FilterCondition>, + where_clause: &Option<serde_json::Value>, + ) -> usize { + use crate::models::FilterCondition; + fn count_filter(cond: &FilterCondition) -> usize { + match cond { + FilterCondition::Simple { .. } => 1, + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + and_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + + or_conditions + .as_ref() + .map_or(0, |conds| conds.iter().map(count_filter).sum()) + } + FilterCondition::Not { not } => count_filter(not), + } + } + let mut count = 0; + if let Some(f) = filter { + count += count_filter(f); + } + if let Some(w) = where_clause { + if let serde_json::Value::Object(map) = w { + count += map.len(); + } + } + count + } + + // Enforce query limits from config (power-level specific with fallback to defaults) + let max_where = state.config.get_max_where_conditions(session.power); + + // Enforce WHERE clause complexity + let condition_count = count_conditions(&payload.filter, &payload.where_clause); + if condition_count > max_where as usize { + // Log security violation + let timestamp = chrono::Utc::now(); + if let Err(log_err) = state + .logging + .log_error( + &request_id, + timestamp, + &format!( + "Too many WHERE conditions: {} exceeds maximum {} for power level {}", + condition_count, max_where, session.power + ), + Some("query_limits"), + Some(&session.username), + Some(session.power), + ) + .await + { + error!( + "[{}] Failed to log WHERE limit violation: {}", + request_id, log_err + ); + } + + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Too many WHERE/filter conditions ({}) > max {} [request_id: {}]", + condition_count, max_where, request_id + )), + warning: None, + results: None, + }); + } + + let mut query = format!("SELECT COUNT(*) as count FROM {}", table); + let mut values = Vec::new(); + + // Add JOIN clauses if provided - validates permissions for all joined tables + if let Some(joins) = &payload.joins { + // Validate user has read permission for all joined tables + for join in joins { + if !state.rbac.check_permission( + &state.config, + session.power, + &join.table, + &crate::models::QueryAction::Select, + ) { + return Ok(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions to JOIN with table '{}'", + join.table + )), + warning: None, + results: None, + }); + } + } + let join_sql = crate::sql::build_join_clause(joins, &state.config)?; + query.push_str(&join_sql); + } + + // Add WHERE conditions (filter takes precedence over where_clause if both are provided) + if let Some(filter) = &payload.filter { + let (where_sql, where_values) = crate::sql::build_filter_clause(filter)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } else if let Some(where_clause) = &payload.where_clause { + let (where_sql, where_values) = build_where_clause(where_clause)?; + query.push_str(&format!(" WHERE {}", where_sql)); + values.extend(where_values.into_iter().map(Some)); + } + + // Log the query + let params_json = serde_json::to_value(&values).ok(); + if let Err(e) = state + .logging + .log_query( + request_id, + chrono::Utc::now(), + username, + Some(session.power), + &query, + params_json.as_ref(), + None, + ) + .await + { + error!("[{}] Failed to log query: {}", request_id, e); + } + + // Execute the query + let mut sqlx_query = sqlx::query(&query); + for value in values { + match value { + Some(v) => sqlx_query = sqlx_query.bind(v), + None => sqlx_query = sqlx_query.bind(Option::<String>::None), + } + } + + let result = sqlx_query.fetch_one(&mut *tx).await?; + tx.commit().await?; + + let count: i64 = result.try_get("count")?; + + Ok(QueryResponse { + success: true, + data: Some(serde_json::json!(count)), + rows_affected: None, + error: None, + warning: None, + results: None, + }) +} + +/// Execute multiple queries in a single transaction +async fn execute_batch_mode( + state: AppState, + session: crate::models::Session, + request_id: String, + timestamp: chrono::DateTime<chrono::Utc>, + client_ip: String, + payload: &QueryRequest, +) -> Result<Json<QueryResponse>, StatusCode> { + let queries = payload.queries.as_ref().unwrap(); + + // Check if batch is empty + if queries.is_empty() { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some("Batch request cannot be empty".to_string()), + warning: None, + results: Some(vec![]), + })); + } + + info!( + "[{}] Batch query request from user {} (power {}): {} queries", + request_id, + session.username, + session.power, + queries.len() + ); + + // Log the batch request (log full payload including action/table) + if let Err(e) = state + .logging + .log_request( + &request_id, + timestamp, + &client_ip, + Some(&session.username), + Some(session.power), + "/query", + &serde_json::to_value(payload).unwrap_or_default(), + ) + .await + { + error!("[{}] Failed to log batch request: {}", request_id, e); + } + + // Get action and table from parent request (all batch queries share these) + let action = payload + .action + .as_ref() + .ok_or_else(|| StatusCode::BAD_REQUEST)?; + + let table = payload + .table + .as_ref() + .ok_or_else(|| StatusCode::BAD_REQUEST)?; + + // Validate table name once for the entire batch + if let Err(e) = validate_table_name(table, &state.config) { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Invalid table name: {}", e)), + warning: None, + results: Some(vec![]), + })); + } + + // Check RBAC permission once for the entire batch + if !state + .rbac + .check_permission(&state.config, session.power, table, action) + { + return Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!( + "Insufficient permissions for {} on table '{}'", + match action { + QueryAction::Select => "SELECT", + QueryAction::Insert => "INSERT", + QueryAction::Update => "UPDATE", + QueryAction::Delete => "DELETE", + QueryAction::Count => "COUNT", + }, + table + )), + warning: None, + results: Some(vec![]), + })); + } + + info!( + "[{}] Validated batch: {} x {:?} on table '{}'", + request_id, + queries.len(), + action, + table + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Validated batch: {} x {:?} on table '{}'", + queries.len(), + action, + table + ), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + if !state.database.is_available() { + warn!( + "[{}] Database marked unavailable before batch execution", + request_id + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + "Database flagged unavailable before batch", + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + + // Start a SINGLE transaction for the batch - proper atomic operation + let mut tx = match state.database.pool().begin().await { + Ok(tx) => { + state.database.mark_available(); + tx + } + Err(e) => { + state.database.mark_unavailable(); + error!("[{}] Failed to begin batch transaction: {}", request_id, e); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to begin batch transaction: {}", e), + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + }; + + // Set user context and request ID in transaction + if let Err(e) = sqlx::query("SET @current_user_id = ?, @request_id = ?") + .bind(session.user_id) + .bind(&request_id) + .execute(&mut *tx) + .await + { + state.database.mark_unavailable(); + error!( + "[{}] Failed to set current user context and request ID: {}", + request_id, e + ); + log_database_unavailable_event( + &state.logging, + &request_id, + Some(&session.username), + Some(session.power), + &format!("Failed to set batch user context: {}", e), + ) + .await; + return Ok(Json(database_unavailable_batch_response(&request_id))); + } + + let rollback_on_error = payload.rollback_on_error.unwrap_or(false); + + info!( + "[{}] Executing NATIVE batch: {} x {:?} on '{}' (rollback_on_error={})", + request_id, + queries.len(), + action, + table, + rollback_on_error + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Executing NATIVE batch: {} x {:?} on '{}' (rollback_on_error={})", + queries.len(), + action, + table, + rollback_on_error + ), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + // Execute as SINGLE native batch query based on action type + let result = match action { + QueryAction::Insert => { + execute_batch_insert( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Update => { + execute_batch_update( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Delete => { + execute_batch_delete( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &state, + &session, + ) + .await + } + QueryAction::Select => { + // SELECT batches are less common but we'll execute them individually + // (combining SELECTs into one query doesn't make sense as they return different results) + execute_batch_selects( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &session, + &state, + ) + .await + } + QueryAction::Count => { + // COUNT batches execute individually (each returns different results) + execute_batch_counts( + &request_id, + &mut tx, + queries, + table, + action, + &session.username, + &session, + &state, + ) + .await + } + }; + + match result { + Ok(response) => { + if response.success || !rollback_on_error { + // Commit the transaction + tx.commit().await.map_err(|e| { + error!("[{}] Failed to commit batch transaction: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + info!( + "[{}] Native batch committed: {} operations", + request_id, + queries.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Native batch committed: {} operations", queries.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + Ok(Json(response)) + } else { + // Rollback on error + error!( + "[{}] Rolling back batch transaction due to error", + request_id + ); + tx.rollback().await.map_err(|e| { + error!("[{}] Failed to rollback transaction: {}", request_id, e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(Json(response)) + } + } + Err(e) => { + error!("[{}] Batch execution failed: {}", request_id, e); + tx.rollback().await.map_err(|e2| { + error!("[{}] Failed to rollback after error: {}", request_id, e2); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + Ok(Json(QueryResponse { + success: false, + data: None, + rows_affected: None, + error: Some(format!("Batch execution failed: {}", e)), + warning: None, + results: Some(vec![]), + })) + } + } +} + +// ===== NATIVE BATCH EXECUTION FUNCTIONS ===== + +/// Execute batch INSERT using MySQL multi-value INSERT +async fn execute_batch_insert( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + _action: &QueryAction, + _username: &str, + state: &AppState, + _session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + use serde_json::Value; + + // Extract all data objects, apply auto-generation, and validate they have the same columns + let mut all_data = Vec::new(); + let mut column_set: Option<std::collections::HashSet<String>> = None; + + // Check for auto-generation config + let auto_config = state.config.get_auto_generation_config(&table); + + for (idx, query) in queries.iter().enumerate() { + let mut data = query + .data + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Query {} missing data field for INSERT", idx + 1))? + .clone(); + + // Apply auto-generation if configured for INSERT + if let Some(ref auto_cfg) = auto_config { + if auto_cfg.on_action == "insert" || auto_cfg.on_action == "both" { + if let Value::Object(ref mut map) = data { + let field_name = &auto_cfg.field; + + if !map.contains_key(field_name) + || map.get(field_name).map_or(true, |v| { + v.is_null() || v.as_str().map_or(true, |s| s.is_empty()) + }) + { + // Generate auto value based on config + let generated_value = generate_auto_value(&state, &table, auto_cfg).await?; + map.insert(field_name.clone(), Value::String(generated_value)); + } + } + } + } + + if let Value::Object(map) = data { + let cols: std::collections::HashSet<String> = map.keys().cloned().collect(); + + if let Some(ref expected_cols) = column_set { + if *expected_cols != cols { + anyhow::bail!("All INSERT queries must have the same columns. Query {} has different columns", idx + 1); + } + } else { + column_set = Some(cols); + } + + all_data.push(map); + } else { + anyhow::bail!("Query {} data must be an object", idx + 1); + } + } + + let columns: Vec<String> = column_set.unwrap().into_iter().collect(); + + // Validate column names + for col in &columns { + validate_column_name(col)?; + } + + // Build multi-value INSERT: INSERT INTO table (col1, col2) VALUES (?, ?), (?, ?), ... + let column_list = columns.join(", "); + let value_placeholder = format!("({})", vec!["?"; columns.len()].join(", ")); + let values_clause = vec![value_placeholder; all_data.len()].join(", "); + + let sql = format!( + "INSERT INTO {} ({}) VALUES {}", + table, column_list, values_clause + ); + + info!( + "[{}] Native batch INSERT: {} rows into {}", + request_id, + all_data.len(), + table + ); + super::log_info_async( + &state.logging, + &request_id, + &format!( + "Native batch INSERT: {} rows into {}", + all_data.len(), + table + ), + Some("query"), + Some(&_session.username), + Some(_session.power), + ); + + // Bind all values + let mut query = sqlx::query(&sql); + for data_map in &all_data { + for col in &columns { + let value = data_map.get(col).and_then(|v| match v { + Value::String(s) => Some(s.clone()), + Value::Number(n) => Some(n.to_string()), + Value::Bool(b) => Some(b.to_string()), + Value::Null => None, + _ => Some(v.to_string()), + }); + query = query.bind(value); + } + } + + // Execute the batch INSERT + let result = query.execute(&mut **tx).await?; + let rows_affected = result.rows_affected(); + + info!( + "[{}] Batch INSERT affected {} rows", + request_id, rows_affected + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch INSERT affected {} rows", rows_affected), + Some("query"), + Some(&_session.username), + Some(_session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(rows_affected), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch UPDATE (executes individually for now, could be optimized with CASE statements) +async fn execute_batch_update( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + let mut total_rows = 0u64; + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_update_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + state, + session, + ) + .await?; + + if let Some(rows) = result.rows_affected { + total_rows += rows; + } + } + + info!( + "[{}] Batch UPDATE affected {} total rows", + request_id, total_rows + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch UPDATE affected {} total rows", total_rows), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(total_rows), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch DELETE using IN clause when possible +async fn execute_batch_delete( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + state: &AppState, + session: &crate::models::Session, +) -> anyhow::Result<QueryResponse> { + // For now, execute individually + // TODO: Optimize by detecting simple ID-based deletes and combining with IN clause + let mut total_rows = 0u64; + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_delete_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + state, + session, + ) + .await?; + + if let Some(rows) = result.rows_affected { + total_rows += rows; + } + } + + info!( + "[{}] Batch DELETE affected {} total rows", + request_id, total_rows + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch DELETE affected {} total rows", total_rows), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: Some(total_rows), + error: None, + warning: None, + results: None, + }) +} + +/// Execute batch SELECT (executes individually since they return different results) +async fn execute_batch_selects( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let mut results = Vec::new(); + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_select_core( + &format!("{}-{}", request_id, idx + 1), + tx, + &query_req, + username, + session, + state, + ) + .await?; + + results.push(result); + } + + info!( + "[{}] Batch SELECT executed {} queries", + request_id, + results.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch SELECT executed {} queries", results.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: None, + error: None, + warning: None, + results: Some(results), + }) +} + +/// Execute batch COUNT (executes individually since they return different results) +async fn execute_batch_counts( + request_id: &str, + tx: &mut sqlx::Transaction<'_, sqlx::MySql>, + queries: &Vec<crate::models::BatchQuery>, + table: &str, + action: &QueryAction, + username: &str, + session: &crate::models::Session, + state: &AppState, +) -> anyhow::Result<QueryResponse> { + let mut results = Vec::new(); + + for (idx, batch_query) in queries.iter().enumerate() { + // Convert BatchQuery to QueryRequest by adding inherited action/table + let query_req = QueryRequest { + action: Some(action.clone()), + table: Some(table.to_string()), + columns: batch_query.columns.clone(), + data: batch_query.data.clone(), + where_clause: batch_query.where_clause.clone(), + filter: batch_query.filter.clone(), + joins: None, + limit: batch_query.limit, + offset: batch_query.offset, + order_by: batch_query.order_by.clone(), + queries: None, + rollback_on_error: None, + }; + + let result = execute_count_core( + &format!("{}:{}", request_id, idx), + tx, + &query_req, + username, + session, + state, + ) + .await?; + + results.push(result); + } + + info!( + "[{}] Batch COUNT executed {} queries", + request_id, + results.len() + ); + super::log_info_async( + &state.logging, + &request_id, + &format!("Batch COUNT executed {} queries", results.len()), + Some("query"), + Some(&session.username), + Some(session.power), + ); + + Ok(QueryResponse { + success: true, + data: None, + rows_affected: None, + error: None, + warning: None, + results: Some(results), + }) +} diff --git a/src/scheduler.rs b/src/scheduler.rs new file mode 100644 index 0000000..ec2710b --- /dev/null +++ b/src/scheduler.rs @@ -0,0 +1,185 @@ +// Scheduled query execution module +use crate::config::{Config, ScheduledQueryTask}; +use crate::db::Database; +use crate::logging::AuditLogger; +use chrono::Utc; +use std::sync::Arc; +use tokio::time::{interval, Duration}; +use tracing::{error, info, warn}; + +pub struct QueryScheduler { + config: Arc<Config>, + database: Database, + logging: AuditLogger, +} + +impl QueryScheduler { + pub fn new(config: Arc<Config>, database: Database, logging: AuditLogger) -> Self { + Self { + config, + database, + logging, + } + } + + /// Spawn background tasks for all enabled scheduled queries + pub fn spawn_tasks(&self) { + let Some(scheduled_config) = &self.config.scheduled_queries else { + info!("No scheduled queries configured"); + return; + }; + + if scheduled_config.tasks.is_empty() { + info!("No scheduled query tasks defined"); + return; + } + + let enabled_tasks: Vec<&ScheduledQueryTask> = scheduled_config + .tasks + .iter() + .filter(|task| task.enabled) + .collect(); + + if enabled_tasks.is_empty() { + info!("No enabled scheduled query tasks"); + return; + } + + info!( + "Spawning {} enabled scheduled query task(s)", + enabled_tasks.len() + ); + + for task in enabled_tasks { + let task_clone = task.clone(); + let database = self.database.clone(); + let logging = self.logging.clone(); + + tokio::spawn(async move { + Self::run_scheduled_task(task_clone, database, logging).await; + }); + } + } + + async fn run_scheduled_task( + task: ScheduledQueryTask, + database: Database, + logging: AuditLogger, + ) { + let mut interval_timer = interval(Duration::from_secs(task.interval_minutes * 60)); + let mut first_run = true; + + info!( + "Scheduled task '{}' started (interval: {}min, run_on_startup: {}): {}", + task.name, task.interval_minutes, task.run_on_startup, task.description + ); + + loop { + interval_timer.tick().await; + + if !database.is_available() { + warn!( + "Skipping scheduled task '{}' because database is unavailable", + task.name + ); + log_task_event( + &logging, + "scheduler_skip", + &task.name, + "Database unavailable, task skipped", + false, + ) + .await; + continue; + } + + // Skip first execution if run_on_startup is false + if first_run && !task.run_on_startup { + first_run = false; + info!( + "Scheduled task '{}' skipping initial run (run_on_startup=false)", + task.name + ); + continue; + } + first_run = false; + + match sqlx::query(&task.query).execute(database.pool()).await { + Ok(result) => { + database.mark_available(); + info!( + "Scheduled task '{}' executed successfully (rows affected: {})", + task.name, + result.rows_affected() + ); + log_task_event( + &logging, + "scheduler_success", + &task.name, + &format!( + "Task executed successfully (rows affected: {})", + result.rows_affected() + ), + false, + ) + .await; + } + Err(e) => { + database.mark_unavailable(); + error!( + "Scheduled task '{}' failed: {} (query: {})", + task.name, e, task.query + ); + log_task_event( + &logging, + "scheduler_failure", + &task.name, + &format!("Task failed: {}", e), + true, + ) + .await; + } + } + } + } +} + +async fn log_task_event( + logging: &AuditLogger, + context: &str, + task_name: &str, + message: &str, + as_error: bool, +) { + let request_id = AuditLogger::generate_request_id(); + let timestamp = Utc::now(); + let full_message = format!("{}: {}", task_name, message); + + let result = if as_error { + logging + .log_error( + &request_id, + timestamp, + &full_message, + Some(context), + Some("system"), + None, + ) + .await + } else { + logging + .log_info( + &request_id, + timestamp, + &full_message, + Some(context), + Some("system"), + None, + ) + .await + }; + + if let Err(err) = result { + error!("Failed to record scheduler event ({}): {}", context, err); + } +} diff --git a/src/sql/builder.rs b/src/sql/builder.rs new file mode 100644 index 0000000..7ba84f2 --- /dev/null +++ b/src/sql/builder.rs @@ -0,0 +1,493 @@ +// Enhanced SQL query builder with proper validation and complex WHERE support +use anyhow::{Context, Result}; +use regex::Regex; +use serde_json::Value; + +use crate::config::Config; +use crate::models::{FilterCondition, FilterOperator, OrderBy, OrderDirection}; + +/// Parse a table reference possibly containing an alias into (base_table, alias) +/// Accepts formats like: "table", "table alias", "table AS alias" (AS case-insensitive) +pub fn parse_table_and_alias(input: &str) -> (String, Option<String>) { + let s = input.trim(); + // Normalize multiple spaces + let parts: Vec<&str> = s.split_whitespace().collect(); + if parts.is_empty() { + return (String::new(), None); + } + + if parts.len() == 1 { + return (parts[0].to_string(), None); + } + + if parts.len() >= 3 && parts[1].eq_ignore_ascii_case("AS") { + // table AS alias [ignore extra] + return (parts[0].to_string(), Some(parts[2].to_string())); + } + + // table alias + (parts[0].to_string(), Some(parts[1].to_string())) +} + +/// Validates a table name against known tables and SQL injection patterns +pub fn validate_table_name(table: &str, config: &Config) -> Result<()> { + // Check if empty + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty"); + } + + // Check against known tables list + let known_tables = config.get_known_tables(); + if !known_tables.contains(&table.to_string()) { + anyhow::bail!("Table '{}' is not in the known tables list", table); + } + + // Validate format: only alphanumeric and underscores, must start with letter + let table_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table name regex")?; + + if !table_regex.is_match(table) { + anyhow::bail!("Invalid table name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", table); + } + + // Additional security: check for SQL keywords and dangerous patterns + let dangerous_patterns = [ + "--", "/*", "*/", ";", "DROP", "ALTER", "CREATE", "EXEC", "EXECUTE", + ]; + let table_upper = table.to_uppercase(); + for pattern in &dangerous_patterns { + if table_upper.contains(pattern) { + anyhow::bail!("Table name contains forbidden pattern: '{}'", pattern); + } + } + + Ok(()) +} + +/// Validates a column name against SQL injection patterns +/// Supports: column, table.column, table.column as alias, table.* +pub fn validate_column_name(column: &str) -> Result<()> { + // Check if empty + if column.trim().is_empty() { + anyhow::bail!("Column name cannot be empty"); + } + + // Allow * for SELECT all + if column == "*" { + return Ok(()); + } + + // Check for SQL injection patterns (comments and statement terminators) + let dangerous_chars = ["--", "/*", "*/", ";"]; + for pattern in &dangerous_chars { + if column.contains(pattern) { + anyhow::bail!("Column name contains forbidden pattern: '{}'", pattern); + } + } + + // Note: We don't block SQL keywords like DROP, CREATE, ALTER etc. in column names + // because legitimate columns like "created_date", "created_by" contain these as substrings. + // The regex validation below ensures only alphanumeric + underscore characters are allowed, + // which prevents actual SQL injection while allowing valid column names. + + // Support multiple formats: + // 1. Simple column: column_name + // 2. Qualified column: table.column_name + // 3. Wildcard: table.* + // 4. Alias: column_name as alias or table.column_name as alias + + // Remove AS alias if present (case insensitive) + let column_upper = column.to_uppercase(); + let column_without_alias = if column_upper.contains(" AS ") { + column.split_whitespace().next().unwrap_or("") + } else { + column + }; + + // Check for qualified column (table.column or table.*) + if column_without_alias.contains('.') { + let parts: Vec<&str> = column_without_alias.split('.').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Invalid qualified column format: '{}'. Must be table.column", + column + ); + } + + // Validate table part + let table_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile table regex")?; + if !table_regex.is_match(parts[0]) { + anyhow::bail!("Invalid table name in qualified column: '{}'", parts[0]); + } + + // Validate column part (allow * for table.*) + if parts[1] != "*" { + let column_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile column regex")?; + if !column_regex.is_match(parts[1]) { + anyhow::bail!("Invalid column name in qualified column: '{}'", parts[1]); + } + } + } else { + // Simple column name validation + let column_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$") + .context("Failed to compile column name regex")?; + + if !column_regex.is_match(column_without_alias) { + anyhow::bail!("Invalid column name format: '{}'. Must start with a letter and contain only letters, numbers, and underscores", column); + } + } + + Ok(()) +} + +/// Validates multiple column names +pub fn validate_column_names(columns: &[String]) -> Result<()> { + for column in columns { + validate_column_name(column).with_context(|| format!("Invalid column name: {}", column))?; + } + Ok(()) +} + +/// Build WHERE clause from enhanced FilterCondition +/// Supports complex operators (=, !=, >, <, LIKE, IN, IS NULL, etc.) and nested logic (AND, OR, NOT) +pub fn build_filter_clause(filter: &FilterCondition) -> Result<(String, Vec<String>)> { + match filter { + FilterCondition::Simple { + column, + operator, + value, + } => { + validate_column_name(column)?; + build_simple_condition(column, operator, value) + } + FilterCondition::Logical { + and_conditions, + or_conditions, + } => { + if let Some(and_conds) = and_conditions { + if and_conds.is_empty() { + anyhow::bail!("AND conditions array cannot be empty"); + } + let mut conditions = Vec::new(); + let mut all_values = Vec::new(); + + for cond in and_conds { + let (sql, values) = build_filter_clause(cond)?; + conditions.push(format!("({})", sql)); + all_values.extend(values); + } + + Ok((conditions.join(" AND "), all_values)) + } else if let Some(or_conds) = or_conditions { + if or_conds.is_empty() { + anyhow::bail!("OR conditions array cannot be empty"); + } + let mut conditions = Vec::new(); + let mut all_values = Vec::new(); + + for cond in or_conds { + let (sql, values) = build_filter_clause(cond)?; + conditions.push(format!("({})", sql)); + all_values.extend(values); + } + + Ok((conditions.join(" OR "), all_values)) + } else { + anyhow::bail!("Logical condition must have either 'and' or 'or' field"); + } + } + FilterCondition::Not { not } => { + let (sql, values) = build_filter_clause(not)?; + Ok((format!("NOT ({})", sql), values)) + } + } +} + +/// Build a simple condition (column operator value) +fn build_simple_condition( + column: &str, + operator: &FilterOperator, + value: &Value, +) -> Result<(String, Vec<String>)> { + match operator { + FilterOperator::Eq => { + if value.is_null() { + Ok((format!("{} IS NULL", column), vec![])) + } else { + Ok((format!("{} = ?", column), vec![json_to_sql_string(value)])) + } + } + FilterOperator::Ne => { + if value.is_null() { + Ok((format!("{} IS NOT NULL", column), vec![])) + } else { + Ok((format!("{} != ?", column), vec![json_to_sql_string(value)])) + } + } + FilterOperator::Gt => Ok((format!("{} > ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Gte => Ok((format!("{} >= ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Lt => Ok((format!("{} < ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Lte => Ok((format!("{} <= ?", column), vec![json_to_sql_string(value)])), + FilterOperator::Like => Ok(( + format!("{} LIKE ?", column), + vec![json_to_sql_string(value)], + )), + FilterOperator::NotLike => Ok(( + format!("{} NOT LIKE ?", column), + vec![json_to_sql_string(value)], + )), + FilterOperator::In => { + if let Value::Array(arr) = value { + if arr.is_empty() { + anyhow::bail!("IN operator requires non-empty array"); + } + let placeholders = vec!["?"; arr.len()].join(", "); + let values: Vec<String> = arr.iter().map(json_to_sql_string).collect(); + Ok((format!("{} IN ({})", column, placeholders), values)) + } else { + anyhow::bail!("IN operator requires array value"); + } + } + FilterOperator::NotIn => { + if let Value::Array(arr) = value { + if arr.is_empty() { + anyhow::bail!("NOT IN operator requires non-empty array"); + } + let placeholders = vec!["?"; arr.len()].join(", "); + let values: Vec<String> = arr.iter().map(json_to_sql_string).collect(); + Ok((format!("{} NOT IN ({})", column, placeholders), values)) + } else { + anyhow::bail!("NOT IN operator requires array value"); + } + } + FilterOperator::IsNull => { + // Value is ignored for IS NULL + Ok((format!("{} IS NULL", column), vec![])) + } + FilterOperator::IsNotNull => { + // Value is ignored for IS NOT NULL + Ok((format!("{} IS NOT NULL", column), vec![])) + } + FilterOperator::Between => { + if let Value::Array(arr) = value { + if arr.len() != 2 { + anyhow::bail!("BETWEEN operator requires array with exactly 2 values"); + } + let val1 = json_to_sql_string(&arr[0]); + let val2 = json_to_sql_string(&arr[1]); + Ok((format!("{} BETWEEN ? AND ?", column), vec![val1, val2])) + } else { + anyhow::bail!("BETWEEN operator requires array with [min, max] values"); + } + } + } +} + +/// Convert JSON value to SQL string representation +/// Properly handles all JSON types including booleans (true/false -> 1/0) +fn json_to_sql_string(value: &Value) -> String { + match value { + Value::String(s) => s.clone(), + Value::Number(n) => n.to_string(), + Value::Bool(b) => { + if *b { + "1".to_string() + } else { + "0".to_string() + } + } + Value::Null => "NULL".to_string(), + _ => serde_json::to_string(value).unwrap_or_else(|_| "NULL".to_string()), + } +} + +/// Build legacy WHERE clause from simple key-value JSON (for backward compatibility) +pub fn build_legacy_where_clause(where_clause: &Value) -> Result<(String, Vec<String>)> { + let mut conditions = Vec::new(); + let mut values = Vec::new(); + + if let Value::Object(map) = where_clause { + for (key, value) in map { + validate_column_name(key)?; + + if value.is_null() { + // Handle NULL values with IS NULL + conditions.push(format!("{} IS NULL", key)); + // Don't add to values since IS NULL doesn't need a parameter + } else { + conditions.push(format!("{} = ?", key)); + values.push(json_to_sql_string(value)); + } + } + } else { + anyhow::bail!("WHERE clause must be an object"); + } + + if conditions.is_empty() { + anyhow::bail!("WHERE clause cannot be empty"); + } + + Ok((conditions.join(" AND "), values)) +} + +/// Build ORDER BY clause with column validation +pub fn build_order_by_clause(order_by: &[OrderBy]) -> Result<String> { + if order_by.is_empty() { + return Ok(String::new()); + } + + let mut clauses = Vec::new(); + for order in order_by { + validate_column_name(&order.column)?; + let direction = match order.direction { + OrderDirection::ASC => "ASC", + OrderDirection::DESC => "DESC", + }; + clauses.push(format!("{} {}", order.column, direction)); + } + + Ok(format!(" ORDER BY {}", clauses.join(", "))) +} + +/// Build JOIN clause from Join specifications +/// Validates table names and join conditions for security +pub fn build_join_clause(joins: &[crate::models::Join], config: &Config) -> Result<String> { + if joins.is_empty() { + return Ok(String::new()); + } + + let mut join_sql = String::new(); + + for join in joins { + // Extract base table and optional alias + let (base_table, alias) = parse_table_and_alias(&join.table); + + // Validate joined base table name + validate_table_name(&base_table, config)?; + + // Optionally validate alias format (same rules as table/column names) + if let Some(alias_name) = &alias { + let alias_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile alias regex")?; + if !alias_regex.is_match(alias_name) { + anyhow::bail!("Invalid table alias format: '{}'", alias_name); + } + } + + // Validate join condition (must be in format "table1.column1 = table2.column2") + validate_join_condition(&join.on)?; + + // Build JOIN clause based on type + let join_type_str = match join.join_type { + crate::models::JoinType::Inner => "INNER JOIN", + crate::models::JoinType::Left => "LEFT JOIN", + crate::models::JoinType::Right => "RIGHT JOIN", + }; + + // Reconstruct safe table reference + let table_ref = match alias { + Some(a) => format!("{} AS {}", base_table, a), + None => base_table, + }; + join_sql.push_str(&format!(" {} {} ON {}", join_type_str, table_ref, join.on)); + } + + Ok(join_sql) +} + +/// Validate JOIN ON condition +/// Must be in format: "table1.column1 = table2.column2" or similar simple conditions +fn validate_join_condition(condition: &str) -> Result<()> { + // Basic validation: must contain = and table.column references + if !condition.contains('=') { + anyhow::bail!("JOIN condition must contain = operator"); + } + + // Split by = and validate both sides + let parts: Vec<&str> = condition.split('=').map(|s| s.trim()).collect(); + if parts.len() != 2 { + anyhow::bail!("JOIN condition must have exactly one = operator"); + } + + // Validate both sides are table.column format + for part in parts { + validate_table_column_reference(part)?; + } + + Ok(()) +} + +/// Validate table.column reference (e.g., "assets.category_id") +fn validate_table_column_reference(reference: &str) -> Result<()> { + let parts: Vec<&str> = reference.split('.').collect(); + + if parts.len() != 2 { + anyhow::bail!( + "Table column reference must be in format 'table.column', got: {}", + reference + ); + } + + let table = parts[0].trim(); + let column = parts[1].trim(); + + // Validate table and column names follow safe patterns + let name_regex = + Regex::new(r"^[a-zA-Z][a-zA-Z0-9_]*$").context("Failed to compile name regex")?; + + if !name_regex.is_match(table) { + anyhow::bail!("Invalid table name in JOIN condition: {}", table); + } + + if !name_regex.is_match(column) { + anyhow::bail!("Invalid column name in JOIN condition: {}", column); + } + + // Additional guard against comment/terminator tokens (redundant given regex, but safe) + // Intentionally do NOT block SQL keywords like CREATE/ALTER since identifiers like + // 'created_at' are legitimate and already validated by the regex above. + let dangerous_tokens = ["--", "/*", "*/", ";"]; + for token in &dangerous_tokens { + if reference.contains(token) { + anyhow::bail!("JOIN condition contains forbidden token: '{}'", token); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_table_name() { + // Valid names + assert!(validate_column_name("users").is_ok()); + assert!(validate_column_name("asset_id").is_ok()); + assert!(validate_column_name("table123").is_ok()); + + // Invalid names + assert!(validate_column_name("123table").is_err()); // starts with number + assert!(validate_column_name("table-name").is_err()); // contains hyphen + assert!(validate_column_name("table name").is_err()); // contains space + assert!(validate_column_name("table;DROP").is_err()); // SQL injection + assert!(validate_column_name("").is_err()); // empty + } + + #[test] + fn test_validate_column_name() { + // Valid names + assert!(validate_column_name("user_id").is_ok()); + assert!(validate_column_name("firstName").is_ok()); + assert!(validate_column_name("*").is_ok()); // wildcard allowed + + // Invalid names + assert!(validate_column_name("123column").is_err()); + assert!(validate_column_name("col-umn").is_err()); + assert!(validate_column_name("col umn").is_err()); + assert!(validate_column_name("").is_err()); + } +} diff --git a/src/sql/mod.rs b/src/sql/mod.rs new file mode 100644 index 0000000..fd5b959 --- /dev/null +++ b/src/sql/mod.rs @@ -0,0 +1,8 @@ +// SQL query building and validation module +pub mod builder; + +// Re-export commonly used functions +pub use builder::{ + build_filter_clause, build_join_clause, build_legacy_where_clause, build_order_by_clause, + validate_column_name, validate_column_names, validate_table_name, +}; @@ -0,0 +1,2 @@ +- [ ] log rotation + |
