feat: setup alt backends

feat: FCM backend
This commit is contained in:
2025-05-21 18:03:42 +01:00
parent e73f5ce70e
commit e2bdbf3c0c
6 changed files with 629 additions and 287 deletions

287
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "a2" name = "a2"
@ -316,6 +316,12 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.6.0" version = "2.6.0"
@ -413,6 +419,7 @@ dependencies = [
"iana-time-zone", "iana-time-zone",
"js-sys", "js-sys",
"num-traits", "num-traits",
"serde",
"wasm-bindgen", "wasm-bindgen",
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
@ -434,6 +441,16 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.6" version = "0.8.6"
@ -489,6 +506,15 @@ version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "env_filter" name = "env_filter"
version = "0.1.0" version = "0.1.0"
@ -527,6 +553,16 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "errno"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18"
dependencies = [
"libc",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "fallible-iterator" name = "fallible-iterator"
version = "0.3.0" version = "0.3.0"
@ -539,6 +575,25 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fcm-service"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96aadee26986d122d1b8b4cdd6e7fe1310fdceeb1fd8db4d612d6510d96d7106"
dependencies = [
"gcp_auth",
"reqwest",
"serde",
"serde_json",
"tokio",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@ -658,6 +713,33 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "gcp_auth"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbf67f30198e045a039264c01fb44659ce82402d7771c50938beb41a5ac87733"
dependencies = [
"async-trait",
"base64 0.22.1",
"bytes",
"chrono",
"home",
"http",
"http-body-util",
"hyper",
"hyper-rustls 0.27.2",
"hyper-util",
"ring",
"rustls-pemfile",
"serde",
"serde_json",
"thiserror",
"tokio",
"tracing",
"tracing-futures",
"url",
]
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.7" version = "0.14.7"
@ -764,6 +846,15 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.1.0" version = "1.1.0"
@ -866,6 +957,7 @@ dependencies = [
"hyper", "hyper",
"hyper-util", "hyper-util",
"rustls 0.23.11", "rustls 0.23.11",
"rustls-native-certs",
"rustls-pki-types", "rustls-pki-types",
"tokio", "tokio",
"tokio-rustls 0.26.0", "tokio-rustls 0.26.0",
@ -873,6 +965,22 @@ dependencies = [
"webpki-roots", "webpki-roots",
] ]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]] [[package]]
name = "hyper-tungstenite" name = "hyper-tungstenite"
version = "0.14.0" version = "0.14.0"
@ -1017,6 +1125,12 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]] [[package]]
name = "lnurl-pay" name = "lnurl-pay"
version = "0.5.0" version = "0.5.0"
@ -1054,6 +1168,12 @@ dependencies = [
"hashbrown", "hashbrown",
] ]
[[package]]
name = "matchit"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f926ade0c4e170215ae43342bf13b9310a437609c81f29f86c5df6657582ef9"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.7.4" version = "2.7.4"
@ -1086,6 +1206,23 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "native-tls"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]] [[package]]
name = "negentropy" name = "negentropy"
version = "0.3.1" version = "0.3.1"
@ -1206,12 +1343,14 @@ dependencies = [
"chrono", "chrono",
"dotenv", "dotenv",
"env_logger", "env_logger",
"fcm-service",
"futures", "futures",
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-tungstenite", "hyper-tungstenite",
"hyper-util", "hyper-util",
"log", "log",
"matchit",
"nostr", "nostr",
"nostr-sdk", "nostr-sdk",
"r2d2", "r2d2",
@ -1224,6 +1363,7 @@ dependencies = [
"toml", "toml",
"tracing", "tracing",
"tungstenite", "tungstenite",
"urlencoding",
"uuid", "uuid",
] ]
@ -1287,7 +1427,7 @@ version = "0.10.64"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"
dependencies = [ dependencies = [
"bitflags", "bitflags 2.6.0",
"cfg-if", "cfg-if",
"foreign-types", "foreign-types",
"libc", "libc",
@ -1307,6 +1447,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]] [[package]]
name = "openssl-sys" name = "openssl-sys"
version = "0.9.102" version = "0.9.102"
@ -1557,7 +1703,7 @@ version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [ dependencies = [
"bitflags", "bitflags 2.6.0",
] ]
[[package]] [[package]]
@ -1597,18 +1743,22 @@ checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
"encoding_rs",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2",
"http", "http",
"http-body", "http-body",
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-rustls 0.27.2", "hyper-rustls 0.27.2",
"hyper-tls",
"hyper-util", "hyper-util",
"ipnet", "ipnet",
"js-sys", "js-sys",
"log", "log",
"mime", "mime",
"native-tls",
"once_cell", "once_cell",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
@ -1620,7 +1770,9 @@ dependencies = [
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"sync_wrapper", "sync_wrapper",
"system-configuration",
"tokio", "tokio",
"tokio-native-tls",
"tokio-rustls 0.26.0", "tokio-rustls 0.26.0",
"tokio-socks", "tokio-socks",
"tower-service", "tower-service",
@ -1653,7 +1805,7 @@ version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae"
dependencies = [ dependencies = [
"bitflags", "bitflags 2.6.0",
"fallible-iterator", "fallible-iterator",
"fallible-streaming-iterator", "fallible-streaming-iterator",
"hashlink", "hashlink",
@ -1683,6 +1835,19 @@ dependencies = [
"semver", "semver",
] ]
[[package]]
name = "rustix"
version = "0.38.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
dependencies = [
"bitflags 2.6.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.22.4" version = "0.22.4"
@ -1711,6 +1876,19 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "rustls-native-certs"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]] [[package]]
name = "rustls-pemfile" name = "rustls-pemfile"
version = "2.1.2" version = "2.1.2"
@ -1753,6 +1931,15 @@ dependencies = [
"cipher", "cipher",
] ]
[[package]]
name = "schannel"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "scheduled-thread-pool" name = "scheduled-thread-pool"
version = "0.2.7" version = "0.2.7"
@ -1801,6 +1988,29 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "semver" name = "semver"
version = "1.0.23" version = "1.0.23"
@ -1942,6 +2152,40 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
[[package]]
name = "system-configuration"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tempfile"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
dependencies = [
"cfg-if",
"fastrand",
"once_cell",
"rustix",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.63" version = "1.0.63"
@ -2007,6 +2251,16 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.25.0" version = "0.25.0"
@ -2137,6 +2391,16 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "tracing-futures"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2"
dependencies = [
"pin-project",
"tracing",
]
[[package]] [[package]]
name = "try-lock" name = "try-lock"
version = "0.2.5" version = "0.2.5"
@ -2218,6 +2482,12 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]] [[package]]
name = "utf-8" name = "utf-8"
version = "0.7.6" version = "0.7.6"
@ -2396,6 +2666,15 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.48.5" version = "0.48.5"

