use anyhow::{Context, Result}; use clap::Parser; use log::{debug, error, info, warn}; use multicast_relay::{ auth::{calculate_hmac, generate_nonce, verify_hmac}, config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups}, protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo}, DEFAULT_BUFFER_SIZE, }; use socket2::{Domain, Protocol, Socket, Type}; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, SocketAddr}, path::PathBuf, str::FromStr, sync::Arc, time::Duration, // Add this import for Duration }; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, sync::{mpsc, Mutex}, }; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { #[arg(short, long, default_value = "server_config.toml")] config: PathBuf, #[arg(short, long, action)] generate_default: bool, } type ClientMap = Arc>>>>; #[tokio::main] async fn main() -> Result<()> { env_logger::init(); let args = Args::parse(); // Generate default configs if requested if args.generate_default { ensure_default_configs()?; return Ok(()); } // Load configuration let config = load_server_config(&args.config) .context(format!("Failed to load config from {:?}", args.config))?; info!("Server configuration loaded from {:?}", args.config); let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port); let listener = TcpListener::bind(&listen_addr).await .context("Failed to bind TCP listener")?; info!("Server listening on {}", listen_addr); // Setup multicast receivers let clients: ClientMap = Arc::new(Mutex::new(HashMap::new())); // Start multicast listeners for each multicast group for (group_id, group) in &config.multicast_groups { let ports = group.get_ports(); if ports.is_empty() { error!("No ports defined for group {}", group_id); continue; } let display_group_id = group_id.clone(); let ports_display = if ports.len() == 1 { format!("port {}", ports[0]) } else { format!("ports {}-{}", ports[0], ports.last().unwrap()) }; // Create a listener for each port in the range for port in ports { let clients = clients.clone(); let _secret = config.secret.clone(); let group_id_clone = group_id.clone(); let mut group_info = group.clone(); // Set the specific port for this listener group_info.port = Some(port); group_info.port_range = None; tokio::spawn(async move { if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await { error!("Multicast listener error for group {} port {}: {}", group_id_clone, port, e); } }); } info!("Listening for multicast group {} on address {} with {}", display_group_id, group.address, ports_display); } // Store config for use in client handlers let config = Arc::new(config); // Accept client connections while let Ok((stream, addr)) = listener.accept().await { info!("New client connection from: {}", addr); let secret = config.secret.clone(); let clients = clients.clone(); let config = config.clone(); tokio::spawn(async move { if let Err(e) = handle_client(stream, addr, &secret, clients, config).await { error!("Client error: {}: {}", addr, e); } info!("Client disconnected: {}", addr); }); } Ok(()) } async fn listen_to_multicast( group_id: &str, group: &MulticastGroup, clients: ClientMap ) -> Result<()> { // Get the port to use let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?; // Parse the multicast address let mcast_ip = match IpAddr::from_str(&group.address) .context("Invalid multicast address")? { IpAddr::V4(addr) => addr, _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported")) }; // Create a UDP socket with more explicit settings let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) .context("Failed to create socket")?; // Important: Set socket options socket.set_reuse_address(true)?; #[cfg(unix)] socket.set_reuse_port(true)?; socket.set_nonblocking(true)?; socket.set_multicast_loop_v4(true)?; // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port); info!("Binding multicast listener to specific address: {:?}", bind_addr); socket.bind(&bind_addr.into())?; // Join the multicast group with a specific interface let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface info!("Joining multicast group {} on interface {:?}", mcast_ip, interface); socket.join_multicast_v4(&mcast_ip, &interface)?; // Additional multicast option: set the IP_MULTICAST_IF option socket.set_multicast_if_v4(&interface)?; // Convert to tokio socket let udp_socket = tokio::net::UdpSocket::from_std(socket.into()) .context("Failed to convert socket to async")?; info!("Multicast listener ready and bound specifically to {}:{} (group {})", group.address, port, group_id); let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; let group_id = group_id.to_string(); loop { match udp_socket.recv_from(&mut buf).await { Ok((len, src)) => { // Since we're bound to the exact multicast address, we can be confident // this packet was sent to our specific multicast group let data = buf[..len].to_vec(); info!("RECEIVED: group={} from={} size={} destination={}:{}", group_id, src, len, mcast_ip, port); // Create a message with the packet let message = Message::MulticastPacket { group_id: group_id.clone(), source: src, destination: group.address.clone(), port, data, }; // Send to clients match serialize_message(&message) { Ok(serialized) => { let clients_lock = clients.lock().await; for (client_addr, sender) in clients_lock.iter() { if sender.send(serialized.clone()).await.is_err() { debug!("Failed to send to client {}", client_addr); } else { debug!("Sent multicast packet to client {}", client_addr); } } } Err(e) => error!("Failed to serialize message: {}", e), } } Err(e) => { if e.kind() != std::io::ErrorKind::WouldBlock { error!("Error receiving from socket: {}", e); } // Small delay to avoid busy waiting on errors tokio::time::sleep(Duration::from_millis(10)).await; } } } } #[derive(PartialEq)] enum StatusMessageType { ClientHeartbeat, ClientPong, Other } async fn handle_client( stream: TcpStream, addr: SocketAddr, secret: &str, clients: ClientMap, config: Arc, ) -> Result<()> { // Check if external clients are allowed when client is not from localhost if !config.allow_external_clients && !addr.ip().is_loopback() && !addr.ip().to_string().starts_with("192.168.") && !addr.ip().to_string().starts_with("10.") { warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr); return Err(anyhow::anyhow!("External clients not allowed")); } // Split the TCP stream into read and write parts once let (mut read_stream, mut write_stream) = tokio::io::split(stream); // Authentication using the split streams if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? { return Err(anyhow::anyhow!("Authentication failed")); } info!("Client authenticated: {}", addr); // Get client info let client_ip = addr.ip().to_string(); let client_port = addr.port(); // Check if client has specific group permissions let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) { Some(groups) => groups, None => return Err(anyhow::anyhow!("Client not authorized for any groups")), }; // Create channel for sending multicast packets to this client let (tx, mut rx) = mpsc::channel::>(100); // Add client to map clients.lock().await.insert(addr, tx.clone()); // Create HashMap of available groups for this client let mut available_groups = HashMap::new(); for (id, group) in &config.multicast_groups { // If client has empty group list (all allowed) or specific group is in list if authorized_groups.is_empty() || authorized_groups.contains(id) { let ports = group.get_ports(); if ports.is_empty() { continue; } // Primary port is the first one let primary_port = ports[0]; // Get additional ports if there are any let additional_ports = if ports.len() > 1 { Some(ports[1..].to_vec()) } else { None }; available_groups.insert(id.clone(), MulticastGroupInfo { address: group.address.clone(), port: primary_port, additional_ports, }); } } // IMPORTANT: Clone tx before moving it into the spawn let tx_for_read = tx.clone(); // Spawn task to read client messages let clients_clone = clients.clone(); // Use the already split read_stream let read_task = tokio::spawn(async move { let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; loop { match read_stream.read(&mut buf).await { Ok(0) => break, // Connection closed Ok(n) => { if let Ok(msg) = deserialize_message(&buf[..n]) { match msg { Message::Subscribe { group_ids } => { info!("Client {} subscribing to groups: {:?}", addr, group_ids); // Group subscriptions handled by server }, Message::MulticastGroupsRequest => { // Send available groups to client let response = Message::MulticastGroupsResponse { groups: available_groups.clone() }; if let Ok(bytes) = serialize_message(&response) { let _ = tx_for_read.send(bytes).await; } }, Message::PingStatus { timestamp, status } => { // Determine the type of status message let msg_type = if status.starts_with("Client keepalive ping") || status.starts_with("Client periodic ping") { StatusMessageType::ClientHeartbeat } else if status.starts_with("Client pong response") { StatusMessageType::ClientPong } else { StatusMessageType::Other }; // Log the message receipt match msg_type { StatusMessageType::ClientHeartbeat => { info!("Heartbeat from client {}: {}", addr, status); // Respond only to actual heartbeat pings, not pong responses let response = Message::PingStatus { timestamp, status: format!("Server connection to {} is OK", addr), }; if let Ok(bytes) = serialize_message(&response) { let _ = tx_for_read.send(bytes).await; } }, StatusMessageType::ClientPong => { // Just log pongs without responding to avoid loops debug!("Pong from client {}: {}", addr, status); }, StatusMessageType::Other => { info!("Status message from client {}: {}", addr, status); } } }, _ => {} } } } Err(e) => { error!("Error reading from client: {}: {}", addr, e); break; } } } // Clean up on disconnect clients_clone.lock().await.remove(&addr); info!("Client reader task ended: {}", addr); }); // Forward multicast packets to client using the already split write_stream let write_task = tokio::spawn(async move { while let Some(packet) = rx.recv().await { if let Err(e) = write_stream.write_all(&packet).await { error!("Error writing to client {}: {}", addr, e); break; } } info!("Client writer task ended: {}", addr); }); // Now tx is still valid here - use it for heartbeats, but with a delay let tx_for_heartbeat = tx.clone(); let client_addr = addr.clone(); tokio::spawn(async move { // Add initial delay before starting heartbeats to avoid interfering with initial setup messages tokio::time::sleep(Duration::from_secs(5)).await; let mut interval = tokio::time::interval(Duration::from_secs(30)); loop { interval.tick().await; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs(); // Send a heartbeat with a clear identifier let msg = Message::PingStatus { timestamp: now, status: format!("Server heartbeat to {}", client_addr), }; if let Ok(bytes) = serialize_message(&msg) { if tx_for_heartbeat.send(bytes).await.is_err() { break; } } } }); // Wait for either task to complete tokio::select! { _ = read_task => {}, _ = write_task => {}, } // Clean up clients.lock().await.remove(&addr); Ok(()) } async fn authenticate_client( reader: &mut (impl AsyncReadExt + Unpin), writer: &mut (impl AsyncWriteExt + Unpin), _addr: SocketAddr, secret: &str ) -> Result { let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; // Receive auth request let n = reader.read(&mut buf).await?; let auth_request = deserialize_message(&buf[..n])?; if let Message::AuthRequest { client_nonce } = auth_request { // Generate server nonce let server_nonce = generate_nonce(); // Calculate auth token let auth_data = format!("{}{}", client_nonce, server_nonce); let auth_token = calculate_hmac(secret, &auth_data); // Send response let response = Message::AuthResponse { server_nonce: server_nonce.clone(), auth_token, }; let response_bytes = serialize_message(&response)?; writer.write_all(&response_bytes).await?; // Receive confirmation let n = reader.read(&mut buf).await?; let auth_confirm = deserialize_message(&buf[..n])?; if let Message::AuthConfirm { auth_token } = auth_confirm { // Verify token let expected_data = format!("{}{}", server_nonce, client_nonce); return Ok(verify_hmac(secret, &expected_data, &auth_token)); } } Ok(false) }