aboutsummaryrefslogtreecommitdiff
path: root/src/db/pool.rs
blob: 43907397fe4aeb760c8fae3077d82d8ed4e94d53 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// Database connection pool module
use crate::config::DatabaseConfig;
use anyhow::{Context, Result as AnyResult};
use sqlx::{Error as SqlxError, MySqlPool};
use std::sync::{
    atomic::{AtomicBool, Ordering},
    Arc,
};

#[derive(Debug)]
pub enum DatabaseInitError {
    Fatal(anyhow::Error),
    Retryable(anyhow::Error),
}

impl DatabaseInitError {
    fn fatal(err: anyhow::Error) -> Self {
        Self::Fatal(err)
    }

    fn retryable(err: anyhow::Error) -> Self {
        Self::Retryable(err)
    }
}

#[derive(Clone)]
pub struct Database {
    pool: MySqlPool,
    availability: Arc<AtomicBool>,
}

impl Database {
    pub async fn new(config: &DatabaseConfig) -> Result<Self, DatabaseInitError> {
        let database_url = format!(
            "mysql://{}:{}@{}:{}/{}",
            config.username, config.password, config.host, config.port, config.database
        );

        let pool = sqlx::mysql::MySqlPoolOptions::new()
            .min_connections(config.min_connections)
            .max_connections(config.max_connections)
            .acquire_timeout(std::time::Duration::from_secs(
                config.connection_timeout_seconds,
            ))
            .connect(&database_url)
            .await
            .map_err(|err| map_sqlx_error(err, "Failed to connect to database"))?;

        // Test the connection
        sqlx::query("SELECT 1")
            .execute(&pool)
            .await
            .map_err(|err| map_sqlx_error(err, "Failed to test database connection"))?;

        Ok(Database {
            pool,
            availability: Arc::new(AtomicBool::new(true)),
        })
    }

    pub fn pool(&self) -> &MySqlPool {
        &self.pool
    }

    pub fn is_available(&self) -> bool {
        self.availability.load(Ordering::Relaxed)
    }

    pub fn mark_available(&self) {
        self.availability.store(true, Ordering::Relaxed);
    }

    pub fn mark_unavailable(&self) {
        self.availability.store(false, Ordering::Relaxed);
    }

    pub async fn set_current_user(&self, user_id: i32) -> AnyResult<()> {
        sqlx::query("SET @current_user_id = ?")
            .bind(user_id)
            .execute(&self.pool)
            .await
            .context("Failed to set current user ID")?;

        Ok(())
    }

    pub async fn close(&self) {
        self.pool.close().await;
    }
}

fn map_sqlx_error(err: SqlxError, context: &str) -> DatabaseInitError {
    let retryable = is_retryable_sqlx_error(&err);
    let wrapped = anyhow::Error::new(err).context(context.to_string());
    if retryable {
        DatabaseInitError::retryable(wrapped)
    } else {
        DatabaseInitError::fatal(wrapped)
    }
}

fn is_retryable_sqlx_error(err: &SqlxError) -> bool {
    matches!(
        err,
        SqlxError::Io(_)
            | SqlxError::PoolTimedOut
            | SqlxError::PoolClosed
            | SqlxError::WorkerCrashed
    )
}