diff options
Diffstat (limited to 'src/bin/client.rs')
-rw-r--r-- | src/bin/client.rs | 467 |
1 files changed, 467 insertions, 0 deletions
diff --git a/src/bin/client.rs b/src/bin/client.rs new file mode 100644 index 0000000..5c39fb9 --- /dev/null +++ b/src/bin/client.rs @@ -0,0 +1,467 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{debug, error, info, warn}; +use multicast_relay::{ + auth::{calculate_hmac, generate_nonce, verify_hmac}, + config::{load_client_config, ensure_default_configs, ClientConfig}, + protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message}, + DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES, +}; +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + path::PathBuf, + str::FromStr, + time::Duration, + sync::Arc, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpStream, UdpSocket}, + signal, + sync::Notify, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long, default_value = "client_config.toml")] + config: PathBuf, + + #[arg(short, long, action)] + generate_default: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + // Generate default configs if requested + if args.generate_default { + ensure_default_configs()?; + return Ok(()); + } + + // Load configuration + let config = load_client_config(&args.config) + .context(format!("Failed to load config from {:?}", args.config))?; + + info!("Client configuration loaded from {:?}", args.config); + + // Create a notification for clean shutdown + let shutdown = Arc::new(Notify::new()); + let shutdown_signal = shutdown.clone(); + + // Setup signal handler for Ctrl+C + tokio::spawn(async move { + if let Err(e) = signal::ctrl_c().await { + error!("Failed to listen for Ctrl+C: {}", e); + return; + } + info!("Received Ctrl+C, shutting down..."); + shutdown_signal.notify_one(); + }); + + let server_addr = format!("{}:{}", config.server, config.port); + + // Main reconnection loop + loop { + info!("Connecting to server at {}", server_addr); + + // Try to connect with timeout + let connect_result = match tokio::time::timeout( + Duration::from_secs(5), + TcpStream::connect(&server_addr), + ).await { + Ok(result) => result, + Err(_) => { + warn!("Connection attempt timed out"); + if !handle_reconnect(&shutdown, &config).await { + break; + } + continue; + } + }; + + match connect_result { + Ok(stream) => { + info!("Connected to server"); + + // Run the client session + match run_client_session(stream, &config, &shutdown).await { + Ok(_) => { + info!("Client session ended normally"); + break; + }, + Err(e) => { + error!("Client session error: {}", e); + if !handle_reconnect(&shutdown, &config).await { + break; + } + } + } + }, + Err(e) => { + error!("Failed to connect: {}", e); + if !handle_reconnect(&shutdown, &config).await { + break; + } + } + } + } + + Ok(()) +} + +// Helper function to handle reconnection delay +// Returns false if shutdown was requested +async fn handle_reconnect(shutdown: &Arc<Notify>, client_config: &ClientConfig) -> bool { + let delay = Duration::from_secs(client_config.reconnect_delay_secs); + info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs); + tokio::select! { + _ = shutdown.notified() => { + info!("Shutdown requested during reconnect"); + return false; + } + _ = tokio::time::sleep(delay) => {} + } + true +} + +// Add a new enum to distinguish message types +#[derive(PartialEq)] +enum StatusMessageType { + ServerHeartbeat, + ServerPong, + Other +} + +// The main client session function that handles a single connection +async fn run_client_session( + mut stream: TcpStream, + config: &ClientConfig, + shutdown: &Arc<Notify>, +) -> Result<()> { + // Authenticate + if let Err(e) = authenticate(&mut stream, &config.secret).await { + return Err(anyhow::anyhow!("Authentication failed: {}", e)); + } + info!("Authentication successful"); + + // Check if test mode is enabled in config + let mut test_mode = config.test_mode; + + // Request server's multicast group information + let groups_request = Message::MulticastGroupsRequest; + let request_bytes = serialize_message(&groups_request)?; + stream.write_all(&request_bytes).await?; + + // Wait for response with multicast group information, ignoring other message types + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let mut groups_response = None; + + // Keep reading until we get the groups response or timeout + let mut attempts = 0; + while groups_response.is_none() && attempts < 10 { + match stream.read(&mut buf).await { + Ok(n) if n > 0 => { + match robust_deserialize_message(&buf[..n]) { + Ok(Message::MulticastGroupsResponse { groups }) => { + groups_response = Some(groups); + }, + Ok(Message::PingStatus { timestamp: _, status }) => { + // Handle ping but keep waiting for multicast groups + info!("Got server ping: {}", status); + }, + Ok(other_msg) => { + debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg); + }, + Err(e) => { + error!("Failed to deserialize message: {}", e); + } + } + }, + Ok(0) => return Err(anyhow::anyhow!("Server closed connection")), + Ok(_) => {}, + Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e)) + } + + // If we didn't get groups yet, wait a bit and try again + if groups_response.is_none() { + attempts += 1; + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + // Now we either have the groups or we timed out + let groups_response = groups_response.ok_or_else(|| + anyhow::anyhow!("Failed to receive multicast group information from server"))?; + + info!("Available multicast groups from server:"); + for (id, group) in &groups_response { + info!(" - {} -> {}:{}", id, group.address, group.port); + } + + // Determine which groups to subscribe to + let groups_to_subscribe: Vec<String> = if config.multicast_group_ids.is_empty() { + // If no specific groups are requested, subscribe to all + info!("No specific groups requested, subscribing to all available groups"); + groups_response.keys().cloned().collect() + } else { + // Otherwise, subscribe only to requested groups that exist + let mut valid_groups = Vec::new(); + for group_id in &config.multicast_group_ids { + if groups_response.contains_key(group_id.as_str()) { + valid_groups.push(group_id.clone()); + } else { + warn!("Requested group '{}' does not exist on server", group_id); + } + } + + if valid_groups.is_empty() { + warn!("None of the requested groups exist on server. No data will be received."); + } else { + info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups); + } + + valid_groups + }; + + // Send subscription message + let subscribe_msg = Message::Subscribe { + group_ids: groups_to_subscribe.clone(), + }; + let subscribe_bytes = serialize_message(&subscribe_msg)?; + stream.write_all(&subscribe_bytes).await?; + + // Create UDP sockets for local retransmission (skip if in test mode) + let mut sockets: HashMap<String, UdpSocket> = HashMap::new(); + if !test_mode { + for group_id in &groups_to_subscribe { + if let Some(group_info) = groups_response.get(group_id.as_str()) { + info!("Creating socket for group {} ({} on port {})", + group_id, group_info.address, group_info.port); + + match UdpSocket::bind("0.0.0.0:0").await { + Ok(socket) => { + sockets.insert(group_id.clone(), socket); + info!("Successfully created UDP socket for group {}", group_id); + }, + Err(e) => { + error!("Failed to create UDP socket for group {}: {}", group_id, e); + } + } + } + } + + if sockets.is_empty() && !groups_to_subscribe.is_empty() { + error!("Failed to create any UDP sockets"); + warn!("Falling back to test mode due to socket creation failure"); + test_mode = true; + } + } + + // Display test mode banner if enabled + if test_mode { + println!("{}", TEST_MODE_BANNER); + info!("Test mode enabled - packets will be displayed but not sent to network"); + } + + // Main receive loop + info!("Listening for multicast traffic from server"); + + // Set the read timeout for the stream + stream.set_nodelay(true)?; + + // Remove problematic code that uses unsupported methods and the nix crate + if config.nat_traversal { + info!("NAT traversal mode enabled - using more frequent keepalives"); + } + + // Calculate appropriate ping interval based on NAT traversal setting + let ping_interval = if config.nat_traversal { + Duration::from_secs(25) // More frequent for NAT + } else { + Duration::from_secs(55) + }; + + // Main receive loop + loop { + tokio::select! { + _ = shutdown.notified() => { + info!("Shutdown requested, ending client session"); + return Ok(()); + } + read_result = stream.read(&mut buf) => { + match read_result { + Ok(0) => { + info!("Server closed connection"); + return Err(anyhow::anyhow!("Server closed connection")); + } + Ok(n) => { + match robust_deserialize_message(&buf[..n]) { + Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => { + if test_mode { + println!("\n----- MULTICAST PACKET -----"); + println!("Group: {}", group_id); + println!("Source: {}", source); + println!("Destination: {}:{}", destination, port); + println!("Size: {} bytes", data.len()); + println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES)); + println!("---------------------------\n"); + } else { + info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes", + group_id, source, destination, port, data.len()); + + // Get socket for this group + if let Some(socket) = sockets.get(&group_id) { + // Parse destination address directly from the packet + match IpAddr::from_str(&destination) { + Ok(dest_addr) => { + // Create destination socket address using the packet's port + let dest = SocketAddr::new(dest_addr, port); + + info!("Forwarding packet to {}:{}", dest_addr, port); + + // Retransmit locally + match socket.send_to(&data, dest).await { + Ok(sent) => { + info!("Successfully forwarded {} of {} bytes to {}:{}", + sent, data.len(), destination, port); + }, + Err(e) => { + error!("Failed to retransmit packet for group {} to {}:{}: {}", + group_id, destination, port, e); + } + } + }, + Err(e) => { + error!("Invalid destination address {}: {}", destination, e); + } + } + } else { + warn!("No socket available for group {}", group_id); + } + } + }, + Ok(Message::ConfigResponse { config: _server_config }) => { + info!("Server configuration received"); + }, + Ok(Message::PingStatus { timestamp: _, status }) => { + // Parse the message type based on the status text + let msg_type = if status.starts_with("Server heartbeat to") { + StatusMessageType::ServerHeartbeat + } else if status.starts_with("Server connection to") { + StatusMessageType::ServerPong + } else { + StatusMessageType::Other + }; + + // Only respond with a pong to server heartbeats + if msg_type == StatusMessageType::ServerHeartbeat { + debug!("Received server heartbeat: {}", status); + + // Send a pong response (but only for heartbeats) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let pong = Message::PingStatus { + timestamp: now, + status: "Client pong response".to_string(), + }; + + if let Ok(bytes) = serialize_message(&pong) { + let _ = stream.write_all(&bytes).await; + } + } else { + // Just log other status messages without responding + debug!("Connection Status: {}", status); + } + }, + Ok(_) => debug!("Received other message type"), + Err(e) => { + error!("Failed to deserialize message: {}", e); + // Don't return/break on deserialization errors - continue reading + } + } + } + Err(e) => { + error!("Error reading from server: {}", e); + return Err(anyhow::anyhow!("Connection error: {}", e)); + } + } + }, + _ = tokio::time::sleep(ping_interval) => { + // Send regular ping to keep NAT connection alive + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let ping_msg = Message::PingStatus { + timestamp: now, + status: if config.nat_traversal { + "Client keepalive ping".to_string() // Changed text to be more specific + } else { + "Client periodic ping".to_string() + }, + }; + + match serialize_message(&ping_msg) { + Ok(bytes) => { + if let Err(e) = stream.write_all(&bytes).await { + error!("Failed to ping server: {}", e); + return Err(anyhow::anyhow!("Server ping failed: {}", e)); + } + debug!("Connection check sent to server"); + }, + Err(e) => error!("Failed to serialize ping message: {}", e), + } + } + } + } +} + +async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> { + // Generate client nonce + let client_nonce = generate_nonce(); + + // Send auth request + let auth_request = Message::AuthRequest { + client_nonce: client_nonce.clone(), + }; + let request_bytes = serialize_message(&auth_request)?; + stream.write_all(&request_bytes).await?; + + // Receive response + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let n = stream.read(&mut buf).await?; + let response = robust_deserialize_message(&buf[..n])?; + + if let Message::AuthResponse { server_nonce, auth_token } = response { + // Verify server's token + let expected_data = format!("{}{}", client_nonce, server_nonce); + if !verify_hmac(secret, &expected_data, &auth_token) { + return Err(anyhow::anyhow!("Server authentication failed")); + } + + // Calculate our token + let auth_data = format!("{}{}", server_nonce, client_nonce); + let client_token = calculate_hmac(secret, &auth_data); + + // Send confirmation + let confirm = Message::AuthConfirm { + auth_token: client_token, + }; + + let confirm_bytes = serialize_message(&confirm)?; + stream.write_all(&confirm_bytes).await?; + + Ok(()) + } else { + Err(anyhow::anyhow!("Unexpected response from server")) + } +} |