From ad91415b4f5b400d3e07def99529af68eedf6387 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 15 Jul 2023 09:16:15 +0100 Subject: [PATCH] Add some wasm profiling. (#173) --- candle-wasm-example/src/lib.rs | 23 ++++++++++++++++++++ candle-wasm-example/src/model.rs | 36 +++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/candle-wasm-example/src/lib.rs b/candle-wasm-example/src/lib.rs index 54c2367c..b47d43ca 100644 --- a/candle-wasm-example/src/lib.rs +++ b/candle-wasm-example/src/lib.rs @@ -1,5 +1,28 @@ #![allow(dead_code)] +pub const WITH_TIMER: bool = true; + +struct Timer { + label: &'static str, +} + +impl Timer { + fn new(label: &'static str) -> Self { + if WITH_TIMER { + web_sys::console::time_with_label(label); + } + Self { label } + } +} + +impl Drop for Timer { + fn drop(&mut self) { + if WITH_TIMER { + web_sys::console::time_end_with_label(self.label) + } + } +} + mod app; mod audio; mod model; diff --git a/candle-wasm-example/src/model.rs b/candle-wasm-example/src/model.rs index b19ff90a..1aa98740 100644 --- a/candle-wasm-example/src/model.rs +++ b/candle-wasm-example/src/model.rs @@ -3,7 +3,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{Device, Tensor}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: @@ -39,6 +39,36 @@ impl Config { } } +// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm +// specific monitoring. +#[derive(Debug)] +struct Linear { + weight: Tensor, + bias: Option, +} + +impl Linear { + fn new(weight: Tensor, bias: Option) -> Self { + Self { weight, bias } + } + + fn forward(&self, x: &Tensor) -> candle::Result { + let _timer = crate::Timer::new("Linear::forward"); + let w = match x.dims() { + &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = { + let _timer = crate::Timer::new("Linear::matmul"); + x.matmul(&w)? + }; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) @@ -124,6 +154,7 @@ impl MultiHeadAttention { } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let _timer = crate::Timer::new("MultiHeadAttention::forward"); let q = self.query.forward(x)?; let k = self.key.forward(xa.unwrap_or(x))?; let v = self.value.forward(xa.unwrap_or(x))?; @@ -197,6 +228,7 @@ impl ResidualAttentionBlock { } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let _timer = crate::Timer::new("ResidualAttentionBlock::forward"); let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; let mut x = (x + attn)?; if let Some((attn, ln)) = &self.cross_attn { @@ -268,6 +300,7 @@ impl AudioEncoder { }) } pub fn forward(&self, x: &Tensor) -> Result { + let _timer = crate::Timer::new("AudioEncoder::forward"); let x = self.conv1.forward(x)?.gelu()?; let x = self.conv2.forward(&x)?.gelu()?; let x = x.transpose(1, 2)?; @@ -293,6 +326,7 @@ pub struct TextDecoder { impl TextDecoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { + let _timer = crate::Timer::new("TextDecoder::forward"); let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions;