From 4aa51fa41f6eb46170317c7b0b97564135e48484 Mon Sep 17 00:00:00 2001 From: Kieran Date: Sun, 15 Dec 2024 23:22:31 +0000 Subject: [PATCH] feat: (WIP) range requests --- Cargo.lock | 7 ++++ Cargo.toml | 1 + src/routes/mod.rs | 105 ++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 105 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d457594..9230899 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1611,6 +1611,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" + [[package]] name = "httparse" version = "1.9.4" @@ -2957,6 +2963,7 @@ dependencies = [ "config", "ffmpeg-rs-raw", "hex", + "http-range-header", "libc", "log", "mime2ext", diff --git a/Cargo.toml b/Cargo.toml index c0820f1..7b5f865 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,3 +52,4 @@ candle-nn = { git = "https://git.v0l.io/huggingface/candle.git", tag = "0.8.1", candle-transformers = { git = "https://git.v0l.io/huggingface/candle.git", tag = "0.8.1", optional = true } sqlx-postgres = { version = "0.8.2", optional = true, features = ["chrono", "uuid"] } mime2ext = "0.1.53" +http-range-header = "0.4.2" diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 92d1c9f..250ec82 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,7 +1,3 @@ -use std::fs; -use std::fs::File; -use std::str::FromStr; - use crate::db::{Database, FileUpload}; use crate::filesystem::FileStore; pub use crate::routes::admin::admin_routes; @@ -13,6 +9,9 @@ use crate::settings::Settings; #[cfg(feature = "void-cat-redirects")] use crate::void_db::VoidCatDb; use anyhow::Error; +use http_range_header::{ + parse_range_header, EndPosition, StartPosition, SyntacticallyCorrectRange, +}; use nostr::Event; use rocket::fs::NamedFile; use rocket::http::{ContentType, Header, Status}; @@ -20,7 +19,13 @@ use rocket::http::{ContentType, Header, Status}; use rocket::response::Redirect; use rocket::response::Responder; use rocket::serde::Serialize; -use rocket::{Request, State}; +use rocket::{Request, Response, State}; +use std::io::SeekFrom; +use std::pin::{pin, Pin}; +use std::str::FromStr; +use std::task::{Context, Poll}; +use tokio::fs::File; +use tokio::io::{AsyncRead, AsyncSeek, ReadBuf}; #[cfg(feature = "blossom")] mod blossom; @@ -95,9 +100,93 @@ impl Nip94Event { } } +struct RangeBody { + pub file: File, + pub file_size: u64, + pub ranges: Vec, + + current_range_index: usize, + current_offset: u64, +} + +impl AsyncRead for RangeBody { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.current_range_index >= self.ranges.len() { + return Poll::Ready(Ok(())); + } + + let current_range = &self.ranges[self.current_range_index]; + let start_pos = match current_range.start { + StartPosition::Index(i) => i, + StartPosition::FromLast(i) => self.file_size - i, + }; + let end_pos = match current_range.end { + EndPosition::Index(i) => i, + EndPosition::LastByte => self.file_size, + }; + let range_start = start_pos + self.current_offset; + let range_len = end_pos - range_start; + let bytes_to_read = buf.remaining().min(range_len as usize) as u64; + + if bytes_to_read == 0 { + self.current_offset = 0; + self.current_range_index += 1; + return self.poll_read(cx, buf); + } + + let pinned = pin!(&mut self.file); + pinned.start_seek(SeekFrom::Start(range_start))?; + + let pinned = pin!(&mut self.file); + match pinned.poll_complete(cx) { + Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + + // Read data from the file + let pinned = pin!(&mut self.file); + let n = pinned.poll_read(cx, &mut buf.take(bytes_to_read as usize)); + if let Poll::Ready(Ok(())) = n { + self.current_offset += bytes_to_read; + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + impl<'r> Responder<'r, 'static> for FilePayload { fn respond_to(self, request: &'r Request<'_>) -> rocket::response::Result<'static> { - let mut response = self.file.respond_to(request)?; + let mut response = Response::new(); + + // handle ranges + #[cfg(feature = "ranges")] + { + response.set_header(Header::new("accept-ranges", "bytes")); + if let Some(r) = request.headers().get("range").next() { + if let Ok(ranges) = parse_range_header(r) { + let r_body = RangeBody { + file_size: self.info.size, // TODO: handle filesize mismatch + file: self.file, + ranges: ranges.ranges, + current_range_index: 0, + current_offset: 0, + }; + response.set_streamed_body(Box::pin(r_body)); + } + } else { + response.set_streamed_body(self.file); + } + } + #[cfg(not(feature = "ranges"))] + response.set_streamed_body(self.file); + response.set_header(Header::new("content-length", self.info.size.to_string())); + if let Ok(ct) = ContentType::from_str(&self.info.mime_type) { response.set_header(ct); } @@ -145,7 +234,7 @@ async fn delete_file( if let Err(e) = db.delete_file(&id).await { return Err(Error::msg(format!("Failed to delete (fs): {}", e))); } - if let Err(e) = fs::remove_file(fs.get(&id)) { + if let Err(e) = tokio::fs::remove_file(fs.get(&id)).await { return Err(Error::msg(format!("Failed to delete (fs): {}", e))); } } @@ -189,7 +278,7 @@ pub async fn get_blob( return Err(Status::NotFound); } if let Ok(Some(info)) = db.get_file(&id).await { - if let Ok(f) = File::open(fs.get(&id)) { + if let Ok(f) = File::open(fs.get(&id)).await { return Ok(FilePayload { file: f, info }); } }