aboutsummaryrefslogtreecommitdiff
path: root/src/bin/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/server.rs')
-rw-r--r--src/bin/server.rs467
1 files changed, 467 insertions, 0 deletions
diff --git a/src/bin/server.rs b/src/bin/server.rs
new file mode 100644
index 0000000..7fe0109
--- /dev/null
+++ b/src/bin/server.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups},
+ protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo},
+ DEFAULT_BUFFER_SIZE,
+};
+use socket2::{Domain, Protocol, Socket, Type};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ sync::Arc,
+ time::Duration, // Add this import for Duration
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpListener, TcpStream},
+ sync::{mpsc, Mutex},
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "server_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+type ClientMap = Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>;
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_server_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Server configuration loaded from {:?}", args.config);
+
+ let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port);
+ let listener = TcpListener::bind(&listen_addr).await
+ .context("Failed to bind TCP listener")?;
+
+ info!("Server listening on {}", listen_addr);
+
+ // Setup multicast receivers
+ let clients: ClientMap = Arc::new(Mutex::new(HashMap::new()));
+
+ // Start multicast listeners for each multicast group
+ for (group_id, group) in &config.multicast_groups {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ error!("No ports defined for group {}", group_id);
+ continue;
+ }
+
+ let display_group_id = group_id.clone();
+ let ports_display = if ports.len() == 1 {
+ format!("port {}", ports[0])
+ } else {
+ format!("ports {}-{}", ports[0], ports.last().unwrap())
+ };
+
+ // Create a listener for each port in the range
+ for port in ports {
+ let clients = clients.clone();
+ let _secret = config.secret.clone();
+ let group_id_clone = group_id.clone();
+ let mut group_info = group.clone();
+
+ // Set the specific port for this listener
+ group_info.port = Some(port);
+ group_info.port_range = None;
+
+ tokio::spawn(async move {
+ if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await {
+ error!("Multicast listener error for group {} port {}: {}",
+ group_id_clone, port, e);
+ }
+ });
+ }
+
+ info!("Listening for multicast group {} on address {} with {}",
+ display_group_id, group.address, ports_display);
+ }
+
+ // Store config for use in client handlers
+ let config = Arc::new(config);
+
+ // Accept client connections
+ while let Ok((stream, addr)) = listener.accept().await {
+ info!("New client connection from: {}", addr);
+ let secret = config.secret.clone();
+ let clients = clients.clone();
+ let config = config.clone();
+
+ tokio::spawn(async move {
+ if let Err(e) = handle_client(stream, addr, &secret, clients, config).await {
+ error!("Client error: {}: {}", addr, e);
+ }
+ info!("Client disconnected: {}", addr);
+ });
+ }
+
+ Ok(())
+}
+
+async fn listen_to_multicast(
+ group_id: &str,
+ group: &MulticastGroup,
+ clients: ClientMap
+) -> Result<()> {
+ // Get the port to use
+ let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?;
+
+ // Parse the multicast address
+ let mcast_ip = match IpAddr::from_str(&group.address)
+ .context("Invalid multicast address")? {
+ IpAddr::V4(addr) => addr,
+ _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported"))
+ };
+
+ // Create a UDP socket with more explicit settings
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
+ .context("Failed to create socket")?;
+
+ // Important: Set socket options
+ socket.set_reuse_address(true)?;
+
+ #[cfg(unix)]
+ socket.set_reuse_port(true)?;
+
+ socket.set_nonblocking(true)?;
+ socket.set_multicast_loop_v4(true)?;
+
+ // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port
+ // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port
+ let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port);
+ info!("Binding multicast listener to specific address: {:?}", bind_addr);
+ socket.bind(&bind_addr.into())?;
+
+ // Join the multicast group with a specific interface
+ let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface
+ info!("Joining multicast group {} on interface {:?}", mcast_ip, interface);
+ socket.join_multicast_v4(&mcast_ip, &interface)?;
+
+ // Additional multicast option: set the IP_MULTICAST_IF option
+ socket.set_multicast_if_v4(&interface)?;
+
+ // Convert to tokio socket
+ let udp_socket = tokio::net::UdpSocket::from_std(socket.into())
+ .context("Failed to convert socket to async")?;
+
+ info!("Multicast listener ready and bound specifically to {}:{} (group {})",
+ group.address, port, group_id);
+
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let group_id = group_id.to_string();
+
+ loop {
+ match udp_socket.recv_from(&mut buf).await {
+ Ok((len, src)) => {
+ // Since we're bound to the exact multicast address, we can be confident
+ // this packet was sent to our specific multicast group
+ let data = buf[..len].to_vec();
+
+ info!("RECEIVED: group={} from={} size={} destination={}:{}",
+ group_id, src, len, mcast_ip, port);
+
+ // Create a message with the packet
+ let message = Message::MulticastPacket {
+ group_id: group_id.clone(),
+ source: src,
+ destination: group.address.clone(),
+ port,
+ data,
+ };
+
+ // Send to clients
+ match serialize_message(&message) {
+ Ok(serialized) => {
+ let clients_lock = clients.lock().await;
+ for (client_addr, sender) in clients_lock.iter() {
+ if sender.send(serialized.clone()).await.is_err() {
+ debug!("Failed to send to client {}", client_addr);
+ } else {
+ debug!("Sent multicast packet to client {}", client_addr);
+ }
+ }
+ }
+ Err(e) => error!("Failed to serialize message: {}", e),
+ }
+ }
+ Err(e) => {
+ if e.kind() != std::io::ErrorKind::WouldBlock {
+ error!("Error receiving from socket: {}", e);
+ }
+ // Small delay to avoid busy waiting on errors
+ tokio::time::sleep(Duration::from_millis(10)).await;
+ }
+ }
+ }
+}
+
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ClientHeartbeat,
+ ClientPong,
+ Other
+}
+
+async fn handle_client(
+ stream: TcpStream,
+ addr: SocketAddr,
+ secret: &str,
+ clients: ClientMap,
+ config: Arc<ServerConfig>,
+) -> Result<()> {
+ // Check if external clients are allowed when client is not from localhost
+ if !config.allow_external_clients &&
+ !addr.ip().is_loopback() &&
+ !addr.ip().to_string().starts_with("192.168.") &&
+ !addr.ip().to_string().starts_with("10.") {
+ warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr);
+ return Err(anyhow::anyhow!("External clients not allowed"));
+ }
+
+ // Split the TCP stream into read and write parts once
+ let (mut read_stream, mut write_stream) = tokio::io::split(stream);
+
+ // Authentication using the split streams
+ if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? {
+ return Err(anyhow::anyhow!("Authentication failed"));
+ }
+
+ info!("Client authenticated: {}", addr);
+
+ // Get client info
+ let client_ip = addr.ip().to_string();
+ let client_port = addr.port();
+
+ // Check if client has specific group permissions
+ let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) {
+ Some(groups) => groups,
+ None => return Err(anyhow::anyhow!("Client not authorized for any groups")),
+ };
+
+ // Create channel for sending multicast packets to this client
+ let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
+
+ // Add client to map
+ clients.lock().await.insert(addr, tx.clone());
+
+ // Create HashMap of available groups for this client
+ let mut available_groups = HashMap::new();
+ for (id, group) in &config.multicast_groups {
+ // If client has empty group list (all allowed) or specific group is in list
+ if authorized_groups.is_empty() || authorized_groups.contains(id) {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ continue;
+ }
+
+ // Primary port is the first one
+ let primary_port = ports[0];
+
+ // Get additional ports if there are any
+ let additional_ports = if ports.len() > 1 {
+ Some(ports[1..].to_vec())
+ } else {
+ None
+ };
+
+ available_groups.insert(id.clone(), MulticastGroupInfo {
+ address: group.address.clone(),
+ port: primary_port,
+ additional_ports,
+ });
+ }
+ }
+
+ // IMPORTANT: Clone tx before moving it into the spawn
+ let tx_for_read = tx.clone();
+
+ // Spawn task to read client messages
+ let clients_clone = clients.clone();
+
+ // Use the already split read_stream
+ let read_task = tokio::spawn(async move {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ loop {
+ match read_stream.read(&mut buf).await {
+ Ok(0) => break, // Connection closed
+ Ok(n) => {
+ if let Ok(msg) = deserialize_message(&buf[..n]) {
+ match msg {
+ Message::Subscribe { group_ids } => {
+ info!("Client {} subscribing to groups: {:?}", addr, group_ids);
+ // Group subscriptions handled by server
+ },
+ Message::MulticastGroupsRequest => {
+ // Send available groups to client
+ let response = Message::MulticastGroupsResponse {
+ groups: available_groups.clone()
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ Message::PingStatus { timestamp, status } => {
+ // Determine the type of status message
+ let msg_type = if status.starts_with("Client keepalive ping") ||
+ status.starts_with("Client periodic ping") {
+ StatusMessageType::ClientHeartbeat
+ } else if status.starts_with("Client pong response") {
+ StatusMessageType::ClientPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Log the message receipt
+ match msg_type {
+ StatusMessageType::ClientHeartbeat => {
+ info!("Heartbeat from client {}: {}", addr, status);
+
+ // Respond only to actual heartbeat pings, not pong responses
+ let response = Message::PingStatus {
+ timestamp,
+ status: format!("Server connection to {} is OK", addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ StatusMessageType::ClientPong => {
+ // Just log pongs without responding to avoid loops
+ debug!("Pong from client {}: {}", addr, status);
+ },
+ StatusMessageType::Other => {
+ info!("Status message from client {}: {}", addr, status);
+ }
+ }
+ },
+ _ => {}
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from client: {}: {}", addr, e);
+ break;
+ }
+ }
+ }
+ // Clean up on disconnect
+ clients_clone.lock().await.remove(&addr);
+ info!("Client reader task ended: {}", addr);
+ });
+
+ // Forward multicast packets to client using the already split write_stream
+ let write_task = tokio::spawn(async move {
+ while let Some(packet) = rx.recv().await {
+ if let Err(e) = write_stream.write_all(&packet).await {
+ error!("Error writing to client {}: {}", addr, e);
+ break;
+ }
+ }
+ info!("Client writer task ended: {}", addr);
+ });
+
+ // Now tx is still valid here - use it for heartbeats, but with a delay
+ let tx_for_heartbeat = tx.clone();
+ let client_addr = addr.clone();
+ tokio::spawn(async move {
+ // Add initial delay before starting heartbeats to avoid interfering with initial setup messages
+ tokio::time::sleep(Duration::from_secs(5)).await;
+
+ let mut interval = tokio::time::interval(Duration::from_secs(30));
+ loop {
+ interval.tick().await;
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ // Send a heartbeat with a clear identifier
+ let msg = Message::PingStatus {
+ timestamp: now,
+ status: format!("Server heartbeat to {}", client_addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&msg) {
+ if tx_for_heartbeat.send(bytes).await.is_err() {
+ break;
+ }
+ }
+ }
+ });
+
+ // Wait for either task to complete
+ tokio::select! {
+ _ = read_task => {},
+ _ = write_task => {},
+ }
+
+ // Clean up
+ clients.lock().await.remove(&addr);
+ Ok(())
+}
+
+async fn authenticate_client(
+ reader: &mut (impl AsyncReadExt + Unpin),
+ writer: &mut (impl AsyncWriteExt + Unpin),
+ _addr: SocketAddr,
+ secret: &str
+) -> Result<bool> {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+
+ // Receive auth request
+ let n = reader.read(&mut buf).await?;
+ let auth_request = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthRequest { client_nonce } = auth_request {
+ // Generate server nonce
+ let server_nonce = generate_nonce();
+
+ // Calculate auth token
+ let auth_data = format!("{}{}", client_nonce, server_nonce);
+ let auth_token = calculate_hmac(secret, &auth_data);
+
+ // Send response
+ let response = Message::AuthResponse {
+ server_nonce: server_nonce.clone(),
+ auth_token,
+ };
+
+ let response_bytes = serialize_message(&response)?;
+ writer.write_all(&response_bytes).await?;
+
+ // Receive confirmation
+ let n = reader.read(&mut buf).await?;
+ let auth_confirm = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthConfirm { auth_token } = auth_confirm {
+ // Verify token
+ let expected_data = format!("{}{}", server_nonce, client_nonce);
+ return Ok(verify_hmac(secret, &expected_data, &auth_token));
+ }
+ }
+
+ Ok(false)
+}