Improve the wasm ui. (#178)

* Improve the wasm ui.

* Improve the UI.

* Cosmetic changes.
This commit is contained in:
Laurent Mazare
2023-07-16 14:22:40 +01:00
committed by GitHub
parent 104f89df31
commit 6de7345e39
4 changed files with 95 additions and 40 deletions

View File

@ -48,4 +48,5 @@ features = [
'RequestInit',
'RequestMode',
'Response',
'Performance',
]

View File

@ -1,5 +1,5 @@
use crate::console_log;
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
use js_sys::Date;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
@ -38,13 +38,17 @@ pub enum Msg {
UpdateStatus(String),
SetDecoder(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(WorkerOutput),
WorkerOutMsg(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
start_time: Option<f64>,
}
pub struct App {
status: String,
content: String,
decode_in_flight: bool,
segments: Vec<Segment>,
current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>,
}
@ -60,6 +64,12 @@ async fn model_data_load() -> Result<ModelData, JsValue> {
})
}
fn performance_now() -> Option<f64> {
let window = web_sys::window()?;
let performance = window.performance()?;
Some(performance.now() / 1000.)
}
impl Component for App {
type Message = Msg;
type Properties = ();
@ -73,8 +83,8 @@ impl Component for App {
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
status,
content: String::new(),
decode_in_flight: false,
segments: vec![],
current_decode: None,
worker,
}
}
@ -103,18 +113,19 @@ impl Component for App {
}
Msg::Run(sample_index) => {
let sample = SAMPLE_NAMES[sample_index];
if self.decode_in_flight {
self.content = "already decoding some sample at the moment".to_string()
if self.current_decode.is_some() {
self.status = "already decoding some sample at the moment".to_string()
} else {
self.decode_in_flight = true;
let start_time = performance_now();
self.current_decode = Some(CurrentDecode { start_time });
self.status = format!("decoding {sample}");
self.content = String::new();
self.segments.clear();
ctx.link().send_future(async move {
match fetch_url(sample).await {
Err(err) => {
let value = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release decode_in_flight
Msg::WorkerOutMsg(WorkerOutput { value })
let output = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release current_decode
Msg::WorkerOutMsg(output)
}
Ok(wav_bytes) => {
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
@ -125,10 +136,26 @@ impl Component for App {
//
true
}
Msg::WorkerOutMsg(WorkerOutput { value }) => {
self.status = "Worker responded!".to_string();
self.content = format!("{value:?}");
self.decode_in_flight = false;
Msg::WorkerOutMsg(output) => {
let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time)
})
});
self.current_decode = None;
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::Decoded(segments)) => {
self.status = match dt {
None => "decoding succeeded!".to_string(),
Some(dt) => format!("decoding succeeded in {:.2}s", dt),
};
self.segments = segments;
}
Err(err) => {
self.status = format!("decoding error {err:?}");
}
}
true
}
Msg::WorkerInMsg(inp) => {
@ -170,12 +197,30 @@ impl Component for App {
{&self.status}
</h2>
{
if self.decode_in_flight {
if self.current_decode.is_some() {
html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
} else { html!{
<blockquote>
<p>
{&self.content}
{
self.segments.iter().map(|segment| { html! {
<>
<i>
{
format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
segment.start,
segment.start + segment.duration,
segment.dr.avg_logprob,
segment.dr.no_speech_prob,
)
}
</i>
<br/ >
{&segment.dr.text}
<br/ >
</>
} }).collect::<Html>()
}
</p>
</blockquote>
}

View File

@ -2,7 +2,7 @@ use crate::model::{Config, Whisper};
use anyhow::Error as E;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use rand::distributions::Distribution;
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@ -59,20 +59,20 @@ pub const SUPPRESS_TOKENS: [u32; 91] = [
];
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DecodingResult {
tokens: Vec<u32>,
text: String,
avg_logprob: f64,
no_speech_prob: f64,
pub struct DecodingResult {
pub tokens: Vec<u32>,
pub text: String,
pub avg_logprob: f64,
pub no_speech_prob: f64,
temperature: f64,
compression_ratio: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Segment {
start: f64,
duration: f64,
dr: DecodingResult,
pub start: f64,
pub duration: f64,
pub dr: DecodingResult,
}
pub struct Decoder {
@ -107,7 +107,7 @@ impl Decoder {
})
}
fn decode(&self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
console_log!("audio features: {:?}", audio_features.dims());
@ -142,8 +142,7 @@ impl Decoder {
let prs = (&logits / t)?.softmax(0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let mut rng = rand::thread_rng();
distr.sample(&mut rng) as u32
distr.sample(rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
@ -179,9 +178,13 @@ impl Decoder {
})
}
fn decode_with_fallback(&self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
fn decode_with_fallback(
&self,
segment: &Tensor,
rng: &mut StdRng,
) -> anyhow::Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult, _> = self.decode(segment, t);
let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
if i == TEMPERATURES.len() - 1 {
return dr;
}
@ -203,6 +206,7 @@ impl Decoder {
}
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let mut rng = StdRng::seed_from_u64(299792458);
let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0;
let mut segments = vec![];
@ -211,7 +215,7 @@ impl Decoder {
let segment_size = usize::min(content_frames - seek, N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
console_log!("no speech detected, skipping {seek} {dr:?}");
@ -289,14 +293,15 @@ pub enum WorkerInput {
}
#[derive(Serialize, Deserialize)]
pub struct WorkerOutput {
pub value: Result<Vec<Segment>, String>,
pub enum WorkerOutput {
Decoded(Vec<Segment>),
WeightsLoaded,
}
impl yew_agent::Worker for Worker {
type Input = WorkerInput;
type Message = ();
type Output = WorkerOutput;
type Output = Result<WorkerOutput, String>;
type Reach = Public<Self>;
fn create(link: WorkerLink<Self>) -> Self {
@ -311,11 +316,11 @@ impl yew_agent::Worker for Worker {
}
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
let value = match msg {
let output = match msg {
WorkerInput::ModelData(md) => match Decoder::load(md) {
Ok(decoder) => {
self.decoder = Some(decoder);
Ok(vec![])
Ok(WorkerOutput::WeightsLoaded)
}
Err(err) => Err(format!("model creation error {err:?}")),
},
@ -323,10 +328,11 @@ impl yew_agent::Worker for Worker {
None => Err("model has not been set".to_string()),
Some(decoder) => decoder
.convert_and_run(&wav_bytes)
.map(WorkerOutput::Decoded)
.map_err(|e| e.to_string()),
},
};
self.link.respond(id, WorkerOutput { value });
self.link.respond(id, output);
}
fn name_of_resource() -> &'static str {