Merge pull request #5 from damus-io/#3

Improve performance with nostr event caching and better mutex design
This commit is contained in:
Daniel D’Aquino
2024-08-09 14:59:23 -07:00
committed by GitHub
13 changed files with 777 additions and 241 deletions

4
Cargo.lock generated
View File

@ -1562,9 +1562,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.5"
version = "1.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [
"aho-corasick",
"memchr",

View File

@ -1,4 +1,5 @@
use crate::nip98_auth;
use crate::notification_manager::notification_manager::UserNotificationSettings;
use crate::relay_connection::RelayConnection;
use http_body_util::Full;
use hyper::body::Buf;
@ -9,50 +10,21 @@ use hyper_tungstenite;
use http_body_util::BodyExt;
use nostr;
use serde_json::from_value;
use crate::notification_manager::NotificationManager;
use hyper::Method;
use log;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
struct ParsedRequest {
uri: String,
method: Method,
body_bytes: Option<Vec<u8>>,
authorized_pubkey: nostr::PublicKey,
}
impl ParsedRequest {
fn body_json(&self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
if let Some(body_bytes) = &self.body_bytes {
Ok(serde_json::from_slice(body_bytes)?)
} else {
Ok(json!({}))
}
}
}
struct APIResponse {
status: StatusCode,
body: Value,
}
pub struct APIHandler {
notification_manager: Arc<NotificationManager>,
base_url: String,
}
impl Clone for APIHandler {
fn clone(&self) -> Self {
APIHandler {
notification_manager: self.notification_manager.clone(),
base_url: self.base_url.clone(),
}
}
}
impl APIHandler {
pub fn new(notification_manager: Arc<NotificationManager>, base_url: String) -> Self {
APIHandler {
@ -60,6 +32,8 @@ impl APIHandler {
base_url,
}
}
// MARK: - HTTP handling
pub async fn handle_http_request(
&self,
@ -185,22 +159,37 @@ impl APIHandler {
authorized_pubkey,
})
}
// MARK: - Router
async fn handle_parsed_http_request(
&self,
parsed_request: &ParsedRequest,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
match (&parsed_request.method, parsed_request.uri.as_str()) {
(&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await,
(&Method::POST, "/user-info/remove") => {
self.handle_user_info_remove(parsed_request).await
}
_ => Ok(APIResponse {
status: StatusCode::NOT_FOUND,
body: json!({ "error": "Not found" }),
}),
if let Some(url_params) = route_match(&Method::PUT, "/user-info/:pubkey/:deviceToken", &parsed_request) {
return self.handle_user_info(parsed_request, &url_params).await;
}
if let Some(url_params) = route_match(&Method::DELETE, "/user-info/:pubkey/:deviceToken", &parsed_request) {
return self.handle_user_info_remove(parsed_request, &url_params).await;
}
if let Some(url_params) = route_match(&Method::GET, "/user-info/:pubkey/:deviceToken/preferences", &parsed_request) {
return self.get_user_settings(parsed_request, &url_params).await;
}
if let Some(url_params) = route_match(&Method::PUT, "/user-info/:pubkey/:deviceToken/preferences", &parsed_request) {
return self.set_user_settings(parsed_request, &url_params).await;
}
Ok(APIResponse {
status: StatusCode::NOT_FOUND,
body: json!({ "error": "Not found" }),
})
}
// MARK: - Authentication
async fn authenticate(
&self,
@ -220,51 +209,283 @@ impl APIHandler {
)
.await)
}
// MARK: - Endpoint handlers
async fn handle_user_info(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
let body = req.body_json()?;
if let Some(device_token) = body["deviceToken"].as_str() {
self.notification_manager.save_user_device_info(req.authorized_pubkey, device_token).await?;
return Ok(APIResponse {
status: StatusCode::OK,
body: json!({ "message": "User info saved successfully" }),
});
} else {
return Ok(APIResponse {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required" }),
body: json!({ "error": "deviceToken is required on the URL" }),
}),
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
}),
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
}),
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
// Proceed with the main logic after passing all checks
self.notification_manager.save_user_device_info(pubkey, device_token).await?;
Ok(APIResponse {
status: StatusCode::OK,
body: json!({ "message": "User info saved successfully" }),
})
}
async fn handle_user_info_remove(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
let body: Value = req.body_json()?;
if let Some(device_token) = body["deviceToken"].as_str() {
self.notification_manager.remove_user_device_info(req.authorized_pubkey, device_token).await?;
return Ok(APIResponse {
status: StatusCode::OK,
body: json!({ "message": "User info removed successfully" }),
});
} else {
return Ok(APIResponse {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required" }),
body: json!({ "error": "deviceToken is required on the URL" }),
}),
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
}),
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
}),
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
// Proceed with the main logic after passing all checks
self.notification_manager.remove_user_device_info(pubkey, device_token).await?;
Ok(APIResponse {
status: StatusCode::OK,
body: json!({ "message": "User info removed successfully" }),
})
}
async fn set_user_settings(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
}),
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
}),
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
}),
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
// Proceed with the main logic after passing all checks
let body = req.body_json()?;
let settings: UserNotificationSettings = match from_value(body.clone()) {
Ok(settings) => settings,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid settings" }),
});
}
};
self.notification_manager.save_user_notification_settings(&req.authorized_pubkey, device_token.to_string(), settings).await?;
return Ok(APIResponse {
status: StatusCode::OK,
body: json!({ "message": "User settings saved successfully" }),
});
}
async fn get_user_settings(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
}),
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
}),
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
}),
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
// Proceed with the main logic after passing all checks
let settings = self.notification_manager.get_user_notification_settings(&req.authorized_pubkey, device_token.to_string()).await?;
Ok(APIResponse {
status: StatusCode::OK,
body: json!(settings),
})
}
}
// MARK: - Extensions
impl Clone for APIHandler {
fn clone(&self) -> Self {
APIHandler {
notification_manager: self.notification_manager.clone(),
base_url: self.base_url.clone(),
}
}
}
// MARK: - Helper types
// Define enum error types including authentication error
#[derive(Debug, Error)]
enum APIError {
#[error("Authentication error: {0}")]
AuthenticationError(String),
}
struct ParsedRequest {
uri: String,
method: Method,
body_bytes: Option<Vec<u8>>,
authorized_pubkey: nostr::PublicKey,
}
impl ParsedRequest {
fn body_json(&self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
if let Some(body_bytes) = &self.body_bytes {
Ok(serde_json::from_slice(body_bytes)?)
} else {
Ok(json!({}))
}
}
}
struct APIResponse {
status: StatusCode,
body: Value,
}
// MARK: - Helper functions
/// Matches the request to a specified route, returning a hashmap of the route parameters
/// e.g. GET /user/:id/info route against request GET /user/123/info matches to { "id": "123" }
fn route_match<'a>(method: &Method, path: &'a str, req: &ParsedRequest) -> Option<HashMap<&'a str, String>> {
if method != req.method {
return None;
}
let mut params = HashMap::new();
let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let req_segments: Vec<&str> = req.uri.split('/').filter(|s| !s.is_empty()).collect();
if path_segments.len() != req_segments.len() {
return None;
}
for (i, segment) in path_segments.iter().enumerate() {
if segment.starts_with(':') {
let key = &segment[1..];
let value = req_segments[i].to_string();
params.insert(key, value);
} else if segment != &req_segments[i] {
return None;
}
}
Some(params)
}

