diff --git a/src/api_request_handler.rs b/src/api_request_handler.rs index e670b2b..292bb1d 100644 --- a/src/api_request_handler.rs +++ b/src/api_request_handler.rs @@ -16,7 +16,6 @@ use log; use serde_json::{json, Value}; use std::sync::Arc; use thiserror::Error; -use tokio::sync::Mutex; struct ParsedRequest { uri: String, @@ -41,7 +40,7 @@ struct APIResponse { } pub struct APIHandler { - notification_manager: Arc>, + notification_manager: Arc, base_url: String, } @@ -55,7 +54,7 @@ impl Clone for APIHandler { } impl APIHandler { - pub fn new(notification_manager: Arc>, base_url: String) -> Self { + pub fn new(notification_manager: Arc, base_url: String) -> Self { APIHandler { notification_manager, base_url, @@ -229,8 +228,7 @@ impl APIHandler { let body = req.body_json()?; if let Some(device_token) = body["deviceToken"].as_str() { - let notification_manager = self.notification_manager.lock().await; - notification_manager.save_user_device_info(req.authorized_pubkey, device_token)?; + self.notification_manager.save_user_device_info(req.authorized_pubkey, device_token).await?; return Ok(APIResponse { status: StatusCode::OK, body: json!({ "message": "User info saved successfully" }), @@ -250,8 +248,7 @@ impl APIHandler { let body: Value = req.body_json()?; if let Some(device_token) = body["deviceToken"].as_str() { - let notification_manager = self.notification_manager.lock().await; - notification_manager.remove_user_device_info(req.authorized_pubkey, device_token)?; + self.notification_manager.remove_user_device_info(req.authorized_pubkey, device_token).await?; return Ok(APIResponse { status: StatusCode::OK, body: json!({ "message": "User info removed successfully" }), diff --git a/src/main.rs b/src/main.rs index 4ab9f86..d695f15 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,6 @@ use hyper_util::rt::TokioIo; use std::sync::Arc; use tokio::net::TcpListener; -use tokio::sync::Mutex; mod notification_manager; use env_logger; use log; @@ -31,7 +30,7 @@ async fn main() -> Result<(), Box> { r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool"); // Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter. // This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections. - let notification_manager = Arc::new(Mutex::new( + let notification_manager = Arc::new( notification_manager::NotificationManager::new( pool, env.relay_url.clone(), @@ -43,7 +42,7 @@ async fn main() -> Result<(), Box> { ) .await .expect("Failed to create notification manager"), - )); + ); let api_handler = Arc::new(api_request_handler::APIHandler::new( notification_manager.clone(), env.api_base_url.clone(), diff --git a/src/notification_manager/notification_manager.rs b/src/notification_manager/notification_manager.rs index 66ff446..df8bc0d 100644 --- a/src/notification_manager/notification_manager.rs +++ b/src/notification_manager/notification_manager.rs @@ -5,7 +5,9 @@ use nostr::key::PublicKey; use nostr::types::Timestamp; use rusqlite; use rusqlite::params; +use tokio::sync::Mutex; use std::collections::HashSet; +use tokio; use super::mute_manager::MuteManager; use super::ExtendedEvent; @@ -18,11 +20,11 @@ use std::fs::File; // MARK: - NotificationManager pub struct NotificationManager { - db: r2d2::Pool, + db: Mutex>, apns_topic: String, - apns_client: Client, + apns_client: Mutex, - mute_manager: MuteManager, + mute_manager: Mutex, } impl NotificationManager { @@ -53,9 +55,9 @@ impl NotificationManager { Ok(Self { apns_topic, - apns_client: client, - db, - mute_manager, + apns_client: Mutex::new(client), + db: Mutex::new(db), + mute_manager: Mutex::new(mute_manager), }) } @@ -146,17 +148,20 @@ impl NotificationManager { for pubkey in pubkeys_to_notify { self.send_event_notifications_to_pubkey(event, &pubkey) .await?; - self.db.get()?.execute( - "INSERT OR REPLACE INTO notifications (id, event_id, pubkey, received_notification, sent_at) - VALUES (?, ?, ?, ?, ?)", - params![ - format!("{}:{}", event.id, pubkey), - event.id.to_sql_string(), - pubkey.to_sql_string(), - true, - nostr::Timestamp::now().to_sql_string(), - ], - )?; + { + let db_mutex_guard = self.db.lock().await; + db_mutex_guard.get()?.execute( + "INSERT OR REPLACE INTO notifications (id, event_id, pubkey, received_notification, sent_at) + VALUES (?, ?, ?, ?, ?)", + params![ + format!("{}:{}", event.id, pubkey), + event.id.to_sql_string(), + pubkey.to_sql_string(), + true, + nostr::Timestamp::now().to_sql_string(), + ], + )?; + } } Ok(()) } @@ -165,8 +170,8 @@ impl NotificationManager { &self, event: &Event, ) -> Result, Box> { - let notification_status = self.get_notification_status(event)?; - let relevant_pubkeys = self.pubkeys_relevant_to_event(event)?; + let notification_status = self.get_notification_status(event).await?; + let relevant_pubkeys = self.pubkeys_relevant_to_event(event).await?; let pubkeys_that_received_notification = notification_status.pubkeys_that_received_notification(); let relevant_pubkeys_yet_to_receive: HashSet = relevant_pubkeys @@ -177,10 +182,12 @@ impl NotificationManager { let mut pubkeys_to_notify = HashSet::new(); for pubkey in relevant_pubkeys_yet_to_receive { - let should_mute: bool = self - .mute_manager - .should_mute_notification_for_pubkey(event, &pubkey) - .await; + let should_mute: bool = { + let mute_manager_mutex_guard = self.mute_manager.lock().await; + mute_manager_mutex_guard + .should_mute_notification_for_pubkey(event, &pubkey) + .await + }; if !should_mute { pubkeys_to_notify.insert(pubkey); } @@ -188,7 +195,7 @@ impl NotificationManager { Ok(pubkeys_to_notify) } - fn pubkeys_relevant_to_event( + async fn pubkeys_relevant_to_event( &self, event: &Event, ) -> Result, Box> { @@ -196,17 +203,18 @@ impl NotificationManager { let referenced_event_ids = event.referenced_event_ids(); for referenced_event_id in referenced_event_ids { let pubkeys_relevant_to_referenced_event = - self.pubkeys_subscribed_to_event_id(&referenced_event_id)?; + self.pubkeys_subscribed_to_event_id(&referenced_event_id).await?; relevant_pubkeys.extend(pubkeys_relevant_to_referenced_event); } Ok(relevant_pubkeys) } - fn pubkeys_subscribed_to_event_id( + async fn pubkeys_subscribed_to_event_id( &self, event_id: &EventId, ) -> Result, Box> { - let connection = self.db.get()?; + let db_mutex_guard = self.db.lock().await; + let connection = db_mutex_guard.get()?; let mut stmt = connection.prepare("SELECT pubkey FROM notifications WHERE event_id = ?")?; let pubkeys = stmt .query_map([event_id.to_sql_string()], |row| row.get(0))? @@ -221,7 +229,7 @@ impl NotificationManager { event: &Event, pubkey: &PublicKey, ) -> Result<(), Box> { - let user_device_tokens = self.get_user_device_tokens(pubkey)?; + let user_device_tokens = self.get_user_device_tokens(pubkey).await?; for device_token in user_device_tokens { self.send_event_notification_to_device_token(event, &device_token) .await?; @@ -229,11 +237,12 @@ impl NotificationManager { Ok(()) } - fn get_user_device_tokens( + async fn get_user_device_tokens( &self, pubkey: &PublicKey, ) -> Result, Box> { - let connection = self.db.get()?; + let db_mutex_guard = self.db.lock().await; + let connection = db_mutex_guard.get()?; let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?; let device_tokens = stmt .query_map([pubkey.to_sql_string()], |row| row.get(0))? @@ -242,11 +251,12 @@ impl NotificationManager { Ok(device_tokens) } - fn get_notification_status( + async fn get_notification_status( &self, event: &Event, ) -> Result> { - let connection = self.db.get()?; + let db_mutex_guard = self.db.lock().await; + let connection = db_mutex_guard.get()?; let mut stmt = connection.prepare( "SELECT pubkey, received_notification FROM notifications WHERE event_id = ?", )?; @@ -291,7 +301,8 @@ impl NotificationManager { let _ = payload.add_custom_data("nostr_event", event); payload.options.apns_topic = Some(self.apns_topic.as_str()); - let _response = self.apns_client.send(payload).await?; + let apns_client_mutex_guard = self.apns_client.lock().await; + let _response = apns_client_mutex_guard.send(payload).await?; log::info!("Notification sent to device token: {}", device_token); @@ -305,13 +316,14 @@ impl NotificationManager { (title, subtitle, body) } - pub fn save_user_device_info( + pub async fn save_user_device_info( &self, pubkey: nostr::PublicKey, device_token: &str, ) -> Result<(), Box> { let current_time_unix = Timestamp::now(); - self.db.get()?.execute( + let db_mutex_guard = self.db.lock().await; + db_mutex_guard.get()?.execute( "INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)", params![ format!("{}:{}", pubkey.to_sql_string(), device_token), @@ -323,12 +335,13 @@ impl NotificationManager { Ok(()) } - pub fn remove_user_device_info( + pub async fn remove_user_device_info( &self, pubkey: nostr::PublicKey, device_token: &str, ) -> Result<(), Box> { - self.db.get()?.execute( + let db_mutex_guard = self.db.lock().await; + db_mutex_guard.get()?.execute( "DELETE FROM user_info WHERE pubkey = ? AND device_token = ?", params![pubkey.to_sql_string(), device_token], )?; diff --git a/src/relay_connection.rs b/src/relay_connection.rs index 3fc154b..644f06c 100644 --- a/src/relay_connection.rs +++ b/src/relay_connection.rs @@ -11,20 +11,19 @@ use serde_json::Value; use std::fmt::{self, Debug}; use std::str::FromStr; use std::sync::Arc; -use tokio::sync::Mutex; use tungstenite::{Error, Message}; const MAX_CONSECUTIVE_ERRORS: u32 = 10; pub struct RelayConnection { - notification_manager: Arc>, + notification_manager: Arc, } impl RelayConnection { // MARK: - Initializers pub async fn new( - notification_manager: Arc>, + notification_manager: Arc, ) -> Result> { log::info!("Accepted websocket connection"); Ok(RelayConnection { @@ -34,7 +33,7 @@ impl RelayConnection { pub async fn run( websocket: HyperWebsocket, - notification_manager: Arc>, + notification_manager: Arc, ) -> Result<(), Box> { let mut connection = RelayConnection::new(notification_manager).await?; Ok(connection.run_loop(websocket).await?) @@ -104,11 +103,7 @@ impl RelayConnection { match message { ClientMessage::Event(event) => { log::info!("Received event: {:?}", event); - { - // TODO: Reduce resource contention by reducing the scope of the mutex into NotificationManager logic. - let mutex_guard = self.notification_manager.lock().await; - mutex_guard.send_notifications_if_needed(&event).await?; - }; // Only hold the mutex for as little time as possible. + self.notification_manager.send_notifications_if_needed(&event).await?; let notice_message = format!("blocked: This relay does not store events"); let response = RelayMessage::Ok { event_id: event.id,