mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Improve the wasm ui. (#178)
* Improve the wasm ui. * Improve the UI. * Cosmetic changes.
This commit is contained in:
@ -3,3 +3,6 @@ rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = ["-C", "target-feature=+simd128"]
|
||||
|
@ -48,4 +48,5 @@ features = [
|
||||
'RequestInit',
|
||||
'RequestMode',
|
||||
'Response',
|
||||
'Performance',
|
||||
]
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::console_log;
|
||||
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
|
||||
use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
|
||||
use js_sys::Date;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_futures::JsFuture;
|
||||
@ -38,13 +38,17 @@ pub enum Msg {
|
||||
UpdateStatus(String),
|
||||
SetDecoder(ModelData),
|
||||
WorkerInMsg(WorkerInput),
|
||||
WorkerOutMsg(WorkerOutput),
|
||||
WorkerOutMsg(Result<WorkerOutput, String>),
|
||||
}
|
||||
|
||||
pub struct CurrentDecode {
|
||||
start_time: Option<f64>,
|
||||
}
|
||||
|
||||
pub struct App {
|
||||
status: String,
|
||||
content: String,
|
||||
decode_in_flight: bool,
|
||||
segments: Vec<Segment>,
|
||||
current_decode: Option<CurrentDecode>,
|
||||
worker: Box<dyn Bridge<Worker>>,
|
||||
}
|
||||
|
||||
@ -60,6 +64,12 @@ async fn model_data_load() -> Result<ModelData, JsValue> {
|
||||
})
|
||||
}
|
||||
|
||||
fn performance_now() -> Option<f64> {
|
||||
let window = web_sys::window()?;
|
||||
let performance = window.performance()?;
|
||||
Some(performance.now() / 1000.)
|
||||
}
|
||||
|
||||
impl Component for App {
|
||||
type Message = Msg;
|
||||
type Properties = ();
|
||||
@ -73,8 +83,8 @@ impl Component for App {
|
||||
let worker = Worker::bridge(std::rc::Rc::new(cb));
|
||||
Self {
|
||||
status,
|
||||
content: String::new(),
|
||||
decode_in_flight: false,
|
||||
segments: vec![],
|
||||
current_decode: None,
|
||||
worker,
|
||||
}
|
||||
}
|
||||
@ -103,18 +113,19 @@ impl Component for App {
|
||||
}
|
||||
Msg::Run(sample_index) => {
|
||||
let sample = SAMPLE_NAMES[sample_index];
|
||||
if self.decode_in_flight {
|
||||
self.content = "already decoding some sample at the moment".to_string()
|
||||
if self.current_decode.is_some() {
|
||||
self.status = "already decoding some sample at the moment".to_string()
|
||||
} else {
|
||||
self.decode_in_flight = true;
|
||||
let start_time = performance_now();
|
||||
self.current_decode = Some(CurrentDecode { start_time });
|
||||
self.status = format!("decoding {sample}");
|
||||
self.content = String::new();
|
||||
self.segments.clear();
|
||||
ctx.link().send_future(async move {
|
||||
match fetch_url(sample).await {
|
||||
Err(err) => {
|
||||
let value = Err(format!("decoding error: {err:?}"));
|
||||
// Mimic a worker output to so as to release decode_in_flight
|
||||
Msg::WorkerOutMsg(WorkerOutput { value })
|
||||
let output = Err(format!("decoding error: {err:?}"));
|
||||
// Mimic a worker output to so as to release current_decode
|
||||
Msg::WorkerOutMsg(output)
|
||||
}
|
||||
Ok(wav_bytes) => {
|
||||
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
|
||||
@ -125,10 +136,26 @@ impl Component for App {
|
||||
//
|
||||
true
|
||||
}
|
||||
Msg::WorkerOutMsg(WorkerOutput { value }) => {
|
||||
self.status = "Worker responded!".to_string();
|
||||
self.content = format!("{value:?}");
|
||||
self.decode_in_flight = false;
|
||||
Msg::WorkerOutMsg(output) => {
|
||||
let dt = self.current_decode.as_ref().and_then(|current_decode| {
|
||||
current_decode.start_time.and_then(|start_time| {
|
||||
performance_now().map(|stop_time| stop_time - start_time)
|
||||
})
|
||||
});
|
||||
self.current_decode = None;
|
||||
match output {
|
||||
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
|
||||
Ok(WorkerOutput::Decoded(segments)) => {
|
||||
self.status = match dt {
|
||||
None => "decoding succeeded!".to_string(),
|
||||
Some(dt) => format!("decoding succeeded in {:.2}s", dt),
|
||||
};
|
||||
self.segments = segments;
|
||||
}
|
||||
Err(err) => {
|
||||
self.status = format!("decoding error {err:?}");
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
Msg::WorkerInMsg(inp) => {
|
||||
@ -170,12 +197,30 @@ impl Component for App {
|
||||
{&self.status}
|
||||
</h2>
|
||||
{
|
||||
if self.decode_in_flight {
|
||||
if self.current_decode.is_some() {
|
||||
html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
|
||||
} else { html!{
|
||||
<blockquote>
|
||||
<p>
|
||||
{&self.content}
|
||||
{
|
||||
self.segments.iter().map(|segment| { html! {
|
||||
<>
|
||||
<i>
|
||||
{
|
||||
format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
|
||||
segment.start,
|
||||
segment.start + segment.duration,
|
||||
segment.dr.avg_logprob,
|
||||
segment.dr.no_speech_prob,
|
||||
)
|
||||
}
|
||||
</i>
|
||||
<br/ >
|
||||
{&segment.dr.text}
|
||||
<br/ >
|
||||
</>
|
||||
} }).collect::<Html>()
|
||||
}
|
||||
</p>
|
||||
</blockquote>
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use crate::model::{Config, Whisper};
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use rand::distributions::Distribution;
|
||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
@ -59,20 +59,20 @@ pub const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct DecodingResult {
|
||||
tokens: Vec<u32>,
|
||||
text: String,
|
||||
avg_logprob: f64,
|
||||
no_speech_prob: f64,
|
||||
pub struct DecodingResult {
|
||||
pub tokens: Vec<u32>,
|
||||
pub text: String,
|
||||
pub avg_logprob: f64,
|
||||
pub no_speech_prob: f64,
|
||||
temperature: f64,
|
||||
compression_ratio: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Segment {
|
||||
start: f64,
|
||||
duration: f64,
|
||||
dr: DecodingResult,
|
||||
pub start: f64,
|
||||
pub duration: f64,
|
||||
pub dr: DecodingResult,
|
||||
}
|
||||
|
||||
pub struct Decoder {
|
||||
@ -107,7 +107,7 @@ impl Decoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
|
||||
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
console_log!("audio features: {:?}", audio_features.dims());
|
||||
@ -142,8 +142,7 @@ impl Decoder {
|
||||
let prs = (&logits / t)?.softmax(0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let mut rng = rand::thread_rng();
|
||||
distr.sample(&mut rng) as u32
|
||||
distr.sample(rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
logits_v
|
||||
@ -179,9 +178,13 @@ impl Decoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_with_fallback(&self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
|
||||
fn decode_with_fallback(
|
||||
&self,
|
||||
segment: &Tensor,
|
||||
rng: &mut StdRng,
|
||||
) -> anyhow::Result<DecodingResult> {
|
||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult, _> = self.decode(segment, t);
|
||||
let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
|
||||
if i == TEMPERATURES.len() - 1 {
|
||||
return dr;
|
||||
}
|
||||
@ -203,6 +206,7 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
@ -211,7 +215,7 @@ impl Decoder {
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||
console_log!("no speech detected, skipping {seek} {dr:?}");
|
||||
@ -289,14 +293,15 @@ pub enum WorkerInput {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct WorkerOutput {
|
||||
pub value: Result<Vec<Segment>, String>,
|
||||
pub enum WorkerOutput {
|
||||
Decoded(Vec<Segment>),
|
||||
WeightsLoaded,
|
||||
}
|
||||
|
||||
impl yew_agent::Worker for Worker {
|
||||
type Input = WorkerInput;
|
||||
type Message = ();
|
||||
type Output = WorkerOutput;
|
||||
type Output = Result<WorkerOutput, String>;
|
||||
type Reach = Public<Self>;
|
||||
|
||||
fn create(link: WorkerLink<Self>) -> Self {
|
||||
@ -311,11 +316,11 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
|
||||
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
|
||||
let value = match msg {
|
||||
let output = match msg {
|
||||
WorkerInput::ModelData(md) => match Decoder::load(md) {
|
||||
Ok(decoder) => {
|
||||
self.decoder = Some(decoder);
|
||||
Ok(vec![])
|
||||
Ok(WorkerOutput::WeightsLoaded)
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
@ -323,10 +328,11 @@ impl yew_agent::Worker for Worker {
|
||||
None => Err("model has not been set".to_string()),
|
||||
Some(decoder) => decoder
|
||||
.convert_and_run(&wav_bytes)
|
||||
.map(WorkerOutput::Decoded)
|
||||
.map_err(|e| e.to_string()),
|
||||
},
|
||||
};
|
||||
self.link.respond(id, WorkerOutput { value });
|
||||
self.link.respond(id, output);
|
||||
}
|
||||
|
||||
fn name_of_resource() -> &'static str {
|
||||
|
Reference in New Issue
Block a user