diff --git a/Cargo.lock b/Cargo.lock index 2ea3e99..144233d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -804,6 +804,12 @@ version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "humantime" version = "2.1.0" @@ -823,6 +829,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -1180,9 +1187,13 @@ name = "notepush" version = "0.1.0" dependencies = [ "a2", + "base64 0.22.1", "chrono", "dotenv", "env_logger", + "http-body-util", + "hyper", + "hyper-util", "log", "nostr", "nostr-sdk", @@ -1191,10 +1202,12 @@ dependencies = [ "rusqlite", "serde", "serde_json", + "thiserror", "tokio", "toml", "tracing", "tungstenite", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ad72e60..4e0b76c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ chrono = { version = "0.4.38" } a2 = { version = "0.10.0" } tokio = { version = "1.38.0", features = ["full"] } tungstenite = "0.23.0" +hyper = { version = "1.4.1", features = ["server"] } nostr = "0.32.1" log = "0.4" env_logger = "0.11.3" @@ -29,3 +30,8 @@ nostr-sdk = "0.32.0" r2d2_sqlite = "0.24.0" r2d2 = "0.8.10" dotenv = "0.15.0" +base64 = "0.22.1" +hyper-util = "0.1.6" +http-body-util = "0.1.2" +uuid = { version = "1.10.0", features = ["v4"] } +thiserror = "1.0.63" diff --git a/README.md b/README.md index 3e3d73e..653a393 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,12 @@ APNS_AUTH_PRIVATE_KEY_ID=1234567890 # The ID of the private key used to generate APNS_ENVIRONMENT="development" # The environment to use with the APNS server. Can be "development" or "production" APPLE_TEAM_ID=1248163264 # The ID of the team. Can be found in AppStore Connect. DB_PATH=./apns_notifications.db # Path to the SQLite database file that will be used to store data about sent notifications, relative to the working directory -RELAY_URL=ws://localhost:7777 # URL to the relay server which will be consulted to get information such as mute lists. +RELAY_URL=wss://relay.damus.io # URL to the relay server which will be consulted to get information such as mute lists. +RELAY_HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces) +RELAY_PORT=9001 # The port to bind the server to. Defaults to 9001 +API_HOST="0.0.0.0" # The host to bind the API server to (Defaults to 0.0.0.0 to bind to all interfaces) +API_PORT=8000 # The port to bind the API server to. Defaults to 8000 API_BASE_URL=http://localhost:8000 # Base URL from the API is allowed access (used by the server to perform NIP-98 authentication) -HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces) -PORT=9001 # The port to bind the server to. Defaults to 9001 ``` 6. Run this relay using the built binary or the `cargo run` command. If you want to change the log level, you can set the `RUST_LOG` environment variable to `DEBUG` or `INFO` before running the relay. diff --git a/src/api_server/api_request_handler.rs b/src/api_server/api_request_handler.rs new file mode 100644 index 0000000..9ebfc92 --- /dev/null +++ b/src/api_server/api_request_handler.rs @@ -0,0 +1,202 @@ +use super::nip98_auth; +use hyper::{Request, Response, StatusCode}; +use hyper::body::Buf; +use hyper::body::Incoming; + +use http_body_util::BodyExt; +use nostr; + +use thiserror::Error; +use std::sync::Arc; +use log; +use hyper::Method; +use tokio::sync::Mutex; +use serde_json::{json, Value}; +use crate::notification_manager::NotificationManager; + +struct ParsedRequest { + uri: String, + method: Method, + body_bytes: Option>, + authorized_pubkey: nostr::PublicKey, +} + +impl ParsedRequest { + fn body_json(&self) -> Result> { + if let Some(body_bytes) = &self.body_bytes { + Ok(serde_json::from_slice(body_bytes)?) + } else { + Ok(json!({})) + } + } +} + +struct APIResponse { + status: StatusCode, + body: Value, +} + +pub struct APIHandler { + notification_manager: Arc>, + base_url: String, +} + +impl Clone for APIHandler { + fn clone(&self) -> Self { + APIHandler { + notification_manager: self.notification_manager.clone(), + base_url: self.base_url.clone(), + } + } +} + +impl APIHandler { + pub fn new(notification_manager: Arc>, base_url: String) -> Self { + APIHandler { + notification_manager, + base_url, + } + } + + pub async fn handle_http_request(&self, req: Request) -> Result, hyper::http::Error> { + let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await { + Ok(api_response) => { + APIResponse { + status: api_response.status, + body: api_response.body, + } + }, + Err(err) => { + // Detect if error is a APIError::AuthenticationError and return a 401 status code + if let Some(api_error) = err.downcast_ref::() { + match api_error { + APIError::AuthenticationError(message) => { + APIResponse { + status: StatusCode::UNAUTHORIZED, + body: json!({ "error": "Unauthorized", "message": message }), + } + }, + } + } + else { + // Otherwise, return a 500 status code + let random_case_uuid = uuid::Uuid::new_v4(); + log::error!("Error handling request: {} (Case ID: {})", err, random_case_uuid); + APIResponse { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: json!({ "error": "Internal server error", "message": format!("Case ID: {}", random_case_uuid) }), + } + } + } + }; + Ok(Response::builder() + .header("Content-Type", "application/json") + .header("Access-Control-Allow-Origin", "*") + .status(final_api_response.status) + .body(final_api_response.body.to_string())? + ) + } + + async fn try_to_handle_http_request(&self, mut req: Request) -> Result> { + let parsed_request = self.parse_http_request(&mut req).await?; + let api_response: APIResponse = self.handle_parsed_http_request(&parsed_request).await?; + log::info!("[{}] {} (Authorized pubkey: {}): {}", req.method(), req.uri(), parsed_request.authorized_pubkey, api_response.status); + Ok(api_response) + } + + async fn parse_http_request(&self, req: &mut Request) -> Result> { + // 1. Read the request body + let body_buffer = req.body_mut().collect().await?.aggregate(); + let body_bytes = body_buffer.chunk(); + let body_bytes = if body_bytes.is_empty() { None } else { Some(body_bytes) }; + + // 2. NIP-98 authentication + let authorized_pubkey = match self.authenticate(&req, body_bytes).await? { + Ok(pubkey) => { + pubkey + }, + Err(auth_error) => { + return Err(Box::new(APIError::AuthenticationError(auth_error))); + } + }; + + // 3. Parse the request + Ok(ParsedRequest { + uri: req.uri().path().to_string(), + method: req.method().clone(), + body_bytes: body_bytes.map(|b| b.to_vec()), + authorized_pubkey, + }) + } + + async fn handle_parsed_http_request(&self, parsed_request: &ParsedRequest) -> Result> { + match (&parsed_request.method, parsed_request.uri.as_str()) { + (&Method::POST, "/user-info") => self.handle_user_info(parsed_request).await, + (&Method::POST, "/user-info/remove") => self.handle_user_info_remove(parsed_request).await, + _ => { + Ok(APIResponse { + status: StatusCode::NOT_FOUND, + body: json!({ "error": "Not found" }), + }) + } + } + } + + async fn authenticate(&self, req: &Request, body_bytes: Option<&[u8]>) -> Result, Box> { + let auth_header = match req.headers().get("Authorization") { + Some(header) => header, + None => return Ok(Err("Authorization header not found".to_string())), + }; + + Ok(nip98_auth::nip98_verify_auth_header( + auth_header.to_str()?.to_string(), + &format!("{}{}", self.base_url, req.uri().path()), + req.method().as_str(), + body_bytes + ).await) + } + + async fn handle_user_info(&self, req: &ParsedRequest) -> Result> { + let body = req.body_json()?; + + if let Some(device_token) = body["deviceToken"].as_str() { + let notification_manager = self.notification_manager.lock().await; + notification_manager.save_user_device_info(req.authorized_pubkey, device_token)?; + return Ok(APIResponse { + status: StatusCode::OK, + body: json!({ "message": "User info saved successfully" }), + }); + } else { + return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "deviceToken is required" }), + }); + } + } + + async fn handle_user_info_remove(&self, req: &ParsedRequest) -> Result> { + let body: Value = req.body_json()?; + + if let Some(device_token) = body["deviceToken"].as_str() { + let notification_manager = self.notification_manager.lock().await; + notification_manager.remove_user_device_info(req.authorized_pubkey, device_token)?; + return Ok(APIResponse { + status: StatusCode::OK, + body: json!({ "message": "User info removed successfully" }), + }); + } else { + return Ok(APIResponse { + status: StatusCode::BAD_REQUEST, + body: json!({ "error": "deviceToken is required" }), + }); + } + } +} + +// Define enum error types including authentication error +#[derive(Debug)] +#[derive(Error)] +enum APIError { + #[error("Authentication error: {0}")] + AuthenticationError(String), +} diff --git a/src/api_server/api_server.rs b/src/api_server/api_server.rs new file mode 100644 index 0000000..9ec4498 --- /dev/null +++ b/src/api_server/api_server.rs @@ -0,0 +1,47 @@ +use hyper::{server::conn::http1, service::service_fn}; +use hyper_util::rt::TokioIo; +use std::sync::Arc; +use tokio::net::TcpListener; +use log; +use tokio::sync::Mutex; +use crate::notification_manager::NotificationManager; +use super::api_request_handler::APIHandler; + +pub struct APIServer { + host: String, + port: String, + api_handler: APIHandler, +} + +impl APIServer { + pub async fn run(host: String, port: String, notification_manager: Arc>, base_url: String) -> Result<(), Box> { + let api_handler = APIHandler::new(notification_manager, base_url); + let server = APIServer { + host, + port, + api_handler, + }; + server.start().await + } + + async fn start(&self) -> Result<(), Box> { + let address = format!("{}:{}", self.host, self.port); + let listener = TcpListener::bind(&address).await?; + + log::info!("HTTP server running at {}", address); + + loop { + let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); + let api_handler = self.api_handler.clone(); + + tokio::task::spawn(async move { + let service = service_fn(|req| api_handler.handle_http_request(req)); + + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + log::error!("Failed to serve connection: {:?}", err); + } + }); + } + } +} diff --git a/src/api_server/mod.rs b/src/api_server/mod.rs new file mode 100644 index 0000000..f3dc4bb --- /dev/null +++ b/src/api_server/mod.rs @@ -0,0 +1,3 @@ +pub mod api_server; +pub mod nip98_auth; +pub mod api_request_handler; diff --git a/src/api_server/nip98_auth.rs b/src/api_server/nip98_auth.rs new file mode 100644 index 0000000..71b7d12 --- /dev/null +++ b/src/api_server/nip98_auth.rs @@ -0,0 +1,129 @@ +use base64::prelude::*; +use serde_json::Value; +use nostr::bitcoin::hashes::sha256::Hash as Sha256Hash; +use nostr::bitcoin::hashes::Hash; +use nostr::util::hex; +use nostr::Timestamp; +use nostr; + +pub async fn nip98_verify_auth_header( + auth_header: String, + url: &str, + method: &str, + body: Option<&[u8]> +) -> Result { + if auth_header.is_empty() { + return Err("Nostr authorization header missing".to_string()); + } + + let auth_header_parts: Vec<&str> = auth_header.split_whitespace().collect(); + if auth_header_parts.len() != 2 { + return Err("Nostr authorization header does not have 2 parts".to_string()); + } + + if auth_header_parts[0] != "Nostr" { + return Err("Nostr authorization header does not start with `Nostr`".to_string()); + } + + let base64_encoded_note = auth_header_parts[1]; + if base64_encoded_note.is_empty() { + return Err("Nostr authorization header does not have a base64 encoded note".to_string()); + } + + let decoded_note_json = BASE64_STANDARD.decode(base64_encoded_note.as_bytes()) + .map_err(|_| format!("Failed to decode base64 encoded note from Nostr authorization header"))?; + + let note_value: Value = serde_json::from_slice(&decoded_note_json) + .map_err(|_| format!("Could not parse JSON note from authorization header"))?; + + let note: nostr::Event = nostr::Event::from_value(note_value) + .map_err(|_| format!("Could not parse Nostr note from JSON"))?; + + if note.kind != nostr::Kind::HttpAuth { + return Err("Nostr note kind in authorization header is incorrect".to_string()); + } + + let authorized_url = note.get_tag_content(nostr::TagKind::SingleLetter(nostr::SingleLetterTag::lowercase(nostr::Alphabet::U))) + .ok_or_else(|| "Missing 'u' tag from Nostr authorization header".to_string())?; + + let authorized_method = note.get_tag_content(nostr::TagKind::Method) + .ok_or_else(|| "Missing 'method' tag from Nostr authorization header".to_string())?; + + if authorized_url != url || authorized_method != method { + return Err(format!( + "Auth note url and/or method does not match request. Auth note url: {}; Request url: {}; Auth note method: {}; Request method: {}", + authorized_url, url, authorized_method, method + )); + } + + let current_time: nostr::Timestamp = nostr::Timestamp::now(); + let note_created_at: nostr::Timestamp = note.created_at(); + let time_delta = TimeDelta::subtracting(current_time, note_created_at); + if (time_delta.negative && time_delta.delta_abs_seconds > 30) || (!time_delta.negative && time_delta.delta_abs_seconds > 60) { + return Err(format!( + "Auth note is too old. Current time: {}; Note created at: {}; Time delta: {} seconds", + current_time, note_created_at, time_delta + )); + } + + if let Some(body_data) = body { + let authorized_content_hash_bytes: Vec = hex::decode( + note.get_tag_content(nostr::TagKind::Payload) + .ok_or("Missing 'payload' tag from Nostr authorization header")? + ) + .map_err(|_| format!("Failed to decode hex encoded payload from Nostr authorization header"))?; + + let authorized_content_hash: Sha256Hash = Sha256Hash::from_slice(&authorized_content_hash_bytes) + .map_err(|_| format!("Failed to convert hex encoded payload to Sha256Hash"))?; + + let body_hash = Sha256Hash::hash(body_data); + if authorized_content_hash != body_hash { + return Err("Auth note payload hash does not match request body hash".to_string()); + } + } else { + let authorized_content_hash_string = note.get_tag_content(nostr::TagKind::Payload); + if authorized_content_hash_string.is_some() { + return Err("Auth note has payload tag but request has no body".to_string()); + } + } + + // Verify both the Event ID and the cryptographic signature + if note.verify().is_err() { + return Err("Auth note id or signature is invalid".to_string()); + } + + Ok(note.pubkey) +} + +struct TimeDelta { + delta_abs_seconds: u64, + negative: bool +} + +impl TimeDelta { + /// Safely calculate the difference between two timestamps in seconds + /// This function is safer against overflows than subtracting the timestamps directly + fn subtracting(t1: Timestamp, t2: Timestamp) -> TimeDelta { + if t1 > t2 { + TimeDelta { + delta_abs_seconds: (t1 - t2).as_u64(), + negative: false + } + } else { + TimeDelta { + delta_abs_seconds: (t2 - t1).as_u64(), + negative: true + } + } + } +} + +impl std::fmt::Display for TimeDelta { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.negative { + write!(f, "-{}", self.delta_abs_seconds) + } else { + write!(f, "{}", self.delta_abs_seconds) + } + } +} diff --git a/src/main.rs b/src/main.rs index 94d34ac..524aea9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,7 @@ +#![forbid(unsafe_code)] use std::net::TcpListener; use std::sync::Arc; +use api_server::api_server::APIServer; use tokio::sync::Mutex; mod notification_manager; use log; @@ -10,6 +12,7 @@ use relay_connection::RelayConnection; use r2d2; mod notepush_env; use notepush_env::NotePushEnv; +mod api_server; #[tokio::main] async fn main () { @@ -19,7 +22,7 @@ async fn main () { env_logger::init(); let env = NotePushEnv::load_env().expect("Failed to load environment variables"); - let server = TcpListener::bind(&env.address()).expect("Failed to bind to address"); + let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address"); let manager = SqliteConnectionManager::file(env.db_path.clone()); let pool: r2d2::Pool = r2d2::Pool::new(manager).expect("Failed to create SQLite connection pool"); @@ -35,10 +38,21 @@ async fn main () { env.apns_topic.clone(), ).await.expect("Failed to create notification manager"))); - log::info!("Server listening on {}", env.address().clone()); + // MARK: - Start the API server + { + let notification_manager = notification_manager.clone(); + let api_host = env.api_host.clone(); + let api_port = env.api_port.clone(); + let api_base_url = env.api_base_url.clone(); + tokio::spawn(async move { + APIServer::run(api_host, api_port, notification_manager, api_base_url).await.expect("Failed to start API server"); + }); + } // MARK: - Start handling incoming connections + log::info!("Relay server listening on {}", env.relay_address().clone()); + for stream in server.incoming() { if let Ok(stream) = stream { let peer_address_string = stream.peer_addr().map_or("unknown".to_string(), |addr| addr.to_string()); diff --git a/src/notepush_env.rs b/src/notepush_env.rs index f28972b..4923e95 100644 --- a/src/notepush_env.rs +++ b/src/notepush_env.rs @@ -3,19 +3,33 @@ use dotenv::dotenv; use a2; const DEFAULT_DB_PATH: &str = "./apns_notifications.db"; -const DEFAULT_HOST: &str = "0.0.0.0"; -const DEFAULT_PORT: &str = "9001"; -const DEFAULT_RELAY_URL: &str = "ws://localhost:7777"; +const DEFAULT_RELAY_HOST: &str = "0.0.0.0"; +const DEFAULT_RELAY_PORT: &str = "9001"; +const DEFAULT_RELAY_URL: &str = "wss://relay.damus.io"; +const DEFAULT_API_HOST: &str = "0.0.0.0"; +const DEFAULT_API_PORT: &str = "8000"; pub struct NotePushEnv { + // The path to the Apple private key .p8 file pub apns_private_key_path: String, + // The Apple private key ID pub apns_private_key_id: String, + // The Apple team ID pub apns_team_id: String, + // The APNS environment to send notifications to (Sandbox or Production) pub apns_environment: a2::client::Endpoint, + // The topic to send notifications to (The Apple app bundle ID) pub apns_topic: String, + // The path to the SQLite database file pub db_path: String, - pub host: String, - pub port: String, + // The host and port to bind the relay server to + pub relay_host: String, + pub relay_port: String, + // The host and port to bind the API server to + pub api_host: String, + pub api_port: String, + pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks + // The URL of the Nostr relay server to connect to for getting mutelists pub relay_url: String, } @@ -26,10 +40,13 @@ impl NotePushEnv { let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?; let apns_team_id = env::var("APPLE_TEAM_ID")?; let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string()); - let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string()); - let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string()); + let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string()); + let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string()); let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string()); let apns_environment_string = env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string()); + let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string()); + let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string()); + let api_base_url = env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port)); let apns_environment = match apns_environment_string.as_str() { "development" => a2::client::Endpoint::Sandbox, "production" => a2::client::Endpoint::Production, @@ -44,13 +61,16 @@ impl NotePushEnv { apns_environment, apns_topic, db_path, - host, - port, + relay_host, + relay_port, + api_host, + api_port, + api_base_url, relay_url, }) } - pub fn address(&self) -> String { - format!("{}:{}", self.host, self.port) + pub fn relay_address(&self) -> String { + format!("{}:{}", self.relay_host, self.relay_port) } } diff --git a/src/notification_manager/notification_manager.rs b/src/notification_manager/notification_manager.rs index 943158a..2b2db7c 100644 --- a/src/notification_manager/notification_manager.rs +++ b/src/notification_manager/notification_manager.rs @@ -249,13 +249,13 @@ impl NotificationManager { (title, subtitle, body) } - pub fn save_user_device_info(&self, pubkey: &str, device_token: &str) -> Result<(), Box> { + pub fn save_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box> { let current_time_unix = Timestamp::now(); self.db.get()?.execute( "INSERT OR REPLACE INTO user_info (id, pubkey, device_token, added_at) VALUES (?, ?, ?, ?)", params![ - format!("{}:{}", pubkey, device_token), - pubkey, + format!("{}:{}", pubkey.to_sql_string(), device_token), + pubkey.to_sql_string(), device_token, current_time_unix.to_sql_string() ], @@ -263,10 +263,10 @@ impl NotificationManager { Ok(()) } - pub fn remove_user_device_info(&self, pubkey: &str, device_token: &str) -> Result<(), Box> { + pub fn remove_user_device_info(&self, pubkey: nostr::PublicKey, device_token: &str) -> Result<(), Box> { self.db.get()?.execute( "DELETE FROM user_info WHERE pubkey = ? AND device_token = ?", - params![pubkey, device_token], + params![pubkey.to_sql_string(), device_token], )?; Ok(()) }