Unified API + Websockets server

Both API and Websockets are running on the same port

Signed-off-by: Daniel D’Aquino <daniel@daquino.me>
This commit is contained in:
Daniel D’Aquino
2024-08-02 15:58:49 -07:00
parent daa2408cfd
commit db088b1aa3
10 changed files with 153 additions and 161 deletions

17
Cargo.lock generated
View File

@ -873,6 +873,21 @@ dependencies = [
"webpki-roots", "webpki-roots",
] ]
[[package]]
name = "hyper-tungstenite"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69ce21dae6ce6e5f336a444d846e592faf42c5c28f70a5c8ff67893cbcb304d3"
dependencies = [
"http-body-util",
"hyper",
"hyper-util",
"pin-project-lite",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]] [[package]]
name = "hyper-util" name = "hyper-util"
version = "0.1.6" version = "0.1.6"
@ -1191,8 +1206,10 @@ dependencies = [
"chrono", "chrono",
"dotenv", "dotenv",
"env_logger", "env_logger",
"futures",
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-tungstenite",
"hyper-util", "hyper-util",
"log", "log",
"nostr", "nostr",

View File

@ -35,3 +35,5 @@ hyper-util = "0.1.6"
http-body-util = "0.1.2" http-body-util = "0.1.2"
uuid = { version = "1.10.0", features = ["v4"] } uuid = { version = "1.10.0", features = ["v4"] }
thiserror = "1.0.63" thiserror = "1.0.63"
hyper-tungstenite = "0.14.0"
futures = "0.3.30"

View File

@ -18,10 +18,8 @@ APNS_ENVIRONMENT="development" # The environment to use with the APNS server.
APPLE_TEAM_ID=1248163264 # The ID of the team. Can be found in AppStore Connect. APPLE_TEAM_ID=1248163264 # The ID of the team. Can be found in AppStore Connect.
DB_PATH=./apns_notifications.db # Path to the SQLite database file that will be used to store data about sent notifications, relative to the working directory DB_PATH=./apns_notifications.db # Path to the SQLite database file that will be used to store data about sent notifications, relative to the working directory
RELAY_URL=wss://relay.damus.io # URL to the relay server which will be consulted to get information such as mute lists. RELAY_URL=wss://relay.damus.io # URL to the relay server which will be consulted to get information such as mute lists.
RELAY_HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces) HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces)
RELAY_PORT=9001 # The port to bind the server to. Defaults to 9001 PORT=8000 # The port to bind the server to. Defaults to 8000
API_HOST="0.0.0.0" # The host to bind the API server to (Defaults to 0.0.0.0 to bind to all interfaces)
API_PORT=8000 # The port to bind the API server to. Defaults to 8000
API_BASE_URL=http://localhost:8000 # Base URL from the API is allowed access (used by the server to perform NIP-98 authentication) API_BASE_URL=http://localhost:8000 # Base URL from the API is allowed access (used by the server to perform NIP-98 authentication)
``` ```
@ -50,6 +48,6 @@ You can use `test/test-inputs` with a websockets test tool such as `websocat` to
```sh ```sh
$ nix-shell $ nix-shell
[nix-shell] $ websocat ws://localhost:9001 [nix-shell] $ websocat ws://localhost:8000
<ENTER_FULL_JSON_PAYLOAD_HERE_AND_PRESS_ENTER> <ENTER_FULL_JSON_PAYLOAD_HERE_AND_PRESS_ENTER>
``` ```

View File

