feat: setup alt backends

feat: FCM backend
This commit is contained in:
2025-05-21 18:03:42 +01:00
parent e73f5ce70e
commit e2bdbf3c0c
6 changed files with 629 additions and 287 deletions

287
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
version = 4
[[package]]
name = "a2"
@ -316,6 +316,12 @@ dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.6.0"
@ -413,6 +419,7 @@ dependencies = [
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-targets 0.52.6",
]
@ -434,6 +441,16 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.6"
@ -489,6 +506,15 @@ version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "env_filter"
version = "0.1.0"
@ -527,6 +553,16 @@ dependencies = [
"serde",
]
[[package]]
name = "errno"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18"
dependencies = [
"libc",
"windows-sys 0.59.0",
]
[[package]]
name = "fallible-iterator"
version = "0.3.0"
@ -539,6 +575,25 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fcm-service"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96aadee26986d122d1b8b4cdd6e7fe1310fdceeb1fd8db4d612d6510d96d7106"
dependencies = [
"gcp_auth",
"reqwest",
"serde",
"serde_json",
"tokio",
]
[[package]]
name = "fnv"
version = "1.0.7"
@ -658,6 +713,33 @@ dependencies = [
"slab",
]
[[package]]
name = "gcp_auth"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbf67f30198e045a039264c01fb44659ce82402d7771c50938beb41a5ac87733"
dependencies = [
"async-trait",
"base64 0.22.1",
"bytes",
"chrono",
"home",
"http",
"http-body-util",
"hyper",
"hyper-rustls 0.27.2",
"hyper-util",
"ring",
"rustls-pemfile",
"serde",
"serde_json",
"thiserror",
"tokio",
"tracing",
"tracing-futures",
"url",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@ -764,6 +846,15 @@ dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "http"
version = "1.1.0"
@ -866,6 +957,7 @@ dependencies = [
"hyper",
"hyper-util",
"rustls 0.23.11",
"rustls-native-certs",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.0",
@ -873,6 +965,22 @@ dependencies = [
"webpki-roots",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-tungstenite"
version = "0.14.0"
@ -1017,6 +1125,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "lnurl-pay"
version = "0.5.0"
@ -1054,6 +1168,12 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "matchit"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f926ade0c4e170215ae43342bf13b9310a437609c81f29f86c5df6657582ef9"
[[package]]
name = "memchr"
version = "2.7.4"
@ -1086,6 +1206,23 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "native-tls"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "negentropy"
version = "0.3.1"
@ -1206,12 +1343,14 @@ dependencies = [
"chrono",
"dotenv",
"env_logger",
"fcm-service",
"futures",
"http-body-util",
"hyper",
"hyper-tungstenite",
"hyper-util",
"log",
"matchit",
"nostr",
"nostr-sdk",
"r2d2",
@ -1224,6 +1363,7 @@ dependencies = [
"toml",
"tracing",
"tungstenite",
"urlencoding",
"uuid",
]
@ -1287,7 +1427,7 @@ version = "0.10.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"
dependencies = [
"bitflags",
"bitflags 2.6.0",
"cfg-if",
"foreign-types",
"libc",
@ -1307,6 +1447,12 @@ dependencies = [
"syn",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.102"
@ -1557,7 +1703,7 @@ version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [
"bitflags",
"bitflags 2.6.0",
]
[[package]]
@ -1597,18 +1743,22 @@ checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37"
dependencies = [
"base64 0.22.1",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls 0.27.2",
"hyper-tls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
@ -1620,7 +1770,9 @@ dependencies = [
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"system-configuration",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.26.0",
"tokio-socks",
"tower-service",
@ -1653,7 +1805,7 @@ version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae"
dependencies = [
"bitflags",
"bitflags 2.6.0",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",
@ -1683,6 +1835,19 @@ dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "0.38.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
dependencies = [
"bitflags 2.6.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
]
[[package]]
name = "rustls"
version = "0.22.4"
@ -1711,6 +1876,19 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "2.1.2"
@ -1753,6 +1931,15 @@ dependencies = [
"cipher",
]
[[package]]
name = "schannel"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "scheduled-thread-pool"
version = "0.2.7"
@ -1801,6 +1988,29 @@ dependencies = [
"cc",
]
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "semver"
version = "1.0.23"
@ -1942,6 +2152,40 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
[[package]]
name = "system-configuration"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tempfile"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
dependencies = [
"cfg-if",
"fastrand",
"once_cell",
"rustix",
"windows-sys 0.59.0",
]
[[package]]
name = "thiserror"
version = "1.0.63"
@ -2007,6 +2251,16 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.25.0"
@ -2137,6 +2391,16 @@ dependencies = [
"once_cell",
]
[[package]]
name = "tracing-futures"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2"
dependencies = [
"pin-project",
"tracing",
]
[[package]]
name = "try-lock"
version = "0.2.5"
@ -2218,6 +2482,12 @@ dependencies = [
"serde",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf-8"
version = "0.7.6"
@ -2396,6 +2666,15 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-targets"
version = "0.48.5"

View File

@ -37,3 +37,6 @@ uuid = { version = "1.10.0", features = ["v4"] }
thiserror = "1.0.63"
hyper-tungstenite = "0.14.0"
futures = "0.3.30"
fcm-service = "0.1.5"
matchit = "0.8.6"
urlencoding = "2.1.3"

View File

@ -1,19 +1,24 @@
use crate::nip98_auth;
use crate::notification_manager::UserNotificationSettings;
use crate::notification_manager::{NotificationBackend, UserNotificationSettings};
use crate::relay_connection::RelayConnection;
use http_body_util::Full;
use hyper::body::Buf;
use hyper::body::Bytes;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use std::borrow::Cow;
use http_body_util::BodyExt;
use serde_json::from_value;
use crate::notification_manager::NotificationManager;
use hyper::Method;
use log::warn;
use matchit::{Params, Router};
use nostr::prelude::url::form_urlencoded;
use nostr::PublicKey;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;
@ -65,6 +70,10 @@ impl APIHandler {
status: StatusCode::UNAUTHORIZED,
body: json!({ "error": "Unauthorized", "message": message }),
},
_ => APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": api_error.to_string() }),
},
}
} else {
// Otherwise, return a 500 status code
@ -154,6 +163,12 @@ impl APIHandler {
method: req.method().clone(),
body_bytes: body_bytes.map(|b| b.to_vec()),
authorized_pubkey,
query: match req.uri().query() {
Some(q) => form_urlencoded::parse(q.as_bytes())
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect(),
None => vec![],
},
})
}
@ -163,38 +178,38 @@ impl APIHandler {
&self,
parsed_request: &ParsedRequest,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
if let Some(url_params) = route_match(
&Method::PUT,
"/user-info/:pubkey/:deviceToken",
parsed_request,
) {
return self.handle_user_info(parsed_request, &url_params).await;
enum RouteMatch {
UserInfo,
UserSetting,
}
let mut routes = Router::new();
routes.insert("/user-info/{pubkey}/{deviceToken}", RouteMatch::UserInfo)?;
routes.insert(
"/user-info/{pubkey}/{deviceToken}/preference",
RouteMatch::UserSetting,
)?;
if let Some(url_params) = route_match(
&Method::DELETE,
"/user-info/:pubkey/:deviceToken",
parsed_request,
) {
return self
.handle_user_info_remove(parsed_request, &url_params)
.await;
}
if let Some(url_params) = route_match(
&Method::GET,
"/user-info/:pubkey/:deviceToken/preferences",
parsed_request,
) {
return self.get_user_settings(parsed_request, &url_params).await;
}
if let Some(url_params) = route_match(
&Method::PUT,
"/user-info/:pubkey/:deviceToken/preferences",
parsed_request,
) {
return self.set_user_settings(parsed_request, &url_params).await;
match routes.at(&parsed_request.uri) {
Ok(m) => match m.value {
RouteMatch::UserInfo if parsed_request.method == Method::PUT => {
return self.handle_user_info(parsed_request, m.params).await;
}
RouteMatch::UserInfo if parsed_request.method == Method::DELETE => {
return self.handle_user_info_remove(parsed_request, m.params).await;
}
RouteMatch::UserSetting if parsed_request.method == Method::GET => {
return self.get_user_settings(parsed_request, m.params).await;
}
RouteMatch::UserSetting if parsed_request.method == Method::PUT => {
return self.set_user_settings(parsed_request, m.params).await;
}
_ => {
// fallthrough to 404
}
},
Err(e) => {
warn!("Match failed: {}", e);
}
}
Ok(APIResponse {
@ -209,7 +224,7 @@ impl APIHandler {
&self,
req: &Request<Incoming>,
body_bytes: Option<&[u8]>,
) -> Result<Result<nostr::PublicKey, String>, Box<dyn std::error::Error>> {
) -> Result<Result<PublicKey, String>, Box<dyn std::error::Error>> {
let auth_header = match req.headers().get("Authorization") {
Some(header) => header,
None => return Ok(Err("Authorization header not found".to_string())),
@ -217,7 +232,11 @@ impl APIHandler {
Ok(nip98_auth::nip98_verify_auth_header(
auth_header.to_str()?.to_string(),
&format!("{}{}", self.base_url, req.uri().path()),
&format!(
"{}{}",
self.base_url,
req.uri().path_and_query().map(|p| p.as_str()).unwrap_or("")
),
req.method().as_str(),
body_bytes,
)
@ -229,52 +248,39 @@ impl APIHandler {
async fn handle_user_info(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
let device_token = get_required_param(&url_params, "deviceToken")?;
let pubkey = get_required_param(&url_params, "pubkey")?;
let pubkey = check_pubkey(&pubkey, req)?;
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
})
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
let backend = match NotificationBackend::from_str(
req.query
.iter()
.find(|c| c.0 == "backend")
.map(|c| &c.1)
.unwrap_or(&"apns".to_string()),
) {
Ok(token) => token,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
body: json!({ "error": "backend is invalid" }),
})
}
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
// Early return if backend is not supported
if !self.notification_manager.has_backend(backend) {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Backend not supported" }),
});
}
// Proceed with the main logic after passing all checks
self.notification_manager
.save_user_device_info_if_not_present(pubkey, device_token)
.save_user_device_info_if_not_present(pubkey, &device_token, backend)
.await?;
Ok(APIResponse {
status: StatusCode::OK,
@ -285,52 +291,15 @@ impl APIHandler {
async fn handle_user_info_remove(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
})
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
})
}
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
let device_token = get_required_param(&url_params, "deviceToken")?;
let pubkey = get_required_param(&url_params, "pubkey")?;
let pubkey = check_pubkey(&pubkey, req)?;
// Proceed with the main logic after passing all checks
self.notification_manager
.remove_user_device_info(pubkey, device_token)
.remove_user_device_info(&pubkey, &device_token)
.await?;
Ok(APIResponse {
@ -342,48 +311,11 @@ impl APIHandler {
async fn set_user_settings(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
})
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
})
}
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
let device_token = get_required_param(&url_params, "deviceToken")?;
let pubkey = get_required_param(&url_params, "pubkey")?;
let pubkey = check_pubkey(&pubkey, req)?;
// Proceed with the main logic after passing all checks
let body = req.body_json()?;
@ -399,11 +331,7 @@ impl APIHandler {
};
self.notification_manager
.save_user_notification_settings(
&req.authorized_pubkey,
device_token.to_string(),
settings,
)
.save_user_notification_settings(&pubkey, device_token.to_string(), settings)
.await?;
Ok(APIResponse {
@ -415,53 +343,16 @@ impl APIHandler {
async fn get_user_settings(
&self,
req: &ParsedRequest,
url_params: &HashMap<&str, String>,
url_params: Params<'_, '_>,
) -> Result<APIResponse, Box<dyn std::error::Error>> {
// Early return if `deviceToken` is missing
let device_token = match url_params.get("deviceToken") {
Some(token) => token,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "deviceToken is required on the URL" }),
})
}
};
// Early return if `pubkey` is missing
let pubkey = match url_params.get("pubkey") {
Some(key) => key,
None => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "pubkey is required on the URL" }),
})
}
};
// Validate the `pubkey` and prepare it for use
let pubkey = match nostr::PublicKey::from_hex(pubkey) {
Ok(key) => key,
Err(_) => {
return Ok(APIResponse {
status: StatusCode::BAD_REQUEST,
body: json!({ "error": "Invalid pubkey" }),
})
}
};
// Early return if `pubkey` does not match `req.authorized_pubkey`
if pubkey != req.authorized_pubkey {
return Ok(APIResponse {
status: StatusCode::FORBIDDEN,
body: json!({ "error": "Forbidden" }),
});
}
let device_token = get_required_param(&url_params, "deviceToken")?;
let pubkey = get_required_param(&url_params, "pubkey")?;
let pubkey = check_pubkey(&pubkey, req)?;
// Proceed with the main logic after passing all checks
let settings = self
.notification_manager
.get_user_notification_settings(&req.authorized_pubkey, device_token.to_string())
.get_user_notification_settings(&pubkey, device_token.to_string())
.await?;
Ok(APIResponse {
@ -487,6 +378,10 @@ impl Clone for APIHandler {
// Define enum error types including authentication error
#[derive(Debug, Error)]
enum APIError {
#[error("Missing required parameter {0}")]
MissingParameter(String),
#[error("Invalid parameter {0}")]
InvalidParameter(String),
#[error("Authentication error: {0}")]
AuthenticationError(String),
}
@ -496,6 +391,7 @@ struct ParsedRequest {
method: Method,
body_bytes: Option<Vec<u8>>,
authorized_pubkey: nostr::PublicKey,
query: Vec<(String, String)>,
}
impl ParsedRequest {
@ -513,34 +409,29 @@ struct APIResponse {
body: Value,
}
// MARK: - Helper functions
/// Matches the request to a specified route, returning a hashmap of the route parameters
/// e.g. GET /user/:id/info route against request GET /user/123/info matches to { "id": "123" }
fn route_match<'a>(
method: &Method,
path: &'a str,
req: &ParsedRequest,
) -> Option<HashMap<&'a str, String>> {
if method != req.method {
return None;
fn get_required_param(params: &Params<'_, '_>, key: &str) -> Result<String, APIError> {
match params.get(key) {
Some(token) => urlencoding::decode(token)
.map(|s| match s {
Cow::Borrowed(s) => s.to_owned(),
Cow::Owned(s) => s,
})
.map_err(|_| APIError::InvalidParameter(key.to_string())),
None => Err(APIError::MissingParameter(key.to_string())),
}
}
fn parse_pubkey(pubkey: &str) -> Result<PublicKey, APIError> {
PublicKey::from_hex(pubkey).map_err(|_| APIError::InvalidParameter("pubkey".to_string()))
}
fn check_pubkey(pubkey: &str, req: &ParsedRequest) -> Result<PublicKey, APIError> {
let pubkey = parse_pubkey(pubkey)?;
if pubkey != req.authorized_pubkey {
Err(APIError::AuthenticationError(
"pubkey doesnt match authorized pubkey".to_string(),
))
} else {
Ok(pubkey)
}
let mut params = HashMap::new();
let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let req_segments: Vec<&str> = req.uri.split('/').filter(|s| !s.is_empty()).collect();
if path_segments.len() != req_segments.len() {
return None;
}
for (i, segment) in path_segments.iter().enumerate() {
if let Some(key) = segment.strip_prefix(':') {
let value = req_segments[i].to_string();
params.insert(key, value);
} else if segment != &req_segments[i] {
return None;
}
}
Some(params)
}

View File

@ -28,20 +28,37 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool");
// Notification manager is a shared resource that will be used by all connections via a mutex and an atomic reference counter.
// This is shared to avoid data races when reading/writing to the sqlite database, and reduce outgoing relay connections.
let notification_manager = Arc::new(
notification_manager::NotificationManager::new(
pool,
env.relay_url.clone(),
env.apns_private_key_path.clone(),
env.apns_private_key_id.clone(),
env.apns_team_id.clone(),
env.apns_environment.clone(),
env.apns_topic.clone(),
env.nostr_event_cache_max_age,
)
.await
.expect("Failed to create notification manager"),
);
let mut notification_manager = notification_manager::NotificationManager::new(
pool,
env.relay_url.clone(),
env.nostr_event_cache_max_age,
)
.await
.expect("Failed to create notification manager");
// setup apns if key path is set
if let Some(apns_key) = &env.apns_private_key_path {
notification_manager = notification_manager
.with_apns(
apns_key.clone(),
env.apns_private_key_id.unwrap().clone(),
env.apns_team_id.unwrap().clone(),
env.apns_environment.clone(),
env.apns_topic.unwrap().clone(),
)
.expect("Failed to set APNs");
log::info!("APNS configured for notifications manager!");
}
// setup fcm if google services path is set
if let Some(gsp) = &env.google_services_file_path {
notification_manager = notification_manager
.with_fcm(gsp)
.expect("Failed to setup FCM");
log::info!("FCM configured for notifications manager!");
}
let notification_manager = Arc::new(notification_manager);
let api_handler = Arc::new(api_request_handler::APIHandler::new(
notification_manager.clone(),
env.api_base_url.clone(),

View File

@ -9,15 +9,15 @@ const DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE: u64 = 60 * 60; // 1 hour
pub struct NotePushEnv {
// The path to the Apple private key .p8 file
pub apns_private_key_path: String,
pub apns_private_key_path: Option<String>,
// The Apple private key ID
pub apns_private_key_id: String,
pub apns_private_key_id: Option<String>,
// The Apple team ID
pub apns_team_id: String,
pub apns_team_id: Option<String>,
// The APNS environment to send notifications to (Sandbox or Production)
pub apns_environment: a2::client::Endpoint,
// The topic to send notifications to (The Apple app bundle ID)
pub apns_topic: String,
pub apns_topic: Option<String>,
// The path to the SQLite database file
pub db_path: String,
// The host and port to bind the relay and API to
@ -28,14 +28,18 @@ pub struct NotePushEnv {
pub relay_url: String,
// The max age of the Nostr event cache, in seconds
pub nostr_event_cache_max_age: std::time::Duration,
// Path to google_services.json for FCM
pub google_services_file_path: Option<String>,
// VAAPI key for WebPush (via FCM)
pub vaapi_key: Option<String>,
}
impl NotePushEnv {
pub fn load_env() -> Result<NotePushEnv, env::VarError> {
dotenv().ok();
let apns_private_key_path = env::var("APNS_AUTH_PRIVATE_KEY_FILE_PATH")?;
let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?;
let apns_team_id = env::var("APPLE_TEAM_ID")?;
let apns_private_key_path = env::var("APNS_AUTH_PRIVATE_KEY_FILE_PATH").ok();
let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID").ok();
let apns_team_id = env::var("APPLE_TEAM_ID").ok();
let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string());
let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string());
let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string());
@ -48,7 +52,7 @@ impl NotePushEnv {
"production" => a2::client::Endpoint::Production,
_ => a2::client::Endpoint::Sandbox,
};
let apns_topic = env::var("APNS_TOPIC")?;
let apns_topic = env::var("APNS_TOPIC").ok();
let nostr_event_cache_max_age = env::var("NOSTR_EVENT_CACHE_MAX_AGE")
.unwrap_or(DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE.to_string())
.parse::<u64>()
@ -56,6 +60,8 @@ impl NotePushEnv {
.unwrap_or(std::time::Duration::from_secs(
DEFAULT_NOSTR_EVENT_CACHE_MAX_AGE,
));
let google_services_file_path = env::var("GOOGLE_SERVICES_FILE_PATH").ok();
let vaapi_key = env::var("VAAPI_KEY").ok();
Ok(NotePushEnv {
apns_private_key_path,
@ -69,6 +75,8 @@ impl NotePushEnv {
api_base_url,
relay_url,
nostr_event_cache_max_age,
google_services_file_path,
vaapi_key,
})
}

View File

@ -3,8 +3,8 @@ mod nostr_event_extensions;
pub mod nostr_network_helper;
pub mod utils;
use std::cmp::{max, min};
use nostr_event_extensions::{ExtendedEvent, SqlStringConvertible};
use std::cmp::{max, min};
use a2::{Client, ClientConfig, DefaultNotificationBuilder, NotificationBuilder};
use nostr::key::PublicKey;
@ -19,6 +19,7 @@ use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use fcm_service::{FcmMessage, FcmNotification, FcmService, Target};
use nostr::Event;
use nostr_event_extensions::Codable;
use nostr_event_extensions::MaybeConvertibleToMuteList;
@ -26,6 +27,8 @@ use nostr_event_extensions::TimestampedMuteList;
use nostr_network_helper::NostrNetworkHelper;
use r2d2_sqlite::SqliteConnectionManager;
use std::fs::File;
use std::str::FromStr;
use thiserror::Error;
use utils::should_mute_notification_for_mutelist;
// MARK: - NotificationManager
@ -37,10 +40,56 @@ const HELLTHREAD_MIN_PUBKEYS: i8 = 6;
// Maximum threshold the hellthread pubkey tag count setting can go up to.
const HELLTHREAD_MAX_PUBKEYS: i8 = 24;
#[derive(Debug, Clone, Copy)]
pub enum NotificationBackend {
APNS,
FCM,
}
impl From<u8> for NotificationBackend {
fn from(value: u8) -> Self {
match value {
0 => NotificationBackend::APNS,
1 => NotificationBackend::FCM,
_ => panic!("Invalid value for NotificationBackend"),
}
}
}
impl Into<u8> for NotificationBackend {
fn into(self) -> u8 {
match self {
NotificationBackend::APNS => 0,
NotificationBackend::FCM => 1,
}
}
}
impl FromStr for NotificationBackend {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"fcm" => Ok(NotificationBackend::FCM),
"apns" => Ok(NotificationBackend::APNS),
_ => Err(()),
}
}
}
#[derive(Error, Debug)]
pub enum NotificationManagerError {
#[error("APNS is not configured")]
APNSMissing,
#[error("FCM is not configured")]
FCMMissing,
}
pub struct NotificationManager {
db: Arc<Mutex<r2d2::Pool<SqliteConnectionManager>>>,
apns_topic: String,
apns_client: Mutex<Client>,
apns_topic: Option<String>,
apns_client: Option<Mutex<Client>>,
fcm_client: Option<Mutex<FcmService>>,
nostr_network_helper: NostrNetworkHelper,
pub event_saver: EventSaver,
}
@ -151,32 +200,19 @@ impl NotificationManager {
pub async fn new(
db: r2d2::Pool<SqliteConnectionManager>,
relay_url: String,
apns_private_key_path: String,
apns_private_key_id: String,
apns_team_id: String,
apns_environment: a2::client::Endpoint,
apns_topic: String,
cache_max_age: std::time::Duration,
) -> Result<Self, Box<dyn std::error::Error>> {
let connection = db.get()?;
Self::setup_database(&connection)?;
let mut file = File::open(&apns_private_key_path)?;
let client = Client::token(
&mut file,
&apns_private_key_id,
&apns_team_id,
ClientConfig::new(apns_environment.clone()),
)?;
let db = Arc::new(Mutex::new(db));
let event_saver = EventSaver::new(db.clone());
let manager = NotificationManager {
db,
apns_topic,
apns_client: Mutex::new(client),
apns_topic: None,
apns_client: None,
fcm_client: None,
nostr_network_helper: NostrNetworkHelper::new(
relay_url.clone(),
cache_max_age,
@ -189,6 +225,46 @@ impl NotificationManager {
Ok(manager)
}
/// Adds APNS configuration
pub fn with_apns(
mut self,
apns_private_key_path: String,
apns_private_key_id: String,
apns_team_id: String,
apns_environment: a2::client::Endpoint,
apns_topic: String,
) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = File::open(&apns_private_key_path)?;
let client = Client::token(
&mut file,
&apns_private_key_id,
&apns_team_id,
ClientConfig::new(apns_environment.clone()),
)?;
self.apns_client.replace(Mutex::new(client));
self.apns_topic.replace(apns_topic);
Ok(self)
}
pub fn with_fcm(
mut self,
google_services_file_path: impl Into<String>,
) -> Result<Self, Box<dyn std::error::Error>> {
self.fcm_client
.replace(Mutex::new(FcmService::new(google_services_file_path)));
Ok(self)
}
pub fn has_backend(&self, backend: NotificationBackend) -> bool {
match backend {
NotificationBackend::APNS if !self.apns_client.is_some() => true,
NotificationBackend::FCM if self.fcm_client.is_some() => true,
_ => false,
}
}
// MARK: - Database setup operations
pub fn setup_database(db: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
@ -296,6 +372,10 @@ impl NotificationManager {
[],
)?;
// Migration for FCM
Self::add_column_if_not_exists(db, "user_info", "backend", "TINYINT", Some("0"))?;
Ok(())
}
@ -507,14 +587,14 @@ impl NotificationManager {
pubkey: &PublicKey,
) -> Result<(), Box<dyn std::error::Error>> {
let user_device_tokens = self.get_user_device_tokens(pubkey).await?;
for device_token in user_device_tokens {
for (device_token, backend) in user_device_tokens {
if !self
.user_wants_notification(pubkey, device_token.clone(), event)
.await?
{
continue;
}
self.send_event_notification_to_device_token(event, &device_token)
self.send_event_notification_to_device_token(event, &device_token, backend)
.await?;
}
Ok(())
@ -564,7 +644,12 @@ impl NotificationManager {
pubkey: &PublicKey,
device_token: &str,
) -> Result<bool, Box<dyn std::error::Error>> {
let current_device_tokens = self.get_user_device_tokens(pubkey).await?;
let current_device_tokens: Vec<String> = self
.get_user_device_tokens(pubkey)
.await?
.into_iter()
.map(|(d, _)| d)
.collect();
Ok(current_device_tokens.contains(&device_token.to_string()))
}
@ -578,12 +663,18 @@ impl NotificationManager {
async fn get_user_device_tokens(
&self,
pubkey: &PublicKey,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
) -> Result<Vec<(String, NotificationBackend)>, Box<dyn std::error::Error>> {
let db_mutex_guard = self.db.lock().await;
let connection = db_mutex_guard.get()?;
let mut stmt = connection.prepare("SELECT device_token FROM user_info WHERE pubkey = ?")?;
let mut stmt =
connection.prepare("SELECT device_token,backend FROM user_info WHERE pubkey = ?")?;
let device_tokens = stmt
.query_map([pubkey.to_sql_string()], |row| row.get(0))?
.query_map([pubkey.to_sql_string()], |row| {
Ok((
row.get::<usize, String>(0)?,
row.get::<usize, u8>(1)?.into(),
))
})?
.filter_map(|r| r.ok())
.collect();
Ok(device_tokens)
@ -623,11 +714,36 @@ impl NotificationManager {
&self,
event: &Event,
device_token: &str,
backend: NotificationBackend,
) -> Result<(), Box<dyn std::error::Error>> {
let (title, subtitle, body) = self.format_notification_message(event);
log::debug!("Sending notification to device token: {}", device_token);
match &backend {
NotificationBackend::APNS => {
self.send_event_notification_apns(event, device_token)
.await?;
}
NotificationBackend::FCM => {
self.send_event_notification_fcm(event, device_token)
.await?;
}
}
log::info!("Notification sent to device token: {}", device_token);
Ok(())
}
async fn send_event_notification_apns(
&self,
event: &Event,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self
.apns_client
.as_ref()
.ok_or(NotificationManagerError::APNSMissing)?;
let (title, subtitle, body) = self.format_notification_message(event);
let mut payload = DefaultNotificationBuilder::new()
.set_title(&title)
.set_subtitle(&subtitle)
@ -636,14 +752,15 @@ impl NotificationManager {
.set_content_available()
.build(device_token, Default::default());
payload.options.apns_topic = Some(self.apns_topic.as_str());
if let Some(t) = self.apns_topic.as_ref() {
payload.options.apns_topic = Some(t);
}
payload.data.insert(
"nostr_event",
serde_json::Value::String(event.try_as_json()?),
);
let apns_client_mutex_guard = self.apns_client.lock().await;
let apns_client_mutex_guard = client.lock().await;
match apns_client_mutex_guard.send(payload).await {
Ok(_response) => {}
Err(e) => log::error!(
@ -653,7 +770,29 @@ impl NotificationManager {
),
}
log::info!("Notification sent to device token: {}", device_token);
Ok(())
}
async fn send_event_notification_fcm(
&self,
event: &Event,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let client = self
.fcm_client
.as_ref()
.ok_or(NotificationManagerError::FCMMissing)?;
let (title, _, body) = self.format_notification_message(event);
let client = client.lock().await;
let mut msg = FcmMessage::new();
let mut notification = FcmNotification::new();
notification.set_title(title);
notification.set_body(body);
msg.set_notification(Some(notification));
msg.set_target(Target::Token(device_token.into()));
client.send_notification(msg).await?;
Ok(())
}
@ -693,6 +832,7 @@ impl NotificationManager {
&self,
pubkey: nostr::PublicKey,
device_token: &str,
backend: NotificationBackend,
) -> Result<(), Box<dyn std::error::Error>> {
if self
.is_pubkey_token_pair_registered(&pubkey, device_token)
@ -700,23 +840,26 @@ impl NotificationManager {
{
return Ok(());
}
self.save_user_device_info(pubkey, device_token).await
self.save_user_device_info(pubkey, device_token, backend)
.await
}
pub async fn save_user_device_info(
&self,
pubkey: nostr::PublicKey,
device_token: &str,
backend: NotificationBackend,
) -> Result<(), Box<dyn std::error::Error>> {
let current_time_unix = Timestamp::now();
let db_mutex_guard = self.db.lock().await;
db_mutex_guard.get()?.execute(
"INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)",
"INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at, backend) VALUES (?, ?, ?, ?, ?)",
params![
format!("{}:{}", pubkey.to_sql_string(), device_token),
pubkey.to_sql_string(),
device_token,
current_time_unix.to_sql_string()
current_time_unix.to_sql_string(),
<NotificationBackend as Into<u8>>::into(backend)
],
)?;
Ok(())
@ -724,7 +867,7 @@ impl NotificationManager {
pub async fn remove_user_device_info(
&self,
pubkey: nostr::PublicKey,
pubkey: &PublicKey,
device_token: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let db_mutex_guard = self.db.lock().await;
@ -754,7 +897,9 @@ impl NotificationManager {
dm_notifications_enabled: row.get(4)?,
only_notifications_from_following_enabled: row.get(5)?,
hellthread_notifications_disabled: row.get::<_, Option<bool>>(6)?.unwrap_or(false),
hellthread_notifications_max_pubkeys: row.get::<_, Option<i8>>(7)?.unwrap_or(DEFAULT_HELLTHREAD_MAX_PUBKEYS),
hellthread_notifications_max_pubkeys: row
.get::<_, Option<i8>>(7)?
.unwrap_or(DEFAULT_HELLTHREAD_MAX_PUBKEYS),
})
})?;
@ -792,7 +937,6 @@ fn default_hellthread_max_pubkeys() -> i8 {
DEFAULT_HELLTHREAD_MAX_PUBKEYS
}
#[derive(Serialize, Deserialize, Debug)]
pub struct UserNotificationSettings {
zap_notifications_enabled: bool,