Add some wasm profiling. (#173)

This commit is contained in:
Laurent Mazare
2023-07-15 09:16:15 +01:00
committed by GitHub
parent 66750f9827
commit ad91415b4f
2 changed files with 58 additions and 1 deletions

View File

@ -1,5 +1,28 @@
#![allow(dead_code)] #![allow(dead_code)]
pub const WITH_TIMER: bool = true;
struct Timer {
label: &'static str,
}
impl Timer {
fn new(label: &'static str) -> Self {
if WITH_TIMER {
web_sys::console::time_with_label(label);
}
Self { label }
}
}
impl Drop for Timer {
fn drop(&mut self) {
if WITH_TIMER {
web_sys::console::time_end_with_label(self.label)
}
}
}
mod app; mod app;
mod audio; mod audio;
mod model; mod model;

View File

@ -3,7 +3,7 @@
// back when using RUST_LIB_BACKTRACE=1. // back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result; use anyhow::Result;
use candle::{Device, Tensor}; use candle::{Device, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
// The names in comments correspond to the original implementation: // The names in comments correspond to the original implementation:
@ -39,6 +39,36 @@ impl Config {
} }
} }
// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm
// specific monitoring.
#[derive(Debug)]
struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _timer = crate::Timer::new("Linear::forward");
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = {
let _timer = crate::Timer::new("Linear::matmul");
x.matmul(&w)?
};
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?; let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size)) Ok(Embedding::new(embeddings, hidden_size))
@ -124,6 +154,7 @@ impl MultiHeadAttention {
} }
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> { fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let _timer = crate::Timer::new("MultiHeadAttention::forward");
let q = self.query.forward(x)?; let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?; let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?; let v = self.value.forward(xa.unwrap_or(x))?;
@ -197,6 +228,7 @@ impl ResidualAttentionBlock {
} }
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> { fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let mut x = (x + attn)?; let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn { if let Some((attn, ln)) = &self.cross_attn {
@ -268,6 +300,7 @@ impl AudioEncoder {
}) })
} }
pub fn forward(&self, x: &Tensor) -> Result<Tensor> { pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _timer = crate::Timer::new("AudioEncoder::forward");
let x = self.conv1.forward(x)?.gelu()?; let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?; let x = self.conv2.forward(&x)?.gelu()?;
let x = x.transpose(1, 2)?; let x = x.transpose(1, 2)?;
@ -293,6 +326,7 @@ pub struct TextDecoder {
impl TextDecoder { impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let _timer = crate::Timer::new("TextDecoder::forward");
let n_state = cfg.d_model; let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads; let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions; let n_ctx = cfg.max_target_positions;