View File

@ -1 +1,2 @@
pub mod notification_manager;
mod utils;

View File

@ -12,6 +12,7 @@ mod notepush_env;
use notepush_env::NotePushEnv;
mod api_request_handler;
mod nip98_auth;
mod utils;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {

View File

@ -3,8 +3,8 @@ use nostr;
use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash;
use nostr::bitcoin::hashes::Hash;
use nostr::util::hex;
use nostr::Timestamp;
use serde_json::Value;
use super::utils::time_delta::TimeDelta;
pub async fn nip98_verify_auth_header(
auth_header: String,
@ -106,36 +106,3 @@ pub async fn nip98_verify_auth_header(
Ok(note.pubkey)
}
struct TimeDelta {
delta_abs_seconds: u64,
negative: bool,
}
impl TimeDelta {
/// Safely calculate the difference between two timestamps in seconds
/// This function is safer against overflows than subtracting the timestamps directly
fn subtracting(t1: Timestamp, t2: Timestamp) -> TimeDelta {
if t1 > t2 {
TimeDelta {
delta_abs_seconds: (t1 - t2).as_u64(),
negative: false,
}
} else {
TimeDelta {
delta_abs_seconds: (t2 - t1).as_u64(),
negative: true,
}
}
}
}
impl std::fmt::Display for TimeDelta {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.negative {
write!(f, "-{}", self.delta_abs_seconds)
} else {
write!(f, "{}", self.delta_abs_seconds)
}
}
}

View File

@ -1,7 +1,8 @@
pub mod mute_manager;
pub mod nostr_network_helper;
mod nostr_event_extensions;
mod nostr_event_cache;
pub mod notification_manager;
pub use mute_manager::MuteManager;
pub use nostr_network_helper::NostrNetworkHelper;
use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible};
pub use notification_manager::NotificationManager;

