diff options
author | Kablersalat <crt@adastra7.net> | 2025-06-05 00:14:27 +0200 |
---|---|---|
committer | Kablersalat <crt@adastra7.net> | 2025-06-05 00:14:27 +0200 |
commit | a0c754fc727a35d775c25857898986f4273db0a5 (patch) | |
tree | 4129d20013ff0a9b80051c0d9a74484cd753b633 /src | |
parent | fd956fde0ba92b4aa7b0f5322cfb93951bb01fbb (diff) |
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.rs | 29 | ||||
-rw-r--r-- | src/bin/client.rs | 467 | ||||
-rw-r--r-- | src/bin/mcast_test.rs | 123 | ||||
-rw-r--r-- | src/bin/server.rs | 467 | ||||
-rw-r--r-- | src/config.rs | 243 | ||||
-rw-r--r-- | src/lib.rs | 13 | ||||
-rw-r--r-- | src/protocol.rs | 129 |
7 files changed, 1471 insertions, 0 deletions
diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..153f0a9 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,29 @@ +use hmac::{Hmac, Mac}; +use rand::{rngs::OsRng, RngCore}; +use sha2::Sha256; + +// Generate a random nonce +pub fn generate_nonce() -> String { + let mut bytes = [0u8; 16]; + OsRng.fill_bytes(&mut bytes); + hex::encode(bytes) +} + +// Calculate HMAC using the shared secret +pub fn calculate_hmac(secret: &str, data: &str) -> String { + type HmacSha256 = Hmac<Sha256>; + let mut mac = HmacSha256::new_from_slice(secret.as_bytes()) + .expect("HMAC can take key of any size"); + + mac.update(data.as_bytes()); + let result = mac.finalize(); + let code_bytes = result.into_bytes(); + + hex::encode(code_bytes) +} + +// Verify an HMAC token +pub fn verify_hmac(secret: &str, data: &str, expected: &str) -> bool { + let calculated = calculate_hmac(secret, data); + calculated == expected +} diff --git a/src/bin/client.rs b/src/bin/client.rs new file mode 100644 index 0000000..5c39fb9 --- /dev/null +++ b/src/bin/client.rs @@ -0,0 +1,467 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{debug, error, info, warn}; +use multicast_relay::{ + auth::{calculate_hmac, generate_nonce, verify_hmac}, + config::{load_client_config, ensure_default_configs, ClientConfig}, + protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message}, + DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES, +}; +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + path::PathBuf, + str::FromStr, + time::Duration, + sync::Arc, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpStream, UdpSocket}, + signal, + sync::Notify, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long, default_value = "client_config.toml")] + config: PathBuf, + + #[arg(short, long, action)] + generate_default: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + // Generate default configs if requested + if args.generate_default { + ensure_default_configs()?; + return Ok(()); + } + + // Load configuration + let config = load_client_config(&args.config) + .context(format!("Failed to load config from {:?}", args.config))?; + + info!("Client configuration loaded from {:?}", args.config); + + // Create a notification for clean shutdown + let shutdown = Arc::new(Notify::new()); + let shutdown_signal = shutdown.clone(); + + // Setup signal handler for Ctrl+C + tokio::spawn(async move { + if let Err(e) = signal::ctrl_c().await { + error!("Failed to listen for Ctrl+C: {}", e); + return; + } + info!("Received Ctrl+C, shutting down..."); + shutdown_signal.notify_one(); + }); + + let server_addr = format!("{}:{}", config.server, config.port); + + // Main reconnection loop + loop { + info!("Connecting to server at {}", server_addr); + + // Try to connect with timeout + let connect_result = match tokio::time::timeout( + Duration::from_secs(5), + TcpStream::connect(&server_addr), + ).await { + Ok(result) => result, + Err(_) => { + warn!("Connection attempt timed out"); + if !handle_reconnect(&shutdown, &config).await { + break; + } + continue; + } + }; + + match connect_result { + Ok(stream) => { + info!("Connected to server"); + + // Run the client session + match run_client_session(stream, &config, &shutdown).await { + Ok(_) => { + info!("Client session ended normally"); + break; + }, + Err(e) => { + error!("Client session error: {}", e); + if !handle_reconnect(&shutdown, &config).await { + break; + } + } + } + }, + Err(e) => { + error!("Failed to connect: {}", e); + if !handle_reconnect(&shutdown, &config).await { + break; + } + } + } + } + + Ok(()) +} + +// Helper function to handle reconnection delay +// Returns false if shutdown was requested +async fn handle_reconnect(shutdown: &Arc<Notify>, client_config: &ClientConfig) -> bool { + let delay = Duration::from_secs(client_config.reconnect_delay_secs); + info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs); + tokio::select! { + _ = shutdown.notified() => { + info!("Shutdown requested during reconnect"); + return false; + } + _ = tokio::time::sleep(delay) => {} + } + true +} + +// Add a new enum to distinguish message types +#[derive(PartialEq)] +enum StatusMessageType { + ServerHeartbeat, + ServerPong, + Other +} + +// The main client session function that handles a single connection +async fn run_client_session( + mut stream: TcpStream, + config: &ClientConfig, + shutdown: &Arc<Notify>, +) -> Result<()> { + // Authenticate + if let Err(e) = authenticate(&mut stream, &config.secret).await { + return Err(anyhow::anyhow!("Authentication failed: {}", e)); + } + info!("Authentication successful"); + + // Check if test mode is enabled in config + let mut test_mode = config.test_mode; + + // Request server's multicast group information + let groups_request = Message::MulticastGroupsRequest; + let request_bytes = serialize_message(&groups_request)?; + stream.write_all(&request_bytes).await?; + + // Wait for response with multicast group information, ignoring other message types + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let mut groups_response = None; + + // Keep reading until we get the groups response or timeout + let mut attempts = 0; + while groups_response.is_none() && attempts < 10 { + match stream.read(&mut buf).await { + Ok(n) if n > 0 => { + match robust_deserialize_message(&buf[..n]) { + Ok(Message::MulticastGroupsResponse { groups }) => { + groups_response = Some(groups); + }, + Ok(Message::PingStatus { timestamp: _, status }) => { + // Handle ping but keep waiting for multicast groups + info!("Got server ping: {}", status); + }, + Ok(other_msg) => { + debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg); + }, + Err(e) => { + error!("Failed to deserialize message: {}", e); + } + } + }, + Ok(0) => return Err(anyhow::anyhow!("Server closed connection")), + Ok(_) => {}, + Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e)) + } + + // If we didn't get groups yet, wait a bit and try again + if groups_response.is_none() { + attempts += 1; + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + // Now we either have the groups or we timed out + let groups_response = groups_response.ok_or_else(|| + anyhow::anyhow!("Failed to receive multicast group information from server"))?; + + info!("Available multicast groups from server:"); + for (id, group) in &groups_response { + info!(" - {} -> {}:{}", id, group.address, group.port); + } + + // Determine which groups to subscribe to + let groups_to_subscribe: Vec<String> = if config.multicast_group_ids.is_empty() { + // If no specific groups are requested, subscribe to all + info!("No specific groups requested, subscribing to all available groups"); + groups_response.keys().cloned().collect() + } else { + // Otherwise, subscribe only to requested groups that exist + let mut valid_groups = Vec::new(); + for group_id in &config.multicast_group_ids { + if groups_response.contains_key(group_id.as_str()) { + valid_groups.push(group_id.clone()); + } else { + warn!("Requested group '{}' does not exist on server", group_id); + } + } + + if valid_groups.is_empty() { + warn!("None of the requested groups exist on server. No data will be received."); + } else { + info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups); + } + + valid_groups + }; + + // Send subscription message + let subscribe_msg = Message::Subscribe { + group_ids: groups_to_subscribe.clone(), + }; + let subscribe_bytes = serialize_message(&subscribe_msg)?; + stream.write_all(&subscribe_bytes).await?; + + // Create UDP sockets for local retransmission (skip if in test mode) + let mut sockets: HashMap<String, UdpSocket> = HashMap::new(); + if !test_mode { + for group_id in &groups_to_subscribe { + if let Some(group_info) = groups_response.get(group_id.as_str()) { + info!("Creating socket for group {} ({} on port {})", + group_id, group_info.address, group_info.port); + + match UdpSocket::bind("0.0.0.0:0").await { + Ok(socket) => { + sockets.insert(group_id.clone(), socket); + info!("Successfully created UDP socket for group {}", group_id); + }, + Err(e) => { + error!("Failed to create UDP socket for group {}: {}", group_id, e); + } + } + } + } + + if sockets.is_empty() && !groups_to_subscribe.is_empty() { + error!("Failed to create any UDP sockets"); + warn!("Falling back to test mode due to socket creation failure"); + test_mode = true; + } + } + + // Display test mode banner if enabled + if test_mode { + println!("{}", TEST_MODE_BANNER); + info!("Test mode enabled - packets will be displayed but not sent to network"); + } + + // Main receive loop + info!("Listening for multicast traffic from server"); + + // Set the read timeout for the stream + stream.set_nodelay(true)?; + + // Remove problematic code that uses unsupported methods and the nix crate + if config.nat_traversal { + info!("NAT traversal mode enabled - using more frequent keepalives"); + } + + // Calculate appropriate ping interval based on NAT traversal setting + let ping_interval = if config.nat_traversal { + Duration::from_secs(25) // More frequent for NAT + } else { + Duration::from_secs(55) + }; + + // Main receive loop + loop { + tokio::select! { + _ = shutdown.notified() => { + info!("Shutdown requested, ending client session"); + return Ok(()); + } + read_result = stream.read(&mut buf) => { + match read_result { + Ok(0) => { + info!("Server closed connection"); + return Err(anyhow::anyhow!("Server closed connection")); + } + Ok(n) => { + match robust_deserialize_message(&buf[..n]) { + Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => { + if test_mode { + println!("\n----- MULTICAST PACKET -----"); + println!("Group: {}", group_id); + println!("Source: {}", source); + println!("Destination: {}:{}", destination, port); + println!("Size: {} bytes", data.len()); + println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES)); + println!("---------------------------\n"); + } else { + info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes", + group_id, source, destination, port, data.len()); + + // Get socket for this group + if let Some(socket) = sockets.get(&group_id) { + // Parse destination address directly from the packet + match IpAddr::from_str(&destination) { + Ok(dest_addr) => { + // Create destination socket address using the packet's port + let dest = SocketAddr::new(dest_addr, port); + + info!("Forwarding packet to {}:{}", dest_addr, port); + + // Retransmit locally + match socket.send_to(&data, dest).await { + Ok(sent) => { + info!("Successfully forwarded {} of {} bytes to {}:{}", + sent, data.len(), destination, port); + }, + Err(e) => { + error!("Failed to retransmit packet for group {} to {}:{}: {}", + group_id, destination, port, e); + } + } + }, + Err(e) => { + error!("Invalid destination address {}: {}", destination, e); + } + } + } else { + warn!("No socket available for group {}", group_id); + } + } + }, + Ok(Message::ConfigResponse { config: _server_config }) => { + info!("Server configuration received"); + }, + Ok(Message::PingStatus { timestamp: _, status }) => { + // Parse the message type based on the status text + let msg_type = if status.starts_with("Server heartbeat to") { + StatusMessageType::ServerHeartbeat + } else if status.starts_with("Server connection to") { + StatusMessageType::ServerPong + } else { + StatusMessageType::Other + }; + + // Only respond with a pong to server heartbeats + if msg_type == StatusMessageType::ServerHeartbeat { + debug!("Received server heartbeat: {}", status); + + // Send a pong response (but only for heartbeats) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let pong = Message::PingStatus { + timestamp: now, + status: "Client pong response".to_string(), + }; + + if let Ok(bytes) = serialize_message(&pong) { + let _ = stream.write_all(&bytes).await; + } + } else { + // Just log other status messages without responding + debug!("Connection Status: {}", status); + } + }, + Ok(_) => debug!("Received other message type"), + Err(e) => { + error!("Failed to deserialize message: {}", e); + // Don't return/break on deserialization errors - continue reading + } + } + } + Err(e) => { + error!("Error reading from server: {}", e); + return Err(anyhow::anyhow!("Connection error: {}", e)); + } + } + }, + _ = tokio::time::sleep(ping_interval) => { + // Send regular ping to keep NAT connection alive + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let ping_msg = Message::PingStatus { + timestamp: now, + status: if config.nat_traversal { + "Client keepalive ping".to_string() // Changed text to be more specific + } else { + "Client periodic ping".to_string() + }, + }; + + match serialize_message(&ping_msg) { + Ok(bytes) => { + if let Err(e) = stream.write_all(&bytes).await { + error!("Failed to ping server: {}", e); + return Err(anyhow::anyhow!("Server ping failed: {}", e)); + } + debug!("Connection check sent to server"); + }, + Err(e) => error!("Failed to serialize ping message: {}", e), + } + } + } + } +} + +async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> { + // Generate client nonce + let client_nonce = generate_nonce(); + + // Send auth request + let auth_request = Message::AuthRequest { + client_nonce: client_nonce.clone(), + }; + let request_bytes = serialize_message(&auth_request)?; + stream.write_all(&request_bytes).await?; + + // Receive response + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let n = stream.read(&mut buf).await?; + let response = robust_deserialize_message(&buf[..n])?; + + if let Message::AuthResponse { server_nonce, auth_token } = response { + // Verify server's token + let expected_data = format!("{}{}", client_nonce, server_nonce); + if !verify_hmac(secret, &expected_data, &auth_token) { + return Err(anyhow::anyhow!("Server authentication failed")); + } + + // Calculate our token + let auth_data = format!("{}{}", server_nonce, client_nonce); + let client_token = calculate_hmac(secret, &auth_data); + + // Send confirmation + let confirm = Message::AuthConfirm { + auth_token: client_token, + }; + + let confirm_bytes = serialize_message(&confirm)?; + stream.write_all(&confirm_bytes).await?; + + Ok(()) + } else { + Err(anyhow::anyhow!("Unexpected response from server")) + } +} diff --git a/src/bin/mcast_test.rs b/src/bin/mcast_test.rs new file mode 100644 index 0000000..782610c --- /dev/null +++ b/src/bin/mcast_test.rs @@ -0,0 +1,123 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{error, info}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::str::FromStr; +use std::time::{Duration, Instant}; +use tokio::net::UdpSocket; + +#[derive(Parser, Debug)] +#[command(author, version, about = "Multicast packet generator for testing CastRepeat")] +struct Args { + #[arg(short, long, default_value = "239.192.55.1")] + multicast_addr: String, + + #[arg(short, long, default_value = "1681")] + port: u16, + + #[arg(short, long, default_value = "1000")] + interval_ms: u64, + + #[arg(short, long, default_value = "60")] + duration_sec: u64, + + #[arg(short, long, default_value = "Test packet")] + message: String, + + #[arg(short, long)] + interface: Option<String>, + + #[arg(short = 'b', long, action)] + binary_mode: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + info!("CastRepeat Multicast Packet Generator"); + info!("--------------------------------"); + info!("Multicast Address: {}", args.multicast_addr); + info!("Port: {}", args.port); + info!("Interval: {} ms", args.interval_ms); + info!("Duration: {} sec", args.duration_sec); + if let Some(interface) = &args.interface { + info!("Interface: {}", interface); + } + info!("Mode: {}", if args.binary_mode { "Binary" } else { "Text" }); + info!("--------------------------------"); + + // Verify multicast address + let mcast_addr = match IpAddr::from_str(&args.multicast_addr) { + Ok(IpAddr::V4(addr)) if addr.is_multicast() => addr, + _ => { + error!("Invalid multicast address: {}", args.multicast_addr); + return Ok(()); + } + }; + + // Create socket + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .context("Failed to create socket")?; + + socket.set_multicast_ttl_v4(4)?; + socket.set_nonblocking(true)?; + + // Set the multicast interface if specified + if let Some(if_str) = &args.interface { + if let Ok(if_addr) = Ipv4Addr::from_str(if_str) { + socket.set_multicast_if_v4(&if_addr)?; + info!("Using interface: {}", if_addr); + } else { + error!("Invalid interface address: {}", if_str); + return Ok(()); + } + } + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + socket.bind(&addr.into())?; + + let socket = UdpSocket::from_std(socket.into())?; + let dest_addr = SocketAddr::new(IpAddr::V4(mcast_addr), args.port); + + // Start sending packets + info!("Sending packets to {}...", dest_addr); + + let start = Instant::now(); + let end = start + Duration::from_secs(args.duration_sec); + let mut counter: u64 = 0; // Specify counter type as u64 + + while Instant::now() < end { + counter += 1; + + // Create either text or binary test data + let data = if args.binary_mode { + // Create binary test data (similar to what we might see in the field) + let mut packet = Vec::with_capacity(16); + packet.extend_from_slice(b"REL\0"); // 4 bytes header + packet.extend_from_slice(&counter.to_be_bytes()[4..]); // 4 bytes counter + packet.extend_from_slice(&[0x00, 0x10, 0x02, 0x00]); // 4 bytes + packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x90]); // 4 bytes + packet + } else { + // Create text test data + format!("{} #{}", args.message, counter).into_bytes() + }; + + match socket.send_to(&data, &dest_addr).await { + Ok(bytes) => { + info!("Sent packet #{}: {} bytes", counter, bytes); + } + Err(e) => { + error!("Error sending packet: {}", e); + } + } + + tokio::time::sleep(Duration::from_millis(args.interval_ms)).await; + } + + info!("Done. Sent {} packets in {} seconds.", counter, args.duration_sec); + Ok(()) +} diff --git a/src/bin/server.rs b/src/bin/server.rs new file mode 100644 index 0000000..7fe0109 --- /dev/null +++ b/src/bin/server.rs @@ -0,0 +1,467 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{debug, error, info, warn}; +use multicast_relay::{ + auth::{calculate_hmac, generate_nonce, verify_hmac}, + config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups}, + protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo}, + DEFAULT_BUFFER_SIZE, +}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::PathBuf, + str::FromStr, + sync::Arc, + time::Duration, // Add this import for Duration +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + sync::{mpsc, Mutex}, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long, default_value = "server_config.toml")] + config: PathBuf, + + #[arg(short, long, action)] + generate_default: bool, +} + +type ClientMap = Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + // Generate default configs if requested + if args.generate_default { + ensure_default_configs()?; + return Ok(()); + } + + // Load configuration + let config = load_server_config(&args.config) + .context(format!("Failed to load config from {:?}", args.config))?; + + info!("Server configuration loaded from {:?}", args.config); + + let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port); + let listener = TcpListener::bind(&listen_addr).await + .context("Failed to bind TCP listener")?; + + info!("Server listening on {}", listen_addr); + + // Setup multicast receivers + let clients: ClientMap = Arc::new(Mutex::new(HashMap::new())); + + // Start multicast listeners for each multicast group + for (group_id, group) in &config.multicast_groups { + let ports = group.get_ports(); + if ports.is_empty() { + error!("No ports defined for group {}", group_id); + continue; + } + + let display_group_id = group_id.clone(); + let ports_display = if ports.len() == 1 { + format!("port {}", ports[0]) + } else { + format!("ports {}-{}", ports[0], ports.last().unwrap()) + }; + + // Create a listener for each port in the range + for port in ports { + let clients = clients.clone(); + let _secret = config.secret.clone(); + let group_id_clone = group_id.clone(); + let mut group_info = group.clone(); + + // Set the specific port for this listener + group_info.port = Some(port); + group_info.port_range = None; + + tokio::spawn(async move { + if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await { + error!("Multicast listener error for group {} port {}: {}", + group_id_clone, port, e); + } + }); + } + + info!("Listening for multicast group {} on address {} with {}", + display_group_id, group.address, ports_display); + } + + // Store config for use in client handlers + let config = Arc::new(config); + + // Accept client connections + while let Ok((stream, addr)) = listener.accept().await { + info!("New client connection from: {}", addr); + let secret = config.secret.clone(); + let clients = clients.clone(); + let config = config.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_client(stream, addr, &secret, clients, config).await { + error!("Client error: {}: {}", addr, e); + } + info!("Client disconnected: {}", addr); + }); + } + + Ok(()) +} + +async fn listen_to_multicast( + group_id: &str, + group: &MulticastGroup, + clients: ClientMap +) -> Result<()> { + // Get the port to use + let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?; + + // Parse the multicast address + let mcast_ip = match IpAddr::from_str(&group.address) + .context("Invalid multicast address")? { + IpAddr::V4(addr) => addr, + _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported")) + }; + + // Create a UDP socket with more explicit settings + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .context("Failed to create socket")?; + + // Important: Set socket options + socket.set_reuse_address(true)?; + + #[cfg(unix)] + socket.set_reuse_port(true)?; + + socket.set_nonblocking(true)?; + socket.set_multicast_loop_v4(true)?; + + // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port + // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port + let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port); + info!("Binding multicast listener to specific address: {:?}", bind_addr); + socket.bind(&bind_addr.into())?; + + // Join the multicast group with a specific interface + let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface + info!("Joining multicast group {} on interface {:?}", mcast_ip, interface); + socket.join_multicast_v4(&mcast_ip, &interface)?; + + // Additional multicast option: set the IP_MULTICAST_IF option + socket.set_multicast_if_v4(&interface)?; + + // Convert to tokio socket + let udp_socket = tokio::net::UdpSocket::from_std(socket.into()) + .context("Failed to convert socket to async")?; + + info!("Multicast listener ready and bound specifically to {}:{} (group {})", + group.address, port, group_id); + + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let group_id = group_id.to_string(); + + loop { + match udp_socket.recv_from(&mut buf).await { + Ok((len, src)) => { + // Since we're bound to the exact multicast address, we can be confident + // this packet was sent to our specific multicast group + let data = buf[..len].to_vec(); + + info!("RECEIVED: group={} from={} size={} destination={}:{}", + group_id, src, len, mcast_ip, port); + + // Create a message with the packet + let message = Message::MulticastPacket { + group_id: group_id.clone(), + source: src, + destination: group.address.clone(), + port, + data, + }; + + // Send to clients + match serialize_message(&message) { + Ok(serialized) => { + let clients_lock = clients.lock().await; + for (client_addr, sender) in clients_lock.iter() { + if sender.send(serialized.clone()).await.is_err() { + debug!("Failed to send to client {}", client_addr); + } else { + debug!("Sent multicast packet to client {}", client_addr); + } + } + } + Err(e) => error!("Failed to serialize message: {}", e), + } + } + Err(e) => { + if e.kind() != std::io::ErrorKind::WouldBlock { + error!("Error receiving from socket: {}", e); + } + // Small delay to avoid busy waiting on errors + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + } +} + +#[derive(PartialEq)] +enum StatusMessageType { + ClientHeartbeat, + ClientPong, + Other +} + +async fn handle_client( + stream: TcpStream, + addr: SocketAddr, + secret: &str, + clients: ClientMap, + config: Arc<ServerConfig>, +) -> Result<()> { + // Check if external clients are allowed when client is not from localhost + if !config.allow_external_clients && + !addr.ip().is_loopback() && + !addr.ip().to_string().starts_with("192.168.") && + !addr.ip().to_string().starts_with("10.") { + warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr); + return Err(anyhow::anyhow!("External clients not allowed")); + } + + // Split the TCP stream into read and write parts once + let (mut read_stream, mut write_stream) = tokio::io::split(stream); + + // Authentication using the split streams + if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? { + return Err(anyhow::anyhow!("Authentication failed")); + } + + info!("Client authenticated: {}", addr); + + // Get client info + let client_ip = addr.ip().to_string(); + let client_port = addr.port(); + + // Check if client has specific group permissions + let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) { + Some(groups) => groups, + None => return Err(anyhow::anyhow!("Client not authorized for any groups")), + }; + + // Create channel for sending multicast packets to this client + let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100); + + // Add client to map + clients.lock().await.insert(addr, tx.clone()); + + // Create HashMap of available groups for this client + let mut available_groups = HashMap::new(); + for (id, group) in &config.multicast_groups { + // If client has empty group list (all allowed) or specific group is in list + if authorized_groups.is_empty() || authorized_groups.contains(id) { + let ports = group.get_ports(); + if ports.is_empty() { + continue; + } + + // Primary port is the first one + let primary_port = ports[0]; + + // Get additional ports if there are any + let additional_ports = if ports.len() > 1 { + Some(ports[1..].to_vec()) + } else { + None + }; + + available_groups.insert(id.clone(), MulticastGroupInfo { + address: group.address.clone(), + port: primary_port, + additional_ports, + }); + } + } + + // IMPORTANT: Clone tx before moving it into the spawn + let tx_for_read = tx.clone(); + + // Spawn task to read client messages + let clients_clone = clients.clone(); + + // Use the already split read_stream + let read_task = tokio::spawn(async move { + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + loop { + match read_stream.read(&mut buf).await { + Ok(0) => break, // Connection closed + Ok(n) => { + if let Ok(msg) = deserialize_message(&buf[..n]) { + match msg { + Message::Subscribe { group_ids } => { + info!("Client {} subscribing to groups: {:?}", addr, group_ids); + // Group subscriptions handled by server + }, + Message::MulticastGroupsRequest => { + // Send available groups to client + let response = Message::MulticastGroupsResponse { + groups: available_groups.clone() + }; + + if let Ok(bytes) = serialize_message(&response) { + let _ = tx_for_read.send(bytes).await; + } + }, + Message::PingStatus { timestamp, status } => { + // Determine the type of status message + let msg_type = if status.starts_with("Client keepalive ping") || + status.starts_with("Client periodic ping") { + StatusMessageType::ClientHeartbeat + } else if status.starts_with("Client pong response") { + StatusMessageType::ClientPong + } else { + StatusMessageType::Other + }; + + // Log the message receipt + match msg_type { + StatusMessageType::ClientHeartbeat => { + info!("Heartbeat from client {}: {}", addr, status); + + // Respond only to actual heartbeat pings, not pong responses + let response = Message::PingStatus { + timestamp, + status: format!("Server connection to {} is OK", addr), + }; + + if let Ok(bytes) = serialize_message(&response) { + let _ = tx_for_read.send(bytes).await; + } + }, + StatusMessageType::ClientPong => { + // Just log pongs without responding to avoid loops + debug!("Pong from client {}: {}", addr, status); + }, + StatusMessageType::Other => { + info!("Status message from client {}: {}", addr, status); + } + } + }, + _ => {} + } + } + } + Err(e) => { + error!("Error reading from client: {}: {}", addr, e); + break; + } + } + } + // Clean up on disconnect + clients_clone.lock().await.remove(&addr); + info!("Client reader task ended: {}", addr); + }); + + // Forward multicast packets to client using the already split write_stream + let write_task = tokio::spawn(async move { + while let Some(packet) = rx.recv().await { + if let Err(e) = write_stream.write_all(&packet).await { + error!("Error writing to client {}: {}", addr, e); + break; + } + } + info!("Client writer task ended: {}", addr); + }); + + // Now tx is still valid here - use it for heartbeats, but with a delay + let tx_for_heartbeat = tx.clone(); + let client_addr = addr.clone(); + tokio::spawn(async move { + // Add initial delay before starting heartbeats to avoid interfering with initial setup messages + tokio::time::sleep(Duration::from_secs(5)).await; + + let mut interval = tokio::time::interval(Duration::from_secs(30)); + loop { + interval.tick().await; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Send a heartbeat with a clear identifier + let msg = Message::PingStatus { + timestamp: now, + status: format!("Server heartbeat to {}", client_addr), + }; + + if let Ok(bytes) = serialize_message(&msg) { + if tx_for_heartbeat.send(bytes).await.is_err() { + break; + } + } + } + }); + + // Wait for either task to complete + tokio::select! { + _ = read_task => {}, + _ = write_task => {}, + } + + // Clean up + clients.lock().await.remove(&addr); + Ok(()) +} + +async fn authenticate_client( + reader: &mut (impl AsyncReadExt + Unpin), + writer: &mut (impl AsyncWriteExt + Unpin), + _addr: SocketAddr, + secret: &str +) -> Result<bool> { + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + + // Receive auth request + let n = reader.read(&mut buf).await?; + let auth_request = deserialize_message(&buf[..n])?; + + if let Message::AuthRequest { client_nonce } = auth_request { + // Generate server nonce + let server_nonce = generate_nonce(); + + // Calculate auth token + let auth_data = format!("{}{}", client_nonce, server_nonce); + let auth_token = calculate_hmac(secret, &auth_data); + + // Send response + let response = Message::AuthResponse { + server_nonce: server_nonce.clone(), + auth_token, + }; + + let response_bytes = serialize_message(&response)?; + writer.write_all(&response_bytes).await?; + + // Receive confirmation + let n = reader.read(&mut buf).await?; + let auth_confirm = deserialize_message(&buf[..n])?; + + if let Message::AuthConfirm { auth_token } = auth_confirm { + // Verify token + let expected_data = format!("{}{}", server_nonce, client_nonce); + return Ok(verify_hmac(secret, &expected_data, &auth_token)); + } + } + + Ok(false) +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..c3e7e56 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,243 @@ +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +// Structure to define individual authorized clients +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientDefinition { + pub name: String, // Friendly name for the client + pub ip_address: String, // Allowed IP address + pub port: Option<u16>, // Optional specific port (if None, any port is allowed) + pub group_ids: Vec<String>, // Multicast groups this client is allowed to access (empty = all) +} + +// Definition of a multicast group +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MulticastGroup { + pub address: String, // Multicast address + #[serde(skip_serializing_if = "Option::is_none")] + pub port: Option<u16>, // Single port (used if port_range is None) + #[serde(skip_serializing_if = "Option::is_none")] + pub port_range: Option<(u16, u16)>, // Optional port range (start, end) +} + +impl MulticastGroup { + // Helper method to get all ports that should be used + pub fn get_ports(&self) -> Vec<u16> { + if let Some(range) = self.port_range { + // If a range is specified, return all ports in the range + (range.0..=range.1).collect() + } else if let Some(port) = self.port { + // If only a single port is specified, return just that port + vec![port] + } else { + // Default to empty vec if neither is specified (should be validated elsewhere) + vec![] + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + pub secret: String, + pub listen_ip: String, + pub listen_port: u16, + pub multicast_groups: HashMap<String, MulticastGroup>, + pub authorized_clients: Vec<ClientDefinition>, + #[serde(default)] + pub simple_auth: bool, + #[serde(default)] + pub allow_external_clients: bool, // Flag to allow clients from outside the local network +} + +impl Default for ServerConfig { + fn default() -> Self { + let mut groups = HashMap::new(); + groups.insert("default".to_string(), MulticastGroup { + address: "224.0.0.1".to_string(), + port: Some(5000), + port_range: None, + }); + groups.insert("ssdp".to_string(), MulticastGroup { + address: "239.255.255.250".to_string(), + port: Some(1900), + port_range: None, + }); + // Example with port range + groups.insert("range_example".to_string(), MulticastGroup { + address: "239.192.55.2".to_string(), + port: None, + port_range: Some((1680, 1685)), + }); + + Self { + secret: "changeme".to_string(), + listen_ip: "0.0.0.0".to_string(), + listen_port: 8989, + multicast_groups: groups, + authorized_clients: vec![ + ClientDefinition { + name: "localhost".to_string(), + ip_address: "127.0.0.1".to_string(), + port: None, + group_ids: vec![], // Empty means all groups + }, + ClientDefinition { + name: "example-client".to_string(), + ip_address: "192.168.1.100".to_string(), + port: Some(12345), + group_ids: vec!["default".to_string()], // Only the default group + } + ], + simple_auth: false, // Default to standard auth mode + allow_external_clients: false, // Default to disallow external clients + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientConfig { + pub secret: String, + pub server: String, + pub port: u16, + pub multicast_group_ids: Vec<String>, + pub test_mode: bool, + #[serde(default)] + pub nat_traversal: bool, // Flag to enable NAT traversal features + #[serde(default = "default_reconnect_delay")] + pub reconnect_delay_secs: u64, +} + +fn default_reconnect_delay() -> u64 { + 5 +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + secret: "changeme".to_string(), + server: "127.0.0.1".to_string(), + port: 8989, + multicast_group_ids: vec![], // Empty means subscribe to all groups + test_mode: false, + nat_traversal: false, + reconnect_delay_secs: 5, + } + } +} + +// Load server configuration from a file +pub fn load_server_config<P: AsRef<Path>>(path: P) -> Result<ServerConfig> { + let config_str = fs::read_to_string(path) + .context("Failed to read server configuration file")?; + let config: ServerConfig = toml::from_str(&config_str) + .context("Failed to parse server configuration")?; + Ok(config) +} + +// Save server configuration to a file +pub fn save_server_config<P: AsRef<Path>>(config: &ServerConfig, path: P) -> Result<()> { + let toml_str = toml::to_string_pretty(config) + .context("Failed to serialize server configuration")?; + fs::write(path, toml_str) + .context("Failed to write server configuration file")?; + Ok(()) +} + +// Load client configuration from a file +pub fn load_client_config<P: AsRef<Path>>(path: P) -> Result<ClientConfig> { + let config_str = fs::read_to_string(path) + .context("Failed to read client configuration file")?; + let config: ClientConfig = toml::from_str(&config_str) + .context("Failed to parse client configuration")?; + Ok(config) +} + +// Save client configuration to a file +pub fn save_client_config<P: AsRef<Path>>(config: &ClientConfig, path: P) -> Result<()> { + let toml_str = toml::to_string_pretty(config) + .context("Failed to serialize client configuration")?; + fs::write(path, toml_str) + .context("Failed to write client configuration file")?; + Ok(()) +} + +// Generate default configuration files if they don't exist +pub fn ensure_default_configs() -> Result<()> { + let server_config_path = "server_config.toml"; + if !Path::new(server_config_path).exists() { + let default_config = ServerConfig::default(); + save_server_config(&default_config, server_config_path)?; + println!("Created default server configuration at {}", server_config_path); + } + + let client_config_path = "client_config.toml"; + if !Path::new(client_config_path).exists() { + let default_config = ClientConfig::default(); + save_client_config(&default_config, client_config_path)?; + println!("Created default client configuration at {}", client_config_path); + } + + Ok(()) +} + +// Check if a client is authorized for specified groups +pub fn get_client_authorized_groups( + config: &ServerConfig, + ip: &str, + port: u16 +) -> Option<Vec<String>> { + // If simple_auth is enabled, all clients that know the secret get access to all groups + if config.simple_auth { + return Some(vec![]); // Empty vector means all groups + } + + // Otherwise use the standard group authorization check + for client in &config.authorized_clients { + if client.ip_address == ip { + // If port is None or matches the client's port + if client.port.is_none() || client.port == Some(port) { + // Return the group IDs or empty vector (meaning all groups) + return Some(client.group_ids.clone()); + } + } + } + + None // Not authorized +} + +// Get client name if authorized +pub fn get_client_name(config: &ServerConfig, ip: &str, port: u16) -> Option<String> { + for client in &config.authorized_clients { + if client.ip_address == ip { + // If port is None, any port from this IP is allowed + if client.port.is_none() || client.port == Some(port) { + return Some(client.name.clone()); + } + } + } + + None +} + +// Check if a client's IP address and port are authorized +pub fn is_client_authorized(config: &ServerConfig, ip: &str, port: u16) -> bool { + // If simple_auth is enabled, all clients that know the secret are authorized + if config.simple_auth { + return true; + } + + // Otherwise use the standard client authorization check + for client in &config.authorized_clients { + if client.ip_address == ip { + // If port is None, any port from this IP is allowed + if client.port.is_none() || client.port == Some(port) { + return true; + } + } + } + + false +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..93f593e --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,13 @@ +pub mod auth; +pub mod config; +pub mod protocol; + +// Constants +pub const DEFAULT_BUFFER_SIZE: usize = 65536; // Max UDP packet size +pub const TEST_MODE_BANNER: &str = " +╔════════════════════════════════════════════════════╗ +║ TEST MODE ║ +║ Packets will be displayed but not sent to network ║ +╚════════════════════════════════════════════════════╝ +"; +pub const MAX_DISPLAY_BYTES: usize = 64; // Maximum bytes to display in test mode diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..340bbd3 --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,129 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; + +// Information about a multicast group +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MulticastGroupInfo { + pub address: String, + pub port: u16, + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_ports: Option<Vec<u16>>, // Additional ports if a range was defined +} + +// Protocol messages exchanged between client and server +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Message { + // Initial authentication message + AuthRequest { + client_nonce: String, + }, + + // Authentication response + AuthResponse { + server_nonce: String, + auth_token: String, // HMAC(secret, client_nonce + server_nonce) + }, + + // Final auth confirmation + AuthConfirm { + auth_token: String, // HMAC(secret, server_nonce + client_nonce) + }, + + // Request information about available multicast groups + MulticastGroupsRequest, + + // Response with available multicast groups + MulticastGroupsResponse { + groups: HashMap<String, MulticastGroupInfo>, + }, + + // Subscribe to specific multicast groups + Subscribe { + group_ids: Vec<String>, + }, + + // Multicast packet forwarded to client + MulticastPacket { + group_id: String, + source: SocketAddr, + destination: String, // Destination multicast address + port: u16, // Destination port + data: Vec<u8>, + }, + + // Configuration response with current settings + ConfigResponse { + config: ServerConfigInfo, + }, + + // New ping/status message for checking connections + PingStatus { + timestamp: u64, + status: String, + }, +} + +// Server configuration info that can be sent to clients +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfigInfo { + pub available_multicast_addresses: Vec<String>, + pub multicast_port: u16, +} + +// Utility function to format packet data for display in test mode +pub fn format_packet_for_display(data: &[u8], max_bytes: usize) -> String { + let display_len = std::cmp::min(data.len(), max_bytes); + let mut result = String::new(); + + // Print hexadecimal representation + for (i, byte) in data.iter().take(display_len).enumerate() { + if i > 0 && i % 16 == 0 { + result.push('\n'); + } + result.push_str(&format!("{:02x} ", byte)); + } + + if data.len() > max_bytes { + result.push_str("\n... (truncated)"); + } + + result +} + +// Serialize a message to bytes +pub fn serialize_message(message: &Message) -> Result<Vec<u8>, serde_json::Error> { + serde_json::to_vec(message) +} + +// Deserialize bytes to a message with better error handling for partial messages +pub fn deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> { + // Find where valid JSON ends to handle cases where multiple messages + // or trailing data might be in the buffer + let mut deserializer = serde_json::Deserializer::from_slice(bytes); + let message = Message::deserialize(&mut deserializer)?; + + // Return the successfully parsed message + Ok(message) +} + +// Add a helper function to handle potential noise in streams +pub fn robust_deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> { + // First try the standard method + match deserialize_message(bytes) { + Ok(message) => Ok(message), + Err(e) => { + // If it fails due to trailing characters, try to parse just the valid JSON + if e.is_syntax() && e.to_string().contains("trailing characters") { + // Try to find where valid JSON ends by parsing incrementally + for i in (1..bytes.len()).rev() { + if let Ok(msg) = deserialize_message(&bytes[0..i]) { + return Ok(msg); + } + } + } + // If we couldn't recover, return the original error + Err(e) + } + } +} |