refactor: simplify to single range handler

This commit is contained in:
kieran 2024-12-16 11:25:34 +00:00
parent 021dc59382
commit 7c0bf5a755
No known key found for this signature in database
GPG Key ID: DE71CEB3925BE941

View File

@ -9,10 +9,8 @@ use crate::settings::Settings;
#[cfg(feature = "void-cat-redirects")] #[cfg(feature = "void-cat-redirects")]
use crate::void_db::VoidCatDb; use crate::void_db::VoidCatDb;
use anyhow::Error; use anyhow::Error;
use http_range_header::{ use http_range_header::{parse_range_header, EndPosition, StartPosition};
parse_range_header, EndPosition, StartPosition, SyntacticallyCorrectRange, use log::warn;
};
use log::{debug, warn};
use nostr::Event; use nostr::Event;
use rocket::fs::NamedFile; use rocket::fs::NamedFile;
use rocket::http::{ContentType, Header, Status}; use rocket::http::{ContentType, Header, Status};
@ -22,6 +20,7 @@ use rocket::response::Responder;
use rocket::serde::Serialize; use rocket::serde::Serialize;
use rocket::{Request, Response, State}; use rocket::{Request, Response, State};
use std::io::SeekFrom; use std::io::SeekFrom;
use std::ops::Range;
use std::pin::{pin, Pin}; use std::pin::{pin, Pin};
use std::str::FromStr; use std::str::FromStr;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -104,21 +103,19 @@ impl Nip94Event {
/// Range request handler over file handle /// Range request handler over file handle
struct RangeBody { struct RangeBody {
file: File, file: File,
file_size: u64, range_start: u64,
ranges: Vec<SyntacticallyCorrectRange>, range_end: u64,
current_range_index: usize,
current_offset: u64, current_offset: u64,
poll_complete: bool, poll_complete: bool,
} }
impl RangeBody { impl RangeBody {
pub fn new(file: File, file_size: u64, ranges: Vec<SyntacticallyCorrectRange>) -> Self { pub fn new(file: File, range: Range<u64>) -> Self {
Self { Self {
file, file,
file_size, range_start: range.start,
ranges, range_end: range.end,
current_offset: 0, current_offset: 0,
current_range_index: 0,
poll_complete: false, poll_complete: false,
} }
} }
@ -130,36 +127,22 @@ impl AsyncRead for RangeBody {
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> { ) -> Poll<std::io::Result<()>> {
if self.current_range_index >= self.ranges.len() { let range_start = self.range_start + self.current_offset;
return Poll::Ready(Ok(())); let range_len = self.range_end - range_start;
}
let current_range = &self.ranges[self.current_range_index];
let start_pos = match current_range.start {
StartPosition::Index(i) => i,
StartPosition::FromLast(i) => self.file_size - i,
};
let end_pos = match current_range.end {
EndPosition::Index(i) => i,
EndPosition::LastByte => self.file_size,
};
let range_start = start_pos + self.current_offset;
let range_len = end_pos - range_start;
let bytes_to_read = buf.remaining().min(range_len as usize) as u64; let bytes_to_read = buf.remaining().min(range_len as usize) as u64;
if bytes_to_read == 0 { if bytes_to_read == 0 {
self.current_offset = 0; return Poll::Ready(Ok(()));
self.current_range_index += 1;
return self.poll_read(cx, buf);
} }
// when no pending poll, seek to starting position
if !self.poll_complete { if !self.poll_complete {
// start seeking to our read position
let pinned = pin!(&mut self.file); let pinned = pin!(&mut self.file);
pinned.start_seek(SeekFrom::Start(range_start))?; pinned.start_seek(SeekFrom::Start(range_start))?;
self.poll_complete = true; self.poll_complete = true;
} }
// check poll completion
if self.poll_complete { if self.poll_complete {
let pinned = pin!(&mut self.file); let pinned = pin!(&mut self.file);
match pinned.poll_complete(cx) { match pinned.poll_complete(cx) {
@ -173,13 +156,16 @@ impl AsyncRead for RangeBody {
// Read data from the file // Read data from the file
let pinned = pin!(&mut self.file); let pinned = pin!(&mut self.file);
let n = pinned.poll_read(cx, buf); match pinned.poll_read(cx, buf) {
if let Poll::Ready(Ok(())) = n { Poll::Ready(Ok(_)) => {
self.current_offset += bytes_to_read; self.current_offset += bytes_to_read;
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} else { }
self.poll_complete = true; Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending Poll::Pending => {
self.poll_complete = true;
Poll::Pending
}
} }
} }
} }
@ -198,6 +184,7 @@ impl<'r> Responder<'r, 'static> for FilePayload {
warn!("Multipart ranges are not supported, fallback to non-range request"); warn!("Multipart ranges are not supported, fallback to non-range request");
response.set_streamed_body(self.file); response.set_streamed_body(self.file);
} else { } else {
const MAX_UNBOUNDED_RANGE: u64 = 1024 * 256;
let single_range = ranges.ranges.first().unwrap(); let single_range = ranges.ranges.first().unwrap();
let range_start = match single_range.start { let range_start = match single_range.start {
StartPosition::Index(i) => i, StartPosition::Index(i) => i,
@ -205,17 +192,18 @@ impl<'r> Responder<'r, 'static> for FilePayload {
}; };
let range_end = match single_range.end { let range_end = match single_range.end {
EndPosition::Index(i) => i, EndPosition::Index(i) => i,
EndPosition::LastByte => self.info.size, EndPosition::LastByte => {
(range_start + MAX_UNBOUNDED_RANGE).min(self.info.size)
}
}; };
debug!("Range: {:?} {:?}", range_start..range_end, single_range);
let r_len = range_end - range_start; let r_len = range_end - range_start;
let r_body = RangeBody::new(self.file, self.info.size, ranges.ranges); let r_body = RangeBody::new(self.file, range_start..range_end);
response.set_status(Status::PartialContent); response.set_status(Status::PartialContent);
response.set_header(Header::new("content-length", r_len.to_string())); response.set_header(Header::new("content-length", r_len.to_string()));
response.set_header(Header::new( response.set_header(Header::new(
"content-range", "content-range",
format!("bytes {}-{}/{}", range_start, range_end, self.info.size), format!("bytes {}-{}/{}", range_start, range_end - 1, self.info.size),
)); ));
response.set_streamed_body(Box::pin(r_body)); response.set_streamed_body(Box::pin(r_body));
} }