run cargo fmt --all

we should try to stick to rustfmt style, many editors save it like this
by default.

Signed-off-by: William Casarin <jb55@jb55.com>
This commit is contained in:
William Casarin
2024-07-31 17:44:34 -07:00
parent e704f6e24c
commit 36cc9f8742
11 changed files with 419 additions and 240 deletions

View File

@ -1,18 +1,18 @@
use super::nip98_auth; use super::nip98_auth;
use hyper::{Request, Response, StatusCode};
use hyper::body::Buf; use hyper::body::Buf;
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use http_body_util::BodyExt; use http_body_util::BodyExt;
use nostr; use nostr;
use thiserror::Error;
use std::sync::Arc;
use log;
use hyper::Method;
use tokio::sync::Mutex;
use serde_json::{json, Value};
use crate::notification_manager::NotificationManager; use crate::notification_manager::NotificationManager;
use hyper::Method;
use log;
use serde_json::{json, Value};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Mutex;
struct ParsedRequest { struct ParsedRequest {
uri: String, uri: String,
@ -57,31 +57,33 @@ impl APIHandler {
base_url, base_url,
} }
} }
pub async fn handle_http_request(&self, req: Request<Incoming>) -> Result<Response<String>, hyper::http::Error> { pub async fn handle_http_request(
&self,
req: Request<Incoming>,
) -> Result<Response<String>, hyper::http::Error> {
let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await { let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await {
Ok(api_response) => { Ok(api_response) => APIResponse {
APIResponse { status: api_response.status,
status: api_response.status, body: api_response.body,
body: api_response.body,
}
}, },
Err(err) => { Err(err) => {
// Detect if error is a APIError::AuthenticationError and return a 401 status code // Detect if error is a APIError::AuthenticationError and return a 401 status code
if let Some(api_error) = err.downcast_ref::<APIError>() { if let Some(api_error) = err.downcast_ref::<APIError>() {
match api_error { match api_error {
APIError::AuthenticationError(message) => { APIError::AuthenticationError(message) => APIResponse {
APIResponse { status: StatusCode::UNAUTHORIZED,
status: StatusCode::UNAUTHORIZED, body: json!({ "error": "Unauthorized", "message": message }),
body: json!({ "error": "Unauthorized", "message": message }),
}
}, },
} }
} } else {
else {
// Otherwise, return a 500 status code // Otherwise, return a 500 status code
let random_case_uuid = uuid::Uuid::new_v4(); let random_case_uuid = uuid::Uuid::new_v4();
log::error!("Error handling request: {} (Case ID: {})", err, random_case_uuid); log::error!(
"Error handling request: {} (Case ID: {})",
err,
random_case_uuid
);
APIResponse { APIResponse {
status: StatusCode::INTERNAL_SERVER_ERROR, status: StatusCode::INTERNAL_SERVER_ERROR,
body: json!({ "error": "Internal server error", "message": format!("Case ID: {}", random_case_uuid) }), body: json!({ "error": "Internal server error", "message": format!("Case ID: {}", random_case_uuid) }),
@ -93,33 +95,46 @@ impl APIHandler {
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Access-Control-Allow-Origin", "*") .header("Access-Control-Allow-Origin", "*")
.status(final_api_response.status) .status(final_api_response.status)
.body(final_api_response.body.to_string())? .body(final_api_response.body.to_string())?)
)
} }
async fn try_to_handle_http_request(&self, mut req: Request<Incoming>) -> Result<APIResponse, Box<dyn std::error::Error>> { async fn try_to_handle_http_request(
let parsed_request = self.parse_http_request(&mut req).await?; &self,
let api_response: APIResponse = self.handle_parsed_http_request(&parsed_request).await?; mut req: Request<Incoming>,
log::info!("[{}] {} (Authorized pubkey: {}): {}", req.method(), req.uri(), parsed_request.authorized_pubkey, api_response.status); ) -> Result<APIResponse, Box<dyn std::error::Error>> {
let parsed_request = self.parse_http_request(&mut req).await?;
let api_response: APIResponse = self.handle_parsed_http_request(&parsed_request).await?;
log::info!(
"[{}] {} (Authorized pubkey: {}): {}",
req.method(),
req.uri(),
parsed_request.authorized_pubkey,
api_response.status
);
Ok(api_response) Ok(api_response)
} }
async fn parse_http_request(&self, req: &mut Request<Incoming>) -> Result<ParsedRequest, Box<dyn std::error::Error>> { async fn parse_http_request(
&self,
req: &mut Request<Incoming>,
) -> Result<ParsedRequest, Box<dyn std::error::Error>> {
// 1. Read the request body // 1. Read the request body
let body_buffer = req.body_mut().collect().await?.aggregate(); let body_buffer = req.body_mut().collect().await?.aggregate();
let body_bytes = body_buffer.chunk(); let body_bytes = body_buffer.chunk();
let body_bytes = if body_bytes.is_empty() { None } else { Some(body_bytes) }; let body_bytes = if body_bytes.is_empty() {
None
} else {
Some(body_bytes)
};
// 2. NIP-98 authentication // 2. NIP-98 authentication
let authorized_pubkey = match self.authenticate(&req, body_bytes).await? { let authorized_pubkey = match self.authenticate(&req, body_bytes).await? {
Ok(pubkey) => { Ok(pubkey) => pubkey,
pubkey
},
Err(auth_error) => { Err(auth_error) => {
return Err(Box::new(APIError::AuthenticationError(auth_error))); return Err(Box::new(APIError::AuthenticationError(auth_error)));
} }
}; };
// 3. Parse the request // 3. Parse the request
Ok(ParsedRequest { Ok(ParsedRequest {
uri: req.uri().path().to_string(), uri: req.uri().path().to_string(),
@ -128,37 +143,48 @@ impl APIHandler {
authorized_pubkey, authorized_pubkey,
}) })
} }
async fn handle_parsed_http_request(&self, parsed_request: &ParsedRequest) -> Result<APIResponse, Box<dyn std::error::Error>> { async fn handle_parsed_http_request(
&self,
parsed_request: &ParsedRequest,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
match (&parsed_request.method, parsed_request.uri.as_str()) { match (&parsed_request.method, parsed_request.uri.as_str()) {
(&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await, (&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await,
(&Method::POST, "/user-info/remove") => self.handle_user_info_remove(parsed_request).await, (&Method::POST, "/user-info/remove") => {
_ => { self.handle_user_info_remove(parsed_request).await
Ok(APIResponse {
status: StatusCode::NOT_FOUND,
body: json!({ "error": "Not found" }),
})
} }
_ => Ok(APIResponse {
status: StatusCode::NOT_FOUND,
body: json!({ "error": "Not found" }),
}),
} }
} }
async fn authenticate(&self, req: &Request<Incoming>, body_bytes: Option<&[u8]>) -> Result<Result<nostr::PublicKey, String>, Box<dyn std::error::Error>> { async fn authenticate(
&self,
req: &Request<Incoming>,
body_bytes: Option<&[u8]>,
) -> Result<Result<nostr::PublicKey, String>, Box<dyn std::error::Error>> {
let auth_header = match req.headers().get("Authorization") { let auth_header = match req.headers().get("Authorization") {
Some(header) => header, Some(header) => header,
None => return Ok(Err("Authorization header not found".to_string())), None => return Ok(Err("Authorization header not found".to_string())),
}; };
Ok(nip98_auth::nip98_verify_auth_header( Ok(nip98_auth::nip98_verify_auth_header(
auth_header.to_str()?.to_string(), auth_header.to_str()?.to_string(),
&format!("{}{}", self.base_url, req.uri().path()), &format!("{}{}", self.base_url, req.uri().path()),
req.method().as_str(), req.method().as_str(),
body_bytes body_bytes,
).await) )
.await)
} }
async fn handle_user_info(&self, req: &ParsedRequest) -> Result<APIResponse, Box<dyn std::error::Error>> { async fn handle_user_info(
&self,
req: &ParsedRequest,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
let body = req.body_json()?; let body = req.body_json()?;
if let Some(device_token) = body["deviceToken"].as_str() { if let Some(device_token) = body["deviceToken"].as_str() {
let notification_manager = self.notification_manager.lock().await; let notification_manager = self.notification_manager.lock().await;
notification_manager.save_user_device_info(req.authorized_pubkey, device_token)?; notification_manager.save_user_device_info(req.authorized_pubkey, device_token)?;
@ -173,10 +199,13 @@ impl APIHandler {
}); });
} }
} }
async fn handle_user_info_remove(&self, req: &ParsedRequest) -> Result<APIResponse, Box<dyn std::error::Error>> { async fn handle_user_info_remove(
&self,
req: &ParsedRequest,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
let body: Value = req.body_json()?; let body: Value = req.body_json()?;
if let Some(device_token) = body["deviceToken"].as_str() { if let Some(device_token) = body["deviceToken"].as_str() {
let notification_manager = self.notification_manager.lock().await; let notification_manager = self.notification_manager.lock().await;
notification_manager.remove_user_device_info(req.authorized_pubkey, device_token)?; notification_manager.remove_user_device_info(req.authorized_pubkey, device_token)?;
@ -194,8 +223,7 @@ impl APIHandler {
} }
// Define enum error types including authentication error // Define enum error types including authentication error
#[derive(Debug)] #[derive(Debug, Error)]
#[derive(Error)]
enum APIError { enum APIError {
#[error("Authentication error: {0}")] #[error("Authentication error: {0}")]
AuthenticationError(String), AuthenticationError(String),

View File

@ -1,11 +1,11 @@
use super::api_request_handler::APIHandler;
use crate::notification_manager::NotificationManager;
use hyper::{server::conn::http1, service::service_fn}; use hyper::{server::conn::http1, service::service_fn};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use log;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use log;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::notification_manager::NotificationManager;
use super::api_request_handler::APIHandler;
pub struct APIServer { pub struct APIServer {
host: String, host: String,
@ -14,7 +14,12 @@ pub struct APIServer {
} }
impl APIServer { impl APIServer {
pub async fn run(host: String, port: String, notification_manager: Arc<Mutex<NotificationManager>>, base_url: String) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { pub async fn run(
host: String,
port: String,
notification_manager: Arc<Mutex<NotificationManager>>,
base_url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let api_handler = APIHandler::new(notification_manager, base_url); let api_handler = APIHandler::new(notification_manager, base_url);
let server = APIServer { let server = APIServer {
host, host,
@ -23,21 +28,21 @@ impl APIServer {
}; };
server.start().await server.start().await
} }
async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let address = format!("{}:{}", self.host, self.port); let address = format!("{}:{}", self.host, self.port);
let listener = TcpListener::bind(&address).await?; let listener = TcpListener::bind(&address).await?;
log::info!("HTTP server running at {}", address); log::info!("HTTP server running at {}", address);
loop { loop {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let api_handler = self.api_handler.clone(); let api_handler = self.api_handler.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
let service = service_fn(|req| api_handler.handle_http_request(req)); let service = service_fn(|req| api_handler.handle_http_request(req));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await { if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
log::error!("Failed to serve connection: {:?}", err); log::error!("Failed to serve connection: {:?}", err);
} }

View File

@ -1,3 +1,3 @@
pub mod api_request_handler;
pub mod api_server; pub mod api_server;
pub mod nip98_auth; pub mod nip98_auth;
pub mod api_request_handler;

View File

@ -1,16 +1,16 @@
use base64::prelude::*; use base64::prelude::*;
use serde_json::Value; use nostr;
use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash; use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash;
use nostr::bitcoin::hashes::Hash; use nostr::bitcoin::hashes::Hash;
use nostr::util::hex; use nostr::util::hex;
use nostr::Timestamp; use nostr::Timestamp;
use nostr; use serde_json::Value;
pub async fn nip98_verify_auth_header( pub async fn nip98_verify_auth_header(
auth_header: String, auth_header: String,
url: &str, url: &str,
method: &str, method: &str,
body: Option<&[u8]> body: Option<&[u8]>,
) -> Result<nostr::PublicKey, String> { ) -> Result<nostr::PublicKey, String> {
if auth_header.is_empty() { if auth_header.is_empty() {
return Err("Nostr authorization header missing".to_string()); return Err("Nostr authorization header missing".to_string());
@ -30,23 +30,30 @@ pub async fn nip98_verify_auth_header(
return Err("Nostr authorization header does not have a base64 encoded note".to_string()); return Err("Nostr authorization header does not have a base64 encoded note".to_string());
} }
let decoded_note_json = BASE64_STANDARD.decode(base64_encoded_note.as_bytes()) let decoded_note_json = BASE64_STANDARD
.map_err(|_| format!("Failed to decode base64 encoded note from Nostr authorization header"))?; .decode(base64_encoded_note.as_bytes())
.map_err(|_| {
format!("Failed to decode base64 encoded note from Nostr authorization header")
})?;
let note_value: Value = serde_json::from_slice(&decoded_note_json) let note_value: Value = serde_json::from_slice(&decoded_note_json)
.map_err(|_| format!("Could not parse JSON note from authorization header"))?; .map_err(|_| format!("Could not parse JSON note from authorization header"))?;
let note: nostr::Event = nostr::Event::from_value(note_value) let note: nostr::Event = nostr::Event::from_value(note_value)
.map_err(|_| format!("Could not parse Nostr note from JSON"))?; .map_err(|_| format!("Could not parse Nostr note from JSON"))?;
if note.kind != nostr::Kind::HttpAuth { if note.kind != nostr::Kind::HttpAuth {
return Err("Nostr note kind in authorization header is incorrect".to_string()); return Err("Nostr note kind in authorization header is incorrect".to_string());
} }
let authorized_url = note.get_tag_content(nostr::TagKind::SingleLetter(nostr::SingleLetterTag::lowercase(nostr::Alphabet::U))) let authorized_url = note
.get_tag_content(nostr::TagKind::SingleLetter(
nostr::SingleLetterTag::lowercase(nostr::Alphabet::U),
))
.ok_or_else(|| "Missing 'u' tag from Nostr authorization header".to_string())?; .ok_or_else(|| "Missing 'u' tag from Nostr authorization header".to_string())?;
let authorized_method = note.get_tag_content(nostr::TagKind::Method) let authorized_method = note
.get_tag_content(nostr::TagKind::Method)
.ok_or_else(|| "Missing 'method' tag from Nostr authorization header".to_string())?; .ok_or_else(|| "Missing 'method' tag from Nostr authorization header".to_string())?;
if authorized_url != url || authorized_method != method { if authorized_url != url || authorized_method != method {
@ -59,7 +66,9 @@ pub async fn nip98_verify_auth_header(
let current_time: nostr::Timestamp = nostr::Timestamp::now(); let current_time: nostr::Timestamp = nostr::Timestamp::now();
let note_created_at: nostr::Timestamp = note.created_at(); let note_created_at: nostr::Timestamp = note.created_at();
let time_delta = TimeDelta::subtracting(current_time, note_created_at); let time_delta = TimeDelta::subtracting(current_time, note_created_at);
if (time_delta.negative && time_delta.delta_abs_seconds > 30) || (!time_delta.negative && time_delta.delta_abs_seconds > 60) { if (time_delta.negative && time_delta.delta_abs_seconds > 30)
|| (!time_delta.negative && time_delta.delta_abs_seconds > 60)
{
return Err(format!( return Err(format!(
"Auth note is too old. Current time: {}; Note created at: {}; Time delta: {} seconds", "Auth note is too old. Current time: {}; Note created at: {}; Time delta: {} seconds",
current_time, note_created_at, time_delta current_time, note_created_at, time_delta
@ -69,13 +78,16 @@ pub async fn nip98_verify_auth_header(
if let Some(body_data) = body { if let Some(body_data) = body {
let authorized_content_hash_bytes: Vec<u8> = hex::decode( let authorized_content_hash_bytes: Vec<u8> = hex::decode(
note.get_tag_content(nostr::TagKind::Payload) note.get_tag_content(nostr::TagKind::Payload)
.ok_or("Missing 'payload' tag from Nostr authorization header")? .ok_or("Missing 'payload' tag from Nostr authorization header")?,
) )
.map_err(|_| format!("Failed to decode hex encoded payload from Nostr authorization header"))?; .map_err(|_| {
format!("Failed to decode hex encoded payload from Nostr authorization header")
})?;
let authorized_content_hash: Sha256Hash =
Sha256Hash::from_slice(&authorized_content_hash_bytes)
.map_err(|_| format!("Failed to convert hex encoded payload to Sha256Hash"))?;
let authorized_content_hash: Sha256Hash = Sha256Hash::from_slice(&authorized_content_hash_bytes)
.map_err(|_| format!("Failed to convert hex encoded payload to Sha256Hash"))?;
let body_hash = Sha256Hash::hash(body_data); let body_hash = Sha256Hash::hash(body_data);
if authorized_content_hash != body_hash { if authorized_content_hash != body_hash {
return Err("Auth note payload hash does not match request body hash".to_string()); return Err("Auth note payload hash does not match request body hash".to_string());
@ -91,13 +103,13 @@ pub async fn nip98_verify_auth_header(
if note.verify().is_err() { if note.verify().is_err() {
return Err("Auth note id or signature is invalid".to_string()); return Err("Auth note id or signature is invalid".to_string());
} }
Ok(note.pubkey) Ok(note.pubkey)
} }
struct TimeDelta { struct TimeDelta {
delta_abs_seconds: u64, delta_abs_seconds: u64,
negative: bool negative: bool,
} }
impl TimeDelta { impl TimeDelta {
@ -107,12 +119,12 @@ impl TimeDelta {
if t1 > t2 { if t1 > t2 {
TimeDelta { TimeDelta {
delta_abs_seconds: (t1 - t2).as_u64(), delta_abs_seconds: (t1 - t2).as_u64(),
negative: false negative: false,
} }
} else { } else {
TimeDelta { TimeDelta {
delta_abs_seconds: (t2 - t1).as_u64(), delta_abs_seconds: (t2 - t1).as_u64(),
negative: true negative: true,
} }
} }
} }

View File

@ -1,43 +1,47 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use api_server::api_server::APIServer;
use std::net::TcpListener; use std::net::TcpListener;
use std::sync::Arc; use std::sync::Arc;
use api_server::api_server::APIServer;
use tokio::sync::Mutex; use tokio::sync::Mutex;
mod notification_manager; mod notification_manager;
use log;
use env_logger; use env_logger;
use log;
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
mod relay_connection; mod relay_connection;
use relay_connection::RelayConnection;
use r2d2; use r2d2;
use relay_connection::RelayConnection;
mod notepush_env; mod notepush_env;
use notepush_env::NotePushEnv; use notepush_env::NotePushEnv;
mod api_server; mod api_server;
#[tokio::main] #[tokio::main]
async fn main () { async fn main() {
// MARK: - Setup basics // MARK: - Setup basics
env_logger::init(); env_logger::init();
let env = NotePushEnv::load_env().expect("Failed to load environment variables"); let env = NotePushEnv::load_env().expect("Failed to load environment variables");
let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address"); let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address");
let manager = SqliteConnectionManager::file(env.db_path.clone()); let manager = SqliteConnectionManager::file(env.db_path.clone());
let pool: r2d2::Pool<SqliteConnectionManager> = r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool"); let pool: r2d2::Pool<SqliteConnectionManager> =
r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool");
// Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter. // Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter.
// This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections. // This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections.
let notification_manager = Arc::new(Mutex::new(notification_manager::NotificationManager::new( let notification_manager = Arc::new(Mutex::new(
pool, notification_manager::NotificationManager::new(
env.relay_url.clone(), pool,
env.apns_private_key_path.clone(), env.relay_url.clone(),
env.apns_private_key_id.clone(), env.apns_private_key_path.clone(),
env.apns_team_id.clone(), env.apns_private_key_id.clone(),
env.apns_environment.clone(), env.apns_team_id.clone(),
env.apns_topic.clone(), env.apns_environment.clone(),
).await.expect("Failed to create notification manager"))); env.apns_topic.clone(),
)
.await
.expect("Failed to create notification manager"),
));
// MARK: - Start the API server // MARK: - Start the API server
{ {
let notification_manager = notification_manager.clone(); let notification_manager = notification_manager.clone();
@ -45,24 +49,32 @@ async fn main () {
let api_port = env.api_port.clone(); let api_port = env.api_port.clone();
let api_base_url = env.api_base_url.clone(); let api_base_url = env.api_base_url.clone();
tokio::spawn(async move { tokio::spawn(async move {
APIServer::run(api_host, api_port, notification_manager, api_base_url).await.expect("Failed to start API server"); APIServer::run(api_host, api_port, notification_manager, api_base_url)
.await
.expect("Failed to start API server");
}); });
} }
// MARK: - Start handling incoming connections // MARK: - Start handling incoming connections
log::info!("Relay server listening on {}", env.relay_address().clone()); log::info!("Relay server listening on {}", env.relay_address().clone());
for stream in server.incoming() { for stream in server.incoming() {
if let Ok(stream) = stream { if let Ok(stream) = stream {
let peer_address_string = stream.peer_addr().map_or("unknown".to_string(), |addr| addr.to_string()); let peer_address_string = stream
.peer_addr()
.map_or("unknown".to_string(), |addr| addr.to_string());
log::info!("New connection from {}", peer_address_string); log::info!("New connection from {}", peer_address_string);
let notification_manager = notification_manager.clone(); let notification_manager = notification_manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
match RelayConnection::run(stream, notification_manager).await { match RelayConnection::run(stream, notification_manager).await {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
log::error!("Error with websocket connection from {}: {:?}", peer_address_string, e); log::error!(
"Error with websocket connection from {}: {:?}",
peer_address_string,
e
);
} }
} }
}); });

View File

@ -1,6 +1,6 @@
use std::env;
use dotenv::dotenv;
use a2; use a2;
use dotenv::dotenv;
use std::env;
const DEFAULT_DB_PATH: &str = "./apns_notifications.db"; const DEFAULT_DB_PATH: &str = "./apns_notifications.db";
const DEFAULT_RELAY_HOST: &str = "0.0.0.0"; const DEFAULT_RELAY_HOST: &str = "0.0.0.0";
@ -28,7 +28,7 @@ pub struct NotePushEnv {
// The host and port to bind the API server to // The host and port to bind the API server to
pub api_host: String, pub api_host: String,
pub api_port: String, pub api_port: String,
pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks
// The URL of the Nostr relay server to connect to for getting mutelists // The URL of the Nostr relay server to connect to for getting mutelists
pub relay_url: String, pub relay_url: String,
} }
@ -43,17 +43,19 @@ impl NotePushEnv {
let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string()); let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string());
let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string()); let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string());
let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string()); let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string());
let apns_environment_string = env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string()); let apns_environment_string =
env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string());
let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string()); let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string());
let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string()); let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string());
let api_base_url = env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port)); let api_base_url =
env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port));
let apns_environment = match apns_environment_string.as_str() { let apns_environment = match apns_environment_string.as_str() {
"development" => a2::client::Endpoint::Sandbox, "development" => a2::client::Endpoint::Sandbox,
"production" => a2::client::Endpoint::Production, "production" => a2::client::Endpoint::Production,
_ => a2::client::Endpoint::Sandbox, _ => a2::client::Endpoint::Sandbox,
}; };
let apns_topic = env::var("APNS_TOPIC")?; let apns_topic = env::var("APNS_TOPIC")?;
Ok(NotePushEnv { Ok(NotePushEnv {
apns_private_key_path, apns_private_key_path,
apns_private_key_id, apns_private_key_id,
@ -69,7 +71,7 @@ impl NotePushEnv {
relay_url, relay_url,
}) })
} }
pub fn relay_address(&self) -> String { pub fn relay_address(&self) -> String {
format!("{}:{}", self.relay_host, self.relay_port) format!("{}:{}", self.relay_host, self.relay_port)
} }

View File

@ -1,7 +1,7 @@
pub mod notification_manager;
pub mod mute_manager; pub mod mute_manager;
mod nostr_event_extensions; mod nostr_event_extensions;
pub mod notification_manager;
pub use notification_manager::NotificationManager;
pub use mute_manager::MuteManager; pub use mute_manager::MuteManager;
use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible}; use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible};
pub use notification_manager::NotificationManager;

View File

@ -1,5 +1,5 @@
use nostr_sdk::prelude::*;
use super::ExtendedEvent; use super::ExtendedEvent;
use nostr_sdk::prelude::*;
pub struct MuteManager { pub struct MuteManager {
relay_url: String, relay_url: String,
@ -11,45 +11,67 @@ impl MuteManager {
let client = Client::new(&Keys::generate()); let client = Client::new(&Keys::generate());
client.add_relay(relay_url.clone()).await?; client.add_relay(relay_url.clone()).await?;
client.connect().await; client.connect().await;
Ok(MuteManager { Ok(MuteManager { relay_url, client })
relay_url,
client
})
} }
pub async fn should_mute_notification_for_pubkey(&self, event: &Event, pubkey: &PublicKey) -> bool { pub async fn should_mute_notification_for_pubkey(
&self,
event: &Event,
pubkey: &PublicKey,
) -> bool {
if let Some(mute_list) = self.get_public_mute_list(pubkey).await { if let Some(mute_list) = self.get_public_mute_list(pubkey).await {
for tag in mute_list.tags() { for tag in mute_list.tags() {
match tag.kind() { match tag.kind() {
TagKind::SingleLetter(SingleLetterTag { character: Alphabet::P, uppercase: false }) => { TagKind::SingleLetter(SingleLetterTag {
let tagged_pubkey: Option<PublicKey> = tag.content().and_then(|h| { PublicKey::from_hex(h).ok() }); character: Alphabet::P,
uppercase: false,
}) => {
let tagged_pubkey: Option<PublicKey> =
tag.content().and_then(|h| PublicKey::from_hex(h).ok());
if let Some(tagged_pubkey) = tagged_pubkey { if let Some(tagged_pubkey) = tagged_pubkey {
if event.pubkey == tagged_pubkey { if event.pubkey == tagged_pubkey {
return true return true;
} }
} }
} }
TagKind::SingleLetter(SingleLetterTag { character: Alphabet::E, uppercase: false }) => { TagKind::SingleLetter(SingleLetterTag {
let tagged_event_id: Option<EventId> = tag.content().and_then(|h| { EventId::from_hex(h).ok() }); character: Alphabet::E,
uppercase: false,
}) => {
let tagged_event_id: Option<EventId> =
tag.content().and_then(|h| EventId::from_hex(h).ok());
if let Some(tagged_event_id) = tagged_event_id { if let Some(tagged_event_id) = tagged_event_id {
if event.id == tagged_event_id || event.referenced_event_ids().contains(&tagged_event_id) { if event.id == tagged_event_id
return true || event.referenced_event_ids().contains(&tagged_event_id)
{
return true;
} }
} }
} }
TagKind::SingleLetter(SingleLetterTag { character: Alphabet::T, uppercase: false }) => { TagKind::SingleLetter(SingleLetterTag {
character: Alphabet::T,
uppercase: false,
}) => {
let tagged_hashtag: Option<String> = tag.content().map(|h| h.to_string()); let tagged_hashtag: Option<String> = tag.content().map(|h| h.to_string());
if let Some(tagged_hashtag) = tagged_hashtag { if let Some(tagged_hashtag) = tagged_hashtag {
let tags_content = event.get_tags_content(TagKind::SingleLetter(SingleLetterTag { character: Alphabet::T, uppercase: false })); let tags_content =
event.get_tags_content(TagKind::SingleLetter(SingleLetterTag {
character: Alphabet::T,
uppercase: false,
}));
let should_mute = tags_content.iter().any(|t| t == &tagged_hashtag); let should_mute = tags_content.iter().any(|t| t == &tagged_hashtag);
return should_mute return should_mute;
} }
} }
TagKind::Word => { TagKind::Word => {
let tagged_word: Option<String> = tag.content().map(|h| h.to_string()); let tagged_word: Option<String> = tag.content().map(|h| h.to_string());
if let Some(tagged_word) = tagged_word { if let Some(tagged_word) = tagged_word {
if event.content.to_lowercase().contains(&tagged_word.to_lowercase()) { if event
return true .content
.to_lowercase()
.contains(&tagged_word.to_lowercase())
{
return true;
} }
} }
} }
@ -66,15 +88,23 @@ impl MuteManager {
.authors(vec![pubkey.clone()]) .authors(vec![pubkey.clone()])
.limit(1); .limit(1);
let this_subscription_id = self.client.subscribe(Vec::from([subscription_filter]), None).await; let this_subscription_id = self
.client
.subscribe(Vec::from([subscription_filter]), None)
.await;
let mut mute_list: Option<Event> = None; let mut mute_list: Option<Event> = None;
let mut notifications = self.client.notifications(); let mut notifications = self.client.notifications();
while let Ok(notification) = notifications.recv().await { while let Ok(notification) = notifications.recv().await {
if let RelayPoolNotification::Event { subscription_id, event, .. } = notification { if let RelayPoolNotification::Event {
subscription_id,
event,
..
} = notification
{
if this_subscription_id == subscription_id && event.kind == Kind::MuteList { if this_subscription_id == subscription_id && event.kind == Kind::MuteList {
mute_list = Some((*event).clone()); mute_list = Some((*event).clone());
break break;
} }
} }
} }

View File

@ -1,39 +1,37 @@
use nostr::{self, key::PublicKey, TagKind::SingleLetter, Alphabet, SingleLetterTag}; use nostr::{self, key::PublicKey, Alphabet, SingleLetterTag, TagKind::SingleLetter};
/// Temporary scaffolding of old methods that have not been ported to use native Event methods /// Temporary scaffolding of old methods that have not been ported to use native Event methods
pub trait ExtendedEvent { pub trait ExtendedEvent {
/// Checks if the note references a given pubkey /// Checks if the note references a given pubkey
fn references_pubkey(&self, pubkey: &PublicKey) -> bool; fn references_pubkey(&self, pubkey: &PublicKey) -> bool;
/// Retrieves a set of pubkeys referenced by the note /// Retrieves a set of pubkeys referenced by the note
fn referenced_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey>; fn referenced_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey>;
/// Retrieves a set of pubkeys relevant to the note /// Retrieves a set of pubkeys relevant to the note
fn relevant_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey>; fn relevant_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey>;
/// Retrieves a set of event IDs referenced by the note /// Retrieves a set of event IDs referenced by the note
fn referenced_event_ids(&self) -> std::collections::HashSet<nostr::EventId>; fn referenced_event_ids(&self) -> std::collections::HashSet<nostr::EventId>;
} }
// This is a wrapper around the Event type from strfry-policies, which adds some useful methods // This is a wrapper around the Event type from strfry-policies, which adds some useful methods
impl ExtendedEvent for nostr::Event { impl ExtendedEvent for nostr::Event {
/// Checks if the note references a given pubkey /// Checks if the note references a given pubkey
fn references_pubkey(&self, pubkey: &PublicKey) -> bool { fn references_pubkey(&self, pubkey: &PublicKey) -> bool {
self.referenced_pubkeys().contains(pubkey) self.referenced_pubkeys().contains(pubkey)
} }
/// Retrieves a set of pubkeys referenced by the note /// Retrieves a set of pubkeys referenced by the note
fn referenced_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey> { fn referenced_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey> {
self.get_tags_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::P))) self.get_tags_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::P)))
.iter() .iter()
.filter_map(|tag| { .filter_map(|tag| PublicKey::from_hex(tag).ok())
PublicKey::from_hex(tag).ok()
})
.collect() .collect()
} }
/// Retrieves a set of pubkeys relevant to the note /// Retrieves a set of pubkeys relevant to the note
fn relevant_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey> { fn relevant_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey> {
let mut pubkeys = self.referenced_pubkeys(); let mut pubkeys = self.referenced_pubkeys();
pubkeys.insert(self.pubkey.clone()); pubkeys.insert(self.pubkey.clone());
pubkeys pubkeys
@ -43,9 +41,7 @@ impl ExtendedEvent for nostr::Event {
fn referenced_event_ids(&self) -> std::collections::HashSet<nostr::EventId> { fn referenced_event_ids(&self) -> std::collections::HashSet<nostr::EventId> {
self.get_tag_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::E))) self.get_tag_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::E)))
.iter() .iter()
.filter_map(|tag| { .filter_map(|tag| nostr::EventId::from_hex(tag).ok())
nostr::EventId::from_hex(tag).ok()
})
.collect() .collect()
} }
} }
@ -54,14 +50,16 @@ impl ExtendedEvent for nostr::Event {
pub trait SqlStringConvertible { pub trait SqlStringConvertible {
fn to_sql_string(&self) -> String; fn to_sql_string(&self) -> String;
fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> where Self: Sized; fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>>
where
Self: Sized;
} }
impl SqlStringConvertible for nostr::EventId { impl SqlStringConvertible for nostr::EventId {
fn to_sql_string(&self) -> String { fn to_sql_string(&self) -> String {
self.to_hex() self.to_hex()
} }
fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> { fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> {
nostr::EventId::from_hex(s).map_err(|e| e.into()) nostr::EventId::from_hex(s).map_err(|e| e.into())
} }
@ -71,7 +69,7 @@ impl SqlStringConvertible for nostr::PublicKey {
fn to_sql_string(&self) -> String { fn to_sql_string(&self) -> String {
self.to_hex() self.to_hex()
} }
fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> { fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> {
nostr::PublicKey::from_hex(s).map_err(|e| e.into()) nostr::PublicKey::from_hex(s).map_err(|e| e.into())
} }
@ -81,7 +79,7 @@ impl SqlStringConvertible for nostr::Timestamp {
fn to_sql_string(&self) -> String { fn to_sql_string(&self) -> String {
self.as_u64().to_string() self.as_u64().to_string()
} }
fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> { fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> {
let u64_timestamp: u64 = s.parse()?; let u64_timestamp: u64 = s.parse()?;
Ok(nostr::Timestamp::from(u64_timestamp)) Ok(nostr::Timestamp::from(u64_timestamp))

View File

@ -1,19 +1,19 @@
use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder}; use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder};
use log;
use nostr::event::EventId; use nostr::event::EventId;
use nostr::key::PublicKey; use nostr::key::PublicKey;
use nostr::types::Timestamp; use nostr::types::Timestamp;
use rusqlite; use rusqlite;
use rusqlite::params; use rusqlite::params;
use std::collections::HashSet; use std::collections::HashSet;
use log;
use std::fs::File;
use super::mute_manager::MuteManager; use super::mute_manager::MuteManager;
use nostr::Event;
use super::SqlStringConvertible;
use super::ExtendedEvent; use super::ExtendedEvent;
use r2d2_sqlite::SqliteConnectionManager; use super::SqlStringConvertible;
use nostr::Event;
use r2d2; use r2d2;
use r2d2_sqlite::SqliteConnectionManager;
use std::fs::File;
// MARK: - NotificationManager // MARK: - NotificationManager
@ -31,20 +31,28 @@ pub struct NotificationManager {
impl NotificationManager { impl NotificationManager {
// MARK: - Initialization // MARK: - Initialization
pub async fn new(db: r2d2::Pool<SqliteConnectionManager>, relay_url: String, apns_private_key_path: String, apns_private_key_id: String, apns_team_id: String, apns_environment: a2::client::Endpoint, apns_topic: String) -> Result<Self, Box<dyn std::error::Error>> { pub async fn new(
db: r2d2::Pool<SqliteConnectionManager>,
relay_url: String,
apns_private_key_path: String,
apns_private_key_id: String,
apns_team_id: String,
apns_environment: a2::client::Endpoint,
apns_topic: String,
) -> Result<Self, Box<dyn std::error::Error>> {
let mute_manager = MuteManager::new(relay_url.clone()).await?; let mute_manager = MuteManager::new(relay_url.clone()).await?;
let connection = db.get()?; let connection = db.get()?;
Self::setup_database(&connection)?; Self::setup_database(&connection)?;
let mut file = File::open(&apns_private_key_path)?; let mut file = File::open(&apns_private_key_path)?;
let client = Client::token( let client = Client::token(
&mut file, &mut file,
&apns_private_key_id, &apns_private_key_id,
&apns_team_id, &apns_team_id,
ClientConfig::new(apns_environment.clone()) ClientConfig::new(apns_environment.clone()),
)?; )?;
Ok(Self { Ok(Self {
@ -58,7 +66,7 @@ impl NotificationManager {
mute_manager, mute_manager,
}) })
} }
// MARK: - Database setup operations // MARK: - Database setup operations
pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> { pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
@ -93,39 +101,58 @@ impl NotificationManager {
Self::add_column_if_not_exists(&db, "notifications", "sent_at", "INTEGER")?; Self::add_column_if_not_exists(&db, "notifications", "sent_at", "INTEGER")?;
Self::add_column_if_not_exists(&db, "user_info", "added_at", "INTEGER")?; Self::add_column_if_not_exists(&db, "user_info", "added_at", "INTEGER")?;
Ok(()) Ok(())
} }
fn add_column_if_not_exists(db: &rusqlite::Connection, table_name: &str, column_name: &str, column_type: &str) -> Result<(), rusqlite::Error> { fn add_column_if_not_exists(
db: &rusqlite::Connection,
table_name: &str,
column_name: &str,
column_type: &str,
) -> Result<(), rusqlite::Error> {
let query = format!("PRAGMA table_info({})", table_name); let query = format!("PRAGMA table_info({})", table_name);
let mut stmt = db.prepare(&query)?; let mut stmt = db.prepare(&query)?;
let column_names: Vec<String> = stmt.query_map([], |row| row.get(1))? let column_names: Vec<String> = stmt
.query_map([], |row| row.get(1))?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.collect(); .collect();
if !column_names.contains(&column_name.to_string()) { if !column_names.contains(&column_name.to_string()) {
let query = format!("ALTER TABLE {} ADD COLUMN {} {}", table_name, column_name, column_type); let query = format!(
"ALTER TABLE {} ADD COLUMN {} {}",
table_name, column_name, column_type
);
db.execute(&query, [])?; db.execute(&query, [])?;
} }
Ok(()) Ok(())
} }
// MARK: - Business logic // MARK: - Business logic
pub async fn send_notifications_if_needed(&self, event: &Event) -> Result<(), Box<dyn std::error::Error>> { pub async fn send_notifications_if_needed(
log::debug!("Checking if notifications need to be sent for event: {}", event.id); &self,
event: &Event,
) -> Result<(), Box<dyn std::error::Error>> {
log::debug!(
"Checking if notifications need to be sent for event: {}",
event.id
);
let one_week_ago = nostr::Timestamp::now() - 7 * 24 * 60 * 60; let one_week_ago = nostr::Timestamp::now() - 7 * 24 * 60 * 60;
if event.created_at < one_week_ago { if event.created_at < one_week_ago {
return Ok(()); return Ok(());
} }
let pubkeys_to_notify = self.pubkeys_to_notify_for_event(event).await?; let pubkeys_to_notify = self.pubkeys_to_notify_for_event(event).await?;
log::debug!("Sending notifications to {} pubkeys", pubkeys_to_notify.len()); log::debug!(
"Sending notifications to {} pubkeys",
pubkeys_to_notify.len()
);
for pubkey in pubkeys_to_notify { for pubkey in pubkeys_to_notify {
self.send_event_notifications_to_pubkey(event, &pubkey).await?; self.send_event_notifications_to_pubkey(event, &pubkey)
.await?;
self.db.get()?.execute( self.db.get()?.execute(
"INSERT OR REPLACE INTO notifications (id, event_id, pubkey, received_notification, sent_at) "INSERT OR REPLACE INTO notifications (id, event_id, pubkey, received_notification, sent_at)
VALUES (?, ?, ?, ?, ?)", VALUES (?, ?, ?, ?, ?)",
@ -141,10 +168,14 @@ impl NotificationManager {
Ok(()) Ok(())
} }
async fn pubkeys_to_notify_for_event(&self, event: &Event) -> Result<HashSet<nostr::PublicKey>, Box<dyn std::error::Error>> { async fn pubkeys_to_notify_for_event(
&self,
event: &Event,
) -> Result<HashSet<nostr::PublicKey>, Box<dyn std::error::Error>> {
let notification_status = self.get_notification_status(event)?; let notification_status = self.get_notification_status(event)?;
let relevant_pubkeys = self.pubkeys_relevant_to_event(event)?; let relevant_pubkeys = self.pubkeys_relevant_to_event(event)?;
let pubkeys_that_received_notification = notification_status.pubkeys_that_received_notification(); let pubkeys_that_received_notification =
notification_status.pubkeys_that_received_notification();
let relevant_pubkeys_yet_to_receive: HashSet<PublicKey> = relevant_pubkeys let relevant_pubkeys_yet_to_receive: HashSet<PublicKey> = relevant_pubkeys
.difference(&pubkeys_that_received_notification) .difference(&pubkeys_that_received_notification)
.filter(|&x| *x != event.pubkey) .filter(|&x| *x != event.pubkey)
@ -153,7 +184,10 @@ impl NotificationManager {
let mut pubkeys_to_notify = HashSet::new(); let mut pubkeys_to_notify = HashSet::new();
for pubkey in relevant_pubkeys_yet_to_receive { for pubkey in relevant_pubkeys_yet_to_receive {
let should_mute: bool = self.mute_manager.should_mute_notification_for_pubkey(event, &pubkey).await; let should_mute: bool = self
.mute_manager
.should_mute_notification_for_pubkey(event, &pubkey)
.await;
if !should_mute { if !should_mute {
pubkeys_to_notify.insert(pubkey); pubkeys_to_notify.insert(pubkey);
} }
@ -161,53 +195,79 @@ impl NotificationManager {
Ok(pubkeys_to_notify) Ok(pubkeys_to_notify)
} }
fn pubkeys_relevant_to_event(&self, event: &Event) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> { fn pubkeys_relevant_to_event(
&self,
event: &Event,
) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
let mut relevant_pubkeys = event.relevant_pubkeys(); let mut relevant_pubkeys = event.relevant_pubkeys();
let referenced_event_ids = event.referenced_event_ids(); let referenced_event_ids = event.referenced_event_ids();
for referenced_event_id in referenced_event_ids { for referenced_event_id in referenced_event_ids {
let pubkeys_relevant_to_referenced_event = self.pubkeys_subscribed_to_event_id(&referenced_event_id)?; let pubkeys_relevant_to_referenced_event =
self.pubkeys_subscribed_to_event_id(&referenced_event_id)?;
relevant_pubkeys.extend(pubkeys_relevant_to_referenced_event); relevant_pubkeys.extend(pubkeys_relevant_to_referenced_event);
} }
Ok(relevant_pubkeys) Ok(relevant_pubkeys)
} }
fn pubkeys_subscribed_to_event(&self, event: &Event) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> { fn pubkeys_subscribed_to_event(
&self,
event: &Event,
) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
self.pubkeys_subscribed_to_event_id(&event.id) self.pubkeys_subscribed_to_event_id(&event.id)
} }
fn pubkeys_subscribed_to_event_id(&self, event_id: &EventId) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> { fn pubkeys_subscribed_to_event_id(
&self,
event_id: &EventId,
) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
let connection = self.db.get()?; let connection = self.db.get()?;
let mut stmt = connection.prepare("SELECT pubkey FROM notifications WHERE event_id = ?")?; let mut stmt = connection.prepare("SELECT pubkey FROM notifications WHERE event_id = ?")?;
let pubkeys = stmt.query_map([event_id.to_sql_string()], |row| row.get(0))? let pubkeys = stmt
.query_map([event_id.to_sql_string()], |row| row.get(0))?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.filter_map(|r: String| PublicKey::from_sql_string(r).ok()) .filter_map(|r: String| PublicKey::from_sql_string(r).ok())
.collect(); .collect();
Ok(pubkeys) Ok(pubkeys)
} }
async fn send_event_notifications_to_pubkey(&self, event: &Event, pubkey: &PublicKey) -> Result<(), Box<dyn std::error::Error>> { async fn send_event_notifications_to_pubkey(
&self,
event: &Event,
pubkey: &PublicKey,
) -> Result<(), Box<dyn std::error::Error>> {
let user_device_tokens = self.get_user_device_tokens(pubkey)?; let user_device_tokens = self.get_user_device_tokens(pubkey)?;
for device_token in user_device_tokens { for device_token in user_device_tokens {
self.send_event_notification_to_device_token(event, &device_token).await?; self.send_event_notification_to_device_token(event, &device_token)
.await?;
} }
Ok(()) Ok(())
} }
fn get_user_device_tokens(&self, pubkey: &PublicKey) -> Result<Vec<String>, Box<dyn std::error::Error>> { fn get_user_device_tokens(
&self,
pubkey: &PublicKey,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let connection = self.db.get()?; let connection = self.db.get()?;
let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?; let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?;
let device_tokens = stmt.query_map([pubkey.to_sql_string()], |row| row.get(0))? let device_tokens = stmt
.query_map([pubkey.to_sql_string()], |row| row.get(0))?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.collect(); .collect();
Ok(device_tokens) Ok(device_tokens)
} }
fn get_notification_status(&self, event: &Event) -> Result<NotificationStatus, Box<dyn std::error::Error>> { fn get_notification_status(
&self,
event: &Event,
) -> Result<NotificationStatus, Box<dyn std::error::Error>> {
let connection = self.db.get()?; let connection = self.db.get()?;
let mut stmt = connection.prepare("SELECT pubkey, received_notification FROM notifications WHERE event_id = ?")?; let mut stmt = connection.prepare(
let rows: std::collections::HashMap<PublicKey, bool> = stmt.query_map([event.id.to_sql_string()], |row| { "SELECT pubkey, received_notification FROM notifications WHERE event_id = ?",
Ok((row.get(0)?, row.get(1)?)) )?;
})? let rows: std::collections::HashMap<PublicKey, bool> = stmt
.query_map([event.id.to_sql_string()], |row| {
Ok((row.get(0)?, row.get(1)?))
})?
.filter_map(|r: Result<(String, bool), rusqlite::Error>| r.ok()) .filter_map(|r: Result<(String, bool), rusqlite::Error>| r.ok())
.filter_map(|r: (String, bool)| { .filter_map(|r: (String, bool)| {
let pubkey = PublicKey::from_sql_string(r.0).ok()?; let pubkey = PublicKey::from_sql_string(r.0).ok()?;
@ -225,9 +285,13 @@ impl NotificationManager {
Ok(NotificationStatus { status_info }) Ok(NotificationStatus { status_info })
} }
async fn send_event_notification_to_device_token(&self, event: &Event, device_token: &str) -> Result<(), Box<dyn std::error::Error>> { async fn send_event_notification_to_device_token(
&self,
event: &Event,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let (title, subtitle, body) = self.format_notification_message(event); let (title, subtitle, body) = self.format_notification_message(event);
log::debug!("Sending notification to device token: {}", device_token); log::debug!("Sending notification to device token: {}", device_token);
let builder = DefaultNotificationBuilder::new() let builder = DefaultNotificationBuilder::new()
@ -236,18 +300,15 @@ impl NotificationManager {
.set_body(&body) .set_body(&body)
.set_mutable_content() .set_mutable_content()
.set_content_available(); .set_content_available();
let mut payload = builder.build( let mut payload = builder.build(device_token, Default::default());
device_token,
Default::default()
);
let _ = payload.add_custom_data("nostr_event", event); let _ = payload.add_custom_data("nostr_event", event);
payload.options.apns_topic = Some(self.apns_topic.as_str()); payload.options.apns_topic = Some(self.apns_topic.as_str());
let _response = self.apns_client.send(payload).await?; let _response = self.apns_client.send(payload).await?;
log::info!("Notification sent to device token: {}", device_token); log::info!("Notification sent to device token: {}", device_token);
Ok(()) Ok(())
} }
@ -258,7 +319,11 @@ impl NotificationManager {
(title, subtitle, body) (title, subtitle, body)
} }
pub fn save_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box<dyn std::error::Error>> { pub fn save_user_device_info(
&self,
pubkey: nostr::PublicKey,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let current_time_unix = Timestamp::now(); let current_time_unix = Timestamp::now();
self.db.get()?.execute( self.db.get()?.execute(
"INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)", "INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)",
@ -272,7 +337,11 @@ impl NotificationManager {
Ok(()) Ok(())
} }
pub fn remove_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box<dyn std::error::Error>> { pub fn remove_user_device_info(
&self,
pubkey: nostr::PublicKey,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
self.db.get()?.execute( self.db.get()?.execute(
"DELETE FROM user_info WHERE pubkey = ? AND device_token = ?", "DELETE FROM user_info WHERE pubkey = ? AND device_token = ?",
params![pubkey.to_sql_string(), device_token], params![pubkey.to_sql_string(), device_token],

View File

@ -1,43 +1,48 @@
use crate::notification_manager::NotificationManager;
use log;
use nostr::util::JsonUtil; use nostr::util::JsonUtil;
use nostr::{RelayMessage, ClientMessage}; use nostr::{ClientMessage, RelayMessage};
use serde_json::Value;
use std::fmt::{self, Debug};
use std::net::TcpStream;
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use serde_json::Value;
use crate::notification_manager::NotificationManager;
use std::str::FromStr;
use std::net::TcpStream;
use tungstenite::{accept, WebSocket}; use tungstenite::{accept, WebSocket};
use log;
use std::fmt::{self, Debug};
const MAX_CONSECUTIVE_ERRORS: u32 = 10; const MAX_CONSECUTIVE_ERRORS: u32 = 10;
pub struct RelayConnection { pub struct RelayConnection {
websocket: WebSocket<TcpStream>, websocket: WebSocket<TcpStream>,
notification_manager: Arc<Mutex<NotificationManager>> notification_manager: Arc<Mutex<NotificationManager>>,
} }
impl RelayConnection { impl RelayConnection {
// MARK: - Initializers // MARK: - Initializers
pub fn new(stream: TcpStream, notification_manager: Arc<Mutex<NotificationManager>>) -> Result<Self, Box<dyn std::error::Error>> { pub fn new(
stream: TcpStream,
notification_manager: Arc<Mutex<NotificationManager>>,
) -> Result<Self, Box<dyn std::error::Error>> {
let address = stream.peer_addr()?; let address = stream.peer_addr()?;
let websocket = accept(stream)?; let websocket = accept(stream)?;
log::info!("Accepted connection from {:?}", address); log::info!("Accepted connection from {:?}", address);
Ok(RelayConnection { Ok(RelayConnection {
websocket, websocket,
notification_manager notification_manager,
}) })
} }
pub async fn run(stream: TcpStream, notification_manager: Arc<Mutex<NotificationManager>>) -> Result<(), Box<dyn std::error::Error>> { pub async fn run(
stream: TcpStream,
notification_manager: Arc<Mutex<NotificationManager>>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut connection = RelayConnection::new(stream, notification_manager)?; let mut connection = RelayConnection::new(stream, notification_manager)?;
Ok(connection.run_loop().await?) Ok(connection.run_loop().await?)
} }
// MARK: - Connection Runtime management // MARK: - Connection Runtime management
pub async fn run_loop(&mut self) -> Result<(), Box<dyn std::error::Error>> { pub async fn run_loop(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let mut consecutive_errors = 0; let mut consecutive_errors = 0;
log::debug!("Starting run loop for connection with {:?}", self.websocket); log::debug!("Starting run loop for connection with {:?}", self.websocket);
@ -47,31 +52,43 @@ impl RelayConnection {
consecutive_errors = 0; consecutive_errors = 0;
} }
Err(e) => { Err(e) => {
log::error!("Error in websocket connection with {:?}: {:?}", self.websocket, e); log::error!(
"Error in websocket connection with {:?}: {:?}",
self.websocket,
e
);
consecutive_errors += 1; consecutive_errors += 1;
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS { if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
log::error!("Too many consecutive errors, closing connection with {:?}", self.websocket); log::error!(
"Too many consecutive errors, closing connection with {:?}",
self.websocket
);
return Err(e); return Err(e);
} }
} }
} }
} }
} }
pub async fn run_loop_iteration<'a>(&'a mut self) -> Result<(), Box<dyn std::error::Error>> { pub async fn run_loop_iteration<'a>(&'a mut self) -> Result<(), Box<dyn std::error::Error>> {
let websocket = &mut self.websocket; let websocket = &mut self.websocket;
let raw_message = websocket.read()?; let raw_message = websocket.read()?;
if raw_message.is_text() { if raw_message.is_text() {
let message: ClientMessage = ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?; let message: ClientMessage =
ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?;
let response = self.handle_client_message(message).await?; let response = self.handle_client_message(message).await?;
self.websocket.send(tungstenite::Message::text(response.try_as_json()?))?; self.websocket
.send(tungstenite::Message::text(response.try_as_json()?))?;
} }
Ok(()) Ok(())
} }
// MARK: - Message handling // MARK: - Message handling
async fn handle_client_message<'b>(&'b self, message: ClientMessage) -> Result<RelayMessage, Box<dyn std::error::Error>> { async fn handle_client_message<'b>(
&'b self,
message: ClientMessage,
) -> Result<RelayMessage, Box<dyn std::error::Error>> {
match message { match message {
ClientMessage::Event(event) => { ClientMessage::Event(event) => {
log::info!("Received event: {:?}", event); log::info!("Received event: {:?}", event);
@ -79,15 +96,21 @@ impl RelayConnection {
// TODO: Reduce resource contention by reducing the scope of the mutex into NotificationManager logic. // TODO: Reduce resource contention by reducing the scope of the mutex into NotificationManager logic.
let mutex_guard = self.notification_manager.lock().await; let mutex_guard = self.notification_manager.lock().await;
mutex_guard.send_notifications_if_needed(&event).await?; mutex_guard.send_notifications_if_needed(&event).await?;
}; // Only hold the mutex for as little time as possible. }; // Only hold the mutex for as little time as possible.
let notice_message = format!("blocked: This relay does not store events"); let notice_message = format!("blocked: This relay does not store events");
let response = RelayMessage::Ok { event_id: event.id, status: false, message: notice_message }; let response = RelayMessage::Ok {
event_id: event.id,
status: false,
message: notice_message,
};
Ok(response) Ok(response)
} }
_ => { _ => {
log::info!("Received unsupported message: {:?}", message); log::info!("Received unsupported message: {:?}", message);
let notice_message = format!("Unsupported message: {:?}", message); let notice_message = format!("Unsupported message: {:?}", message);
let response = RelayMessage::Notice { message: notice_message }; let response = RelayMessage::Notice {
message: notice_message,
};
Ok(response) Ok(response)
} }
} }