@ -1,7 +1,11 @@
use super::nip98_auth; use crate::nip98_auth;
use crate::relay_connection::RelayConnection;
use http_body_util::Full;
use hyper::body::Buf; use hyper::body::Buf;
use hyper::body::Bytes;
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode}; use hyper::{Request, Response, StatusCode};
use hyper_tungstenite;
use http_body_util::BodyExt; use http_body_util::BodyExt;
use nostr; use nostr;
@ -61,7 +65,23 @@ impl APIHandler {
pub async fn handle_http_request( pub async fn handle_http_request(
&self, &self,
req: Request<Incoming>, req: Request<Incoming>,
) -> Result<Response<String>, hyper::http::Error> { ) -> Result<Response<Full<Bytes>>, hyper::http::Error> {
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&req) {
return match self.handle_websocket_upgrade(req).await {
Ok(response) => Ok(response),
Err(err) => {
log::error!("Error handling websocket upgrade request: {}", err);
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(http_body_util::Full::new(Bytes::from(
"Internal server error",
)))?)
}
};
}
// If not, handle the request as a normal API request.
let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await { let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await {
Ok(api_response) => APIResponse { Ok(api_response) => APIResponse {
status: api_response.status, status: api_response.status,
@ -91,11 +111,34 @@ impl APIHandler {
} }
} }
}; };
Ok(Response::builder() Ok(Response::builder()
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Access-Control-Allow-Origin", "*") .header("Access-Control-Allow-Origin", "*")
.status(final_api_response.status) .status(final_api_response.status)
.body(final_api_response.body.to_string())?) .body(http_body_util::Full::new(Bytes::from(
final_api_response.body.to_string(),
)))?)
}
async fn handle_websocket_upgrade(
&self,
mut req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, Box<dyn std::error::Error>> {
let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?;
log::info!("New websocket connection.");
let new_notification_manager = self.notification_manager.clone();
tokio::spawn(async move {
match RelayConnection::run(websocket, new_notification_manager).await {
Ok(_) => {}
Err(e) => {
log::error!("Error with websocket connection: {:?}", e);
}
}
});
Ok(response)
} }
async fn try_to_handle_http_request( async fn try_to_handle_http_request(

View File

@ -1,52 +0,0 @@
use super::api_request_handler::APIHandler;
use crate::notification_manager::NotificationManager;
use hyper::{server::conn::http1, service::service_fn};
use hyper_util::rt::TokioIo;
use log;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
pub struct APIServer {
host: String,
port: String,
api_handler: APIHandler,
}
impl APIServer {
pub async fn run(
host: String,
port: String,
notification_manager: Arc<Mutex<NotificationManager>>,
base_url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let api_handler = APIHandler::new(notification_manager, base_url);
let server = APIServer {
host,
port,
api_handler,
};
server.start().await
}
async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let address = format!("{}:{}", self.host, self.port);
let listener = TcpListener::bind(&address).await?;
log::info!("HTTP server running at {}", address);
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let api_handler = self.api_handler.clone();
tokio::task::spawn(async move {
let service = service_fn(|req| api_handler.handle_http_request(req));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
log::error!("Failed to serve connection: {:?}", err);
}
});
}
}
}

View File

@ -1,3 +0,0 @@
pub mod api_request_handler;
pub mod api_server;
pub mod nip98_auth;

View File

@ -1,7 +1,7 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use api_server::api_server::APIServer; use hyper_util::rt::TokioIo;
use std::net::TcpListener;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex; use tokio::sync::Mutex;
mod notification_manager; mod notification_manager;
use env_logger; use env_logger;
@ -9,19 +9,22 @@ use log;
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
mod relay_connection; mod relay_connection;
use r2d2; use r2d2;
use relay_connection::RelayConnection;
mod notepush_env; mod notepush_env;
use notepush_env::NotePushEnv; use notepush_env::NotePushEnv;
mod api_server; mod api_request_handler;
mod nip98_auth;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// MARK: - Setup basics // MARK: - Setup basics
env_logger::init(); env_logger::init();
let env = NotePushEnv::load_env().expect("Failed to load environment variables"); let env = NotePushEnv::load_env().expect("Failed to load environment variables");
let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address"); let listener = TcpListener::bind(&env.relay_address())
.await
.expect("Failed to bind to address");
log::info!("Server running at {}", env.relay_address());
let manager = SqliteConnectionManager::file(env.db_path.clone()); let manager = SqliteConnectionManager::file(env.db_path.clone());
let pool: r2d2::Pool<SqliteConnectionManager> = let pool: r2d2::Pool<SqliteConnectionManager> =
@ -41,45 +44,27 @@ async fn main() {
.await .await
.expect("Failed to create notification manager"), .expect("Failed to create notification manager"),
)); ));
let api_handler = Arc::new(api_request_handler::APIHandler::new(
notification_manager.clone(),
env.api_base_url.clone(),
));
// MARK: - Start the API server loop {
{ let (stream, _) = listener.accept().await?;
let notification_manager = notification_manager.clone(); let io = TokioIo::new(stream);
let api_host = env.api_host.clone(); let api_handler_clone = api_handler.clone();
let api_port = env.api_port.clone(); let mut http = hyper::server::conn::http1::Builder::new();
let api_base_url = env.api_base_url.clone(); http.keep_alive(true);
tokio::spawn(async move {
APIServer::run(api_host, api_port, notification_manager, api_base_url) tokio::task::spawn(async move {
.await let service =
.expect("Failed to start API server"); hyper::service::service_fn(|req| api_handler_clone.handle_http_request(req));
let connection = http.serve_connection(io, service).with_upgrades();
if let Err(err) = connection.await {
log::error!("Failed to serve connection: {:?}", err);
}
}); });
} }
// MARK: - Start handling incoming connections
log::info!("Relay server listening on {}", env.relay_address().clone());
for stream in server.incoming() {
if let Ok(stream) = stream {
let peer_address_string = stream
.peer_addr()
.map_or("unknown".to_string(), |addr| addr.to_string());
log::info!("New connection from {}", peer_address_string);
let notification_manager = notification_manager.clone();
tokio::spawn(async move {
match RelayConnection::run(stream, notification_manager).await {
Ok(_) => {}
Err(e) => {
log::error!(
"Error with websocket connection from {}: {:?}",
peer_address_string,
e
);
}
}
});
} else if let Err(e) = stream {
log::error!("Error in incoming connection stream: {:?}", e);
}
}
} }

