aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKablersalat <crt@adastra7.net>2025-06-05 00:14:27 +0200
committerKablersalat <crt@adastra7.net>2025-06-05 00:14:27 +0200
commita0c754fc727a35d775c25857898986f4273db0a5 (patch)
tree4129d20013ff0a9b80051c0d9a74484cd753b633
parentfd956fde0ba92b4aa7b0f5322cfb93951bb01fbb (diff)
Imported and sanetized dev server to publish on gitterHEADmaster
-rw-r--r--Cargo.toml33
-rw-r--r--README.md237
-rw-r--r--bodeting.sh83
-rw-r--r--src/auth.rs29
-rw-r--r--src/bin/client.rs467
-rw-r--r--src/bin/mcast_test.rs123
-rw-r--r--src/bin/server.rs467
-rw-r--r--src/config.rs243
-rw-r--r--src/lib.rs13
-rw-r--r--src/protocol.rs129
10 files changed, 1824 insertions, 0 deletions
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..b6b6e8e
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,33 @@
+[package]
+name = "multicast_relay"
+version = "1.0.0"
+edition = "2021"
+description = "Multicast traffic relay system"
+# Note: Requires C compiler toolchain (build-essential on Debian/Ubuntu)
+
+[[bin]]
+name = "castrepeat-server"
+path = "src/bin/server.rs"
+
+[[bin]]
+name = "castrepeat-client"
+path = "src/bin/client.rs"
+
+[[bin]]
+name = "castrepeat-mcast-test"
+path = "src/bin/mcast_test.rs"
+
+[dependencies]
+tokio = { version = "1.28", features = ["full"] }
+serde = { version = "1.0", features = ["derive"] }
+serde_json = "1.0"
+toml = "0.7"
+clap = { version = "4.3", features = ["derive"] }
+anyhow = "1.0"
+hmac = "0.12"
+sha2 = "0.10"
+hex = "0.4"
+log = "0.4"
+env_logger = "0.10"
+socket2 = "0.5"
+rand = "0.8"
diff --git a/README.md b/README.md
index e69de29..f40783f 100644
--- a/README.md
+++ b/README.md
@@ -0,0 +1,237 @@
+# CastRepeat
+CastRepeat is a somehwhat robust multicast relay tool written in Rust that allows you to capture multicast traffic on one network and repeat it on another. It is designed for environments where multicast routing is not available or practical and you need to bridge multicast traffic across network boundaries via a reliable transfer protocol (TCP instead of UDP for example)
+
+## Features
+- **Being a multicast relay**: Capture and relay multicast packets between networks
+- **Group filtering**: Configure specific multicast groups to relay
+- **Port filtering**: Relay traffic on specific ports only
+- **Authentication**: Secure communication between server and client
+- **NAT Traversal**: Support for clients behind NAT firewalls
+- **Test Mode**: View received multicast packets without relaying them
+- **Authorization Controls**: Control which clients can access which multicast groups
+
+## Architecture
+CastRepeat uses a client-server architecture:
+
+- **Server**: Captures multicast traffic on its network
+- **Client**: Receives and retransmits multicast traffic locally
+
+```
+Network A Network B
++----------------+ +----------------+
+| Multicast | | Multicast |
+| Sources | | Receivers |
+| | | |
+| +--------+ +---------+ |
+| | Server +----TCP-IP----+ Client | |
+| +--------+ +---------+ |
+| | | |
++----------------+ +----------------+
+```
+
+## Prerequisites
+- Rust 1.52.0 or higher
+- `libssl-dev` (OpenSSL development libraries)
+- Network privileges for multicast (may require sudo/root)
+
+## Installing Rust
+
+If you don't have Rust installed, use rustup (the official Rust installer):
+```
+curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
+```
+
+Then, follow the on-screen instructions to complete the installation. After installing, ensure that the Rust `bin` directory is in your system's `PATH`.
+
+## Installation
+
+1. Clone the repository:
+ ```
+ git clone https://git.teleco.ch/crt/castrepeat.git/
+ cd castrepeat
+ ```
+
+2. Build the project:
+ ```
+ cargo build --release
+ ```
+
+3. The binaries will be available in `target/release/`:
+ - `castrepeat-server`: The server component
+ - `castrepeat-client`: The client component
+ - `castrepeat-mcast-test`: A simple tool for generating multicast test traffic
+
+## Configuration
+
+### Server Configuration
+
+Create a `server_config.toml` file:
+
+```
+secret = "your-shared-secret-key"
+listen_ip = "0.0.0.0"
+listen_port = 8989
+simple_auth = true
+allow_external_clients = true
+
+[multicast_groups]
+# Example group for some service
+[multicast_groups.service1]
+address = "239.192.55.1"
+port = 1681
+
+# Another group with a port range
+[multicast_groups.service2]
+address = "239.192.55.2"
+port_range = [1680, 1685]
+
+# Client authorization (optional when simple_auth = true which is needed for NAT traversal scenarios)
+[[authorized_clients]]
+name = "office-client"
+ip_address = "192.168.1.100"
+group_ids = ["service1"] # Only allow access to service1
+```
+
+### Client Configuration
+
+Create a `client_config.toml` file:
+
+```
+secret = "your-shared-secret-key"
+server = "server.example.com" # Server hostname or IP address
+port = 8989
+multicast_group_ids = [] # Empty means subscribe to all available groups
+test_mode = false
+nat_traversal = true # Enable for NAT traversal features
+reconnect_delay_secs = 5
+```
+
+## Usage
+
+### Running the Server
+
+```
+# With default config location
+sudo ./target/release/castrepeat-server
+
+# With custom config path
+sudo ./target/release/castrepeat-server --config /path/to/server_config.toml
+
+# Generate default configs
+./target/release/castrepeat-server --generate-default
+```
+
+### Running the Client
+
+```
+# With default config location
+./target/release/castrepeat-client
+
+# With custom config path
+./target/release/castrepeat-client --config /path/to/client_config.toml
+
+# Generate default configs
+./target/release/castrepeat-client --generate-default
+```
+
+### Debug Mode
+
+Use the client's test mode to see multicast packets without relaying them:
+
+```
+# In client_config.toml:
+test_mode = true
+```
+
+### Using the Helper Script
+
+The `bodeting.sh` script provides a simple way to start a server and debug client:
+
+```
+# Run the script
+./bodeting.sh
+```
+
+This will:
+1. Create configuration files in the `configs/` directory
+2. Start a tmux session with the server and client
+3. Display the logs side-by-side for easy debugging
+
+## Testing with mcast_test
+
+Generate test multicast traffic:
+
+# Basic usage with default parameters
+./target/release/castrepeat-mcast-test
+
+# Specify multicast address and port
+./target/release/castrepeat-mcast-test --multicast-addr 239.192.55.1 --port 1681
+
+# Customize message and timing
+./target/release/castrepeat-mcast-test --message "Custom packet data" --interval-ms 500 --duration-sec 120
+
+# Full example with all parameters
+./target/release/castrepeat-mcast-test --multicast-addr 239.192.55.2 --port 1681 --interval-ms 2000 --duration-sec 300 --message "Test packet"
+
+## NAT Traversal
+
+For clients behind NAT firewalls:
+
+1. Forward server's TCP port (default 8989) to the server's internal IP
+2. Set `allow_external_clients = true` in the server's config
+3. Set `nat_traversal = true` in the client's config
+4. Use the server's public IP address in the client's config
+
+## Troubleshooting
+
+### Multicast Packets Not Being Received
+
+- Ensure multicast routing is enabled on your router/switch
+- Check firewall rules to allow multicast traffic
+- Verify that you're using the correct multicast address and port
+- Run `castrepeat-mcast-test` to verify multicast connectivity
+
+### Client Connection Issues
+
+- Check if the server is running and accessible
+- Verify that the shared secret matches on both ends
+- For clients behind NAT, ensure proper port forwarding is configured
+
+### Binding Errors
+
+- Make sure no other application is using the same ports
+- Run the server with sudo/root privileges
+- Try binding to a specific interface
+
+## Common Use Cases
+
+- Bridging multicast traffic across networks
+- Relaying multicast across VPNs or cloud environments
+- Testing multicast applications in isolated environments
+- Monitoring multicast traffic for debugging
+
+## Command Line Arguments
+
+### Server
+
+- `-c, --config <FILE>`: Path to config file (default: server_config.toml)
+- `-g, --generate-default`: Generate default config files
+
+### Client
+
+- `-c, --config <FILE>`: Path to config file (default: client_config.toml)
+- `-g, --generate-default`: Generate default config files
+
+### castrepeat-mcast-test
+
+- `--multicast-addr <ADDR>`: Multicast address to send to (default: 239.192.55.1)
+- `--port <PORT>`: Port to send on (default: 1681)
+- `--interval-ms <MS>`: Sending interval in milliseconds (default: 1000)
+- `--duration-sec <SEC>`: How long to send packets (default: 60)
+- `--message <MSG>`: Custom message to send (default: "Test packet")
+
+
+## Disclaimer at last
+This is my first more like real world usecase rust project, Therefore I do not guarantee anything.
+All comments and documentation have been sanatized by Claude 3.7 Sonnet Thinking, as my personal ones tend to have some more aggresssive messages in them.
+Alot of issues were also fixed by throwing spaghetti at the wall and hoping it magically achieves the results expected \ No newline at end of file
diff --git a/bodeting.sh b/bodeting.sh
new file mode 100644
index 0000000..4ee1a86
--- /dev/null
+++ b/bodeting.sh
@@ -0,0 +1,83 @@
+#!/bin/bash
+
+# This file serves as as an example of how I launch the debug server and client in a tmux session for what I call "Bodeting" (Bodet Alarm and Gong) testing
+
+# Parse command line arguments
+SUDO_PASSWORD="mlol-no"
+while getopts "p:" opt; do
+ case $opt in
+ p) SUDO_PASSWORD="$OPTARG" ;;
+ *) echo "Usage: $0 [-p sudo_password]" >&2; exit 1 ;;
+ esac
+done
+
+# Create directory for client configs if it doesn't exist
+mkdir -p configs
+
+# Create server config with multicast groups
+cat > configs/server_config.toml << EOF
+secret = "bodeting-secret-key-so-much-secure"
+listen_ip = "0.0.0.0"
+listen_port = 8989
+simple_auth = true
+
+[multicast_groups]
+# Bodet Alarm group for 239.192.55.2
+[multicast_groups.bodetalarm]
+address = "239.192.55.2"
+port = 1681
+
+# Bodet Gong group for 239.192.55.1
+[multicast_groups.bodetgong]
+address = "239.192.55.1"
+port = 1681
+
+# Test client authorizations
+[[authorized_clients]]
+name = "localhost"
+ip_address = "127.0.0.1"
+group_ids = [] # Empty means all groups
+EOF
+
+# Create debug agent config
+cat > configs/debug_agent_config.toml << EOF
+secret = "bodeting-secret-key-so-much-secure"
+server = "127.0.0.1"
+port = 8989
+multicast_group_ids = []
+# Set to true for debug mode
+test_mode = true
+EOF
+
+# Check if tmux is already running
+if [ -z "$TMUX" ]; then
+ # Start a new tmux session
+ tmux new-session -d -s debug_server
+
+ # Split the window horizontally (side by side)
+ tmux split-window -h -t debug_server:0.0
+
+ # Configure the left pane for the server
+ if [ -n "$SUDO_PASSWORD" ]; then
+ tmux send-keys -t debug_server:0.0 "cd $(pwd) && clear && echo 'Starting server...' && echo '$SUDO_PASSWORD' | sudo -S RUST_LOG=debug ./target/release/castrepeat-server --config configs/server_config.toml" C-m
+ else
+ tmux send-keys -t debug_server:0.0 "cd $(pwd) && clear && echo 'Starting server...' && sudo RUST_LOG=debug ./target/release/castrepeat-server --config configs/server_config.toml" C-m
+ fi
+
+ # Give the server a moment to start up
+ sleep 2
+
+ # Configure the right pane for the debug agent
+ tmux send-keys -t debug_server:0.1 "cd $(pwd) && clear && echo 'Starting debug agent...' && RUST_LOG=debug ./target/release/castrepeat-client --config configs/debug_agent_config.toml" C-m
+
+ # If you need to test multicast traffic generation, you can add:
+ # tmux split-window -v -t debug_server:0.1
+ # tmux send-keys -t debug_server:0.2 "cd $(pwd) && clear && echo 'Starting multicast generator...' && RUST_LOG=debug ./target/release/castrepeat-mcast-test" C-m
+
+ # Attach to the session
+ tmux attach-session -t debug_server
+else
+ echo "Already in a tmux session. Please exit current tmux session before running this script."
+ exit 1
+fi
+
diff --git a/src/auth.rs b/src/auth.rs
new file mode 100644
index 0000000..153f0a9
--- /dev/null
+++ b/src/auth.rs
@@ -0,0 +1,29 @@
+use hmac::{Hmac, Mac};
+use rand::{rngs::OsRng, RngCore};
+use sha2::Sha256;
+
+// Generate a random nonce
+pub fn generate_nonce() -> String {
+ let mut bytes = [0u8; 16];
+ OsRng.fill_bytes(&mut bytes);
+ hex::encode(bytes)
+}
+
+// Calculate HMAC using the shared secret
+pub fn calculate_hmac(secret: &str, data: &str) -> String {
+ type HmacSha256 = Hmac<Sha256>;
+ let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
+ .expect("HMAC can take key of any size");
+
+ mac.update(data.as_bytes());
+ let result = mac.finalize();
+ let code_bytes = result.into_bytes();
+
+ hex::encode(code_bytes)
+}
+
+// Verify an HMAC token
+pub fn verify_hmac(secret: &str, data: &str, expected: &str) -> bool {
+ let calculated = calculate_hmac(secret, data);
+ calculated == expected
+}
diff --git a/src/bin/client.rs b/src/bin/client.rs
new file mode 100644
index 0000000..5c39fb9
--- /dev/null
+++ b/src/bin/client.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_client_config, ensure_default_configs, ClientConfig},
+ protocol::{serialize_message, Message, format_packet_for_display, robust_deserialize_message},
+ DEFAULT_BUFFER_SIZE, TEST_MODE_BANNER, MAX_DISPLAY_BYTES,
+};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ time::Duration,
+ sync::Arc,
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpStream, UdpSocket},
+ signal,
+ sync::Notify,
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "client_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_client_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Client configuration loaded from {:?}", args.config);
+
+ // Create a notification for clean shutdown
+ let shutdown = Arc::new(Notify::new());
+ let shutdown_signal = shutdown.clone();
+
+ // Setup signal handler for Ctrl+C
+ tokio::spawn(async move {
+ if let Err(e) = signal::ctrl_c().await {
+ error!("Failed to listen for Ctrl+C: {}", e);
+ return;
+ }
+ info!("Received Ctrl+C, shutting down...");
+ shutdown_signal.notify_one();
+ });
+
+ let server_addr = format!("{}:{}", config.server, config.port);
+
+ // Main reconnection loop
+ loop {
+ info!("Connecting to server at {}", server_addr);
+
+ // Try to connect with timeout
+ let connect_result = match tokio::time::timeout(
+ Duration::from_secs(5),
+ TcpStream::connect(&server_addr),
+ ).await {
+ Ok(result) => result,
+ Err(_) => {
+ warn!("Connection attempt timed out");
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ continue;
+ }
+ };
+
+ match connect_result {
+ Ok(stream) => {
+ info!("Connected to server");
+
+ // Run the client session
+ match run_client_session(stream, &config, &shutdown).await {
+ Ok(_) => {
+ info!("Client session ended normally");
+ break;
+ },
+ Err(e) => {
+ error!("Client session error: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ },
+ Err(e) => {
+ error!("Failed to connect: {}", e);
+ if !handle_reconnect(&shutdown, &config).await {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+// Helper function to handle reconnection delay
+// Returns false if shutdown was requested
+async fn handle_reconnect(shutdown: &Arc<Notify>, client_config: &ClientConfig) -> bool {
+ let delay = Duration::from_secs(client_config.reconnect_delay_secs);
+ info!("Reconnecting in {} seconds...", client_config.reconnect_delay_secs);
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested during reconnect");
+ return false;
+ }
+ _ = tokio::time::sleep(delay) => {}
+ }
+ true
+}
+
+// Add a new enum to distinguish message types
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ServerHeartbeat,
+ ServerPong,
+ Other
+}
+
+// The main client session function that handles a single connection
+async fn run_client_session(
+ mut stream: TcpStream,
+ config: &ClientConfig,
+ shutdown: &Arc<Notify>,
+) -> Result<()> {
+ // Authenticate
+ if let Err(e) = authenticate(&mut stream, &config.secret).await {
+ return Err(anyhow::anyhow!("Authentication failed: {}", e));
+ }
+ info!("Authentication successful");
+
+ // Check if test mode is enabled in config
+ let mut test_mode = config.test_mode;
+
+ // Request server's multicast group information
+ let groups_request = Message::MulticastGroupsRequest;
+ let request_bytes = serialize_message(&groups_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Wait for response with multicast group information, ignoring other message types
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let mut groups_response = None;
+
+ // Keep reading until we get the groups response or timeout
+ let mut attempts = 0;
+ while groups_response.is_none() && attempts < 10 {
+ match stream.read(&mut buf).await {
+ Ok(n) if n > 0 => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastGroupsResponse { groups }) => {
+ groups_response = Some(groups);
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Handle ping but keep waiting for multicast groups
+ info!("Got server ping: {}", status);
+ },
+ Ok(other_msg) => {
+ debug!("Ignoring unexpected message while waiting for groups: {:?}", other_msg);
+ },
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ }
+ }
+ },
+ Ok(0) => return Err(anyhow::anyhow!("Server closed connection")),
+ Ok(_) => {},
+ Err(e) => return Err(anyhow::anyhow!("Error reading from server: {}", e))
+ }
+
+ // If we didn't get groups yet, wait a bit and try again
+ if groups_response.is_none() {
+ attempts += 1;
+ tokio::time::sleep(Duration::from_millis(100)).await;
+ }
+ }
+
+ // Now we either have the groups or we timed out
+ let groups_response = groups_response.ok_or_else(||
+ anyhow::anyhow!("Failed to receive multicast group information from server"))?;
+
+ info!("Available multicast groups from server:");
+ for (id, group) in &groups_response {
+ info!(" - {} -> {}:{}", id, group.address, group.port);
+ }
+
+ // Determine which groups to subscribe to
+ let groups_to_subscribe: Vec<String> = if config.multicast_group_ids.is_empty() {
+ // If no specific groups are requested, subscribe to all
+ info!("No specific groups requested, subscribing to all available groups");
+ groups_response.keys().cloned().collect()
+ } else {
+ // Otherwise, subscribe only to requested groups that exist
+ let mut valid_groups = Vec::new();
+ for group_id in &config.multicast_group_ids {
+ if groups_response.contains_key(group_id.as_str()) {
+ valid_groups.push(group_id.clone());
+ } else {
+ warn!("Requested group '{}' does not exist on server", group_id);
+ }
+ }
+
+ if valid_groups.is_empty() {
+ warn!("None of the requested groups exist on server. No data will be received.");
+ } else {
+ info!("Subscribing to {} groups: {:?}", valid_groups.len(), valid_groups);
+ }
+
+ valid_groups
+ };
+
+ // Send subscription message
+ let subscribe_msg = Message::Subscribe {
+ group_ids: groups_to_subscribe.clone(),
+ };
+ let subscribe_bytes = serialize_message(&subscribe_msg)?;
+ stream.write_all(&subscribe_bytes).await?;
+
+ // Create UDP sockets for local retransmission (skip if in test mode)
+ let mut sockets: HashMap<String, UdpSocket> = HashMap::new();
+ if !test_mode {
+ for group_id in &groups_to_subscribe {
+ if let Some(group_info) = groups_response.get(group_id.as_str()) {
+ info!("Creating socket for group {} ({} on port {})",
+ group_id, group_info.address, group_info.port);
+
+ match UdpSocket::bind("0.0.0.0:0").await {
+ Ok(socket) => {
+ sockets.insert(group_id.clone(), socket);
+ info!("Successfully created UDP socket for group {}", group_id);
+ },
+ Err(e) => {
+ error!("Failed to create UDP socket for group {}: {}", group_id, e);
+ }
+ }
+ }
+ }
+
+ if sockets.is_empty() && !groups_to_subscribe.is_empty() {
+ error!("Failed to create any UDP sockets");
+ warn!("Falling back to test mode due to socket creation failure");
+ test_mode = true;
+ }
+ }
+
+ // Display test mode banner if enabled
+ if test_mode {
+ println!("{}", TEST_MODE_BANNER);
+ info!("Test mode enabled - packets will be displayed but not sent to network");
+ }
+
+ // Main receive loop
+ info!("Listening for multicast traffic from server");
+
+ // Set the read timeout for the stream
+ stream.set_nodelay(true)?;
+
+ // Remove problematic code that uses unsupported methods and the nix crate
+ if config.nat_traversal {
+ info!("NAT traversal mode enabled - using more frequent keepalives");
+ }
+
+ // Calculate appropriate ping interval based on NAT traversal setting
+ let ping_interval = if config.nat_traversal {
+ Duration::from_secs(25) // More frequent for NAT
+ } else {
+ Duration::from_secs(55)
+ };
+
+ // Main receive loop
+ loop {
+ tokio::select! {
+ _ = shutdown.notified() => {
+ info!("Shutdown requested, ending client session");
+ return Ok(());
+ }
+ read_result = stream.read(&mut buf) => {
+ match read_result {
+ Ok(0) => {
+ info!("Server closed connection");
+ return Err(anyhow::anyhow!("Server closed connection"));
+ }
+ Ok(n) => {
+ match robust_deserialize_message(&buf[..n]) {
+ Ok(Message::MulticastPacket { group_id, source, destination, port, data }) => {
+ if test_mode {
+ println!("\n----- MULTICAST PACKET -----");
+ println!("Group: {}", group_id);
+ println!("Source: {}", source);
+ println!("Destination: {}:{}", destination, port);
+ println!("Size: {} bytes", data.len());
+ println!("Data:\n{}", format_packet_for_display(&data, MAX_DISPLAY_BYTES));
+ println!("---------------------------\n");
+ } else {
+ info!("Received multicast packet: group={}, from={}, to={}:{}, size={}bytes",
+ group_id, source, destination, port, data.len());
+
+ // Get socket for this group
+ if let Some(socket) = sockets.get(&group_id) {
+ // Parse destination address directly from the packet
+ match IpAddr::from_str(&destination) {
+ Ok(dest_addr) => {
+ // Create destination socket address using the packet's port
+ let dest = SocketAddr::new(dest_addr, port);
+
+ info!("Forwarding packet to {}:{}", dest_addr, port);
+
+ // Retransmit locally
+ match socket.send_to(&data, dest).await {
+ Ok(sent) => {
+ info!("Successfully forwarded {} of {} bytes to {}:{}",
+ sent, data.len(), destination, port);
+ },
+ Err(e) => {
+ error!("Failed to retransmit packet for group {} to {}:{}: {}",
+ group_id, destination, port, e);
+ }
+ }
+ },
+ Err(e) => {
+ error!("Invalid destination address {}: {}", destination, e);
+ }
+ }
+ } else {
+ warn!("No socket available for group {}", group_id);
+ }
+ }
+ },
+ Ok(Message::ConfigResponse { config: _server_config }) => {
+ info!("Server configuration received");
+ },
+ Ok(Message::PingStatus { timestamp: _, status }) => {
+ // Parse the message type based on the status text
+ let msg_type = if status.starts_with("Server heartbeat to") {
+ StatusMessageType::ServerHeartbeat
+ } else if status.starts_with("Server connection to") {
+ StatusMessageType::ServerPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Only respond with a pong to server heartbeats
+ if msg_type == StatusMessageType::ServerHeartbeat {
+ debug!("Received server heartbeat: {}", status);
+
+ // Send a pong response (but only for heartbeats)
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let pong = Message::PingStatus {
+ timestamp: now,
+ status: "Client pong response".to_string(),
+ };
+
+ if let Ok(bytes) = serialize_message(&pong) {
+ let _ = stream.write_all(&bytes).await;
+ }
+ } else {
+ // Just log other status messages without responding
+ debug!("Connection Status: {}", status);
+ }
+ },
+ Ok(_) => debug!("Received other message type"),
+ Err(e) => {
+ error!("Failed to deserialize message: {}", e);
+ // Don't return/break on deserialization errors - continue reading
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from server: {}", e);
+ return Err(anyhow::anyhow!("Connection error: {}", e));
+ }
+ }
+ },
+ _ = tokio::time::sleep(ping_interval) => {
+ // Send regular ping to keep NAT connection alive
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let ping_msg = Message::PingStatus {
+ timestamp: now,
+ status: if config.nat_traversal {
+ "Client keepalive ping".to_string() // Changed text to be more specific
+ } else {
+ "Client periodic ping".to_string()
+ },
+ };
+
+ match serialize_message(&ping_msg) {
+ Ok(bytes) => {
+ if let Err(e) = stream.write_all(&bytes).await {
+ error!("Failed to ping server: {}", e);
+ return Err(anyhow::anyhow!("Server ping failed: {}", e));
+ }
+ debug!("Connection check sent to server");
+ },
+ Err(e) => error!("Failed to serialize ping message: {}", e),
+ }
+ }
+ }
+ }
+}
+
+async fn authenticate(stream: &mut TcpStream, secret: &str) -> Result<()> {
+ // Generate client nonce
+ let client_nonce = generate_nonce();
+
+ // Send auth request
+ let auth_request = Message::AuthRequest {
+ client_nonce: client_nonce.clone(),
+ };
+ let request_bytes = serialize_message(&auth_request)?;
+ stream.write_all(&request_bytes).await?;
+
+ // Receive response
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let n = stream.read(&mut buf).await?;
+ let response = robust_deserialize_message(&buf[..n])?;
+
+ if let Message::AuthResponse { server_nonce, auth_token } = response {
+ // Verify server's token
+ let expected_data = format!("{}{}", client_nonce, server_nonce);
+ if !verify_hmac(secret, &expected_data, &auth_token) {
+ return Err(anyhow::anyhow!("Server authentication failed"));
+ }
+
+ // Calculate our token
+ let auth_data = format!("{}{}", server_nonce, client_nonce);
+ let client_token = calculate_hmac(secret, &auth_data);
+
+ // Send confirmation
+ let confirm = Message::AuthConfirm {
+ auth_token: client_token,
+ };
+
+ let confirm_bytes = serialize_message(&confirm)?;
+ stream.write_all(&confirm_bytes).await?;
+
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Unexpected response from server"))
+ }
+}
diff --git a/src/bin/mcast_test.rs b/src/bin/mcast_test.rs
new file mode 100644
index 0000000..782610c
--- /dev/null
+++ b/src/bin/mcast_test.rs
@@ -0,0 +1,123 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{error, info};
+use socket2::{Domain, Protocol, Socket, Type};
+use std::net::{IpAddr, Ipv4Addr, SocketAddr};
+use std::str::FromStr;
+use std::time::{Duration, Instant};
+use tokio::net::UdpSocket;
+
+#[derive(Parser, Debug)]
+#[command(author, version, about = "Multicast packet generator for testing CastRepeat")]
+struct Args {
+ #[arg(short, long, default_value = "239.192.55.1")]
+ multicast_addr: String,
+
+ #[arg(short, long, default_value = "1681")]
+ port: u16,
+
+ #[arg(short, long, default_value = "1000")]
+ interval_ms: u64,
+
+ #[arg(short, long, default_value = "60")]
+ duration_sec: u64,
+
+ #[arg(short, long, default_value = "Test packet")]
+ message: String,
+
+ #[arg(short, long)]
+ interface: Option<String>,
+
+ #[arg(short = 'b', long, action)]
+ binary_mode: bool,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ info!("CastRepeat Multicast Packet Generator");
+ info!("--------------------------------");
+ info!("Multicast Address: {}", args.multicast_addr);
+ info!("Port: {}", args.port);
+ info!("Interval: {} ms", args.interval_ms);
+ info!("Duration: {} sec", args.duration_sec);
+ if let Some(interface) = &args.interface {
+ info!("Interface: {}", interface);
+ }
+ info!("Mode: {}", if args.binary_mode { "Binary" } else { "Text" });
+ info!("--------------------------------");
+
+ // Verify multicast address
+ let mcast_addr = match IpAddr::from_str(&args.multicast_addr) {
+ Ok(IpAddr::V4(addr)) if addr.is_multicast() => addr,
+ _ => {
+ error!("Invalid multicast address: {}", args.multicast_addr);
+ return Ok(());
+ }
+ };
+
+ // Create socket
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
+ .context("Failed to create socket")?;
+
+ socket.set_multicast_ttl_v4(4)?;
+ socket.set_nonblocking(true)?;
+
+ // Set the multicast interface if specified
+ if let Some(if_str) = &args.interface {
+ if let Ok(if_addr) = Ipv4Addr::from_str(if_str) {
+ socket.set_multicast_if_v4(&if_addr)?;
+ info!("Using interface: {}", if_addr);
+ } else {
+ error!("Invalid interface address: {}", if_str);
+ return Ok(());
+ }
+ }
+
+ let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
+ socket.bind(&addr.into())?;
+
+ let socket = UdpSocket::from_std(socket.into())?;
+ let dest_addr = SocketAddr::new(IpAddr::V4(mcast_addr), args.port);
+
+ // Start sending packets
+ info!("Sending packets to {}...", dest_addr);
+
+ let start = Instant::now();
+ let end = start + Duration::from_secs(args.duration_sec);
+ let mut counter: u64 = 0; // Specify counter type as u64
+
+ while Instant::now() < end {
+ counter += 1;
+
+ // Create either text or binary test data
+ let data = if args.binary_mode {
+ // Create binary test data (similar to what we might see in the field)
+ let mut packet = Vec::with_capacity(16);
+ packet.extend_from_slice(b"REL\0"); // 4 bytes header
+ packet.extend_from_slice(&counter.to_be_bytes()[4..]); // 4 bytes counter
+ packet.extend_from_slice(&[0x00, 0x10, 0x02, 0x00]); // 4 bytes
+ packet.extend_from_slice(&[0x00, 0x00, 0x00, 0x90]); // 4 bytes
+ packet
+ } else {
+ // Create text test data
+ format!("{} #{}", args.message, counter).into_bytes()
+ };
+
+ match socket.send_to(&data, &dest_addr).await {
+ Ok(bytes) => {
+ info!("Sent packet #{}: {} bytes", counter, bytes);
+ }
+ Err(e) => {
+ error!("Error sending packet: {}", e);
+ }
+ }
+
+ tokio::time::sleep(Duration::from_millis(args.interval_ms)).await;
+ }
+
+ info!("Done. Sent {} packets in {} seconds.", counter, args.duration_sec);
+ Ok(())
+}
diff --git a/src/bin/server.rs b/src/bin/server.rs
new file mode 100644
index 0000000..7fe0109
--- /dev/null
+++ b/src/bin/server.rs
@@ -0,0 +1,467 @@
+use anyhow::{Context, Result};
+use clap::Parser;
+use log::{debug, error, info, warn};
+use multicast_relay::{
+ auth::{calculate_hmac, generate_nonce, verify_hmac},
+ config::{load_server_config, ensure_default_configs, ServerConfig, MulticastGroup, get_client_authorized_groups},
+ protocol::{deserialize_message, serialize_message, Message, MulticastGroupInfo},
+ DEFAULT_BUFFER_SIZE,
+};
+use socket2::{Domain, Protocol, Socket, Type};
+use std::{
+ collections::HashMap,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
+ path::PathBuf,
+ str::FromStr,
+ sync::Arc,
+ time::Duration, // Add this import for Duration
+};
+use tokio::{
+ io::{AsyncReadExt, AsyncWriteExt},
+ net::{TcpListener, TcpStream},
+ sync::{mpsc, Mutex},
+};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ #[arg(short, long, default_value = "server_config.toml")]
+ config: PathBuf,
+
+ #[arg(short, long, action)]
+ generate_default: bool,
+}
+
+type ClientMap = Arc<Mutex<HashMap<SocketAddr, mpsc::Sender<Vec<u8>>>>>;
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ env_logger::init();
+ let args = Args::parse();
+
+ // Generate default configs if requested
+ if args.generate_default {
+ ensure_default_configs()?;
+ return Ok(());
+ }
+
+ // Load configuration
+ let config = load_server_config(&args.config)
+ .context(format!("Failed to load config from {:?}", args.config))?;
+
+ info!("Server configuration loaded from {:?}", args.config);
+
+ let listen_addr = format!("{}:{}", config.listen_ip, config.listen_port);
+ let listener = TcpListener::bind(&listen_addr).await
+ .context("Failed to bind TCP listener")?;
+
+ info!("Server listening on {}", listen_addr);
+
+ // Setup multicast receivers
+ let clients: ClientMap = Arc::new(Mutex::new(HashMap::new()));
+
+ // Start multicast listeners for each multicast group
+ for (group_id, group) in &config.multicast_groups {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ error!("No ports defined for group {}", group_id);
+ continue;
+ }
+
+ let display_group_id = group_id.clone();
+ let ports_display = if ports.len() == 1 {
+ format!("port {}", ports[0])
+ } else {
+ format!("ports {}-{}", ports[0], ports.last().unwrap())
+ };
+
+ // Create a listener for each port in the range
+ for port in ports {
+ let clients = clients.clone();
+ let _secret = config.secret.clone();
+ let group_id_clone = group_id.clone();
+ let mut group_info = group.clone();
+
+ // Set the specific port for this listener
+ group_info.port = Some(port);
+ group_info.port_range = None;
+
+ tokio::spawn(async move {
+ if let Err(e) = listen_to_multicast(&group_id_clone, &group_info, clients).await {
+ error!("Multicast listener error for group {} port {}: {}",
+ group_id_clone, port, e);
+ }
+ });
+ }
+
+ info!("Listening for multicast group {} on address {} with {}",
+ display_group_id, group.address, ports_display);
+ }
+
+ // Store config for use in client handlers
+ let config = Arc::new(config);
+
+ // Accept client connections
+ while let Ok((stream, addr)) = listener.accept().await {
+ info!("New client connection from: {}", addr);
+ let secret = config.secret.clone();
+ let clients = clients.clone();
+ let config = config.clone();
+
+ tokio::spawn(async move {
+ if let Err(e) = handle_client(stream, addr, &secret, clients, config).await {
+ error!("Client error: {}: {}", addr, e);
+ }
+ info!("Client disconnected: {}", addr);
+ });
+ }
+
+ Ok(())
+}
+
+async fn listen_to_multicast(
+ group_id: &str,
+ group: &MulticastGroup,
+ clients: ClientMap
+) -> Result<()> {
+ // Get the port to use
+ let port = group.port.ok_or_else(|| anyhow::anyhow!("No port specified"))?;
+
+ // Parse the multicast address
+ let mcast_ip = match IpAddr::from_str(&group.address)
+ .context("Invalid multicast address")? {
+ IpAddr::V4(addr) => addr,
+ _ => return Err(anyhow::anyhow!("Only IPv4 multicast supported"))
+ };
+
+ // Create a UDP socket with more explicit settings
+ let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
+ .context("Failed to create socket")?;
+
+ // Important: Set socket options
+ socket.set_reuse_address(true)?;
+
+ #[cfg(unix)]
+ socket.set_reuse_port(true)?;
+
+ socket.set_nonblocking(true)?;
+ socket.set_multicast_loop_v4(true)?;
+
+ // THIS IS THE KEY CHANGE: Bind to the specific multicast address AND port
+ // Instead of binding to 0.0.0.0:port, bind directly to the multicast address:port
+ let bind_addr = SocketAddr::new(IpAddr::V4(mcast_ip), port);
+ info!("Binding multicast listener to specific address: {:?}", bind_addr);
+ socket.bind(&bind_addr.into())?;
+
+ // Join the multicast group with a specific interface
+ let interface = Ipv4Addr::new(0, 0, 0, 0); // Any interface
+ info!("Joining multicast group {} on interface {:?}", mcast_ip, interface);
+ socket.join_multicast_v4(&mcast_ip, &interface)?;
+
+ // Additional multicast option: set the IP_MULTICAST_IF option
+ socket.set_multicast_if_v4(&interface)?;
+
+ // Convert to tokio socket
+ let udp_socket = tokio::net::UdpSocket::from_std(socket.into())
+ .context("Failed to convert socket to async")?;
+
+ info!("Multicast listener ready and bound specifically to {}:{} (group {})",
+ group.address, port, group_id);
+
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ let group_id = group_id.to_string();
+
+ loop {
+ match udp_socket.recv_from(&mut buf).await {
+ Ok((len, src)) => {
+ // Since we're bound to the exact multicast address, we can be confident
+ // this packet was sent to our specific multicast group
+ let data = buf[..len].to_vec();
+
+ info!("RECEIVED: group={} from={} size={} destination={}:{}",
+ group_id, src, len, mcast_ip, port);
+
+ // Create a message with the packet
+ let message = Message::MulticastPacket {
+ group_id: group_id.clone(),
+ source: src,
+ destination: group.address.clone(),
+ port,
+ data,
+ };
+
+ // Send to clients
+ match serialize_message(&message) {
+ Ok(serialized) => {
+ let clients_lock = clients.lock().await;
+ for (client_addr, sender) in clients_lock.iter() {
+ if sender.send(serialized.clone()).await.is_err() {
+ debug!("Failed to send to client {}", client_addr);
+ } else {
+ debug!("Sent multicast packet to client {}", client_addr);
+ }
+ }
+ }
+ Err(e) => error!("Failed to serialize message: {}", e),
+ }
+ }
+ Err(e) => {
+ if e.kind() != std::io::ErrorKind::WouldBlock {
+ error!("Error receiving from socket: {}", e);
+ }
+ // Small delay to avoid busy waiting on errors
+ tokio::time::sleep(Duration::from_millis(10)).await;
+ }
+ }
+ }
+}
+
+#[derive(PartialEq)]
+enum StatusMessageType {
+ ClientHeartbeat,
+ ClientPong,
+ Other
+}
+
+async fn handle_client(
+ stream: TcpStream,
+ addr: SocketAddr,
+ secret: &str,
+ clients: ClientMap,
+ config: Arc<ServerConfig>,
+) -> Result<()> {
+ // Check if external clients are allowed when client is not from localhost
+ if !config.allow_external_clients &&
+ !addr.ip().is_loopback() &&
+ !addr.ip().to_string().starts_with("192.168.") &&
+ !addr.ip().to_string().starts_with("10.") {
+ warn!("Connection attempt from external address {} rejected - set allow_external_clients=true to allow", addr);
+ return Err(anyhow::anyhow!("External clients not allowed"));
+ }
+
+ // Split the TCP stream into read and write parts once
+ let (mut read_stream, mut write_stream) = tokio::io::split(stream);
+
+ // Authentication using the split streams
+ if !authenticate_client(&mut read_stream, &mut write_stream, addr, secret).await? {
+ return Err(anyhow::anyhow!("Authentication failed"));
+ }
+
+ info!("Client authenticated: {}", addr);
+
+ // Get client info
+ let client_ip = addr.ip().to_string();
+ let client_port = addr.port();
+
+ // Check if client has specific group permissions
+ let authorized_groups = match get_client_authorized_groups(&config, &client_ip, client_port) {
+ Some(groups) => groups,
+ None => return Err(anyhow::anyhow!("Client not authorized for any groups")),
+ };
+
+ // Create channel for sending multicast packets to this client
+ let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
+
+ // Add client to map
+ clients.lock().await.insert(addr, tx.clone());
+
+ // Create HashMap of available groups for this client
+ let mut available_groups = HashMap::new();
+ for (id, group) in &config.multicast_groups {
+ // If client has empty group list (all allowed) or specific group is in list
+ if authorized_groups.is_empty() || authorized_groups.contains(id) {
+ let ports = group.get_ports();
+ if ports.is_empty() {
+ continue;
+ }
+
+ // Primary port is the first one
+ let primary_port = ports[0];
+
+ // Get additional ports if there are any
+ let additional_ports = if ports.len() > 1 {
+ Some(ports[1..].to_vec())
+ } else {
+ None
+ };
+
+ available_groups.insert(id.clone(), MulticastGroupInfo {
+ address: group.address.clone(),
+ port: primary_port,
+ additional_ports,
+ });
+ }
+ }
+
+ // IMPORTANT: Clone tx before moving it into the spawn
+ let tx_for_read = tx.clone();
+
+ // Spawn task to read client messages
+ let clients_clone = clients.clone();
+
+ // Use the already split read_stream
+ let read_task = tokio::spawn(async move {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+ loop {
+ match read_stream.read(&mut buf).await {
+ Ok(0) => break, // Connection closed
+ Ok(n) => {
+ if let Ok(msg) = deserialize_message(&buf[..n]) {
+ match msg {
+ Message::Subscribe { group_ids } => {
+ info!("Client {} subscribing to groups: {:?}", addr, group_ids);
+ // Group subscriptions handled by server
+ },
+ Message::MulticastGroupsRequest => {
+ // Send available groups to client
+ let response = Message::MulticastGroupsResponse {
+ groups: available_groups.clone()
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ Message::PingStatus { timestamp, status } => {
+ // Determine the type of status message
+ let msg_type = if status.starts_with("Client keepalive ping") ||
+ status.starts_with("Client periodic ping") {
+ StatusMessageType::ClientHeartbeat
+ } else if status.starts_with("Client pong response") {
+ StatusMessageType::ClientPong
+ } else {
+ StatusMessageType::Other
+ };
+
+ // Log the message receipt
+ match msg_type {
+ StatusMessageType::ClientHeartbeat => {
+ info!("Heartbeat from client {}: {}", addr, status);
+
+ // Respond only to actual heartbeat pings, not pong responses
+ let response = Message::PingStatus {
+ timestamp,
+ status: format!("Server connection to {} is OK", addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&response) {
+ let _ = tx_for_read.send(bytes).await;
+ }
+ },
+ StatusMessageType::ClientPong => {
+ // Just log pongs without responding to avoid loops
+ debug!("Pong from client {}: {}", addr, status);
+ },
+ StatusMessageType::Other => {
+ info!("Status message from client {}: {}", addr, status);
+ }
+ }
+ },
+ _ => {}
+ }
+ }
+ }
+ Err(e) => {
+ error!("Error reading from client: {}: {}", addr, e);
+ break;
+ }
+ }
+ }
+ // Clean up on disconnect
+ clients_clone.lock().await.remove(&addr);
+ info!("Client reader task ended: {}", addr);
+ });
+
+ // Forward multicast packets to client using the already split write_stream
+ let write_task = tokio::spawn(async move {
+ while let Some(packet) = rx.recv().await {
+ if let Err(e) = write_stream.write_all(&packet).await {
+ error!("Error writing to client {}: {}", addr, e);
+ break;
+ }
+ }
+ info!("Client writer task ended: {}", addr);
+ });
+
+ // Now tx is still valid here - use it for heartbeats, but with a delay
+ let tx_for_heartbeat = tx.clone();
+ let client_addr = addr.clone();
+ tokio::spawn(async move {
+ // Add initial delay before starting heartbeats to avoid interfering with initial setup messages
+ tokio::time::sleep(Duration::from_secs(5)).await;
+
+ let mut interval = tokio::time::interval(Duration::from_secs(30));
+ loop {
+ interval.tick().await;
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ // Send a heartbeat with a clear identifier
+ let msg = Message::PingStatus {
+ timestamp: now,
+ status: format!("Server heartbeat to {}", client_addr),
+ };
+
+ if let Ok(bytes) = serialize_message(&msg) {
+ if tx_for_heartbeat.send(bytes).await.is_err() {
+ break;
+ }
+ }
+ }
+ });
+
+ // Wait for either task to complete
+ tokio::select! {
+ _ = read_task => {},
+ _ = write_task => {},
+ }
+
+ // Clean up
+ clients.lock().await.remove(&addr);
+ Ok(())
+}
+
+async fn authenticate_client(
+ reader: &mut (impl AsyncReadExt + Unpin),
+ writer: &mut (impl AsyncWriteExt + Unpin),
+ _addr: SocketAddr,
+ secret: &str
+) -> Result<bool> {
+ let mut buf = vec![0u8; DEFAULT_BUFFER_SIZE];
+
+ // Receive auth request
+ let n = reader.read(&mut buf).await?;
+ let auth_request = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthRequest { client_nonce } = auth_request {
+ // Generate server nonce
+ let server_nonce = generate_nonce();
+
+ // Calculate auth token
+ let auth_data = format!("{}{}", client_nonce, server_nonce);
+ let auth_token = calculate_hmac(secret, &auth_data);
+
+ // Send response
+ let response = Message::AuthResponse {
+ server_nonce: server_nonce.clone(),
+ auth_token,
+ };
+
+ let response_bytes = serialize_message(&response)?;
+ writer.write_all(&response_bytes).await?;
+
+ // Receive confirmation
+ let n = reader.read(&mut buf).await?;
+ let auth_confirm = deserialize_message(&buf[..n])?;
+
+ if let Message::AuthConfirm { auth_token } = auth_confirm {
+ // Verify token
+ let expected_data = format!("{}{}", server_nonce, client_nonce);
+ return Ok(verify_hmac(secret, &expected_data, &auth_token));
+ }
+ }
+
+ Ok(false)
+}
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..c3e7e56
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,243 @@
+use anyhow::{Context, Result};
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use std::fs;
+use std::path::Path;
+
+// Structure to define individual authorized clients
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ClientDefinition {
+ pub name: String, // Friendly name for the client
+ pub ip_address: String, // Allowed IP address
+ pub port: Option<u16>, // Optional specific port (if None, any port is allowed)
+ pub group_ids: Vec<String>, // Multicast groups this client is allowed to access (empty = all)
+}
+
+// Definition of a multicast group
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct MulticastGroup {
+ pub address: String, // Multicast address
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub port: Option<u16>, // Single port (used if port_range is None)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub port_range: Option<(u16, u16)>, // Optional port range (start, end)
+}
+
+impl MulticastGroup {
+ // Helper method to get all ports that should be used
+ pub fn get_ports(&self) -> Vec<u16> {
+ if let Some(range) = self.port_range {
+ // If a range is specified, return all ports in the range
+ (range.0..=range.1).collect()
+ } else if let Some(port) = self.port {
+ // If only a single port is specified, return just that port
+ vec![port]
+ } else {
+ // Default to empty vec if neither is specified (should be validated elsewhere)
+ vec![]
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ServerConfig {
+ pub secret: String,
+ pub listen_ip: String,
+ pub listen_port: u16,
+ pub multicast_groups: HashMap<String, MulticastGroup>,
+ pub authorized_clients: Vec<ClientDefinition>,
+ #[serde(default)]
+ pub simple_auth: bool,
+ #[serde(default)]
+ pub allow_external_clients: bool, // Flag to allow clients from outside the local network
+}
+
+impl Default for ServerConfig {
+ fn default() -> Self {
+ let mut groups = HashMap::new();
+ groups.insert("default".to_string(), MulticastGroup {
+ address: "224.0.0.1".to_string(),
+ port: Some(5000),
+ port_range: None,
+ });
+ groups.insert("ssdp".to_string(), MulticastGroup {
+ address: "239.255.255.250".to_string(),
+ port: Some(1900),
+ port_range: None,
+ });
+ // Example with port range
+ groups.insert("range_example".to_string(), MulticastGroup {
+ address: "239.192.55.2".to_string(),
+ port: None,
+ port_range: Some((1680, 1685)),
+ });
+
+ Self {
+ secret: "changeme".to_string(),
+ listen_ip: "0.0.0.0".to_string(),
+ listen_port: 8989,
+ multicast_groups: groups,
+ authorized_clients: vec![
+ ClientDefinition {
+ name: "localhost".to_string(),
+ ip_address: "127.0.0.1".to_string(),
+ port: None,
+ group_ids: vec![], // Empty means all groups
+ },
+ ClientDefinition {
+ name: "example-client".to_string(),
+ ip_address: "192.168.1.100".to_string(),
+ port: Some(12345),
+ group_ids: vec!["default".to_string()], // Only the default group
+ }
+ ],
+ simple_auth: false, // Default to standard auth mode
+ allow_external_clients: false, // Default to disallow external clients
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ClientConfig {
+ pub secret: String,
+ pub server: String,
+ pub port: u16,
+ pub multicast_group_ids: Vec<String>,
+ pub test_mode: bool,
+ #[serde(default)]
+ pub nat_traversal: bool, // Flag to enable NAT traversal features
+ #[serde(default = "default_reconnect_delay")]
+ pub reconnect_delay_secs: u64,
+}
+
+fn default_reconnect_delay() -> u64 {
+ 5
+}
+
+impl Default for ClientConfig {
+ fn default() -> Self {
+ Self {
+ secret: "changeme".to_string(),
+ server: "127.0.0.1".to_string(),
+ port: 8989,
+ multicast_group_ids: vec![], // Empty means subscribe to all groups
+ test_mode: false,
+ nat_traversal: false,
+ reconnect_delay_secs: 5,
+ }
+ }
+}
+
+// Load server configuration from a file
+pub fn load_server_config<P: AsRef<Path>>(path: P) -> Result<ServerConfig> {
+ let config_str = fs::read_to_string(path)
+ .context("Failed to read server configuration file")?;
+ let config: ServerConfig = toml::from_str(&config_str)
+ .context("Failed to parse server configuration")?;
+ Ok(config)
+}
+
+// Save server configuration to a file
+pub fn save_server_config<P: AsRef<Path>>(config: &ServerConfig, path: P) -> Result<()> {
+ let toml_str = toml::to_string_pretty(config)
+ .context("Failed to serialize server configuration")?;
+ fs::write(path, toml_str)
+ .context("Failed to write server configuration file")?;
+ Ok(())
+}
+
+// Load client configuration from a file
+pub fn load_client_config<P: AsRef<Path>>(path: P) -> Result<ClientConfig> {
+ let config_str = fs::read_to_string(path)
+ .context("Failed to read client configuration file")?;
+ let config: ClientConfig = toml::from_str(&config_str)
+ .context("Failed to parse client configuration")?;
+ Ok(config)
+}
+
+// Save client configuration to a file
+pub fn save_client_config<P: AsRef<Path>>(config: &ClientConfig, path: P) -> Result<()> {
+ let toml_str = toml::to_string_pretty(config)
+ .context("Failed to serialize client configuration")?;
+ fs::write(path, toml_str)
+ .context("Failed to write client configuration file")?;
+ Ok(())
+}
+
+// Generate default configuration files if they don't exist
+pub fn ensure_default_configs() -> Result<()> {
+ let server_config_path = "server_config.toml";
+ if !Path::new(server_config_path).exists() {
+ let default_config = ServerConfig::default();
+ save_server_config(&default_config, server_config_path)?;
+ println!("Created default server configuration at {}", server_config_path);
+ }
+
+ let client_config_path = "client_config.toml";
+ if !Path::new(client_config_path).exists() {
+ let default_config = ClientConfig::default();
+ save_client_config(&default_config, client_config_path)?;
+ println!("Created default client configuration at {}", client_config_path);
+ }
+
+ Ok(())
+}
+
+// Check if a client is authorized for specified groups
+pub fn get_client_authorized_groups(
+ config: &ServerConfig,
+ ip: &str,
+ port: u16
+) -> Option<Vec<String>> {
+ // If simple_auth is enabled, all clients that know the secret get access to all groups
+ if config.simple_auth {
+ return Some(vec![]); // Empty vector means all groups
+ }
+
+ // Otherwise use the standard group authorization check
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None or matches the client's port
+ if client.port.is_none() || client.port == Some(port) {
+ // Return the group IDs or empty vector (meaning all groups)
+ return Some(client.group_ids.clone());
+ }
+ }
+ }
+
+ None // Not authorized
+}
+
+// Get client name if authorized
+pub fn get_client_name(config: &ServerConfig, ip: &str, port: u16) -> Option<String> {
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None, any port from this IP is allowed
+ if client.port.is_none() || client.port == Some(port) {
+ return Some(client.name.clone());
+ }
+ }
+ }
+
+ None
+}
+
+// Check if a client's IP address and port are authorized
+pub fn is_client_authorized(config: &ServerConfig, ip: &str, port: u16) -> bool {
+ // If simple_auth is enabled, all clients that know the secret are authorized
+ if config.simple_auth {
+ return true;
+ }
+
+ // Otherwise use the standard client authorization check
+ for client in &config.authorized_clients {
+ if client.ip_address == ip {
+ // If port is None, any port from this IP is allowed
+ if client.port.is_none() || client.port == Some(port) {
+ return true;
+ }
+ }
+ }
+
+ false
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..93f593e
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,13 @@
+pub mod auth;
+pub mod config;
+pub mod protocol;
+
+// Constants
+pub const DEFAULT_BUFFER_SIZE: usize = 65536; // Max UDP packet size
+pub const TEST_MODE_BANNER: &str = "
+╔════════════════════════════════════════════════════╗
+║ TEST MODE ║
+║ Packets will be displayed but not sent to network ║
+╚════════════════════════════════════════════════════╝
+";
+pub const MAX_DISPLAY_BYTES: usize = 64; // Maximum bytes to display in test mode
diff --git a/src/protocol.rs b/src/protocol.rs
new file mode 100644
index 0000000..340bbd3
--- /dev/null
+++ b/src/protocol.rs
@@ -0,0 +1,129 @@
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use std::net::SocketAddr;
+
+// Information about a multicast group
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct MulticastGroupInfo {
+ pub address: String,
+ pub port: u16,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub additional_ports: Option<Vec<u16>>, // Additional ports if a range was defined
+}
+
+// Protocol messages exchanged between client and server
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub enum Message {
+ // Initial authentication message
+ AuthRequest {
+ client_nonce: String,
+ },
+
+ // Authentication response
+ AuthResponse {
+ server_nonce: String,
+ auth_token: String, // HMAC(secret, client_nonce + server_nonce)
+ },
+
+ // Final auth confirmation
+ AuthConfirm {
+ auth_token: String, // HMAC(secret, server_nonce + client_nonce)
+ },
+
+ // Request information about available multicast groups
+ MulticastGroupsRequest,
+
+ // Response with available multicast groups
+ MulticastGroupsResponse {
+ groups: HashMap<String, MulticastGroupInfo>,
+ },
+
+ // Subscribe to specific multicast groups
+ Subscribe {
+ group_ids: Vec<String>,
+ },
+
+ // Multicast packet forwarded to client
+ MulticastPacket {
+ group_id: String,
+ source: SocketAddr,
+ destination: String, // Destination multicast address
+ port: u16, // Destination port
+ data: Vec<u8>,
+ },
+
+ // Configuration response with current settings
+ ConfigResponse {
+ config: ServerConfigInfo,
+ },
+
+ // New ping/status message for checking connections
+ PingStatus {
+ timestamp: u64,
+ status: String,
+ },
+}
+
+// Server configuration info that can be sent to clients
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ServerConfigInfo {
+ pub available_multicast_addresses: Vec<String>,
+ pub multicast_port: u16,
+}
+
+// Utility function to format packet data for display in test mode
+pub fn format_packet_for_display(data: &[u8], max_bytes: usize) -> String {
+ let display_len = std::cmp::min(data.len(), max_bytes);
+ let mut result = String::new();
+
+ // Print hexadecimal representation
+ for (i, byte) in data.iter().take(display_len).enumerate() {
+ if i > 0 && i % 16 == 0 {
+ result.push('\n');
+ }
+ result.push_str(&format!("{:02x} ", byte));
+ }
+
+ if data.len() > max_bytes {
+ result.push_str("\n... (truncated)");
+ }
+
+ result
+}
+
+// Serialize a message to bytes
+pub fn serialize_message(message: &Message) -> Result<Vec<u8>, serde_json::Error> {
+ serde_json::to_vec(message)
+}
+
+// Deserialize bytes to a message with better error handling for partial messages
+pub fn deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
+ // Find where valid JSON ends to handle cases where multiple messages
+ // or trailing data might be in the buffer
+ let mut deserializer = serde_json::Deserializer::from_slice(bytes);
+ let message = Message::deserialize(&mut deserializer)?;
+
+ // Return the successfully parsed message
+ Ok(message)
+}
+
+// Add a helper function to handle potential noise in streams
+pub fn robust_deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
+ // First try the standard method
+ match deserialize_message(bytes) {
+ Ok(message) => Ok(message),
+ Err(e) => {
+ // If it fails due to trailing characters, try to parse just the valid JSON
+ if e.is_syntax() && e.to_string().contains("trailing characters") {
+ // Try to find where valid JSON ends by parsing incrementally
+ for i in (1..bytes.len()).rev() {
+ if let Ok(msg) = deserialize_message(&bytes[0..i]) {
+ return Ok(msg);
+ }
+ }
+ }
+ // If we couldn't recover, return the original error
+ Err(e)
+ }
+ }
+}