mirror of
https://github.com/nostrlabs-io/notepush.git
synced 2025-06-15 19:38:24 +00:00
@ -1,18 +1,18 @@
|
|||||||
use super::nip98_auth;
|
use super::nip98_auth;
|
||||||
use hyper::{Request, Response, StatusCode};
|
|
||||||
use hyper::body::Buf;
|
use hyper::body::Buf;
|
||||||
use hyper::body::Incoming;
|
use hyper::body::Incoming;
|
||||||
|
use hyper::{Request, Response, StatusCode};
|
||||||
|
|
||||||
use http_body_util::BodyExt;
|
use http_body_util::BodyExt;
|
||||||
use nostr;
|
use nostr;
|
||||||
|
|
||||||
use thiserror::Error;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use log;
|
|
||||||
use hyper::Method;
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use serde_json::{json, Value};
|
|
||||||
use crate::notification_manager::NotificationManager;
|
use crate::notification_manager::NotificationManager;
|
||||||
|
use hyper::Method;
|
||||||
|
use log;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
struct ParsedRequest {
|
struct ParsedRequest {
|
||||||
uri: String,
|
uri: String,
|
||||||
@ -58,30 +58,32 @@ impl APIHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_http_request(&self, req: Request<Incoming>) -> Result<Response<String>, hyper::http::Error> {
|
pub async fn handle_http_request(
|
||||||
|
&self,
|
||||||
|
req: Request<Incoming>,
|
||||||
|
) -> Result<Response<String>, hyper::http::Error> {
|
||||||
let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await {
|
let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await {
|
||||||
Ok(api_response) => {
|
Ok(api_response) => APIResponse {
|
||||||
APIResponse {
|
|
||||||
status: api_response.status,
|
status: api_response.status,
|
||||||
body: api_response.body,
|
body: api_response.body,
|
||||||
}
|
|
||||||
},
|
},
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
// Detect if error is a APIError::AuthenticationError and return a 401 status code
|
// Detect if error is a APIError::AuthenticationError and return a 401 status code
|
||||||
if let Some(api_error) = err.downcast_ref::<APIError>() {
|
if let Some(api_error) = err.downcast_ref::<APIError>() {
|
||||||
match api_error {
|
match api_error {
|
||||||
APIError::AuthenticationError(message) => {
|
APIError::AuthenticationError(message) => APIResponse {
|
||||||
APIResponse {
|
|
||||||
status: StatusCode::UNAUTHORIZED,
|
status: StatusCode::UNAUTHORIZED,
|
||||||
body: json!({ "error": "Unauthorized", "message": message }),
|
body: json!({ "error": "Unauthorized", "message": message }),
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
// Otherwise, return a 500 status code
|
// Otherwise, return a 500 status code
|
||||||
let random_case_uuid = uuid::Uuid::new_v4();
|
let random_case_uuid = uuid::Uuid::new_v4();
|
||||||
log::error!("Error handling request: {} (Case ID: {})", err, random_case_uuid);
|
log::error!(
|
||||||
|
"Error handling request: {} (Case ID: {})",
|
||||||
|
err,
|
||||||
|
random_case_uuid
|
||||||
|
);
|
||||||
APIResponse {
|
APIResponse {
|
||||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
body: json!({ "error": "Internal server error", "message": format!("Case ID: {}", random_case_uuid) }),
|
body: json!({ "error": "Internal server error", "message": format!("Case ID: {}", random_case_uuid) }),
|
||||||
@ -93,28 +95,41 @@ impl APIHandler {
|
|||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Access-Control-Allow-Origin", "*")
|
.header("Access-Control-Allow-Origin", "*")
|
||||||
.status(final_api_response.status)
|
.status(final_api_response.status)
|
||||||
.body(final_api_response.body.to_string())?
|
.body(final_api_response.body.to_string())?)
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn try_to_handle_http_request(&self, mut req: Request<Incoming>) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
async fn try_to_handle_http_request(
|
||||||
|
&self,
|
||||||
|
mut req: Request<Incoming>,
|
||||||
|
) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
||||||
let parsed_request = self.parse_http_request(&mut req).await?;
|
let parsed_request = self.parse_http_request(&mut req).await?;
|
||||||
let api_response: APIResponse = self.handle_parsed_http_request(&parsed_request).await?;
|
let api_response: APIResponse = self.handle_parsed_http_request(&parsed_request).await?;
|
||||||
log::info!("[{}] {} (Authorized pubkey: {}): {}", req.method(), req.uri(), parsed_request.authorized_pubkey, api_response.status);
|
log::info!(
|
||||||
|
"[{}] {} (Authorized pubkey: {}): {}",
|
||||||
|
req.method(),
|
||||||
|
req.uri(),
|
||||||
|
parsed_request.authorized_pubkey,
|
||||||
|
api_response.status
|
||||||
|
);
|
||||||
Ok(api_response)
|
Ok(api_response)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn parse_http_request(&self, req: &mut Request<Incoming>) -> Result<ParsedRequest, Box<dyn std::error::Error>> {
|
async fn parse_http_request(
|
||||||
|
&self,
|
||||||
|
req: &mut Request<Incoming>,
|
||||||
|
) -> Result<ParsedRequest, Box<dyn std::error::Error>> {
|
||||||
// 1. Read the request body
|
// 1. Read the request body
|
||||||
let body_buffer = req.body_mut().collect().await?.aggregate();
|
let body_buffer = req.body_mut().collect().await?.aggregate();
|
||||||
let body_bytes = body_buffer.chunk();
|
let body_bytes = body_buffer.chunk();
|
||||||
let body_bytes = if body_bytes.is_empty() { None } else { Some(body_bytes) };
|
let body_bytes = if body_bytes.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(body_bytes)
|
||||||
|
};
|
||||||
|
|
||||||
// 2. NIP-98 authentication
|
// 2. NIP-98 authentication
|
||||||
let authorized_pubkey = match self.authenticate(&req, body_bytes).await? {
|
let authorized_pubkey = match self.authenticate(&req, body_bytes).await? {
|
||||||
Ok(pubkey) => {
|
Ok(pubkey) => pubkey,
|
||||||
pubkey
|
|
||||||
},
|
|
||||||
Err(auth_error) => {
|
Err(auth_error) => {
|
||||||
return Err(Box::new(APIError::AuthenticationError(auth_error)));
|
return Err(Box::new(APIError::AuthenticationError(auth_error)));
|
||||||
}
|
}
|
||||||
@ -129,20 +144,27 @@ impl APIHandler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_parsed_http_request(&self, parsed_request: &ParsedRequest) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
async fn handle_parsed_http_request(
|
||||||
|
&self,
|
||||||
|
parsed_request: &ParsedRequest,
|
||||||
|
) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
||||||
match (&parsed_request.method, parsed_request.uri.as_str()) {
|
match (&parsed_request.method, parsed_request.uri.as_str()) {
|
||||||
(&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await,
|
(&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await,
|
||||||
(&Method::POST, "/user-info/remove") => self.handle_user_info_remove(parsed_request).await,
|
(&Method::POST, "/user-info/remove") => {
|
||||||
_ => {
|
self.handle_user_info_remove(parsed_request).await
|
||||||
Ok(APIResponse {
|
}
|
||||||
|
_ => Ok(APIResponse {
|
||||||
status: StatusCode::NOT_FOUND,
|
status: StatusCode::NOT_FOUND,
|
||||||
body: json!({ "error": "Not found" }),
|
body: json!({ "error": "Not found" }),
|
||||||
})
|
}),
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authenticate(&self, req: &Request<Incoming>, body_bytes: Option<&[u8]>) -> Result<Result<nostr::PublicKey, String>, Box<dyn std::error::Error>> {
|
async fn authenticate(
|
||||||
|
&self,
|
||||||
|
req: &Request<Incoming>,
|
||||||
|
body_bytes: Option<&[u8]>,
|
||||||
|
) -> Result<Result<nostr::PublicKey, String>, Box<dyn std::error::Error>> {
|
||||||
let auth_header = match req.headers().get("Authorization") {
|
let auth_header = match req.headers().get("Authorization") {
|
||||||
Some(header) => header,
|
Some(header) => header,
|
||||||
None => return Ok(Err("Authorization header not found".to_string())),
|
None => return Ok(Err("Authorization header not found".to_string())),
|
||||||
@ -152,11 +174,15 @@ impl APIHandler {
|
|||||||
auth_header.to_str()?.to_string(),
|
auth_header.to_str()?.to_string(),
|
||||||
&format!("{}{}", self.base_url, req.uri().path()),
|
&format!("{}{}", self.base_url, req.uri().path()),
|
||||||
req.method().as_str(),
|
req.method().as_str(),
|
||||||
body_bytes
|
body_bytes,
|
||||||
).await)
|
)
|
||||||
|
.await)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_user_info(&self, req: &ParsedRequest) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
async fn handle_user_info(
|
||||||
|
&self,
|
||||||
|
req: &ParsedRequest,
|
||||||
|
) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
||||||
let body = req.body_json()?;
|
let body = req.body_json()?;
|
||||||
|
|
||||||
if let Some(device_token) = body["deviceToken"].as_str() {
|
if let Some(device_token) = body["deviceToken"].as_str() {
|
||||||
@ -174,7 +200,10 @@ impl APIHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_user_info_remove(&self, req: &ParsedRequest) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
async fn handle_user_info_remove(
|
||||||
|
&self,
|
||||||
|
req: &ParsedRequest,
|
||||||
|
) -> Result<APIResponse, Box<dyn std::error::Error>> {
|
||||||
let body: Value = req.body_json()?;
|
let body: Value = req.body_json()?;
|
||||||
|
|
||||||
if let Some(device_token) = body["deviceToken"].as_str() {
|
if let Some(device_token) = body["deviceToken"].as_str() {
|
||||||
@ -194,8 +223,7 @@ impl APIHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Define enum error types including authentication error
|
// Define enum error types including authentication error
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Error)]
|
||||||
#[derive(Error)]
|
|
||||||
enum APIError {
|
enum APIError {
|
||||||
#[error("Authentication error: {0}")]
|
#[error("Authentication error: {0}")]
|
||||||
AuthenticationError(String),
|
AuthenticationError(String),
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
|
use super::api_request_handler::APIHandler;
|
||||||
|
use crate::notification_manager::NotificationManager;
|
||||||
use hyper::{server::conn::http1, service::service_fn};
|
use hyper::{server::conn::http1, service::service_fn};
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
|
use log;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use log;
|
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use crate::notification_manager::NotificationManager;
|
|
||||||
use super::api_request_handler::APIHandler;
|
|
||||||
|
|
||||||
pub struct APIServer {
|
pub struct APIServer {
|
||||||
host: String,
|
host: String,
|
||||||
@ -14,7 +14,12 @@ pub struct APIServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl APIServer {
|
impl APIServer {
|
||||||
pub async fn run(host: String, port: String, notification_manager: Arc<Mutex<NotificationManager>>, base_url: String) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
pub async fn run(
|
||||||
|
host: String,
|
||||||
|
port: String,
|
||||||
|
notification_manager: Arc<Mutex<NotificationManager>>,
|
||||||
|
base_url: String,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let api_handler = APIHandler::new(notification_manager, base_url);
|
let api_handler = APIHandler::new(notification_manager, base_url);
|
||||||
let server = APIServer {
|
let server = APIServer {
|
||||||
host,
|
host,
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
|
pub mod api_request_handler;
|
||||||
pub mod api_server;
|
pub mod api_server;
|
||||||
pub mod nip98_auth;
|
pub mod nip98_auth;
|
||||||
pub mod api_request_handler;
|
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
use base64::prelude::*;
|
use base64::prelude::*;
|
||||||
use serde_json::Value;
|
use nostr;
|
||||||
use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash;
|
use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash;
|
||||||
use nostr::bitcoin::hashes::Hash;
|
use nostr::bitcoin::hashes::Hash;
|
||||||
use nostr::util::hex;
|
use nostr::util::hex;
|
||||||
use nostr::Timestamp;
|
use nostr::Timestamp;
|
||||||
use nostr;
|
use serde_json::Value;
|
||||||
|
|
||||||
pub async fn nip98_verify_auth_header(
|
pub async fn nip98_verify_auth_header(
|
||||||
auth_header: String,
|
auth_header: String,
|
||||||
url: &str,
|
url: &str,
|
||||||
method: &str,
|
method: &str,
|
||||||
body: Option<&[u8]>
|
body: Option<&[u8]>,
|
||||||
) -> Result<nostr::PublicKey, String> {
|
) -> Result<nostr::PublicKey, String> {
|
||||||
if auth_header.is_empty() {
|
if auth_header.is_empty() {
|
||||||
return Err("Nostr authorization header missing".to_string());
|
return Err("Nostr authorization header missing".to_string());
|
||||||
@ -30,8 +30,11 @@ pub async fn nip98_verify_auth_header(
|
|||||||
return Err("Nostr authorization header does not have a base64 encoded note".to_string());
|
return Err("Nostr authorization header does not have a base64 encoded note".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
let decoded_note_json = BASE64_STANDARD.decode(base64_encoded_note.as_bytes())
|
let decoded_note_json = BASE64_STANDARD
|
||||||
.map_err(|_| format!("Failed to decode base64 encoded note from Nostr authorization header"))?;
|
.decode(base64_encoded_note.as_bytes())
|
||||||
|
.map_err(|_| {
|
||||||
|
format!("Failed to decode base64 encoded note from Nostr authorization header")
|
||||||
|
})?;
|
||||||
|
|
||||||
let note_value: Value = serde_json::from_slice(&decoded_note_json)
|
let note_value: Value = serde_json::from_slice(&decoded_note_json)
|
||||||
.map_err(|_| format!("Could not parse JSON note from authorization header"))?;
|
.map_err(|_| format!("Could not parse JSON note from authorization header"))?;
|
||||||
@ -43,10 +46,14 @@ pub async fn nip98_verify_auth_header(
|
|||||||
return Err("Nostr note kind in authorization header is incorrect".to_string());
|
return Err("Nostr note kind in authorization header is incorrect".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
let authorized_url = note.get_tag_content(nostr::TagKind::SingleLetter(nostr::SingleLetterTag::lowercase(nostr::Alphabet::U)))
|
let authorized_url = note
|
||||||
|
.get_tag_content(nostr::TagKind::SingleLetter(
|
||||||
|
nostr::SingleLetterTag::lowercase(nostr::Alphabet::U),
|
||||||
|
))
|
||||||
.ok_or_else(|| "Missing 'u' tag from Nostr authorization header".to_string())?;
|
.ok_or_else(|| "Missing 'u' tag from Nostr authorization header".to_string())?;
|
||||||
|
|
||||||
let authorized_method = note.get_tag_content(nostr::TagKind::Method)
|
let authorized_method = note
|
||||||
|
.get_tag_content(nostr::TagKind::Method)
|
||||||
.ok_or_else(|| "Missing 'method' tag from Nostr authorization header".to_string())?;
|
.ok_or_else(|| "Missing 'method' tag from Nostr authorization header".to_string())?;
|
||||||
|
|
||||||
if authorized_url != url || authorized_method != method {
|
if authorized_url != url || authorized_method != method {
|
||||||
@ -59,7 +66,9 @@ pub async fn nip98_verify_auth_header(
|
|||||||
let current_time: nostr::Timestamp = nostr::Timestamp::now();
|
let current_time: nostr::Timestamp = nostr::Timestamp::now();
|
||||||
let note_created_at: nostr::Timestamp = note.created_at();
|
let note_created_at: nostr::Timestamp = note.created_at();
|
||||||
let time_delta = TimeDelta::subtracting(current_time, note_created_at);
|
let time_delta = TimeDelta::subtracting(current_time, note_created_at);
|
||||||
if (time_delta.negative && time_delta.delta_abs_seconds > 30) || (!time_delta.negative && time_delta.delta_abs_seconds > 60) {
|
if (time_delta.negative && time_delta.delta_abs_seconds > 30)
|
||||||
|
|| (!time_delta.negative && time_delta.delta_abs_seconds > 60)
|
||||||
|
{
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"Auth note is too old. Current time: {}; Note created at: {}; Time delta: {} seconds",
|
"Auth note is too old. Current time: {}; Note created at: {}; Time delta: {} seconds",
|
||||||
current_time, note_created_at, time_delta
|
current_time, note_created_at, time_delta
|
||||||
@ -69,11 +78,14 @@ pub async fn nip98_verify_auth_header(
|
|||||||
if let Some(body_data) = body {
|
if let Some(body_data) = body {
|
||||||
let authorized_content_hash_bytes: Vec<u8> = hex::decode(
|
let authorized_content_hash_bytes: Vec<u8> = hex::decode(
|
||||||
note.get_tag_content(nostr::TagKind::Payload)
|
note.get_tag_content(nostr::TagKind::Payload)
|
||||||
.ok_or("Missing 'payload' tag from Nostr authorization header")?
|
.ok_or("Missing 'payload' tag from Nostr authorization header")?,
|
||||||
)
|
)
|
||||||
.map_err(|_| format!("Failed to decode hex encoded payload from Nostr authorization header"))?;
|
.map_err(|_| {
|
||||||
|
format!("Failed to decode hex encoded payload from Nostr authorization header")
|
||||||
|
})?;
|
||||||
|
|
||||||
let authorized_content_hash: Sha256Hash = Sha256Hash::from_slice(&authorized_content_hash_bytes)
|
let authorized_content_hash: Sha256Hash =
|
||||||
|
Sha256Hash::from_slice(&authorized_content_hash_bytes)
|
||||||
.map_err(|_| format!("Failed to convert hex encoded payload to Sha256Hash"))?;
|
.map_err(|_| format!("Failed to convert hex encoded payload to Sha256Hash"))?;
|
||||||
|
|
||||||
let body_hash = Sha256Hash::hash(body_data);
|
let body_hash = Sha256Hash::hash(body_data);
|
||||||
@ -97,7 +109,7 @@ pub async fn nip98_verify_auth_header(
|
|||||||
|
|
||||||
struct TimeDelta {
|
struct TimeDelta {
|
||||||
delta_abs_seconds: u64,
|
delta_abs_seconds: u64,
|
||||||
negative: bool
|
negative: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TimeDelta {
|
impl TimeDelta {
|
||||||
@ -107,12 +119,12 @@ impl TimeDelta {
|
|||||||
if t1 > t2 {
|
if t1 > t2 {
|
||||||
TimeDelta {
|
TimeDelta {
|
||||||
delta_abs_seconds: (t1 - t2).as_u64(),
|
delta_abs_seconds: (t1 - t2).as_u64(),
|
||||||
negative: false
|
negative: false,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TimeDelta {
|
TimeDelta {
|
||||||
delta_abs_seconds: (t2 - t1).as_u64(),
|
delta_abs_seconds: (t2 - t1).as_u64(),
|
||||||
negative: true
|
negative: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
32
src/main.rs
32
src/main.rs
@ -1,22 +1,21 @@
|
|||||||
#![forbid(unsafe_code)]
|
#![forbid(unsafe_code)]
|
||||||
|
use api_server::api_server::APIServer;
|
||||||
use std::net::TcpListener;
|
use std::net::TcpListener;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use api_server::api_server::APIServer;
|
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
mod notification_manager;
|
mod notification_manager;
|
||||||
use log;
|
|
||||||
use env_logger;
|
use env_logger;
|
||||||
|
use log;
|
||||||
use r2d2_sqlite::SqliteConnectionManager;
|
use r2d2_sqlite::SqliteConnectionManager;
|
||||||
mod relay_connection;
|
mod relay_connection;
|
||||||
use relay_connection::RelayConnection;
|
|
||||||
use r2d2;
|
use r2d2;
|
||||||
|
use relay_connection::RelayConnection;
|
||||||
mod notepush_env;
|
mod notepush_env;
|
||||||
use notepush_env::NotePushEnv;
|
use notepush_env::NotePushEnv;
|
||||||
mod api_server;
|
mod api_server;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
|
|
||||||
// MARK: - Setup basics
|
// MARK: - Setup basics
|
||||||
|
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
@ -25,10 +24,12 @@ async fn main () {
|
|||||||
let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address");
|
let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address");
|
||||||
|
|
||||||
let manager = SqliteConnectionManager::file(env.db_path.clone());
|
let manager = SqliteConnectionManager::file(env.db_path.clone());
|
||||||
let pool: r2d2::Pool<SqliteConnectionManager> = r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool");
|
let pool: r2d2::Pool<SqliteConnectionManager> =
|
||||||
|
r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool");
|
||||||
// Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter.
|
// Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter.
|
||||||
// This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections.
|
// This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections.
|
||||||
let notification_manager = Arc::new(Mutex::new(notification_manager::NotificationManager::new(
|
let notification_manager = Arc::new(Mutex::new(
|
||||||
|
notification_manager::NotificationManager::new(
|
||||||
pool,
|
pool,
|
||||||
env.relay_url.clone(),
|
env.relay_url.clone(),
|
||||||
env.apns_private_key_path.clone(),
|
env.apns_private_key_path.clone(),
|
||||||
@ -36,7 +37,10 @@ async fn main () {
|
|||||||
env.apns_team_id.clone(),
|
env.apns_team_id.clone(),
|
||||||
env.apns_environment.clone(),
|
env.apns_environment.clone(),
|
||||||
env.apns_topic.clone(),
|
env.apns_topic.clone(),
|
||||||
).await.expect("Failed to create notification manager")));
|
)
|
||||||
|
.await
|
||||||
|
.expect("Failed to create notification manager"),
|
||||||
|
));
|
||||||
|
|
||||||
// MARK: - Start the API server
|
// MARK: - Start the API server
|
||||||
{
|
{
|
||||||
@ -45,7 +49,9 @@ async fn main () {
|
|||||||
let api_port = env.api_port.clone();
|
let api_port = env.api_port.clone();
|
||||||
let api_base_url = env.api_base_url.clone();
|
let api_base_url = env.api_base_url.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
APIServer::run(api_host, api_port, notification_manager, api_base_url).await.expect("Failed to start API server");
|
APIServer::run(api_host, api_port, notification_manager, api_base_url)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start API server");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,14 +61,20 @@ async fn main () {
|
|||||||
|
|
||||||
for stream in server.incoming() {
|
for stream in server.incoming() {
|
||||||
if let Ok(stream) = stream {
|
if let Ok(stream) = stream {
|
||||||
let peer_address_string = stream.peer_addr().map_or("unknown".to_string(), |addr| addr.to_string());
|
let peer_address_string = stream
|
||||||
|
.peer_addr()
|
||||||
|
.map_or("unknown".to_string(), |addr| addr.to_string());
|
||||||
log::info!("New connection from {}", peer_address_string);
|
log::info!("New connection from {}", peer_address_string);
|
||||||
let notification_manager = notification_manager.clone();
|
let notification_manager = notification_manager.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match RelayConnection::run(stream, notification_manager).await {
|
match RelayConnection::run(stream, notification_manager).await {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Error with websocket connection from {}: {:?}", peer_address_string, e);
|
log::error!(
|
||||||
|
"Error with websocket connection from {}: {:?}",
|
||||||
|
peer_address_string,
|
||||||
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use std::env;
|
|
||||||
use dotenv::dotenv;
|
|
||||||
use a2;
|
use a2;
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
const DEFAULT_DB_PATH: &str = "./apns_notifications.db";
|
const DEFAULT_DB_PATH: &str = "./apns_notifications.db";
|
||||||
const DEFAULT_RELAY_HOST: &str = "0.0.0.0";
|
const DEFAULT_RELAY_HOST: &str = "0.0.0.0";
|
||||||
@ -43,10 +43,12 @@ impl NotePushEnv {
|
|||||||
let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string());
|
let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string());
|
||||||
let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string());
|
let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string());
|
||||||
let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string());
|
let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string());
|
||||||
let apns_environment_string = env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string());
|
let apns_environment_string =
|
||||||
|
env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string());
|
||||||
let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string());
|
let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string());
|
||||||
let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string());
|
let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string());
|
||||||
let api_base_url = env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port));
|
let api_base_url =
|
||||||
|
env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port));
|
||||||
let apns_environment = match apns_environment_string.as_str() {
|
let apns_environment = match apns_environment_string.as_str() {
|
||||||
"development" => a2::client::Endpoint::Sandbox,
|
"development" => a2::client::Endpoint::Sandbox,
|
||||||
"production" => a2::client::Endpoint::Production,
|
"production" => a2::client::Endpoint::Production,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
pub mod notification_manager;
|
|
||||||
pub mod mute_manager;
|
pub mod mute_manager;
|
||||||
mod nostr_event_extensions;
|
mod nostr_event_extensions;
|
||||||
|
pub mod notification_manager;
|
||||||
|
|
||||||
pub use notification_manager::NotificationManager;
|
|
||||||
pub use mute_manager::MuteManager;
|
pub use mute_manager::MuteManager;
|
||||||
use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible};
|
use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible};
|
||||||
|
pub use notification_manager::NotificationManager;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use nostr_sdk::prelude::*;
|
|
||||||
use super::ExtendedEvent;
|
use super::ExtendedEvent;
|
||||||
|
use nostr_sdk::prelude::*;
|
||||||
|
|
||||||
pub struct MuteManager {
|
pub struct MuteManager {
|
||||||
relay_url: String,
|
relay_url: String,
|
||||||
@ -11,45 +11,67 @@ impl MuteManager {
|
|||||||
let client = Client::new(&Keys::generate());
|
let client = Client::new(&Keys::generate());
|
||||||
client.add_relay(relay_url.clone()).await?;
|
client.add_relay(relay_url.clone()).await?;
|
||||||
client.connect().await;
|
client.connect().await;
|
||||||
Ok(MuteManager {
|
Ok(MuteManager { relay_url, client })
|
||||||
relay_url,
|
|
||||||
client
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn should_mute_notification_for_pubkey(&self, event: &Event, pubkey: &PublicKey) -> bool {
|
pub async fn should_mute_notification_for_pubkey(
|
||||||
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
pubkey: &PublicKey,
|
||||||
|
) -> bool {
|
||||||
if let Some(mute_list) = self.get_public_mute_list(pubkey).await {
|
if let Some(mute_list) = self.get_public_mute_list(pubkey).await {
|
||||||
for tag in mute_list.tags() {
|
for tag in mute_list.tags() {
|
||||||
match tag.kind() {
|
match tag.kind() {
|
||||||
TagKind::SingleLetter(SingleLetterTag { character: Alphabet::P, uppercase: false }) => {
|
TagKind::SingleLetter(SingleLetterTag {
|
||||||
let tagged_pubkey: Option<PublicKey> = tag.content().and_then(|h| { PublicKey::from_hex(h).ok() });
|
character: Alphabet::P,
|
||||||
|
uppercase: false,
|
||||||
|
}) => {
|
||||||
|
let tagged_pubkey: Option<PublicKey> =
|
||||||
|
tag.content().and_then(|h| PublicKey::from_hex(h).ok());
|
||||||
if let Some(tagged_pubkey) = tagged_pubkey {
|
if let Some(tagged_pubkey) = tagged_pubkey {
|
||||||
if event.pubkey == tagged_pubkey {
|
if event.pubkey == tagged_pubkey {
|
||||||
return true
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TagKind::SingleLetter(SingleLetterTag { character: Alphabet::E, uppercase: false }) => {
|
TagKind::SingleLetter(SingleLetterTag {
|
||||||
let tagged_event_id: Option<EventId> = tag.content().and_then(|h| { EventId::from_hex(h).ok() });
|
character: Alphabet::E,
|
||||||
|
uppercase: false,
|
||||||
|
}) => {
|
||||||
|
let tagged_event_id: Option<EventId> =
|
||||||
|
tag.content().and_then(|h| EventId::from_hex(h).ok());
|
||||||
if let Some(tagged_event_id) = tagged_event_id {
|
if let Some(tagged_event_id) = tagged_event_id {
|
||||||
if event.id == tagged_event_id || event.referenced_event_ids().contains(&tagged_event_id) {
|
if event.id == tagged_event_id
|
||||||
return true
|
|| event.referenced_event_ids().contains(&tagged_event_id)
|
||||||
|
{
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TagKind::SingleLetter(SingleLetterTag { character: Alphabet::T, uppercase: false }) => {
|
TagKind::SingleLetter(SingleLetterTag {
|
||||||
|
character: Alphabet::T,
|
||||||
|
uppercase: false,
|
||||||
|
}) => {
|
||||||
let tagged_hashtag: Option<String> = tag.content().map(|h| h.to_string());
|
let tagged_hashtag: Option<String> = tag.content().map(|h| h.to_string());
|
||||||
if let Some(tagged_hashtag) = tagged_hashtag {
|
if let Some(tagged_hashtag) = tagged_hashtag {
|
||||||
let tags_content = event.get_tags_content(TagKind::SingleLetter(SingleLetterTag { character: Alphabet::T, uppercase: false }));
|
let tags_content =
|
||||||
|
event.get_tags_content(TagKind::SingleLetter(SingleLetterTag {
|
||||||
|
character: Alphabet::T,
|
||||||
|
uppercase: false,
|
||||||
|
}));
|
||||||
let should_mute = tags_content.iter().any(|t| t == &tagged_hashtag);
|
let should_mute = tags_content.iter().any(|t| t == &tagged_hashtag);
|
||||||
return should_mute
|
return should_mute;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TagKind::Word => {
|
TagKind::Word => {
|
||||||
let tagged_word: Option<String> = tag.content().map(|h| h.to_string());
|
let tagged_word: Option<String> = tag.content().map(|h| h.to_string());
|
||||||
if let Some(tagged_word) = tagged_word {
|
if let Some(tagged_word) = tagged_word {
|
||||||
if event.content.to_lowercase().contains(&tagged_word.to_lowercase()) {
|
if event
|
||||||
return true
|
.content
|
||||||
|
.to_lowercase()
|
||||||
|
.contains(&tagged_word.to_lowercase())
|
||||||
|
{
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -66,15 +88,23 @@ impl MuteManager {
|
|||||||
.authors(vec![pubkey.clone()])
|
.authors(vec![pubkey.clone()])
|
||||||
.limit(1);
|
.limit(1);
|
||||||
|
|
||||||
let this_subscription_id = self.client.subscribe(Vec::from([subscription_filter]), None).await;
|
let this_subscription_id = self
|
||||||
|
.client
|
||||||
|
.subscribe(Vec::from([subscription_filter]), None)
|
||||||
|
.await;
|
||||||
|
|
||||||
let mut mute_list: Option<Event> = None;
|
let mut mute_list: Option<Event> = None;
|
||||||
let mut notifications = self.client.notifications();
|
let mut notifications = self.client.notifications();
|
||||||
while let Ok(notification) = notifications.recv().await {
|
while let Ok(notification) = notifications.recv().await {
|
||||||
if let RelayPoolNotification::Event { subscription_id, event, .. } = notification {
|
if let RelayPoolNotification::Event {
|
||||||
|
subscription_id,
|
||||||
|
event,
|
||||||
|
..
|
||||||
|
} = notification
|
||||||
|
{
|
||||||
if this_subscription_id == subscription_id && event.kind == Kind::MuteList {
|
if this_subscription_id == subscription_id && event.kind == Kind::MuteList {
|
||||||
mute_list = Some((*event).clone());
|
mute_list = Some((*event).clone());
|
||||||
break
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use nostr::{self, key::PublicKey, TagKind::SingleLetter, Alphabet, SingleLetterTag};
|
use nostr::{self, key::PublicKey, Alphabet, SingleLetterTag, TagKind::SingleLetter};
|
||||||
|
|
||||||
/// Temporary scaffolding of old methods that have not been ported to use native Event methods
|
/// Temporary scaffolding of old methods that have not been ported to use native Event methods
|
||||||
pub trait ExtendedEvent {
|
pub trait ExtendedEvent {
|
||||||
@ -26,9 +26,7 @@ impl ExtendedEvent for nostr::Event {
|
|||||||
fn referenced_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey> {
|
fn referenced_pubkeys(&self) -> std::collections::HashSet<nostr::PublicKey> {
|
||||||
self.get_tags_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::P)))
|
self.get_tags_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::P)))
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|tag| {
|
.filter_map(|tag| PublicKey::from_hex(tag).ok())
|
||||||
PublicKey::from_hex(tag).ok()
|
|
||||||
})
|
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,9 +41,7 @@ impl ExtendedEvent for nostr::Event {
|
|||||||
fn referenced_event_ids(&self) -> std::collections::HashSet<nostr::EventId> {
|
fn referenced_event_ids(&self) -> std::collections::HashSet<nostr::EventId> {
|
||||||
self.get_tag_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::E)))
|
self.get_tag_content(SingleLetter(SingleLetterTag::lowercase(Alphabet::E)))
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|tag| {
|
.filter_map(|tag| nostr::EventId::from_hex(tag).ok())
|
||||||
nostr::EventId::from_hex(tag).ok()
|
|
||||||
})
|
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,7 +50,9 @@ impl ExtendedEvent for nostr::Event {
|
|||||||
|
|
||||||
pub trait SqlStringConvertible {
|
pub trait SqlStringConvertible {
|
||||||
fn to_sql_string(&self) -> String;
|
fn to_sql_string(&self) -> String;
|
||||||
fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>> where Self: Sized;
|
fn from_sql_string(s: String) -> Result<Self, Box<dyn std::error::Error>>
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SqlStringConvertible for nostr::EventId {
|
impl SqlStringConvertible for nostr::EventId {
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder};
|
use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder};
|
||||||
|
use log;
|
||||||
use nostr::event::EventId;
|
use nostr::event::EventId;
|
||||||
use nostr::key::PublicKey;
|
use nostr::key::PublicKey;
|
||||||
use nostr::types::Timestamp;
|
use nostr::types::Timestamp;
|
||||||
use rusqlite;
|
use rusqlite;
|
||||||
use rusqlite::params;
|
use rusqlite::params;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use log;
|
|
||||||
|
|
||||||
use std::fs::File;
|
|
||||||
use super::mute_manager::MuteManager;
|
use super::mute_manager::MuteManager;
|
||||||
use nostr::Event;
|
|
||||||
use super::SqlStringConvertible;
|
|
||||||
use super::ExtendedEvent;
|
use super::ExtendedEvent;
|
||||||
use r2d2_sqlite::SqliteConnectionManager;
|
use super::SqlStringConvertible;
|
||||||
|
use nostr::Event;
|
||||||
use r2d2;
|
use r2d2;
|
||||||
|
use r2d2_sqlite::SqliteConnectionManager;
|
||||||
|
use std::fs::File;
|
||||||
|
|
||||||
// MARK: - NotificationManager
|
// MARK: - NotificationManager
|
||||||
|
|
||||||
@ -32,7 +32,15 @@ pub struct NotificationManager {
|
|||||||
impl NotificationManager {
|
impl NotificationManager {
|
||||||
// MARK: - Initialization
|
// MARK: - Initialization
|
||||||
|
|
||||||
pub async fn new(db: r2d2::Pool<SqliteConnectionManager>, relay_url: String, apns_private_key_path: String, apns_private_key_id: String, apns_team_id: String, apns_environment: a2::client::Endpoint, apns_topic: String) -> Result<Self, Box<dyn std::error::Error>> {
|
pub async fn new(
|
||||||
|
db: r2d2::Pool<SqliteConnectionManager>,
|
||||||
|
relay_url: String,
|
||||||
|
apns_private_key_path: String,
|
||||||
|
apns_private_key_id: String,
|
||||||
|
apns_team_id: String,
|
||||||
|
apns_environment: a2::client::Endpoint,
|
||||||
|
apns_topic: String,
|
||||||
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
let mute_manager = MuteManager::new(relay_url.clone()).await?;
|
let mute_manager = MuteManager::new(relay_url.clone()).await?;
|
||||||
|
|
||||||
let connection = db.get()?;
|
let connection = db.get()?;
|
||||||
@ -44,7 +52,7 @@ impl NotificationManager {
|
|||||||
&mut file,
|
&mut file,
|
||||||
&apns_private_key_id,
|
&apns_private_key_id,
|
||||||
&apns_team_id,
|
&apns_team_id,
|
||||||
ClientConfig::new(apns_environment.clone())
|
ClientConfig::new(apns_environment.clone()),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -97,15 +105,24 @@ impl NotificationManager {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_column_if_not_exists(db: &rusqlite::Connection, table_name: &str, column_name: &str, column_type: &str) -> Result<(), rusqlite::Error> {
|
fn add_column_if_not_exists(
|
||||||
|
db: &rusqlite::Connection,
|
||||||
|
table_name: &str,
|
||||||
|
column_name: &str,
|
||||||
|
column_type: &str,
|
||||||
|
) -> Result<(), rusqlite::Error> {
|
||||||
let query = format!("PRAGMA table_info({})", table_name);
|
let query = format!("PRAGMA table_info({})", table_name);
|
||||||
let mut stmt = db.prepare(&query)?;
|
let mut stmt = db.prepare(&query)?;
|
||||||
let column_names: Vec<String> = stmt.query_map([], |row| row.get(1))?
|
let column_names: Vec<String> = stmt
|
||||||
|
.query_map([], |row| row.get(1))?
|
||||||
.filter_map(|r| r.ok())
|
.filter_map(|r| r.ok())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if !column_names.contains(&column_name.to_string()) {
|
if !column_names.contains(&column_name.to_string()) {
|
||||||
let query = format!("ALTER TABLE {} ADD COLUMN {} {}", table_name, column_name, column_type);
|
let query = format!(
|
||||||
|
"ALTER TABLE {} ADD COLUMN {} {}",
|
||||||
|
table_name, column_name, column_type
|
||||||
|
);
|
||||||
db.execute(&query, [])?;
|
db.execute(&query, [])?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -113,8 +130,14 @@ impl NotificationManager {
|
|||||||
|
|
||||||
// MARK: - Business logic
|
// MARK: - Business logic
|
||||||
|
|
||||||
pub async fn send_notifications_if_needed(&self, event: &Event) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn send_notifications_if_needed(
|
||||||
log::debug!("Checking if notifications need to be sent for event: {}", event.id);
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
log::debug!(
|
||||||
|
"Checking if notifications need to be sent for event: {}",
|
||||||
|
event.id
|
||||||
|
);
|
||||||
let one_week_ago = nostr::Timestamp::now() - 7 * 24 * 60 * 60;
|
let one_week_ago = nostr::Timestamp::now() - 7 * 24 * 60 * 60;
|
||||||
if event.created_at < one_week_ago {
|
if event.created_at < one_week_ago {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
@ -122,10 +145,14 @@ impl NotificationManager {
|
|||||||
|
|
||||||
let pubkeys_to_notify = self.pubkeys_to_notify_for_event(event).await?;
|
let pubkeys_to_notify = self.pubkeys_to_notify_for_event(event).await?;
|
||||||
|
|
||||||
log::debug!("Sending notifications to {} pubkeys", pubkeys_to_notify.len());
|
log::debug!(
|
||||||
|
"Sending notifications to {} pubkeys",
|
||||||
|
pubkeys_to_notify.len()
|
||||||
|
);
|
||||||
|
|
||||||
for pubkey in pubkeys_to_notify {
|
for pubkey in pubkeys_to_notify {
|
||||||
self.send_event_notifications_to_pubkey(event, &pubkey).await?;
|
self.send_event_notifications_to_pubkey(event, &pubkey)
|
||||||
|
.await?;
|
||||||
self.db.get()?.execute(
|
self.db.get()?.execute(
|
||||||
"INSERT OR REPLACE INTO notifications (id, event_id, pubkey, received_notification, sent_at)
|
"INSERT OR REPLACE INTO notifications (id, event_id, pubkey, received_notification, sent_at)
|
||||||
VALUES (?, ?, ?, ?, ?)",
|
VALUES (?, ?, ?, ?, ?)",
|
||||||
@ -141,10 +168,14 @@ impl NotificationManager {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn pubkeys_to_notify_for_event(&self, event: &Event) -> Result<HashSet<nostr::PublicKey>, Box<dyn std::error::Error>> {
|
async fn pubkeys_to_notify_for_event(
|
||||||
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
) -> Result<HashSet<nostr::PublicKey>, Box<dyn std::error::Error>> {
|
||||||
let notification_status = self.get_notification_status(event)?;
|
let notification_status = self.get_notification_status(event)?;
|
||||||
let relevant_pubkeys = self.pubkeys_relevant_to_event(event)?;
|
let relevant_pubkeys = self.pubkeys_relevant_to_event(event)?;
|
||||||
let pubkeys_that_received_notification = notification_status.pubkeys_that_received_notification();
|
let pubkeys_that_received_notification =
|
||||||
|
notification_status.pubkeys_that_received_notification();
|
||||||
let relevant_pubkeys_yet_to_receive: HashSet<PublicKey> = relevant_pubkeys
|
let relevant_pubkeys_yet_to_receive: HashSet<PublicKey> = relevant_pubkeys
|
||||||
.difference(&pubkeys_that_received_notification)
|
.difference(&pubkeys_that_received_notification)
|
||||||
.filter(|&x| *x != event.pubkey)
|
.filter(|&x| *x != event.pubkey)
|
||||||
@ -153,7 +184,10 @@ impl NotificationManager {
|
|||||||
|
|
||||||
let mut pubkeys_to_notify = HashSet::new();
|
let mut pubkeys_to_notify = HashSet::new();
|
||||||
for pubkey in relevant_pubkeys_yet_to_receive {
|
for pubkey in relevant_pubkeys_yet_to_receive {
|
||||||
let should_mute: bool = self.mute_manager.should_mute_notification_for_pubkey(event, &pubkey).await;
|
let should_mute: bool = self
|
||||||
|
.mute_manager
|
||||||
|
.should_mute_notification_for_pubkey(event, &pubkey)
|
||||||
|
.await;
|
||||||
if !should_mute {
|
if !should_mute {
|
||||||
pubkeys_to_notify.insert(pubkey);
|
pubkeys_to_notify.insert(pubkey);
|
||||||
}
|
}
|
||||||
@ -161,51 +195,77 @@ impl NotificationManager {
|
|||||||
Ok(pubkeys_to_notify)
|
Ok(pubkeys_to_notify)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pubkeys_relevant_to_event(&self, event: &Event) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
|
fn pubkeys_relevant_to_event(
|
||||||
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
|
||||||
let mut relevant_pubkeys = event.relevant_pubkeys();
|
let mut relevant_pubkeys = event.relevant_pubkeys();
|
||||||
let referenced_event_ids = event.referenced_event_ids();
|
let referenced_event_ids = event.referenced_event_ids();
|
||||||
for referenced_event_id in referenced_event_ids {
|
for referenced_event_id in referenced_event_ids {
|
||||||
let pubkeys_relevant_to_referenced_event = self.pubkeys_subscribed_to_event_id(&referenced_event_id)?;
|
let pubkeys_relevant_to_referenced_event =
|
||||||
|
self.pubkeys_subscribed_to_event_id(&referenced_event_id)?;
|
||||||
relevant_pubkeys.extend(pubkeys_relevant_to_referenced_event);
|
relevant_pubkeys.extend(pubkeys_relevant_to_referenced_event);
|
||||||
}
|
}
|
||||||
Ok(relevant_pubkeys)
|
Ok(relevant_pubkeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pubkeys_subscribed_to_event(&self, event: &Event) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
|
fn pubkeys_subscribed_to_event(
|
||||||
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
|
||||||
self.pubkeys_subscribed_to_event_id(&event.id)
|
self.pubkeys_subscribed_to_event_id(&event.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pubkeys_subscribed_to_event_id(&self, event_id: &EventId) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
|
fn pubkeys_subscribed_to_event_id(
|
||||||
|
&self,
|
||||||
|
event_id: &EventId,
|
||||||
|
) -> Result<HashSet<PublicKey>, Box<dyn std::error::Error>> {
|
||||||
let connection = self.db.get()?;
|
let connection = self.db.get()?;
|
||||||
let mut stmt = connection.prepare("SELECT pubkey FROM notifications WHERE event_id = ?")?;
|
let mut stmt = connection.prepare("SELECT pubkey FROM notifications WHERE event_id = ?")?;
|
||||||
let pubkeys = stmt.query_map([event_id.to_sql_string()], |row| row.get(0))?
|
let pubkeys = stmt
|
||||||
|
.query_map([event_id.to_sql_string()], |row| row.get(0))?
|
||||||
.filter_map(|r| r.ok())
|
.filter_map(|r| r.ok())
|
||||||
.filter_map(|r: String| PublicKey::from_sql_string(r).ok())
|
.filter_map(|r: String| PublicKey::from_sql_string(r).ok())
|
||||||
.collect();
|
.collect();
|
||||||
Ok(pubkeys)
|
Ok(pubkeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_event_notifications_to_pubkey(&self, event: &Event, pubkey: &PublicKey) -> Result<(), Box<dyn std::error::Error>> {
|
async fn send_event_notifications_to_pubkey(
|
||||||
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
pubkey: &PublicKey,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let user_device_tokens = self.get_user_device_tokens(pubkey)?;
|
let user_device_tokens = self.get_user_device_tokens(pubkey)?;
|
||||||
for device_token in user_device_tokens {
|
for device_token in user_device_tokens {
|
||||||
self.send_event_notification_to_device_token(event, &device_token).await?;
|
self.send_event_notification_to_device_token(event, &device_token)
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_user_device_tokens(&self, pubkey: &PublicKey) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
fn get_user_device_tokens(
|
||||||
|
&self,
|
||||||
|
pubkey: &PublicKey,
|
||||||
|
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
||||||
let connection = self.db.get()?;
|
let connection = self.db.get()?;
|
||||||
let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?;
|
let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?;
|
||||||
let device_tokens = stmt.query_map([pubkey.to_sql_string()], |row| row.get(0))?
|
let device_tokens = stmt
|
||||||
|
.query_map([pubkey.to_sql_string()], |row| row.get(0))?
|
||||||
.filter_map(|r| r.ok())
|
.filter_map(|r| r.ok())
|
||||||
.collect();
|
.collect();
|
||||||
Ok(device_tokens)
|
Ok(device_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_notification_status(&self, event: &Event) -> Result<NotificationStatus, Box<dyn std::error::Error>> {
|
fn get_notification_status(
|
||||||
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
) -> Result<NotificationStatus, Box<dyn std::error::Error>> {
|
||||||
let connection = self.db.get()?;
|
let connection = self.db.get()?;
|
||||||
let mut stmt = connection.prepare("SELECT pubkey, received_notification FROM notifications WHERE event_id = ?")?;
|
let mut stmt = connection.prepare(
|
||||||
let rows: std::collections::HashMap<PublicKey, bool> = stmt.query_map([event.id.to_sql_string()], |row| {
|
"SELECT pubkey, received_notification FROM notifications WHERE event_id = ?",
|
||||||
|
)?;
|
||||||
|
let rows: std::collections::HashMap<PublicKey, bool> = stmt
|
||||||
|
.query_map([event.id.to_sql_string()], |row| {
|
||||||
Ok((row.get(0)?, row.get(1)?))
|
Ok((row.get(0)?, row.get(1)?))
|
||||||
})?
|
})?
|
||||||
.filter_map(|r: Result<(String, bool), rusqlite::Error>| r.ok())
|
.filter_map(|r: Result<(String, bool), rusqlite::Error>| r.ok())
|
||||||
@ -225,7 +285,11 @@ impl NotificationManager {
|
|||||||
Ok(NotificationStatus { status_info })
|
Ok(NotificationStatus { status_info })
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_event_notification_to_device_token(&self, event: &Event, device_token: &str) -> Result<(), Box<dyn std::error::Error>> {
|
async fn send_event_notification_to_device_token(
|
||||||
|
&self,
|
||||||
|
event: &Event,
|
||||||
|
device_token: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (title, subtitle, body) = self.format_notification_message(event);
|
let (title, subtitle, body) = self.format_notification_message(event);
|
||||||
|
|
||||||
log::debug!("Sending notification to device token: {}", device_token);
|
log::debug!("Sending notification to device token: {}", device_token);
|
||||||
@ -237,10 +301,7 @@ impl NotificationManager {
|
|||||||
.set_mutable_content()
|
.set_mutable_content()
|
||||||
.set_content_available();
|
.set_content_available();
|
||||||
|
|
||||||
let mut payload = builder.build(
|
let mut payload = builder.build(device_token, Default::default());
|
||||||
device_token,
|
|
||||||
Default::default()
|
|
||||||
);
|
|
||||||
let _ = payload.add_custom_data("nostr_event", event);
|
let _ = payload.add_custom_data("nostr_event", event);
|
||||||
payload.options.apns_topic = Some(self.apns_topic.as_str());
|
payload.options.apns_topic = Some(self.apns_topic.as_str());
|
||||||
|
|
||||||
@ -258,7 +319,11 @@ impl NotificationManager {
|
|||||||
(title, subtitle, body)
|
(title, subtitle, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box<dyn std::error::Error>> {
|
pub fn save_user_device_info(
|
||||||
|
&self,
|
||||||
|
pubkey: nostr::PublicKey,
|
||||||
|
device_token: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let current_time_unix = Timestamp::now();
|
let current_time_unix = Timestamp::now();
|
||||||
self.db.get()?.execute(
|
self.db.get()?.execute(
|
||||||
"INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)",
|
"INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)",
|
||||||
@ -272,7 +337,11 @@ impl NotificationManager {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remove_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box<dyn std::error::Error>> {
|
pub fn remove_user_device_info(
|
||||||
|
&self,
|
||||||
|
pubkey: nostr::PublicKey,
|
||||||
|
device_token: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
self.db.get()?.execute(
|
self.db.get()?.execute(
|
||||||
"DELETE FROM user_info WHERE pubkey = ? AND device_token = ?",
|
"DELETE FROM user_info WHERE pubkey = ? AND device_token = ?",
|
||||||
params![pubkey.to_sql_string(), device_token],
|
params![pubkey.to_sql_string(), device_token],
|
||||||
|
@ -1,37 +1,42 @@
|
|||||||
|
use crate::notification_manager::NotificationManager;
|
||||||
|
use log;
|
||||||
use nostr::util::JsonUtil;
|
use nostr::util::JsonUtil;
|
||||||
use nostr::{RelayMessage, ClientMessage};
|
use nostr::{ClientMessage, RelayMessage};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::fmt::{self, Debug};
|
||||||
|
use std::net::TcpStream;
|
||||||
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use serde_json::Value;
|
|
||||||
use crate::notification_manager::NotificationManager;
|
|
||||||
use std::str::FromStr;
|
|
||||||
use std::net::TcpStream;
|
|
||||||
use tungstenite::{accept, WebSocket};
|
use tungstenite::{accept, WebSocket};
|
||||||
use log;
|
|
||||||
use std::fmt::{self, Debug};
|
|
||||||
|
|
||||||
const MAX_CONSECUTIVE_ERRORS: u32 = 10;
|
const MAX_CONSECUTIVE_ERRORS: u32 = 10;
|
||||||
|
|
||||||
pub struct RelayConnection {
|
pub struct RelayConnection {
|
||||||
websocket: WebSocket<TcpStream>,
|
websocket: WebSocket<TcpStream>,
|
||||||
notification_manager: Arc<Mutex<NotificationManager>>
|
notification_manager: Arc<Mutex<NotificationManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RelayConnection {
|
impl RelayConnection {
|
||||||
|
|
||||||
// MARK: - Initializers
|
// MARK: - Initializers
|
||||||
|
|
||||||
pub fn new(stream: TcpStream, notification_manager: Arc<Mutex<NotificationManager>>) -> Result<Self, Box<dyn std::error::Error>> {
|
pub fn new(
|
||||||
|
stream: TcpStream,
|
||||||
|
notification_manager: Arc<Mutex<NotificationManager>>,
|
||||||
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
let address = stream.peer_addr()?;
|
let address = stream.peer_addr()?;
|
||||||
let websocket = accept(stream)?;
|
let websocket = accept(stream)?;
|
||||||
log::info!("Accepted connection from {:?}", address);
|
log::info!("Accepted connection from {:?}", address);
|
||||||
Ok(RelayConnection {
|
Ok(RelayConnection {
|
||||||
websocket,
|
websocket,
|
||||||
notification_manager
|
notification_manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(stream: TcpStream, notification_manager: Arc<Mutex<NotificationManager>>) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run(
|
||||||
|
stream: TcpStream,
|
||||||
|
notification_manager: Arc<Mutex<NotificationManager>>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let mut connection = RelayConnection::new(stream, notification_manager)?;
|
let mut connection = RelayConnection::new(stream, notification_manager)?;
|
||||||
Ok(connection.run_loop().await?)
|
Ok(connection.run_loop().await?)
|
||||||
}
|
}
|
||||||
@ -47,10 +52,17 @@ impl RelayConnection {
|
|||||||
consecutive_errors = 0;
|
consecutive_errors = 0;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Error in websocket connection with {:?}: {:?}", self.websocket, e);
|
log::error!(
|
||||||
|
"Error in websocket connection with {:?}: {:?}",
|
||||||
|
self.websocket,
|
||||||
|
e
|
||||||
|
);
|
||||||
consecutive_errors += 1;
|
consecutive_errors += 1;
|
||||||
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
|
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
|
||||||
log::error!("Too many consecutive errors, closing connection with {:?}", self.websocket);
|
log::error!(
|
||||||
|
"Too many consecutive errors, closing connection with {:?}",
|
||||||
|
self.websocket
|
||||||
|
);
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -62,16 +74,21 @@ impl RelayConnection {
|
|||||||
let websocket = &mut self.websocket;
|
let websocket = &mut self.websocket;
|
||||||
let raw_message = websocket.read()?;
|
let raw_message = websocket.read()?;
|
||||||
if raw_message.is_text() {
|
if raw_message.is_text() {
|
||||||
let message: ClientMessage = ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?;
|
let message: ClientMessage =
|
||||||
|
ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?;
|
||||||
let response = self.handle_client_message(message).await?;
|
let response = self.handle_client_message(message).await?;
|
||||||
self.websocket.send(tungstenite::Message::text(response.try_as_json()?))?;
|
self.websocket
|
||||||
|
.send(tungstenite::Message::text(response.try_as_json()?))?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Message handling
|
// MARK: - Message handling
|
||||||
|
|
||||||
async fn handle_client_message<'b>(&'b self, message: ClientMessage) -> Result<RelayMessage, Box<dyn std::error::Error>> {
|
async fn handle_client_message<'b>(
|
||||||
|
&'b self,
|
||||||
|
message: ClientMessage,
|
||||||
|
) -> Result<RelayMessage, Box<dyn std::error::Error>> {
|
||||||
match message {
|
match message {
|
||||||
ClientMessage::Event(event) => {
|
ClientMessage::Event(event) => {
|
||||||
log::info!("Received event: {:?}", event);
|
log::info!("Received event: {:?}", event);
|
||||||
@ -81,13 +98,19 @@ impl RelayConnection {
|
|||||||
mutex_guard.send_notifications_if_needed(&event).await?;
|
mutex_guard.send_notifications_if_needed(&event).await?;
|
||||||
}; // Only hold the mutex for as little time as possible.
|
}; // Only hold the mutex for as little time as possible.
|
||||||
let notice_message = format!("blocked: This relay does not store events");
|
let notice_message = format!("blocked: This relay does not store events");
|
||||||
let response = RelayMessage::Ok { event_id: event.id, status: false, message: notice_message };
|
let response = RelayMessage::Ok {
|
||||||
|
event_id: event.id,
|
||||||
|
status: false,
|
||||||
|
message: notice_message,
|
||||||
|
};
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
log::info!("Received unsupported message: {:?}", message);
|
log::info!("Received unsupported message: {:?}", message);
|
||||||
let notice_message = format!("Unsupported message: {:?}", message);
|
let notice_message = format!("Unsupported message: {:?}", message);
|
||||||
let response = RelayMessage::Notice { message: notice_message };
|
let response = RelayMessage::Notice {
|
||||||
|
message: notice_message,
|
||||||
|
};
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user