diff options
Diffstat (limited to 'src/protocol.rs')
-rw-r--r-- | src/protocol.rs | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..340bbd3 --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,129 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; + +// Information about a multicast group +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MulticastGroupInfo { + pub address: String, + pub port: u16, + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_ports: Option<Vec<u16>>, // Additional ports if a range was defined +} + +// Protocol messages exchanged between client and server +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Message { + // Initial authentication message + AuthRequest { + client_nonce: String, + }, + + // Authentication response + AuthResponse { + server_nonce: String, + auth_token: String, // HMAC(secret, client_nonce + server_nonce) + }, + + // Final auth confirmation + AuthConfirm { + auth_token: String, // HMAC(secret, server_nonce + client_nonce) + }, + + // Request information about available multicast groups + MulticastGroupsRequest, + + // Response with available multicast groups + MulticastGroupsResponse { + groups: HashMap<String, MulticastGroupInfo>, + }, + + // Subscribe to specific multicast groups + Subscribe { + group_ids: Vec<String>, + }, + + // Multicast packet forwarded to client + MulticastPacket { + group_id: String, + source: SocketAddr, + destination: String, // Destination multicast address + port: u16, // Destination port + data: Vec<u8>, + }, + + // Configuration response with current settings + ConfigResponse { + config: ServerConfigInfo, + }, + + // New ping/status message for checking connections + PingStatus { + timestamp: u64, + status: String, + }, +} + +// Server configuration info that can be sent to clients +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfigInfo { + pub available_multicast_addresses: Vec<String>, + pub multicast_port: u16, +} + +// Utility function to format packet data for display in test mode +pub fn format_packet_for_display(data: &[u8], max_bytes: usize) -> String { + let display_len = std::cmp::min(data.len(), max_bytes); + let mut result = String::new(); + + // Print hexadecimal representation + for (i, byte) in data.iter().take(display_len).enumerate() { + if i > 0 && i % 16 == 0 { + result.push('\n'); + } + result.push_str(&format!("{:02x} ", byte)); + } + + if data.len() > max_bytes { + result.push_str("\n... (truncated)"); + } + + result +} + +// Serialize a message to bytes +pub fn serialize_message(message: &Message) -> Result<Vec<u8>, serde_json::Error> { + serde_json::to_vec(message) +} + +// Deserialize bytes to a message with better error handling for partial messages +pub fn deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> { + // Find where valid JSON ends to handle cases where multiple messages + // or trailing data might be in the buffer + let mut deserializer = serde_json::Deserializer::from_slice(bytes); + let message = Message::deserialize(&mut deserializer)?; + + // Return the successfully parsed message + Ok(message) +} + +// Add a helper function to handle potential noise in streams +pub fn robust_deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> { + // First try the standard method + match deserialize_message(bytes) { + Ok(message) => Ok(message), + Err(e) => { + // If it fails due to trailing characters, try to parse just the valid JSON + if e.is_syntax() && e.to_string().contains("trailing characters") { + // Try to find where valid JSON ends by parsing incrementally + for i in (1..bytes.len()).rev() { + if let Ok(msg) = deserialize_message(&bytes[0..i]) { + return Ok(msg); + } + } + } + // If we couldn't recover, return the original error + Err(e) + } + } +} |