use anyhow::{Context, Result}; use clap::Parser; use log::{debug, error, info, warn}; use multicast_relay::{ auth::{calculate_hmac, generate_nonce, verify_hmac}, config::{load_client_config, ensure_default_configs, ClientConfig}, protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message}, DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES, }; use std::{ collections::HashMap, net::{IpAddr, SocketAddr}, path::PathBuf, str::FromStr, time::Duration, sync::Arc, }; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpStream, UdpSocket}, signal, sync::Notify, }; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { #[arg(short, long, default_value = "client_config.toml")] config: PathBuf, #[arg(short, long, action)] generate_default: bool, } #[tokio::main] async fn main() -> Result<()> { env_logger::init(); let args = Args::parse(); // Generate default configs if requested if args.generate_default { ensure_default_configs()?; return Ok(()); } // Load configuration let config = load_client_config(&args.config) .context(format!("Failed to load config from {:?}", args.config))?; info!("Client configuration loaded from {:?}", args.config); // Create a notification for clean shutdown let shutdown = Arc::new(Notify::new()); let shutdown_signal = shutdown.clone(); // Setup signal handler for Ctrl+C tokio::spawn(async move { if let Err(e) = signal::ctrl_c().await { error!("Failed to listen for Ctrl+C: {}", e); return; } info!("Received Ctrl+C, shutting down..."); shutdown_signal.notify_one(); }); let server_addr = format!("{}:{}", config.server, config.port); // Main reconnection loop loop { info!("Connecting to server at {}", server_addr); // Try to connect with timeout let connect_result = match tokio::time::timeout( Duration::from_secs(5), TcpStream::connect(&server_addr), ).await { Ok(result) => result, Err(_) => { warn!("Connection attempt timed out"); if !handle_reconnect(&shutdown, &config).await { break; } continue; } }; match connect_result { Ok(stream) => { info!("Connected to server"); // Run the client session match run_client_session(stream, &config, &shutdown).await { Ok(_) => { info!("Client session ended normally"); break; }, Err(e) => { error!("Client session error: {}", e); if !handle_reconnect(&shutdown, &config).await { break; } } } }, Err(e) => { error!("Failed to connect: {}", e); if !handle_reconnect(&shutdown, &config).await { break; } } } } Ok(()) } // Helper function to handle reconnection delay // Returns false if shutdown was requested async fn handle_reconnect(shutdown: &Arc, client_config: &ClientConfig) -> bool { let delay = Duration::from_secs(client_config.reconnect_delay_secs); info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs); tokio::select! { _ = shutdown.notified() => { info!("Shutdown requested during reconnect"); return false; } _ = tokio::time::sleep(delay) => {} } true } // Add a new enum to distinguish message types #[derive(PartialEq)] enum StatusMessageType { ServerHeartbeat, ServerPong, Other } // The main client session function that handles a single connection async fn run_client_session( mut stream: TcpStream, config: &ClientConfig, shutdown: &Arc, ) -> Result<()> { // Authenticate if let Err(e) = authenticate(&mut stream, &config.secret).await { return Err(anyhow::anyhow!("Authentication failed: {}", e)); } info!("Authentication successful"); // Check if test mode is enabled in config let mut test_mode = config.test_mode; // Request server's multicast group information let groups_request = Message::MulticastGroupsRequest; let request_bytes = serialize_message(&groups_request)?; stream.write_all(&request_bytes).await?; // Wait for response with multicast group information, ignoring other message types let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; let mut groups_response = None; // Keep reading until we get the groups response or timeout let mut attempts = 0; while groups_response.is_none() && attempts < 10 { match stream.read(&mut buf).await { Ok(n) if n > 0 => { match robust_deserialize_message(&buf[..n]) { Ok(Message::MulticastGroupsResponse { groups }) => { groups_response = Some(groups); }, Ok(Message::PingStatus { timestamp: _, status }) => { // Handle ping but keep waiting for multicast groups info!("Got server ping: {}", status); }, Ok(other_msg) => { debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg); }, Err(e) => { error!("Failed to deserialize message: {}", e); } } }, Ok(0) => return Err(anyhow::anyhow!("Server closed connection")), Ok(_) => {}, Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e)) } // If we didn't get groups yet, wait a bit and try again if groups_response.is_none() { attempts += 1; tokio::time::sleep(Duration::from_millis(100)).await; } } // Now we either have the groups or we timed out let groups_response = groups_response.ok_or_else(|| anyhow::anyhow!("Failed to receive multicast group information from server"))?; info!("Available multicast groups from server:"); for (id, group) in &groups_response { info!(" - {} -> {}:{}", id, group.address, group.port); } // Determine which groups to subscribe to let groups_to_subscribe: Vec = if config.multicast_group_ids.is_empty() { // If no specific groups are requested, subscribe to all info!("No specific groups requested, subscribing to all available groups"); groups_response.keys().cloned().collect() } else { // Otherwise, subscribe only to requested groups that exist let mut valid_groups = Vec::new(); for group_id in &config.multicast_group_ids { if groups_response.contains_key(group_id.as_str()) { valid_groups.push(group_id.clone()); } else { warn!("Requested group '{}' does not exist on server", group_id); } } if valid_groups.is_empty() { warn!("None of the requested groups exist on server. No data will be received."); } else { info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups); } valid_groups }; // Send subscription message let subscribe_msg = Message::Subscribe { group_ids: groups_to_subscribe.clone(), }; let subscribe_bytes = serialize_message(&subscribe_msg)?; stream.write_all(&subscribe_bytes).await?; // Create UDP sockets for local retransmission (skip if in test mode) let mut sockets: HashMap = HashMap::new(); if !test_mode { for group_id in &groups_to_subscribe { if let Some(group_info) = groups_response.get(group_id.as_str()) { info!("Creating socket for group {} ({} on port {})", group_id, group_info.address, group_info.port); match UdpSocket::bind("0.0.0.0:0").await { Ok(socket) => { sockets.insert(group_id.clone(), socket); info!("Successfully created UDP socket for group {}", group_id); }, Err(e) => { error!("Failed to create UDP socket for group {}: {}", group_id, e); } } } } if sockets.is_empty() && !groups_to_subscribe.is_empty() { error!("Failed to create any UDP sockets"); warn!("Falling back to test mode due to socket creation failure"); test_mode = true; } } // Display test mode banner if enabled if test_mode { println!("{}", TEST_MODE_BANNER); info!("Test mode enabled - packets will be displayed but not sent to network"); } // Main receive loop info!("Listening for multicast traffic from server"); // Set the read timeout for the stream stream.set_nodelay(true)?; // Remove problematic code that uses unsupported methods and the nix crate if config.nat_traversal { info!("NAT traversal mode enabled - using more frequent keepalives"); } // Calculate appropriate ping interval based on NAT traversal setting let ping_interval = if config.nat_traversal { Duration::from_secs(25) // More frequent for NAT } else { Duration::from_secs(55) }; // Main receive loop loop { tokio::select! { _ = shutdown.notified() => { info!("Shutdown requested, ending client session"); return Ok(()); } read_result = stream.read(&mut buf) => { match read_result { Ok(0) => { info!("Server closed connection"); return Err(anyhow::anyhow!("Server closed connection")); } Ok(n) => { match robust_deserialize_message(&buf[..n]) { Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => { if test_mode { println!("\n----- MULTICAST PACKET -----"); println!("Group: {}", group_id); println!("Source: {}", source); println!("Destination: {}:{}", destination, port); println!("Size: {} bytes", data.len()); println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES)); println!("---------------------------\n"); } else { info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes", group_id, source, destination, port, data.len()); // Get socket for this group if let Some(socket) = sockets.get(&group_id) { // Parse destination address directly from the packet match IpAddr::from_str(&destination) { Ok(dest_addr) => { // Create destination socket address using the packet's port let dest = SocketAddr::new(dest_addr, port); info!("Forwarding packet to {}:{}", dest_addr, port); // Retransmit locally match socket.send_to(&data, dest).await { Ok(sent) => { info!("Successfully forwarded {} of {} bytes to {}:{}", sent, data.len(), destination, port); }, Err(e) => { error!("Failed to retransmit packet for group {} to {}:{}: {}", group_id, destination, port, e); } } }, Err(e) => { error!("Invalid destination address {}: {}", destination, e); } } } else { warn!("No socket available for group {}", group_id); } } }, Ok(Message::ConfigResponse { config: _server_config }) => { info!("Server configuration received"); }, Ok(Message::PingStatus { timestamp: _, status }) => { // Parse the message type based on the status text let msg_type = if status.starts_with("Server heartbeat to") { StatusMessageType::ServerHeartbeat } else if status.starts_with("Server connection to") { StatusMessageType::ServerPong } else { StatusMessageType::Other }; // Only respond with a pong to server heartbeats if msg_type == StatusMessageType::ServerHeartbeat { debug!("Received server heartbeat: {}", status); // Send a pong response (but only for heartbeats) let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs(); let pong = Message::PingStatus { timestamp: now, status: "Client pong response".to_string(), }; if let Ok(bytes) = serialize_message(&pong) { let _ = stream.write_all(&bytes).await; } } else { // Just log other status messages without responding debug!("Connection Status: {}", status); } }, Ok(_) => debug!("Received other message type"), Err(e) => { error!("Failed to deserialize message: {}", e); // Don't return/break on deserialization errors - continue reading } } } Err(e) => { error!("Error reading from server: {}", e); return Err(anyhow::anyhow!("Connection error: {}", e)); } } }, _ = tokio::time::sleep(ping_interval) => { // Send regular ping to keep NAT connection alive let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs(); let ping_msg = Message::PingStatus { timestamp: now, status: if config.nat_traversal { "Client keepalive ping".to_string() // Changed text to be more specific } else { "Client periodic ping".to_string() }, }; match serialize_message(&ping_msg) { Ok(bytes) => { if let Err(e) = stream.write_all(&bytes).await { error!("Failed to ping server: {}", e); return Err(anyhow::anyhow!("Server ping failed: {}", e)); } debug!("Connection check sent to server"); }, Err(e) => error!("Failed to serialize ping message: {}", e), } } } } } async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> { // Generate client nonce let client_nonce = generate_nonce(); // Send auth request let auth_request = Message::AuthRequest { client_nonce: client_nonce.clone(), }; let request_bytes = serialize_message(&auth_request)?; stream.write_all(&request_bytes).await?; // Receive response let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; let n = stream.read(&mut buf).await?; let response = robust_deserialize_message(&buf[..n])?; if let Message::AuthResponse { server_nonce, auth_token } = response { // Verify server's token let expected_data = format!("{}{}", client_nonce, server_nonce); if !verify_hmac(secret, &expected_data, &auth_token) { return Err(anyhow::anyhow!("Server authentication failed")); } // Calculate our token let auth_data = format!("{}{}", server_nonce, client_nonce); let client_token = calculate_hmac(secret, &auth_data); // Send confirmation let confirm = Message::AuthConfirm { auth_token: client_token, }; let confirm_bytes = serialize_message(&confirm)?; stream.write_all(&confirm_bytes).await?; Ok(()) } else { Err(anyhow::anyhow!("Unexpected response from server")) } }