View File

@ -1,128 +0,0 @@
use super::ExtendedEvent;
use nostr_sdk::prelude::*;
use tokio::time::{timeout, Duration};
pub struct MuteManager {
client: Client,
}
impl MuteManager {
pub async fn new(relay_url: String) -> Result<Self, Box<dyn std::error::Error>> {
let client = Client::new(&Keys::generate());
client.add_relay(relay_url.clone()).await?;
client.connect().await;
Ok(MuteManager { client })
}
pub async fn should_mute_notification_for_pubkey(
&self,
event: &Event,
pubkey: &PublicKey,
) -> bool {
log::debug!(
"Checking if event {:?} should be muted for pubkey {:?}",
event,
pubkey
);
if let Some(mute_list) = self.get_public_mute_list(pubkey).await {
for tag in mute_list.tags() {
match tag.kind() {
TagKind::SingleLetter(SingleLetterTag {
character: Alphabet::P,
uppercase: false,
}) => {
let tagged_pubkey: Option<PublicKey> =
tag.content().and_then(|h| PublicKey::from_hex(h).ok());
if let Some(tagged_pubkey) = tagged_pubkey {
if event.pubkey == tagged_pubkey {
return true;
}
}
}
TagKind::SingleLetter(SingleLetterTag {
character: Alphabet::E,
uppercase: false,
}) => {
let tagged_event_id: Option<EventId> =
tag.content().and_then(|h| EventId::from_hex(h).ok());
if let Some(tagged_event_id) = tagged_event_id {
if event.id == tagged_event_id
|| event.referenced_event_ids().contains(&tagged_event_id)
{
return true;
}
}
}
TagKind::SingleLetter(SingleLetterTag {
character: Alphabet::T,
uppercase: false,
}) => {
let tagged_hashtag: Option<String> = tag.content().map(|h| h.to_string());
if let Some(tagged_hashtag) = tagged_hashtag {
let tags_content =
event.get_tags_content(TagKind::SingleLetter(SingleLetterTag {
character: Alphabet::T,
uppercase: false,
}));
let should_mute = tags_content.iter().any(|t| t == &tagged_hashtag);
return should_mute;
}
}
TagKind::Word => {
let tagged_word: Option<String> = tag.content().map(|h| h.to_string());
if let Some(tagged_word) = tagged_word {
if event
.content
.to_lowercase()
.contains(&tagged_word.to_lowercase())
{
return true;
}
}
}
_ => {}
}
}
}
false
}
pub async fn get_public_mute_list(&self, pubkey: &PublicKey) -> Option<Event> {
let subscription_filter = Filter::new()
.kinds(vec![Kind::MuteList])
.authors(vec![pubkey.clone()])
.limit(1);
let this_subscription_id = self
.client
.subscribe(Vec::from([subscription_filter]), None)
.await;
let mut mute_list: Option<Event> = None;
let mut notifications = self.client.notifications();
let timeout_duration = Duration::from_secs(10);
while let Ok(result) = timeout(timeout_duration, notifications.recv()).await {
if let Ok(notification) = result {
if let RelayPoolNotification::Event {
subscription_id,
event,
..
} = notification
{
if this_subscription_id == subscription_id && event.kind == Kind::MuteList {
mute_list = Some((*event).clone());
break;
}
}
}
}
if mute_list.is_none() {
log::debug!("Mute list not found for pubkey {:?}", pubkey);
}
self.client.unsubscribe(this_subscription_id).await;
mute_list
}
}

