diff options
| author | UMTS at Teleco <crt@teleco.ch> | 2025-12-13 02:48:13 +0100 |
|---|---|---|
| committer | UMTS at Teleco <crt@teleco.ch> | 2025-12-13 02:48:13 +0100 |
| commit | e52b8e1c2e110d0feb74feb7905c2ff064b51d55 (patch) | |
| tree | 3090814e422250e07e72cf1c83241ffd95cf20f7 /src/config.rs | |
Diffstat (limited to 'src/config.rs')
| -rw-r--r-- | src/config.rs | 845 |
1 files changed, 845 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..602c0d8 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,845 @@ +// Configuration module for SeckelAPI +use anyhow::{Context, Result}; +use ipnet::IpNet; +use serde::Deserialize; +use std::collections::HashMap; +use std::fs; +use std::str::FromStr; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub database: DatabaseConfig, + pub server: ServerConfig, + pub security: SecurityConfig, + pub permissions: PermissionsConfig, + pub logging: LoggingConfig, + pub auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + pub scheduled_queries: Option<ScheduledQueriesConfig>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PermissionsConfig { + #[serde(flatten)] + pub power_levels: HashMap<String, PowerLevelPermissions>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PowerLevelPermissions { + pub basic_rules: Vec<String>, + pub advanced_rules: Option<Vec<String>>, + pub max_limit: Option<u32>, + pub max_where_conditions: Option<u32>, + pub session_timeout_minutes: Option<u64>, + pub max_concurrent_sessions: Option<u32>, + #[serde(default = "default_true")] + pub rollback_on_error: bool, + #[serde(default = "default_false")] + pub allow_batch_operations: bool, + #[serde(default = "default_user_settings_access")] + pub user_settings_access: UserSettingsAccess, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum UserSettingsAccess { + ReadOwnOnly, + ReadWriteOwn, + ReadWriteAll, +} + +fn default_user_settings_access() -> UserSettingsAccess { + UserSettingsAccess::ReadWriteOwn +} + +#[derive(Debug, Clone, Deserialize)] +pub struct DatabaseConfig { + pub host: String, + pub port: u16, + pub database: String, + pub username: String, + pub password: String, + #[serde(default = "default_min_connections")] + pub min_connections: u32, + #[serde(default = "default_max_connections")] + pub max_connections: u32, + #[serde(default = "default_connection_timeout_seconds")] + pub connection_timeout_seconds: u64, + #[serde(default = "default_connection_timeout_wait")] + pub connection_timeout_wait: u64, + #[serde(default = "default_connection_check_interval")] + pub connection_check: u64, +} + +fn default_min_connections() -> u32 { + 1 +} + +fn default_max_connections() -> u32 { + 10 +} + +fn default_connection_timeout_seconds() -> u64 { + 30 +} + +fn default_connection_timeout_wait() -> u64 { + 5 +} + +fn default_connection_check_interval() -> u64 { + 30 +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ServerConfig { + pub host: String, + pub port: u16, + #[serde(default = "default_request_body_limit_mb")] + pub request_body_limit_mb: usize, +} + +fn default_request_body_limit_mb() -> usize { + 10 // 10 MB default +} + +#[derive(Debug, Clone, Deserialize)] +pub struct SecurityConfig { + pub whitelisted_pin_ips: Vec<String>, + pub whitelisted_string_ips: Vec<String>, + pub session_timeout_minutes: u64, + pub refresh_session_on_activity: bool, + pub max_concurrent_sessions: u32, + pub session_cleanup_interval_minutes: u64, + pub default_max_limit: u32, + pub default_max_where_conditions: u32, + #[serde(default = "default_hash_pins")] + pub hash_pins: bool, // Whether to use bcrypt for PINs (false = plaintext) + #[serde(default = "default_hash_tokens")] + pub hash_tokens: bool, // Whether to use bcrypt for login_strings (false = plaintext) + #[serde(default = "default_pin_column")] + pub pin_column: String, // Database column name for PINs + #[serde(default = "default_token_column")] + pub token_column: String, // Database column name for login strings + // Rate limiting + #[serde(default = "default_enable_rate_limiting")] + pub enable_rate_limiting: bool, // Master switch for rate limiting (disable for debugging) + + // Auth rate limiting + #[serde(default = "default_auth_rate_limit_per_minute")] + pub auth_rate_limit_per_minute: u32, // Max auth requests per IP per minute + #[serde(default = "default_auth_rate_limit_per_second")] + pub auth_rate_limit_per_second: u32, // Max auth requests per IP per second (burst protection) + + // API rate limiting + #[serde(default = "default_api_rate_limit_per_minute")] + pub api_rate_limit_per_minute: u32, // Max API calls per user per minute + #[serde(default = "default_api_rate_limit_per_second")] + pub api_rate_limit_per_second: u32, // Max API calls per user per second (burst protection) + + // Table configuration (moved from basics.toml) + #[serde(default = "default_known_tables")] + pub known_tables: Vec<String>, + #[serde(default = "default_read_only_tables")] + pub read_only_tables: Vec<String>, + #[serde(default = "default_global_write_protected_columns")] + pub global_write_protected_columns: Vec<String>, + + // User preferences access control + #[serde(default = "default_user_settings_access")] + pub default_user_settings_access: UserSettingsAccess, +} + +fn default_known_tables() -> Vec<String> { + vec![] // Empty by default, must be configured +} + +fn default_read_only_tables() -> Vec<String> { + vec![] // Empty by default +} + +fn default_global_write_protected_columns() -> Vec<String> { + vec!["id".to_string()] // Protect 'id' by default at minimum +} + +fn default_hash_pins() -> bool { + false // Default to plaintext (must be explicitly enabled for bcrypt hashing) +} + +fn default_hash_tokens() -> bool { + false // Default to plaintext (must be explicitly enabled for bcrypt hashing) +} + +fn default_pin_column() -> String { + "pin_code".to_string() +} + +fn default_token_column() -> String { + "login_string".to_string() +} + +fn default_enable_rate_limiting() -> bool { + true // Enable by default for security +} + +fn default_auth_rate_limit_per_minute() -> u32 { + 10 // 10 login attempts per IP per minute (prevents brute force) +} + +fn default_auth_rate_limit_per_second() -> u32 { + 5 // Max 5 login attempts per second (burst protection) +} + +fn default_api_rate_limit_per_minute() -> u32 { + 60 // 60 API calls per user per minute +} + +fn default_api_rate_limit_per_second() -> u32 { + 10 // Max 10 API calls per second (burst protection) +} + +#[derive(Debug, Clone, Deserialize)] +pub struct LoggingConfig { + pub request_log: Option<String>, + pub query_log: Option<String>, + pub error_log: Option<String>, + pub warning_log: Option<String>, // Warning messages + pub info_log: Option<String>, // Info messages + pub combined_log: Option<String>, // Unified log with request IDs + pub level: String, + pub mask_passwords: bool, + #[serde(default)] + pub sensitive_fields: Vec<String>, // Fields to mask beyond password/pin + #[serde(default)] + pub custom_filters: Vec<CustomLogFilter>, // Regex-based log routing +} + +#[derive(Debug, Clone, Deserialize)] +pub struct CustomLogFilter { + pub name: String, + pub output_file: String, + pub pattern: String, // Regex pattern to match + #[serde(default = "default_filter_enabled")] + pub enabled: bool, +} + +fn default_filter_enabled() -> bool { + true +} + +fn default_true() -> bool { + true +} + +fn default_false() -> bool { + false +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AutoGenerationConfig { + pub field: String, + #[serde(rename = "type")] + pub gen_type: String, + pub length: Option<u32>, + pub range_min: Option<u64>, + pub range_max: Option<u64>, + pub max_attempts: Option<u32>, + #[serde(default = "default_on_action")] + pub on_action: String, // "insert", "update", or "both" +} + +fn default_on_action() -> String { + "insert".to_string() +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ScheduledQueriesConfig { + #[serde(default)] + pub tasks: Vec<ScheduledQueryTask>, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ScheduledQueryTask { + pub name: String, + pub description: String, + pub query: String, + pub interval_minutes: u64, + #[serde(default = "default_enabled")] + pub enabled: bool, + #[serde(default = "default_run_on_startup")] + pub run_on_startup: bool, +} + +fn default_enabled() -> bool { + true +} + +fn default_run_on_startup() -> bool { + true // Run immediately on startup by default +} + +impl Config { + pub fn load() -> Result<Self> { + Self::load_from_folder() + } + + fn load_from_folder() -> Result<Self> { + // Load consolidated config files + let basics_content = + fs::read_to_string("config/basics.toml").context("Can't read config/basics.toml")?; + let security_content = fs::read_to_string("config/security.toml") + .context("Can't read config/security.toml")?; + let logging_content = + fs::read_to_string("config/logging.toml").context("Can't read config/logging.toml")?; + + // Parse individual sections + #[derive(Deserialize)] + struct BasicsWrapper { + server: ServerConfig, + database: DatabaseConfig, + #[serde(default)] + auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + } + #[derive(Deserialize)] + struct SecurityWrapper { + security: SecurityConfig, + permissions: PermissionsConfig, + } + #[derive(Deserialize)] + struct LoggingWrapper { + logging: LoggingConfig, + } + + #[derive(Deserialize)] + struct FunctionsWrapper { + #[serde(default)] + auto_generation: Option<HashMap<String, AutoGenerationConfig>>, + #[serde(default)] + scheduled_queries: Option<ScheduledQueriesConfig>, + } + + let basics: BasicsWrapper = + toml::from_str(&basics_content).context("Failed to parse basics.toml")?; + let security: SecurityWrapper = + toml::from_str(&security_content).context("Failed to parse security.toml")?; + let logging: LoggingWrapper = + toml::from_str(&logging_content).context("Failed to parse logging.toml")?; + + // Load functions.toml if it exists, otherwise use basics fallback + let functions: FunctionsWrapper = if fs::metadata("config/functions.toml").is_ok() { + let functions_content = fs::read_to_string("config/functions.toml") + .context("Can't read config/functions.toml")?; + toml::from_str(&functions_content).context("Failed to parse functions.toml")? + } else { + // Fallback to basics.toml for auto_generation + FunctionsWrapper { + auto_generation: basics.auto_generation.clone(), + scheduled_queries: None, + } + }; + + let config = Config { + database: basics.database, + server: basics.server, + security: security.security, + permissions: security.permissions, + logging: logging.logging, + auto_generation: functions.auto_generation, + scheduled_queries: functions.scheduled_queries, + }; + + // Validate configuration + config.validate()?; + + Ok(config) + } + + /// Validate configuration values + fn validate(&self) -> Result<()> { + // Validate server port (u16 is already limited to 0-65535, just check for 0) + if self.server.port == 0 { + anyhow::bail!( + "Invalid server.port: {} (must be 1-65535)", + self.server.port + ); + } + + // Validate database port (u16 is already limited to 0-65535, just check for 0) + if self.database.port == 0 { + anyhow::bail!( + "Invalid database.port: {} (must be 1-65535)", + self.database.port + ); + } + + // Validate database connection details + if self.database.host.trim().is_empty() { + anyhow::bail!("database.host cannot be empty"); + } + if self.database.database.trim().is_empty() { + anyhow::bail!("database.database cannot be empty"); + } + if self.database.username.trim().is_empty() { + anyhow::bail!("database.username cannot be empty"); + } + + // Validate PIN whitelist IPs (can be CIDR or single IPs) + for ip in &self.security.whitelisted_pin_ips { + if IpNet::from_str(ip).is_err() && ip.parse::<std::net::IpAddr>().is_err() { + anyhow::bail!("Invalid IP/CIDR in whitelisted_pin_ips: {}", ip); + } + } + + // Validate string auth whitelist IPs (can be CIDR or single IPs) + for ip in &self.security.whitelisted_string_ips { + if IpNet::from_str(ip).is_err() && ip.parse::<std::net::IpAddr>().is_err() { + anyhow::bail!("Invalid IP/CIDR in whitelisted_string_ips: {}", ip); + } + } + + // Validate session timeout + if self.security.session_timeout_minutes == 0 { + anyhow::bail!("security.session_timeout_minutes must be greater than 0"); + } + + // Validate permission syntax + for (power_level, perms) in &self.permissions.power_levels { + // Validate power level is numeric + if power_level.parse::<i32>().is_err() { + anyhow::bail!("Invalid power level '{}': must be numeric", power_level); + } + + // Validate basic rules format + for rule in &perms.basic_rules { + Self::validate_permission_rule(rule).with_context(|| { + format!("Invalid basic rule for power level {}", power_level) + })?; + } + + // Validate advanced rules format + if let Some(advanced_rules) = &perms.advanced_rules { + for rule in advanced_rules { + Self::validate_advanced_rule(rule).with_context(|| { + format!("Invalid advanced rule for power level {}", power_level) + })?; + } + } + } + + // Validate logging configuration (paths are now optional) + // No validation needed for optional log paths + + // Validate log level + let valid_levels = ["trace", "debug", "info", "warn", "error"]; + if !valid_levels.contains(&self.logging.level.to_lowercase().as_str()) { + anyhow::bail!( + "Invalid logging.level '{}': must be one of: {}", + self.logging.level, + valid_levels.join(", ") + ); + } + + Ok(()) + } + + /// Validate basic permission rule format (table:permission) + fn validate_permission_rule(rule: &str) -> Result<()> { + let parts: Vec<&str> = rule.split(':').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Permission rule '{}' must be in format 'table:permission'", + rule + ); + } + + let table = parts[0]; + let permission = parts[1]; + + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty in rule '{}'", rule); + } + + // Validate permission type + let valid_permissions = ["r", "rw", "rwd"]; + if !valid_permissions.contains(&permission) { + anyhow::bail!( + "Invalid permission '{}' in rule '{}': must be one of: {}", + permission, + rule, + valid_permissions.join(", ") + ); + } + + Ok(()) + } + + /// Validate advanced permission rule format (table.column:permission) + fn validate_advanced_rule(rule: &str) -> Result<()> { + let parts: Vec<&str> = rule.split(':').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Advanced rule '{}' must be in format 'table.column:permission'", + rule + ); + } + + let table_col = parts[0]; + let permission = parts[1]; + + // Validate table.column format + let col_parts: Vec<&str> = table_col.split('.').collect(); + if col_parts.len() != 2 { + anyhow::bail!("Advanced rule '{}' must have 'table.column' format", rule); + } + + let table = col_parts[0]; + let column = col_parts[1]; + + if table.trim().is_empty() { + anyhow::bail!("Table name cannot be empty in rule '{}'", rule); + } + if column.trim().is_empty() { + anyhow::bail!("Column name cannot be empty in rule '{}'", rule); + } + + // Validate permission type + let valid_permissions = ["r", "w", "rw", "block"]; + if !valid_permissions.contains(&permission) { + anyhow::bail!( + "Invalid permission '{}' in rule '{}': must be one of: {}", + permission, + rule, + valid_permissions.join(", ") + ); + } + + Ok(()) + } + + pub fn get_database_url(&self) -> String { + format!( + "mysql://{}:{}@{}:{}/{}", + self.database.username, + self.database.password, + self.database.host, + self.database.port, + self.database.database + ) + } + + pub fn is_pin_ip_whitelisted(&self, ip: &str) -> bool { + self.is_ip_in_whitelist(ip, &self.security.whitelisted_pin_ips) + } + + pub fn is_string_ip_whitelisted(&self, ip: &str) -> bool { + self.is_ip_in_whitelist(ip, &self.security.whitelisted_string_ips) + } + + fn is_ip_in_whitelist(&self, ip: &str, whitelist: &[String]) -> bool { + let client_ip = match ip.parse::<std::net::IpAddr>() { + Ok(addr) => addr, + Err(_) => return false, + }; + + for allowed in whitelist { + // Try to parse as a network (CIDR notation) + if let Ok(network) = IpNet::from_str(allowed) { + if network.contains(&client_ip) { + return true; + } + } + // Try to parse as a single IP address + else if let Ok(allowed_ip) = allowed.parse::<std::net::IpAddr>() { + if client_ip == allowed_ip { + return true; + } + } + } + false + } + + pub fn get_role_permissions(&self, power: i32) -> Vec<(String, String)> { + let power_str = power.to_string(); + + if let Some(power_perms) = self.permissions.power_levels.get(&power_str) { + power_perms + .basic_rules + .iter() + .filter_map(|perm| { + let parts: Vec<&str> = perm.split(':').collect(); + if parts.len() == 2 { + Some((parts[0].to_string(), parts[1].to_string())) + } else { + None + } + }) + .collect() + } else { + Vec::new() + } + } + + // Helper methods for new configuration sections + pub fn get_known_tables(&self) -> Vec<String> { + if self.security.known_tables.is_empty() { + tracing::warn!("No known_tables configured in security.toml - returning empty list. Wildcard permissions (*) will not work."); + } + self.security.known_tables.clone() + } + + pub fn is_read_only_table(&self, table: &str) -> bool { + self.security.read_only_tables.contains(&table.to_string()) + } + + pub fn get_auto_generation_config(&self, table: &str) -> Option<&AutoGenerationConfig> { + self.auto_generation + .as_ref() + .and_then(|configs| configs.get(table)) + } + + pub fn get_basic_permissions(&self, power: i32) -> Option<&Vec<String>> { + self.permissions + .power_levels + .get(&power.to_string()) + .map(|p| &p.basic_rules) + } + + pub fn get_advanced_permissions(&self, power: i32) -> Option<&Vec<String>> { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.advanced_rules.as_ref()) + } + + pub fn filter_readable_columns( + &self, + power: i32, + table: &str, + requested_columns: &[String], + ) -> Vec<String> { + if let Some(advanced_rules) = self.get_advanced_permissions(power) { + let mut allowed_columns = Vec::new(); + let mut blocked_columns = Vec::new(); + let mut has_wildcard_block = false; + let mut has_wildcard_allow = false; + + // Parse advanced rules for this table + for rule in advanced_rules { + if let Some((table_col, permission)) = rule.split_once(':') { + if let Some((rule_table, column)) = table_col.split_once('.') { + if rule_table == table { + match permission { + "block" => { + if column == "*" { + has_wildcard_block = true; + } else { + blocked_columns.push(column.to_string()); + } + } + "r" | "rw" => { + if column == "*" { + has_wildcard_allow = true; + } else { + allowed_columns.push(column.to_string()); + } + } + _ => {} + } + } + } + } + } + + // Filter requested columns based on rules + let mut result = Vec::new(); + for column in requested_columns { + let allow = if has_wildcard_block { + // If wildcard block, only allow specifically allowed columns + allowed_columns.contains(column) + } else if has_wildcard_allow { + // If wildcard allow, block only specifically blocked columns + !blocked_columns.contains(column) + } else { + // No wildcard rules, block specifically blocked columns + !blocked_columns.contains(column) + }; + + if allow { + result.push(column.clone()); + } + } + + result + } else { + // No advanced rules, return all requested columns + requested_columns.to_vec() + } + } + + pub fn filter_writable_columns( + &self, + power: i32, + table: &str, + requested_columns: &[String], + ) -> Vec<String> { + // First, apply global write-protected columns (these override everything) + let mut globally_blocked: Vec<String> = + self.security.global_write_protected_columns.clone(); + + if let Some(advanced_rules) = self.get_advanced_permissions(power) { + let mut allowed_columns = Vec::new(); + let mut blocked_columns = Vec::new(); + let mut has_wildcard_block = false; + let mut has_wildcard_allow = false; + + // Parse advanced rules for this table + for rule in advanced_rules { + if let Some((table_col, permission)) = rule.split_once(':') { + if let Some((rule_table, column)) = table_col.split_once('.') { + if rule_table == table { + match permission { + "block" => { + if column == "*" { + has_wildcard_block = true; + } else { + blocked_columns.push(column.to_string()); + } + } + "w" | "rw" => { + if column == "*" { + has_wildcard_allow = true; + } else { + allowed_columns.push(column.to_string()); + } + } + "r" => { + // Read-only: block from writing but not from reading + blocked_columns.push(column.to_string()); + } + _ => {} + } + } + } + } + } + + // Merge advanced_rules blocked columns with globally protected + blocked_columns.append(&mut globally_blocked); + + // Filter requested columns based on rules + let mut result = Vec::new(); + for column in requested_columns { + let allow = if has_wildcard_block { + // If wildcard block, only allow specifically allowed columns + allowed_columns.contains(column) + } else if has_wildcard_allow { + // If wildcard allow, block only specifically blocked columns + !blocked_columns.contains(column) + } else { + // No wildcard rules, block specifically blocked columns + !blocked_columns.contains(column) + }; + + if allow { + result.push(column.clone()); + } + } + + result + } else { + // No advanced rules, just filter out globally protected columns + requested_columns + .iter() + .filter(|col| !globally_blocked.contains(col)) + .cloned() + .collect() + } + } + + /// Get the max_limit for a specific power level (with fallback to next lower power level) + pub fn get_max_limit(&self, power: i32) -> u32 { + // Try exact match first + if let Some(perms) = self.permissions.power_levels.get(&power.to_string()) { + if let Some(limit) = perms.max_limit { + return limit; + } + } + + // Find next lower power level + let fallback_power = self.find_fallback_power_level(power); + if let Some(fb_power) = fallback_power { + tracing::warn!( + "Power level {} not found in config, falling back to power level {}", + power, + fb_power + ); + if let Some(perms) = self.permissions.power_levels.get(&fb_power.to_string()) { + if let Some(limit) = perms.max_limit { + return limit; + } + } + } + + // Ultimate fallback to default + self.security.default_max_limit + } + + /// Get the max_where_conditions for a specific power level (with fallback to next lower power level) + pub fn get_max_where_conditions(&self, power: i32) -> u32 { + // Try exact match first + if let Some(perms) = self.permissions.power_levels.get(&power.to_string()) { + if let Some(max_where) = perms.max_where_conditions { + return max_where; + } + } + + // Find next lower power level + let fallback_power = self.find_fallback_power_level(power); + if let Some(fb_power) = fallback_power { + tracing::warn!( + "Power level {} not found in config, falling back to power level {}", + power, + fb_power + ); + if let Some(perms) = self.permissions.power_levels.get(&fb_power.to_string()) { + if let Some(max_where) = perms.max_where_conditions { + return max_where; + } + } + } + + // Ultimate fallback to default + self.security.default_max_where_conditions + } + + /// Find the next lower configured power level (e.g., power=60 → fallback to 50) + fn find_fallback_power_level(&self, power: i32) -> Option<i32> { + let mut available_powers: Vec<i32> = self + .permissions + .power_levels + .keys() + .filter_map(|k| k.parse::<i32>().ok()) + .filter(|&p| p < power) // Only consider lower power levels + .collect(); + + available_powers.sort_by(|a, b| b.cmp(a)); // Sort descending + available_powers.first().copied() + } + + /// Get the session_timeout_minutes for a specific power level (with fallback to default) + pub fn get_session_timeout(&self, power: i32) -> u64 { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.session_timeout_minutes) + .unwrap_or(self.security.session_timeout_minutes) + } + + /// Get the max_concurrent_sessions for a specific power level (with fallback to default) + pub fn get_max_concurrent_sessions(&self, power: i32) -> u32 { + self.permissions + .power_levels + .get(&power.to_string()) + .and_then(|p| p.max_concurrent_sessions) + .unwrap_or(self.security.max_concurrent_sessions) + } +} |