View File

@ -3,11 +3,9 @@ use dotenv::dotenv;
use std::env; use std::env;
const DEFAULT_DB_PATH: &str = "./apns_notifications.db"; const DEFAULT_DB_PATH: &str = "./apns_notifications.db";
const DEFAULT_RELAY_HOST: &str = "0.0.0.0"; const DEFAULT_HOST: &str = "0.0.0.0";
const DEFAULT_RELAY_PORT: &str = "9001"; const DEFAULT_PORT: &str = "8000";
const DEFAULT_RELAY_URL: &str = "wss://relay.damus.io"; const DEFAULT_RELAY_URL: &str = "wss://relay.damus.io";
const DEFAULT_API_HOST: &str = "0.0.0.0";
const DEFAULT_API_PORT: &str = "8000";
pub struct NotePushEnv { pub struct NotePushEnv {
// The path to the Apple private key .p8 file // The path to the Apple private key .p8 file
@ -22,12 +20,9 @@ pub struct NotePushEnv {
pub apns_topic: String, pub apns_topic: String,
// The path to the SQLite database file // The path to the SQLite database file
pub db_path: String, pub db_path: String,
// The host and port to bind the relay server to // The host and port to bind the relay and API to
pub relay_host: String, pub host: String,
pub relay_port: String, pub port: String,
// The host and port to bind the API server to
pub api_host: String,
pub api_port: String,
pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks
// The URL of the Nostr relay server to connect to for getting mutelists // The URL of the Nostr relay server to connect to for getting mutelists
pub relay_url: String, pub relay_url: String,
@ -40,15 +35,12 @@ impl NotePushEnv {
let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?; let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?;
let apns_team_id = env::var("APPLE_TEAM_ID")?; let apns_team_id = env::var("APPLE_TEAM_ID")?;
let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string()); let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string());
let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string()); let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string());
let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string()); let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string());
let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string()); let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string());
let apns_environment_string = let apns_environment_string =
env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string()); env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string());
let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string()); let api_base_url = env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", host, port));
let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string());
let api_base_url =
env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port));
let apns_environment = match apns_environment_string.as_str() { let apns_environment = match apns_environment_string.as_str() {
"development" => a2::client::Endpoint::Sandbox, "development" => a2::client::Endpoint::Sandbox,
"production" => a2::client::Endpoint::Production, "production" => a2::client::Endpoint::Production,
@ -63,16 +55,14 @@ impl NotePushEnv {
apns_environment, apns_environment,
apns_topic, apns_topic,
db_path, db_path,
relay_host, host,
relay_port, port,
api_host,
api_port,
api_base_url, api_base_url,
relay_url, relay_url,
}) })
} }
pub fn relay_address(&self) -> String { pub fn relay_address(&self) -> String {
format!("{}:{}", self.relay_host, self.relay_port) format!("{}:{}", self.host, self.port)
} }
} }

View File