View File

@ -0,0 +1,144 @@
use crate::utils::time_delta::TimeDelta;
use tokio::time::Duration;
use nostr_sdk::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use log;
use super::nostr_event_extensions::MaybeConvertibleToMuteList;
struct CacheEntry {
event: Option<Event>, // `None` means the event does not exist as far as we know (It does NOT mean expired)
added_at: nostr::Timestamp,
}
impl CacheEntry {
fn is_expired(&self, max_age: Duration) -> bool {
let time_delta = TimeDelta::subtracting(nostr::Timestamp::now(), self.added_at);
time_delta.negative || (time_delta.delta_abs_seconds > max_age.as_secs())
}
}
pub struct Cache {
entries: HashMap<EventId, Arc<CacheEntry>>,
mute_lists: HashMap<PublicKey, Arc<CacheEntry>>,
contact_lists: HashMap<PublicKey, Arc<CacheEntry>>,
max_age: Duration,
}
impl Cache {
// MARK: - Initialization
pub fn new(max_age: Duration) -> Self {
Cache {
entries: HashMap::new(),
mute_lists: HashMap::new(),
contact_lists: HashMap::new(),
max_age,
}
}
// MARK: - Adding items to the cache
pub fn add_optional_mute_list_with_author(&mut self, author: &PublicKey, mute_list: Option<Event>) {
if let Some(mute_list) = mute_list {
self.add_event(mute_list);
} else {
self.mute_lists.insert(
author.clone(),
Arc::new(CacheEntry {
event: None,
added_at: nostr::Timestamp::now(),
}),
);
}
}
pub fn add_optional_contact_list_with_author(&mut self, author: &PublicKey, contact_list: Option<Event>) {
if let Some(contact_list) = contact_list {
self.add_event(contact_list);
} else {
self.contact_lists.insert(
author.clone(),
Arc::new(CacheEntry {
event: None,
added_at: nostr::Timestamp::now(),
}),
);
}
}
pub fn add_event(&mut self, event: Event) {
let entry = Arc::new(CacheEntry {
event: Some(event.clone()),
added_at: nostr::Timestamp::now(),
});
self.entries.insert(event.id.clone(), entry.clone());
match event.kind {
Kind::MuteList => {
self.mute_lists.insert(event.pubkey.clone(), entry.clone());
log::debug!("Added mute list to the cache. Event ID: {}", event.id.to_hex());
}
Kind::ContactList => {
self.contact_lists
.insert(event.pubkey.clone(), entry.clone());
log::debug!("Added contact list to the cache. Event ID: {}", event.id.to_hex());
}
_ => {
log::debug!("Added event to the cache. Event ID: {}", event.id.to_hex());
}
}
}
// MARK: - Fetching items from the cache
pub fn get_mute_list(&mut self, pubkey: &PublicKey) -> Result<Option<MuteList>, CacheError> {
if let Some(entry) = self.mute_lists.get(pubkey) {
let entry = entry.clone(); // Clone the Arc to avoid borrowing issues
if !entry.is_expired(self.max_age) {
if let Some(event) = entry.event.clone() {
return Ok(event.to_mute_list());
}
} else {
log::debug!("Mute list for pubkey {} is expired, removing it from the cache", pubkey.to_hex());
self.mute_lists.remove(pubkey);
self.remove_event_from_all_maps(&entry.event);
}
}
Err(CacheError::NotFound)
}
pub fn get_contact_list(&mut self, pubkey: &PublicKey) -> Result<Option<Event>, CacheError> {
if let Some(entry) = self.contact_lists.get(pubkey) {
let entry = entry.clone(); // Clone the Arc to avoid borrowing issues
if !entry.is_expired(self.max_age) {
return Ok(entry.event.clone());
} else {
log::debug!("Contact list for pubkey {} is expired, removing it from the cache", pubkey.to_hex());
self.contact_lists.remove(pubkey);
self.remove_event_from_all_maps(&entry.event);
}
}
Err(CacheError::NotFound)
}
// MARK: - Removing items from the cache
fn remove_event_from_all_maps(&mut self, event: &Option<Event>) {
if let Some(event) = event {
let event_id = event.id.clone();
let pubkey = event.pubkey.clone();
self.entries.remove(&event_id);
self.mute_lists.remove(&pubkey);
self.contact_lists.remove(&pubkey);
}
// We can't remove an event from all maps if the event does not exist
}
}
// Error type
#[derive(Debug)]
pub enum CacheError {
NotFound,
}

