aboutsummaryrefslogtreecommitdiff
path: root/src/bin/client.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/client.rs')
-rw-r--r--src/bin/client.rs467
1 files changed, 467 insertions, 0 deletions
diff --git a/src/bin/client.rs b/src/bin/client.rs
new file mode 100644
index 0000000..5c39fb9
--- /dev/null
+++ b/src/bin/client.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_client_config, ensure_default_configs, ClientConfig},
+ protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message},
+ DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES,
+};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ time::Duration,
+ sync::Arc,
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpStream, UdpSocket},
+ signal,
+ sync::Notify,
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "client_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_client_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Client configuration loaded from {:?}", args.config);
+
+ // Create a notification for clean shutdown
+ let shutdown = Arc::new(Notify::new());
+ let shutdown_signal = shutdown.clone();
+
+ // Setup signal handler for Ctrl+C
+ tokio::spawn(async move {
+ if let Err(e) = signal::ctrl_c().await {
+ error!("Failed to listen for Ctrl+C: {}", e);
+ return;
+ }
+ info!("Received Ctrl+C, shutting down...");
+ shutdown_signal.notify_one();
+ });
+
+ let server_addr = format!("{}:{}", config.server, config.port);
+
+ // Main reconnection loop
+ loop {
+ info!("Connecting to server at {}", server_addr);
+
+ // Try to connect with timeout
+ let connect_result = match tokio::time::timeout(
+ Duration::from_secs(5),
+ TcpStream::connect(&server_addr),
+ ).await {
+ Ok(result) => result,
+ Err(_) => {
+ warn!("Connection attempt timed out");
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ continue;
+ }
+ };
+
+ match connect_result {
+ Ok(stream) => {
+ info!("Connected to server");
+
+ // Run the client session
+ match run_client_session(stream, &config, &shutdown).await {
+ Ok(_) => {
+ info!("Client session ended normally");
+ break;
+ },
+ Err(e) => {
+ error!("Client session error: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ },
+ Err(e) => {
+ error!("Failed to connect: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+// Helper function to handle reconnection delay
+// Returns false if shutdown was requested
+async fn handle_reconnect(shutdown: &Arc<Notify>, client_config: &ClientConfig) -> bool {
+ let delay = Duration::from_secs(client_config.reconnect_delay_secs);
+ info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs);
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested during reconnect");
+ return false;
+ }
+ _ = tokio::time::sleep(delay) => {}
+ }
+ true
+}
+
+// Add a new enum to distinguish message types
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ServerHeartbeat,
+ ServerPong,
+ Other
+}
+
+// The main client session function that handles a single connection
+async fn run_client_session(
+ mut stream: TcpStream,
+ config: &ClientConfig,
+ shutdown: &Arc<Notify>,
+) -> Result<()> {
+ // Authenticate
+ if let Err(e) = authenticate(&mut stream, &config.secret).await {
+ return Err(anyhow::anyhow!("Authentication failed: {}", e));
+ }
+ info!("Authentication successful");
+
+ // Check if test mode is enabled in config
+ let mut test_mode = config.test_mode;
+
+ // Request server's multicast group information
+ let groups_request = Message::MulticastGroupsRequest;
+ let request_bytes = serialize_message(&groups_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Wait for response with multicast group information, ignoring other message types
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let mut groups_response = None;
+
+ // Keep reading until we get the groups response or timeout
+ let mut attempts = 0;
+ while groups_response.is_none() && attempts < 10 {
+ match stream.read(&mut buf).await {
+ Ok(n) if n > 0 => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastGroupsResponse { groups }) => {
+ groups_response = Some(groups);
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Handle ping but keep waiting for multicast groups
+ info!("Got server ping: {}", status);
+ },
+ Ok(other_msg) => {
+ debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg);
+ },
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ }
+ }
+ },
+ Ok(0) => return Err(anyhow::anyhow!("Server closed connection")),
+ Ok(_) => {},
+ Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e))
+ }
+
+ // If we didn't get groups yet, wait a bit and try again
+ if groups_response.is_none() {
+ attempts += 1;
+ tokio::time::sleep(Duration::from_millis(100)).await;
+ }
+ }
+
+ // Now we either have the groups or we timed out
+ let groups_response = groups_response.ok_or_else(||
+ anyhow::anyhow!("Failed to receive multicast group information from server"))?;
+
+ info!("Available multicast groups from server:");
+ for (id, group) in &groups_response {
+ info!(" - {} -> {}:{}", id, group.address, group.port);
+ }
+
+ // Determine which groups to subscribe to
+ let groups_to_subscribe: Vec<String> = if config.multicast_group_ids.is_empty() {
+ // If no specific groups are requested, subscribe to all
+ info!("No specific groups requested, subscribing to all available groups");
+ groups_response.keys().cloned().collect()
+ } else {
+ // Otherwise, subscribe only to requested groups that exist
+ let mut valid_groups = Vec::new();
+ for group_id in &config.multicast_group_ids {
+ if groups_response.contains_key(group_id.as_str()) {
+ valid_groups.push(group_id.clone());
+ } else {
+ warn!("Requested group '{}' does not exist on server", group_id);
+ }
+ }
+
+ if valid_groups.is_empty() {
+ warn!("None of the requested groups exist on server. No data will be received.");
+ } else {
+ info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups);
+ }
+
+ valid_groups
+ };
+
+ // Send subscription message
+ let subscribe_msg = Message::Subscribe {
+ group_ids: groups_to_subscribe.clone(),
+ };
+ let subscribe_bytes = serialize_message(&subscribe_msg)?;
+ stream.write_all(&subscribe_bytes).await?;
+
+ // Create UDP sockets for local retransmission (skip if in test mode)
+ let mut sockets: HashMap<String, UdpSocket> = HashMap::new();
+ if !test_mode {
+ for group_id in &groups_to_subscribe {
+ if let Some(group_info) = groups_response.get(group_id.as_str()) {
+ info!("Creating socket for group {} ({} on port {})",
+ group_id, group_info.address, group_info.port);
+
+ match UdpSocket::bind("0.0.0.0:0").await {
+ Ok(socket) => {
+ sockets.insert(group_id.clone(), socket);
+ info!("Successfully created UDP socket for group {}", group_id);
+ },
+ Err(e) => {
+ error!("Failed to create UDP socket for group {}: {}", group_id, e);
+ }
+ }
+ }
+ }
+
+ if sockets.is_empty() && !groups_to_subscribe.is_empty() {
+ error!("Failed to create any UDP sockets");
+ warn!("Falling back to test mode due to socket creation failure");
+ test_mode = true;
+ }
+ }
+
+ // Display test mode banner if enabled
+ if test_mode {
+ println!("{}", TEST_MODE_BANNER);
+ info!("Test mode enabled - packets will be displayed but not sent to network");
+ }
+
+ // Main receive loop
+ info!("Listening for multicast traffic from server");
+
+ // Set the read timeout for the stream
+ stream.set_nodelay(true)?;
+
+ // Remove problematic code that uses unsupported methods and the nix crate
+ if config.nat_traversal {
+ info!("NAT traversal mode enabled - using more frequent keepalives");
+ }
+
+ // Calculate appropriate ping interval based on NAT traversal setting
+ let ping_interval = if config.nat_traversal {
+ Duration::from_secs(25) // More frequent for NAT
+ } else {
+ Duration::from_secs(55)
+ };
+
+ // Main receive loop
+ loop {
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested, ending client session");
+ return Ok(());
+ }
+ read_result = stream.read(&mut buf) => {
+ match read_result {
+ Ok(0) => {
+ info!("Server closed connection");
+ return Err(anyhow::anyhow!("Server closed connection"));
+ }
+ Ok(n) => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => {
+ if test_mode {
+ println!("\n----- MULTICAST PACKET -----");
+ println!("Group: {}", group_id);
+ println!("Source: {}", source);
+ println!("Destination: {}:{}", destination, port);
+ println!("Size: {} bytes", data.len());
+ println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES));
+ println!("---------------------------\n");
+ } else {
+ info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes",
+ group_id, source, destination, port, data.len());
+
+ // Get socket for this group
+ if let Some(socket) = sockets.get(&group_id) {
+ // Parse destination address directly from the packet
+ match IpAddr::from_str(&destination) {
+ Ok(dest_addr) => {
+ // Create destination socket address using the packet's port
+ let dest = SocketAddr::new(dest_addr, port);
+
+ info!("Forwarding packet to {}:{}", dest_addr, port);
+
+ // Retransmit locally
+ match socket.send_to(&data, dest).await {
+ Ok(sent) => {
+ info!("Successfully forwarded {} of {} bytes to {}:{}",
+ sent, data.len(), destination, port);
+ },
+ Err(e) => {
+ error!("Failed to retransmit packet for group {} to {}:{}: {}",
+ group_id, destination, port, e);
+ }
+ }
+ },
+ Err(e) => {
+ error!("Invalid destination address {}: {}", destination, e);
+ }
+ }
+ } else {
+ warn!("No socket available for group {}", group_id);
+ }
+ }
+ },
+ Ok(Message::ConfigResponse { config: _server_config }) => {
+ info!("Server configuration received");
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Parse the message type based on the status text
+ let msg_type = if status.starts_with("Server heartbeat to") {
+ StatusMessageType::ServerHeartbeat
+ } else if status.starts_with("Server connection to") {
+ StatusMessageType::ServerPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Only respond with a pong to server heartbeats
+ if msg_type == StatusMessageType::ServerHeartbeat {
+ debug!("Received server heartbeat: {}", status);
+
+ // Send a pong response (but only for heartbeats)
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let pong = Message::PingStatus {
+ timestamp: now,
+ status: "Client pong response".to_string(),
+ };
+
+ if let Ok(bytes) = serialize_message(&pong) {
+ let _ = stream.write_all(&bytes).await;
+ }
+ } else {
+ // Just log other status messages without responding
+ debug!("Connection Status: {}", status);
+ }
+ },
+ Ok(_) => debug!("Received other message type"),
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ // Don't return/break on deserialization errors - continue reading
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from server: {}", e);
+ return Err(anyhow::anyhow!("Connection error: {}", e));
+ }
+ }
+ },
+ _ = tokio::time::sleep(ping_interval) => {
+ // Send regular ping to keep NAT connection alive
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let ping_msg = Message::PingStatus {
+ timestamp: now,
+ status: if config.nat_traversal {
+ "Client keepalive ping".to_string() // Changed text to be more specific
+ } else {
+ "Client periodic ping".to_string()
+ },
+ };
+
+ match serialize_message(&ping_msg) {
+ Ok(bytes) => {
+ if let Err(e) = stream.write_all(&bytes).await {
+ error!("Failed to ping server: {}", e);
+ return Err(anyhow::anyhow!("Server ping failed: {}", e));
+ }
+ debug!("Connection check sent to server");
+ },
+ Err(e) => error!("Failed to serialize ping message: {}", e),
+ }
+ }
+ }
+ }
+}
+
+async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> {
+ // Generate client nonce
+ let client_nonce = generate_nonce();
+
+ // Send auth request
+ let auth_request = Message::AuthRequest {
+ client_nonce: client_nonce.clone(),
+ };
+ let request_bytes = serialize_message(&auth_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Receive response
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let n = stream.read(&mut buf).await?;
+ let response = robust_deserialize_message(&buf[..n])?;
+
+ if let Message::AuthResponse { server_nonce, auth_token } = response {
+ // Verify server's token
+ let expected_data = format!("{}{}", client_nonce, server_nonce);
+ if !verify_hmac(secret, &expected_data, &auth_token) {
+ return Err(anyhow::anyhow!("Server authentication failed"));
+ }
+
+ // Calculate our token
+ let auth_data = format!("{}{}", server_nonce, client_nonce);
+ let client_token = calculate_hmac(secret, &auth_data);
+
+ // Send confirmation
+ let confirm = Message::AuthConfirm {
+ auth_token: client_token,
+ };
+
+ let confirm_bytes = serialize_message(&confirm)?;
+ stream.write_all(&confirm_bytes).await?;
+
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Unexpected response from server"))
+ }
+}