aboutsummaryrefslogtreecommitdiff
path: root/src/config.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/config.rs')
-rw-r--r--src/config.rs243
1 files changed, 243 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..c3e7e56
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,243 @@
+use anyhow::{Context, Result};
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use std::fs;
+use std::path::Path;
+
+// Structure to define individual authorized clients
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ClientDefinition {
+ pub name: String, // Friendly name for the client
+ pub ip_address: String, // Allowed IP address
+ pub port: Option<u16>, // Optional specific port (if None, any port is allowed)
+ pub group_ids: Vec<String>, // Multicast groups this client is allowed to access (empty = all)
+}
+
+// Definition of a multicast group
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct MulticastGroup {
+ pub address: String, // Multicast address
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub port: Option<u16>, // Single port (used if port_range is None)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub port_range: Option<(u16, u16)>, // Optional port range (start, end)
+}
+
+impl MulticastGroup {
+ // Helper method to get all ports that should be used
+ pub fn get_ports(&self) -> Vec<u16> {
+ if let Some(range) = self.port_range {
+ // If a range is specified, return all ports in the range
+ (range.0..=range.1).collect()
+ } else if let Some(port) = self.port {
+ // If only a single port is specified, return just that port
+ vec![port]
+ } else {
+ // Default to empty vec if neither is specified (should be validated elsewhere)
+ vec![]
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ServerConfig {
+ pub secret: String,
+ pub listen_ip: String,
+ pub listen_port: u16,
+ pub multicast_groups: HashMap<String, MulticastGroup>,
+ pub authorized_clients: Vec<ClientDefinition>,
+ #[serde(default)]
+ pub simple_auth: bool,
+ #[serde(default)]
+ pub allow_external_clients: bool, // Flag to allow clients from outside the local network
+}
+
+impl Default for ServerConfig {
+ fn default() -> Self {
+ let mut groups = HashMap::new();
+ groups.insert("default".to_string(), MulticastGroup {
+ address: "224.0.0.1".to_string(),
+ port: Some(5000),
+ port_range: None,
+ });
+ groups.insert("ssdp".to_string(), MulticastGroup {
+ address: "239.255.255.250".to_string(),
+ port: Some(1900),
+ port_range: None,
+ });
+ // Example with port range
+ groups.insert("range_example".to_string(), MulticastGroup {
+ address: "239.192.55.2".to_string(),
+ port: None,
+ port_range: Some((1680, 1685)),
+ });
+
+ Self {
+ secret: "changeme".to_string(),
+ listen_ip: "0.0.0.0".to_string(),
+ listen_port: 8989,
+ multicast_groups: groups,
+ authorized_clients: vec![
+ ClientDefinition {
+ name: "localhost".to_string(),
+ ip_address: "127.0.0.1".to_string(),
+ port: None,
+ group_ids: vec![], // Empty means all groups
+ },
+ ClientDefinition {
+ name: "example-client".to_string(),
+ ip_address: "192.168.1.100".to_string(),
+ port: Some(12345),
+ group_ids: vec!["default".to_string()], // Only the default group
+ }
+ ],
+ simple_auth: false, // Default to standard auth mode
+ allow_external_clients: false, // Default to disallow external clients
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ClientConfig {
+ pub secret: String,
+ pub server: String,
+ pub port: u16,
+ pub multicast_group_ids: Vec<String>,
+ pub test_mode: bool,
+ #[serde(default)]
+ pub nat_traversal: bool, // Flag to enable NAT traversal features
+ #[serde(default = "default_reconnect_delay")]
+ pub reconnect_delay_secs: u64,
+}
+
+fn default_reconnect_delay() -> u64 {
+ 5
+}
+
+impl Default for ClientConfig {
+ fn default() -> Self {
+ Self {
+ secret: "changeme".to_string(),
+ server: "127.0.0.1".to_string(),
+ port: 8989,
+ multicast_group_ids: vec![], // Empty means subscribe to all groups
+ test_mode: false,
+ nat_traversal: false,
+ reconnect_delay_secs: 5,
+ }
+ }
+}
+
+// Load server configuration from a file
+pub fn load_server_config<P: AsRef<Path>>(path: P) -> Result<ServerConfig> {
+ let config_str = fs::read_to_string(path)
+ .context("Failed to read server configuration file")?;
+ let config: ServerConfig = toml::from_str(&config_str)
+ .context("Failed to parse server configuration")?;
+ Ok(config)
+}
+
+// Save server configuration to a file
+pub fn save_server_config<P: AsRef<Path>>(config: &ServerConfig, path: P) -> Result<()> {
+ let toml_str = toml::to_string_pretty(config)
+ .context("Failed to serialize server configuration")?;
+ fs::write(path, toml_str)
+ .context("Failed to write server configuration file")?;
+ Ok(())
+}
+
+// Load client configuration from a file
+pub fn load_client_config<P: AsRef<Path>>(path: P) -> Result<ClientConfig> {
+ let config_str = fs::read_to_string(path)
+ .context("Failed to read client configuration file")?;
+ let config: ClientConfig = toml::from_str(&config_str)
+ .context("Failed to parse client configuration")?;
+ Ok(config)
+}
+
+// Save client configuration to a file
+pub fn save_client_config<P: AsRef<Path>>(config: &ClientConfig, path: P) -> Result<()> {
+ let toml_str = toml::to_string_pretty(config)
+ .context("Failed to serialize client configuration")?;
+ fs::write(path, toml_str)
+ .context("Failed to write client configuration file")?;
+ Ok(())
+}
+
+// Generate default configuration files if they don't exist
+pub fn ensure_default_configs() -> Result<()> {
+ let server_config_path = "server_config.toml";
+ if !Path::new(server_config_path).exists() {
+ let default_config = ServerConfig::default();
+ save_server_config(&default_config, server_config_path)?;
+ println!("Created default server configuration at {}", server_config_path);
+ }
+
+ let client_config_path = "client_config.toml";
+ if !Path::new(client_config_path).exists() {
+ let default_config = ClientConfig::default();
+ save_client_config(&default_config, client_config_path)?;
+ println!("Created default client configuration at {}", client_config_path);
+ }
+
+ Ok(())
+}
+
+// Check if a client is authorized for specified groups
+pub fn get_client_authorized_groups(
+ config: &ServerConfig,
+ ip: &str,
+ port: u16
+) -> Option<Vec<String>> {
+ // If simple_auth is enabled, all clients that know the secret get access to all groups
+ if config.simple_auth {
+ return Some(vec![]); // Empty vector means all groups
+ }
+
+ // Otherwise use the standard group authorization check
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None or matches the client's port
+ if client.port.is_none() || client.port == Some(port) {
+ // Return the group IDs or empty vector (meaning all groups)
+ return Some(client.group_ids.clone());
+ }
+ }
+ }
+
+ None // Not authorized
+}
+
+// Get client name if authorized
+pub fn get_client_name(config: &ServerConfig, ip: &str, port: u16) -> Option<String> {
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None, any port from this IP is allowed
+ if client.port.is_none() || client.port == Some(port) {
+ return Some(client.name.clone());
+ }
+ }
+ }
+
+ None
+}
+
+// Check if a client's IP address and port are authorized
+pub fn is_client_authorized(config: &ServerConfig, ip: &str, port: u16) -> bool {
+ // If simple_auth is enabled, all clients that know the secret are authorized
+ if config.simple_auth {
+ return true;
+ }
+
+ // Otherwise use the standard client authorization check
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None, any port from this IP is allowed
+ if client.port.is_none() || client.port == Some(port) {
+ return true;
+ }
+ }
+ }
+
+ false
+}