aboutsummaryrefslogtreecommitdiff
path: root/src/bin
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin')
-rw-r--r--src/bin/client.rs467
-rw-r--r--src/bin/mcast_test.rs123
-rw-r--r--src/bin/server.rs467
3 files changed, 1057 insertions, 0 deletions
diff --git a/src/bin/client.rs b/src/bin/client.rs
new file mode 100644
index 0000000..5c39fb9
--- /dev/null
+++ b/src/bin/client.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_client_config, ensure_default_configs, ClientConfig},
+ protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message},
+ DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES,
+};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ time::Duration,
+ sync::Arc,
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpStream, UdpSocket},
+ signal,
+ sync::Notify,
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "client_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_client_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Client configuration loaded from {:?}", args.config);
+
+ // Create a notification for clean shutdown
+ let shutdown = Arc::new(Notify::new());
+ let shutdown_signal = shutdown.clone();
+
+ // Setup signal handler for Ctrl+C
+ tokio::spawn(async move {
+ if let Err(e) = signal::ctrl_c().await {
+ error!("Failed to listen for Ctrl+C: {}", e);
+ return;
+ }
+ info!("Received Ctrl+C, shutting down...");
+ shutdown_signal.notify_one();
+ });
+
+ let server_addr = format!("{}:{}", config.server, config.port);
+
+ // Main reconnection loop
+ loop {
+ info!("Connecting to server at {}", server_addr);
+
+ // Try to connect with timeout
+ let connect_result = match tokio::time::timeout(
+ Duration::from_secs(5),
+ TcpStream::connect(&server_addr),
+ ).await {
+ Ok(result) => result,
+ Err(_) => {
+ warn!("Connection attempt timed out");
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ continue;
+ }
+ };
+
+ match connect_result {
+ Ok(stream) => {
+ info!("Connected to server");
+
+ // Run the client session
+ match run_client_session(stream, &config, &shutdown).await {
+ Ok(_) => {
+ info!("Client session ended normally");
+ break;
+ },
+ Err(e) => {
+ error!("Client session error: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ },
+ Err(e) => {
+ error!("Failed to connect: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+// Helper function to handle reconnection delay
+// Returns false if shutdown was requested
+async fn handle_reconnect(shutdown: &Arc<Notify>, client_config: &ClientConfig) -> bool {
+ let delay = Duration::from_secs(client_config.reconnect_delay_secs);
+ info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs);
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested during reconnect");
+ return false;
+ }
+ _ = tokio::time::sleep(delay) => {}
+ }
+ true
+}
+
+// Add a new enum to distinguish message types
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ServerHeartbeat,
+ ServerPong,
+ Other
+}
+
+// The main client session function that handles a single connection
+async fn run_client_session(
+ mut stream: TcpStream,
+ config: &ClientConfig,
+ shutdown: &Arc<Notify>,
+) -> Result<()> {
+ // Authenticate
+ if let Err(e) = authenticate(&mut stream, &config.secret).await {
+ return Err(anyhow::anyhow!("Authentication failed: {}", e));
+ }
+ info!("Authentication successful");
+
+ // Check if test mode is enabled in config
+ let mut test_mode = config.test_mode;
+
+ // Request server's multicast group information
+ let groups_request = Message::MulticastGroupsRequest;
+ let request_bytes = serialize_message(&groups_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Wait for response with multicast group information, ignoring other message types
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let mut groups_response = None;
+
+ // Keep reading until we get the groups response or timeout
+ let mut attempts = 0;
+ while groups_response.is_none() && attempts < 10 {
+ match stream.read(&mut buf).await {
+ Ok(n) if n > 0 => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastGroupsResponse { groups }) => {
+ groups_response = Some(groups);
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Handle ping but keep waiting for multicast groups
+ info!("Got server ping: {}", status);
+ },
+ Ok(other_msg) => {
+ debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg);
+ },
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ }
+ }
+ },
+ Ok(0) => return Err(anyhow::anyhow!("Server closed connection")),
+ Ok(_) => {},
+ Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e))
+ }
+
+ // If we didn't get groups yet, wait a bit and try again
+ if groups_response.is_none() {
+ attempts += 1;
+ tokio::time::sleep(Duration::from_millis(100)).await;
+ }
+ }
+
+ // Now we either have the groups or we timed out
+ let groups_response = groups_response.ok_or_else(||
+ anyhow::anyhow!("Failed to receive multicast group information from server"))?;
+
+ info!("Available multicast groups from server:");
+ for (id, group) in &groups_response {
+ info!(" - {} -> {}:{}", id, group.address, group.port);
+ }
+
+ // Determine which groups to subscribe to
+ let groups_to_subscribe: Vec<String> = if config.multicast_group_ids.is_empty() {
+ // If no specific groups are requested, subscribe to all
+ info!("No specific groups requested, subscribing to all available groups");
+ groups_response.keys().cloned().collect()
+ } else {
+ // Otherwise, subscribe only to requested groups that exist
+ let mut valid_groups = Vec::new();
+ for group_id in &config.multicast_group_ids {
+ if groups_response.contains_key(group_id.as_str()) {
+ valid_groups.push(group_id.clone());
+ } else {
+ warn!("Requested group '{}' does not exist on server", group_id);
+ }
+ }
+
+ if valid_groups.is_empty() {
+ warn!("None of the requested groups exist on server. No data will be received.");
+ } else {
+ info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups);
+ }
+
+ valid_groups
+ };
+
+ // Send subscription message
+ let subscribe_msg = Message::Subscribe {
+ group_ids: groups_to_subscribe.clone(),
+ };
+ let subscribe_bytes = serialize_message(&subscribe_msg)?;
+ stream.write_all(&subscribe_bytes).await?;
+
+ // Create UDP sockets for local retransmission (skip if in test mode)
+ let mut sockets: HashMap<String, UdpSocket> = HashMap::new();
+ if !test_mode {
+ for group_id in &groups_to_subscribe {
+ if let Some(group_info) = groups_response.get(group_id.as_str()) {
+ info!("Creating socket for group {} ({} on port {})",
+ group_id, group_info.address, group_info.port);
+
+ match UdpSocket::bind("0.0.0.0:0").await {
+ Ok(socket) => {
+ sockets.insert(group_id.clone(), socket);
+ info!("Successfully created UDP socket for group {}", group_id);
+ },
+ Err(e) => {
+ error!("Failed to create UDP socket for group {}: {}", group_id, e);
+ }
+ }
+ }
+ }
+
+ if sockets.is_empty() && !groups_to_subscribe.is_empty() {
+ error!("Failed to create any UDP sockets");
+ warn!("Falling back to test mode due to socket creation failure");
+ test_mode = true;
+ }
+ }
+
+ // Display test mode banner if enabled
+ if test_mode {
+ println!("{}", TEST_MODE_BANNER);
+ info!("Test mode enabled - packets will be displayed but not sent to network");
+ }
+
+ // Main receive loop
+ info!("Listening for multicast traffic from server");
+
+ // Set the read timeout for the stream
+ stream.set_nodelay(true)?;
+
+ // Remove problematic code that uses unsupported methods and the nix crate
+ if config.nat_traversal {
+ info!("NAT traversal mode enabled - using more frequent keepalives");
+ }
+
+ // Calculate appropriate ping interval based on NAT traversal setting
+ let ping_interval = if config.nat_traversal {
+ Duration::from_secs(25) // More frequent for NAT
+ } else {
+ Duration::from_secs(55)
+ };
+
+ // Main receive loop
+ loop {
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested, ending client session");
+ return Ok(());
+ }
+ read_result = stream.read(&mut buf) => {
+ match read_result {
+ Ok(0) => {
+ info!("Server closed connection");
+ return Err(anyhow::anyhow!("Server closed connection"));
+ }
+ Ok(n) => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => {
+ if test_mode {
+ println!("\n----- MULTICAST PACKET -----");
+ println!("Group: {}", group_id);
+ println!("Source: {}", source);
+ println!("Destination: {}:{}", destination, port);
+ println!("Size: {} bytes", data.len());
+ println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES));
+ println!("---------------------------\n");
+ } else {
+ info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes",
+ group_id, source, destination, port, data.len());
+
+ // Get socket for this group
+ if let Some(socket) = sockets.get(&group_id) {
+ // Parse destination address directly from the packet
+ match IpAddr::from_str(&destination) {
+ Ok(dest_addr) => {
+ // Create destination socket address using the packet's port
+ let dest = SocketAddr::new(dest_addr, port);
+
+ info!("Forwarding packet to {}:{}", dest_addr, port);
+
+ // Retransmit locally
+ match socket.send_to(&data, dest).await {
+ Ok(sent) => {
+ info!("Successfully forwarded {} of {} bytes to {}:{}",
+ sent, data.len(), destination, port);
+ },
+ Err(e) => {
+ error!("Failed to retransmit packet for group {} to {}:{}: {}",
+ group_id, destination, port, e);
+ }
+ }
+ },
+ Err(e) => {
+ error!("Invalid destination address {}: {}", destination, e);
+ }
+ }
+ } else {
+ warn!("No socket available for group {}", group_id);
+ }
+ }
+ },
+ Ok(Message::ConfigResponse { config: _server_config }) => {
+ info!("Server configuration received");
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Parse the message type based on the status text
+ let msg_type = if status.starts_with("Server heartbeat to") {
+ StatusMessageType::ServerHeartbeat
+ } else if status.starts_with("Server connection to") {
+ StatusMessageType::ServerPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Only respond with a pong to server heartbeats
+ if msg_type == StatusMessageType::ServerHeartbeat {
+ debug!("Received server heartbeat: {}", status);
+
+ // Send a pong response (but only for heartbeats)
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let pong = Message::PingStatus {
+ timestamp: now,
+ status: "Client pong response".to_string(),
+ };
+
+ if let Ok(bytes) = serialize_message(&pong) {
+ let _ = stream.write_all(&bytes).await;
+ }
+ } else {
+ // Just log other status messages without responding
+ debug!("Connection Status: {}", status);
+ }
+ },
+ Ok(_) => debug!("Received other message type"),
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ // Don't return/break on deserialization errors - continue reading
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from server: {}", e);
+ return Err(anyhow::anyhow!("Connection error: {}", e));
+ }
+ }
+ },
+ _ = tokio::time::sleep(ping_interval) => {
+ // Send regular ping to keep NAT connection alive
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let ping_msg = Message::PingStatus {
+ timestamp: now,
+ status: if config.nat_traversal {
+ "Client keepalive ping".to_string() // Changed text to be more specific
+ } else {
+ "Client periodic ping".to_string()
+ },
+ };
+
+ match serialize_message(&ping_msg) {
+ Ok(bytes) => {
+ if let Err(e) = stream.write_all(&bytes).await {
+ error!("Failed to ping server: {}", e);
+ return Err(anyhow::anyhow!("Server ping failed: {}", e));
+ }
+ debug!("Connection check sent to server");
+ },
+ Err(e) => error!("Failed to serialize ping message: {}", e),
+ }
+ }
+ }
+ }
+}
+
+async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> {
+ // Generate client nonce
+ let client_nonce = generate_nonce();
+
+ // Send auth request
+ let auth_request = Message::AuthRequest {
+ client_nonce: client_nonce.clone(),
+ };
+ let request_bytes = serialize_message(&auth_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Receive response
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let n = stream.read(&mut buf).await?;
+ let response = robust_deserialize_message(&buf[..n])?;
+
+ if let Message::AuthResponse { server_nonce, auth_token } = response {
+ // Verify server's token
+ let expected_data = format!("{}{}", client_nonce, server_nonce);
+ if !verify_hmac(secret, &expected_data, &auth_token) {
+ return Err(anyhow::anyhow!("Server authentication failed"));
+ }
+
+ // Calculate our token
+ let auth_data = format!("{}{}", server_nonce, client_nonce);
+ let client_token = calculate_hmac(secret, &auth_data);
+
+ // Send confirmation
+ let confirm = Message::AuthConfirm {
+ auth_token: client_token,
+ };
+
+ let confirm_bytes = serialize_message(&confirm)?;
+ stream.write_all(&confirm_bytes).await?;
+
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Unexpected response from server"))
+ }
+}
diff --git a/src/bin/mcast_test.rs b/src/bin/mcast_test.rs
new file mode 100644
index 0000000..782610c
--- /dev/null
+++ b/src/bin/mcast_test.rs
@@ -0,0 +1,123 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{error, info};
+use socket2::{Domain, Protocol, Socket, Type};
+use std::net::{IpAddr, Ipv4Addr, SocketAddr};
+use std::str::FromStr;
+use std::time::{Duration, Instant};
+use tokio::net::UdpSocket;
+
+#[derive(Parser, Debug)]
+#[command(author, version, about = "Multicast packet generator for testing CastRepeat")]
+struct Args {
+ #[arg(short, long, default_value = "239.192.55.1")]
+ multicast_addr: String,
+
+ #[arg(short, long, default_value = "1681")]
+ port: u16,
+
+ #[arg(short, long, default_value = "1000")]
+ interval_ms: u64,
+
+ #[arg(short, long, default_value = "60")]
+ duration_sec: u64,
+
+ #[arg(short, long, default_value = "Test packet")]
+ message: String,
+
+ #[arg(short, long)]
+ interface: Option<String>,
+
+ #[arg(short = 'b', long, action)]
+ binary_mode: bool,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ info!("CastRepeat Multicast Packet Generator");
+ info!("--------------------------------");
+ info!("Multicast Address: {}", args.multicast_addr);
+ info!("Port: {}", args.port);
+ info!("Interval: {} ms", args.interval_ms);
+ info!("Duration: {} sec", args.duration_sec);
+ if let Some(interface) = &args.interface {
+ info!("Interface: {}", interface);
+ }
+ info!("Mode: {}", if args.binary_mode { "Binary" } else { "Text" });
+ info!("--------------------------------");
+
+ // Verify multicast address
+ let mcast_addr = match IpAddr::from_str(&args.multicast_addr) {
+ Ok(IpAddr::V4(addr)) if addr.is_multicast() => addr,
+ _ => {
+ error!("Invalid multicast address: {}", args.multicast_addr);
+ return Ok(());
+ }
+ };
+
+ // Create socket
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
+ .context("Failed to create socket")?;
+
+ socket.set_multicast_ttl_v4(4)?;
+ socket.set_nonblocking(true)?;
+
+ // Set the multicast interface if specified
+ if let Some(if_str) = &args.interface {
+ if let Ok(if_addr) = Ipv4Addr::from_str(if_str) {
+ socket.set_multicast_if_v4(&if_addr)?;
+ info!("Using interface: {}", if_addr);
+ } else {
+ error!("Invalid interface address: {}", if_str);
+ return Ok(());
+ }
+ }
+
+ let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
+ socket.bind(&addr.into())?;
+
+ let socket = UdpSocket::from_std(socket.into())?;
+ let dest_addr = SocketAddr::new(IpAddr::V4(mcast_addr), args.port);
+
+ // Start sending packets
+ info!("Sending packets to {}...", dest_addr);
+
+ let start = Instant::now();
+ let end = start + Duration::from_secs(args.duration_sec);
+ let mut counter: u64 = 0; // Specify counter type as u64
+
+ while Instant::now() < end {
+ counter += 1;
+
+ // Create either text or binary test data
+ let data = if args.binary_mode {
+ // Create binary test data (similar to what we might see in the field)
+ let mut packet = Vec::with_capacity(16);
+ packet.extend_from_slice(b"REL\0"); // 4 bytes header
+ packet.extend_from_slice(&counter.to_be_bytes()[4..]); // 4 bytes counter
+ packet.extend_from_slice(&[0x00, 0x10, 0x02, 0x00]); // 4 bytes
+ packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x90]); // 4 bytes
+ packet
+ } else {
+ // Create text test data
+ format!("{} #{}", args.message, counter).into_bytes()
+ };
+
+ match socket.send_to(&data, &dest_addr).await {
+ Ok(bytes) => {
+ info!("Sent packet #{}: {} bytes", counter, bytes);
+ }
+ Err(e) => {
+ error!("Error sending packet: {}", e);
+ }
+ }
+
+ tokio::time::sleep(Duration::from_millis(args.interval_ms)).await;
+ }
+
+ info!("Done. Sent {} packets in {} seconds.", counter, args.duration_sec);
+ Ok(())
+}
diff --git a/src/bin/server.rs b/src/bin/server.rs
new file mode 100644
index 0000000..7fe0109
--- /dev/null
+++ b/src/bin/server.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups},
+ protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo},
+ DEFAULT_BUFFER_SIZE,
+};
+use socket2::{Domain, Protocol, Socket, Type};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ sync::Arc,
+ time::Duration, // Add this import for Duration
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpListener, TcpStream},
+ sync::{mpsc, Mutex},
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "server_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+type ClientMap = Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>;
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_server_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Server configuration loaded from {:?}", args.config);
+
+ let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port);
+ let listener = TcpListener::bind(&listen_addr).await
+ .context("Failed to bind TCP listener")?;
+
+ info!("Server listening on {}", listen_addr);
+
+ // Setup multicast receivers
+ let clients: ClientMap = Arc::new(Mutex::new(HashMap::new()));
+
+ // Start multicast listeners for each multicast group
+ for (group_id, group) in &config.multicast_groups {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ error!("No ports defined for group {}", group_id);
+ continue;
+ }
+
+ let display_group_id = group_id.clone();
+ let ports_display = if ports.len() == 1 {
+ format!("port {}", ports[0])
+ } else {
+ format!("ports {}-{}", ports[0], ports.last().unwrap())
+ };
+
+ // Create a listener for each port in the range
+ for port in ports {
+ let clients = clients.clone();
+ let _secret = config.secret.clone();
+ let group_id_clone = group_id.clone();
+ let mut group_info = group.clone();
+
+ // Set the specific port for this listener
+ group_info.port = Some(port);
+ group_info.port_range = None;
+
+ tokio::spawn(async move {
+ if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await {
+ error!("Multicast listener error for group {} port {}: {}",
+ group_id_clone, port, e);
+ }
+ });
+ }
+
+ info!("Listening for multicast group {} on address {} with {}",
+ display_group_id, group.address, ports_display);
+ }
+
+ // Store config for use in client handlers
+ let config = Arc::new(config);
+
+ // Accept client connections
+ while let Ok((stream, addr)) = listener.accept().await {
+ info!("New client connection from: {}", addr);
+ let secret = config.secret.clone();
+ let clients = clients.clone();
+ let config = config.clone();
+
+ tokio::spawn(async move {
+ if let Err(e) = handle_client(stream, addr, &secret, clients, config).await {
+ error!("Client error: {}: {}", addr, e);
+ }
+ info!("Client disconnected: {}", addr);
+ });
+ }
+
+ Ok(())
+}
+
+async fn listen_to_multicast(
+ group_id: &str,
+ group: &MulticastGroup,
+ clients: ClientMap
+) -> Result<()> {
+ // Get the port to use
+ let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?;
+
+ // Parse the multicast address
+ let mcast_ip = match IpAddr::from_str(&group.address)
+ .context("Invalid multicast address")? {
+ IpAddr::V4(addr) => addr,
+ _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported"))
+ };
+
+ // Create a UDP socket with more explicit settings
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
+ .context("Failed to create socket")?;
+
+ // Important: Set socket options
+ socket.set_reuse_address(true)?;
+
+ #[cfg(unix)]
+ socket.set_reuse_port(true)?;
+
+ socket.set_nonblocking(true)?;
+ socket.set_multicast_loop_v4(true)?;
+
+ // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port
+ // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port
+ let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port);
+ info!("Binding multicast listener to specific address: {:?}", bind_addr);
+ socket.bind(&bind_addr.into())?;
+
+ // Join the multicast group with a specific interface
+ let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface
+ info!("Joining multicast group {} on interface {:?}", mcast_ip, interface);
+ socket.join_multicast_v4(&mcast_ip, &interface)?;
+
+ // Additional multicast option: set the IP_MULTICAST_IF option
+ socket.set_multicast_if_v4(&interface)?;
+
+ // Convert to tokio socket
+ let udp_socket = tokio::net::UdpSocket::from_std(socket.into())
+ .context("Failed to convert socket to async")?;
+
+ info!("Multicast listener ready and bound specifically to {}:{} (group {})",
+ group.address, port, group_id);
+
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let group_id = group_id.to_string();
+
+ loop {
+ match udp_socket.recv_from(&mut buf).await {
+ Ok((len, src)) => {
+ // Since we're bound to the exact multicast address, we can be confident
+ // this packet was sent to our specific multicast group
+ let data = buf[..len].to_vec();
+
+ info!("RECEIVED: group={} from={} size={} destination={}:{}",
+ group_id, src, len, mcast_ip, port);
+
+ // Create a message with the packet
+ let message = Message::MulticastPacket {
+ group_id: group_id.clone(),
+ source: src,
+ destination: group.address.clone(),
+ port,
+ data,
+ };
+
+ // Send to clients
+ match serialize_message(&message) {
+ Ok(serialized) => {
+ let clients_lock = clients.lock().await;
+ for (client_addr, sender) in clients_lock.iter() {
+ if sender.send(serialized.clone()).await.is_err() {
+ debug!("Failed to send to client {}", client_addr);
+ } else {
+ debug!("Sent multicast packet to client {}", client_addr);
+ }
+ }
+ }
+ Err(e) => error!("Failed to serialize message: {}", e),
+ }
+ }
+ Err(e) => {
+ if e.kind() != std::io::ErrorKind::WouldBlock {
+ error!("Error receiving from socket: {}", e);
+ }
+ // Small delay to avoid busy waiting on errors
+ tokio::time::sleep(Duration::from_millis(10)).await;
+ }
+ }
+ }
+}
+
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ClientHeartbeat,
+ ClientPong,
+ Other
+}
+
+async fn handle_client(
+ stream: TcpStream,
+ addr: SocketAddr,
+ secret: &str,
+ clients: ClientMap,
+ config: Arc<ServerConfig>,
+) -> Result<()> {
+ // Check if external clients are allowed when client is not from localhost
+ if !config.allow_external_clients &&
+ !addr.ip().is_loopback() &&
+ !addr.ip().to_string().starts_with("192.168.") &&
+ !addr.ip().to_string().starts_with("10.") {
+ warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr);
+ return Err(anyhow::anyhow!("External clients not allowed"));
+ }
+
+ // Split the TCP stream into read and write parts once
+ let (mut read_stream, mut write_stream) = tokio::io::split(stream);
+
+ // Authentication using the split streams
+ if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? {
+ return Err(anyhow::anyhow!("Authentication failed"));
+ }
+
+ info!("Client authenticated: {}", addr);
+
+ // Get client info
+ let client_ip = addr.ip().to_string();
+ let client_port = addr.port();
+
+ // Check if client has specific group permissions
+ let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) {
+ Some(groups) => groups,
+ None => return Err(anyhow::anyhow!("Client not authorized for any groups")),
+ };
+
+ // Create channel for sending multicast packets to this client
+ let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
+
+ // Add client to map
+ clients.lock().await.insert(addr, tx.clone());
+
+ // Create HashMap of available groups for this client
+ let mut available_groups = HashMap::new();
+ for (id, group) in &config.multicast_groups {
+ // If client has empty group list (all allowed) or specific group is in list
+ if authorized_groups.is_empty() || authorized_groups.contains(id) {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ continue;
+ }
+
+ // Primary port is the first one
+ let primary_port = ports[0];
+
+ // Get additional ports if there are any
+ let additional_ports = if ports.len() > 1 {
+ Some(ports[1..].to_vec())
+ } else {
+ None
+ };
+
+ available_groups.insert(id.clone(), MulticastGroupInfo {
+ address: group.address.clone(),
+ port: primary_port,
+ additional_ports,
+ });
+ }
+ }
+
+ // IMPORTANT: Clone tx before moving it into the spawn
+ let tx_for_read = tx.clone();
+
+ // Spawn task to read client messages
+ let clients_clone = clients.clone();
+
+ // Use the already split read_stream
+ let read_task = tokio::spawn(async move {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ loop {
+ match read_stream.read(&mut buf).await {
+ Ok(0) => break, // Connection closed
+ Ok(n) => {
+ if let Ok(msg) = deserialize_message(&buf[..n]) {
+ match msg {
+ Message::Subscribe { group_ids } => {
+ info!("Client {} subscribing to groups: {:?}", addr, group_ids);
+ // Group subscriptions handled by server
+ },
+ Message::MulticastGroupsRequest => {
+ // Send available groups to client
+ let response = Message::MulticastGroupsResponse {
+ groups: available_groups.clone()
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ Message::PingStatus { timestamp, status } => {
+ // Determine the type of status message
+ let msg_type = if status.starts_with("Client keepalive ping") ||
+ status.starts_with("Client periodic ping") {
+ StatusMessageType::ClientHeartbeat
+ } else if status.starts_with("Client pong response") {
+ StatusMessageType::ClientPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Log the message receipt
+ match msg_type {
+ StatusMessageType::ClientHeartbeat => {
+ info!("Heartbeat from client {}: {}", addr, status);
+
+ // Respond only to actual heartbeat pings, not pong responses
+ let response = Message::PingStatus {
+ timestamp,
+ status: format!("Server connection to {} is OK", addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ StatusMessageType::ClientPong => {
+ // Just log pongs without responding to avoid loops
+ debug!("Pong from client {}: {}", addr, status);
+ },
+ StatusMessageType::Other => {
+ info!("Status message from client {}: {}", addr, status);
+ }
+ }
+ },
+ _ => {}
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from client: {}: {}", addr, e);
+ break;
+ }
+ }
+ }
+ // Clean up on disconnect
+ clients_clone.lock().await.remove(&addr);
+ info!("Client reader task ended: {}", addr);
+ });
+
+ // Forward multicast packets to client using the already split write_stream
+ let write_task = tokio::spawn(async move {
+ while let Some(packet) = rx.recv().await {
+ if let Err(e) = write_stream.write_all(&packet).await {
+ error!("Error writing to client {}: {}", addr, e);
+ break;
+ }
+ }
+ info!("Client writer task ended: {}", addr);
+ });
+
+ // Now tx is still valid here - use it for heartbeats, but with a delay
+ let tx_for_heartbeat = tx.clone();
+ let client_addr = addr.clone();
+ tokio::spawn(async move {
+ // Add initial delay before starting heartbeats to avoid interfering with initial setup messages
+ tokio::time::sleep(Duration::from_secs(5)).await;
+
+ let mut interval = tokio::time::interval(Duration::from_secs(30));
+ loop {
+ interval.tick().await;
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ // Send a heartbeat with a clear identifier
+ let msg = Message::PingStatus {
+ timestamp: now,
+ status: format!("Server heartbeat to {}", client_addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&msg) {
+ if tx_for_heartbeat.send(bytes).await.is_err() {
+ break;
+ }
+ }
+ }
+ });
+
+ // Wait for either task to complete
+ tokio::select! {
+ _ = read_task => {},
+ _ = write_task => {},
+ }
+
+ // Clean up
+ clients.lock().await.remove(&addr);
+ Ok(())
+}
+
+async fn authenticate_client(
+ reader: &mut (impl AsyncReadExt + Unpin),
+ writer: &mut (impl AsyncWriteExt + Unpin),
+ _addr: SocketAddr,
+ secret: &str
+) -> Result<bool> {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+
+ // Receive auth request
+ let n = reader.read(&mut buf).await?;
+ let auth_request = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthRequest { client_nonce } = auth_request {
+ // Generate server nonce
+ let server_nonce = generate_nonce();
+
+ // Calculate auth token
+ let auth_data = format!("{}{}", client_nonce, server_nonce);
+ let auth_token = calculate_hmac(secret, &auth_data);
+
+ // Send response
+ let response = Message::AuthResponse {
+ server_nonce: server_nonce.clone(),
+ auth_token,
+ };
+
+ let response_bytes = serialize_message(&response)?;
+ writer.write_all(&response_bytes).await?;
+
+ // Receive confirmation
+ let n = reader.read(&mut buf).await?;
+ let auth_confirm = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthConfirm { auth_token } = auth_confirm {
+ // Verify token
+ let expected_data = format!("{}{}", server_nonce, client_nonce);
+ return Ok(verify_hmac(secret, &expected_data, &auth_token));
+ }
+ }
+
+ Ok(false)
+}