From 36cc9f87427d241071dc95f200f765042356211f Mon Sep 17 00:00:00 2001 From: William Casarin Date: Wed, 31 Jul 2024 17:44:34 -0700 Subject: [PATCH] run cargo fmt --all we should try to stick to rustfmt style, many editors save it like this by default. Signed-off-by: William Casarin --- src/api_server/api_request_handler.rs | 144 +++++++++------ src/api_server/api_server.rs | 23 ++- src/api_server/mod.rs | 2 +- src/api_server/nip98_auth.rs | 52 ++++-- src/main.rs | 62 ++++--- src/notepush_env.rs | 16 +- src/notification_manager/mod.rs | 4 +- src/notification_manager/mute_manager.rs | 74 +++++--- .../nostr_event_extensions.rs | 34 ++-- .../notification_manager.rs | 171 ++++++++++++------ src/relay_connection.rs | 77 +++++--- 11 files changed, 419 insertions(+), 240 deletions(-) diff --git a/src/api_server/api_request_handler.rs b/src/api_server/api_request_handler.rs index 9ebfc92..599cf35 100644 --- a/src/api_server/api_request_handler.rs +++ b/src/api_server/api_request_handler.rs @@ -1,18 +1,18 @@ use super::nip98_auth; -use hyper::{Request, Response, StatusCode}; use hyper::body::Buf; use hyper::body::Incoming; +use hyper::{Request, Response, StatusCode}; use http_body_util::BodyExt; use nostr; -use thiserror::Error; -use std::sync::Arc; -use log; -use hyper::Method; -use tokio::sync::Mutex; -use serde_json::{json, Value}; use crate::notification_manager::NotificationManager; +use hyper::Method; +use log; +use serde_json::{json, Value}; +use std::sync::Arc; +use thiserror::Error; +use tokio::sync::Mutex; struct ParsedRequest { uri: String, @@ -57,31 +57,33 @@ impl APIHandler { base_url, } } - - pub async fn handle_http_request(&self, req: Request) -> Result, hyper::http::Error> { + + pub async fn handle_http_request( + &self, + req: Request, + ) -> Result, hyper::http::Error> { let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await { - Ok(api_response) => { - APIResponse { - status: api_response.status, - body: api_response.body, - } + Ok(api_response) => APIResponse { + status: api_response.status, + body: api_response.body, }, Err(err) => { // Detect if error is a APIError::AuthenticationError and return a 401 status code if let Some(api_error) = err.downcast_ref::() { match api_error { - APIError::AuthenticationError(message) => { - APIResponse { - status: StatusCode::UNAUTHORIZED, - body: json!({ "error": "Unauthorized", "message": message }), - } + APIError::AuthenticationError(message) => APIResponse { + status: StatusCode::UNAUTHORIZED, + body: json!({ "error": "Unauthorized", "message": message }), }, } - } - else { + } else { // Otherwise, return a 500 status code let random_case_uuid = uuid::Uuid::new_v4(); - log::error!("Error handling request: {} (Case ID: {})", err, random_case_uuid); + log::error!( + "Error handling request: {} (Case ID: {})", + err, + random_case_uuid + ); APIResponse { status: StatusCode::INTERNAL_SERVER_ERROR, body: json!({ "error": "Internal server error", "message": format!("Case ID: {}", random_case_uuid) }), @@ -93,33 +95,46 @@ impl APIHandler { .header("Content-Type", "application/json") .header("Access-Control-Allow-Origin", "*") .status(final_api_response.status) - .body(final_api_response.body.to_string())? - ) + .body(final_api_response.body.to_string())?) } - - async fn try_to_handle_http_request(&self, mut req: Request) -> Result> { - let parsed_request = self.parse_http_request(&mut req).await?; - let api_response: APIResponse = self.handle_parsed_http_request(&parsed_request).await?; - log::info!("[{}] {} (Authorized pubkey: {}): {}", req.method(), req.uri(), parsed_request.authorized_pubkey, api_response.status); + + async fn try_to_handle_http_request( + &self, + mut req: Request, + ) -> Result> { + let parsed_request = self.parse_http_request(&mut req).await?; + let api_response: APIResponse = self.handle_parsed_http_request(&parsed_request).await?; + log::info!( + "[{}] {} (Authorized pubkey: {}): {}", + req.method(), + req.uri(), + parsed_request.authorized_pubkey, + api_response.status + ); Ok(api_response) } - - async fn parse_http_request(&self, req: &mut Request) -> Result> { + + async fn parse_http_request( + &self, + req: &mut Request, + ) -> Result> { // 1. Read the request body let body_buffer = req.body_mut().collect().await?.aggregate(); let body_bytes = body_buffer.chunk(); - let body_bytes = if body_bytes.is_empty() { None } else { Some(body_bytes) }; - + let body_bytes = if body_bytes.is_empty() { + None + } else { + Some(body_bytes) + }; + // 2. NIP-98 authentication let authorized_pubkey = match self.authenticate(&req, body_bytes).await? { - Ok(pubkey) => { - pubkey - }, + Ok(pubkey) => pubkey, Err(auth_error) => { return Err(Box::new(APIError::AuthenticationError(auth_error))); } }; - + // 3. Parse the request Ok(ParsedRequest { uri: req.uri().path().to_string(), @@ -128,37 +143,48 @@ impl APIHandler { authorized_pubkey, }) } - - async fn handle_parsed_http_request(&self, parsed_request: &ParsedRequest) -> Result> { + + async fn handle_parsed_http_request( + &self, + parsed_request: &ParsedRequest, + ) -> Result> { match (&parsed_request.method, parsed_request.uri.as_str()) { (&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await, - (&Method::POST, "/user-info/remove") => self.handle_user_info_remove(parsed_request).await, - _ => { - Ok(APIResponse { - status: StatusCode::NOT_FOUND, - body: json!({ "error": "Not found" }), - }) + (&Method::POST, "/user-info/remove") => { + self.handle_user_info_remove(parsed_request).await } + _ => Ok(APIResponse { + status: StatusCode::NOT_FOUND, + body: json!({ "error": "Not found" }), + }), } } - - async fn authenticate(&self, req: &Request, body_bytes: Option<&[u8]>) -> Result, Box> { + + async fn authenticate( + &self, + req: &Request, + body_bytes: Option<&[u8]>, + ) -> Result, Box> { let auth_header = match req.headers().get("Authorization") { Some(header) => header, None => return Ok(Err("Authorization header not found".to_string())), }; - + Ok(nip98_auth::nip98_verify_auth_header( auth_header.to_str()?.to_string(), &format!("{}{}", self.base_url, req.uri().path()), req.method().as_str(), - body_bytes - ).await) + body_bytes, + ) + .await) } - - async fn handle_user_info(&self, req: &ParsedRequest) -> Result> { + + async fn handle_user_info( + &self, + req: &ParsedRequest, + ) -> Result> { let body = req.body_json()?; - + if let Some(device_token) = body["deviceToken"].as_str() { let notification_manager = self.notification_manager.lock().await; notification_manager.save_user_device_info(req.authorized_pubkey, device_token)?; @@ -173,10 +199,13 @@ impl APIHandler { }); } } - - async fn handle_user_info_remove(&self, req: &ParsedRequest) -> Result> { + + async fn handle_user_info_remove( + &self, + req: &ParsedRequest, + ) -> Result> { let body: Value = req.body_json()?; - + if let Some(device_token) = body["deviceToken"].as_str() { let notification_manager = self.notification_manager.lock().await; notification_manager.remove_user_device_info(req.authorized_pubkey, device_token)?; @@ -194,8 +223,7 @@ impl APIHandler { } // Define enum error types including authentication error -#[derive(Debug)] -#[derive(Error)] +#[derive(Debug, Error)] enum APIError { #[error("Authentication error: {0}")] AuthenticationError(String), diff --git a/src/api_server/api_server.rs b/src/api_server/api_server.rs index 9ec4498..9709641 100644 --- a/src/api_server/api_server.rs +++ b/src/api_server/api_server.rs @@ -1,11 +1,11 @@ +use super::api_request_handler::APIHandler; +use crate::notification_manager::NotificationManager; use hyper::{server::conn::http1, service::service_fn}; use hyper_util::rt::TokioIo; +use log; use std::sync::Arc; use tokio::net::TcpListener; -use log; use tokio::sync::Mutex; -use crate::notification_manager::NotificationManager; -use super::api_request_handler::APIHandler; pub struct APIServer { host: String, @@ -14,7 +14,12 @@ pub struct APIServer { } impl APIServer { - pub async fn run(host: String, port: String, notification_manager: Arc>, base_url: String) -> Result<(), Box> { + pub async fn run( + host: String, + port: String, + notification_manager: Arc>, + base_url: String, + ) -> Result<(), Box> { let api_handler = APIHandler::new(notification_manager, base_url); let server = APIServer { host, @@ -23,21 +28,21 @@ impl APIServer { }; server.start().await } - + async fn start(&self) -> Result<(), Box> { let address = format!("{}:{}", self.host, self.port); let listener = TcpListener::bind(&address).await?; - + log::info!("HTTP server running at {}", address); - + loop { let (stream, _) = listener.accept().await?; let io = TokioIo::new(stream); let api_handler = self.api_handler.clone(); - + tokio::task::spawn(async move { let service = service_fn(|req| api_handler.handle_http_request(req)); - + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { log::error!("Failed to serve connection: {:?}", err); } diff --git a/src/api_server/mod.rs b/src/api_server/mod.rs index f3dc4bb..ffd1b5e 100644 --- a/src/api_server/mod.rs +++ b/src/api_server/mod.rs @@ -1,3 +1,3 @@ +pub mod api_request_handler; pub mod api_server; pub mod nip98_auth; -pub mod api_request_handler; diff --git a/src/api_server/nip98_auth.rs b/src/api_server/nip98_auth.rs index 71b7d12..396c7ca 100644 --- a/src/api_server/nip98_auth.rs +++ b/src/api_server/nip98_auth.rs @@ -1,16 +1,16 @@ use base64::prelude::*; -use serde_json::Value; +use nostr; use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash; use nostr::bitcoin::hashes::Hash; use nostr::util::hex; use nostr::Timestamp; -use nostr; +use serde_json::Value; pub async fn nip98_verify_auth_header( auth_header: String, url: &str, method: &str, - body: Option<&[u8]> + body: Option<&[u8]>, ) -> Result { if auth_header.is_empty() { return Err("Nostr authorization header missing".to_string()); @@ -30,23 +30,30 @@ pub async fn nip98_verify_auth_header( return Err("Nostr authorization header does not have a base64 encoded note".to_string()); } - let decoded_note_json = BASE64_STANDARD.decode(base64_encoded_note.as_bytes()) - .map_err(|_| format!("Failed to decode base64 encoded note from Nostr authorization header"))?; - + let decoded_note_json = BASE64_STANDARD + .decode(base64_encoded_note.as_bytes()) + .map_err(|_| { + format!("Failed to decode base64 encoded note from Nostr authorization header") + })?; + let note_value: Value = serde_json::from_slice(&decoded_note_json) .map_err(|_| format!("Could not parse JSON note from authorization header"))?; - + let note: nostr::Event = nostr::Event::from_value(note_value) .map_err(|_| format!("Could not parse Nostr note from JSON"))?; if note.kind != nostr::Kind::HttpAuth { return Err("Nostr note kind in authorization header is incorrect".to_string()); } - - let authorized_url = note.get_tag_content(nostr::TagKind::SingleLetter(nostr::SingleLetterTag::lowercase(nostr::Alphabet::U))) + + let authorized_url = note + .get_tag_content(nostr::TagKind::SingleLetter( + nostr::SingleLetterTag::lowercase(nostr::Alphabet::U), + )) .ok_or_else(|| "Missing 'u' tag from Nostr authorization header".to_string())?; - let authorized_method = note.get_tag_content(nostr::TagKind::Method) + let authorized_method = note + .get_tag_content(nostr::TagKind::Method) .ok_or_else(|| "Missing 'method' tag from Nostr authorization header".to_string())?; if authorized_url != url || authorized_method != method { @@ -59,7 +66,9 @@ pub async fn nip98_verify_auth_header( let current_time: nostr::Timestamp = nostr::Timestamp::now(); let note_created_at: nostr::Timestamp = note.created_at(); let time_delta = TimeDelta::subtracting(current_time, note_created_at); - if (time_delta.negative && time_delta.delta_abs_seconds > 30) || (!time_delta.negative && time_delta.delta_abs_seconds > 60) { + if (time_delta.negative && time_delta.delta_abs_seconds > 30) + || (!time_delta.negative && time_delta.delta_abs_seconds > 60) + { return Err(format!( "Auth note is too old. Current time: {}; Note created at: {}; Time delta: {} seconds", current_time, note_created_at, time_delta @@ -69,13 +78,16 @@ pub async fn nip98_verify_auth_header( if let Some(body_data) = body { let authorized_content_hash_bytes: Vec = hex::decode( note.get_tag_content(nostr::TagKind::Payload) - .ok_or("Missing 'payload' tag from Nostr authorization header")? + .ok_or("Missing 'payload' tag from Nostr authorization header")?, ) - .map_err(|_| format!("Failed to decode hex encoded payload from Nostr authorization header"))?; + .map_err(|_| { + format!("Failed to decode hex encoded payload from Nostr authorization header") + })?; + + let authorized_content_hash: Sha256Hash = + Sha256Hash::from_slice(&authorized_content_hash_bytes) + .map_err(|_| format!("Failed to convert hex encoded payload to Sha256Hash"))?; - let authorized_content_hash: Sha256Hash = Sha256Hash::from_slice(&authorized_content_hash_bytes) - .map_err(|_| format!("Failed to convert hex encoded payload to Sha256Hash"))?; - let body_hash = Sha256Hash::hash(body_data); if authorized_content_hash != body_hash { return Err("Auth note payload hash does not match request body hash".to_string()); @@ -91,13 +103,13 @@ pub async fn nip98_verify_auth_header( if note.verify().is_err() { return Err("Auth note id or signature is invalid".to_string()); } - + Ok(note.pubkey) } struct TimeDelta { delta_abs_seconds: u64, - negative: bool + negative: bool, } impl TimeDelta { @@ -107,12 +119,12 @@ impl TimeDelta { if t1 > t2 { TimeDelta { delta_abs_seconds: (t1 - t2).as_u64(), - negative: false + negative: false, } } else { TimeDelta { delta_abs_seconds: (t2 - t1).as_u64(), - negative: true + negative: true, } } } diff --git a/src/main.rs b/src/main.rs index 524aea9..4e780e5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,43 +1,47 @@ #![forbid(unsafe_code)] +use api_server::api_server::APIServer; use std::net::TcpListener; use std::sync::Arc; -use api_server::api_server::APIServer; use tokio::sync::Mutex; mod notification_manager; -use log; use env_logger; +use log; use r2d2_sqlite::SqliteConnectionManager; mod relay_connection; -use relay_connection::RelayConnection; use r2d2; +use relay_connection::RelayConnection; mod notepush_env; use notepush_env::NotePushEnv; mod api_server; #[tokio::main] -async fn main () { - +async fn main() { // MARK: - Setup basics - + env_logger::init(); - + let env = NotePushEnv::load_env().expect("Failed to load environment variables"); let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address"); - + let manager = SqliteConnectionManager::file(env.db_path.clone()); - let pool: r2d2::Pool = r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool"); + let pool: r2d2::Pool = + r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool"); // Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter. // This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections. - let notification_manager = Arc::new(Mutex::new(notification_manager::NotificationManager::new( - pool, - env.relay_url.clone(), - env.apns_private_key_path.clone(), - env.apns_private_key_id.clone(), - env.apns_team_id.clone(), - env.apns_environment.clone(), - env.apns_topic.clone(), - ).await.expect("Failed to create notification manager"))); - + let notification_manager = Arc::new(Mutex::new( + notification_manager::NotificationManager::new( + pool, + env.relay_url.clone(), + env.apns_private_key_path.clone(), + env.apns_private_key_id.clone(), + env.apns_team_id.clone(), + env.apns_environment.clone(), + env.apns_topic.clone(), + ) + .await + .expect("Failed to create notification manager"), + )); + // MARK: - Start the API server { let notification_manager = notification_manager.clone(); @@ -45,24 +49,32 @@ async fn main () { let api_port = env.api_port.clone(); let api_base_url = env.api_base_url.clone(); tokio::spawn(async move { - APIServer::run(api_host, api_port, notification_manager, api_base_url).await.expect("Failed to start API server"); + APIServer::run(api_host, api_port, notification_manager, api_base_url) + .await + .expect("Failed to start API server"); }); } - + // MARK: - Start handling incoming connections - + log::info!("Relay server listening on {}", env.relay_address().clone()); - + for stream in server.incoming() { if let Ok(stream) = stream { - let peer_address_string = stream.peer_addr().map_or("unknown".to_string(), |addr| addr.to_string()); + let peer_address_string = stream + .peer_addr() + .map_or("unknown".to_string(), |addr| addr.to_string()); log::info!("New connection from {}", peer_address_string); let notification_manager = notification_manager.clone(); tokio::spawn(async move { match RelayConnection::run(stream, notification_manager).await { Ok(_) => {} Err(e) => { - log::error!("Error with websocket connection from {}: {:?}", peer_address_string, e); + log::error!( + "Error with websocket connection from {}: {:?}", + peer_address_string, + e + ); } } }); diff --git a/src/notepush_env.rs b/src/notepush_env.rs index 4923e95..b62b08d 100644 --- a/src/notepush_env.rs +++ b/src/notepush_env.rs @@ -1,6 +1,6 @@ -use std::env; -use dotenv::dotenv; use a2; +use dotenv::dotenv; +use std::env; const DEFAULT_DB_PATH: &str = "./apns_notifications.db"; const DEFAULT_RELAY_HOST: &str = "0.0.0.0"; @@ -28,7 +28,7 @@ pub struct NotePushEnv { // The host and port to bind the API server to pub api_host: String, pub api_port: String, - pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks + pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks // The URL of the Nostr relay server to connect to for getting mutelists pub relay_url: String, } @@ -43,17 +43,19 @@ impl NotePushEnv { let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string()); let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string()); let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string()); - let apns_environment_string = env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string()); + let apns_environment_string = + env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string()); let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string()); let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string()); - let api_base_url = env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port)); + let api_base_url = + env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port)); let apns_environment = match apns_environment_string.as_str() { "development" => a2::client::Endpoint::Sandbox, "production" => a2::client::Endpoint::Production, _ => a2::client::Endpoint::Sandbox, }; let apns_topic = env::var("APNS_TOPIC")?; - + Ok(NotePushEnv { apns_private_key_path, apns_private_key_id, @@ -69,7 +71,7 @@ impl NotePushEnv { relay_url, }) } - + pub fn relay_address(&self) -> String { format!("{}:{}", self.relay_host, self.relay_port) } diff --git a/src/notification_manager/mod.rs b/src/notification_manager/mod.rs index f0bfb17..653c48e 100644 --- a/src/notification_manager/mod.rs +++ b/src/notification_manager/mod.rs @@ -1,7 +1,7 @@ -pub mod notification_manager; pub mod mute_manager; mod nostr_event_extensions; +pub mod notification_manager; -pub use notification_manager::NotificationManager; pub use mute_manager::MuteManager; use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible}; +pub use notification_manager::NotificationManager; diff --git a/src/notification_manager/mute_manager.rs b/src/notification_manager/mute_manager.rs index 8ec9e7f..0205ac6 100644 --- a/src/notification_manager/mute_manager.rs +++ b/src/notification_manager/mute_manager.rs @@ -1,5 +1,5 @@ -use nostr_sdk::prelude::*; use super::ExtendedEvent; +use nostr_sdk::prelude::*; pub struct MuteManager { relay_url: String, @@ -11,45 +11,67 @@ impl MuteManager { let client = Client::new(&Keys::generate()); client.add_relay(relay_url.clone()).await?; client.connect().await; - Ok(MuteManager { - relay_url, - client - }) + Ok(MuteManager { relay_url, client }) } - pub async fn should_mute_notification_for_pubkey(&self, event: &Event, pubkey: &PublicKey) -> bool { + pub async fn should_mute_notification_for_pubkey( + &self, + event: &Event, + pubkey: &PublicKey, + ) -> bool { if let Some(mute_list) = self.get_public_mute_list(pubkey).await { for tag in mute_list.tags() { match tag.kind() { - TagKind::SingleLetter(SingleLetterTag { character: Alphabet::P, uppercase: false }) => { - let tagged_pubkey: Option = tag.content().and_then(|h| { PublicKey::from_hex(h).ok() }); + TagKind::SingleLetter(SingleLetterTag { + character: Alphabet::P, + uppercase: false, + }) => { + let tagged_pubkey: Option = + tag.content().and_then(|h| PublicKey::from_hex(h).ok()); if let Some(tagged_pubkey) = tagged_pubkey { if event.pubkey == tagged_pubkey { - return true + return true; } } } - TagKind::SingleLetter(SingleLetterTag { character: Alphabet::E, uppercase: false }) => { - let tagged_event_id: Option = tag.content().and_then(|h| { EventId::from_hex(h).ok() }); + TagKind::SingleLetter(SingleLetterTag { + character: Alphabet::E, + uppercase: false, + }) => { + let tagged_event_id: Option = + tag.content().and_then(|h| EventId::from_hex(h).ok()); if let Some(tagged_event_id) = tagged_event_id { - if event.id == tagged_event_id || event.referenced_event_ids().contains(&tagged_event_id) { - return true + if event.id == tagged_event_id + || event.referenced_event_ids().contains(&tagged_event_id) + { + return true; } } } - TagKind::SingleLetter(SingleLetterTag { character: Alphabet::T, uppercase: false }) => { + TagKind::SingleLetter(SingleLetterTag { + character: Alphabet::T, + uppercase: false, + }) => { let tagged_hashtag: Option = tag.content().map(|h| h.to_string()); if let Some(tagged_hashtag) = tagged_hashtag { - let tags_content = event.get_tags_content(TagKind::SingleLetter(SingleLetterTag { character: Alphabet::T, uppercase: false })); + let tags_content = + event.get_tags_content(TagKind::SingleLetter(SingleLetterTag { + character: Alphabet::T, + uppercase: false, + })); let should_mute = tags_content.iter().any(|t| t == &tagged_hashtag); - return should_mute + return should_mute; } } TagKind::Word => { let tagged_word: Option = tag.content().map(|h| h.to_string()); if let Some(tagged_word) = tagged_word { - if event.content.to_lowercase().contains(&tagged_word.to_lowercase()) { - return true + if event + .content + .to_lowercase() + .contains(&tagged_word.to_lowercase()) + { + return true; } } } @@ -66,15 +88,23 @@ impl MuteManager { .authors(vec![pubkey.clone()]) .limit(1); - let this_subscription_id = self.client.subscribe(Vec::from([subscription_filter]), None).await; - + let this_subscription_id = self + .client + .subscribe(Vec::from([subscription_filter]), None) + .await; + let mut mute_list: Option = None; let mut notifications = self.client.notifications(); while let Ok(notification) = notifications.recv().await { - if let RelayPoolNotification::Event { subscription_id, event, .. } = notification { + if let RelayPoolNotification::Event { + subscription_id, + event, + .. + } = notification + { if this_subscription_id == subscription_id && event.kind == Kind::MuteList { mute_list = Some((*event).clone()); - break + break; } } } diff --git a/src/notification_manager/nostr_event_extensions.rs b/src/notification_manager/nostr_event_extensions.rs index 1b791b6..ad9fca9 100644 --- a/src/notification_manager/nostr_event_extensions.rs +++ b/src/notification_manager/nostr_event_extensions.rs @@ -1,39 +1,37 @@ -use nostr::{self, key::PublicKey, TagKind::SingleLetter, Alphabet, SingleLetterTag}; +use nostr::{self, key::PublicKey, Alphabet, SingleLetterTag, TagKind::SingleLetter}; /// Temporary scaffolding of old methods that have not been ported to use native Event methods pub trait ExtendedEvent { /// Checks if the note references a given pubkey - fn references_pubkey(&self, pubkey: &PublicKey) -> bool; + fn references_pubkey(&self, pubkey: &PublicKey) -> bool; /// Retrieves a set of pubkeys referenced by the note - fn referenced_pubkeys(&self) -> std::collections::HashSet; + fn referenced_pubkeys(&self) -> std::collections::HashSet; /// Retrieves a set of pubkeys relevant to the note - fn relevant_pubkeys(&self) -> std::collections::HashSet; + fn relevant_pubkeys(&self) -> std::collections::HashSet; /// Retrieves a set of event IDs referenced by the note - fn referenced_event_ids(&self) -> std::collections::HashSet; + fn referenced_event_ids(&self) -> std::collections::HashSet; } // This is a wrapper around the Event type from strfry-policies, which adds some useful methods impl ExtendedEvent for nostr::Event { /// Checks if the note references a given pubkey - fn references_pubkey(&self, pubkey: &PublicKey) -> bool { + fn references_pubkey(&self, pubkey: &PublicKey) -> bool { self.referenced_pubkeys().contains(pubkey) } /// Retrieves a set of pubkeys referenced by the note - fn referenced_pubkeys(&self) -> std::collections::HashSet { + fn referenced_pubkeys(&self) -> std::collections::HashSet { self.get_tags_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::P))) .iter() - .filter_map(|tag| { - PublicKey::from_hex(tag).ok() - }) + .filter_map(|tag| PublicKey::from_hex(tag).ok()) .collect() } /// Retrieves a set of pubkeys relevant to the note - fn relevant_pubkeys(&self) -> std::collections::HashSet { + fn relevant_pubkeys(&self) -> std::collections::HashSet { let mut pubkeys = self.referenced_pubkeys(); pubkeys.insert(self.pubkey.clone()); pubkeys @@ -43,9 +41,7 @@ impl ExtendedEvent for nostr::Event { fn referenced_event_ids(&self) -> std::collections::HashSet { self.get_tag_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::E))) .iter() - .filter_map(|tag| { - nostr::EventId::from_hex(tag).ok() - }) + .filter_map(|tag| nostr::EventId::from_hex(tag).ok()) .collect() } } @@ -54,14 +50,16 @@ impl ExtendedEvent for nostr::Event { pub trait SqlStringConvertible { fn to_sql_string(&self) -> String; - fn from_sql_string(s: String) -> Result> where Self: Sized; + fn from_sql_string(s: String) -> Result> + where + Self: Sized; } impl SqlStringConvertible for nostr::EventId { fn to_sql_string(&self) -> String { self.to_hex() } - + fn from_sql_string(s: String) -> Result> { nostr::EventId::from_hex(s).map_err(|e| e.into()) } @@ -71,7 +69,7 @@ impl SqlStringConvertible for nostr::PublicKey { fn to_sql_string(&self) -> String { self.to_hex() } - + fn from_sql_string(s: String) -> Result> { nostr::PublicKey::from_hex(s).map_err(|e| e.into()) } @@ -81,7 +79,7 @@ impl SqlStringConvertible for nostr::Timestamp { fn to_sql_string(&self) -> String { self.as_u64().to_string() } - + fn from_sql_string(s: String) -> Result> { let u64_timestamp: u64 = s.parse()?; Ok(nostr::Timestamp::from(u64_timestamp)) diff --git a/src/notification_manager/notification_manager.rs b/src/notification_manager/notification_manager.rs index 6bb5477..ae477e9 100644 --- a/src/notification_manager/notification_manager.rs +++ b/src/notification_manager/notification_manager.rs @@ -1,19 +1,19 @@ use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder}; +use log; use nostr::event::EventId; use nostr::key::PublicKey; use nostr::types::Timestamp; use rusqlite; use rusqlite::params; use std::collections::HashSet; -use log; -use std::fs::File; use super::mute_manager::MuteManager; -use nostr::Event; -use super::SqlStringConvertible; use super::ExtendedEvent; -use r2d2_sqlite::SqliteConnectionManager; +use super::SqlStringConvertible; +use nostr::Event; use r2d2; +use r2d2_sqlite::SqliteConnectionManager; +use std::fs::File; // MARK: - NotificationManager @@ -31,20 +31,28 @@ pub struct NotificationManager { impl NotificationManager { // MARK: - Initialization - - pub async fn new(db: r2d2::Pool, relay_url: String, apns_private_key_path: String, apns_private_key_id: String, apns_team_id: String, apns_environment: a2::client::Endpoint, apns_topic: String) -> Result> { + + pub async fn new( + db: r2d2::Pool, + relay_url: String, + apns_private_key_path: String, + apns_private_key_id: String, + apns_team_id: String, + apns_environment: a2::client::Endpoint, + apns_topic: String, + ) -> Result> { let mute_manager = MuteManager::new(relay_url.clone()).await?; - + let connection = db.get()?; Self::setup_database(&connection)?; - + let mut file = File::open(&apns_private_key_path)?; - + let client = Client::token( &mut file, &apns_private_key_id, &apns_team_id, - ClientConfig::new(apns_environment.clone()) + ClientConfig::new(apns_environment.clone()), )?; Ok(Self { @@ -58,7 +66,7 @@ impl NotificationManager { mute_manager, }) } - + // MARK: - Database setup operations pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> { @@ -93,39 +101,58 @@ impl NotificationManager { Self::add_column_if_not_exists(&db, "notifications", "sent_at", "INTEGER")?; Self::add_column_if_not_exists(&db, "user_info", "added_at", "INTEGER")?; - + Ok(()) } - fn add_column_if_not_exists(db: &rusqlite::Connection, table_name: &str, column_name: &str, column_type: &str) -> Result<(), rusqlite::Error> { + fn add_column_if_not_exists( + db: &rusqlite::Connection, + table_name: &str, + column_name: &str, + column_type: &str, + ) -> Result<(), rusqlite::Error> { let query = format!("PRAGMA table_info({})", table_name); let mut stmt = db.prepare(&query)?; - let column_names: Vec = stmt.query_map([], |row| row.get(1))? + let column_names: Vec = stmt + .query_map([], |row| row.get(1))? .filter_map(|r| r.ok()) .collect(); if !column_names.contains(&column_name.to_string()) { - let query = format!("ALTER TABLE {} ADD COLUMN {} {}", table_name, column_name, column_type); + let query = format!( + "ALTER TABLE {} ADD COLUMN {} {}", + table_name, column_name, column_type + ); db.execute(&query, [])?; } Ok(()) } - + // MARK: - Business logic - pub async fn send_notifications_if_needed(&self, event: &Event) -> Result<(), Box> { - log::debug!("Checking if notifications need to be sent for event: {}", event.id); + pub async fn send_notifications_if_needed( + &self, + event: &Event, + ) -> Result<(), Box> { + log::debug!( + "Checking if notifications need to be sent for event: {}", + event.id + ); let one_week_ago = nostr::Timestamp::now() - 7 * 24 * 60 * 60; if event.created_at < one_week_ago { return Ok(()); } let pubkeys_to_notify = self.pubkeys_to_notify_for_event(event).await?; - - log::debug!("Sending notifications to {} pubkeys", pubkeys_to_notify.len()); + + log::debug!( + "Sending notifications to {} pubkeys", + pubkeys_to_notify.len() + ); for pubkey in pubkeys_to_notify { - self.send_event_notifications_to_pubkey(event, &pubkey).await?; + self.send_event_notifications_to_pubkey(event, &pubkey) + .await?; self.db.get()?.execute( "INSERT OR REPLACE INTO notifications (id, event_id, pubkey, received_notification, sent_at) VALUES (?, ?, ?, ?, ?)", @@ -141,10 +168,14 @@ impl NotificationManager { Ok(()) } - async fn pubkeys_to_notify_for_event(&self, event: &Event) -> Result, Box> { + async fn pubkeys_to_notify_for_event( + &self, + event: &Event, + ) -> Result, Box> { let notification_status = self.get_notification_status(event)?; let relevant_pubkeys = self.pubkeys_relevant_to_event(event)?; - let pubkeys_that_received_notification = notification_status.pubkeys_that_received_notification(); + let pubkeys_that_received_notification = + notification_status.pubkeys_that_received_notification(); let relevant_pubkeys_yet_to_receive: HashSet = relevant_pubkeys .difference(&pubkeys_that_received_notification) .filter(|&x| *x != event.pubkey) @@ -153,7 +184,10 @@ impl NotificationManager { let mut pubkeys_to_notify = HashSet::new(); for pubkey in relevant_pubkeys_yet_to_receive { - let should_mute: bool = self.mute_manager.should_mute_notification_for_pubkey(event, &pubkey).await; + let should_mute: bool = self + .mute_manager + .should_mute_notification_for_pubkey(event, &pubkey) + .await; if !should_mute { pubkeys_to_notify.insert(pubkey); } @@ -161,53 +195,79 @@ impl NotificationManager { Ok(pubkeys_to_notify) } - fn pubkeys_relevant_to_event(&self, event: &Event) -> Result, Box> { + fn pubkeys_relevant_to_event( + &self, + event: &Event, + ) -> Result, Box> { let mut relevant_pubkeys = event.relevant_pubkeys(); let referenced_event_ids = event.referenced_event_ids(); for referenced_event_id in referenced_event_ids { - let pubkeys_relevant_to_referenced_event = self.pubkeys_subscribed_to_event_id(&referenced_event_id)?; + let pubkeys_relevant_to_referenced_event = + self.pubkeys_subscribed_to_event_id(&referenced_event_id)?; relevant_pubkeys.extend(pubkeys_relevant_to_referenced_event); } Ok(relevant_pubkeys) } - fn pubkeys_subscribed_to_event(&self, event: &Event) -> Result, Box> { + fn pubkeys_subscribed_to_event( + &self, + event: &Event, + ) -> Result, Box> { self.pubkeys_subscribed_to_event_id(&event.id) } - fn pubkeys_subscribed_to_event_id(&self, event_id: &EventId) -> Result, Box> { + fn pubkeys_subscribed_to_event_id( + &self, + event_id: &EventId, + ) -> Result, Box> { let connection = self.db.get()?; let mut stmt = connection.prepare("SELECT pubkey FROM notifications WHERE event_id = ?")?; - let pubkeys = stmt.query_map([event_id.to_sql_string()], |row| row.get(0))? + let pubkeys = stmt + .query_map([event_id.to_sql_string()], |row| row.get(0))? .filter_map(|r| r.ok()) .filter_map(|r: String| PublicKey::from_sql_string(r).ok()) .collect(); Ok(pubkeys) } - async fn send_event_notifications_to_pubkey(&self, event: &Event, pubkey: &PublicKey) -> Result<(), Box> { + async fn send_event_notifications_to_pubkey( + &self, + event: &Event, + pubkey: &PublicKey, + ) -> Result<(), Box> { let user_device_tokens = self.get_user_device_tokens(pubkey)?; for device_token in user_device_tokens { - self.send_event_notification_to_device_token(event, &device_token).await?; + self.send_event_notification_to_device_token(event, &device_token) + .await?; } Ok(()) } - fn get_user_device_tokens(&self, pubkey: &PublicKey) -> Result, Box> { + fn get_user_device_tokens( + &self, + pubkey: &PublicKey, + ) -> Result, Box> { let connection = self.db.get()?; let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?; - let device_tokens = stmt.query_map([pubkey.to_sql_string()], |row| row.get(0))? + let device_tokens = stmt + .query_map([pubkey.to_sql_string()], |row| row.get(0))? .filter_map(|r| r.ok()) .collect(); Ok(device_tokens) } - fn get_notification_status(&self, event: &Event) -> Result> { + fn get_notification_status( + &self, + event: &Event, + ) -> Result> { let connection = self.db.get()?; - let mut stmt = connection.prepare("SELECT pubkey, received_notification FROM notifications WHERE event_id = ?")?; - let rows: std::collections::HashMap = stmt.query_map([event.id.to_sql_string()], |row| { - Ok((row.get(0)?, row.get(1)?)) - })? + let mut stmt = connection.prepare( + "SELECT pubkey, received_notification FROM notifications WHERE event_id = ?", + )?; + let rows: std::collections::HashMap = stmt + .query_map([event.id.to_sql_string()], |row| { + Ok((row.get(0)?, row.get(1)?)) + })? .filter_map(|r: Result<(String, bool), rusqlite::Error>| r.ok()) .filter_map(|r: (String, bool)| { let pubkey = PublicKey::from_sql_string(r.0).ok()?; @@ -225,9 +285,13 @@ impl NotificationManager { Ok(NotificationStatus { status_info }) } - async fn send_event_notification_to_device_token(&self, event: &Event, device_token: &str) -> Result<(), Box> { + async fn send_event_notification_to_device_token( + &self, + event: &Event, + device_token: &str, + ) -> Result<(), Box> { let (title, subtitle, body) = self.format_notification_message(event); - + log::debug!("Sending notification to device token: {}", device_token); let builder = DefaultNotificationBuilder::new() @@ -236,18 +300,15 @@ impl NotificationManager { .set_body(&body) .set_mutable_content() .set_content_available(); - - let mut payload = builder.build( - device_token, - Default::default() - ); + + let mut payload = builder.build(device_token, Default::default()); let _ = payload.add_custom_data("nostr_event", event); payload.options.apns_topic = Some(self.apns_topic.as_str()); - + let _response = self.apns_client.send(payload).await?; - + log::info!("Notification sent to device token: {}", device_token); - + Ok(()) } @@ -258,7 +319,11 @@ impl NotificationManager { (title, subtitle, body) } - pub fn save_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box> { + pub fn save_user_device_info( + &self, + pubkey: nostr::PublicKey, + device_token: &str, + ) -> Result<(), Box> { let current_time_unix = Timestamp::now(); self.db.get()?.execute( "INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)", @@ -272,7 +337,11 @@ impl NotificationManager { Ok(()) } - pub fn remove_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box> { + pub fn remove_user_device_info( + &self, + pubkey: nostr::PublicKey, + device_token: &str, + ) -> Result<(), Box> { self.db.get()?.execute( "DELETE FROM user_info WHERE pubkey = ? AND device_token = ?", params![pubkey.to_sql_string(), device_token], diff --git a/src/relay_connection.rs b/src/relay_connection.rs index 9c41fbe..aa4efc4 100644 --- a/src/relay_connection.rs +++ b/src/relay_connection.rs @@ -1,43 +1,48 @@ +use crate::notification_manager::NotificationManager; +use log; use nostr::util::JsonUtil; -use nostr::{RelayMessage, ClientMessage}; +use nostr::{ClientMessage, RelayMessage}; +use serde_json::Value; +use std::fmt::{self, Debug}; +use std::net::TcpStream; +use std::str::FromStr; use std::sync::Arc; use tokio::sync::Mutex; -use serde_json::Value; -use crate::notification_manager::NotificationManager; -use std::str::FromStr; -use std::net::TcpStream; use tungstenite::{accept, WebSocket}; -use log; -use std::fmt::{self, Debug}; const MAX_CONSECUTIVE_ERRORS: u32 = 10; pub struct RelayConnection { websocket: WebSocket, - notification_manager: Arc> + notification_manager: Arc>, } impl RelayConnection { - // MARK: - Initializers - - pub fn new(stream: TcpStream, notification_manager: Arc>) -> Result> { + + pub fn new( + stream: TcpStream, + notification_manager: Arc>, + ) -> Result> { let address = stream.peer_addr()?; let websocket = accept(stream)?; log::info!("Accepted connection from {:?}", address); Ok(RelayConnection { websocket, - notification_manager + notification_manager, }) } - - pub async fn run(stream: TcpStream, notification_manager: Arc>) -> Result<(), Box> { + + pub async fn run( + stream: TcpStream, + notification_manager: Arc>, + ) -> Result<(), Box> { let mut connection = RelayConnection::new(stream, notification_manager)?; Ok(connection.run_loop().await?) } - + // MARK: - Connection Runtime management - + pub async fn run_loop(&mut self) -> Result<(), Box> { let mut consecutive_errors = 0; log::debug!("Starting run loop for connection with {:?}", self.websocket); @@ -47,31 +52,43 @@ impl RelayConnection { consecutive_errors = 0; } Err(e) => { - log::error!("Error in websocket connection with {:?}: {:?}", self.websocket, e); + log::error!( + "Error in websocket connection with {:?}: {:?}", + self.websocket, + e + ); consecutive_errors += 1; if consecutive_errors >= MAX_CONSECUTIVE_ERRORS { - log::error!("Too many consecutive errors, closing connection with {:?}", self.websocket); + log::error!( + "Too many consecutive errors, closing connection with {:?}", + self.websocket + ); return Err(e); } } } } } - + pub async fn run_loop_iteration<'a>(&'a mut self) -> Result<(), Box> { let websocket = &mut self.websocket; let raw_message = websocket.read()?; if raw_message.is_text() { - let message: ClientMessage = ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?; + let message: ClientMessage = + ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?; let response = self.handle_client_message(message).await?; - self.websocket.send(tungstenite::Message::text(response.try_as_json()?))?; + self.websocket + .send(tungstenite::Message::text(response.try_as_json()?))?; } Ok(()) } - + // MARK: - Message handling - - async fn handle_client_message<'b>(&'b self, message: ClientMessage) -> Result> { + + async fn handle_client_message<'b>( + &'b self, + message: ClientMessage, + ) -> Result> { match message { ClientMessage::Event(event) => { log::info!("Received event: {:?}", event); @@ -79,15 +96,21 @@ impl RelayConnection { // TODO: Reduce resource contention by reducing the scope of the mutex into NotificationManager logic. let mutex_guard = self.notification_manager.lock().await; mutex_guard.send_notifications_if_needed(&event).await?; - }; // Only hold the mutex for as little time as possible. + }; // Only hold the mutex for as little time as possible. let notice_message = format!("blocked: This relay does not store events"); - let response = RelayMessage::Ok { event_id: event.id, status: false, message: notice_message }; + let response = RelayMessage::Ok { + event_id: event.id, + status: false, + message: notice_message, + }; Ok(response) } _ => { log::info!("Received unsupported message: {:?}", message); let notice_message = format!("Unsupported message: {:?}", message); - let response = RelayMessage::Notice { message: notice_message }; + let response = RelayMessage::Notice { + message: notice_message, + }; Ok(response) } }