View File

@ -1,4 +1,5 @@
use nostr::{self, key::PublicKey, Alphabet, SingleLetterTag, TagKind::SingleLetter};
use nostr::{self, key::PublicKey, nips::nip51::MuteList, Alphabet, SingleLetterTag, TagKind::SingleLetter};
use nostr_sdk::{Kind, TagKind};
/// Temporary scaffolding of old methods that have not been ported to use native Event methods
pub trait ExtendedEvent {
@ -13,6 +14,9 @@ pub trait ExtendedEvent {
/// Retrieves a set of event IDs referenced by the note
fn referenced_event_ids(&self) -> std::collections::HashSet<nostr::EventId>;
/// Retrieves a set of hashtags (t tags) referenced by the note
fn referenced_hashtags(&self) -> std::collections::HashSet<String>;
}
// This is a wrapper around the Event type from strfry-policies, which adds some useful methods
@ -44,6 +48,14 @@ impl ExtendedEvent for nostr::Event {
.filter_map(|tag| nostr::EventId::from_hex(tag).ok())
.collect()
}
/// Retrieves a set of hashtags (t tags) referenced by the note
fn referenced_hashtags(&self) -> std::collections::HashSet<String> {
self.get_tags_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::T)))
.iter()
.map(|tag| tag.to_string())
.collect()
}
}
// MARK: - SQL String Convertible
@ -85,3 +97,21 @@ impl SqlStringConvertible for nostr::Timestamp {
Ok(nostr::Timestamp::from(u64_timestamp))
}
}
pub trait MaybeConvertibleToMuteList {
fn to_mute_list(&self) -> Option<MuteList>;
}
impl MaybeConvertibleToMuteList for nostr::Event {
fn to_mute_list(&self) -> Option<MuteList> {
if self.kind != Kind::MuteList {
return None;
}
Some(MuteList {
public_keys: self.referenced_pubkeys().iter().map(|pk| pk.clone()).collect(),
hashtags: self.referenced_hashtags().iter().map(|tag| tag.clone()).collect(),
event_ids: self.referenced_event_ids().iter().map(|id| id.clone()).collect(),
words: self.get_tags_content(TagKind::Word).iter().map(|tag| tag.to_string()).collect(),
})
}
}

View File

