diff options
Diffstat (limited to 'src/config.rs')
-rw-r--r-- | src/config.rs | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..c3e7e56 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,243 @@ +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +// Structure to define individual authorized clients +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientDefinition { + pub name: String, // Friendly name for the client + pub ip_address: String, // Allowed IP address + pub port: Option<u16>, // Optional specific port (if None, any port is allowed) + pub group_ids: Vec<String>, // Multicast groups this client is allowed to access (empty = all) +} + +// Definition of a multicast group +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MulticastGroup { + pub address: String, // Multicast address + #[serde(skip_serializing_if = "Option::is_none")] + pub port: Option<u16>, // Single port (used if port_range is None) + #[serde(skip_serializing_if = "Option::is_none")] + pub port_range: Option<(u16, u16)>, // Optional port range (start, end) +} + +impl MulticastGroup { + // Helper method to get all ports that should be used + pub fn get_ports(&self) -> Vec<u16> { + if let Some(range) = self.port_range { + // If a range is specified, return all ports in the range + (range.0..=range.1).collect() + } else if let Some(port) = self.port { + // If only a single port is specified, return just that port + vec![port] + } else { + // Default to empty vec if neither is specified (should be validated elsewhere) + vec![] + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + pub secret: String, + pub listen_ip: String, + pub listen_port: u16, + pub multicast_groups: HashMap<String, MulticastGroup>, + pub authorized_clients: Vec<ClientDefinition>, + #[serde(default)] + pub simple_auth: bool, + #[serde(default)] + pub allow_external_clients: bool, // Flag to allow clients from outside the local network +} + +impl Default for ServerConfig { + fn default() -> Self { + let mut groups = HashMap::new(); + groups.insert("default".to_string(), MulticastGroup { + address: "224.0.0.1".to_string(), + port: Some(5000), + port_range: None, + }); + groups.insert("ssdp".to_string(), MulticastGroup { + address: "239.255.255.250".to_string(), + port: Some(1900), + port_range: None, + }); + // Example with port range + groups.insert("range_example".to_string(), MulticastGroup { + address: "239.192.55.2".to_string(), + port: None, + port_range: Some((1680, 1685)), + }); + + Self { + secret: "changeme".to_string(), + listen_ip: "0.0.0.0".to_string(), + listen_port: 8989, + multicast_groups: groups, + authorized_clients: vec![ + ClientDefinition { + name: "localhost".to_string(), + ip_address: "127.0.0.1".to_string(), + port: None, + group_ids: vec![], // Empty means all groups + }, + ClientDefinition { + name: "example-client".to_string(), + ip_address: "192.168.1.100".to_string(), + port: Some(12345), + group_ids: vec!["default".to_string()], // Only the default group + } + ], + simple_auth: false, // Default to standard auth mode + allow_external_clients: false, // Default to disallow external clients + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientConfig { + pub secret: String, + pub server: String, + pub port: u16, + pub multicast_group_ids: Vec<String>, + pub test_mode: bool, + #[serde(default)] + pub nat_traversal: bool, // Flag to enable NAT traversal features + #[serde(default = "default_reconnect_delay")] + pub reconnect_delay_secs: u64, +} + +fn default_reconnect_delay() -> u64 { + 5 +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + secret: "changeme".to_string(), + server: "127.0.0.1".to_string(), + port: 8989, + multicast_group_ids: vec![], // Empty means subscribe to all groups + test_mode: false, + nat_traversal: false, + reconnect_delay_secs: 5, + } + } +} + +// Load server configuration from a file +pub fn load_server_config<P: AsRef<Path>>(path: P) -> Result<ServerConfig> { + let config_str = fs::read_to_string(path) + .context("Failed to read server configuration file")?; + let config: ServerConfig = toml::from_str(&config_str) + .context("Failed to parse server configuration")?; + Ok(config) +} + +// Save server configuration to a file +pub fn save_server_config<P: AsRef<Path>>(config: &ServerConfig, path: P) -> Result<()> { + let toml_str = toml::to_string_pretty(config) + .context("Failed to serialize server configuration")?; + fs::write(path, toml_str) + .context("Failed to write server configuration file")?; + Ok(()) +} + +// Load client configuration from a file +pub fn load_client_config<P: AsRef<Path>>(path: P) -> Result<ClientConfig> { + let config_str = fs::read_to_string(path) + .context("Failed to read client configuration file")?; + let config: ClientConfig = toml::from_str(&config_str) + .context("Failed to parse client configuration")?; + Ok(config) +} + +// Save client configuration to a file +pub fn save_client_config<P: AsRef<Path>>(config: &ClientConfig, path: P) -> Result<()> { + let toml_str = toml::to_string_pretty(config) + .context("Failed to serialize client configuration")?; + fs::write(path, toml_str) + .context("Failed to write client configuration file")?; + Ok(()) +} + +// Generate default configuration files if they don't exist +pub fn ensure_default_configs() -> Result<()> { + let server_config_path = "server_config.toml"; + if !Path::new(server_config_path).exists() { + let default_config = ServerConfig::default(); + save_server_config(&default_config, server_config_path)?; + println!("Created default server configuration at {}", server_config_path); + } + + let client_config_path = "client_config.toml"; + if !Path::new(client_config_path).exists() { + let default_config = ClientConfig::default(); + save_client_config(&default_config, client_config_path)?; + println!("Created default client configuration at {}", client_config_path); + } + + Ok(()) +} + +// Check if a client is authorized for specified groups +pub fn get_client_authorized_groups( + config: &ServerConfig, + ip: &str, + port: u16 +) -> Option<Vec<String>> { + // If simple_auth is enabled, all clients that know the secret get access to all groups + if config.simple_auth { + return Some(vec![]); // Empty vector means all groups + } + + // Otherwise use the standard group authorization check + for client in &config.authorized_clients { + if client.ip_address == ip { + // If port is None or matches the client's port + if client.port.is_none() || client.port == Some(port) { + // Return the group IDs or empty vector (meaning all groups) + return Some(client.group_ids.clone()); + } + } + } + + None // Not authorized +} + +// Get client name if authorized +pub fn get_client_name(config: &ServerConfig, ip: &str, port: u16) -> Option<String> { + for client in &config.authorized_clients { + if client.ip_address == ip { + // If port is None, any port from this IP is allowed + if client.port.is_none() || client.port == Some(port) { + return Some(client.name.clone()); + } + } + } + + None +} + +// Check if a client's IP address and port are authorized +pub fn is_client_authorized(config: &ServerConfig, ip: &str, port: u16) -> bool { + // If simple_auth is enabled, all clients that know the secret are authorized + if config.simple_auth { + return true; + } + + // Otherwise use the standard client authorization check + for client in &config.authorized_clients { + if client.ip_address == ip { + // If port is None, any port from this IP is allowed + if client.port.is_none() || client.port == Some(port) { + return true; + } + } + } + + false +} |