mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some wasm profiling. (#173)
This commit is contained in:
@ -1,5 +1,28 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub const WITH_TIMER: bool = true;
|
||||
|
||||
struct Timer {
|
||||
label: &'static str,
|
||||
}
|
||||
|
||||
impl Timer {
|
||||
fn new(label: &'static str) -> Self {
|
||||
if WITH_TIMER {
|
||||
web_sys::console::time_with_label(label);
|
||||
}
|
||||
Self { label }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Timer {
|
||||
fn drop(&mut self) {
|
||||
if WITH_TIMER {
|
||||
web_sys::console::time_end_with_label(self.label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod app;
|
||||
mod audio;
|
||||
mod model;
|
||||
|
@ -3,7 +3,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
@ -39,6 +39,36 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm
|
||||
// specific monitoring.
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let _timer = crate::Timer::new("Linear::forward");
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = {
|
||||
let _timer = crate::Timer::new("Linear::matmul");
|
||||
x.matmul(&w)?
|
||||
};
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
@ -124,6 +154,7 @@ impl MultiHeadAttention {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("MultiHeadAttention::forward");
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||
@ -197,6 +228,7 @@ impl ResidualAttentionBlock {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
@ -268,6 +300,7 @@ impl AudioEncoder {
|
||||
})
|
||||
}
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("AudioEncoder::forward");
|
||||
let x = self.conv1.forward(x)?.gelu()?;
|
||||
let x = self.conv2.forward(&x)?.gelu()?;
|
||||
let x = x.transpose(1, 2)?;
|
||||
@ -293,6 +326,7 @@ pub struct TextDecoder {
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let _timer = crate::Timer::new("TextDecoder::forward");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
|
Reference in New Issue
Block a user