Unified API + Websockets server

Both API and Websockets are running on the same port

Signed-off-by: Daniel D’Aquino <daniel@daquino.me>
This commit is contained in:
Daniel D’Aquino
2024-08-02 15:58:49 -07:00
parent daa2408cfd
commit db088b1aa3
10 changed files with 153 additions and 161 deletions

17
Cargo.lock generated
View File

@ -873,6 +873,21 @@ dependencies = [
"webpki-roots",
]
[[package]]
name = "hyper-tungstenite"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69ce21dae6ce6e5f336a444d846e592faf42c5c28f70a5c8ff67893cbcb304d3"
dependencies = [
"http-body-util",
"hyper",
"hyper-util",
"pin-project-lite",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]]
name = "hyper-util"
version = "0.1.6"
@ -1191,8 +1206,10 @@ dependencies = [
"chrono",
"dotenv",
"env_logger",
"futures",
"http-body-util",
"hyper",
"hyper-tungstenite",
"hyper-util",
"log",
"nostr",

View File

@ -35,3 +35,5 @@ hyper-util = "0.1.6"
http-body-util = "0.1.2"
uuid = { version = "1.10.0", features = ["v4"] }
thiserror = "1.0.63"
hyper-tungstenite = "0.14.0"
futures = "0.3.30"

View File

@ -18,10 +18,8 @@ APNS_ENVIRONMENT="development" # The environment to use with the APNS server.
APPLE_TEAM_ID=1248163264 # The ID of the team. Can be found in AppStore Connect.
DB_PATH=./apns_notifications.db # Path to the SQLite database file that will be used to store data about sent notifications, relative to the working directory
RELAY_URL=wss://relay.damus.io # URL to the relay server which will be consulted to get information such as mute lists.
RELAY_HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces)
RELAY_PORT=9001 # The port to bind the server to. Defaults to 9001
API_HOST="0.0.0.0" # The host to bind the API server to (Defaults to 0.0.0.0 to bind to all interfaces)
API_PORT=8000 # The port to bind the API server to. Defaults to 8000
HOST="0.0.0.0" # The host to bind the server to (Defaults to 0.0.0.0 to bind to all interfaces)
PORT=8000 # The port to bind the server to. Defaults to 8000
API_BASE_URL=http://localhost:8000 # Base URL from the API is allowed access (used by the server to perform NIP-98 authentication)
```
@ -50,6 +48,6 @@ You can use `test/test-inputs` with a websockets test tool such as `websocat` to
```sh
$ nix-shell
[nix-shell] $ websocat ws://localhost:9001
[nix-shell] $ websocat ws://localhost:8000
<ENTER_FULL_JSON_PAYLOAD_HERE_AND_PRESS_ENTER>
```

View File

@ -1,7 +1,11 @@
use super::nip98_auth;
use crate::nip98_auth;
use crate::relay_connection::RelayConnection;
use http_body_util::Full;
use hyper::body::Buf;
use hyper::body::Bytes;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use hyper_tungstenite;
use http_body_util::BodyExt;
use nostr;
@ -61,7 +65,23 @@ impl APIHandler {
pub async fn handle_http_request(
&self,
req: Request<Incoming>,
) -> Result<Response<String>, hyper::http::Error> {
) -> Result<Response<Full<Bytes>>, hyper::http::Error> {
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&req) {
return match self.handle_websocket_upgrade(req).await {
Ok(response) => Ok(response),
Err(err) => {
log::error!("Error handling websocket upgrade request: {}", err);
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(http_body_util::Full::new(Bytes::from(
"Internal server error",
)))?)
}
};
}
// If not, handle the request as a normal API request.
let final_api_response: APIResponse = match self.try_to_handle_http_request(req).await {
Ok(api_response) => APIResponse {
status: api_response.status,
@ -91,11 +111,34 @@ impl APIHandler {
}
}
};
Ok(Response::builder()
.header("Content-Type", "application/json")
.header("Access-Control-Allow-Origin", "*")
.status(final_api_response.status)
.body(final_api_response.body.to_string())?)
.body(http_body_util::Full::new(Bytes::from(
final_api_response.body.to_string(),
)))?)
}
async fn handle_websocket_upgrade(
&self,
mut req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, Box<dyn std::error::Error>> {
let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)?;
log::info!("New websocket connection.");
let new_notification_manager = self.notification_manager.clone();
tokio::spawn(async move {
match RelayConnection::run(websocket, new_notification_manager).await {
Ok(_) => {}
Err(e) => {
log::error!("Error with websocket connection: {:?}", e);
}
}
});
Ok(response)
}
async fn try_to_handle_http_request(

View File

@ -1,52 +0,0 @@
use super::api_request_handler::APIHandler;
use crate::notification_manager::NotificationManager;
use hyper::{server::conn::http1, service::service_fn};
use hyper_util::rt::TokioIo;
use log;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
pub struct APIServer {
host: String,
port: String,
api_handler: APIHandler,
}
impl APIServer {
pub async fn run(
host: String,
port: String,
notification_manager: Arc<Mutex<NotificationManager>>,
base_url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let api_handler = APIHandler::new(notification_manager, base_url);
let server = APIServer {
host,
port,
api_handler,
};
server.start().await
}
async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let address = format!("{}:{}", self.host, self.port);
let listener = TcpListener::bind(&address).await?;
log::info!("HTTP server running at {}", address);
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let api_handler = self.api_handler.clone();
tokio::task::spawn(async move {
let service = service_fn(|req| api_handler.handle_http_request(req));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
log::error!("Failed to serve connection: {:?}", err);
}
});
}
}
}

View File

@ -1,3 +0,0 @@
pub mod api_request_handler;
pub mod api_server;
pub mod nip98_auth;

View File

@ -1,7 +1,7 @@
#![forbid(unsafe_code)]
use api_server::api_server::APIServer;
use std::net::TcpListener;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
mod notification_manager;
use env_logger;
@ -9,19 +9,22 @@ use log;
use r2d2_sqlite::SqliteConnectionManager;
mod relay_connection;
use r2d2;
use relay_connection::RelayConnection;
mod notepush_env;
use notepush_env::NotePushEnv;
mod api_server;
mod api_request_handler;
mod nip98_auth;
#[tokio::main]
async fn main() {
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// MARK: - Setup basics
env_logger::init();
let env = NotePushEnv::load_env().expect("Failed to load environment variables");
let server = TcpListener::bind(&env.relay_address()).expect("Failed to bind to address");
let listener = TcpListener::bind(&env.relay_address())
.await
.expect("Failed to bind to address");
log::info!("Server running at {}", env.relay_address());
let manager = SqliteConnectionManager::file(env.db_path.clone());
let pool: r2d2::Pool<SqliteConnectionManager> =
@ -41,45 +44,27 @@ async fn main() {
.await
.expect("Failed to create notification manager"),
));
let api_handler = Arc::new(api_request_handler::APIHandler::new(
notification_manager.clone(),
env.api_base_url.clone(),
));
// MARK: - Start the API server
{
let notification_manager = notification_manager.clone();
let api_host = env.api_host.clone();
let api_port = env.api_port.clone();
let api_base_url = env.api_base_url.clone();
tokio::spawn(async move {
APIServer::run(api_host, api_port, notification_manager, api_base_url)
.await
.expect("Failed to start API server");
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let api_handler_clone = api_handler.clone();
let mut http = hyper::server::conn::http1::Builder::new();
http.keep_alive(true);
tokio::task::spawn(async move {
let service =
hyper::service::service_fn(|req| api_handler_clone.handle_http_request(req));
let connection = http.serve_connection(io, service).with_upgrades();
if let Err(err) = connection.await {
log::error!("Failed to serve connection: {:?}", err);
}
});
}
// MARK: - Start handling incoming connections
log::info!("Relay server listening on {}", env.relay_address().clone());
for stream in server.incoming() {
if let Ok(stream) = stream {
let peer_address_string = stream
.peer_addr()
.map_or("unknown".to_string(), |addr| addr.to_string());
log::info!("New connection from {}", peer_address_string);
let notification_manager = notification_manager.clone();
tokio::spawn(async move {
match RelayConnection::run(stream, notification_manager).await {
Ok(_) => {}
Err(e) => {
log::error!(
"Error with websocket connection from {}: {:?}",
peer_address_string,
e
);
}
}
});
} else if let Err(e) = stream {
log::error!("Error in incoming connection stream: {:?}", e);
}
}
}

View File

@ -3,11 +3,9 @@ use dotenv::dotenv;
use std::env;
const DEFAULT_DB_PATH: &str = "./apns_notifications.db";
const DEFAULT_RELAY_HOST: &str = "0.0.0.0";
const DEFAULT_RELAY_PORT: &str = "9001";
const DEFAULT_HOST: &str = "0.0.0.0";
const DEFAULT_PORT: &str = "8000";
const DEFAULT_RELAY_URL: &str = "wss://relay.damus.io";
const DEFAULT_API_HOST: &str = "0.0.0.0";
const DEFAULT_API_PORT: &str = "8000";
pub struct NotePushEnv {
// The path to the Apple private key .p8 file
@ -22,12 +20,9 @@ pub struct NotePushEnv {
pub apns_topic: String,
// The path to the SQLite database file
pub db_path: String,
// The host and port to bind the relay server to
pub relay_host: String,
pub relay_port: String,
// The host and port to bind the API server to
pub api_host: String,
pub api_port: String,
// The host and port to bind the relay and API to
pub host: String,
pub port: String,
pub api_base_url: String, // The base URL of where the API server is hosted for NIP-98 auth checks
// The URL of the Nostr relay server to connect to for getting mutelists
pub relay_url: String,
@ -40,15 +35,12 @@ impl NotePushEnv {
let apns_private_key_id = env::var("APNS_AUTH_PRIVATE_KEY_ID")?;
let apns_team_id = env::var("APPLE_TEAM_ID")?;
let db_path = env::var("DB_PATH").unwrap_or(DEFAULT_DB_PATH.to_string());
let relay_host = env::var("RELAY_HOST").unwrap_or(DEFAULT_RELAY_HOST.to_string());
let relay_port = env::var("RELAY_PORT").unwrap_or(DEFAULT_RELAY_PORT.to_string());
let host = env::var("HOST").unwrap_or(DEFAULT_HOST.to_string());
let port = env::var("PORT").unwrap_or(DEFAULT_PORT.to_string());
let relay_url = env::var("RELAY_URL").unwrap_or(DEFAULT_RELAY_URL.to_string());
let apns_environment_string =
env::var("APNS_ENVIRONMENT").unwrap_or("development".to_string());
let api_host = env::var("API_HOST").unwrap_or(DEFAULT_API_HOST.to_string());
let api_port = env::var("API_PORT").unwrap_or(DEFAULT_API_PORT.to_string());
let api_base_url =
env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", api_host, api_port));
let api_base_url = env::var("API_BASE_URL").unwrap_or(format!("https://{}:{}", host, port));
let apns_environment = match apns_environment_string.as_str() {
"development" => a2::client::Endpoint::Sandbox,
"production" => a2::client::Endpoint::Production,
@ -63,16 +55,14 @@ impl NotePushEnv {
apns_environment,
apns_topic,
db_path,
relay_host,
relay_port,
api_host,
api_port,
host,
port,
api_base_url,
relay_url,
})
}
pub fn relay_address(&self) -> String {
format!("{}:{}", self.relay_host, self.relay_port)
format!("{}:{}", self.host, self.port)
}
}

View File

@ -1,92 +1,104 @@
use crate::notification_manager::NotificationManager;
use futures::sink::SinkExt;
use futures::StreamExt;
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{HyperWebsocket, WebSocketStream};
use hyper_util::rt::TokioIo;
use log;
use nostr::util::JsonUtil;
use nostr::{ClientMessage, RelayMessage};
use serde_json::Value;
use std::fmt::{self, Debug};
use std::net::TcpStream;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::Mutex;
use tungstenite::{accept, WebSocket};
use tungstenite::{Error, Message};
const MAX_CONSECUTIVE_ERRORS: u32 = 10;
pub struct RelayConnection {
websocket: WebSocket<TcpStream>,
notification_manager: Arc<Mutex<NotificationManager>>,
}
impl RelayConnection {
// MARK: - Initializers
pub fn new(
stream: TcpStream,
pub async fn new(
notification_manager: Arc<Mutex<NotificationManager>>,
) -> Result<Self, Box<dyn std::error::Error>> {
let address = stream.peer_addr()?;
let websocket = accept(stream)?;
log::info!("Accepted connection from {:?}", address);
log::info!("Accepted websocket connection");
Ok(RelayConnection {
websocket,
notification_manager,
})
}
pub async fn run(
stream: TcpStream,
websocket: HyperWebsocket,
notification_manager: Arc<Mutex<NotificationManager>>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut connection = RelayConnection::new(stream, notification_manager)?;
Ok(connection.run_loop().await?)
let mut connection = RelayConnection::new(notification_manager).await?;
Ok(connection.run_loop(websocket).await?)
}
// MARK: - Connection Runtime management
pub async fn run_loop(&mut self) -> Result<(), Box<dyn std::error::Error>> {
pub async fn run_loop(
&mut self,
websocket: HyperWebsocket,
) -> Result<(), Box<dyn std::error::Error>> {
let mut consecutive_errors = 0;
log::debug!("Starting run loop for connection with {:?}", self.websocket);
loop {
match self.run_loop_iteration().await {
log::debug!("Starting run loop for connection with {:?}", websocket);
let mut websocket_stream = websocket.await?;
while let Some(raw_message) = websocket_stream.next().await {
match self
.run_loop_iteration_if_raw_message_is_ok(raw_message, &mut websocket_stream)
.await
{
Ok(_) => {
consecutive_errors = 0;
}
Err(e) => {
log::error!(
"Error in websocket connection with {:?}: {:?}",
self.websocket,
e
);
log::error!("Error in websocket connection: {:?}", e);
consecutive_errors += 1;
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
log::error!(
"Too many consecutive errors, closing connection with {:?}",
self.websocket
);
log::error!("Too many consecutive errors, closing connection");
return Err(e);
}
}
}
}
Ok(())
}
pub async fn run_loop_iteration<'a>(&'a mut self) -> Result<(), Box<dyn std::error::Error>> {
let websocket = &mut self.websocket;
let raw_message = websocket.read()?;
pub async fn run_loop_iteration_if_raw_message_is_ok(
&mut self,
raw_message: Result<Message, Error>,
stream: &mut WebSocketStream<TokioIo<Upgraded>>,
) -> Result<(), Box<dyn std::error::Error>> {
let raw_message = raw_message?;
self.run_loop_iteration(raw_message, stream).await
}
pub async fn run_loop_iteration(
&mut self,
raw_message: Message,
stream: &mut WebSocketStream<TokioIo<Upgraded>>,
) -> Result<(), Box<dyn std::error::Error>> {
if raw_message.is_text() {
let message: ClientMessage =
ClientMessage::from_value(Value::from_str(raw_message.to_text()?)?)?;
let response = self.handle_client_message(message).await?;
self.websocket
.send(tungstenite::Message::text(response.try_as_json()?))?;
stream
.send(tungstenite::Message::text(response.try_as_json()?))
.await?;
}
Ok(())
}
// MARK: - Message handling
async fn handle_client_message<'b>(
&'b self,
async fn handle_client_message(
&self,
message: ClientMessage,
) -> Result<RelayMessage, Box<dyn std::error::Error>> {
match message {
@ -119,6 +131,6 @@ impl RelayConnection {
impl Debug for RelayConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RelayConnection with websocket: {:?}", self.websocket)
write!(f, "RelayConnection with websocket")
}
}