diff options
Diffstat (limited to 'src/bin/server.rs')
-rw-r--r-- | src/bin/server.rs | 467 |
1 files changed, 467 insertions, 0 deletions
diff --git a/src/bin/server.rs b/src/bin/server.rs new file mode 100644 index 0000000..7fe0109 --- /dev/null +++ b/src/bin/server.rs @@ -0,0 +1,467 @@ +use anyhow::{Context, Result}; +use clap::Parser; +use log::{debug, error, info, warn}; +use multicast_relay::{ + auth::{calculate_hmac, generate_nonce, verify_hmac}, + config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups}, + protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo}, + DEFAULT_BUFFER_SIZE, +}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::PathBuf, + str::FromStr, + sync::Arc, + time::Duration, // Add this import for Duration +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + sync::{mpsc, Mutex}, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long, default_value = "server_config.toml")] + config: PathBuf, + + #[arg(short, long, action)] + generate_default: bool, +} + +type ClientMap = Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let args = Args::parse(); + + // Generate default configs if requested + if args.generate_default { + ensure_default_configs()?; + return Ok(()); + } + + // Load configuration + let config = load_server_config(&args.config) + .context(format!("Failed to load config from {:?}", args.config))?; + + info!("Server configuration loaded from {:?}", args.config); + + let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port); + let listener = TcpListener::bind(&listen_addr).await + .context("Failed to bind TCP listener")?; + + info!("Server listening on {}", listen_addr); + + // Setup multicast receivers + let clients: ClientMap = Arc::new(Mutex::new(HashMap::new())); + + // Start multicast listeners for each multicast group + for (group_id, group) in &config.multicast_groups { + let ports = group.get_ports(); + if ports.is_empty() { + error!("No ports defined for group {}", group_id); + continue; + } + + let display_group_id = group_id.clone(); + let ports_display = if ports.len() == 1 { + format!("port {}", ports[0]) + } else { + format!("ports {}-{}", ports[0], ports.last().unwrap()) + }; + + // Create a listener for each port in the range + for port in ports { + let clients = clients.clone(); + let _secret = config.secret.clone(); + let group_id_clone = group_id.clone(); + let mut group_info = group.clone(); + + // Set the specific port for this listener + group_info.port = Some(port); + group_info.port_range = None; + + tokio::spawn(async move { + if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await { + error!("Multicast listener error for group {} port {}: {}", + group_id_clone, port, e); + } + }); + } + + info!("Listening for multicast group {} on address {} with {}", + display_group_id, group.address, ports_display); + } + + // Store config for use in client handlers + let config = Arc::new(config); + + // Accept client connections + while let Ok((stream, addr)) = listener.accept().await { + info!("New client connection from: {}", addr); + let secret = config.secret.clone(); + let clients = clients.clone(); + let config = config.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_client(stream, addr, &secret, clients, config).await { + error!("Client error: {}: {}", addr, e); + } + info!("Client disconnected: {}", addr); + }); + } + + Ok(()) +} + +async fn listen_to_multicast( + group_id: &str, + group: &MulticastGroup, + clients: ClientMap +) -> Result<()> { + // Get the port to use + let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?; + + // Parse the multicast address + let mcast_ip = match IpAddr::from_str(&group.address) + .context("Invalid multicast address")? { + IpAddr::V4(addr) => addr, + _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported")) + }; + + // Create a UDP socket with more explicit settings + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .context("Failed to create socket")?; + + // Important: Set socket options + socket.set_reuse_address(true)?; + + #[cfg(unix)] + socket.set_reuse_port(true)?; + + socket.set_nonblocking(true)?; + socket.set_multicast_loop_v4(true)?; + + // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port + // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port + let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port); + info!("Binding multicast listener to specific address: {:?}", bind_addr); + socket.bind(&bind_addr.into())?; + + // Join the multicast group with a specific interface + let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface + info!("Joining multicast group {} on interface {:?}", mcast_ip, interface); + socket.join_multicast_v4(&mcast_ip, &interface)?; + + // Additional multicast option: set the IP_MULTICAST_IF option + socket.set_multicast_if_v4(&interface)?; + + // Convert to tokio socket + let udp_socket = tokio::net::UdpSocket::from_std(socket.into()) + .context("Failed to convert socket to async")?; + + info!("Multicast listener ready and bound specifically to {}:{} (group {})", + group.address, port, group_id); + + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + let group_id = group_id.to_string(); + + loop { + match udp_socket.recv_from(&mut buf).await { + Ok((len, src)) => { + // Since we're bound to the exact multicast address, we can be confident + // this packet was sent to our specific multicast group + let data = buf[..len].to_vec(); + + info!("RECEIVED: group={} from={} size={} destination={}:{}", + group_id, src, len, mcast_ip, port); + + // Create a message with the packet + let message = Message::MulticastPacket { + group_id: group_id.clone(), + source: src, + destination: group.address.clone(), + port, + data, + }; + + // Send to clients + match serialize_message(&message) { + Ok(serialized) => { + let clients_lock = clients.lock().await; + for (client_addr, sender) in clients_lock.iter() { + if sender.send(serialized.clone()).await.is_err() { + debug!("Failed to send to client {}", client_addr); + } else { + debug!("Sent multicast packet to client {}", client_addr); + } + } + } + Err(e) => error!("Failed to serialize message: {}", e), + } + } + Err(e) => { + if e.kind() != std::io::ErrorKind::WouldBlock { + error!("Error receiving from socket: {}", e); + } + // Small delay to avoid busy waiting on errors + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + } +} + +#[derive(PartialEq)] +enum StatusMessageType { + ClientHeartbeat, + ClientPong, + Other +} + +async fn handle_client( + stream: TcpStream, + addr: SocketAddr, + secret: &str, + clients: ClientMap, + config: Arc<ServerConfig>, +) -> Result<()> { + // Check if external clients are allowed when client is not from localhost + if !config.allow_external_clients && + !addr.ip().is_loopback() && + !addr.ip().to_string().starts_with("192.168.") && + !addr.ip().to_string().starts_with("10.") { + warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr); + return Err(anyhow::anyhow!("External clients not allowed")); + } + + // Split the TCP stream into read and write parts once + let (mut read_stream, mut write_stream) = tokio::io::split(stream); + + // Authentication using the split streams + if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? { + return Err(anyhow::anyhow!("Authentication failed")); + } + + info!("Client authenticated: {}", addr); + + // Get client info + let client_ip = addr.ip().to_string(); + let client_port = addr.port(); + + // Check if client has specific group permissions + let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) { + Some(groups) => groups, + None => return Err(anyhow::anyhow!("Client not authorized for any groups")), + }; + + // Create channel for sending multicast packets to this client + let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100); + + // Add client to map + clients.lock().await.insert(addr, tx.clone()); + + // Create HashMap of available groups for this client + let mut available_groups = HashMap::new(); + for (id, group) in &config.multicast_groups { + // If client has empty group list (all allowed) or specific group is in list + if authorized_groups.is_empty() || authorized_groups.contains(id) { + let ports = group.get_ports(); + if ports.is_empty() { + continue; + } + + // Primary port is the first one + let primary_port = ports[0]; + + // Get additional ports if there are any + let additional_ports = if ports.len() > 1 { + Some(ports[1..].to_vec()) + } else { + None + }; + + available_groups.insert(id.clone(), MulticastGroupInfo { + address: group.address.clone(), + port: primary_port, + additional_ports, + }); + } + } + + // IMPORTANT: Clone tx before moving it into the spawn + let tx_for_read = tx.clone(); + + // Spawn task to read client messages + let clients_clone = clients.clone(); + + // Use the already split read_stream + let read_task = tokio::spawn(async move { + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + loop { + match read_stream.read(&mut buf).await { + Ok(0) => break, // Connection closed + Ok(n) => { + if let Ok(msg) = deserialize_message(&buf[..n]) { + match msg { + Message::Subscribe { group_ids } => { + info!("Client {} subscribing to groups: {:?}", addr, group_ids); + // Group subscriptions handled by server + }, + Message::MulticastGroupsRequest => { + // Send available groups to client + let response = Message::MulticastGroupsResponse { + groups: available_groups.clone() + }; + + if let Ok(bytes) = serialize_message(&response) { + let _ = tx_for_read.send(bytes).await; + } + }, + Message::PingStatus { timestamp, status } => { + // Determine the type of status message + let msg_type = if status.starts_with("Client keepalive ping") || + status.starts_with("Client periodic ping") { + StatusMessageType::ClientHeartbeat + } else if status.starts_with("Client pong response") { + StatusMessageType::ClientPong + } else { + StatusMessageType::Other + }; + + // Log the message receipt + match msg_type { + StatusMessageType::ClientHeartbeat => { + info!("Heartbeat from client {}: {}", addr, status); + + // Respond only to actual heartbeat pings, not pong responses + let response = Message::PingStatus { + timestamp, + status: format!("Server connection to {} is OK", addr), + }; + + if let Ok(bytes) = serialize_message(&response) { + let _ = tx_for_read.send(bytes).await; + } + }, + StatusMessageType::ClientPong => { + // Just log pongs without responding to avoid loops + debug!("Pong from client {}: {}", addr, status); + }, + StatusMessageType::Other => { + info!("Status message from client {}: {}", addr, status); + } + } + }, + _ => {} + } + } + } + Err(e) => { + error!("Error reading from client: {}: {}", addr, e); + break; + } + } + } + // Clean up on disconnect + clients_clone.lock().await.remove(&addr); + info!("Client reader task ended: {}", addr); + }); + + // Forward multicast packets to client using the already split write_stream + let write_task = tokio::spawn(async move { + while let Some(packet) = rx.recv().await { + if let Err(e) = write_stream.write_all(&packet).await { + error!("Error writing to client {}: {}", addr, e); + break; + } + } + info!("Client writer task ended: {}", addr); + }); + + // Now tx is still valid here - use it for heartbeats, but with a delay + let tx_for_heartbeat = tx.clone(); + let client_addr = addr.clone(); + tokio::spawn(async move { + // Add initial delay before starting heartbeats to avoid interfering with initial setup messages + tokio::time::sleep(Duration::from_secs(5)).await; + + let mut interval = tokio::time::interval(Duration::from_secs(30)); + loop { + interval.tick().await; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Send a heartbeat with a clear identifier + let msg = Message::PingStatus { + timestamp: now, + status: format!("Server heartbeat to {}", client_addr), + }; + + if let Ok(bytes) = serialize_message(&msg) { + if tx_for_heartbeat.send(bytes).await.is_err() { + break; + } + } + } + }); + + // Wait for either task to complete + tokio::select! { + _ = read_task => {}, + _ = write_task => {}, + } + + // Clean up + clients.lock().await.remove(&addr); + Ok(()) +} + +async fn authenticate_client( + reader: &mut (impl AsyncReadExt + Unpin), + writer: &mut (impl AsyncWriteExt + Unpin), + _addr: SocketAddr, + secret: &str +) -> Result<bool> { + let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE]; + + // Receive auth request + let n = reader.read(&mut buf).await?; + let auth_request = deserialize_message(&buf[..n])?; + + if let Message::AuthRequest { client_nonce } = auth_request { + // Generate server nonce + let server_nonce = generate_nonce(); + + // Calculate auth token + let auth_data = format!("{}{}", client_nonce, server_nonce); + let auth_token = calculate_hmac(secret, &auth_data); + + // Send response + let response = Message::AuthResponse { + server_nonce: server_nonce.clone(), + auth_token, + }; + + let response_bytes = serialize_message(&response)?; + writer.write_all(&response_bytes).await?; + + // Receive confirmation + let n = reader.read(&mut buf).await?; + let auth_confirm = deserialize_message(&buf[..n])?; + + if let Message::AuthConfirm { auth_token } = auth_confirm { + // Verify token + let expected_data = format!("{}{}", server_nonce, client_nonce); + return Ok(verify_hmac(secret, &expected_data, &auth_token)); + } + } + + Ok(false) +} |