diff --git a/Cargo.lock b/Cargo.lock index ca5e842..d971bf8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1562,9 +1562,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", diff --git a/src/api_request_handler.rs b/src/api_request_handler.rs index 292bb1d..b55f646 100644 --- a/src/api_request_handler.rs +++ b/src/api_request_handler.rs @@ -1,4 +1,5 @@ use crate::nip98_auth; +use crate::notification_manager::notification_manager::UserNotificationSettings; use crate::relay_connection::RelayConnection; use http_body_util::Full; use hyper::body::Buf; @@ -9,50 +10,21 @@ use hyper_tungstenite; use http_body_util::BodyExt; use nostr; +use serde_json::from_value; use crate::notification_manager::NotificationManager; use hyper::Method; use log; use serde_json::{json, Value}; +use std::collections::HashMap; use std::sync::Arc; use thiserror::Error; -struct ParsedRequest { - uri: String, - method: Method, - body_bytes: Option>, - authorized_pubkey: nostr::PublicKey, -} - -impl ParsedRequest { - fn body_json(&self) -> Result> { - if let Some(body_bytes) = &self.body_bytes { - Ok(serde_json::from_slice(body_bytes)?) - } else { - Ok(json!({})) - } - } -} - -struct APIResponse { - status: StatusCode, - body: Value, -} - pub struct APIHandler { notification_manager: Arc, base_url: String, } -impl Clone for APIHandler { - fn clone(&self) -> Self { - APIHandler { - notification_manager: self.notification_manager.clone(), - base_url: self.base_url.clone(), - } - } -} - impl APIHandler { pub fn new(notification_manager: Arc, base_url: String) -> Self { APIHandler { @@ -60,6 +32,8 @@ impl APIHandler { base_url, } } + + // MARK: - HTTP handling pub async fn handle_http_request( &self, @@ -185,22 +159,37 @@ impl APIHandler { authorized_pubkey, }) } + + // MARK: - Router async fn handle_parsed_http_request( &self, parsed_request: &ParsedRequest, ) -> Result> { - match (&parsed_request.method, parsed_request.uri.as_str()) { - (&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await, - (&Method::POST, "/user-info/remove") => { - self.handle_user_info_remove(parsed_request).await - } - _ => Ok(APIResponse { - status: StatusCode::NOT_FOUND, - body: json!({ "error": "Not found" }), - }), + + if let Some(url_params) = route_match(&Method::PUT, "/user-info/:pubkey/:deviceToken", &parsed_request) { + return self.handle_user_info(parsed_request, &url_params).await; } + + if let Some(url_params) = route_match(&Method::DELETE, "/user-info/:pubkey/:deviceToken", &parsed_request) { + return self.handle_user_info_remove(parsed_request, &url_params).await; + } + + if let Some(url_params) = route_match(&Method::GET, "/user-info/:pubkey/:deviceToken/preferences", &parsed_request) { + return self.get_user_settings(parsed_request, &url_params).await; + } + + if let Some(url_params) = route_match(&Method::PUT, "/user-info/:pubkey/:deviceToken/preferences", &parsed_request) { + return self.set_user_settings(parsed_request, &url_params).await; + } + + Ok(APIResponse { + status: StatusCode::NOT_FOUND, + body: json!({ "error": "Not found" }), + }) } + + // MARK: - Authentication async fn authenticate( &self, @@ -220,51 +209,283 @@ impl APIHandler { ) .await) } + + // MARK: - Endpoint handlers async fn handle_user_info( &self, req: &ParsedRequest, + url_params: &HashMap<&str, String>, ) -> Result> { - let body = req.body_json()?; - - if let Some(device_token) = body["deviceToken"].as_str() { - self.notification_manager.save_user_device_info(req.authorized_pubkey, device_token).await?; - return Ok(APIResponse { - status: StatusCode::OK, - body: json!({ "message": "User info saved successfully" }), - }); - } else { - return Ok(APIResponse { + // Early return if `deviceToken` is missing + let device_token = match url_params.get("deviceToken") { + Some(token) => token, + None => return Ok(APIResponse { status: StatusCode::BAD_REQUEST, - body: json!({ "error": "deviceToken is required" }), + body: json!({ "error": "deviceToken is required on the URL" }), + }), + }; + + // Early return if `pubkey` is missing + let pubkey = match url_params.get("pubkey") { + Some(key) => key, + None => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "pubkey is required on the URL" }), + }), + }; + + // Validate the `pubkey` and prepare it for use + let pubkey = match nostr::PublicKey::from_hex(pubkey) { + Ok(key) => key, + Err(_) => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "Invalid pubkey" }), + }), + }; + + // Early return if `pubkey` does not match `req.authorized_pubkey` + if pubkey != req.authorized_pubkey { + return Ok(APIResponse { + status: StatusCode::FORBIDDEN, + body: json!({ "error": "Forbidden" }), }); } + + // Proceed with the main logic after passing all checks + self.notification_manager.save_user_device_info(pubkey, device_token).await?; + Ok(APIResponse { + status: StatusCode::OK, + body: json!({ "message": "User info saved successfully" }), + }) } async fn handle_user_info_remove( &self, req: &ParsedRequest, + url_params: &HashMap<&str, String>, ) -> Result> { - let body: Value = req.body_json()?; - - if let Some(device_token) = body["deviceToken"].as_str() { - self.notification_manager.remove_user_device_info(req.authorized_pubkey, device_token).await?; - return Ok(APIResponse { - status: StatusCode::OK, - body: json!({ "message": "User info removed successfully" }), - }); - } else { - return Ok(APIResponse { + // Early return if `deviceToken` is missing + let device_token = match url_params.get("deviceToken") { + Some(token) => token, + None => return Ok(APIResponse { status: StatusCode::BAD_REQUEST, - body: json!({ "error": "deviceToken is required" }), + body: json!({ "error": "deviceToken is required on the URL" }), + }), + }; + + // Early return if `pubkey` is missing + let pubkey = match url_params.get("pubkey") { + Some(key) => key, + None => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "pubkey is required on the URL" }), + }), + }; + + // Validate the `pubkey` and prepare it for use + let pubkey = match nostr::PublicKey::from_hex(pubkey) { + Ok(key) => key, + Err(_) => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "Invalid pubkey" }), + }), + }; + + // Early return if `pubkey` does not match `req.authorized_pubkey` + if pubkey != req.authorized_pubkey { + return Ok(APIResponse { + status: StatusCode::FORBIDDEN, + body: json!({ "error": "Forbidden" }), }); } + + // Proceed with the main logic after passing all checks + self.notification_manager.remove_user_device_info(pubkey, device_token).await?; + + Ok(APIResponse { + status: StatusCode::OK, + body: json!({ "message": "User info removed successfully" }), + }) + } + + async fn set_user_settings( + &self, + req: &ParsedRequest, + url_params: &HashMap<&str, String>, + ) -> Result> { + // Early return if `deviceToken` is missing + let device_token = match url_params.get("deviceToken") { + Some(token) => token, + None => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "deviceToken is required on the URL" }), + }), + }; + + // Early return if `pubkey` is missing + let pubkey = match url_params.get("pubkey") { + Some(key) => key, + None => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "pubkey is required on the URL" }), + }), + }; + + // Validate the `pubkey` and prepare it for use + let pubkey = match nostr::PublicKey::from_hex(pubkey) { + Ok(key) => key, + Err(_) => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "Invalid pubkey" }), + }), + }; + + // Early return if `pubkey` does not match `req.authorized_pubkey` + if pubkey != req.authorized_pubkey { + return Ok(APIResponse { + status: StatusCode::FORBIDDEN, + body: json!({ "error": "Forbidden" }), + }); + } + + // Proceed with the main logic after passing all checks + let body = req.body_json()?; + + let settings: UserNotificationSettings = match from_value(body.clone()) { + Ok(settings) => settings, + Err(_) => { + return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "Invalid settings" }), + }); + } + }; + + self.notification_manager.save_user_notification_settings(&req.authorized_pubkey, device_token.to_string(), settings).await?; + return Ok(APIResponse { + status: StatusCode::OK, + body: json!({ "message": "User settings saved successfully" }), + }); + } + + async fn get_user_settings( + &self, + req: &ParsedRequest, + url_params: &HashMap<&str, String>, + ) -> Result> { + // Early return if `deviceToken` is missing + let device_token = match url_params.get("deviceToken") { + Some(token) => token, + None => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "deviceToken is required on the URL" }), + }), + }; + + // Early return if `pubkey` is missing + let pubkey = match url_params.get("pubkey") { + Some(key) => key, + None => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "pubkey is required on the URL" }), + }), + }; + + // Validate the `pubkey` and prepare it for use + let pubkey = match nostr::PublicKey::from_hex(pubkey) { + Ok(key) => key, + Err(_) => return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "Invalid pubkey" }), + }), + }; + + // Early return if `pubkey` does not match `req.authorized_pubkey` + if pubkey != req.authorized_pubkey { + return Ok(APIResponse { + status: StatusCode::FORBIDDEN, + body: json!({ "error": "Forbidden" }), + }); + } + + // Proceed with the main logic after passing all checks + let settings = self.notification_manager.get_user_notification_settings(&req.authorized_pubkey, device_token.to_string()).await?; + + Ok(APIResponse { + status: StatusCode::OK, + body: json!(settings), + }) + } +} + +// MARK: - Extensions + +impl Clone for APIHandler { + fn clone(&self) -> Self { + APIHandler { + notification_manager: self.notification_manager.clone(), + base_url: self.base_url.clone(), + } } } +// MARK: - Helper types + // Define enum error types including authentication error #[derive(Debug, Error)] enum APIError { #[error("Authentication error: {0}")] AuthenticationError(String), } + +struct ParsedRequest { + uri: String, + method: Method, + body_bytes: Option>, + authorized_pubkey: nostr::PublicKey, +} + +impl ParsedRequest { + fn body_json(&self) -> Result> { + if let Some(body_bytes) = &self.body_bytes { + Ok(serde_json::from_slice(body_bytes)?) + } else { + Ok(json!({})) + } + } +} + +struct APIResponse { + status: StatusCode, + body: Value, +} + +// MARK: - Helper functions + +/// Matches the request to a specified route, returning a hashmap of the route parameters +/// e.g. GET /user/:id/info route against request GET /user/123/info matches to { "id": "123" } +fn route_match<'a>(method: &Method, path: &'a str, req: &ParsedRequest) -> Option> { + if method != req.method { + return None; + } + let mut params = HashMap::new(); + let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); + let req_segments: Vec<&str> = req.uri.split('/').filter(|s| !s.is_empty()).collect(); + + if path_segments.len() != req_segments.len() { + return None; + } + + for (i, segment) in path_segments.iter().enumerate() { + if segment.starts_with(':') { + let key = &segment[1..]; + let value = req_segments[i].to_string(); + params.insert(key, value); + } else if segment != &req_segments[i] { + return None; + } + } + + Some(params) +} diff --git a/src/lib.rs b/src/lib.rs index f4339cb..293ef64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1 +1,2 @@ pub mod notification_manager; +mod utils; diff --git a/src/main.rs b/src/main.rs index d695f15..b0d7c3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ mod notepush_env; use notepush_env::NotePushEnv; mod api_request_handler; mod nip98_auth; +mod utils; #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/src/nip98_auth.rs b/src/nip98_auth.rs index 396c7ca..9c83d7b 100644 --- a/src/nip98_auth.rs +++ b/src/nip98_auth.rs @@ -3,8 +3,8 @@ use nostr; use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash; use nostr::bitcoin::hashes::Hash; use nostr::util::hex; -use nostr::Timestamp; use serde_json::Value; +use super::utils::time_delta::TimeDelta; pub async fn nip98_verify_auth_header( auth_header: String, @@ -106,36 +106,3 @@ pub async fn nip98_verify_auth_header( Ok(note.pubkey) } - -struct TimeDelta { - delta_abs_seconds: u64, - negative: bool, -} - -impl TimeDelta { - /// Safely calculate the difference between two timestamps in seconds - /// This function is safer against overflows than subtracting the timestamps directly - fn subtracting(t1: Timestamp, t2: Timestamp) -> TimeDelta { - if t1 > t2 { - TimeDelta { - delta_abs_seconds: (t1 - t2).as_u64(), - negative: false, - } - } else { - TimeDelta { - delta_abs_seconds: (t2 - t1).as_u64(), - negative: true, - } - } - } -} - -impl std::fmt::Display for TimeDelta { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - if self.negative { - write!(f, "-{}", self.delta_abs_seconds) - } else { - write!(f, "{}", self.delta_abs_seconds) - } - } -} diff --git a/src/notification_manager/mod.rs b/src/notification_manager/mod.rs index 653c48e..10b6ab8 100644 --- a/src/notification_manager/mod.rs +++ b/src/notification_manager/mod.rs @@ -1,7 +1,8 @@ -pub mod mute_manager; +pub mod nostr_network_helper; mod nostr_event_extensions; +mod nostr_event_cache; pub mod notification_manager; -pub use mute_manager::MuteManager; +pub use nostr_network_helper::NostrNetworkHelper; use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible}; pub use notification_manager::NotificationManager; diff --git a/src/notification_manager/mute_manager.rs b/src/notification_manager/mute_manager.rs deleted file mode 100644 index 8b34b97..0000000 --- a/src/notification_manager/mute_manager.rs +++ /dev/null @@ -1,128 +0,0 @@ -use super::ExtendedEvent; -use nostr_sdk::prelude::*; -use tokio::time::{timeout, Duration}; - -pub struct MuteManager { - client: Client, -} - -impl MuteManager { - pub async fn new(relay_url: String) -> Result> { - let client = Client::new(&Keys::generate()); - client.add_relay(relay_url.clone()).await?; - client.connect().await; - Ok(MuteManager { client }) - } - - pub async fn should_mute_notification_for_pubkey( - &self, - event: &Event, - pubkey: &PublicKey, - ) -> bool { - log::debug!( - "Checking if event {:?} should be muted for pubkey {:?}", - event, - pubkey - ); - if let Some(mute_list) = self.get_public_mute_list(pubkey).await { - for tag in mute_list.tags() { - match tag.kind() { - TagKind::SingleLetter(SingleLetterTag { - character: Alphabet::P, - uppercase: false, - }) => { - let tagged_pubkey: Option = - tag.content().and_then(|h| PublicKey::from_hex(h).ok()); - if let Some(tagged_pubkey) = tagged_pubkey { - if event.pubkey == tagged_pubkey { - return true; - } - } - } - TagKind::SingleLetter(SingleLetterTag { - character: Alphabet::E, - uppercase: false, - }) => { - let tagged_event_id: Option = - tag.content().and_then(|h| EventId::from_hex(h).ok()); - if let Some(tagged_event_id) = tagged_event_id { - if event.id == tagged_event_id - || event.referenced_event_ids().contains(&tagged_event_id) - { - return true; - } - } - } - TagKind::SingleLetter(SingleLetterTag { - character: Alphabet::T, - uppercase: false, - }) => { - let tagged_hashtag: Option = tag.content().map(|h| h.to_string()); - if let Some(tagged_hashtag) = tagged_hashtag { - let tags_content = - event.get_tags_content(TagKind::SingleLetter(SingleLetterTag { - character: Alphabet::T, - uppercase: false, - })); - let should_mute = tags_content.iter().any(|t| t == &tagged_hashtag); - return should_mute; - } - } - TagKind::Word => { - let tagged_word: Option = tag.content().map(|h| h.to_string()); - if let Some(tagged_word) = tagged_word { - if event - .content - .to_lowercase() - .contains(&tagged_word.to_lowercase()) - { - return true; - } - } - } - _ => {} - } - } - } - false - } - - pub async fn get_public_mute_list(&self, pubkey: &PublicKey) -> Option { - let subscription_filter = Filter::new() - .kinds(vec![Kind::MuteList]) - .authors(vec![pubkey.clone()]) - .limit(1); - - let this_subscription_id = self - .client - .subscribe(Vec::from([subscription_filter]), None) - .await; - - let mut mute_list: Option = None; - let mut notifications = self.client.notifications(); - - let timeout_duration = Duration::from_secs(10); - while let Ok(result) = timeout(timeout_duration, notifications.recv()).await { - if let Ok(notification) = result { - if let RelayPoolNotification::Event { - subscription_id, - event, - .. - } = notification - { - if this_subscription_id == subscription_id && event.kind == Kind::MuteList { - mute_list = Some((*event).clone()); - break; - } - } - } - } - - if mute_list.is_none() { - log::debug!("Mute list not found for pubkey {:?}", pubkey); - } - - self.client.unsubscribe(this_subscription_id).await; - mute_list - } -} diff --git a/src/notification_manager/nostr_event_cache.rs b/src/notification_manager/nostr_event_cache.rs new file mode 100644 index 0000000..b2b9840 --- /dev/null +++ b/src/notification_manager/nostr_event_cache.rs @@ -0,0 +1,144 @@ +use crate::utils::time_delta::TimeDelta; +use tokio::time::Duration; +use nostr_sdk::prelude::*; +use std::collections::HashMap; +use std::sync::Arc; +use log; + +use super::nostr_event_extensions::MaybeConvertibleToMuteList; + +struct CacheEntry { + event: Option, // `None` means the event does not exist as far as we know (It does NOT mean expired) + added_at: nostr::Timestamp, +} + +impl CacheEntry { + fn is_expired(&self, max_age: Duration) -> bool { + let time_delta = TimeDelta::subtracting(nostr::Timestamp::now(), self.added_at); + time_delta.negative || (time_delta.delta_abs_seconds > max_age.as_secs()) + } +} + +pub struct Cache { + entries: HashMap>, + mute_lists: HashMap>, + contact_lists: HashMap>, + max_age: Duration, +} + +impl Cache { + // MARK: - Initialization + + pub fn new(max_age: Duration) -> Self { + Cache { + entries: HashMap::new(), + mute_lists: HashMap::new(), + contact_lists: HashMap::new(), + max_age, + } + } + + // MARK: - Adding items to the cache + + pub fn add_optional_mute_list_with_author(&mut self, author: &PublicKey, mute_list: Option) { + if let Some(mute_list) = mute_list { + self.add_event(mute_list); + } else { + self.mute_lists.insert( + author.clone(), + Arc::new(CacheEntry { + event: None, + added_at: nostr::Timestamp::now(), + }), + ); + } + } + + pub fn add_optional_contact_list_with_author(&mut self, author: &PublicKey, contact_list: Option) { + if let Some(contact_list) = contact_list { + self.add_event(contact_list); + } else { + self.contact_lists.insert( + author.clone(), + Arc::new(CacheEntry { + event: None, + added_at: nostr::Timestamp::now(), + }), + ); + } + } + + pub fn add_event(&mut self, event: Event) { + let entry = Arc::new(CacheEntry { + event: Some(event.clone()), + added_at: nostr::Timestamp::now(), + }); + self.entries.insert(event.id.clone(), entry.clone()); + + match event.kind { + Kind::MuteList => { + self.mute_lists.insert(event.pubkey.clone(), entry.clone()); + log::debug!("Added mute list to the cache. Event ID: {}", event.id.to_hex()); + } + Kind::ContactList => { + self.contact_lists + .insert(event.pubkey.clone(), entry.clone()); + log::debug!("Added contact list to the cache. Event ID: {}", event.id.to_hex()); + } + _ => { + log::debug!("Added event to the cache. Event ID: {}", event.id.to_hex()); + } + } + } + + // MARK: - Fetching items from the cache + + pub fn get_mute_list(&mut self, pubkey: &PublicKey) -> Result, CacheError> { + if let Some(entry) = self.mute_lists.get(pubkey) { + let entry = entry.clone(); // Clone the Arc to avoid borrowing issues + if !entry.is_expired(self.max_age) { + if let Some(event) = entry.event.clone() { + return Ok(event.to_mute_list()); + } + } else { + log::debug!("Mute list for pubkey {} is expired, removing it from the cache", pubkey.to_hex()); + self.mute_lists.remove(pubkey); + self.remove_event_from_all_maps(&entry.event); + } + } + Err(CacheError::NotFound) + } + + pub fn get_contact_list(&mut self, pubkey: &PublicKey) -> Result, CacheError> { + if let Some(entry) = self.contact_lists.get(pubkey) { + let entry = entry.clone(); // Clone the Arc to avoid borrowing issues + if !entry.is_expired(self.max_age) { + return Ok(entry.event.clone()); + } else { + log::debug!("Contact list for pubkey {} is expired, removing it from the cache", pubkey.to_hex()); + self.contact_lists.remove(pubkey); + self.remove_event_from_all_maps(&entry.event); + } + } + Err(CacheError::NotFound) + } + + // MARK: - Removing items from the cache + + fn remove_event_from_all_maps(&mut self, event: &Option) { + if let Some(event) = event { + let event_id = event.id.clone(); + let pubkey = event.pubkey.clone(); + self.entries.remove(&event_id); + self.mute_lists.remove(&pubkey); + self.contact_lists.remove(&pubkey); + } + // We can't remove an event from all maps if the event does not exist + } +} + +// Error type +#[derive(Debug)] +pub enum CacheError { + NotFound, +} diff --git a/src/notification_manager/nostr_event_extensions.rs b/src/notification_manager/nostr_event_extensions.rs index ad9fca9..b17ea71 100644 --- a/src/notification_manager/nostr_event_extensions.rs +++ b/src/notification_manager/nostr_event_extensions.rs @@ -1,4 +1,5 @@ -use nostr::{self, key::PublicKey, Alphabet, SingleLetterTag, TagKind::SingleLetter}; +use nostr::{self, key::PublicKey, nips::nip51::MuteList, Alphabet, SingleLetterTag, TagKind::SingleLetter}; +use nostr_sdk::{Kind, TagKind}; /// Temporary scaffolding of old methods that have not been ported to use native Event methods pub trait ExtendedEvent { @@ -13,6 +14,9 @@ pub trait ExtendedEvent { /// Retrieves a set of event IDs referenced by the note fn referenced_event_ids(&self) -> std::collections::HashSet; + + /// Retrieves a set of hashtags (t tags) referenced by the note + fn referenced_hashtags(&self) -> std::collections::HashSet; } // This is a wrapper around the Event type from strfry-policies, which adds some useful methods @@ -44,6 +48,14 @@ impl ExtendedEvent for nostr::Event { .filter_map(|tag| nostr::EventId::from_hex(tag).ok()) .collect() } + + /// Retrieves a set of hashtags (t tags) referenced by the note + fn referenced_hashtags(&self) -> std::collections::HashSet { + self.get_tags_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::T))) + .iter() + .map(|tag| tag.to_string()) + .collect() + } } // MARK: - SQL String Convertible @@ -85,3 +97,21 @@ impl SqlStringConvertible for nostr::Timestamp { Ok(nostr::Timestamp::from(u64_timestamp)) } } + +pub trait MaybeConvertibleToMuteList { + fn to_mute_list(&self) -> Option; +} + +impl MaybeConvertibleToMuteList for nostr::Event { + fn to_mute_list(&self) -> Option { + if self.kind != Kind::MuteList { + return None; + } + Some(MuteList { + public_keys: self.referenced_pubkeys().iter().map(|pk| pk.clone()).collect(), + hashtags: self.referenced_hashtags().iter().map(|tag| tag.clone()).collect(), + event_ids: self.referenced_event_ids().iter().map(|id| id.clone()).collect(), + words: self.get_tags_content(TagKind::Word).iter().map(|tag| tag.to_string()).collect(), + }) + } +} diff --git a/src/notification_manager/nostr_network_helper.rs b/src/notification_manager/nostr_network_helper.rs new file mode 100644 index 0000000..8930a4e --- /dev/null +++ b/src/notification_manager/nostr_network_helper.rs @@ -0,0 +1,161 @@ +use tokio::sync::Mutex; +use super::nostr_event_extensions::MaybeConvertibleToMuteList; +use super::ExtendedEvent; +use nostr_sdk::prelude::*; +use super::nostr_event_cache::Cache; +use tokio::time::{timeout, Duration}; + +const NOTE_FETCH_TIMEOUT: Duration = Duration::from_secs(5); +const CACHE_MAX_AGE: Duration = Duration::from_secs(60); + +pub struct NostrNetworkHelper { + client: Client, + cache: Mutex, +} + +impl NostrNetworkHelper { + // MARK: - Initialization + + pub async fn new(relay_url: String) -> Result> { + let client = Client::new(&Keys::generate()); + client.add_relay(relay_url.clone()).await?; + client.connect().await; + + Ok(NostrNetworkHelper { client, cache: Mutex::new(Cache::new(CACHE_MAX_AGE)) }) + } + + // MARK: - Answering questions about a user + + pub async fn should_mute_notification_for_pubkey( + &self, + event: &Event, + pubkey: &PublicKey, + ) -> bool { + log::debug!( + "Checking if event {:?} should be muted for pubkey {:?}", + event, + pubkey + ); + if let Some(mute_list) = self.get_public_mute_list(pubkey).await { + for muted_public_key in mute_list.public_keys { + if event.pubkey == muted_public_key { + return true; + } + } + for muted_event_id in mute_list.event_ids { + if event.id == muted_event_id + || event.referenced_event_ids().contains(&muted_event_id) + { + return true; + } + } + for muted_hashtag in mute_list.hashtags { + if event + .referenced_hashtags() + .iter() + .any(|t| t == &muted_hashtag) + { + return true; + } + } + for muted_word in mute_list.words { + if event + .content + .to_lowercase() + .contains(&muted_word.to_lowercase()) + { + return true; + } + } + } + false + } + + pub async fn does_pubkey_follow_pubkey( + &self, + source_pubkey: &PublicKey, + target_pubkey: &PublicKey, + ) -> bool { + log::debug!( + "Checking if pubkey {:?} follows pubkey {:?}", + source_pubkey, + target_pubkey + ); + if let Some(contact_list) = self.get_contact_list(source_pubkey).await { + return contact_list.referenced_pubkeys().contains(target_pubkey); + } + false + } + + // MARK: - Getting specific event types with caching + + pub async fn get_public_mute_list(&self, pubkey: &PublicKey) -> Option { + { + let mut cache_mutex_guard = self.cache.lock().await; + if let Ok(optional_mute_list) = cache_mutex_guard.get_mute_list(pubkey) { + return optional_mute_list; + } + } // Release the lock here for improved performance + + // We don't have an answer from the cache, so we need to fetch it + let mute_list_event = self.fetch_single_event(pubkey, Kind::MuteList).await; + let mut cache_mutex_guard = self.cache.lock().await; + cache_mutex_guard.add_optional_mute_list_with_author(pubkey, mute_list_event.clone()); + mute_list_event?.to_mute_list() + } + + pub async fn get_contact_list(&self, pubkey: &PublicKey) -> Option { + { + let mut cache_mutex_guard = self.cache.lock().await; + if let Ok(optional_contact_list) = cache_mutex_guard.get_contact_list(pubkey) { + return optional_contact_list; + } + } // Release the lock here for improved performance + + // We don't have an answer from the cache, so we need to fetch it + let contact_list_event = self.fetch_single_event(pubkey, Kind::ContactList).await; + let mut cache_mutex_guard = self.cache.lock().await; + cache_mutex_guard.add_optional_contact_list_with_author(pubkey, contact_list_event.clone()); + contact_list_event + } + + // MARK: - Lower level fetching functions + + async fn fetch_single_event(&self, author: &PublicKey, kind: Kind) -> Option { + let subscription_filter = Filter::new() + .kinds(vec![kind]) + .authors(vec![author.clone()]) + .limit(1); + + let mut notifications = self.client.notifications(); + let this_subscription_id = self + .client + .subscribe(Vec::from([subscription_filter]), None) + .await; + + let mut event: Option = None; + + while let Ok(result) = timeout(NOTE_FETCH_TIMEOUT, notifications.recv()).await { + if let Ok(notification) = result { + if let RelayPoolNotification::Event { + subscription_id, + event: event_option, + .. + } = notification + { + if this_subscription_id == subscription_id && event_option.kind == kind { + event = Some((*event_option).clone()); + break; + } + } + } + } + + if event.is_none() { + log::info!("Event of kind {:?} not found for pubkey {:?}", kind, author); + } + + self.client.unsubscribe(this_subscription_id).await; + event + } +} diff --git a/src/notification_manager/notification_manager.rs b/src/notification_manager/notification_manager.rs index 8583de7..db93c12 100644 --- a/src/notification_manager/notification_manager.rs +++ b/src/notification_manager/notification_manager.rs @@ -4,13 +4,16 @@ use nostr::event::EventId; use nostr::key::PublicKey; use nostr::types::Timestamp; use nostr_sdk::JsonUtil; +use nostr_sdk::Kind; use rusqlite; use rusqlite::params; +use serde::Deserialize; +use serde::Serialize; use tokio::sync::Mutex; use std::collections::HashSet; use tokio; -use super::mute_manager::MuteManager; +use super::nostr_network_helper::NostrNetworkHelper; use super::ExtendedEvent; use super::SqlStringConvertible; use nostr::Event; @@ -24,8 +27,7 @@ pub struct NotificationManager { db: Mutex>, apns_topic: String, apns_client: Mutex, - - mute_manager: Mutex, + nostr_network_helper: NostrNetworkHelper, } impl NotificationManager { @@ -40,8 +42,6 @@ impl NotificationManager { apns_environment: a2::client::Endpoint, apns_topic: String, ) -> Result> { - let mute_manager = MuteManager::new(relay_url.clone()).await?; - let connection = db.get()?; Self::setup_database(&connection)?; @@ -58,13 +58,15 @@ impl NotificationManager { apns_topic, apns_client: Mutex::new(client), db: Mutex::new(db), - mute_manager: Mutex::new(mute_manager), + nostr_network_helper: NostrNetworkHelper::new(relay_url.clone()).await?, }) } // MARK: - Database setup operations pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> { + // Initial schema setup + db.execute( "CREATE TABLE IF NOT EXISTS notifications ( id TEXT PRIMARY KEY, @@ -94,8 +96,17 @@ impl NotificationManager { [], )?; - Self::add_column_if_not_exists(&db, "notifications", "sent_at", "INTEGER")?; - Self::add_column_if_not_exists(&db, "user_info", "added_at", "INTEGER")?; + Self::add_column_if_not_exists(&db, "notifications", "sent_at", "INTEGER", None)?; + Self::add_column_if_not_exists(&db, "user_info", "added_at", "INTEGER", None)?; + + // Notification settings migration (https://github.com/damus-io/damus/issues/2360) + + Self::add_column_if_not_exists(&db, "user_info", "zap_notifications_enabled", "BOOLEAN", Some("true"))?; + Self::add_column_if_not_exists(&db, "user_info", "mention_notifications_enabled", "BOOLEAN", Some("true"))?; + Self::add_column_if_not_exists(&db, "user_info", "repost_notifications_enabled", "BOOLEAN", Some("true"))?; + Self::add_column_if_not_exists(&db, "user_info", "reaction_notifications_enabled", "BOOLEAN", Some("true"))?; + Self::add_column_if_not_exists(&db, "user_info", "dm_notifications_enabled", "BOOLEAN", Some("true"))?; + Self::add_column_if_not_exists(&db, "user_info", "only_notifications_from_following_enabled", "BOOLEAN", Some("false"))?; Ok(()) } @@ -105,6 +116,7 @@ impl NotificationManager { table_name: &str, column_name: &str, column_type: &str, + default_value: Option<&str>, ) -> Result<(), rusqlite::Error> { let query = format!("PRAGMA table_info({})", table_name); let mut stmt = db.prepare(&query)?; @@ -115,8 +127,11 @@ impl NotificationManager { if !column_names.contains(&column_name.to_string()) { let query = format!( - "ALTER TABLE {} ADD COLUMN {} {}", - table_name, column_name, column_type + "ALTER TABLE {} ADD COLUMN {} {} {}", + table_name, column_name, column_type, match default_value { + Some(value) => format!("DEFAULT {}", value), + None => "".to_string(), + }, ); db.execute(&query, [])?; } @@ -203,8 +218,7 @@ impl NotificationManager { let mut pubkeys_to_notify = HashSet::new(); for pubkey in relevant_pubkeys_yet_to_receive { let should_mute: bool = { - let mute_manager_mutex_guard = self.mute_manager.lock().await; - mute_manager_mutex_guard + self.nostr_network_helper .should_mute_notification_for_pubkey(event, &pubkey) .await }; @@ -251,11 +265,39 @@ impl NotificationManager { ) -> Result<(), Box> { let user_device_tokens = self.get_user_device_tokens(pubkey).await?; for device_token in user_device_tokens { + if !self.user_wants_notification(pubkey, device_token.clone(), event).await? { + continue; + } self.send_event_notification_to_device_token(event, &device_token) .await?; } Ok(()) } + + async fn user_wants_notification( + &self, + pubkey: &PublicKey, + device_token: String, + event: &Event, + ) -> Result> { + let notification_preferences = self.get_user_notification_settings(pubkey, device_token).await?; + if notification_preferences.only_notifications_from_following_enabled { + if !self.nostr_network_helper.does_pubkey_follow_pubkey(pubkey, &event.author()).await { + return Ok(false); + } + } + match event.kind { + Kind::TextNote => Ok(notification_preferences.mention_notifications_enabled), // TODO: Not 100% accurate + Kind::EncryptedDirectMessage => Ok(notification_preferences.dm_notifications_enabled), + Kind::Repost => Ok(notification_preferences.repost_notifications_enabled), + Kind::GenericRepost => Ok(notification_preferences.repost_notifications_enabled), + Kind::Reaction => Ok(notification_preferences.reaction_notifications_enabled), + Kind::ZapPrivateMessage => Ok(notification_preferences.zap_notifications_enabled), + Kind::ZapRequest => Ok(notification_preferences.zap_notifications_enabled), + Kind::ZapReceipt => Ok(notification_preferences.zap_notifications_enabled), + _ => Ok(false), + } + } async fn get_user_device_tokens( &self, @@ -343,6 +385,8 @@ impl NotificationManager { }; (title, "".to_string(), body) } + + // MARK: - User device info and settings pub async fn save_user_device_info( &self, @@ -375,6 +419,65 @@ impl NotificationManager { )?; Ok(()) } + + pub async fn get_user_notification_settings( + &self, + pubkey: &PublicKey, + device_token: String, + ) -> Result> { + let db_mutex_guard = self.db.lock().await; + let connection = db_mutex_guard.get()?; + let mut stmt = connection.prepare( + "SELECT zap_notifications_enabled, mention_notifications_enabled, repost_notifications_enabled, reaction_notifications_enabled, dm_notifications_enabled, only_notifications_from_following_enabled FROM user_info WHERE pubkey = ? AND device_token = ?", + )?; + let settings = stmt + .query_row([pubkey.to_sql_string(), device_token], |row| { + Ok(UserNotificationSettings { + zap_notifications_enabled: row.get(0)?, + mention_notifications_enabled: row.get(1)?, + repost_notifications_enabled: row.get(2)?, + reaction_notifications_enabled: row.get(3)?, + dm_notifications_enabled: row.get(4)?, + only_notifications_from_following_enabled: row.get(5)?, + }) + })?; + + Ok(settings) + } + + pub async fn save_user_notification_settings( + &self, + pubkey: &PublicKey, + device_token: String, + settings: UserNotificationSettings, + ) -> Result<(), Box> { + let db_mutex_guard = self.db.lock().await; + let connection = db_mutex_guard.get()?; + connection.execute( + "UPDATE user_info SET zap_notifications_enabled = ?, mention_notifications_enabled = ?, repost_notifications_enabled = ?, reaction_notifications_enabled = ?, dm_notifications_enabled = ?, only_notifications_from_following_enabled = ? WHERE pubkey = ? AND device_token = ?", + params![ + settings.zap_notifications_enabled, + settings.mention_notifications_enabled, + settings.repost_notifications_enabled, + settings.reaction_notifications_enabled, + settings.dm_notifications_enabled, + settings.only_notifications_from_following_enabled, + pubkey.to_sql_string(), + device_token, + ], + )?; + Ok(()) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct UserNotificationSettings { + zap_notifications_enabled: bool, + mention_notifications_enabled: bool, + repost_notifications_enabled: bool, + reaction_notifications_enabled: bool, + dm_notifications_enabled: bool, + only_notifications_from_following_enabled: bool } struct NotificationStatus { diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..5a78e80 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod time_delta; diff --git a/src/utils/time_delta.rs b/src/utils/time_delta.rs new file mode 100644 index 0000000..5bfabe7 --- /dev/null +++ b/src/utils/time_delta.rs @@ -0,0 +1,34 @@ +use nostr_sdk::Timestamp; + +pub struct TimeDelta { + pub delta_abs_seconds: u64, + pub negative: bool, +} + +impl TimeDelta { + /// Safely calculate the difference between two timestamps in seconds + /// This function is safer against overflows than subtracting the timestamps directly + pub fn subtracting(t1: Timestamp, t2: Timestamp) -> TimeDelta { + if t1 > t2 { + TimeDelta { + delta_abs_seconds: (t1 - t2).as_u64(), + negative: false, + } + } else { + TimeDelta { + delta_abs_seconds: (t2 - t1).as_u64(), + negative: true, + } + } + } +} + +impl std::fmt::Display for TimeDelta { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.negative { + write!(f, "-{}", self.delta_abs_seconds) + } else { + write!(f, "{}", self.delta_abs_seconds) + } + } +}