aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKablersalat <crt@adastra7.net>2025-06-05 00:14:27 +0200
committerKablersalat <crt@adastra7.net>2025-06-05 00:14:27 +0200
commita0c754fc727a35d775c25857898986f4273db0a5 (patch)
tree4129d20013ff0a9b80051c0d9a74484cd753b633 /src
parentfd956fde0ba92b4aa7b0f5322cfb93951bb01fbb (diff)
Imported and sanetized dev server to publish on gitterHEADmaster
Diffstat (limited to 'src')
-rw-r--r--src/auth.rs29
-rw-r--r--src/bin/client.rs467
-rw-r--r--src/bin/mcast_test.rs123
-rw-r--r--src/bin/server.rs467
-rw-r--r--src/config.rs243
-rw-r--r--src/lib.rs13
-rw-r--r--src/protocol.rs129
7 files changed, 1471 insertions, 0 deletions
diff --git a/src/auth.rs b/src/auth.rs
new file mode 100644
index 0000000..153f0a9
--- /dev/null
+++ b/src/auth.rs
@@ -0,0 +1,29 @@
+use hmac::{Hmac, Mac};
+use rand::{rngs::OsRng, RngCore};
+use sha2::Sha256;
+
+// Generate a random nonce
+pub fn generate_nonce() -> String {
+ let mut bytes = [0u8; 16];
+ OsRng.fill_bytes(&mut bytes);
+ hex::encode(bytes)
+}
+
+// Calculate HMAC using the shared secret
+pub fn calculate_hmac(secret: &str, data: &str) -> String {
+ type HmacSha256 = Hmac<Sha256>;
+ let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
+ .expect("HMAC can take key of any size");
+
+ mac.update(data.as_bytes());
+ let result = mac.finalize();
+ let code_bytes = result.into_bytes();
+
+ hex::encode(code_bytes)
+}
+
+// Verify an HMAC token
+pub fn verify_hmac(secret: &str, data: &str, expected: &str) -> bool {
+ let calculated = calculate_hmac(secret, data);
+ calculated == expected
+}
diff --git a/src/bin/client.rs b/src/bin/client.rs
new file mode 100644
index 0000000..5c39fb9
--- /dev/null
+++ b/src/bin/client.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_client_config, ensure_default_configs, ClientConfig},
+ protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message},
+ DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES,
+};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ time::Duration,
+ sync::Arc,
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpStream, UdpSocket},
+ signal,
+ sync::Notify,
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "client_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_client_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Client configuration loaded from {:?}", args.config);
+
+ // Create a notification for clean shutdown
+ let shutdown = Arc::new(Notify::new());
+ let shutdown_signal = shutdown.clone();
+
+ // Setup signal handler for Ctrl+C
+ tokio::spawn(async move {
+ if let Err(e) = signal::ctrl_c().await {
+ error!("Failed to listen for Ctrl+C: {}", e);
+ return;
+ }
+ info!("Received Ctrl+C, shutting down...");
+ shutdown_signal.notify_one();
+ });
+
+ let server_addr = format!("{}:{}", config.server, config.port);
+
+ // Main reconnection loop
+ loop {
+ info!("Connecting to server at {}", server_addr);
+
+ // Try to connect with timeout
+ let connect_result = match tokio::time::timeout(
+ Duration::from_secs(5),
+ TcpStream::connect(&server_addr),
+ ).await {
+ Ok(result) => result,
+ Err(_) => {
+ warn!("Connection attempt timed out");
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ continue;
+ }
+ };
+
+ match connect_result {
+ Ok(stream) => {
+ info!("Connected to server");
+
+ // Run the client session
+ match run_client_session(stream, &config, &shutdown).await {
+ Ok(_) => {
+ info!("Client session ended normally");
+ break;
+ },
+ Err(e) => {
+ error!("Client session error: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ },
+ Err(e) => {
+ error!("Failed to connect: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+// Helper function to handle reconnection delay
+// Returns false if shutdown was requested
+async fn handle_reconnect(shutdown: &Arc<Notify>, client_config: &ClientConfig) -> bool {
+ let delay = Duration::from_secs(client_config.reconnect_delay_secs);
+ info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs);
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested during reconnect");
+ return false;
+ }
+ _ = tokio::time::sleep(delay) => {}
+ }
+ true
+}
+
+// Add a new enum to distinguish message types
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ServerHeartbeat,
+ ServerPong,
+ Other
+}
+
+// The main client session function that handles a single connection
+async fn run_client_session(
+ mut stream: TcpStream,
+ config: &ClientConfig,
+ shutdown: &Arc<Notify>,
+) -> Result<()> {
+ // Authenticate
+ if let Err(e) = authenticate(&mut stream, &config.secret).await {
+ return Err(anyhow::anyhow!("Authentication failed: {}", e));
+ }
+ info!("Authentication successful");
+
+ // Check if test mode is enabled in config
+ let mut test_mode = config.test_mode;
+
+ // Request server's multicast group information
+ let groups_request = Message::MulticastGroupsRequest;
+ let request_bytes = serialize_message(&groups_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Wait for response with multicast group information, ignoring other message types
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let mut groups_response = None;
+
+ // Keep reading until we get the groups response or timeout
+ let mut attempts = 0;
+ while groups_response.is_none() && attempts < 10 {
+ match stream.read(&mut buf).await {
+ Ok(n) if n > 0 => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastGroupsResponse { groups }) => {
+ groups_response = Some(groups);
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Handle ping but keep waiting for multicast groups
+ info!("Got server ping: {}", status);
+ },
+ Ok(other_msg) => {
+ debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg);
+ },
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ }
+ }
+ },
+ Ok(0) => return Err(anyhow::anyhow!("Server closed connection")),
+ Ok(_) => {},
+ Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e))
+ }
+
+ // If we didn't get groups yet, wait a bit and try again
+ if groups_response.is_none() {
+ attempts += 1;
+ tokio::time::sleep(Duration::from_millis(100)).await;
+ }
+ }
+
+ // Now we either have the groups or we timed out
+ let groups_response = groups_response.ok_or_else(||
+ anyhow::anyhow!("Failed to receive multicast group information from server"))?;
+
+ info!("Available multicast groups from server:");
+ for (id, group) in &groups_response {
+ info!(" - {} -> {}:{}", id, group.address, group.port);
+ }
+
+ // Determine which groups to subscribe to
+ let groups_to_subscribe: Vec<String> = if config.multicast_group_ids.is_empty() {
+ // If no specific groups are requested, subscribe to all
+ info!("No specific groups requested, subscribing to all available groups");
+ groups_response.keys().cloned().collect()
+ } else {
+ // Otherwise, subscribe only to requested groups that exist
+ let mut valid_groups = Vec::new();
+ for group_id in &config.multicast_group_ids {
+ if groups_response.contains_key(group_id.as_str()) {
+ valid_groups.push(group_id.clone());
+ } else {
+ warn!("Requested group '{}' does not exist on server", group_id);
+ }
+ }
+
+ if valid_groups.is_empty() {
+ warn!("None of the requested groups exist on server. No data will be received.");
+ } else {
+ info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups);
+ }
+
+ valid_groups
+ };
+
+ // Send subscription message
+ let subscribe_msg = Message::Subscribe {
+ group_ids: groups_to_subscribe.clone(),
+ };
+ let subscribe_bytes = serialize_message(&subscribe_msg)?;
+ stream.write_all(&subscribe_bytes).await?;
+
+ // Create UDP sockets for local retransmission (skip if in test mode)
+ let mut sockets: HashMap<String, UdpSocket> = HashMap::new();
+ if !test_mode {
+ for group_id in &groups_to_subscribe {
+ if let Some(group_info) = groups_response.get(group_id.as_str()) {
+ info!("Creating socket for group {} ({} on port {})",
+ group_id, group_info.address, group_info.port);
+
+ match UdpSocket::bind("0.0.0.0:0").await {
+ Ok(socket) => {
+ sockets.insert(group_id.clone(), socket);
+ info!("Successfully created UDP socket for group {}", group_id);
+ },
+ Err(e) => {
+ error!("Failed to create UDP socket for group {}: {}", group_id, e);
+ }
+ }
+ }
+ }
+
+ if sockets.is_empty() && !groups_to_subscribe.is_empty() {
+ error!("Failed to create any UDP sockets");
+ warn!("Falling back to test mode due to socket creation failure");
+ test_mode = true;
+ }
+ }
+
+ // Display test mode banner if enabled
+ if test_mode {
+ println!("{}", TEST_MODE_BANNER);
+ info!("Test mode enabled - packets will be displayed but not sent to network");
+ }
+
+ // Main receive loop
+ info!("Listening for multicast traffic from server");
+
+ // Set the read timeout for the stream
+ stream.set_nodelay(true)?;
+
+ // Remove problematic code that uses unsupported methods and the nix crate
+ if config.nat_traversal {
+ info!("NAT traversal mode enabled - using more frequent keepalives");
+ }
+
+ // Calculate appropriate ping interval based on NAT traversal setting
+ let ping_interval = if config.nat_traversal {
+ Duration::from_secs(25) // More frequent for NAT
+ } else {
+ Duration::from_secs(55)
+ };
+
+ // Main receive loop
+ loop {
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested, ending client session");
+ return Ok(());
+ }
+ read_result = stream.read(&mut buf) => {
+ match read_result {
+ Ok(0) => {
+ info!("Server closed connection");
+ return Err(anyhow::anyhow!("Server closed connection"));
+ }
+ Ok(n) => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => {
+ if test_mode {
+ println!("\n----- MULTICAST PACKET -----");
+ println!("Group: {}", group_id);
+ println!("Source: {}", source);
+ println!("Destination: {}:{}", destination, port);
+ println!("Size: {} bytes", data.len());
+ println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES));
+ println!("---------------------------\n");
+ } else {
+ info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes",
+ group_id, source, destination, port, data.len());
+
+ // Get socket for this group
+ if let Some(socket) = sockets.get(&group_id) {
+ // Parse destination address directly from the packet
+ match IpAddr::from_str(&destination) {
+ Ok(dest_addr) => {
+ // Create destination socket address using the packet's port
+ let dest = SocketAddr::new(dest_addr, port);
+
+ info!("Forwarding packet to {}:{}", dest_addr, port);
+
+ // Retransmit locally
+ match socket.send_to(&data, dest).await {
+ Ok(sent) => {
+ info!("Successfully forwarded {} of {} bytes to {}:{}",
+ sent, data.len(), destination, port);
+ },
+ Err(e) => {
+ error!("Failed to retransmit packet for group {} to {}:{}: {}",
+ group_id, destination, port, e);
+ }
+ }
+ },
+ Err(e) => {
+ error!("Invalid destination address {}: {}", destination, e);
+ }
+ }
+ } else {
+ warn!("No socket available for group {}", group_id);
+ }
+ }
+ },
+ Ok(Message::ConfigResponse { config: _server_config }) => {
+ info!("Server configuration received");
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Parse the message type based on the status text
+ let msg_type = if status.starts_with("Server heartbeat to") {
+ StatusMessageType::ServerHeartbeat
+ } else if status.starts_with("Server connection to") {
+ StatusMessageType::ServerPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Only respond with a pong to server heartbeats
+ if msg_type == StatusMessageType::ServerHeartbeat {
+ debug!("Received server heartbeat: {}", status);
+
+ // Send a pong response (but only for heartbeats)
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let pong = Message::PingStatus {
+ timestamp: now,
+ status: "Client pong response".to_string(),
+ };
+
+ if let Ok(bytes) = serialize_message(&pong) {
+ let _ = stream.write_all(&bytes).await;
+ }
+ } else {
+ // Just log other status messages without responding
+ debug!("Connection Status: {}", status);
+ }
+ },
+ Ok(_) => debug!("Received other message type"),
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ // Don't return/break on deserialization errors - continue reading
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from server: {}", e);
+ return Err(anyhow::anyhow!("Connection error: {}", e));
+ }
+ }
+ },
+ _ = tokio::time::sleep(ping_interval) => {
+ // Send regular ping to keep NAT connection alive
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let ping_msg = Message::PingStatus {
+ timestamp: now,
+ status: if config.nat_traversal {
+ "Client keepalive ping".to_string() // Changed text to be more specific
+ } else {
+ "Client periodic ping".to_string()
+ },
+ };
+
+ match serialize_message(&ping_msg) {
+ Ok(bytes) => {
+ if let Err(e) = stream.write_all(&bytes).await {
+ error!("Failed to ping server: {}", e);
+ return Err(anyhow::anyhow!("Server ping failed: {}", e));
+ }
+ debug!("Connection check sent to server");
+ },
+ Err(e) => error!("Failed to serialize ping message: {}", e),
+ }
+ }
+ }
+ }
+}
+
+async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> {
+ // Generate client nonce
+ let client_nonce = generate_nonce();
+
+ // Send auth request
+ let auth_request = Message::AuthRequest {
+ client_nonce: client_nonce.clone(),
+ };
+ let request_bytes = serialize_message(&auth_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Receive response
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let n = stream.read(&mut buf).await?;
+ let response = robust_deserialize_message(&buf[..n])?;
+
+ if let Message::AuthResponse { server_nonce, auth_token } = response {
+ // Verify server's token
+ let expected_data = format!("{}{}", client_nonce, server_nonce);
+ if !verify_hmac(secret, &expected_data, &auth_token) {
+ return Err(anyhow::anyhow!("Server authentication failed"));
+ }
+
+ // Calculate our token
+ let auth_data = format!("{}{}", server_nonce, client_nonce);
+ let client_token = calculate_hmac(secret, &auth_data);
+
+ // Send confirmation
+ let confirm = Message::AuthConfirm {
+ auth_token: client_token,
+ };
+
+ let confirm_bytes = serialize_message(&confirm)?;
+ stream.write_all(&confirm_bytes).await?;
+
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Unexpected response from server"))
+ }
+}
diff --git a/src/bin/mcast_test.rs b/src/bin/mcast_test.rs
new file mode 100644
index 0000000..782610c
--- /dev/null
+++ b/src/bin/mcast_test.rs
@@ -0,0 +1,123 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{error, info};
+use socket2::{Domain, Protocol, Socket, Type};
+use std::net::{IpAddr, Ipv4Addr, SocketAddr};
+use std::str::FromStr;
+use std::time::{Duration, Instant};
+use tokio::net::UdpSocket;
+
+#[derive(Parser, Debug)]
+#[command(author, version, about = "Multicast packet generator for testing CastRepeat")]
+struct Args {
+ #[arg(short, long, default_value = "239.192.55.1")]
+ multicast_addr: String,
+
+ #[arg(short, long, default_value = "1681")]
+ port: u16,
+
+ #[arg(short, long, default_value = "1000")]
+ interval_ms: u64,
+
+ #[arg(short, long, default_value = "60")]
+ duration_sec: u64,
+
+ #[arg(short, long, default_value = "Test packet")]
+ message: String,
+
+ #[arg(short, long)]
+ interface: Option<String>,
+
+ #[arg(short = 'b', long, action)]
+ binary_mode: bool,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ info!("CastRepeat Multicast Packet Generator");
+ info!("--------------------------------");
+ info!("Multicast Address: {}", args.multicast_addr);
+ info!("Port: {}", args.port);
+ info!("Interval: {} ms", args.interval_ms);
+ info!("Duration: {} sec", args.duration_sec);
+ if let Some(interface) = &args.interface {
+ info!("Interface: {}", interface);
+ }
+ info!("Mode: {}", if args.binary_mode { "Binary" } else { "Text" });
+ info!("--------------------------------");
+
+ // Verify multicast address
+ let mcast_addr = match IpAddr::from_str(&args.multicast_addr) {
+ Ok(IpAddr::V4(addr)) if addr.is_multicast() => addr,
+ _ => {
+ error!("Invalid multicast address: {}", args.multicast_addr);
+ return Ok(());
+ }
+ };
+
+ // Create socket
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
+ .context("Failed to create socket")?;
+
+ socket.set_multicast_ttl_v4(4)?;
+ socket.set_nonblocking(true)?;
+
+ // Set the multicast interface if specified
+ if let Some(if_str) = &args.interface {
+ if let Ok(if_addr) = Ipv4Addr::from_str(if_str) {
+ socket.set_multicast_if_v4(&if_addr)?;
+ info!("Using interface: {}", if_addr);
+ } else {
+ error!("Invalid interface address: {}", if_str);
+ return Ok(());
+ }
+ }
+
+ let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
+ socket.bind(&addr.into())?;
+
+ let socket = UdpSocket::from_std(socket.into())?;
+ let dest_addr = SocketAddr::new(IpAddr::V4(mcast_addr), args.port);
+
+ // Start sending packets
+ info!("Sending packets to {}...", dest_addr);
+
+ let start = Instant::now();
+ let end = start + Duration::from_secs(args.duration_sec);
+ let mut counter: u64 = 0; // Specify counter type as u64
+
+ while Instant::now() < end {
+ counter += 1;
+
+ // Create either text or binary test data
+ let data = if args.binary_mode {
+ // Create binary test data (similar to what we might see in the field)
+ let mut packet = Vec::with_capacity(16);
+ packet.extend_from_slice(b"REL\0"); // 4 bytes header
+ packet.extend_from_slice(&counter.to_be_bytes()[4..]); // 4 bytes counter
+ packet.extend_from_slice(&[0x00, 0x10, 0x02, 0x00]); // 4 bytes
+ packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x90]); // 4 bytes
+ packet
+ } else {
+ // Create text test data
+ format!("{} #{}", args.message, counter).into_bytes()
+ };
+
+ match socket.send_to(&data, &dest_addr).await {
+ Ok(bytes) => {
+ info!("Sent packet #{}: {} bytes", counter, bytes);
+ }
+ Err(e) => {
+ error!("Error sending packet: {}", e);
+ }
+ }
+
+ tokio::time::sleep(Duration::from_millis(args.interval_ms)).await;
+ }
+
+ info!("Done. Sent {} packets in {} seconds.", counter, args.duration_sec);
+ Ok(())
+}
diff --git a/src/bin/server.rs b/src/bin/server.rs
new file mode 100644
index 0000000..7fe0109
--- /dev/null
+++ b/src/bin/server.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups},
+ protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo},
+ DEFAULT_BUFFER_SIZE,
+};
+use socket2::{Domain, Protocol, Socket, Type};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ sync::Arc,
+ time::Duration, // Add this import for Duration
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpListener, TcpStream},
+ sync::{mpsc, Mutex},
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "server_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+type ClientMap = Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>;
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_server_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Server configuration loaded from {:?}", args.config);
+
+ let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port);
+ let listener = TcpListener::bind(&listen_addr).await
+ .context("Failed to bind TCP listener")?;
+
+ info!("Server listening on {}", listen_addr);
+
+ // Setup multicast receivers
+ let clients: ClientMap = Arc::new(Mutex::new(HashMap::new()));
+
+ // Start multicast listeners for each multicast group
+ for (group_id, group) in &config.multicast_groups {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ error!("No ports defined for group {}", group_id);
+ continue;
+ }
+
+ let display_group_id = group_id.clone();
+ let ports_display = if ports.len() == 1 {
+ format!("port {}", ports[0])
+ } else {
+ format!("ports {}-{}", ports[0], ports.last().unwrap())
+ };
+
+ // Create a listener for each port in the range
+ for port in ports {
+ let clients = clients.clone();
+ let _secret = config.secret.clone();
+ let group_id_clone = group_id.clone();
+ let mut group_info = group.clone();
+
+ // Set the specific port for this listener
+ group_info.port = Some(port);
+ group_info.port_range = None;
+
+ tokio::spawn(async move {
+ if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await {
+ error!("Multicast listener error for group {} port {}: {}",
+ group_id_clone, port, e);
+ }
+ });
+ }
+
+ info!("Listening for multicast group {} on address {} with {}",
+ display_group_id, group.address, ports_display);
+ }
+
+ // Store config for use in client handlers
+ let config = Arc::new(config);
+
+ // Accept client connections
+ while let Ok((stream, addr)) = listener.accept().await {
+ info!("New client connection from: {}", addr);
+ let secret = config.secret.clone();
+ let clients = clients.clone();
+ let config = config.clone();
+
+ tokio::spawn(async move {
+ if let Err(e) = handle_client(stream, addr, &secret, clients, config).await {
+ error!("Client error: {}: {}", addr, e);
+ }
+ info!("Client disconnected: {}", addr);
+ });
+ }
+
+ Ok(())
+}
+
+async fn listen_to_multicast(
+ group_id: &str,
+ group: &MulticastGroup,
+ clients: ClientMap
+) -> Result<()> {
+ // Get the port to use
+ let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?;
+
+ // Parse the multicast address
+ let mcast_ip = match IpAddr::from_str(&group.address)
+ .context("Invalid multicast address")? {
+ IpAddr::V4(addr) => addr,
+ _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported"))
+ };
+
+ // Create a UDP socket with more explicit settings
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
+ .context("Failed to create socket")?;
+
+ // Important: Set socket options
+ socket.set_reuse_address(true)?;
+
+ #[cfg(unix)]
+ socket.set_reuse_port(true)?;
+
+ socket.set_nonblocking(true)?;
+ socket.set_multicast_loop_v4(true)?;
+
+ // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port
+ // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port
+ let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port);
+ info!("Binding multicast listener to specific address: {:?}", bind_addr);
+ socket.bind(&bind_addr.into())?;
+
+ // Join the multicast group with a specific interface
+ let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface
+ info!("Joining multicast group {} on interface {:?}", mcast_ip, interface);
+ socket.join_multicast_v4(&mcast_ip, &interface)?;
+
+ // Additional multicast option: set the IP_MULTICAST_IF option
+ socket.set_multicast_if_v4(&interface)?;
+
+ // Convert to tokio socket
+ let udp_socket = tokio::net::UdpSocket::from_std(socket.into())
+ .context("Failed to convert socket to async")?;
+
+ info!("Multicast listener ready and bound specifically to {}:{} (group {})",
+ group.address, port, group_id);
+
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let group_id = group_id.to_string();
+
+ loop {
+ match udp_socket.recv_from(&mut buf).await {
+ Ok((len, src)) => {
+ // Since we're bound to the exact multicast address, we can be confident
+ // this packet was sent to our specific multicast group
+ let data = buf[..len].to_vec();
+
+ info!("RECEIVED: group={} from={} size={} destination={}:{}",
+ group_id, src, len, mcast_ip, port);
+
+ // Create a message with the packet
+ let message = Message::MulticastPacket {
+ group_id: group_id.clone(),
+ source: src,
+ destination: group.address.clone(),
+ port,
+ data,
+ };
+
+ // Send to clients
+ match serialize_message(&message) {
+ Ok(serialized) => {
+ let clients_lock = clients.lock().await;
+ for (client_addr, sender) in clients_lock.iter() {
+ if sender.send(serialized.clone()).await.is_err() {
+ debug!("Failed to send to client {}", client_addr);
+ } else {
+ debug!("Sent multicast packet to client {}", client_addr);
+ }
+ }
+ }
+ Err(e) => error!("Failed to serialize message: {}", e),
+ }
+ }
+ Err(e) => {
+ if e.kind() != std::io::ErrorKind::WouldBlock {
+ error!("Error receiving from socket: {}", e);
+ }
+ // Small delay to avoid busy waiting on errors
+ tokio::time::sleep(Duration::from_millis(10)).await;
+ }
+ }
+ }
+}
+
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ClientHeartbeat,
+ ClientPong,
+ Other
+}
+
+async fn handle_client(
+ stream: TcpStream,
+ addr: SocketAddr,
+ secret: &str,
+ clients: ClientMap,
+ config: Arc<ServerConfig>,
+) -> Result<()> {
+ // Check if external clients are allowed when client is not from localhost
+ if !config.allow_external_clients &&
+ !addr.ip().is_loopback() &&
+ !addr.ip().to_string().starts_with("192.168.") &&
+ !addr.ip().to_string().starts_with("10.") {
+ warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr);
+ return Err(anyhow::anyhow!("External clients not allowed"));
+ }
+
+ // Split the TCP stream into read and write parts once
+ let (mut read_stream, mut write_stream) = tokio::io::split(stream);
+
+ // Authentication using the split streams
+ if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? {
+ return Err(anyhow::anyhow!("Authentication failed"));
+ }
+
+ info!("Client authenticated: {}", addr);
+
+ // Get client info
+ let client_ip = addr.ip().to_string();
+ let client_port = addr.port();
+
+ // Check if client has specific group permissions
+ let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) {
+ Some(groups) => groups,
+ None => return Err(anyhow::anyhow!("Client not authorized for any groups")),
+ };
+
+ // Create channel for sending multicast packets to this client
+ let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
+
+ // Add client to map
+ clients.lock().await.insert(addr, tx.clone());
+
+ // Create HashMap of available groups for this client
+ let mut available_groups = HashMap::new();
+ for (id, group) in &config.multicast_groups {
+ // If client has empty group list (all allowed) or specific group is in list
+ if authorized_groups.is_empty() || authorized_groups.contains(id) {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ continue;
+ }
+
+ // Primary port is the first one
+ let primary_port = ports[0];
+
+ // Get additional ports if there are any
+ let additional_ports = if ports.len() > 1 {
+ Some(ports[1..].to_vec())
+ } else {
+ None
+ };
+
+ available_groups.insert(id.clone(), MulticastGroupInfo {
+ address: group.address.clone(),
+ port: primary_port,
+ additional_ports,
+ });
+ }
+ }
+
+ // IMPORTANT: Clone tx before moving it into the spawn
+ let tx_for_read = tx.clone();
+
+ // Spawn task to read client messages
+ let clients_clone = clients.clone();
+
+ // Use the already split read_stream
+ let read_task = tokio::spawn(async move {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ loop {
+ match read_stream.read(&mut buf).await {
+ Ok(0) => break, // Connection closed
+ Ok(n) => {
+ if let Ok(msg) = deserialize_message(&buf[..n]) {
+ match msg {
+ Message::Subscribe { group_ids } => {
+ info!("Client {} subscribing to groups: {:?}", addr, group_ids);
+ // Group subscriptions handled by server
+ },
+ Message::MulticastGroupsRequest => {
+ // Send available groups to client
+ let response = Message::MulticastGroupsResponse {
+ groups: available_groups.clone()
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ Message::PingStatus { timestamp, status } => {
+ // Determine the type of status message
+ let msg_type = if status.starts_with("Client keepalive ping") ||
+ status.starts_with("Client periodic ping") {
+ StatusMessageType::ClientHeartbeat
+ } else if status.starts_with("Client pong response") {
+ StatusMessageType::ClientPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Log the message receipt
+ match msg_type {
+ StatusMessageType::ClientHeartbeat => {
+ info!("Heartbeat from client {}: {}", addr, status);
+
+ // Respond only to actual heartbeat pings, not pong responses
+ let response = Message::PingStatus {
+ timestamp,
+ status: format!("Server connection to {} is OK", addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ StatusMessageType::ClientPong => {
+ // Just log pongs without responding to avoid loops
+ debug!("Pong from client {}: {}", addr, status);
+ },
+ StatusMessageType::Other => {
+ info!("Status message from client {}: {}", addr, status);
+ }
+ }
+ },
+ _ => {}
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from client: {}: {}", addr, e);
+ break;
+ }
+ }
+ }
+ // Clean up on disconnect
+ clients_clone.lock().await.remove(&addr);
+ info!("Client reader task ended: {}", addr);
+ });
+
+ // Forward multicast packets to client using the already split write_stream
+ let write_task = tokio::spawn(async move {
+ while let Some(packet) = rx.recv().await {
+ if let Err(e) = write_stream.write_all(&packet).await {
+ error!("Error writing to client {}: {}", addr, e);
+ break;
+ }
+ }
+ info!("Client writer task ended: {}", addr);
+ });
+
+ // Now tx is still valid here - use it for heartbeats, but with a delay
+ let tx_for_heartbeat = tx.clone();
+ let client_addr = addr.clone();
+ tokio::spawn(async move {
+ // Add initial delay before starting heartbeats to avoid interfering with initial setup messages
+ tokio::time::sleep(Duration::from_secs(5)).await;
+
+ let mut interval = tokio::time::interval(Duration::from_secs(30));
+ loop {
+ interval.tick().await;
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ // Send a heartbeat with a clear identifier
+ let msg = Message::PingStatus {
+ timestamp: now,
+ status: format!("Server heartbeat to {}", client_addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&msg) {
+ if tx_for_heartbeat.send(bytes).await.is_err() {
+ break;
+ }
+ }
+ }
+ });
+
+ // Wait for either task to complete
+ tokio::select! {
+ _ = read_task => {},
+ _ = write_task => {},
+ }
+
+ // Clean up
+ clients.lock().await.remove(&addr);
+ Ok(())
+}
+
+async fn authenticate_client(
+ reader: &mut (impl AsyncReadExt + Unpin),
+ writer: &mut (impl AsyncWriteExt + Unpin),
+ _addr: SocketAddr,
+ secret: &str
+) -> Result<bool> {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+
+ // Receive auth request
+ let n = reader.read(&mut buf).await?;
+ let auth_request = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthRequest { client_nonce } = auth_request {
+ // Generate server nonce
+ let server_nonce = generate_nonce();
+
+ // Calculate auth token
+ let auth_data = format!("{}{}", client_nonce, server_nonce);
+ let auth_token = calculate_hmac(secret, &auth_data);
+
+ // Send response
+ let response = Message::AuthResponse {
+ server_nonce: server_nonce.clone(),
+ auth_token,
+ };
+
+ let response_bytes = serialize_message(&response)?;
+ writer.write_all(&response_bytes).await?;
+
+ // Receive confirmation
+ let n = reader.read(&mut buf).await?;
+ let auth_confirm = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthConfirm { auth_token } = auth_confirm {
+ // Verify token
+ let expected_data = format!("{}{}", server_nonce, client_nonce);
+ return Ok(verify_hmac(secret, &expected_data, &auth_token));
+ }
+ }
+
+ Ok(false)
+}
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..c3e7e56
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,243 @@
+use anyhow::{Context, Result};
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use std::fs;
+use std::path::Path;
+
+// Structure to define individual authorized clients
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ClientDefinition {
+ pub name: String, // Friendly name for the client
+ pub ip_address: String, // Allowed IP address
+ pub port: Option<u16>, // Optional specific port (if None, any port is allowed)
+ pub group_ids: Vec<String>, // Multicast groups this client is allowed to access (empty = all)
+}
+
+// Definition of a multicast group
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct MulticastGroup {
+ pub address: String, // Multicast address
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub port: Option<u16>, // Single port (used if port_range is None)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub port_range: Option<(u16, u16)>, // Optional port range (start, end)
+}
+
+impl MulticastGroup {
+ // Helper method to get all ports that should be used
+ pub fn get_ports(&self) -> Vec<u16> {
+ if let Some(range) = self.port_range {
+ // If a range is specified, return all ports in the range
+ (range.0..=range.1).collect()
+ } else if let Some(port) = self.port {
+ // If only a single port is specified, return just that port
+ vec![port]
+ } else {
+ // Default to empty vec if neither is specified (should be validated elsewhere)
+ vec![]
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ServerConfig {
+ pub secret: String,
+ pub listen_ip: String,
+ pub listen_port: u16,
+ pub multicast_groups: HashMap<String, MulticastGroup>,
+ pub authorized_clients: Vec<ClientDefinition>,
+ #[serde(default)]
+ pub simple_auth: bool,
+ #[serde(default)]
+ pub allow_external_clients: bool, // Flag to allow clients from outside the local network
+}
+
+impl Default for ServerConfig {
+ fn default() -> Self {
+ let mut groups = HashMap::new();
+ groups.insert("default".to_string(), MulticastGroup {
+ address: "224.0.0.1".to_string(),
+ port: Some(5000),
+ port_range: None,
+ });
+ groups.insert("ssdp".to_string(), MulticastGroup {
+ address: "239.255.255.250".to_string(),
+ port: Some(1900),
+ port_range: None,
+ });
+ // Example with port range
+ groups.insert("range_example".to_string(), MulticastGroup {
+ address: "239.192.55.2".to_string(),
+ port: None,
+ port_range: Some((1680, 1685)),
+ });
+
+ Self {
+ secret: "changeme".to_string(),
+ listen_ip: "0.0.0.0".to_string(),
+ listen_port: 8989,
+ multicast_groups: groups,
+ authorized_clients: vec![
+ ClientDefinition {
+ name: "localhost".to_string(),
+ ip_address: "127.0.0.1".to_string(),
+ port: None,
+ group_ids: vec![], // Empty means all groups
+ },
+ ClientDefinition {
+ name: "example-client".to_string(),
+ ip_address: "192.168.1.100".to_string(),
+ port: Some(12345),
+ group_ids: vec!["default".to_string()], // Only the default group
+ }
+ ],
+ simple_auth: false, // Default to standard auth mode
+ allow_external_clients: false, // Default to disallow external clients
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ClientConfig {
+ pub secret: String,
+ pub server: String,
+ pub port: u16,
+ pub multicast_group_ids: Vec<String>,
+ pub test_mode: bool,
+ #[serde(default)]
+ pub nat_traversal: bool, // Flag to enable NAT traversal features
+ #[serde(default = "default_reconnect_delay")]
+ pub reconnect_delay_secs: u64,
+}
+
+fn default_reconnect_delay() -> u64 {
+ 5
+}
+
+impl Default for ClientConfig {
+ fn default() -> Self {
+ Self {
+ secret: "changeme".to_string(),
+ server: "127.0.0.1".to_string(),
+ port: 8989,
+ multicast_group_ids: vec![], // Empty means subscribe to all groups
+ test_mode: false,
+ nat_traversal: false,
+ reconnect_delay_secs: 5,
+ }
+ }
+}
+
+// Load server configuration from a file
+pub fn load_server_config<P: AsRef<Path>>(path: P) -> Result<ServerConfig> {
+ let config_str = fs::read_to_string(path)
+ .context("Failed to read server configuration file")?;
+ let config: ServerConfig = toml::from_str(&config_str)
+ .context("Failed to parse server configuration")?;
+ Ok(config)
+}
+
+// Save server configuration to a file
+pub fn save_server_config<P: AsRef<Path>>(config: &ServerConfig, path: P) -> Result<()> {
+ let toml_str = toml::to_string_pretty(config)
+ .context("Failed to serialize server configuration")?;
+ fs::write(path, toml_str)
+ .context("Failed to write server configuration file")?;
+ Ok(())
+}
+
+// Load client configuration from a file
+pub fn load_client_config<P: AsRef<Path>>(path: P) -> Result<ClientConfig> {
+ let config_str = fs::read_to_string(path)
+ .context("Failed to read client configuration file")?;
+ let config: ClientConfig = toml::from_str(&config_str)
+ .context("Failed to parse client configuration")?;
+ Ok(config)
+}
+
+// Save client configuration to a file
+pub fn save_client_config<P: AsRef<Path>>(config: &ClientConfig, path: P) -> Result<()> {
+ let toml_str = toml::to_string_pretty(config)
+ .context("Failed to serialize client configuration")?;
+ fs::write(path, toml_str)
+ .context("Failed to write client configuration file")?;
+ Ok(())
+}
+
+// Generate default configuration files if they don't exist
+pub fn ensure_default_configs() -> Result<()> {
+ let server_config_path = "server_config.toml";
+ if !Path::new(server_config_path).exists() {
+ let default_config = ServerConfig::default();
+ save_server_config(&default_config, server_config_path)?;
+ println!("Created default server configuration at {}", server_config_path);
+ }
+
+ let client_config_path = "client_config.toml";
+ if !Path::new(client_config_path).exists() {
+ let default_config = ClientConfig::default();
+ save_client_config(&default_config, client_config_path)?;
+ println!("Created default client configuration at {}", client_config_path);
+ }
+
+ Ok(())
+}
+
+// Check if a client is authorized for specified groups
+pub fn get_client_authorized_groups(
+ config: &ServerConfig,
+ ip: &str,
+ port: u16
+) -> Option<Vec<String>> {
+ // If simple_auth is enabled, all clients that know the secret get access to all groups
+ if config.simple_auth {
+ return Some(vec![]); // Empty vector means all groups
+ }
+
+ // Otherwise use the standard group authorization check
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None or matches the client's port
+ if client.port.is_none() || client.port == Some(port) {
+ // Return the group IDs or empty vector (meaning all groups)
+ return Some(client.group_ids.clone());
+ }
+ }
+ }
+
+ None // Not authorized
+}
+
+// Get client name if authorized
+pub fn get_client_name(config: &ServerConfig, ip: &str, port: u16) -> Option<String> {
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None, any port from this IP is allowed
+ if client.port.is_none() || client.port == Some(port) {
+ return Some(client.name.clone());
+ }
+ }
+ }
+
+ None
+}
+
+// Check if a client's IP address and port are authorized
+pub fn is_client_authorized(config: &ServerConfig, ip: &str, port: u16) -> bool {
+ // If simple_auth is enabled, all clients that know the secret are authorized
+ if config.simple_auth {
+ return true;
+ }
+
+ // Otherwise use the standard client authorization check
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None, any port from this IP is allowed
+ if client.port.is_none() || client.port == Some(port) {
+ return true;
+ }
+ }
+ }
+
+ false
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..93f593e
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,13 @@
+pub mod auth;
+pub mod config;
+pub mod protocol;
+
+// Constants
+pub const DEFAULT_BUFFER_SIZE: usize = 65536; // Max UDP packet size
+pub const TEST_MODE_BANNER: &str = "
+╔════════════════════════════════════════════════════╗
+║ TEST MODE ║
+║ Packets will be displayed but not sent to network ║
+╚════════════════════════════════════════════════════╝
+";
+pub const MAX_DISPLAY_BYTES: usize = 64; // Maximum bytes to display in test mode
diff --git a/src/protocol.rs b/src/protocol.rs
new file mode 100644
index 0000000..340bbd3
--- /dev/null
+++ b/src/protocol.rs
@@ -0,0 +1,129 @@
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use std::net::SocketAddr;
+
+// Information about a multicast group
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct MulticastGroupInfo {
+ pub address: String,
+ pub port: u16,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub additional_ports: Option<Vec<u16>>, // Additional ports if a range was defined
+}
+
+// Protocol messages exchanged between client and server
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub enum Message {
+ // Initial authentication message
+ AuthRequest {
+ client_nonce: String,
+ },
+
+ // Authentication response
+ AuthResponse {
+ server_nonce: String,
+ auth_token: String, // HMAC(secret, client_nonce + server_nonce)
+ },
+
+ // Final auth confirmation
+ AuthConfirm {
+ auth_token: String, // HMAC(secret, server_nonce + client_nonce)
+ },
+
+ // Request information about available multicast groups
+ MulticastGroupsRequest,
+
+ // Response with available multicast groups
+ MulticastGroupsResponse {
+ groups: HashMap<String, MulticastGroupInfo>,
+ },
+
+ // Subscribe to specific multicast groups
+ Subscribe {
+ group_ids: Vec<String>,
+ },
+
+ // Multicast packet forwarded to client
+ MulticastPacket {
+ group_id: String,
+ source: SocketAddr,
+ destination: String, // Destination multicast address
+ port: u16, // Destination port
+ data: Vec<u8>,
+ },
+
+ // Configuration response with current settings
+ ConfigResponse {
+ config: ServerConfigInfo,
+ },
+
+ // New ping/status message for checking connections
+ PingStatus {
+ timestamp: u64,
+ status: String,
+ },
+}
+
+// Server configuration info that can be sent to clients
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ServerConfigInfo {
+ pub available_multicast_addresses: Vec<String>,
+ pub multicast_port: u16,
+}
+
+// Utility function to format packet data for display in test mode
+pub fn format_packet_for_display(data: &[u8], max_bytes: usize) -> String {
+ let display_len = std::cmp::min(data.len(), max_bytes);
+ let mut result = String::new();
+
+ // Print hexadecimal representation
+ for (i, byte) in data.iter().take(display_len).enumerate() {
+ if i > 0 && i % 16 == 0 {
+ result.push('\n');
+ }
+ result.push_str(&format!("{:02x} ", byte));
+ }
+
+ if data.len() > max_bytes {
+ result.push_str("\n... (truncated)");
+ }
+
+ result
+}
+
+// Serialize a message to bytes
+pub fn serialize_message(message: &Message) -> Result<Vec<u8>, serde_json::Error> {
+ serde_json::to_vec(message)
+}
+
+// Deserialize bytes to a message with better error handling for partial messages
+pub fn deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
+ // Find where valid JSON ends to handle cases where multiple messages
+ // or trailing data might be in the buffer
+ let mut deserializer = serde_json::Deserializer::from_slice(bytes);
+ let message = Message::deserialize(&mut deserializer)?;
+
+ // Return the successfully parsed message
+ Ok(message)
+}
+
+// Add a helper function to handle potential noise in streams
+pub fn robust_deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
+ // First try the standard method
+ match deserialize_message(bytes) {
+ Ok(message) => Ok(message),
+ Err(e) => {
+ // If it fails due to trailing characters, try to parse just the valid JSON
+ if e.is_syntax() && e.to_string().contains("trailing characters") {
+ // Try to find where valid JSON ends by parsing incrementally
+ for i in (1..bytes.len()).rev() {
+ if let Ok(msg) = deserialize_message(&bytes[0..i]) {
+ return Ok(msg);
+ }
+ }
+ }
+ // If we couldn't recover, return the original error
+ Err(e)
+ }
+ }
+}