mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Improve the wasm ui. (#178)
* Improve the wasm ui. * Improve the UI. * Cosmetic changes.
This commit is contained in:
@ -3,3 +3,6 @@ rustflags = ["-C", "target-cpu=native"]
|
|||||||
|
|
||||||
[target.aarch64-apple-darwin]
|
[target.aarch64-apple-darwin]
|
||||||
rustflags = ["-C", "target-cpu=native"]
|
rustflags = ["-C", "target-cpu=native"]
|
||||||
|
|
||||||
|
[target.wasm32-unknown-unknown]
|
||||||
|
rustflags = ["-C", "target-feature=+simd128"]
|
||||||
|
@ -48,4 +48,5 @@ features = [
|
|||||||
'RequestInit',
|
'RequestInit',
|
||||||
'RequestMode',
|
'RequestMode',
|
||||||
'Response',
|
'Response',
|
||||||
|
'Performance',
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::console_log;
|
use crate::console_log;
|
||||||
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
|
use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
|
||||||
use js_sys::Date;
|
use js_sys::Date;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
use wasm_bindgen_futures::JsFuture;
|
use wasm_bindgen_futures::JsFuture;
|
||||||
@ -38,13 +38,17 @@ pub enum Msg {
|
|||||||
UpdateStatus(String),
|
UpdateStatus(String),
|
||||||
SetDecoder(ModelData),
|
SetDecoder(ModelData),
|
||||||
WorkerInMsg(WorkerInput),
|
WorkerInMsg(WorkerInput),
|
||||||
WorkerOutMsg(WorkerOutput),
|
WorkerOutMsg(Result<WorkerOutput, String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CurrentDecode {
|
||||||
|
start_time: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct App {
|
pub struct App {
|
||||||
status: String,
|
status: String,
|
||||||
content: String,
|
segments: Vec<Segment>,
|
||||||
decode_in_flight: bool,
|
current_decode: Option<CurrentDecode>,
|
||||||
worker: Box<dyn Bridge<Worker>>,
|
worker: Box<dyn Bridge<Worker>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,6 +64,12 @@ async fn model_data_load() -> Result<ModelData, JsValue> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn performance_now() -> Option<f64> {
|
||||||
|
let window = web_sys::window()?;
|
||||||
|
let performance = window.performance()?;
|
||||||
|
Some(performance.now() / 1000.)
|
||||||
|
}
|
||||||
|
|
||||||
impl Component for App {
|
impl Component for App {
|
||||||
type Message = Msg;
|
type Message = Msg;
|
||||||
type Properties = ();
|
type Properties = ();
|
||||||
@ -73,8 +83,8 @@ impl Component for App {
|
|||||||
let worker = Worker::bridge(std::rc::Rc::new(cb));
|
let worker = Worker::bridge(std::rc::Rc::new(cb));
|
||||||
Self {
|
Self {
|
||||||
status,
|
status,
|
||||||
content: String::new(),
|
segments: vec![],
|
||||||
decode_in_flight: false,
|
current_decode: None,
|
||||||
worker,
|
worker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -103,18 +113,19 @@ impl Component for App {
|
|||||||
}
|
}
|
||||||
Msg::Run(sample_index) => {
|
Msg::Run(sample_index) => {
|
||||||
let sample = SAMPLE_NAMES[sample_index];
|
let sample = SAMPLE_NAMES[sample_index];
|
||||||
if self.decode_in_flight {
|
if self.current_decode.is_some() {
|
||||||
self.content = "already decoding some sample at the moment".to_string()
|
self.status = "already decoding some sample at the moment".to_string()
|
||||||
} else {
|
} else {
|
||||||
self.decode_in_flight = true;
|
let start_time = performance_now();
|
||||||
|
self.current_decode = Some(CurrentDecode { start_time });
|
||||||
self.status = format!("decoding {sample}");
|
self.status = format!("decoding {sample}");
|
||||||
self.content = String::new();
|
self.segments.clear();
|
||||||
ctx.link().send_future(async move {
|
ctx.link().send_future(async move {
|
||||||
match fetch_url(sample).await {
|
match fetch_url(sample).await {
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let value = Err(format!("decoding error: {err:?}"));
|
let output = Err(format!("decoding error: {err:?}"));
|
||||||
// Mimic a worker output to so as to release decode_in_flight
|
// Mimic a worker output to so as to release current_decode
|
||||||
Msg::WorkerOutMsg(WorkerOutput { value })
|
Msg::WorkerOutMsg(output)
|
||||||
}
|
}
|
||||||
Ok(wav_bytes) => {
|
Ok(wav_bytes) => {
|
||||||
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
|
Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
|
||||||
@ -125,10 +136,26 @@ impl Component for App {
|
|||||||
//
|
//
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
Msg::WorkerOutMsg(WorkerOutput { value }) => {
|
Msg::WorkerOutMsg(output) => {
|
||||||
self.status = "Worker responded!".to_string();
|
let dt = self.current_decode.as_ref().and_then(|current_decode| {
|
||||||
self.content = format!("{value:?}");
|
current_decode.start_time.and_then(|start_time| {
|
||||||
self.decode_in_flight = false;
|
performance_now().map(|stop_time| stop_time - start_time)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
self.current_decode = None;
|
||||||
|
match output {
|
||||||
|
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
|
||||||
|
Ok(WorkerOutput::Decoded(segments)) => {
|
||||||
|
self.status = match dt {
|
||||||
|
None => "decoding succeeded!".to_string(),
|
||||||
|
Some(dt) => format!("decoding succeeded in {:.2}s", dt),
|
||||||
|
};
|
||||||
|
self.segments = segments;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.status = format!("decoding error {err:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
Msg::WorkerInMsg(inp) => {
|
Msg::WorkerInMsg(inp) => {
|
||||||
@ -170,12 +197,30 @@ impl Component for App {
|
|||||||
{&self.status}
|
{&self.status}
|
||||||
</h2>
|
</h2>
|
||||||
{
|
{
|
||||||
if self.decode_in_flight {
|
if self.current_decode.is_some() {
|
||||||
html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
|
html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
|
||||||
} else { html!{
|
} else { html!{
|
||||||
<blockquote>
|
<blockquote>
|
||||||
<p>
|
<p>
|
||||||
{&self.content}
|
{
|
||||||
|
self.segments.iter().map(|segment| { html! {
|
||||||
|
<>
|
||||||
|
<i>
|
||||||
|
{
|
||||||
|
format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
|
||||||
|
segment.start,
|
||||||
|
segment.start + segment.duration,
|
||||||
|
segment.dr.avg_logprob,
|
||||||
|
segment.dr.no_speech_prob,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</i>
|
||||||
|
<br/ >
|
||||||
|
{&segment.dr.text}
|
||||||
|
<br/ >
|
||||||
|
</>
|
||||||
|
} }).collect::<Html>()
|
||||||
|
}
|
||||||
</p>
|
</p>
|
||||||
</blockquote>
|
</blockquote>
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ use crate::model::{Config, Whisper};
|
|||||||
use anyhow::Error as E;
|
use anyhow::Error as E;
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use rand::distributions::Distribution;
|
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
@ -59,20 +59,20 @@ pub const SUPPRESS_TOKENS: [u32; 91] = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct DecodingResult {
|
pub struct DecodingResult {
|
||||||
tokens: Vec<u32>,
|
pub tokens: Vec<u32>,
|
||||||
text: String,
|
pub text: String,
|
||||||
avg_logprob: f64,
|
pub avg_logprob: f64,
|
||||||
no_speech_prob: f64,
|
pub no_speech_prob: f64,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
compression_ratio: f64,
|
compression_ratio: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Segment {
|
pub struct Segment {
|
||||||
start: f64,
|
pub start: f64,
|
||||||
duration: f64,
|
pub duration: f64,
|
||||||
dr: DecodingResult,
|
pub dr: DecodingResult,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Decoder {
|
pub struct Decoder {
|
||||||
@ -107,7 +107,7 @@ impl Decoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode(&self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
|
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||||
let model = &self.model;
|
let model = &self.model;
|
||||||
let audio_features = model.encoder.forward(mel)?;
|
let audio_features = model.encoder.forward(mel)?;
|
||||||
console_log!("audio features: {:?}", audio_features.dims());
|
console_log!("audio features: {:?}", audio_features.dims());
|
||||||
@ -142,8 +142,7 @@ impl Decoder {
|
|||||||
let prs = (&logits / t)?.softmax(0)?;
|
let prs = (&logits / t)?.softmax(0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||||
let mut rng = rand::thread_rng();
|
distr.sample(rng) as u32
|
||||||
distr.sample(&mut rng) as u32
|
|
||||||
} else {
|
} else {
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
logits_v
|
logits_v
|
||||||
@ -179,9 +178,13 @@ impl Decoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode_with_fallback(&self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
|
fn decode_with_fallback(
|
||||||
|
&self,
|
||||||
|
segment: &Tensor,
|
||||||
|
rng: &mut StdRng,
|
||||||
|
) -> anyhow::Result<DecodingResult> {
|
||||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||||
let dr: Result<DecodingResult, _> = self.decode(segment, t);
|
let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
|
||||||
if i == TEMPERATURES.len() - 1 {
|
if i == TEMPERATURES.len() - 1 {
|
||||||
return dr;
|
return dr;
|
||||||
}
|
}
|
||||||
@ -203,6 +206,7 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||||
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let (_, _, content_frames) = mel.shape().r3()?;
|
let (_, _, content_frames) = mel.shape().r3()?;
|
||||||
let mut seek = 0;
|
let mut seek = 0;
|
||||||
let mut segments = vec![];
|
let mut segments = vec![];
|
||||||
@ -211,7 +215,7 @@ impl Decoder {
|
|||||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
|
||||||
seek += segment_size;
|
seek += segment_size;
|
||||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||||
console_log!("no speech detected, skipping {seek} {dr:?}");
|
console_log!("no speech detected, skipping {seek} {dr:?}");
|
||||||
@ -289,14 +293,15 @@ pub enum WorkerInput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct WorkerOutput {
|
pub enum WorkerOutput {
|
||||||
pub value: Result<Vec<Segment>, String>,
|
Decoded(Vec<Segment>),
|
||||||
|
WeightsLoaded,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl yew_agent::Worker for Worker {
|
impl yew_agent::Worker for Worker {
|
||||||
type Input = WorkerInput;
|
type Input = WorkerInput;
|
||||||
type Message = ();
|
type Message = ();
|
||||||
type Output = WorkerOutput;
|
type Output = Result<WorkerOutput, String>;
|
||||||
type Reach = Public<Self>;
|
type Reach = Public<Self>;
|
||||||
|
|
||||||
fn create(link: WorkerLink<Self>) -> Self {
|
fn create(link: WorkerLink<Self>) -> Self {
|
||||||
@ -311,11 +316,11 @@ impl yew_agent::Worker for Worker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
|
fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
|
||||||
let value = match msg {
|
let output = match msg {
|
||||||
WorkerInput::ModelData(md) => match Decoder::load(md) {
|
WorkerInput::ModelData(md) => match Decoder::load(md) {
|
||||||
Ok(decoder) => {
|
Ok(decoder) => {
|
||||||
self.decoder = Some(decoder);
|
self.decoder = Some(decoder);
|
||||||
Ok(vec![])
|
Ok(WorkerOutput::WeightsLoaded)
|
||||||
}
|
}
|
||||||
Err(err) => Err(format!("model creation error {err:?}")),
|
Err(err) => Err(format!("model creation error {err:?}")),
|
||||||
},
|
},
|
||||||
@ -323,10 +328,11 @@ impl yew_agent::Worker for Worker {
|
|||||||
None => Err("model has not been set".to_string()),
|
None => Err("model has not been set".to_string()),
|
||||||
Some(decoder) => decoder
|
Some(decoder) => decoder
|
||||||
.convert_and_run(&wav_bytes)
|
.convert_and_run(&wav_bytes)
|
||||||
|
.map(WorkerOutput::Decoded)
|
||||||
.map_err(|e| e.to_string()),
|
.map_err(|e| e.to_string()),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
self.link.respond(id, WorkerOutput { value });
|
self.link.respond(id, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn name_of_resource() -> &'static str {
|
fn name_of_resource() -> &'static str {
|
||||||
|
Reference in New Issue
Block a user