@ -0,0 +1,161 @@
use tokio::sync::Mutex;
use super::nostr_event_extensions::MaybeConvertibleToMuteList;
use super::ExtendedEvent;
use nostr_sdk::prelude::*;
use super::nostr_event_cache::Cache;
use tokio::time::{timeout, Duration};
const NOTE_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
const CACHE_MAX_AGE: Duration = Duration::from_secs(60);
pub struct NostrNetworkHelper {
client: Client,
cache: Mutex<Cache>,
}
impl NostrNetworkHelper {
// MARK: - Initialization
pub async fn new(relay_url: String) -> Result<Self, Box<dyn std::error::Error>> {
let client = Client::new(&Keys::generate());
client.add_relay(relay_url.clone()).await?;
client.connect().await;
Ok(NostrNetworkHelper { client, cache: Mutex::new(Cache::new(CACHE_MAX_AGE)) })
}
// MARK: - Answering questions about a user
pub async fn should_mute_notification_for_pubkey(
&self,
event: &Event,
pubkey: &PublicKey,
) -> bool {
log::debug!(
"Checking if event {:?} should be muted for pubkey {:?}",
event,
pubkey
);
if let Some(mute_list) = self.get_public_mute_list(pubkey).await {
for muted_public_key in mute_list.public_keys {
if event.pubkey == muted_public_key {
return true;
}
}
for muted_event_id in mute_list.event_ids {
if event.id == muted_event_id
|| event.referenced_event_ids().contains(&muted_event_id)
{
return true;
}
}
for muted_hashtag in mute_list.hashtags {
if event
.referenced_hashtags()
.iter()
.any(|t| t == &muted_hashtag)
{
return true;
}
}
for muted_word in mute_list.words {
if event
.content
.to_lowercase()
.contains(&muted_word.to_lowercase())
{
return true;
}
}
}
false
}
pub async fn does_pubkey_follow_pubkey(
&self,
source_pubkey: &PublicKey,
target_pubkey: &PublicKey,
) -> bool {
log::debug!(
"Checking if pubkey {:?} follows pubkey {:?}",
source_pubkey,
target_pubkey
);
if let Some(contact_list) = self.get_contact_list(source_pubkey).await {
return contact_list.referenced_pubkeys().contains(target_pubkey);
}
false
}
// MARK: - Getting specific event types with caching
pub async fn get_public_mute_list(&self, pubkey: &PublicKey) -> Option<MuteList> {
{
let mut cache_mutex_guard = self.cache.lock().await;
if let Ok(optional_mute_list) = cache_mutex_guard.get_mute_list(pubkey) {
return optional_mute_list;
}
} // Release the lock here for improved performance
// We don't have an answer from the cache, so we need to fetch it
let mute_list_event = self.fetch_single_event(pubkey, Kind::MuteList).await;
let mut cache_mutex_guard = self.cache.lock().await;
cache_mutex_guard.add_optional_mute_list_with_author(pubkey, mute_list_event.clone());
mute_list_event?.to_mute_list()
}
pub async fn get_contact_list(&self, pubkey: &PublicKey) -> Option<Event> {
{
let mut cache_mutex_guard = self.cache.lock().await;
if let Ok(optional_contact_list) = cache_mutex_guard.get_contact_list(pubkey) {
return optional_contact_list;
}
} // Release the lock here for improved performance
// We don't have an answer from the cache, so we need to fetch it
let contact_list_event = self.fetch_single_event(pubkey, Kind::ContactList).await;
let mut cache_mutex_guard = self.cache.lock().await;
cache_mutex_guard.add_optional_contact_list_with_author(pubkey, contact_list_event.clone());
contact_list_event
}
// MARK: - Lower level fetching functions
async fn fetch_single_event(&self, author: &PublicKey, kind: Kind) -> Option<Event> {
let subscription_filter = Filter::new()
.kinds(vec![kind])
.authors(vec![author.clone()])
.limit(1);
let mut notifications = self.client.notifications();
let this_subscription_id = self
.client
.subscribe(Vec::from([subscription_filter]), None)
.await;
let mut event: Option<Event> = None;
while let Ok(result) = timeout(NOTE_FETCH_TIMEOUT, notifications.recv()).await {
if let Ok(notification) = result {
if let RelayPoolNotification::Event {
subscription_id,
event: event_option,
..
} = notification
{
if this_subscription_id == subscription_id && event_option.kind == kind {
event = Some((*event_option).clone());
break;
}
}
}
}
if event.is_none() {
log::info!("Event of kind {:?} not found for pubkey {:?}", kind, author);
}
self.client.unsubscribe(this_subscription_id).await;
event
}
}

View File

