From e6bddcf641125d4523e6a6f5a2fc5083fd54d904 Mon Sep 17 00:00:00 2001 From: Kieran Date: Fri, 20 Jun 2025 10:46:26 +0100 Subject: [PATCH] fix: tag recording in ended stream event --- crates/core/src/egress/recorder.rs | 4 +- crates/zap-stream-db/src/db.rs | 4 +- crates/zap-stream/src/overseer.rs | 118 ++++++++++++++++++----------- 3 files changed, 78 insertions(+), 48 deletions(-) diff --git a/crates/core/src/egress/recorder.rs b/crates/core/src/egress/recorder.rs index 75d5ac9..d84b3af 100644 --- a/crates/core/src/egress/recorder.rs +++ b/crates/core/src/egress/recorder.rs @@ -18,11 +18,13 @@ pub struct RecorderEgress { } impl RecorderEgress { + pub const FILENAME: &'static str = "recording.mp4"; + pub fn new<'a>( out_dir: PathBuf, variants: impl Iterator, ) -> Result { - let out_file = out_dir.join("recording.mp4"); + let out_file = out_dir.join(Self::FILENAME); let mut var_map = HashMap::new(); let muxer = unsafe { let mut m = Muxer::builder() diff --git a/crates/zap-stream-db/src/db.rs b/crates/zap-stream-db/src/db.rs index 4b9e9de..f817592 100644 --- a/crates/zap-stream-db/src/db.rs +++ b/crates/zap-stream-db/src/db.rs @@ -358,10 +358,10 @@ impl ZapStreamDb { } /// Get ingest endpoint by id - pub async fn get_ingest_endpoint(&self, endpoint_id: u64) -> Result> { + pub async fn get_ingest_endpoint(&self, endpoint_id: u64) -> Result { Ok(sqlx::query_as("select * from ingest_endpoint where id = ?") .bind(endpoint_id) - .fetch_optional(&self.db) + .fetch_one(&self.db) .await?) } diff --git a/crates/zap-stream/src/overseer.rs b/crates/zap-stream/src/overseer.rs index c8629aa..e127e9e 100644 --- a/crates/zap-stream/src/overseer.rs +++ b/crates/zap-stream/src/overseer.rs @@ -16,6 +16,7 @@ use tokio::sync::RwLock; use url::Url; use uuid::Uuid; use zap_stream_core::egress::hls::HlsEgress; +use zap_stream_core::egress::recorder::RecorderEgress; use zap_stream_core::egress::{EgressConfig, EgressSegment}; use zap_stream_core::ingress::ConnectionInfo; use zap_stream_core::overseer::{IngressInfo, IngressStream, IngressStreamType, Overseer}; @@ -230,19 +231,8 @@ impl ZapStreamOverseer { ) -> Result { // TODO: remove assumption that HLS is enabled let pipeline_dir = PathBuf::from(stream.id.to_string()); - let extra_tags = vec![ + let mut extra_tags = vec![ Tag::parse(["p", hex::encode(pubkey).as_str(), "", "host"])?, - Tag::parse([ - "streaming", - self.map_to_public_url( - pipeline_dir - .join(HlsEgress::PATH) - .join("live.m3u8") - .to_str() - .unwrap(), - )? - .as_str(), - ])?, Tag::parse([ "image", self.map_to_public_url(pipeline_dir.join("thumb.webp").to_str().unwrap())? @@ -250,6 +240,43 @@ impl ZapStreamOverseer { ])?, Tag::parse(["service", self.map_to_public_url("api/v1")?.as_str()])?, ]; + match stream.state { + UserStreamState::Live => { + extra_tags.push(Tag::parse([ + "streaming", + self.map_to_public_url( + pipeline_dir + .join(HlsEgress::PATH) + .join("live.m3u8") + .to_str() + .unwrap(), + )? + .as_str(), + ])?); + } + UserStreamState::Ended => { + if let Some(ep) = stream.endpoint_id { + let endpoint = self.db.get_ingest_endpoint(ep).await?; + let caps = parse_capabilities(&endpoint.capabilities); + let has_recording = caps + .iter() + .any(|c| matches!(c, EndpointCapability::DVR { .. })); + if has_recording { + extra_tags.push(Tag::parse([ + "recording", + self.map_to_public_url( + pipeline_dir + .join(RecorderEgress::FILENAME) + .to_str() + .unwrap(), + )? + .as_str(), + ])?); + } + } + } + _ => {} + } let ev = self .stream_to_event_builder(stream)? .tags(extra_tags) @@ -357,7 +384,7 @@ impl Overseer for ZapStreamOverseer { // Get ingest endpoint configuration based on connection type let endpoint = self.detect_endpoint(&connection).await?; - let caps = parse_capabilities(&endpoint.capabilities.unwrap_or("".to_string())); + let caps = parse_capabilities(&endpoint.capabilities); let cfg = get_variants_from_endpoint(&stream_info, &caps)?; if cfg.video_src.is_none() || cfg.variants.is_empty() { @@ -451,11 +478,8 @@ impl Overseer for ZapStreamOverseer { // Get the cost per minute from the ingest endpoint, or use default let cost_per_minute = if let Some(endpoint_id) = stream.endpoint_id { - if let Some(endpoint) = self.db.get_ingest_endpoint(endpoint_id).await? { - endpoint.cost - } else { - bail!("Endpoint doesnt exist"); - } + let ep = self.db.get_ingest_endpoint(endpoint_id).await?; + ep.cost } else { bail!("Endpoint id not set on stream"); }; @@ -586,36 +610,40 @@ enum EndpointCapability { DVR { height: u16 }, } -fn parse_capabilities(cap: &str) -> Vec { - cap.to_ascii_lowercase() - .split(',') - .map_while(|c| { - let cs = c.split(':').collect::>(); - match cs[0] { - "variant" if cs[1] == "source" => Some(EndpointCapability::SourceVariant), - "variant" if cs.len() == 3 => { - if let (Ok(h), Ok(br)) = (cs[1].parse(), cs[2].parse()) { - Some(EndpointCapability::Variant { - height: h, - bitrate: br, - }) - } else { - warn!("Invalid variant: {}", c); - None +fn parse_capabilities(cap: &Option) -> Vec { + if let Some(cap) = cap { + cap.to_ascii_lowercase() + .split(',') + .map_while(|c| { + let cs = c.split(':').collect::>(); + match cs[0] { + "variant" if cs[1] == "source" => Some(EndpointCapability::SourceVariant), + "variant" if cs.len() == 3 => { + if let (Ok(h), Ok(br)) = (cs[1].parse(), cs[2].parse()) { + Some(EndpointCapability::Variant { + height: h, + bitrate: br, + }) + } else { + warn!("Invalid variant: {}", c); + None + } } - } - "dvr" if cs.len() == 2 => { - if let Ok(h) = cs[1].parse() { - Some(EndpointCapability::DVR { height: h }) - } else { - warn!("Invalid dvr: {}", c); - None + "dvr" if cs.len() == 2 => { + if let Ok(h) = cs[1].parse() { + Some(EndpointCapability::DVR { height: h }) + } else { + warn!("Invalid dvr: {}", c); + None + } } + _ => None, } - _ => None, - } - }) - .collect() + }) + .collect() + } else { + vec![] + } } fn get_variants_from_endpoint<'a>(