@ -1,92 +1,104 @@
use crate::notification_manager::NotificationManager; use crate::notification_manager::NotificationManager;
use futures::sink::SinkExt;
use futures::StreamExt;
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{HyperWebsocket, WebSocketStream};
use hyper_util::rt::TokioIo;
use log; use log;
use nostr::util::JsonUtil; use nostr::util::JsonUtil;
use nostr::{ClientMessage, RelayMessage}; use nostr::{ClientMessage, RelayMessage};
use serde_json::Value; use serde_json::Value;
use std::fmt::{self, Debug}; use std::fmt::{self, Debug};
use std::net::TcpStream;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tungstenite::{accept, WebSocket}; use tungstenite::{Error, Message};
const MAX_CONSECUTIVE_ERRORS: u32 = 10; const MAX_CONSECUTIVE_ERRORS: u32 = 10;
pub struct RelayConnection { pub struct RelayConnection {
websocket: WebSocket<TcpStream>,
notification_manager: Arc<Mutex<NotificationManager>>, notification_manager: Arc<Mutex<NotificationManager>>,
} }
impl RelayConnection { impl RelayConnection {
// MARK: - Initializers // MARK: - Initializers
pub fn new( pub async fn new(
stream: TcpStream,
notification_manager: Arc<Mutex<NotificationManager>>, notification_manager: Arc<Mutex<NotificationManager>>,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> Result<Self, Box<dyn std::error::Error>> {
let address = stream.peer_addr()?; log::info!("Accepted websocket connection");
let websocket = accept(stream)?;
log::info!("Accepted connection from {:?}", address);
Ok(RelayConnection { Ok(RelayConnection {
websocket,
notification_manager, notification_manager,
}) })
} }
pub async fn run( pub async fn run(
stream: TcpStream, websocket: HyperWebsocket,
notification_manager: Arc<Mutex<NotificationManager>>, notification_manager: Arc<Mutex<NotificationManager>>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let mut connection = RelayConnection::new(stream, notification_manager)?; let mut connection = RelayConnection::new(notification_manager).await?;
Ok(connection.run_loop().await?) Ok(connection.run_loop(websocket).await?)
} }
// MARK: - Connection Runtime management // MARK: - Connection Runtime management
pub async fn run_loop(&mut self) -> Result<(), Box<dyn std::error::Error>> { pub async fn run_loop(
&mut self,
websocket: HyperWebsocket,
) -> Result<(), Box<dyn std::error::Error>> {
let mut consecutive_errors = 0; let mut consecutive_errors = 0;
log::debug!("Starting run loop for connection with {:?}", self.websocket); log::debug!("Starting run loop for connection with {:?}", websocket);
loop { let mut websocket_stream = websocket.await?;
match self.run_loop_iteration().await { while let Some(raw_message) = websocket_stream.next().await {
match self
.run_loop_iteration_if_raw_message_is_ok(raw_message, &mut websocket_stream)
.await
{
Ok(_) => { Ok(_) => {
consecutive_errors = 0; consecutive_errors = 0;
} }
Err(e) => { Err(e) => {
log::error!( log::error!("Error in websocket connection: {:?}", e);
"Error in websocket connection with {:?}: {:?}",
self.websocket,
e
);
consecutive_errors += 1; consecutive_errors += 1;
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS { if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
log::error!( log::error!("Too many consecutive errors, closing connection");
"Too many consecutive errors, closing connection with {:?}",
self.websocket
);
return Err(e); return Err(e);
} }
} }
} }
} }
Ok(())
} }
pub async fn run_loop_iteration<'a>(&'a mut self) -> Result<(), Box<dyn std::error::Error>> { pub async fn run_loop_iteration_if_raw_message_is_ok(
let websocket = &mut self.websocket; &mut self,
let raw_message = websocket.read()?; raw_message: Result<Message, Error>,
stream: &mut WebSocketStream<TokioIo<Upgraded>>,
) -> Result<(), Box<dyn std::error::Error>> {
let raw_message = raw_message?;
self.run_loop_iteration(raw_message, stream).await
}
pub async fn run_loop_iteration(
&mut self,
raw_message: Message,
stream: &mut WebSocketStream<TokioIo<Upgraded>>,
) -> Result<(), Box<dyn std::error::Error>> {
if raw_message.is_text() { if raw_message.is_text() {
let message: ClientMessage = let message: ClientMessage =
ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?; ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?;
let response = self.handle_client_message(message).await?; let response = self.handle_client_message(message).await?;
self.websocket stream
.send(tungstenite::Message::text(response.try_as_json()?))?; .send(tungstenite::Message::text(response.try_as_json()?))
.await?;
} }
Ok(()) Ok(())
} }
// MARK: - Message handling // MARK: - Message handling
async fn handle_client_message<'b>( async fn handle_client_message(
&'b self, &self,
message: ClientMessage, message: ClientMessage,
) -> Result<RelayMessage, Box<dyn std::error::Error>> { ) -> Result<RelayMessage, Box<dyn std::error::Error>> {
match message { match message {
@ -119,6 +131,6 @@ impl RelayConnection {
impl Debug for RelayConnection { impl Debug for RelayConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RelayConnection with websocket: {:?}", self.websocket) write!(f, "RelayConnection with websocket")
} }
} }