View File

@ -37,3 +37,6 @@ uuid = { version = "1.10.0", features = ["v4"] }
thiserror = "1.0.63" thiserror = "1.0.63"
hyper-tungstenite = "0.14.0" hyper-tungstenite = "0.14.0"
futures = "0.3.30" futures = "0.3.30"
fcm-service = "0.1.5"
matchit = "0.8.6"
urlencoding = "2.1.3"

View File

@ -1,19 +1,24 @@
use crate::nip98_auth; use crate::nip98_auth;
use crate::notification_manager::UserNotificationSettings; use crate::notification_manager::{NotificationBackend, UserNotificationSettings};
use crate::relay_connection::RelayConnection; use crate::relay_connection::RelayConnection;
use http_body_util::Full; use http_body_util::Full;
use hyper::body::Buf; use hyper::body::Buf;
use hyper::body::Bytes; use hyper::body::Bytes;
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode}; use hyper::{Request, Response, StatusCode};
use std::borrow::Cow;
use http_body_util::BodyExt; use http_body_util::BodyExt;
use serde_json::from_value; use serde_json::from_value;
use crate::notification_manager::NotificationManager; use crate::notification_manager::NotificationManager;
use hyper::Method; use hyper::Method;
use log::warn;
use matchit::{Params, Router};
use nostr::prelude::url::form_urlencoded;
use nostr::PublicKey;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
@ -65,6 +70,10 @@ impl APIHandler {
status: StatusCode::UNAUTHORIZED, status: StatusCode::UNAUTHORIZED,
body: json!({ "error": "Unauthorized", "message": message }), body: json!({ "error": "Unauthorized", "message": message }),
}, },
_ => APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": api_error.to_string() }),
},
} }
} else { } else {
// Otherwise, return a 500 status code // Otherwise, return a 500 status code
@ -154,6 +163,12 @@ impl APIHandler {
method: req.method().clone(), method: req.method().clone(),
body_bytes: body_bytes.map(|b| b.to_vec()), body_bytes: body_bytes.map(|b| b.to_vec()),
authorized_pubkey, authorized_pubkey,
query: match req.uri().query() {
Some(q) => form_urlencoded::parse(q.as_bytes())
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect(),
None => vec![],
},
}) })
} }
@ -163,38 +178,38 @@ impl APIHandler {
&self, &self,
parsed_request: &ParsedRequest, parsed_request: &ParsedRequest,
) -> Result<APIResponse, Box<dyn std::error::Error>> { ) -> Result<APIResponse, Box<dyn std::error::Error>> {
if let Some(url_params) = route_match( enum RouteMatch {
&Method::PUT, UserInfo,
"/user-info/:pubkey/:deviceToken", UserSetting,
parsed_request,
) {
return self.handle_user_info(parsed_request, &url_params).await;
} }
let mut routes = Router::new();
routes.insert("/user-info/{pubkey}/{deviceToken}", RouteMatch::UserInfo)?;
routes.insert(
"/user-info/{pubkey}/{deviceToken}/preference",
RouteMatch::UserSetting,
)?;
if let Some(url_params) = route_match( match routes.at(&parsed_request.uri) {
&Method::DELETE, Ok(m) => match m.value {
"/user-info/:pubkey/:deviceToken", RouteMatch::UserInfo if parsed_request.method == Method::PUT => {
parsed_request, return self.handle_user_info(parsed_request, m.params).await;
) { }
return self RouteMatch::UserInfo if parsed_request.method == Method::DELETE => {
.handle_user_info_remove(parsed_request, &url_params) return self.handle_user_info_remove(parsed_request, m.params).await;
.await; }
} RouteMatch::UserSetting if parsed_request.method == Method::GET => {
return self.get_user_settings(parsed_request, m.params).await;
if let Some(url_params) = route_match( }
&Method::GET, RouteMatch::UserSetting if parsed_request.method == Method::PUT => {
"/user-info/:pubkey/:deviceToken/preferences", return self.set_user_settings(parsed_request, m.params).await;
parsed_request, }
) { _ => {
return self.get_user_settings(parsed_request, &url_params).await; // fallthrough to 404
} }
},
if let Some(url_params) = route_match( Err(e) => {
&Method::PUT, warn!("Match failed: {}", e);
"/user-info/:pubkey/:deviceToken/preferences", }
parsed_request,
) {
return self.set_user_settings(parsed_request, &url_params).await;
} }
Ok(APIResponse { Ok(APIResponse {
@ -209,7 +224,7 @@ impl APIHandler {
&self, &self,
req: &Request<Incoming>, req: &Request<Incoming>,
body_bytes: Option<&[u8]>, body_bytes: Option<&[u8]>,
) -> Result<Result<nostr::PublicKey, String>, Box<dyn std::error::Error>> { ) -> Result<Result<PublicKey, String>, Box<dyn std::error::Error>> {
let auth_header = match req.headers().get("Authorization") { let auth_header = match req.headers().get("Authorization") {
Some(header) => header, Some(header) => header,
None => return Ok(Err("Authorization header not found".to_string())), None => return Ok(Err("Authorization header not found".to_string())),
@ -217,7 +232,11 @@ impl APIHandler {
Ok(nip98_auth::nip98_verify_auth_header( Ok(nip98_auth::nip98_verify_auth_header(
auth_header.to_str()?.to_string(), auth_header.to_str()?.to_string(),
&format!("{}{}", self.base_url, req.uri().path()), &format!(
"{}{}",
self.base_url,
req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("")
),
req.method().as_str(), req.method().as_str(),
body_bytes, body_bytes,
) )
@ -229,52 +248,39 @@ impl APIHandler {
async fn handle_user_info( async fn handle_user_info(
&self, &self,
req: &ParsedRequest, req: &ParsedRequest,
url_params: &HashMap<&str, String>, url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> { ) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing let device_token = get_required_param(&url_params, "deviceToken")?;
let device_token = match url_params.get("deviceToken") { let pubkey = get_required_param(&url_params, "pubkey")?;
Some(token) => token, let pubkey = check_pubkey(&pubkey, req)?;
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
// Early return if `pubkey` is missing let backend = match NotificationBackend::from_str(
let pubkey = match url_params.get("pubkey") { req.query
Some(key) => key, .iter()
None => { .find(|c| c.0 == "backend")
return Ok(APIResponse { .map(|c| &c.1)
status: StatusCode::BAD_REQUEST, .unwrap_or(&"apns".to_string()),
body: json!({ "error": "pubkey is required on the URL" }), ) {
}) Ok(token) => token,
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => { Err(_) => {
return Ok(APIResponse { return Ok(APIResponse {
status: StatusCode::BAD_REQUEST, status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }), body: json!({ "error": "backend is invalid" }),
}) })
} }
}; };
// Early return if `pubkey` does not match `req.authorized_pubkey` // Early return if backend is not supported
if pubkey != req.authorized_pubkey { if !self.notification_manager.has_backend(backend) {
return Ok(APIResponse { return Ok(APIResponse {
status: StatusCode::FORBIDDEN, status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Forbidden" }), body: json!({ "error": "Backend not supported" }),
}); });
} }
// Proceed with the main logic after passing all checks // Proceed with the main logic after passing all checks
self.notification_manager self.notification_manager
.save_user_device_info_if_not_present(pubkey, device_token) .save_user_device_info_if_not_present(pubkey, &device_token, backend)
.await?; .await?;
Ok(APIResponse { Ok(APIResponse {
status: StatusCode::OK, status: StatusCode::OK,
@ -285,52 +291,15 @@ impl APIHandler {
async fn handle_user_info_remove( async fn handle_user_info_remove(
&self, &self,
req: &ParsedRequest, req: &ParsedRequest,
url_params: &HashMap<&str, String>, url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> { ) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing let device_token = get_required_param(&url_params, "deviceToken")?;
let device_token = match url_params.get("deviceToken") { let pubkey = get_required_param(&url_params, "pubkey")?;
Some(token) => token, let pubkey = check_pubkey(&pubkey, req)?;
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
})
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
})
}
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
// Proceed with the main logic after passing all checks // Proceed with the main logic after passing all checks
self.notification_manager self.notification_manager
.remove_user_device_info(pubkey, device_token) .remove_user_device_info(&pubkey, &device_token)
.await?; .await?;
Ok(APIResponse { Ok(APIResponse {
@ -342,48 +311,11 @@ impl APIHandler {
async fn set_user_settings( async fn set_user_settings(
&self, &self,
req: &ParsedRequest, req: &ParsedRequest,
url_params: &HashMap<&str, String>, url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> { ) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing let device_token = get_required_param(&url_params, "deviceToken")?;
let device_token = match url_params.get("deviceToken") { let pubkey = get_required_param(&url_params, "pubkey")?;
Some(token) => token, let pubkey = check_pubkey(&pubkey, req)?;
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
})
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
})
}
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
// Proceed with the main logic after passing all checks // Proceed with the main logic after passing all checks
let body = req.body_json()?; let body = req.body_json()?;
@ -399,11 +331,7 @@ impl APIHandler {
}; };
self.notification_manager self.notification_manager
.save_user_notification_settings( .save_user_notification_settings(&pubkey, device_token.to_string(), settings)
&req.authorized_pubkey,
device_token.to_string(),
settings,
)
.await?; .await?;
Ok(APIResponse { Ok(APIResponse {
@ -415,53 +343,16 @@ impl APIHandler {
async fn get_user_settings( async fn get_user_settings(
&self, &self,
req: &ParsedRequest, req: &ParsedRequest,
url_params: &HashMap<&str, String>, url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> { ) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing let device_token = get_required_param(&url_params, "deviceToken")?;
let device_token = match url_params.get("deviceToken") { let pubkey = get_required_param(&url_params, "pubkey")?;
Some(token) => token, let pubkey = check_pubkey(&pubkey, req)?;
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
})
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
})
}
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
// Proceed with the main logic after passing all checks // Proceed with the main logic after passing all checks
let settings = self let settings = self
.notification_manager .notification_manager
.get_user_notification_settings(&req.authorized_pubkey, device_token.to_string()) .get_user_notification_settings(&pubkey, device_token.to_string())
.await?; .await?;
Ok(APIResponse { Ok(APIResponse {
@ -487,6 +378,10 @@ impl Clone for APIHandler {
// Define enum error types including authentication error // Define enum error types including authentication error
#[derive(Debug, Error)] #[derive(Debug, Error)]
enum APIError { enum APIError {
#[error("Missing required parameter {0}")]
MissingParameter(String),
#[error("Invalid parameter {0}")]
InvalidParameter(String),
#[error("Authentication error: {0}")] #[error("Authentication error: {0}")]
AuthenticationError(String), AuthenticationError(String),
} }
@ -496,6 +391,7 @@ struct ParsedRequest {
method: Method, method: Method,
body_bytes: Option<Vec<u8>>, body_bytes: Option<Vec<u8>>,
authorized_pubkey: nostr::PublicKey, authorized_pubkey: nostr::PublicKey,
query: Vec<(String, String)>,
} }
impl ParsedRequest { impl ParsedRequest {
@ -513,34 +409,29 @@ struct APIResponse {
body: Value, body: Value,
} }
// MARK: - Helper functions fn get_required_param(params: &Params<'_, '_>, key: &str) -> Result<String, APIError> {
match params.get(key) {
/// Matches the request to a specified route, returning a hashmap of the route parameters Some(token) => urlencoding::decode(token)
/// e.g. GET /user/:id/info route against request GET /user/123/info matches to { "id": "123" } .map(|s| match s {
fn route_match<'a>( Cow::Borrowed(s) => s.to_owned(),
method: &Method, Cow::Owned(s) => s,
path: &'a str, })
req: &ParsedRequest, .map_err(|_| APIError::InvalidParameter(key.to_string())),
) -> Option<HashMap<&'a str, String>> { None => Err(APIError::MissingParameter(key.to_string())),
if method != req.method { }
return None; }
fn parse_pubkey(pubkey: &str) -> Result<PublicKey, APIError> {
PublicKey::from_hex(pubkey).map_err(|_| APIError::InvalidParameter("pubkey".to_string()))
}
fn check_pubkey(pubkey: &str, req: &ParsedRequest) -> Result<PublicKey, APIError> {
let pubkey = parse_pubkey(pubkey)?;
if pubkey != req.authorized_pubkey {
Err(APIError::AuthenticationError(
"pubkey doesnt match authorized pubkey".to_string(),
))
} else {
Ok(pubkey)
} }
let mut params = HashMap::new();
let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let req_segments: Vec<&str> = req.uri.split('/').filter(|s| !s.is_empty()).collect();
if path_segments.len() != req_segments.len() {
return None;
}
for (i, segment) in path_segments.iter().enumerate() {
if let Some(key) = segment.strip_prefix(':') {
let value = req_segments[i].to_string();
params.insert(key, value);
} else if segment != &req_segments[i] {
return None;
}
}
Some(params)
} }

View File

@ -28,20 +28,37 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool"); r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool");
// Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter. // Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter.
// This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections. // This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections.
let notification_manager = Arc::new( let mut notification_manager = notification_manager::NotificationManager::new(
notification_manager::NotificationManager::new( pool,
pool, env.relay_url.clone(),
env.relay_url.clone(), env.nostr_event_cache_max_age,
env.apns_private_key_path.clone(), )
env.apns_private_key_id.clone(), .await
env.apns_team_id.clone(), .expect("Failed to create notification manager");
env.apns_environment.clone(),
env.apns_topic.clone(), // setup apns if key path is set
env.nostr_event_cache_max_age, if let Some(apns_key) = &env.apns_private_key_path {
) notification_manager = notification_manager
.await .with_apns(
.expect("Failed to create notification manager"), apns_key.clone(),
); env.apns_private_key_id.unwrap().clone(),
env.apns_team_id.unwrap().clone(),
env.apns_environment.clone(),
env.apns_topic.unwrap().clone(),
)
.expect("Failed to set APNs");
log::info!("APNS configured for notifications manager!");
}
// setup fcm if google services path is set
if let Some(gsp) = &env.google_services_file_path {
notification_manager = notification_manager
.with_fcm(gsp)
.expect("Failed to setup FCM");
log::info!("FCM configured for notifications manager!");
}
let notification_manager = Arc::new(notification_manager);
let api_handler = Arc::new(api_request_handler::APIHandler::new( let api_handler = Arc::new(api_request_handler::APIHandler::new(
notification_manager.clone(), notification_manager.clone(),
env.api_base_url.clone(), env.api_base_url.clone(),

View File

@ -9,15 +9,15 @@ const DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE: u64 = 60 * 60; // 1 hour
pub struct NotePushEnv { pub struct NotePushEnv {
// The path to the Apple private key .p8 file // The path to the Apple private key .p8 file
pub apns_private_key_path: String, pub apns_private_key_path: Option<String>,
// The Apple private key ID // The Apple private key ID
pub apns_private_key_id: String, pub apns_private_key_id: Option<String>,
// The Apple team ID // The Apple team ID
pub apns_team_id: String, pub apns_team_id: Option<String>,
// The APNS environment to send notifications to (Sandbox or Production) // The APNS environment to send notifications to (Sandbox or Production)
pub apns_environment: a2::client::Endpoint, pub apns_environment: a2::client::Endpoint,
// The topic to send notifications to (The Apple app bundle ID) // The topic to send notifications to (The Apple app bundle ID)
pub apns_topic: String, pub apns_topic: Option<String>,
// The path to the SQLite database file // The path to the SQLite database file
pub db_path: String, pub db_path: String,
// The host and port to bind the relay and API to // The host and port to bind the relay and API to
@ -28,14 +28,18 @@ pub struct NotePushEnv {
pub relay_url: String, pub relay_url: String,
// The max age of the Nostr event cache, in seconds // The max age of the Nostr event cache, in seconds
pub nostr_event_cache_max_age: std::time::Duration, pub nostr_event_cache_max_age: std::time::Duration,
// Path to google_services.json for FCM
pub google_services_file_path: Option<String>,
// VAAPI key for WebPush (via FCM)
pub vaapi_key: Option<String>,
} }
impl NotePushEnv { impl NotePushEnv {
pub fn load_env() -> Result<NotePushEnv, env::VarError> { pub fn load_env() -> Result<NotePushEnv, env::VarError> {
dotenv().ok(); dotenv().ok();
let apns_private_key_path = env::var("APNS_AUTH_PRIVATE_KEY_FILE_PATH")?; let apns_private_key_path = env::var("APNS_AUTH_PRIVATE_KEY_FILE_PATH").ok();
let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?; let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID").ok();
let apns_team_id = env::var("APPLE_TEAM_ID")?; let apns_team_id = env::var("APPLE_TEAM_ID").ok();
let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string()); let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string());
let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string()); let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string());
let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string()); let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string());
@ -48,7 +52,7 @@ impl NotePushEnv {
"production" => a2::client::Endpoint::Production, "production" => a2::client::Endpoint::Production,
_ => a2::client::Endpoint::Sandbox, _ => a2::client::Endpoint::Sandbox,
}; };
let apns_topic = env::var("APNS_TOPIC")?; let apns_topic = env::var("APNS_TOPIC").ok();
let nostr_event_cache_max_age = env::var("NOSTR_EVENT_CACHE_MAX_AGE") let nostr_event_cache_max_age = env::var("NOSTR_EVENT_CACHE_MAX_AGE")
.unwrap_or(DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE.to_string()) .unwrap_or(DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE.to_string())
.parse::<u64>() .parse::<u64>()
@ -56,6 +60,8 @@ impl NotePushEnv {
.unwrap_or(std::time::Duration::from_secs( .unwrap_or(std::time::Duration::from_secs(
DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE, DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE,
)); ));
let google_services_file_path = env::var("GOOGLE_SERVICES_FILE_PATH").ok();
let vaapi_key = env::var("VAAPI_KEY").ok();
Ok(NotePushEnv { Ok(NotePushEnv {
apns_private_key_path, apns_private_key_path,
@ -69,6 +75,8 @@ impl NotePushEnv {
api_base_url, api_base_url,
relay_url, relay_url,
nostr_event_cache_max_age, nostr_event_cache_max_age,
google_services_file_path,
vaapi_key,
}) })
} }

View File

@ -3,8 +3,8 @@ mod nostr_event_extensions;
pub mod nostr_network_helper; pub mod nostr_network_helper;
pub mod utils; pub mod utils;
use std::cmp::{max, min};
use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible}; use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible};
use std::cmp::{max, min};
use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder}; use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder};
use nostr::key::PublicKey; use nostr::key::PublicKey;
@ -19,6 +19,7 @@ use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use fcm_service::{FcmMessage, FcmNotification, FcmService, Target};
use nostr::Event; use nostr::Event;
use nostr_event_extensions::Codable; use nostr_event_extensions::Codable;
use nostr_event_extensions::MaybeConvertibleToMuteList; use nostr_event_extensions::MaybeConvertibleToMuteList;
@ -26,6 +27,8 @@ use nostr_event_extensions::TimestampedMuteList;
use nostr_network_helper::NostrNetworkHelper; use nostr_network_helper::NostrNetworkHelper;
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
use std::fs::File; use std::fs::File;
use std::str::FromStr;
use thiserror::Error;
use utils::should_mute_notification_for_mutelist; use utils::should_mute_notification_for_mutelist;
// MARK: - NotificationManager // MARK: - NotificationManager
@ -37,10 +40,56 @@ const HELLTHREAD_MIN_PUBKEYS: i8 = 6;
// Maximum threshold the hellthread pubkey tag count setting can go up to. // Maximum threshold the hellthread pubkey tag count setting can go up to.
const HELLTHREAD_MAX_PUBKEYS: i8 = 24; const HELLTHREAD_MAX_PUBKEYS: i8 = 24;
#[derive(Debug, Clone, Copy)]
pub enum NotificationBackend {
APNS,
FCM,
}
impl From<u8> for NotificationBackend {
fn from(value: u8) -> Self {
match value {
0 => NotificationBackend::APNS,
1 => NotificationBackend::FCM,
_ => panic!("Invalid value for NotificationBackend"),
}
}
}
impl Into<u8> for NotificationBackend {
fn into(self) -> u8 {
match self {
NotificationBackend::APNS => 0,
NotificationBackend::FCM => 1,
}
}
}
impl FromStr for NotificationBackend {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"fcm" => Ok(NotificationBackend::FCM),
"apns" => Ok(NotificationBackend::APNS),
_ => Err(()),
}
}
}
#[derive(Error, Debug)]
pub enum NotificationManagerError {
#[error("APNS is not configured")]
APNSMissing,
#[error("FCM is not configured")]
FCMMissing,
}
pub struct NotificationManager { pub struct NotificationManager {
db: Arc<Mutex<r2d2::Pool<SqliteConnectionManager>>>, db: Arc<Mutex<r2d2::Pool<SqliteConnectionManager>>>,
apns_topic: String, apns_topic: Option<String>,
apns_client: Mutex<Client>, apns_client: Option<Mutex<Client>>,
fcm_client: Option<Mutex<FcmService>>,
nostr_network_helper: NostrNetworkHelper, nostr_network_helper: NostrNetworkHelper,
pub event_saver: EventSaver, pub event_saver: EventSaver,
} }
@ -151,32 +200,19 @@ impl NotificationManager {
pub async fn new( pub async fn new(
db: r2d2::Pool<SqliteConnectionManager>, db: r2d2::Pool<SqliteConnectionManager>,
relay_url: String, relay_url: String,
apns_private_key_path: String,
apns_private_key_id: String,
apns_team_id: String,
apns_environment: a2::client::Endpoint,
apns_topic: String,
cache_max_age: std::time::Duration, cache_max_age: std::time::Duration,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> Result<Self, Box<dyn std::error::Error>> {
let connection = db.get()?; let connection = db.get()?;
Self::setup_database(&connection)?; Self::setup_database(&connection)?;
let mut file = File::open(&apns_private_key_path)?;
let client = Client::token(
&mut file,
&apns_private_key_id,
&apns_team_id,
ClientConfig::new(apns_environment.clone()),
)?;
let db = Arc::new(Mutex::new(db)); let db = Arc::new(Mutex::new(db));
let event_saver = EventSaver::new(db.clone()); let event_saver = EventSaver::new(db.clone());
let manager = NotificationManager { let manager = NotificationManager {
db, db,
apns_topic, apns_topic: None,
apns_client: Mutex::new(client), apns_client: None,
fcm_client: None,
nostr_network_helper: NostrNetworkHelper::new( nostr_network_helper: NostrNetworkHelper::new(
relay_url.clone(), relay_url.clone(),
cache_max_age, cache_max_age,
@ -189,6 +225,46 @@ impl NotificationManager {
Ok(manager) Ok(manager)
} }
/// Adds APNS configuration
pub fn with_apns(
mut self,
apns_private_key_path: String,
apns_private_key_id: String,
apns_team_id: String,
apns_environment: a2::client::Endpoint,
apns_topic: String,
) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = File::open(&apns_private_key_path)?;
let client = Client::token(
&mut file,
&apns_private_key_id,
&apns_team_id,
ClientConfig::new(apns_environment.clone()),
)?;
self.apns_client.replace(Mutex::new(client));
self.apns_topic.replace(apns_topic);
Ok(self)
}
pub fn with_fcm(
mut self,
google_services_file_path: impl Into<String>,
) -> Result<Self, Box<dyn std::error::Error>> {
self.fcm_client
.replace(Mutex::new(FcmService::new(google_services_file_path)));
Ok(self)
}
pub fn has_backend(&self, backend: NotificationBackend) -> bool {
match backend {
NotificationBackend::APNS if !self.apns_client.is_some() => true,
NotificationBackend::FCM if self.fcm_client.is_some() => true,
_ => false,
}
}
// MARK: - Database setup operations // MARK: - Database setup operations
pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> { pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
@ -296,6 +372,10 @@ impl NotificationManager {
[], [],
)?; )?;
// Migration for FCM
Self::add_column_if_not_exists(db, "user_info", "backend", "TINYINT", Some("0"))?;
Ok(()) Ok(())
} }
@ -507,14 +587,14 @@ impl NotificationManager {
pubkey: &PublicKey, pubkey: &PublicKey,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let user_device_tokens = self.get_user_device_tokens(pubkey).await?; let user_device_tokens = self.get_user_device_tokens(pubkey).await?;
for device_token in user_device_tokens { for (device_token, backend) in user_device_tokens {
if !self if !self
.user_wants_notification(pubkey, device_token.clone(), event) .user_wants_notification(pubkey, device_token.clone(), event)
.await? .await?
{ {
continue; continue;
} }
self.send_event_notification_to_device_token(event, &device_token) self.send_event_notification_to_device_token(event, &device_token, backend)
.await?; .await?;
} }
Ok(()) Ok(())
@ -564,7 +644,12 @@ impl NotificationManager {
pubkey: &PublicKey, pubkey: &PublicKey,
device_token: &str, device_token: &str,
) -> Result<bool, Box<dyn std::error::Error>> { ) -> Result<bool, Box<dyn std::error::Error>> {
let current_device_tokens = self.get_user_device_tokens(pubkey).await?; let current_device_tokens: Vec<String> = self
.get_user_device_tokens(pubkey)
.await?
.into_iter()
.map(|(d, _)| d)
.collect();
Ok(current_device_tokens.contains(&device_token.to_string())) Ok(current_device_tokens.contains(&device_token.to_string()))
} }
@ -578,12 +663,18 @@ impl NotificationManager {
async fn get_user_device_tokens( async fn get_user_device_tokens(
&self, &self,
pubkey: &PublicKey, pubkey: &PublicKey,
) -> Result<Vec<String>, Box<dyn std::error::Error>> { ) -> Result<Vec<(String, NotificationBackend)>, Box<dyn std::error::Error>> {
let db_mutex_guard = self.db.lock().await; let db_mutex_guard = self.db.lock().await;
let connection = db_mutex_guard.get()?; let connection = db_mutex_guard.get()?;
let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?; let mut stmt =
connection.prepare("SELECT device_token,backend FROM user_info WHERE pubkey = ?")?;
let device_tokens = stmt let device_tokens = stmt
.query_map([pubkey.to_sql_string()], |row| row.get(0))? .query_map([pubkey.to_sql_string()], |row| {
Ok((
row.get::<usize, String>(0)?,
row.get::<usize, u8>(1)?.into(),
))
})?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.collect(); .collect();
Ok(device_tokens) Ok(device_tokens)
@ -623,11 +714,36 @@ impl NotificationManager {
&self, &self,
event: &Event, event: &Event,
device_token: &str, device_token: &str,
backend: NotificationBackend,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let (title, subtitle, body) = self.format_notification_message(event);
log::debug!("Sending notification to device token: {}", device_token); log::debug!("Sending notification to device token: {}", device_token);
match &backend {
NotificationBackend::APNS => {
self.send_event_notification_apns(event, device_token)
.await?;
}
NotificationBackend::FCM => {
self.send_event_notification_fcm(event, device_token)
.await?;
}
}
log::info!("Notification sent to device token: {}", device_token);
Ok(())
}
async fn send_event_notification_apns(
&self,
event: &Event,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self
.apns_client
.as_ref()
.ok_or(NotificationManagerError::APNSMissing)?;
let (title, subtitle, body) = self.format_notification_message(event);
let mut payload = DefaultNotificationBuilder::new() let mut payload = DefaultNotificationBuilder::new()
.set_title(&title) .set_title(&title)
.set_subtitle(&subtitle) .set_subtitle(&subtitle)
@ -636,14 +752,15 @@ impl NotificationManager {
.set_content_available() .set_content_available()
.build(device_token, Default::default()); .build(device_token, Default::default());
payload.options.apns_topic = Some(self.apns_topic.as_str()); if let Some(t) = self.apns_topic.as_ref() {
payload.options.apns_topic = Some(t);
}
payload.data.insert( payload.data.insert(
"nostr_event", "nostr_event",
serde_json::Value::String(event.try_as_json()?), serde_json::Value::String(event.try_as_json()?),
); );
let apns_client_mutex_guard = self.apns_client.lock().await; let apns_client_mutex_guard = client.lock().await;
match apns_client_mutex_guard.send(payload).await { match apns_client_mutex_guard.send(payload).await {
Ok(_response) => {} Ok(_response) => {}
Err(e) => log::error!( Err(e) => log::error!(
@ -653,7 +770,29 @@ impl NotificationManager {
), ),
} }
log::info!("Notification sent to device token: {}", device_token); Ok(())
}
async fn send_event_notification_fcm(
&self,
event: &Event,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self
.fcm_client
.as_ref()
.ok_or(NotificationManagerError::FCMMissing)?;
let (title, _, body) = self.format_notification_message(event);
let client = client.lock().await;
let mut msg = FcmMessage::new();
let mut notification = FcmNotification::new();
notification.set_title(title);
notification.set_body(body);
msg.set_notification(Some(notification));
msg.set_target(Target::Token(device_token.into()));
client.send_notification(msg).await?;
Ok(()) Ok(())
} }
@ -693,6 +832,7 @@ impl NotificationManager {
&self, &self,
pubkey: nostr::PublicKey, pubkey: nostr::PublicKey,
device_token: &str, device_token: &str,
backend: NotificationBackend,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
if self if self
.is_pubkey_token_pair_registered(&pubkey, device_token) .is_pubkey_token_pair_registered(&pubkey, device_token)
@ -700,23 +840,26 @@ impl NotificationManager {
{ {
return Ok(()); return Ok(());
} }
self.save_user_device_info(pubkey, device_token).await self.save_user_device_info(pubkey, device_token, backend)
.await
} }
pub async fn save_user_device_info( pub async fn save_user_device_info(
&self, &self,
pubkey: nostr::PublicKey, pubkey: nostr::PublicKey,
device_token: &str, device_token: &str,
backend: NotificationBackend,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let current_time_unix = Timestamp::now(); let current_time_unix = Timestamp::now();
let db_mutex_guard = self.db.lock().await; let db_mutex_guard = self.db.lock().await;
db_mutex_guard.get()?.execute( db_mutex_guard.get()?.execute(
"INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)", "INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at, backend) VALUES (?, ?, ?, ?, ?)",
params![ params![
format!("{}:{}", pubkey.to_sql_string(), device_token), format!("{}:{}", pubkey.to_sql_string(), device_token),
pubkey.to_sql_string(), pubkey.to_sql_string(),
device_token, device_token,
current_time_unix.to_sql_string() current_time_unix.to_sql_string(),
<NotificationBackend as Into<u8>>::into(backend)
], ],
)?; )?;
Ok(()) Ok(())
@ -724,7 +867,7 @@ impl NotificationManager {
pub async fn remove_user_device_info( pub async fn remove_user_device_info(
&self, &self,
pubkey: nostr::PublicKey, pubkey: &PublicKey,
device_token: &str, device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let db_mutex_guard = self.db.lock().await; let db_mutex_guard = self.db.lock().await;
@ -754,7 +897,9 @@ impl NotificationManager {
dm_notifications_enabled: row.get(4)?, dm_notifications_enabled: row.get(4)?,
only_notifications_from_following_enabled: row.get(5)?, only_notifications_from_following_enabled: row.get(5)?,
hellthread_notifications_disabled: row.get::<_, Option<bool>>(6)?.unwrap_or(false), hellthread_notifications_disabled: row.get::<_, Option<bool>>(6)?.unwrap_or(false),
hellthread_notifications_max_pubkeys: row.get::<_, Option<i8>>(7)?.unwrap_or(DEFAULT_HELLTHREAD_MAX_PUBKEYS), hellthread_notifications_max_pubkeys: row
.get::<_, Option<i8>>(7)?
.unwrap_or(DEFAULT_HELLTHREAD_MAX_PUBKEYS),
}) })
})?; })?;
@ -792,7 +937,6 @@ fn default_hellthread_max_pubkeys() -> i8 {
DEFAULT_HELLTHREAD_MAX_PUBKEYS DEFAULT_HELLTHREAD_MAX_PUBKEYS
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct UserNotificationSettings { pub struct UserNotificationSettings {
zap_notifications_enabled: bool, zap_notifications_enabled: bool,