aboutsummaryrefslogtreecommitdiff
path: root/src/protocol.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/protocol.rs')
-rw-r--r--src/protocol.rs129
1 files changed, 129 insertions, 0 deletions
diff --git a/src/protocol.rs b/src/protocol.rs
new file mode 100644
index 0000000..340bbd3
--- /dev/null
+++ b/src/protocol.rs
@@ -0,0 +1,129 @@
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use std::net::SocketAddr;
+
+// Information about a multicast group
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct MulticastGroupInfo {
+ pub address: String,
+ pub port: u16,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub additional_ports: Option<Vec<u16>>, // Additional ports if a range was defined
+}
+
+// Protocol messages exchanged between client and server
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub enum Message {
+ // Initial authentication message
+ AuthRequest {
+ client_nonce: String,
+ },
+
+ // Authentication response
+ AuthResponse {
+ server_nonce: String,
+ auth_token: String, // HMAC(secret, client_nonce + server_nonce)
+ },
+
+ // Final auth confirmation
+ AuthConfirm {
+ auth_token: String, // HMAC(secret, server_nonce + client_nonce)
+ },
+
+ // Request information about available multicast groups
+ MulticastGroupsRequest,
+
+ // Response with available multicast groups
+ MulticastGroupsResponse {
+ groups: HashMap<String, MulticastGroupInfo>,
+ },
+
+ // Subscribe to specific multicast groups
+ Subscribe {
+ group_ids: Vec<String>,
+ },
+
+ // Multicast packet forwarded to client
+ MulticastPacket {
+ group_id: String,
+ source: SocketAddr,
+ destination: String, // Destination multicast address
+ port: u16, // Destination port
+ data: Vec<u8>,
+ },
+
+ // Configuration response with current settings
+ ConfigResponse {
+ config: ServerConfigInfo,
+ },
+
+ // New ping/status message for checking connections
+ PingStatus {
+ timestamp: u64,
+ status: String,
+ },
+}
+
+// Server configuration info that can be sent to clients
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ServerConfigInfo {
+ pub available_multicast_addresses: Vec<String>,
+ pub multicast_port: u16,
+}
+
+// Utility function to format packet data for display in test mode
+pub fn format_packet_for_display(data: &[u8], max_bytes: usize) -> String {
+ let display_len = std::cmp::min(data.len(), max_bytes);
+ let mut result = String::new();
+
+ // Print hexadecimal representation
+ for (i, byte) in data.iter().take(display_len).enumerate() {
+ if i > 0 && i % 16 == 0 {
+ result.push('\n');
+ }
+ result.push_str(&format!("{:02x} ", byte));
+ }
+
+ if data.len() > max_bytes {
+ result.push_str("\n... (truncated)");
+ }
+
+ result
+}
+
+// Serialize a message to bytes
+pub fn serialize_message(message: &Message) -> Result<Vec<u8>, serde_json::Error> {
+ serde_json::to_vec(message)
+}
+
+// Deserialize bytes to a message with better error handling for partial messages
+pub fn deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
+ // Find where valid JSON ends to handle cases where multiple messages
+ // or trailing data might be in the buffer
+ let mut deserializer = serde_json::Deserializer::from_slice(bytes);
+ let message = Message::deserialize(&mut deserializer)?;
+
+ // Return the successfully parsed message
+ Ok(message)
+}
+
+// Add a helper function to handle potential noise in streams
+pub fn robust_deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
+ // First try the standard method
+ match deserialize_message(bytes) {
+ Ok(message) => Ok(message),
+ Err(e) => {
+ // If it fails due to trailing characters, try to parse just the valid JSON
+ if e.is_syntax() && e.to_string().contains("trailing characters") {
+ // Try to find where valid JSON ends by parsing incrementally
+ for i in (1..bytes.len()).rev() {
+ if let Ok(msg) = deserialize_message(&bytes[0..i]) {
+ return Ok(msg);
+ }
+ }
+ }
+ // If we couldn't recover, return the original error
+ Err(e)
+ }
+ }
+}