From e2bdbf3c0c7fd30975e7c267ff3931eb7c251b31 Mon Sep 17 00:00:00 2001 From: Kieran Date: Wed, 21 May 2025 18:03:42 +0100 Subject: [PATCH] feat: setup alt backends feat: FCM backend --- Cargo.lock | 287 ++++++++++++++++++++++++++- Cargo.toml | 3 + src/api_request_handler.rs | 339 +++++++++++--------------------- src/main.rs | 45 +++-- src/notepush_env.rs | 24 ++- src/notification_manager/mod.rs | 218 ++++++++++++++++---- 6 files changed, 629 insertions(+), 287 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0d7c826..3a02680 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "a2" @@ -316,6 +316,12 @@ dependencies = [ "serde", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.6.0" @@ -413,6 +419,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-targets 0.52.6", ] @@ -434,6 +441,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -489,6 +506,15 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "env_filter" version = "0.1.0" @@ -527,6 +553,16 @@ dependencies = [ "serde", ] +[[package]] +name = "errno" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "fallible-iterator" version = "0.3.0" @@ -539,6 +575,25 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "fcm-service" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96aadee26986d122d1b8b4cdd6e7fe1310fdceeb1fd8db4d612d6510d96d7106" +dependencies = [ + "gcp_auth", + "reqwest", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "fnv" version = "1.0.7" @@ -658,6 +713,33 @@ dependencies = [ "slab", ] +[[package]] +name = "gcp_auth" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf67f30198e045a039264c01fb44659ce82402d7771c50938beb41a5ac87733" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bytes", + "chrono", + "home", + "http", + "http-body-util", + "hyper", + "hyper-rustls 0.27.2", + "hyper-util", + "ring", + "rustls-pemfile", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", + "tracing-futures", + "url", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -764,6 +846,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "http" version = "1.1.0" @@ -866,6 +957,7 @@ dependencies = [ "hyper", "hyper-util", "rustls 0.23.11", + "rustls-native-certs", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -873,6 +965,22 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-tungstenite" version = "0.14.0" @@ -1017,6 +1125,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "lnurl-pay" version = "0.5.0" @@ -1054,6 +1168,12 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "matchit" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f926ade0c4e170215ae43342bf13b9310a437609c81f29f86c5df6657582ef9" + [[package]] name = "memchr" version = "2.7.4" @@ -1086,6 +1206,23 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "negentropy" version = "0.3.1" @@ -1206,12 +1343,14 @@ dependencies = [ "chrono", "dotenv", "env_logger", + "fcm-service", "futures", "http-body-util", "hyper", "hyper-tungstenite", "hyper-util", "log", + "matchit", "nostr", "nostr-sdk", "r2d2", @@ -1224,6 +1363,7 @@ dependencies = [ "toml", "tracing", "tungstenite", + "urlencoding", "uuid", ] @@ -1287,7 +1427,7 @@ version = "0.10.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ - "bitflags", + "bitflags 2.6.0", "cfg-if", "foreign-types", "libc", @@ -1307,6 +1447,12 @@ dependencies = [ "syn", ] +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + [[package]] name = "openssl-sys" version = "0.9.102" @@ -1557,7 +1703,7 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" dependencies = [ - "bitflags", + "bitflags 2.6.0", ] [[package]] @@ -1597,18 +1743,22 @@ checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", + "h2", "http", "http-body", "http-body-util", "hyper", "hyper-rustls 0.27.2", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -1620,7 +1770,9 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper", + "system-configuration", "tokio", + "tokio-native-tls", "tokio-rustls 0.26.0", "tokio-socks", "tower-service", @@ -1653,7 +1805,7 @@ version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" dependencies = [ - "bitflags", + "bitflags 2.6.0", "fallible-iterator", "fallible-streaming-iterator", "hashlink", @@ -1683,6 +1835,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags 2.6.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + [[package]] name = "rustls" version = "0.22.4" @@ -1711,6 +1876,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "2.1.2" @@ -1753,6 +1931,15 @@ dependencies = [ "cipher", ] +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "scheduled-thread-pool" version = "0.2.7" @@ -1801,6 +1988,29 @@ dependencies = [ "cc", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.23" @@ -1942,6 +2152,40 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tempfile" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -2007,6 +2251,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.25.0" @@ -2137,6 +2391,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tracing-futures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" +dependencies = [ + "pin-project", + "tracing", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -2218,6 +2482,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf-8" version = "0.7.6" @@ -2396,6 +2666,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index ba0b9fa..43b923a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,3 +37,6 @@ uuid = { version = "1.10.0", features = ["v4"] } thiserror = "1.0.63" hyper-tungstenite = "0.14.0" futures = "0.3.30" +fcm-service = "0.1.5" +matchit = "0.8.6" +urlencoding = "2.1.3" diff --git a/src/api_request_handler.rs b/src/api_request_handler.rs index c77ba8c..c764a8a 100644 --- a/src/api_request_handler.rs +++ b/src/api_request_handler.rs @@ -1,19 +1,24 @@ use crate::nip98_auth; -use crate::notification_manager::UserNotificationSettings; +use crate::notification_manager::{NotificationBackend, UserNotificationSettings}; use crate::relay_connection::RelayConnection; use http_body_util::Full; use hyper::body::Buf; use hyper::body::Bytes; use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; +use std::borrow::Cow; use http_body_util::BodyExt; use serde_json::from_value; use crate::notification_manager::NotificationManager; use hyper::Method; +use log::warn; +use matchit::{Params, Router}; +use nostr::prelude::url::form_urlencoded; +use nostr::PublicKey; use serde_json::{json, Value}; -use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use thiserror::Error; @@ -65,6 +70,10 @@ impl APIHandler { status: StatusCode::UNAUTHORIZED, body: json!({ "error": "Unauthorized", "message": message }), }, + _ => APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": api_error.to_string() }), + }, } } else { // Otherwise, return a 500 status code @@ -154,6 +163,12 @@ impl APIHandler { method: req.method().clone(), body_bytes: body_bytes.map(|b| b.to_vec()), authorized_pubkey, + query: match req.uri().query() { + Some(q) => form_urlencoded::parse(q.as_bytes()) + .map(|(k, v)| (k.into_owned(), v.into_owned())) + .collect(), + None => vec![], + }, }) } @@ -163,38 +178,38 @@ impl APIHandler { &self, parsed_request: &ParsedRequest, ) -> Result> { - if let Some(url_params) = route_match( - &Method::PUT, - "/user-info/:pubkey/:deviceToken", - parsed_request, - ) { - return self.handle_user_info(parsed_request, &url_params).await; + enum RouteMatch { + UserInfo, + UserSetting, } + let mut routes = Router::new(); + routes.insert("/user-info/{pubkey}/{deviceToken}", RouteMatch::UserInfo)?; + routes.insert( + "/user-info/{pubkey}/{deviceToken}/preference", + RouteMatch::UserSetting, + )?; - if let Some(url_params) = route_match( - &Method::DELETE, - "/user-info/:pubkey/:deviceToken", - parsed_request, - ) { - return self - .handle_user_info_remove(parsed_request, &url_params) - .await; - } - - if let Some(url_params) = route_match( - &Method::GET, - "/user-info/:pubkey/:deviceToken/preferences", - parsed_request, - ) { - return self.get_user_settings(parsed_request, &url_params).await; - } - - if let Some(url_params) = route_match( - &Method::PUT, - "/user-info/:pubkey/:deviceToken/preferences", - parsed_request, - ) { - return self.set_user_settings(parsed_request, &url_params).await; + match routes.at(&parsed_request.uri) { + Ok(m) => match m.value { + RouteMatch::UserInfo if parsed_request.method == Method::PUT => { + return self.handle_user_info(parsed_request, m.params).await; + } + RouteMatch::UserInfo if parsed_request.method == Method::DELETE => { + return self.handle_user_info_remove(parsed_request, m.params).await; + } + RouteMatch::UserSetting if parsed_request.method == Method::GET => { + return self.get_user_settings(parsed_request, m.params).await; + } + RouteMatch::UserSetting if parsed_request.method == Method::PUT => { + return self.set_user_settings(parsed_request, m.params).await; + } + _ => { + // fallthrough to 404 + } + }, + Err(e) => { + warn!("Match failed: {}", e); + } } Ok(APIResponse { @@ -209,7 +224,7 @@ impl APIHandler { &self, req: &Request, body_bytes: Option<&[u8]>, - ) -> Result, Box> { + ) -> Result, Box> { let auth_header = match req.headers().get("Authorization") { Some(header) => header, None => return Ok(Err("Authorization header not found".to_string())), @@ -217,7 +232,11 @@ impl APIHandler { Ok(nip98_auth::nip98_verify_auth_header( auth_header.to_str()?.to_string(), - &format!("{}{}", self.base_url, req.uri().path()), + &format!( + "{}{}", + self.base_url, + req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("") + ), req.method().as_str(), body_bytes, ) @@ -229,52 +248,39 @@ impl APIHandler { async fn handle_user_info( &self, req: &ParsedRequest, - url_params: &HashMap<&str, String>, + url_params: Params<'_, '_>, ) -> Result> { - // Early return if `deviceToken` is missing - let device_token = match url_params.get("deviceToken") { - Some(token) => token, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "deviceToken is required on the URL" }), - }) - } - }; + let device_token = get_required_param(&url_params, "deviceToken")?; + let pubkey = get_required_param(&url_params, "pubkey")?; + let pubkey = check_pubkey(&pubkey, req)?; - // Early return if `pubkey` is missing - let pubkey = match url_params.get("pubkey") { - Some(key) => key, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "pubkey is required on the URL" }), - }) - } - }; - - // Validate the `pubkey` and prepare it for use - let pubkey = match nostr::PublicKey::from_hex(pubkey) { - Ok(key) => key, + let backend = match NotificationBackend::from_str( + req.query + .iter() + .find(|c| c.0 == "backend") + .map(|c| &c.1) + .unwrap_or(&"apns".to_string()), + ) { + Ok(token) => token, Err(_) => { return Ok(APIResponse { status: StatusCode::BAD_REQUEST, - body: json!({ "error": "Invalid pubkey" }), + body: json!({ "error": "backend is invalid" }), }) } }; - // Early return if `pubkey` does not match `req.authorized_pubkey` - if pubkey != req.authorized_pubkey { + // Early return if backend is not supported + if !self.notification_manager.has_backend(backend) { return Ok(APIResponse { - status: StatusCode::FORBIDDEN, - body: json!({ "error": "Forbidden" }), + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "Backend not supported" }), }); } // Proceed with the main logic after passing all checks self.notification_manager - .save_user_device_info_if_not_present(pubkey, device_token) + .save_user_device_info_if_not_present(pubkey, &device_token, backend) .await?; Ok(APIResponse { status: StatusCode::OK, @@ -285,52 +291,15 @@ impl APIHandler { async fn handle_user_info_remove( &self, req: &ParsedRequest, - url_params: &HashMap<&str, String>, + url_params: Params<'_, '_>, ) -> Result> { - // Early return if `deviceToken` is missing - let device_token = match url_params.get("deviceToken") { - Some(token) => token, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "deviceToken is required on the URL" }), - }) - } - }; - - // Early return if `pubkey` is missing - let pubkey = match url_params.get("pubkey") { - Some(key) => key, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "pubkey is required on the URL" }), - }) - } - }; - - // Validate the `pubkey` and prepare it for use - let pubkey = match nostr::PublicKey::from_hex(pubkey) { - Ok(key) => key, - Err(_) => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "Invalid pubkey" }), - }) - } - }; - - // Early return if `pubkey` does not match `req.authorized_pubkey` - if pubkey != req.authorized_pubkey { - return Ok(APIResponse { - status: StatusCode::FORBIDDEN, - body: json!({ "error": "Forbidden" }), - }); - } + let device_token = get_required_param(&url_params, "deviceToken")?; + let pubkey = get_required_param(&url_params, "pubkey")?; + let pubkey = check_pubkey(&pubkey, req)?; // Proceed with the main logic after passing all checks self.notification_manager - .remove_user_device_info(pubkey, device_token) + .remove_user_device_info(&pubkey, &device_token) .await?; Ok(APIResponse { @@ -342,48 +311,11 @@ impl APIHandler { async fn set_user_settings( &self, req: &ParsedRequest, - url_params: &HashMap<&str, String>, + url_params: Params<'_, '_>, ) -> Result> { - // Early return if `deviceToken` is missing - let device_token = match url_params.get("deviceToken") { - Some(token) => token, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "deviceToken is required on the URL" }), - }) - } - }; - - // Early return if `pubkey` is missing - let pubkey = match url_params.get("pubkey") { - Some(key) => key, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "pubkey is required on the URL" }), - }) - } - }; - - // Validate the `pubkey` and prepare it for use - let pubkey = match nostr::PublicKey::from_hex(pubkey) { - Ok(key) => key, - Err(_) => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "Invalid pubkey" }), - }) - } - }; - - // Early return if `pubkey` does not match `req.authorized_pubkey` - if pubkey != req.authorized_pubkey { - return Ok(APIResponse { - status: StatusCode::FORBIDDEN, - body: json!({ "error": "Forbidden" }), - }); - } + let device_token = get_required_param(&url_params, "deviceToken")?; + let pubkey = get_required_param(&url_params, "pubkey")?; + let pubkey = check_pubkey(&pubkey, req)?; // Proceed with the main logic after passing all checks let body = req.body_json()?; @@ -399,11 +331,7 @@ impl APIHandler { }; self.notification_manager - .save_user_notification_settings( - &req.authorized_pubkey, - device_token.to_string(), - settings, - ) + .save_user_notification_settings(&pubkey, device_token.to_string(), settings) .await?; Ok(APIResponse { @@ -415,53 +343,16 @@ impl APIHandler { async fn get_user_settings( &self, req: &ParsedRequest, - url_params: &HashMap<&str, String>, + url_params: Params<'_, '_>, ) -> Result> { - // Early return if `deviceToken` is missing - let device_token = match url_params.get("deviceToken") { - Some(token) => token, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "deviceToken is required on the URL" }), - }) - } - }; - - // Early return if `pubkey` is missing - let pubkey = match url_params.get("pubkey") { - Some(key) => key, - None => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "pubkey is required on the URL" }), - }) - } - }; - - // Validate the `pubkey` and prepare it for use - let pubkey = match nostr::PublicKey::from_hex(pubkey) { - Ok(key) => key, - Err(_) => { - return Ok(APIResponse { - status: StatusCode::BAD_REQUEST, - body: json!({ "error": "Invalid pubkey" }), - }) - } - }; - - // Early return if `pubkey` does not match `req.authorized_pubkey` - if pubkey != req.authorized_pubkey { - return Ok(APIResponse { - status: StatusCode::FORBIDDEN, - body: json!({ "error": "Forbidden" }), - }); - } + let device_token = get_required_param(&url_params, "deviceToken")?; + let pubkey = get_required_param(&url_params, "pubkey")?; + let pubkey = check_pubkey(&pubkey, req)?; // Proceed with the main logic after passing all checks let settings = self .notification_manager - .get_user_notification_settings(&req.authorized_pubkey, device_token.to_string()) + .get_user_notification_settings(&pubkey, device_token.to_string()) .await?; Ok(APIResponse { @@ -487,6 +378,10 @@ impl Clone for APIHandler { // Define enum error types including authentication error #[derive(Debug, Error)] enum APIError { + #[error("Missing required parameter {0}")] + MissingParameter(String), + #[error("Invalid parameter {0}")] + InvalidParameter(String), #[error("Authentication error: {0}")] AuthenticationError(String), } @@ -496,6 +391,7 @@ struct ParsedRequest { method: Method, body_bytes: Option>, authorized_pubkey: nostr::PublicKey, + query: Vec<(String, String)>, } impl ParsedRequest { @@ -513,34 +409,29 @@ struct APIResponse { body: Value, } -// MARK: - Helper functions - -/// Matches the request to a specified route, returning a hashmap of the route parameters -/// e.g. GET /user/:id/info route against request GET /user/123/info matches to { "id": "123" } -fn route_match<'a>( - method: &Method, - path: &'a str, - req: &ParsedRequest, -) -> Option> { - if method != req.method { - return None; +fn get_required_param(params: &Params<'_, '_>, key: &str) -> Result { + match params.get(key) { + Some(token) => urlencoding::decode(token) + .map(|s| match s { + Cow::Borrowed(s) => s.to_owned(), + Cow::Owned(s) => s, + }) + .map_err(|_| APIError::InvalidParameter(key.to_string())), + None => Err(APIError::MissingParameter(key.to_string())), + } +} + +fn parse_pubkey(pubkey: &str) -> Result { + PublicKey::from_hex(pubkey).map_err(|_| APIError::InvalidParameter("pubkey".to_string())) +} + +fn check_pubkey(pubkey: &str, req: &ParsedRequest) -> Result { + let pubkey = parse_pubkey(pubkey)?; + if pubkey != req.authorized_pubkey { + Err(APIError::AuthenticationError( + "pubkey doesnt match authorized pubkey".to_string(), + )) + } else { + Ok(pubkey) } - let mut params = HashMap::new(); - let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); - let req_segments: Vec<&str> = req.uri.split('/').filter(|s| !s.is_empty()).collect(); - - if path_segments.len() != req_segments.len() { - return None; - } - - for (i, segment) in path_segments.iter().enumerate() { - if let Some(key) = segment.strip_prefix(':') { - let value = req_segments[i].to_string(); - params.insert(key, value); - } else if segment != &req_segments[i] { - return None; - } - } - - Some(params) } diff --git a/src/main.rs b/src/main.rs index ea57cdd..b1ea848 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,20 +28,37 @@ async fn main() -> Result<(), Box> { r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool"); // Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter. // This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections. - let notification_manager = Arc::new( - notification_manager::NotificationManager::new( - pool, - env.relay_url.clone(), - env.apns_private_key_path.clone(), - env.apns_private_key_id.clone(), - env.apns_team_id.clone(), - env.apns_environment.clone(), - env.apns_topic.clone(), - env.nostr_event_cache_max_age, - ) - .await - .expect("Failed to create notification manager"), - ); + let mut notification_manager = notification_manager::NotificationManager::new( + pool, + env.relay_url.clone(), + env.nostr_event_cache_max_age, + ) + .await + .expect("Failed to create notification manager"); + + // setup apns if key path is set + if let Some(apns_key) = &env.apns_private_key_path { + notification_manager = notification_manager + .with_apns( + apns_key.clone(), + env.apns_private_key_id.unwrap().clone(), + env.apns_team_id.unwrap().clone(), + env.apns_environment.clone(), + env.apns_topic.unwrap().clone(), + ) + .expect("Failed to set APNs"); + log::info!("APNS configured for notifications manager!"); + } + + // setup fcm if google services path is set + if let Some(gsp) = &env.google_services_file_path { + notification_manager = notification_manager + .with_fcm(gsp) + .expect("Failed to setup FCM"); + log::info!("FCM configured for notifications manager!"); + } + + let notification_manager = Arc::new(notification_manager); let api_handler = Arc::new(api_request_handler::APIHandler::new( notification_manager.clone(), env.api_base_url.clone(), diff --git a/src/notepush_env.rs b/src/notepush_env.rs index 5784c64..a726ba1 100644 --- a/src/notepush_env.rs +++ b/src/notepush_env.rs @@ -9,15 +9,15 @@ const DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE: u64 = 60 * 60; // 1 hour pub struct NotePushEnv { // The path to the Apple private key .p8 file - pub apns_private_key_path: String, + pub apns_private_key_path: Option, // The Apple private key ID - pub apns_private_key_id: String, + pub apns_private_key_id: Option, // The Apple team ID - pub apns_team_id: String, + pub apns_team_id: Option, // The APNS environment to send notifications to (Sandbox or Production) pub apns_environment: a2::client::Endpoint, // The topic to send notifications to (The Apple app bundle ID) - pub apns_topic: String, + pub apns_topic: Option, // The path to the SQLite database file pub db_path: String, // The host and port to bind the relay and API to @@ -28,14 +28,18 @@ pub struct NotePushEnv { pub relay_url: String, // The max age of the Nostr event cache, in seconds pub nostr_event_cache_max_age: std::time::Duration, + // Path to google_services.json for FCM + pub google_services_file_path: Option, + // VAAPI key for WebPush (via FCM) + pub vaapi_key: Option, } impl NotePushEnv { pub fn load_env() -> Result { dotenv().ok(); - let apns_private_key_path = env::var("APNS_AUTH_PRIVATE_KEY_FILE_PATH")?; - let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?; - let apns_team_id = env::var("APPLE_TEAM_ID")?; + let apns_private_key_path = env::var("APNS_AUTH_PRIVATE_KEY_FILE_PATH").ok(); + let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID").ok(); + let apns_team_id = env::var("APPLE_TEAM_ID").ok(); let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string()); let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string()); let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string()); @@ -48,7 +52,7 @@ impl NotePushEnv { "production" => a2::client::Endpoint::Production, _ => a2::client::Endpoint::Sandbox, }; - let apns_topic = env::var("APNS_TOPIC")?; + let apns_topic = env::var("APNS_TOPIC").ok(); let nostr_event_cache_max_age = env::var("NOSTR_EVENT_CACHE_MAX_AGE") .unwrap_or(DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE.to_string()) .parse::() @@ -56,6 +60,8 @@ impl NotePushEnv { .unwrap_or(std::time::Duration::from_secs( DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE, )); + let google_services_file_path = env::var("GOOGLE_SERVICES_FILE_PATH").ok(); + let vaapi_key = env::var("VAAPI_KEY").ok(); Ok(NotePushEnv { apns_private_key_path, @@ -69,6 +75,8 @@ impl NotePushEnv { api_base_url, relay_url, nostr_event_cache_max_age, + google_services_file_path, + vaapi_key, }) } diff --git a/src/notification_manager/mod.rs b/src/notification_manager/mod.rs index cda2099..84d2983 100644 --- a/src/notification_manager/mod.rs +++ b/src/notification_manager/mod.rs @@ -3,8 +3,8 @@ mod nostr_event_extensions; pub mod nostr_network_helper; pub mod utils; -use std::cmp::{max, min}; use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible}; +use std::cmp::{max, min}; use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder}; use nostr::key::PublicKey; @@ -19,6 +19,7 @@ use std::collections::HashSet; use std::sync::Arc; use tokio::sync::Mutex; +use fcm_service::{FcmMessage, FcmNotification, FcmService, Target}; use nostr::Event; use nostr_event_extensions::Codable; use nostr_event_extensions::MaybeConvertibleToMuteList; @@ -26,6 +27,8 @@ use nostr_event_extensions::TimestampedMuteList; use nostr_network_helper::NostrNetworkHelper; use r2d2_sqlite::SqliteConnectionManager; use std::fs::File; +use std::str::FromStr; +use thiserror::Error; use utils::should_mute_notification_for_mutelist; // MARK: - NotificationManager @@ -37,10 +40,56 @@ const HELLTHREAD_MIN_PUBKEYS: i8 = 6; // Maximum threshold the hellthread pubkey tag count setting can go up to. const HELLTHREAD_MAX_PUBKEYS: i8 = 24; +#[derive(Debug, Clone, Copy)] +pub enum NotificationBackend { + APNS, + FCM, +} + +impl From for NotificationBackend { + fn from(value: u8) -> Self { + match value { + 0 => NotificationBackend::APNS, + 1 => NotificationBackend::FCM, + _ => panic!("Invalid value for NotificationBackend"), + } + } +} + +impl Into for NotificationBackend { + fn into(self) -> u8 { + match self { + NotificationBackend::APNS => 0, + NotificationBackend::FCM => 1, + } + } +} + +impl FromStr for NotificationBackend { + type Err = (); + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "fcm" => Ok(NotificationBackend::FCM), + "apns" => Ok(NotificationBackend::APNS), + _ => Err(()), + } + } +} + +#[derive(Error, Debug)] +pub enum NotificationManagerError { + #[error("APNS is not configured")] + APNSMissing, + #[error("FCM is not configured")] + FCMMissing, +} + pub struct NotificationManager { db: Arc>>, - apns_topic: String, - apns_client: Mutex, + apns_topic: Option, + apns_client: Option>, + fcm_client: Option>, nostr_network_helper: NostrNetworkHelper, pub event_saver: EventSaver, } @@ -151,32 +200,19 @@ impl NotificationManager { pub async fn new( db: r2d2::Pool, relay_url: String, - apns_private_key_path: String, - apns_private_key_id: String, - apns_team_id: String, - apns_environment: a2::client::Endpoint, - apns_topic: String, cache_max_age: std::time::Duration, ) -> Result> { let connection = db.get()?; Self::setup_database(&connection)?; - let mut file = File::open(&apns_private_key_path)?; - - let client = Client::token( - &mut file, - &apns_private_key_id, - &apns_team_id, - ClientConfig::new(apns_environment.clone()), - )?; - let db = Arc::new(Mutex::new(db)); let event_saver = EventSaver::new(db.clone()); let manager = NotificationManager { db, - apns_topic, - apns_client: Mutex::new(client), + apns_topic: None, + apns_client: None, + fcm_client: None, nostr_network_helper: NostrNetworkHelper::new( relay_url.clone(), cache_max_age, @@ -189,6 +225,46 @@ impl NotificationManager { Ok(manager) } + /// Adds APNS configuration + pub fn with_apns( + mut self, + apns_private_key_path: String, + apns_private_key_id: String, + apns_team_id: String, + apns_environment: a2::client::Endpoint, + apns_topic: String, + ) -> Result> { + let mut file = File::open(&apns_private_key_path)?; + + let client = Client::token( + &mut file, + &apns_private_key_id, + &apns_team_id, + ClientConfig::new(apns_environment.clone()), + )?; + + self.apns_client.replace(Mutex::new(client)); + self.apns_topic.replace(apns_topic); + + Ok(self) + } + + pub fn with_fcm( + mut self, + google_services_file_path: impl Into, + ) -> Result> { + self.fcm_client + .replace(Mutex::new(FcmService::new(google_services_file_path))); + Ok(self) + } + + pub fn has_backend(&self, backend: NotificationBackend) -> bool { + match backend { + NotificationBackend::APNS if !self.apns_client.is_some() => true, + NotificationBackend::FCM if self.fcm_client.is_some() => true, + _ => false, + } + } // MARK: - Database setup operations pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> { @@ -296,6 +372,10 @@ impl NotificationManager { [], )?; + // Migration for FCM + + Self::add_column_if_not_exists(db, "user_info", "backend", "TINYINT", Some("0"))?; + Ok(()) } @@ -507,14 +587,14 @@ impl NotificationManager { pubkey: &PublicKey, ) -> Result<(), Box> { let user_device_tokens = self.get_user_device_tokens(pubkey).await?; - for device_token in user_device_tokens { + for (device_token, backend) in user_device_tokens { if !self .user_wants_notification(pubkey, device_token.clone(), event) .await? { continue; } - self.send_event_notification_to_device_token(event, &device_token) + self.send_event_notification_to_device_token(event, &device_token, backend) .await?; } Ok(()) @@ -564,7 +644,12 @@ impl NotificationManager { pubkey: &PublicKey, device_token: &str, ) -> Result> { - let current_device_tokens = self.get_user_device_tokens(pubkey).await?; + let current_device_tokens: Vec = self + .get_user_device_tokens(pubkey) + .await? + .into_iter() + .map(|(d, _)| d) + .collect(); Ok(current_device_tokens.contains(&device_token.to_string())) } @@ -578,12 +663,18 @@ impl NotificationManager { async fn get_user_device_tokens( &self, pubkey: &PublicKey, - ) -> Result, Box> { + ) -> Result, Box> { let db_mutex_guard = self.db.lock().await; let connection = db_mutex_guard.get()?; - let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?; + let mut stmt = + connection.prepare("SELECT device_token,backend FROM user_info WHERE pubkey = ?")?; let device_tokens = stmt - .query_map([pubkey.to_sql_string()], |row| row.get(0))? + .query_map([pubkey.to_sql_string()], |row| { + Ok(( + row.get::(0)?, + row.get::(1)?.into(), + )) + })? .filter_map(|r| r.ok()) .collect(); Ok(device_tokens) @@ -623,11 +714,36 @@ impl NotificationManager { &self, event: &Event, device_token: &str, + backend: NotificationBackend, ) -> Result<(), Box> { - let (title, subtitle, body) = self.format_notification_message(event); - log::debug!("Sending notification to device token: {}", device_token); + match &backend { + NotificationBackend::APNS => { + self.send_event_notification_apns(event, device_token) + .await?; + } + NotificationBackend::FCM => { + self.send_event_notification_fcm(event, device_token) + .await?; + } + } + + log::info!("Notification sent to device token: {}", device_token); + Ok(()) + } + + async fn send_event_notification_apns( + &self, + event: &Event, + device_token: &str, + ) -> Result<(), Box> { + let client = self + .apns_client + .as_ref() + .ok_or(NotificationManagerError::APNSMissing)?; + + let (title, subtitle, body) = self.format_notification_message(event); let mut payload = DefaultNotificationBuilder::new() .set_title(&title) .set_subtitle(&subtitle) @@ -636,14 +752,15 @@ impl NotificationManager { .set_content_available() .build(device_token, Default::default()); - payload.options.apns_topic = Some(self.apns_topic.as_str()); + if let Some(t) = self.apns_topic.as_ref() { + payload.options.apns_topic = Some(t); + } payload.data.insert( "nostr_event", serde_json::Value::String(event.try_as_json()?), ); - let apns_client_mutex_guard = self.apns_client.lock().await; - + let apns_client_mutex_guard = client.lock().await; match apns_client_mutex_guard.send(payload).await { Ok(_response) => {} Err(e) => log::error!( @@ -653,7 +770,29 @@ impl NotificationManager { ), } - log::info!("Notification sent to device token: {}", device_token); + Ok(()) + } + + async fn send_event_notification_fcm( + &self, + event: &Event, + device_token: &str, + ) -> Result<(), Box> { + let client = self + .fcm_client + .as_ref() + .ok_or(NotificationManagerError::FCMMissing)?; + + let (title, _, body) = self.format_notification_message(event); + + let client = client.lock().await; + let mut msg = FcmMessage::new(); + let mut notification = FcmNotification::new(); + notification.set_title(title); + notification.set_body(body); + msg.set_notification(Some(notification)); + msg.set_target(Target::Token(device_token.into())); + client.send_notification(msg).await?; Ok(()) } @@ -693,6 +832,7 @@ impl NotificationManager { &self, pubkey: nostr::PublicKey, device_token: &str, + backend: NotificationBackend, ) -> Result<(), Box> { if self .is_pubkey_token_pair_registered(&pubkey, device_token) @@ -700,23 +840,26 @@ impl NotificationManager { { return Ok(()); } - self.save_user_device_info(pubkey, device_token).await + self.save_user_device_info(pubkey, device_token, backend) + .await } pub async fn save_user_device_info( &self, pubkey: nostr::PublicKey, device_token: &str, + backend: NotificationBackend, ) -> Result<(), Box> { let current_time_unix = Timestamp::now(); let db_mutex_guard = self.db.lock().await; db_mutex_guard.get()?.execute( - "INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)", + "INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at, backend) VALUES (?, ?, ?, ?, ?)", params![ format!("{}:{}", pubkey.to_sql_string(), device_token), pubkey.to_sql_string(), device_token, - current_time_unix.to_sql_string() + current_time_unix.to_sql_string(), + >::into(backend) ], )?; Ok(()) @@ -724,7 +867,7 @@ impl NotificationManager { pub async fn remove_user_device_info( &self, - pubkey: nostr::PublicKey, + pubkey: &PublicKey, device_token: &str, ) -> Result<(), Box> { let db_mutex_guard = self.db.lock().await; @@ -754,7 +897,9 @@ impl NotificationManager { dm_notifications_enabled: row.get(4)?, only_notifications_from_following_enabled: row.get(5)?, hellthread_notifications_disabled: row.get::<_, Option>(6)?.unwrap_or(false), - hellthread_notifications_max_pubkeys: row.get::<_, Option>(7)?.unwrap_or(DEFAULT_HELLTHREAD_MAX_PUBKEYS), + hellthread_notifications_max_pubkeys: row + .get::<_, Option>(7)? + .unwrap_or(DEFAULT_HELLTHREAD_MAX_PUBKEYS), }) })?; @@ -792,7 +937,6 @@ fn default_hellthread_max_pubkeys() -> i8 { DEFAULT_HELLTHREAD_MAX_PUBKEYS } - #[derive(Serialize, Deserialize, Debug)] pub struct UserNotificationSettings { zap_notifications_enabled: bool,