From 6de7345e392ad4eb5c41685c4a18747e42f90fec Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 16 Jul 2023 14:22:40 +0100 Subject: [PATCH] Improve the wasm ui. (#178) * Improve the wasm ui. * Improve the UI. * Cosmetic changes. --- .cargo/config.toml | 3 ++ candle-wasm-example/Cargo.toml | 1 + candle-wasm-example/src/app.rs | 83 ++++++++++++++++++++++++------- candle-wasm-example/src/worker.rs | 48 ++++++++++-------- 4 files changed, 95 insertions(+), 40 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index a6c6276e..8ff190a4 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -3,3 +3,6 @@ rustflags = ["-C", "target-cpu=native"] [target.aarch64-apple-darwin] rustflags = ["-C", "target-cpu=native"] + +[target.wasm32-unknown-unknown] +rustflags = ["-C", "target-feature=+simd128"] diff --git a/candle-wasm-example/Cargo.toml b/candle-wasm-example/Cargo.toml index 4a3777ea..c4003efb 100644 --- a/candle-wasm-example/Cargo.toml +++ b/candle-wasm-example/Cargo.toml @@ -48,4 +48,5 @@ features = [ 'RequestInit', 'RequestMode', 'Response', + 'Performance', ] diff --git a/candle-wasm-example/src/app.rs b/candle-wasm-example/src/app.rs index 5a88ba2e..23519ebd 100644 --- a/candle-wasm-example/src/app.rs +++ b/candle-wasm-example/src/app.rs @@ -1,5 +1,5 @@ use crate::console_log; -use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; +use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput}; use js_sys::Date; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; @@ -38,13 +38,17 @@ pub enum Msg { UpdateStatus(String), SetDecoder(ModelData), WorkerInMsg(WorkerInput), - WorkerOutMsg(WorkerOutput), + WorkerOutMsg(Result), +} + +pub struct CurrentDecode { + start_time: Option, } pub struct App { status: String, - content: String, - decode_in_flight: bool, + segments: Vec, + current_decode: Option, worker: Box>, } @@ -60,6 +64,12 @@ async fn model_data_load() -> Result { }) } +fn performance_now() -> Option { + let window = web_sys::window()?; + let performance = window.performance()?; + Some(performance.now() / 1000.) +} + impl Component for App { type Message = Msg; type Properties = (); @@ -73,8 +83,8 @@ impl Component for App { let worker = Worker::bridge(std::rc::Rc::new(cb)); Self { status, - content: String::new(), - decode_in_flight: false, + segments: vec![], + current_decode: None, worker, } } @@ -103,18 +113,19 @@ impl Component for App { } Msg::Run(sample_index) => { let sample = SAMPLE_NAMES[sample_index]; - if self.decode_in_flight { - self.content = "already decoding some sample at the moment".to_string() + if self.current_decode.is_some() { + self.status = "already decoding some sample at the moment".to_string() } else { - self.decode_in_flight = true; + let start_time = performance_now(); + self.current_decode = Some(CurrentDecode { start_time }); self.status = format!("decoding {sample}"); - self.content = String::new(); + self.segments.clear(); ctx.link().send_future(async move { match fetch_url(sample).await { Err(err) => { - let value = Err(format!("decoding error: {err:?}")); - // Mimic a worker output to so as to release decode_in_flight - Msg::WorkerOutMsg(WorkerOutput { value }) + let output = Err(format!("decoding error: {err:?}")); + // Mimic a worker output to so as to release current_decode + Msg::WorkerOutMsg(output) } Ok(wav_bytes) => { Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes }) @@ -125,10 +136,26 @@ impl Component for App { // true } - Msg::WorkerOutMsg(WorkerOutput { value }) => { - self.status = "Worker responded!".to_string(); - self.content = format!("{value:?}"); - self.decode_in_flight = false; + Msg::WorkerOutMsg(output) => { + let dt = self.current_decode.as_ref().and_then(|current_decode| { + current_decode.start_time.and_then(|start_time| { + performance_now().map(|stop_time| stop_time - start_time) + }) + }); + self.current_decode = None; + match output { + Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), + Ok(WorkerOutput::Decoded(segments)) => { + self.status = match dt { + None => "decoding succeeded!".to_string(), + Some(dt) => format!("decoding succeeded in {:.2}s", dt), + }; + self.segments = segments; + } + Err(err) => { + self.status = format!("decoding error {err:?}"); + } + } true } Msg::WorkerInMsg(inp) => { @@ -170,12 +197,30 @@ impl Component for App { {&self.status} { - if self.decode_in_flight { + if self.current_decode.is_some() { html! { } } else { html!{

- {&self.content} + { + self.segments.iter().map(|segment| { html! { + <> + + { + format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})", + segment.start, + segment.start + segment.duration, + segment.dr.avg_logprob, + segment.dr.no_speech_prob, + ) + } + +
+ {&segment.dr.text} +
+ + } }).collect::() + }

} diff --git a/candle-wasm-example/src/worker.rs b/candle-wasm-example/src/worker.rs index c1074ecd..7b9ffbec 100644 --- a/candle-wasm-example/src/worker.rs +++ b/candle-wasm-example/src/worker.rs @@ -2,7 +2,7 @@ use crate::model::{Config, Whisper}; use anyhow::Error as E; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; -use rand::distributions::Distribution; +use rand::{distributions::Distribution, rngs::StdRng, SeedableRng}; use serde::{Deserialize, Serialize}; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; @@ -59,20 +59,20 @@ pub const SUPPRESS_TOKENS: [u32; 91] = [ ]; #[derive(Debug, Clone, Serialize, Deserialize)] -struct DecodingResult { - tokens: Vec, - text: String, - avg_logprob: f64, - no_speech_prob: f64, +pub struct DecodingResult { + pub tokens: Vec, + pub text: String, + pub avg_logprob: f64, + pub no_speech_prob: f64, temperature: f64, compression_ratio: f64, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Segment { - start: f64, - duration: f64, - dr: DecodingResult, + pub start: f64, + pub duration: f64, + pub dr: DecodingResult, } pub struct Decoder { @@ -107,7 +107,7 @@ impl Decoder { }) } - fn decode(&self, mel: &Tensor, t: f64) -> anyhow::Result { + fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result { let model = &self.model; let audio_features = model.encoder.forward(mel)?; console_log!("audio features: {:?}", audio_features.dims()); @@ -142,8 +142,7 @@ impl Decoder { let prs = (&logits / t)?.softmax(0)?; let logits_v: Vec = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; - let mut rng = rand::thread_rng(); - distr.sample(&mut rng) as u32 + distr.sample(rng) as u32 } else { let logits_v: Vec = logits.to_vec1()?; logits_v @@ -179,9 +178,13 @@ impl Decoder { }) } - fn decode_with_fallback(&self, segment: &Tensor) -> anyhow::Result { + fn decode_with_fallback( + &self, + segment: &Tensor, + rng: &mut StdRng, + ) -> anyhow::Result { for (i, &t) in TEMPERATURES.iter().enumerate() { - let dr: Result = self.decode(segment, t); + let dr: Result = self.decode(segment, t, rng); if i == TEMPERATURES.len() - 1 { return dr; } @@ -203,6 +206,7 @@ impl Decoder { } fn run(&self, mel: &Tensor) -> anyhow::Result> { + let mut rng = StdRng::seed_from_u64(299792458); let (_, _, content_frames) = mel.shape().r3()?; let mut seek = 0; let mut segments = vec![]; @@ -211,7 +215,7 @@ impl Decoder { let segment_size = usize::min(content_frames - seek, N_FRAMES); let mel_segment = mel.narrow(2, seek, segment_size)?; let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let dr = self.decode_with_fallback(&mel_segment)?; + let dr = self.decode_with_fallback(&mel_segment, &mut rng)?; seek += segment_size; if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { console_log!("no speech detected, skipping {seek} {dr:?}"); @@ -289,14 +293,15 @@ pub enum WorkerInput { } #[derive(Serialize, Deserialize)] -pub struct WorkerOutput { - pub value: Result, String>, +pub enum WorkerOutput { + Decoded(Vec), + WeightsLoaded, } impl yew_agent::Worker for Worker { type Input = WorkerInput; type Message = (); - type Output = WorkerOutput; + type Output = Result; type Reach = Public; fn create(link: WorkerLink) -> Self { @@ -311,11 +316,11 @@ impl yew_agent::Worker for Worker { } fn handle_input(&mut self, msg: Self::Input, id: HandlerId) { - let value = match msg { + let output = match msg { WorkerInput::ModelData(md) => match Decoder::load(md) { Ok(decoder) => { self.decoder = Some(decoder); - Ok(vec![]) + Ok(WorkerOutput::WeightsLoaded) } Err(err) => Err(format!("model creation error {err:?}")), }, @@ -323,10 +328,11 @@ impl yew_agent::Worker for Worker { None => Err("model has not been set".to_string()), Some(decoder) => decoder .convert_and_run(&wav_bytes) + .map(WorkerOutput::Decoded) .map_err(|e| e.to_string()), }, }; - self.link.respond(id, WorkerOutput { value }); + self.link.respond(id, output); } fn name_of_resource() -> &'static str {