diff --git a/Cargo.lock b/Cargo.lock index 144233d..ca5e842 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -873,6 +873,21 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "hyper-tungstenite" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ce21dae6ce6e5f336a444d846e592faf42c5c28f70a5c8ff67893cbcb304d3" +dependencies = [ + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tokio-tungstenite", + "tungstenite", +] + [[package]] name = "hyper-util" version = "0.1.6" @@ -1191,8 +1206,10 @@ dependencies = [ "chrono", "dotenv", "env_logger", + "futures", "http-body-util", "hyper", + "hyper-tungstenite", "hyper-util", "log", "nostr", diff --git a/Cargo.toml b/Cargo.toml index 4e0b76c..2eb28b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,3 +35,5 @@ hyper-util = "0.1.6" http-body-util = "0.1.2" uuid = { version = "1.10.0", features = ["v4"] } thiserror = "1.0.63" +hyper-tungstenite = "0.14.0" +futures = "0.3.30" diff --git a/README.md b/README.md index 653a393..634b2e0 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,8 @@ APNS_ENVIRONMENT="development" # The environment to use with the APNS server. APPLE_TEAM_ID=1248163264 # The ID of the team. Can be found in AppStore Connect. DB_PATH=./apns_notifications.db # Path to the SQLite database file that will be used to store data about sent notifications, relative to the working directory RELAY_URL=wss://relay.damus.io # URL to the relay server which will be consulted to get information such as mute lists. -RELAY_HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces) -RELAY_PORT=9001 # The port to bind the server to. Defaults to 9001 -API_HOST="0.0.0.0" # The host to bind the API server to (Defaults to 0.0.0.0 to bind to all interfaces) -API_PORT=8000 # The port to bind the API server to. Defaults to 8000 +HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces) +PORT=8000 # The port to bind the server to. Defaults to 8000 API_BASE_URL=http://localhost:8000 # Base URL from the API is allowed access (used by the server to perform NIP-98 authentication) ``` @@ -50,6 +48,6 @@ You can use `test/test-inputs` with a websockets test tool such as `websocat` to ```sh $ nix-shell -[nix-shell] $ websocat ws://localhost:9001 +[nix-shell] $ websocat ws://localhost:8000 ``` diff --git a/src/api_server/api_request_handler.rs b/src/api_request_handler.rs similarity index 81% rename from src/api_server/api_request_handler.rs rename to src/api_request_handler.rs index 599cf35..e670b2b 100644 --- a/src/api_server/api_request_handler.rs +++ b/src/api_request_handler.rs @@ -1,7 +1,11 @@ -use super::nip98_auth; +use crate::nip98_auth; +use crate::relay_connection::RelayConnection; +use http_body_util::Full; use hyper::body::Buf; +use hyper::body::Bytes; use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; +use hyper_tungstenite; use http_body_util::BodyExt; use nostr; @@ -61,7 +65,23 @@ impl APIHandler { pub async fn handle_http_request( &self, req: Request, - ) -> Result, hyper::http::Error> { + ) -> Result>, hyper::http::Error> { + // Check if the request is a websocket upgrade request. + if hyper_tungstenite::is_upgrade_request(&req) { + return match self.handle_websocket_upgrade(req).await { + Ok(response) => Ok(response), + Err(err) => { + log::error!("Error handling websocket upgrade request: {}", err); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(http_body_util::Full::new(Bytes::from( + "Internal server error", + )))?) + } + }; + } + + // If not, handle the request as a normal API request. let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await { Ok(api_response) => APIResponse { status: api_response.status, @@ -91,11 +111,34 @@ impl APIHandler { } } }; + Ok(Response::builder() .header("Content-Type", "application/json") .header("Access-Control-Allow-Origin", "*") .status(final_api_response.status) - .body(final_api_response.body.to_string())?) + .body(http_body_util::Full::new(Bytes::from( + final_api_response.body.to_string(), + )))?) + } + + async fn handle_websocket_upgrade( + &self, + mut req: Request, + ) -> Result>, Box> { + let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?; + log::info!("New websocket connection."); + + let new_notification_manager = self.notification_manager.clone(); + tokio::spawn(async move { + match RelayConnection::run(websocket, new_notification_manager).await { + Ok(_) => {} + Err(e) => { + log::error!("Error with websocket connection: {:?}", e); + } + } + }); + + Ok(response) } async fn try_to_handle_http_request( diff --git a/src/api_server/api_server.rs b/src/api_server/api_server.rs deleted file mode 100644 index 9709641..0000000 --- a/src/api_server/api_server.rs +++ /dev/null @@ -1,52 +0,0 @@ -use super::api_request_handler::APIHandler; -use crate::notification_manager::NotificationManager; -use hyper::{server::conn::http1, service::service_fn}; -use hyper_util::rt::TokioIo; -use log; -use std::sync::Arc; -use tokio::net::TcpListener; -use tokio::sync::Mutex; - -pub struct APIServer { - host: String, - port: String, - api_handler: APIHandler, -} - -impl APIServer { - pub async fn run( - host: String, - port: String, - notification_manager: Arc>, - base_url: String, - ) -> Result<(), Box> { - let api_handler = APIHandler::new(notification_manager, base_url); - let server = APIServer { - host, - port, - api_handler, - }; - server.start().await - } - - async fn start(&self) -> Result<(), Box> { - let address = format!("{}:{}", self.host, self.port); - let listener = TcpListener::bind(&address).await?; - - log::info!("HTTP server running at {}", address); - - loop { - let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); - let api_handler = self.api_handler.clone(); - - tokio::task::spawn(async move { - let service = service_fn(|req| api_handler.handle_http_request(req)); - - if let Err(err) = http1::Builder::new().serve_connection(io, service).await { - log::error!("Failed to serve connection: {:?}", err); - } - }); - } - } -} diff --git a/src/api_server/mod.rs b/src/api_server/mod.rs deleted file mode 100644 index ffd1b5e..0000000 --- a/src/api_server/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod api_request_handler; -pub mod api_server; -pub mod nip98_auth; diff --git a/src/main.rs b/src/main.rs index 4e780e5..4ab9f86 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ #![forbid(unsafe_code)] -use api_server::api_server::APIServer; -use std::net::TcpListener; +use hyper_util::rt::TokioIo; use std::sync::Arc; +use tokio::net::TcpListener; use tokio::sync::Mutex; mod notification_manager; use env_logger; @@ -9,19 +9,22 @@ use log; use r2d2_sqlite::SqliteConnectionManager; mod relay_connection; use r2d2; -use relay_connection::RelayConnection; mod notepush_env; use notepush_env::NotePushEnv; -mod api_server; +mod api_request_handler; +mod nip98_auth; #[tokio::main] -async fn main() { +async fn main() -> Result<(), Box> { // MARK: - Setup basics env_logger::init(); let env = NotePushEnv::load_env().expect("Failed to load environment variables"); - let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address"); + let listener = TcpListener::bind(&env.relay_address()) + .await + .expect("Failed to bind to address"); + log::info!("Server running at {}", env.relay_address()); let manager = SqliteConnectionManager::file(env.db_path.clone()); let pool: r2d2::Pool = @@ -41,45 +44,27 @@ async fn main() { .await .expect("Failed to create notification manager"), )); + let api_handler = Arc::new(api_request_handler::APIHandler::new( + notification_manager.clone(), + env.api_base_url.clone(), + )); - // MARK: - Start the API server - { - let notification_manager = notification_manager.clone(); - let api_host = env.api_host.clone(); - let api_port = env.api_port.clone(); - let api_base_url = env.api_base_url.clone(); - tokio::spawn(async move { - APIServer::run(api_host, api_port, notification_manager, api_base_url) - .await - .expect("Failed to start API server"); + loop { + let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); + let api_handler_clone = api_handler.clone(); + let mut http = hyper::server::conn::http1::Builder::new(); + http.keep_alive(true); + + tokio::task::spawn(async move { + let service = + hyper::service::service_fn(|req| api_handler_clone.handle_http_request(req)); + + let connection = http.serve_connection(io, service).with_upgrades(); + + if let Err(err) = connection.await { + log::error!("Failed to serve connection: {:?}", err); + } }); } - - // MARK: - Start handling incoming connections - - log::info!("Relay server listening on {}", env.relay_address().clone()); - - for stream in server.incoming() { - if let Ok(stream) = stream { - let peer_address_string = stream - .peer_addr() - .map_or("unknown".to_string(), |addr| addr.to_string()); - log::info!("New connection from {}", peer_address_string); - let notification_manager = notification_manager.clone(); - tokio::spawn(async move { - match RelayConnection::run(stream, notification_manager).await { - Ok(_) => {} - Err(e) => { - log::error!( - "Error with websocket connection from {}: {:?}", - peer_address_string, - e - ); - } - } - }); - } else if let Err(e) = stream { - log::error!("Error in incoming connection stream: {:?}", e); - } - } } diff --git a/src/api_server/nip98_auth.rs b/src/nip98_auth.rs similarity index 100% rename from src/api_server/nip98_auth.rs rename to src/nip98_auth.rs diff --git a/src/notepush_env.rs b/src/notepush_env.rs index b62b08d..e4fc43a 100644 --- a/src/notepush_env.rs +++ b/src/notepush_env.rs @@ -3,11 +3,9 @@ use dotenv::dotenv; use std::env; const DEFAULT_DB_PATH: &str = "./apns_notifications.db"; -const DEFAULT_RELAY_HOST: &str = "0.0.0.0"; -const DEFAULT_RELAY_PORT: &str = "9001"; +const DEFAULT_HOST: &str = "0.0.0.0"; +const DEFAULT_PORT: &str = "8000"; const DEFAULT_RELAY_URL: &str = "wss://relay.damus.io"; -const DEFAULT_API_HOST: &str = "0.0.0.0"; -const DEFAULT_API_PORT: &str = "8000"; pub struct NotePushEnv { // The path to the Apple private key .p8 file @@ -22,12 +20,9 @@ pub struct NotePushEnv { pub apns_topic: String, // The path to the SQLite database file pub db_path: String, - // The host and port to bind the relay server to - pub relay_host: String, - pub relay_port: String, - // The host and port to bind the API server to - pub api_host: String, - pub api_port: String, + // The host and port to bind the relay and API to + pub host: String, + pub port: String, pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks // The URL of the Nostr relay server to connect to for getting mutelists pub relay_url: String, @@ -40,15 +35,12 @@ impl NotePushEnv { let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?; let apns_team_id = env::var("APPLE_TEAM_ID")?; let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string()); - let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string()); - let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string()); + let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string()); + let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string()); let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string()); let apns_environment_string = env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string()); - let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string()); - let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string()); - let api_base_url = - env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port)); + let api_base_url = env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", host, port)); let apns_environment = match apns_environment_string.as_str() { "development" => a2::client::Endpoint::Sandbox, "production" => a2::client::Endpoint::Production, @@ -63,16 +55,14 @@ impl NotePushEnv { apns_environment, apns_topic, db_path, - relay_host, - relay_port, - api_host, - api_port, + host, + port, api_base_url, relay_url, }) } pub fn relay_address(&self) -> String { - format!("{}:{}", self.relay_host, self.relay_port) + format!("{}:{}", self.host, self.port) } } diff --git a/src/relay_connection.rs b/src/relay_connection.rs index aa4efc4..3fc154b 100644 --- a/src/relay_connection.rs +++ b/src/relay_connection.rs @@ -1,92 +1,104 @@ use crate::notification_manager::NotificationManager; +use futures::sink::SinkExt; +use futures::StreamExt; +use hyper::upgrade::Upgraded; +use hyper_tungstenite::{HyperWebsocket, WebSocketStream}; +use hyper_util::rt::TokioIo; use log; use nostr::util::JsonUtil; use nostr::{ClientMessage, RelayMessage}; use serde_json::Value; use std::fmt::{self, Debug}; -use std::net::TcpStream; use std::str::FromStr; use std::sync::Arc; use tokio::sync::Mutex; -use tungstenite::{accept, WebSocket}; +use tungstenite::{Error, Message}; const MAX_CONSECUTIVE_ERRORS: u32 = 10; pub struct RelayConnection { - websocket: WebSocket, notification_manager: Arc>, } impl RelayConnection { // MARK: - Initializers - pub fn new( - stream: TcpStream, + pub async fn new( notification_manager: Arc>, ) -> Result> { - let address = stream.peer_addr()?; - let websocket = accept(stream)?; - log::info!("Accepted connection from {:?}", address); + log::info!("Accepted websocket connection"); Ok(RelayConnection { - websocket, notification_manager, }) } pub async fn run( - stream: TcpStream, + websocket: HyperWebsocket, notification_manager: Arc>, ) -> Result<(), Box> { - let mut connection = RelayConnection::new(stream, notification_manager)?; - Ok(connection.run_loop().await?) + let mut connection = RelayConnection::new(notification_manager).await?; + Ok(connection.run_loop(websocket).await?) } // MARK: - Connection Runtime management - pub async fn run_loop(&mut self) -> Result<(), Box> { + pub async fn run_loop( + &mut self, + websocket: HyperWebsocket, + ) -> Result<(), Box> { let mut consecutive_errors = 0; - log::debug!("Starting run loop for connection with {:?}", self.websocket); - loop { - match self.run_loop_iteration().await { + log::debug!("Starting run loop for connection with {:?}", websocket); + let mut websocket_stream = websocket.await?; + while let Some(raw_message) = websocket_stream.next().await { + match self + .run_loop_iteration_if_raw_message_is_ok(raw_message, &mut websocket_stream) + .await + { Ok(_) => { consecutive_errors = 0; } Err(e) => { - log::error!( - "Error in websocket connection with {:?}: {:?}", - self.websocket, - e - ); + log::error!("Error in websocket connection: {:?}", e); consecutive_errors += 1; if consecutive_errors >= MAX_CONSECUTIVE_ERRORS { - log::error!( - "Too many consecutive errors, closing connection with {:?}", - self.websocket - ); + log::error!("Too many consecutive errors, closing connection"); return Err(e); } } } } + Ok(()) } - pub async fn run_loop_iteration<'a>(&'a mut self) -> Result<(), Box> { - let websocket = &mut self.websocket; - let raw_message = websocket.read()?; + pub async fn run_loop_iteration_if_raw_message_is_ok( + &mut self, + raw_message: Result, + stream: &mut WebSocketStream>, + ) -> Result<(), Box> { + let raw_message = raw_message?; + self.run_loop_iteration(raw_message, stream).await + } + + pub async fn run_loop_iteration( + &mut self, + raw_message: Message, + stream: &mut WebSocketStream>, + ) -> Result<(), Box> { if raw_message.is_text() { let message: ClientMessage = ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?; let response = self.handle_client_message(message).await?; - self.websocket - .send(tungstenite::Message::text(response.try_as_json()?))?; + stream + .send(tungstenite::Message::text(response.try_as_json()?)) + .await?; } Ok(()) } // MARK: - Message handling - async fn handle_client_message<'b>( - &'b self, + async fn handle_client_message( + &self, message: ClientMessage, ) -> Result> { match message { @@ -119,6 +131,6 @@ impl RelayConnection { impl Debug for RelayConnection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "RelayConnection with websocket: {:?}", self.websocket) + write!(f, "RelayConnection with websocket") } }