@ -4,13 +4,16 @@ use nostr::event::EventId;
use nostr::key::PublicKey;
use nostr::types::Timestamp;
use nostr_sdk::JsonUtil;
use nostr_sdk::Kind;
use rusqlite;
use rusqlite::params;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::Mutex;
use std::collections::HashSet;
use tokio;
use super::mute_manager::MuteManager;
use super::nostr_network_helper::NostrNetworkHelper;
use super::ExtendedEvent;
use super::SqlStringConvertible;
use nostr::Event;
@ -24,8 +27,7 @@ pub struct NotificationManager {
db: Mutex<r2d2::Pool<SqliteConnectionManager>>,
apns_topic: String,
apns_client: Mutex<Client>,
mute_manager: Mutex<MuteManager>,
nostr_network_helper: NostrNetworkHelper,
}
impl NotificationManager {
@ -40,8 +42,6 @@ impl NotificationManager {
apns_environment: a2::client::Endpoint,
apns_topic: String,
) -> Result<Self, Box<dyn std::error::Error>> {
let mute_manager = MuteManager::new(relay_url.clone()).await?;
let connection = db.get()?;
Self::setup_database(&connection)?;
@ -58,13 +58,15 @@ impl NotificationManager {
apns_topic,
apns_client: Mutex::new(client),
db: Mutex::new(db),
mute_manager: Mutex::new(mute_manager),
nostr_network_helper: NostrNetworkHelper::new(relay_url.clone()).await?,
})
}
// MARK: - Database setup operations
pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
// Initial schema setup
db.execute(
"CREATE TABLE IF NOT EXISTS notifications (
id TEXT PRIMARY KEY,
@ -94,8 +96,17 @@ impl NotificationManager {
[],
)?;
Self::add_column_if_not_exists(&db, "notifications", "sent_at", "INTEGER")?;
Self::add_column_if_not_exists(&db, "user_info", "added_at", "INTEGER")?;
Self::add_column_if_not_exists(&db, "notifications", "sent_at", "INTEGER", None)?;
Self::add_column_if_not_exists(&db, "user_info", "added_at", "INTEGER", None)?;
// Notification settings migration (https://github.com/damus-io/damus/issues/2360)
Self::add_column_if_not_exists(&db, "user_info", "zap_notifications_enabled", "BOOLEAN", Some("true"))?;
Self::add_column_if_not_exists(&db, "user_info", "mention_notifications_enabled", "BOOLEAN", Some("true"))?;
Self::add_column_if_not_exists(&db, "user_info", "repost_notifications_enabled", "BOOLEAN", Some("true"))?;
Self::add_column_if_not_exists(&db, "user_info", "reaction_notifications_enabled", "BOOLEAN", Some("true"))?;
Self::add_column_if_not_exists(&db, "user_info", "dm_notifications_enabled", "BOOLEAN", Some("true"))?;
Self::add_column_if_not_exists(&db, "user_info", "only_notifications_from_following_enabled", "BOOLEAN", Some("false"))?;
Ok(())
}
@ -105,6 +116,7 @@ impl NotificationManager {
table_name: &str,
column_name: &str,
column_type: &str,
default_value: Option<&str>,
) -> Result<(), rusqlite::Error> {
let query = format!("PRAGMA table_info({})", table_name);
let mut stmt = db.prepare(&query)?;
@ -115,8 +127,11 @@ impl NotificationManager {
if !column_names.contains(&column_name.to_string()) {
let query = format!(
"ALTER TABLE {} ADD COLUMN {} {}",
table_name, column_name, column_type
"ALTER TABLE {} ADD COLUMN {} {} {}",
table_name, column_name, column_type, match default_value {
Some(value) => format!("DEFAULT {}", value),
None => "".to_string(),
},
);
db.execute(&query, [])?;
}
@ -203,8 +218,7 @@ impl NotificationManager {
let mut pubkeys_to_notify = HashSet::new();
for pubkey in relevant_pubkeys_yet_to_receive {
let should_mute: bool = {
let mute_manager_mutex_guard = self.mute_manager.lock().await;
mute_manager_mutex_guard
self.nostr_network_helper
.should_mute_notification_for_pubkey(event, &pubkey)
.await
};
@ -251,11 +265,39 @@ impl NotificationManager {
) -> Result<(), Box<dyn std::error::Error>> {
let user_device_tokens = self.get_user_device_tokens(pubkey).await?;
for device_token in user_device_tokens {
if !self.user_wants_notification(pubkey, device_token.clone(), event).await? {
continue;
}
self.send_event_notification_to_device_token(event, &device_token)
.await?;
}
Ok(())
}
async fn user_wants_notification(
&self,
pubkey: &PublicKey,
device_token: String,
event: &Event,
) -> Result<bool, Box<dyn std::error::Error>> {
let notification_preferences = self.get_user_notification_settings(pubkey, device_token).await?;
if notification_preferences.only_notifications_from_following_enabled {
if !self.nostr_network_helper.does_pubkey_follow_pubkey(pubkey, &event.author()).await {
return Ok(false);
}
}
match event.kind {
Kind::TextNote => Ok(notification_preferences.mention_notifications_enabled), // TODO: Not 100% accurate
Kind::EncryptedDirectMessage => Ok(notification_preferences.dm_notifications_enabled),
Kind::Repost => Ok(notification_preferences.repost_notifications_enabled),
Kind::GenericRepost => Ok(notification_preferences.repost_notifications_enabled),
Kind::Reaction => Ok(notification_preferences.reaction_notifications_enabled),
Kind::ZapPrivateMessage => Ok(notification_preferences.zap_notifications_enabled),
Kind::ZapRequest => Ok(notification_preferences.zap_notifications_enabled),
Kind::ZapReceipt => Ok(notification_preferences.zap_notifications_enabled),
_ => Ok(false),
}
}
async fn get_user_device_tokens(
&self,
@ -343,6 +385,8 @@ impl NotificationManager {
};
(title, "".to_string(), body)
}
// MARK: - User device info and settings
pub async fn save_user_device_info(
&self,
@ -375,6 +419,65 @@ impl NotificationManager {
)?;
Ok(())
}
pub async fn get_user_notification_settings(
&self,
pubkey: &PublicKey,
device_token: String,
) -> Result<UserNotificationSettings, Box<dyn std::error::Error>> {
let db_mutex_guard = self.db.lock().await;
let connection = db_mutex_guard.get()?;
let mut stmt = connection.prepare(
"SELECT zap_notifications_enabled, mention_notifications_enabled, repost_notifications_enabled, reaction_notifications_enabled, dm_notifications_enabled, only_notifications_from_following_enabled FROM user_info WHERE pubkey = ? AND device_token = ?",
)?;
let settings = stmt
.query_row([pubkey.to_sql_string(), device_token], |row| {
Ok(UserNotificationSettings {
zap_notifications_enabled: row.get(0)?,
mention_notifications_enabled: row.get(1)?,
repost_notifications_enabled: row.get(2)?,
reaction_notifications_enabled: row.get(3)?,
dm_notifications_enabled: row.get(4)?,
only_notifications_from_following_enabled: row.get(5)?,
})
})?;
Ok(settings)
}
pub async fn save_user_notification_settings(
&self,
pubkey: &PublicKey,
device_token: String,
settings: UserNotificationSettings,
) -> Result<(), Box<dyn std::error::Error>> {
let db_mutex_guard = self.db.lock().await;
let connection = db_mutex_guard.get()?;
connection.execute(
"UPDATE user_info SET zap_notifications_enabled = ?, mention_notifications_enabled = ?, repost_notifications_enabled = ?, reaction_notifications_enabled = ?, dm_notifications_enabled = ?, only_notifications_from_following_enabled = ? WHERE pubkey = ? AND device_token = ?",
params![
settings.zap_notifications_enabled,
settings.mention_notifications_enabled,
settings.repost_notifications_enabled,
settings.reaction_notifications_enabled,
settings.dm_notifications_enabled,
settings.only_notifications_from_following_enabled,
pubkey.to_sql_string(),
device_token,
],
)?;
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct UserNotificationSettings {
zap_notifications_enabled: bool,
mention_notifications_enabled: bool,
repost_notifications_enabled: bool,
reaction_notifications_enabled: bool,
dm_notifications_enabled: bool,
only_notifications_from_following_enabled: bool
}
struct NotificationStatus {

1
src/utils/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod time_delta;

34
src/utils/time_delta.rs Normal file
View File

@ -0,0 +1,34 @@
use nostr_sdk::Timestamp;
pub struct TimeDelta {
pub delta_abs_seconds: u64,
pub negative: bool,
}
impl TimeDelta {
/// Safely calculate the difference between two timestamps in seconds
/// This function is safer against overflows than subtracting the timestamps directly
pub fn subtracting(t1: Timestamp, t2: Timestamp) -> TimeDelta {
if t1 > t2 {
TimeDelta {
delta_abs_seconds: (t1 - t2).as_u64(),
negative: false,
}
} else {
TimeDelta {
delta_abs_seconds: (t2 - t1).as_u64(),
negative: true,
}
}
}
}
impl std::fmt::Display for TimeDelta {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.negative {
write!(f, "-{}", self.delta_abs_seconds)
} else {
write!(f, "{}", self.delta_abs_seconds)
}
}
}