Improve the wasm ui. (#178)

* Improve the wasm ui.

* Improve the UI.

* Cosmetic changes.
This commit is contained in:
Laurent Mazare
2023-07-16 14:22:40 +01:00
committed by GitHub
parent 104f89df31
commit 6de7345e39
4 changed files with 95 additions and 40 deletions

View File

@ -3,3 +3,6 @@ rustflags = ["-C", "target-cpu=native"]
[target.aarch64-apple-darwin] [target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=native"] rustflags = ["-C", "target-cpu=native"]
[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+simd128"]

View File

@ -48,4 +48,5 @@ features = [
'RequestInit', 'RequestInit',
'RequestMode', 'RequestMode',
'Response', 'Response',
'Performance',
] ]

View File

@ -1,5 +1,5 @@
use crate::console_log; use crate::console_log;
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
use js_sys::Date; use js_sys::Date;
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture; use wasm_bindgen_futures::JsFuture;
@ -38,13 +38,17 @@ pub enum Msg {
UpdateStatus(String), UpdateStatus(String),
SetDecoder(ModelData), SetDecoder(ModelData),
WorkerInMsg(WorkerInput), WorkerInMsg(WorkerInput),
WorkerOutMsg(WorkerOutput), WorkerOutMsg(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
start_time: Option<f64>,
} }
pub struct App { pub struct App {
status: String, status: String,
content: String, segments: Vec<Segment>,
decode_in_flight: bool, current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>, worker: Box<dyn Bridge<Worker>>,
} }
@ -60,6 +64,12 @@ async fn model_data_load() -> Result<ModelData, JsValue> {
}) })
} }
fn performance_now() -> Option<f64> {
let window = web_sys::window()?;
let performance = window.performance()?;
Some(performance.now() / 1000.)
}
impl Component for App { impl Component for App {
type Message = Msg; type Message = Msg;
type Properties = (); type Properties = ();
@ -73,8 +83,8 @@ impl Component for App {
let worker = Worker::bridge(std::rc::Rc::new(cb)); let worker = Worker::bridge(std::rc::Rc::new(cb));
Self { Self {
status, status,
content: String::new(), segments: vec![],
decode_in_flight: false, current_decode: None,
worker, worker,
} }
} }
@ -103,18 +113,19 @@ impl Component for App {
} }
Msg::Run(sample_index) => { Msg::Run(sample_index) => {
let sample = SAMPLE_NAMES[sample_index]; let sample = SAMPLE_NAMES[sample_index];
if self.decode_in_flight { if self.current_decode.is_some() {
self.content = "already decoding some sample at the moment".to_string() self.status = "already decoding some sample at the moment".to_string()
} else { } else {
self.decode_in_flight = true; let start_time = performance_now();
self.current_decode = Some(CurrentDecode { start_time });
self.status = format!("decoding {sample}"); self.status = format!("decoding {sample}");
self.content = String::new(); self.segments.clear();
ctx.link().send_future(async move { ctx.link().send_future(async move {
match fetch_url(sample).await { match fetch_url(sample).await {
Err(err) => { Err(err) => {
let value = Err(format!("decoding error: {err:?}")); let output = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release decode_in_flight // Mimic a worker output to so as to release current_decode
Msg::WorkerOutMsg(WorkerOutput { value }) Msg::WorkerOutMsg(output)
} }
Ok(wav_bytes) => { Ok(wav_bytes) => {
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes }) Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
@ -125,10 +136,26 @@ impl Component for App {
// //
true true
} }
Msg::WorkerOutMsg(WorkerOutput { value }) => { Msg::WorkerOutMsg(output) => {
self.status = "Worker responded!".to_string(); let dt = self.current_decode.as_ref().and_then(|current_decode| {
self.content = format!("{value:?}"); current_decode.start_time.and_then(|start_time| {
self.decode_in_flight = false; performance_now().map(|stop_time| stop_time - start_time)
})
});
self.current_decode = None;
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::Decoded(segments)) => {
self.status = match dt {
None => "decoding succeeded!".to_string(),
Some(dt) => format!("decoding succeeded in {:.2}s", dt),
};
self.segments = segments;
}
Err(err) => {
self.status = format!("decoding error {err:?}");
}
}
true true
} }
Msg::WorkerInMsg(inp) => { Msg::WorkerInMsg(inp) => {
@ -170,12 +197,30 @@ impl Component for App {
{&self.status} {&self.status}
</h2> </h2>
{ {
if self.decode_in_flight { if self.current_decode.is_some() {
html! { <progress id="progress-bar" aria-label="decoding…"></progress> } html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
} else { html!{ } else { html!{
<blockquote> <blockquote>
<p> <p>
{&self.content} {
self.segments.iter().map(|segment| { html! {
<>
<i>
{
format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
segment.start,
segment.start + segment.duration,
segment.dr.avg_logprob,
segment.dr.no_speech_prob,
)
}
</i>
<br/ >
{&segment.dr.text}
<br/ >
</>
} }).collect::<Html>()
}
</p> </p>
</blockquote> </blockquote>
} }

View File

@ -2,7 +2,7 @@ use crate::model::{Config, Whisper};
use anyhow::Error as E; use anyhow::Error as E;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use rand::distributions::Distribution; use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
@ -59,20 +59,20 @@ pub const SUPPRESS_TOKENS: [u32; 91] = [
]; ];
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct DecodingResult { pub struct DecodingResult {
tokens: Vec<u32>, pub tokens: Vec<u32>,
text: String, pub text: String,
avg_logprob: f64, pub avg_logprob: f64,
no_speech_prob: f64, pub no_speech_prob: f64,
temperature: f64, temperature: f64,
compression_ratio: f64, compression_ratio: f64,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Segment { pub struct Segment {
start: f64, pub start: f64,
duration: f64, pub duration: f64,
dr: DecodingResult, pub dr: DecodingResult,
} }
pub struct Decoder { pub struct Decoder {
@ -107,7 +107,7 @@ impl Decoder {
}) })
} }
fn decode(&self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> { fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
let model = &self.model; let model = &self.model;
let audio_features = model.encoder.forward(mel)?; let audio_features = model.encoder.forward(mel)?;
console_log!("audio features: {:?}", audio_features.dims()); console_log!("audio features: {:?}", audio_features.dims());
@ -142,8 +142,7 @@ impl Decoder {
let prs = (&logits / t)?.softmax(0)?; let prs = (&logits / t)?.softmax(0)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let mut rng = rand::thread_rng(); distr.sample(rng) as u32
distr.sample(&mut rng) as u32
} else { } else {
let logits_v: Vec<f32> = logits.to_vec1()?; let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v logits_v
@ -179,9 +178,13 @@ impl Decoder {
}) })
} }
fn decode_with_fallback(&self, segment: &Tensor) -> anyhow::Result<DecodingResult> { fn decode_with_fallback(
&self,
segment: &Tensor,
rng: &mut StdRng,
) -> anyhow::Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() { for (i, &t) in TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult, _> = self.decode(segment, t); let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
if i == TEMPERATURES.len() - 1 { if i == TEMPERATURES.len() - 1 {
return dr; return dr;
} }
@ -203,6 +206,7 @@ impl Decoder {
} }
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> { fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let mut rng = StdRng::seed_from_u64(299792458);
let (_, _, content_frames) = mel.shape().r3()?; let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0; let mut seek = 0;
let mut segments = vec![]; let mut segments = vec![];
@ -211,7 +215,7 @@ impl Decoder {
let segment_size = usize::min(content_frames - seek, N_FRAMES); let segment_size = usize::min(content_frames - seek, N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?; let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?; let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
seek += segment_size; seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
console_log!("no speech detected, skipping {seek} {dr:?}"); console_log!("no speech detected, skipping {seek} {dr:?}");
@ -289,14 +293,15 @@ pub enum WorkerInput {
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct WorkerOutput { pub enum WorkerOutput {
pub value: Result<Vec<Segment>, String>, Decoded(Vec<Segment>),
WeightsLoaded,
} }
impl yew_agent::Worker for Worker { impl yew_agent::Worker for Worker {
type Input = WorkerInput; type Input = WorkerInput;
type Message = (); type Message = ();
type Output = WorkerOutput; type Output = Result<WorkerOutput, String>;
type Reach = Public<Self>; type Reach = Public<Self>;
fn create(link: WorkerLink<Self>) -> Self { fn create(link: WorkerLink<Self>) -> Self {
@ -311,11 +316,11 @@ impl yew_agent::Worker for Worker {
} }
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) { fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
let value = match msg { let output = match msg {
WorkerInput::ModelData(md) => match Decoder::load(md) { WorkerInput::ModelData(md) => match Decoder::load(md) {
Ok(decoder) => { Ok(decoder) => {
self.decoder = Some(decoder); self.decoder = Some(decoder);
Ok(vec![]) Ok(WorkerOutput::WeightsLoaded)
} }
Err(err) => Err(format!("model creation error {err:?}")), Err(err) => Err(format!("model creation error {err:?}")),
}, },
@ -323,10 +328,11 @@ impl yew_agent::Worker for Worker {
None => Err("model has not been set".to_string()), None => Err("model has not been set".to_string()),
Some(decoder) => decoder Some(decoder) => decoder
.convert_and_run(&wav_bytes) .convert_and_run(&wav_bytes)
.map(WorkerOutput::Decoded)
.map_err(|e| e.to_string()), .map_err(|e| e.to_string()),
}, },
}; };
self.link.respond(id, WorkerOutput { value }); self.link.respond(id, output);
} }
fn name_of_resource() -> &'static str { fn name_of_resource() -> &'static str {