aboutsummaryrefslogtreecommitdiff
path: root/src/protocol.rs
blob: 340bbd308dd027be395c8ad4caee8d0225d68103 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;

// Information about a multicast group
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MulticastGroupInfo {
    pub address: String,
    pub port: u16,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub additional_ports: Option<Vec<u16>>, // Additional ports if a range was defined
}

// Protocol messages exchanged between client and server
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Message {
    // Initial authentication message
    AuthRequest {
        client_nonce: String,
    },
    
    // Authentication response
    AuthResponse {
        server_nonce: String,
        auth_token: String,  // HMAC(secret, client_nonce + server_nonce)
    },
    
    // Final auth confirmation
    AuthConfirm {
        auth_token: String,  // HMAC(secret, server_nonce + client_nonce)
    },
    
    // Request information about available multicast groups
    MulticastGroupsRequest,
    
    // Response with available multicast groups
    MulticastGroupsResponse {
        groups: HashMap<String, MulticastGroupInfo>,
    },
    
    // Subscribe to specific multicast groups
    Subscribe {
        group_ids: Vec<String>,
    },
    
    // Multicast packet forwarded to client
    MulticastPacket {
        group_id: String,
        source: SocketAddr,
        destination: String,  // Destination multicast address
        port: u16,            // Destination port
        data: Vec<u8>,
    },

    // Configuration response with current settings
    ConfigResponse {
        config: ServerConfigInfo,
    },
    
    // New ping/status message for checking connections
    PingStatus {
        timestamp: u64,
        status: String,
    },
}

// Server configuration info that can be sent to clients
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfigInfo {
    pub available_multicast_addresses: Vec<String>,
    pub multicast_port: u16,
}

// Utility function to format packet data for display in test mode
pub fn format_packet_for_display(data: &[u8], max_bytes: usize) -> String {
    let display_len = std::cmp::min(data.len(), max_bytes);
    let mut result = String::new();
    
    // Print hexadecimal representation
    for (i, byte) in data.iter().take(display_len).enumerate() {
        if i > 0 && i % 16 == 0 {
            result.push('\n');
        }
        result.push_str(&format!("{:02x} ", byte));
    }
    
    if data.len() > max_bytes {
        result.push_str("\n... (truncated)");
    }
    
    result
}

// Serialize a message to bytes
pub fn serialize_message(message: &Message) -> Result<Vec<u8>, serde_json::Error> {
    serde_json::to_vec(message)
}

// Deserialize bytes to a message with better error handling for partial messages
pub fn deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
    // Find where valid JSON ends to handle cases where multiple messages
    // or trailing data might be in the buffer
    let mut deserializer = serde_json::Deserializer::from_slice(bytes);
    let message = Message::deserialize(&mut deserializer)?;
    
    // Return the successfully parsed message
    Ok(message)
}

// Add a helper function to handle potential noise in streams
pub fn robust_deserialize_message(bytes: &[u8]) -> Result<Message, serde_json::Error> {
    // First try the standard method
    match deserialize_message(bytes) {
        Ok(message) => Ok(message),
        Err(e) => {
            // If it fails due to trailing characters, try to parse just the valid JSON
            if e.is_syntax() && e.to_string().contains("trailing characters") {
                // Try to find where valid JSON ends by parsing incrementally
                for i in (1..bytes.len()).rev() {
                    if let Ok(msg) = deserialize_message(&bytes[0..i]) {
                        return Ok(msg);
                    }
                }
            }
            // If we couldn't recover, return the original error
            Err(e)
        }
    }
}