From a0c754fc727a35d775c25857898986f4273db0a5 Mon Sep 17 00:00:00 2001 From: Kablersalat Date: Thu, 5 Jun 2025 00:14:27 +0200 Subject: Imported and sanetized dev server to publish on gitter --- src/bin/client.rs | 467 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/bin/mcast_test.rs | 123 +++++++++++++ src/bin/server.rs | 467 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1057 insertions(+) create mode 100644 src/bin/client.rs create mode 100644 src/bin/mcast_test.rs create mode 100644 src/bin/server.rs (limited to 'src/bin') diff --git a/src/bin/client.rs b/src/bin/client.rs new file mode 100644 index 0000000..5c39fb9 --- /dev/null +++ b/src/bin/client.rs @@ -0,0 +1,467 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{debug, error, info, warn}; +use multicast_relay::{ + auth::{calculate_hmac, generate_nonce, verify_hmac}, + config::{load_client_config, ensure_default_configs, ClientConfig}, + protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message}, + DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES, +}; +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + path::PathBuf, + str::FromStr, + time::Duration, + sync::Arc, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpStream, UdpSocket}, + signal, + sync::Notify, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long, default_value = "client_config.toml")] + config: PathBuf, + + #[arg(short, long, action)] + generate_default: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + // Generate default configs if requested + if args.generate_default { + ensure_default_configs()?; + return Ok(()); + } + + // Load configuration + let config = load_client_config(&args.config) + .context(format!("Failed to load config from {:?}", args.config))?; + + info!("Client configuration loaded from {:?}", args.config); + + // Create a notification for clean shutdown + let shutdown = Arc::new(Notify::new()); + let shutdown_signal = shutdown.clone(); + + // Setup signal handler for Ctrl+C + tokio::spawn(async move { + if let Err(e) = signal::ctrl_c().await { + error!("Failed to listen for Ctrl+C: {}", e); + return; + } + info!("Received Ctrl+C, shutting down..."); + shutdown_signal.notify_one(); + }); + + let server_addr = format!("{}:{}", config.server, config.port); + + // Main reconnection loop + loop { + info!("Connecting to server at {}", server_addr); + + // Try to connect with timeout + let connect_result = match tokio::time::timeout( + Duration::from_secs(5), + TcpStream::connect(&server_addr), + ).await { + Ok(result) => result, + Err(_) => { + warn!("Connection attempt timed out"); + if !handle_reconnect(&shutdown, &config).await { + break; + } + continue; + } + }; + + match connect_result { + Ok(stream) => { + info!("Connected to server"); + + // Run the client session + match run_client_session(stream, &config, &shutdown).await { + Ok(_) => { + info!("Client session ended normally"); + break; + }, + Err(e) => { + error!("Client session error: {}", e); + if !handle_reconnect(&shutdown, &config).await { + break; + } + } + } + }, + Err(e) => { + error!("Failed to connect: {}", e); + if !handle_reconnect(&shutdown, &config).await { + break; + } + } + } + } + + Ok(()) +} + +// Helper function to handle reconnection delay +// Returns false if shutdown was requested +async fn handle_reconnect(shutdown: &Arc, client_config: &ClientConfig) -> bool { + let delay = Duration::from_secs(client_config.reconnect_delay_secs); + info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs); + tokio::select! { + _ = shutdown.notified() => { + info!("Shutdown requested during reconnect"); + return false; + } + _ = tokio::time::sleep(delay) => {} + } + true +} + +// Add a new enum to distinguish message types +#[derive(PartialEq)] +enum StatusMessageType { + ServerHeartbeat, + ServerPong, + Other +} + +// The main client session function that handles a single connection +async fn run_client_session( + mut stream: TcpStream, + config: &ClientConfig, + shutdown: &Arc, +) -> Result<()> { + // Authenticate + if let Err(e) = authenticate(&mut stream, &config.secret).await { + return Err(anyhow::anyhow!("Authentication failed: {}", e)); + } + info!("Authentication successful"); + + // Check if test mode is enabled in config + let mut test_mode = config.test_mode; + + // Request server's multicast group information + let groups_request = Message::MulticastGroupsRequest; + let request_bytes = serialize_message(&groups_request)?; + stream.write_all(&request_bytes).await?; + + // Wait for response with multicast group information, ignoring other message types + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let mut groups_response = None; + + // Keep reading until we get the groups response or timeout + let mut attempts = 0; + while groups_response.is_none() && attempts < 10 { + match stream.read(&mut buf).await { + Ok(n) if n > 0 => { + match robust_deserialize_message(&buf[..n]) { + Ok(Message::MulticastGroupsResponse { groups }) => { + groups_response = Some(groups); + }, + Ok(Message::PingStatus { timestamp: _, status }) => { + // Handle ping but keep waiting for multicast groups + info!("Got server ping: {}", status); + }, + Ok(other_msg) => { + debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg); + }, + Err(e) => { + error!("Failed to deserialize message: {}", e); + } + } + }, + Ok(0) => return Err(anyhow::anyhow!("Server closed connection")), + Ok(_) => {}, + Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e)) + } + + // If we didn't get groups yet, wait a bit and try again + if groups_response.is_none() { + attempts += 1; + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + // Now we either have the groups or we timed out + let groups_response = groups_response.ok_or_else(|| + anyhow::anyhow!("Failed to receive multicast group information from server"))?; + + info!("Available multicast groups from server:"); + for (id, group) in &groups_response { + info!(" - {} -> {}:{}", id, group.address, group.port); + } + + // Determine which groups to subscribe to + let groups_to_subscribe: Vec = if config.multicast_group_ids.is_empty() { + // If no specific groups are requested, subscribe to all + info!("No specific groups requested, subscribing to all available groups"); + groups_response.keys().cloned().collect() + } else { + // Otherwise, subscribe only to requested groups that exist + let mut valid_groups = Vec::new(); + for group_id in &config.multicast_group_ids { + if groups_response.contains_key(group_id.as_str()) { + valid_groups.push(group_id.clone()); + } else { + warn!("Requested group '{}' does not exist on server", group_id); + } + } + + if valid_groups.is_empty() { + warn!("None of the requested groups exist on server. No data will be received."); + } else { + info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups); + } + + valid_groups + }; + + // Send subscription message + let subscribe_msg = Message::Subscribe { + group_ids: groups_to_subscribe.clone(), + }; + let subscribe_bytes = serialize_message(&subscribe_msg)?; + stream.write_all(&subscribe_bytes).await?; + + // Create UDP sockets for local retransmission (skip if in test mode) + let mut sockets: HashMap = HashMap::new(); + if !test_mode { + for group_id in &groups_to_subscribe { + if let Some(group_info) = groups_response.get(group_id.as_str()) { + info!("Creating socket for group {} ({} on port {})", + group_id, group_info.address, group_info.port); + + match UdpSocket::bind("0.0.0.0:0").await { + Ok(socket) => { + sockets.insert(group_id.clone(), socket); + info!("Successfully created UDP socket for group {}", group_id); + }, + Err(e) => { + error!("Failed to create UDP socket for group {}: {}", group_id, e); + } + } + } + } + + if sockets.is_empty() && !groups_to_subscribe.is_empty() { + error!("Failed to create any UDP sockets"); + warn!("Falling back to test mode due to socket creation failure"); + test_mode = true; + } + } + + // Display test mode banner if enabled + if test_mode { + println!("{}", TEST_MODE_BANNER); + info!("Test mode enabled - packets will be displayed but not sent to network"); + } + + // Main receive loop + info!("Listening for multicast traffic from server"); + + // Set the read timeout for the stream + stream.set_nodelay(true)?; + + // Remove problematic code that uses unsupported methods and the nix crate + if config.nat_traversal { + info!("NAT traversal mode enabled - using more frequent keepalives"); + } + + // Calculate appropriate ping interval based on NAT traversal setting + let ping_interval = if config.nat_traversal { + Duration::from_secs(25) // More frequent for NAT + } else { + Duration::from_secs(55) + }; + + // Main receive loop + loop { + tokio::select! { + _ = shutdown.notified() => { + info!("Shutdown requested, ending client session"); + return Ok(()); + } + read_result = stream.read(&mut buf) => { + match read_result { + Ok(0) => { + info!("Server closed connection"); + return Err(anyhow::anyhow!("Server closed connection")); + } + Ok(n) => { + match robust_deserialize_message(&buf[..n]) { + Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => { + if test_mode { + println!("\n----- MULTICAST PACKET -----"); + println!("Group: {}", group_id); + println!("Source: {}", source); + println!("Destination: {}:{}", destination, port); + println!("Size: {} bytes", data.len()); + println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES)); + println!("---------------------------\n"); + } else { + info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes", + group_id, source, destination, port, data.len()); + + // Get socket for this group + if let Some(socket) = sockets.get(&group_id) { + // Parse destination address directly from the packet + match IpAddr::from_str(&destination) { + Ok(dest_addr) => { + // Create destination socket address using the packet's port + let dest = SocketAddr::new(dest_addr, port); + + info!("Forwarding packet to {}:{}", dest_addr, port); + + // Retransmit locally + match socket.send_to(&data, dest).await { + Ok(sent) => { + info!("Successfully forwarded {} of {} bytes to {}:{}", + sent, data.len(), destination, port); + }, + Err(e) => { + error!("Failed to retransmit packet for group {} to {}:{}: {}", + group_id, destination, port, e); + } + } + }, + Err(e) => { + error!("Invalid destination address {}: {}", destination, e); + } + } + } else { + warn!("No socket available for group {}", group_id); + } + } + }, + Ok(Message::ConfigResponse { config: _server_config }) => { + info!("Server configuration received"); + }, + Ok(Message::PingStatus { timestamp: _, status }) => { + // Parse the message type based on the status text + let msg_type = if status.starts_with("Server heartbeat to") { + StatusMessageType::ServerHeartbeat + } else if status.starts_with("Server connection to") { + StatusMessageType::ServerPong + } else { + StatusMessageType::Other + }; + + // Only respond with a pong to server heartbeats + if msg_type == StatusMessageType::ServerHeartbeat { + debug!("Received server heartbeat: {}", status); + + // Send a pong response (but only for heartbeats) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let pong = Message::PingStatus { + timestamp: now, + status: "Client pong response".to_string(), + }; + + if let Ok(bytes) = serialize_message(&pong) { + let _ = stream.write_all(&bytes).await; + } + } else { + // Just log other status messages without responding + debug!("Connection Status: {}", status); + } + }, + Ok(_) => debug!("Received other message type"), + Err(e) => { + error!("Failed to deserialize message: {}", e); + // Don't return/break on deserialization errors - continue reading + } + } + } + Err(e) => { + error!("Error reading from server: {}", e); + return Err(anyhow::anyhow!("Connection error: {}", e)); + } + } + }, + _ = tokio::time::sleep(ping_interval) => { + // Send regular ping to keep NAT connection alive + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let ping_msg = Message::PingStatus { + timestamp: now, + status: if config.nat_traversal { + "Client keepalive ping".to_string() // Changed text to be more specific + } else { + "Client periodic ping".to_string() + }, + }; + + match serialize_message(&ping_msg) { + Ok(bytes) => { + if let Err(e) = stream.write_all(&bytes).await { + error!("Failed to ping server: {}", e); + return Err(anyhow::anyhow!("Server ping failed: {}", e)); + } + debug!("Connection check sent to server"); + }, + Err(e) => error!("Failed to serialize ping message: {}", e), + } + } + } + } +} + +async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> { + // Generate client nonce + let client_nonce = generate_nonce(); + + // Send auth request + let auth_request = Message::AuthRequest { + client_nonce: client_nonce.clone(), + }; + let request_bytes = serialize_message(&auth_request)?; + stream.write_all(&request_bytes).await?; + + // Receive response + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let n = stream.read(&mut buf).await?; + let response = robust_deserialize_message(&buf[..n])?; + + if let Message::AuthResponse { server_nonce, auth_token } = response { + // Verify server's token + let expected_data = format!("{}{}", client_nonce, server_nonce); + if !verify_hmac(secret, &expected_data, &auth_token) { + return Err(anyhow::anyhow!("Server authentication failed")); + } + + // Calculate our token + let auth_data = format!("{}{}", server_nonce, client_nonce); + let client_token = calculate_hmac(secret, &auth_data); + + // Send confirmation + let confirm = Message::AuthConfirm { + auth_token: client_token, + }; + + let confirm_bytes = serialize_message(&confirm)?; + stream.write_all(&confirm_bytes).await?; + + Ok(()) + } else { + Err(anyhow::anyhow!("Unexpected response from server")) + } +} diff --git a/src/bin/mcast_test.rs b/src/bin/mcast_test.rs new file mode 100644 index 0000000..782610c --- /dev/null +++ b/src/bin/mcast_test.rs @@ -0,0 +1,123 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{error, info}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::str::FromStr; +use std::time::{Duration, Instant}; +use tokio::net::UdpSocket; + +#[derive(Parser, Debug)] +#[command(author, version, about = "Multicast packet generator for testing CastRepeat")] +struct Args { + #[arg(short, long, default_value = "239.192.55.1")] + multicast_addr: String, + + #[arg(short, long, default_value = "1681")] + port: u16, + + #[arg(short, long, default_value = "1000")] + interval_ms: u64, + + #[arg(short, long, default_value = "60")] + duration_sec: u64, + + #[arg(short, long, default_value = "Test packet")] + message: String, + + #[arg(short, long)] + interface: Option, + + #[arg(short = 'b', long, action)] + binary_mode: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + info!("CastRepeat Multicast Packet Generator"); + info!("--------------------------------"); + info!("Multicast Address: {}", args.multicast_addr); + info!("Port: {}", args.port); + info!("Interval: {} ms", args.interval_ms); + info!("Duration: {} sec", args.duration_sec); + if let Some(interface) = &args.interface { + info!("Interface: {}", interface); + } + info!("Mode: {}", if args.binary_mode { "Binary" } else { "Text" }); + info!("--------------------------------"); + + // Verify multicast address + let mcast_addr = match IpAddr::from_str(&args.multicast_addr) { + Ok(IpAddr::V4(addr)) if addr.is_multicast() => addr, + _ => { + error!("Invalid multicast address: {}", args.multicast_addr); + return Ok(()); + } + }; + + // Create socket + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .context("Failed to create socket")?; + + socket.set_multicast_ttl_v4(4)?; + socket.set_nonblocking(true)?; + + // Set the multicast interface if specified + if let Some(if_str) = &args.interface { + if let Ok(if_addr) = Ipv4Addr::from_str(if_str) { + socket.set_multicast_if_v4(&if_addr)?; + info!("Using interface: {}", if_addr); + } else { + error!("Invalid interface address: {}", if_str); + return Ok(()); + } + } + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + socket.bind(&addr.into())?; + + let socket = UdpSocket::from_std(socket.into())?; + let dest_addr = SocketAddr::new(IpAddr::V4(mcast_addr), args.port); + + // Start sending packets + info!("Sending packets to {}...", dest_addr); + + let start = Instant::now(); + let end = start + Duration::from_secs(args.duration_sec); + let mut counter: u64 = 0; // Specify counter type as u64 + + while Instant::now() < end { + counter += 1; + + // Create either text or binary test data + let data = if args.binary_mode { + // Create binary test data (similar to what we might see in the field) + let mut packet = Vec::with_capacity(16); + packet.extend_from_slice(b"REL\0"); // 4 bytes header + packet.extend_from_slice(&counter.to_be_bytes()[4..]); // 4 bytes counter + packet.extend_from_slice(&[0x00, 0x10, 0x02, 0x00]); // 4 bytes + packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x90]); // 4 bytes + packet + } else { + // Create text test data + format!("{} #{}", args.message, counter).into_bytes() + }; + + match socket.send_to(&data, &dest_addr).await { + Ok(bytes) => { + info!("Sent packet #{}: {} bytes", counter, bytes); + } + Err(e) => { + error!("Error sending packet: {}", e); + } + } + + tokio::time::sleep(Duration::from_millis(args.interval_ms)).await; + } + + info!("Done. Sent {} packets in {} seconds.", counter, args.duration_sec); + Ok(()) +} diff --git a/src/bin/server.rs b/src/bin/server.rs new file mode 100644 index 0000000..7fe0109 --- /dev/null +++ b/src/bin/server.rs @@ -0,0 +1,467 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{debug, error, info, warn}; +use multicast_relay::{ + auth::{calculate_hmac, generate_nonce, verify_hmac}, + config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups}, + protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo}, + DEFAULT_BUFFER_SIZE, +}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::PathBuf, + str::FromStr, + sync::Arc, + time::Duration, // Add this import for Duration +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + sync::{mpsc, Mutex}, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long, default_value = "server_config.toml")] + config: PathBuf, + + #[arg(short, long, action)] + generate_default: bool, +} + +type ClientMap = Arc>>>>; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + // Generate default configs if requested + if args.generate_default { + ensure_default_configs()?; + return Ok(()); + } + + // Load configuration + let config = load_server_config(&args.config) + .context(format!("Failed to load config from {:?}", args.config))?; + + info!("Server configuration loaded from {:?}", args.config); + + let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port); + let listener = TcpListener::bind(&listen_addr).await + .context("Failed to bind TCP listener")?; + + info!("Server listening on {}", listen_addr); + + // Setup multicast receivers + let clients: ClientMap = Arc::new(Mutex::new(HashMap::new())); + + // Start multicast listeners for each multicast group + for (group_id, group) in &config.multicast_groups { + let ports = group.get_ports(); + if ports.is_empty() { + error!("No ports defined for group {}", group_id); + continue; + } + + let display_group_id = group_id.clone(); + let ports_display = if ports.len() == 1 { + format!("port {}", ports[0]) + } else { + format!("ports {}-{}", ports[0], ports.last().unwrap()) + }; + + // Create a listener for each port in the range + for port in ports { + let clients = clients.clone(); + let _secret = config.secret.clone(); + let group_id_clone = group_id.clone(); + let mut group_info = group.clone(); + + // Set the specific port for this listener + group_info.port = Some(port); + group_info.port_range = None; + + tokio::spawn(async move { + if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await { + error!("Multicast listener error for group {} port {}: {}", + group_id_clone, port, e); + } + }); + } + + info!("Listening for multicast group {} on address {} with {}", + display_group_id, group.address, ports_display); + } + + // Store config for use in client handlers + let config = Arc::new(config); + + // Accept client connections + while let Ok((stream, addr)) = listener.accept().await { + info!("New client connection from: {}", addr); + let secret = config.secret.clone(); + let clients = clients.clone(); + let config = config.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_client(stream, addr, &secret, clients, config).await { + error!("Client error: {}: {}", addr, e); + } + info!("Client disconnected: {}", addr); + }); + } + + Ok(()) +} + +async fn listen_to_multicast( + group_id: &str, + group: &MulticastGroup, + clients: ClientMap +) -> Result<()> { + // Get the port to use + let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?; + + // Parse the multicast address + let mcast_ip = match IpAddr::from_str(&group.address) + .context("Invalid multicast address")? { + IpAddr::V4(addr) => addr, + _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported")) + }; + + // Create a UDP socket with more explicit settings + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .context("Failed to create socket")?; + + // Important: Set socket options + socket.set_reuse_address(true)?; + + #[cfg(unix)] + socket.set_reuse_port(true)?; + + socket.set_nonblocking(true)?; + socket.set_multicast_loop_v4(true)?; + + // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port + // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port + let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port); + info!("Binding multicast listener to specific address: {:?}", bind_addr); + socket.bind(&bind_addr.into())?; + + // Join the multicast group with a specific interface + let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface + info!("Joining multicast group {} on interface {:?}", mcast_ip, interface); + socket.join_multicast_v4(&mcast_ip, &interface)?; + + // Additional multicast option: set the IP_MULTICAST_IF option + socket.set_multicast_if_v4(&interface)?; + + // Convert to tokio socket + let udp_socket = tokio::net::UdpSocket::from_std(socket.into()) + .context("Failed to convert socket to async")?; + + info!("Multicast listener ready and bound specifically to {}:{} (group {})", + group.address, port, group_id); + + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let group_id = group_id.to_string(); + + loop { + match udp_socket.recv_from(&mut buf).await { + Ok((len, src)) => { + // Since we're bound to the exact multicast address, we can be confident + // this packet was sent to our specific multicast group + let data = buf[..len].to_vec(); + + info!("RECEIVED: group={} from={} size={} destination={}:{}", + group_id, src, len, mcast_ip, port); + + // Create a message with the packet + let message = Message::MulticastPacket { + group_id: group_id.clone(), + source: src, + destination: group.address.clone(), + port, + data, + }; + + // Send to clients + match serialize_message(&message) { + Ok(serialized) => { + let clients_lock = clients.lock().await; + for (client_addr, sender) in clients_lock.iter() { + if sender.send(serialized.clone()).await.is_err() { + debug!("Failed to send to client {}", client_addr); + } else { + debug!("Sent multicast packet to client {}", client_addr); + } + } + } + Err(e) => error!("Failed to serialize message: {}", e), + } + } + Err(e) => { + if e.kind() != std::io::ErrorKind::WouldBlock { + error!("Error receiving from socket: {}", e); + } + // Small delay to avoid busy waiting on errors + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + } +} + +#[derive(PartialEq)] +enum StatusMessageType { + ClientHeartbeat, + ClientPong, + Other +} + +async fn handle_client( + stream: TcpStream, + addr: SocketAddr, + secret: &str, + clients: ClientMap, + config: Arc, +) -> Result<()> { + // Check if external clients are allowed when client is not from localhost + if !config.allow_external_clients && + !addr.ip().is_loopback() && + !addr.ip().to_string().starts_with("192.168.") && + !addr.ip().to_string().starts_with("10.") { + warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr); + return Err(anyhow::anyhow!("External clients not allowed")); + } + + // Split the TCP stream into read and write parts once + let (mut read_stream, mut write_stream) = tokio::io::split(stream); + + // Authentication using the split streams + if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? { + return Err(anyhow::anyhow!("Authentication failed")); + } + + info!("Client authenticated: {}", addr); + + // Get client info + let client_ip = addr.ip().to_string(); + let client_port = addr.port(); + + // Check if client has specific group permissions + let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) { + Some(groups) => groups, + None => return Err(anyhow::anyhow!("Client not authorized for any groups")), + }; + + // Create channel for sending multicast packets to this client + let (tx, mut rx) = mpsc::channel::>(100); + + // Add client to map + clients.lock().await.insert(addr, tx.clone()); + + // Create HashMap of available groups for this client + let mut available_groups = HashMap::new(); + for (id, group) in &config.multicast_groups { + // If client has empty group list (all allowed) or specific group is in list + if authorized_groups.is_empty() || authorized_groups.contains(id) { + let ports = group.get_ports(); + if ports.is_empty() { + continue; + } + + // Primary port is the first one + let primary_port = ports[0]; + + // Get additional ports if there are any + let additional_ports = if ports.len() > 1 { + Some(ports[1..].to_vec()) + } else { + None + }; + + available_groups.insert(id.clone(), MulticastGroupInfo { + address: group.address.clone(), + port: primary_port, + additional_ports, + }); + } + } + + // IMPORTANT: Clone tx before moving it into the spawn + let tx_for_read = tx.clone(); + + // Spawn task to read client messages + let clients_clone = clients.clone(); + + // Use the already split read_stream + let read_task = tokio::spawn(async move { + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + loop { + match read_stream.read(&mut buf).await { + Ok(0) => break, // Connection closed + Ok(n) => { + if let Ok(msg) = deserialize_message(&buf[..n]) { + match msg { + Message::Subscribe { group_ids } => { + info!("Client {} subscribing to groups: {:?}", addr, group_ids); + // Group subscriptions handled by server + }, + Message::MulticastGroupsRequest => { + // Send available groups to client + let response = Message::MulticastGroupsResponse { + groups: available_groups.clone() + }; + + if let Ok(bytes) = serialize_message(&response) { + let _ = tx_for_read.send(bytes).await; + } + }, + Message::PingStatus { timestamp, status } => { + // Determine the type of status message + let msg_type = if status.starts_with("Client keepalive ping") || + status.starts_with("Client periodic ping") { + StatusMessageType::ClientHeartbeat + } else if status.starts_with("Client pong response") { + StatusMessageType::ClientPong + } else { + StatusMessageType::Other + }; + + // Log the message receipt + match msg_type { + StatusMessageType::ClientHeartbeat => { + info!("Heartbeat from client {}: {}", addr, status); + + // Respond only to actual heartbeat pings, not pong responses + let response = Message::PingStatus { + timestamp, + status: format!("Server connection to {} is OK", addr), + }; + + if let Ok(bytes) = serialize_message(&response) { + let _ = tx_for_read.send(bytes).await; + } + }, + StatusMessageType::ClientPong => { + // Just log pongs without responding to avoid loops + debug!("Pong from client {}: {}", addr, status); + }, + StatusMessageType::Other => { + info!("Status message from client {}: {}", addr, status); + } + } + }, + _ => {} + } + } + } + Err(e) => { + error!("Error reading from client: {}: {}", addr, e); + break; + } + } + } + // Clean up on disconnect + clients_clone.lock().await.remove(&addr); + info!("Client reader task ended: {}", addr); + }); + + // Forward multicast packets to client using the already split write_stream + let write_task = tokio::spawn(async move { + while let Some(packet) = rx.recv().await { + if let Err(e) = write_stream.write_all(&packet).await { + error!("Error writing to client {}: {}", addr, e); + break; + } + } + info!("Client writer task ended: {}", addr); + }); + + // Now tx is still valid here - use it for heartbeats, but with a delay + let tx_for_heartbeat = tx.clone(); + let client_addr = addr.clone(); + tokio::spawn(async move { + // Add initial delay before starting heartbeats to avoid interfering with initial setup messages + tokio::time::sleep(Duration::from_secs(5)).await; + + let mut interval = tokio::time::interval(Duration::from_secs(30)); + loop { + interval.tick().await; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Send a heartbeat with a clear identifier + let msg = Message::PingStatus { + timestamp: now, + status: format!("Server heartbeat to {}", client_addr), + }; + + if let Ok(bytes) = serialize_message(&msg) { + if tx_for_heartbeat.send(bytes).await.is_err() { + break; + } + } + } + }); + + // Wait for either task to complete + tokio::select! { + _ = read_task => {}, + _ = write_task => {}, + } + + // Clean up + clients.lock().await.remove(&addr); + Ok(()) +} + +async fn authenticate_client( + reader: &mut (impl AsyncReadExt + Unpin), + writer: &mut (impl AsyncWriteExt + Unpin), + _addr: SocketAddr, + secret: &str +) -> Result { + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + + // Receive auth request + let n = reader.read(&mut buf).await?; + let auth_request = deserialize_message(&buf[..n])?; + + if let Message::AuthRequest { client_nonce } = auth_request { + // Generate server nonce + let server_nonce = generate_nonce(); + + // Calculate auth token + let auth_data = format!("{}{}", client_nonce, server_nonce); + let auth_token = calculate_hmac(secret, &auth_data); + + // Send response + let response = Message::AuthResponse { + server_nonce: server_nonce.clone(), + auth_token, + }; + + let response_bytes = serialize_message(&response)?; + writer.write_all(&response_bytes).await?; + + // Receive confirmation + let n = reader.read(&mut buf).await?; + let auth_confirm = deserialize_message(&buf[..n])?; + + if let Message::AuthConfirm { auth_token } = auth_confirm { + // Verify token + let expected_data = format!("{}{}", server_nonce, client_nonce); + return Ok(verify_hmac(secret, &expected_data, &auth_token)); + } + } + + Ok(false) +} -- cgit v1.2.3-70-g09d2