Merge branch 'main' into readme_fixes

This commit is contained in:
Matt
2023-08-10 00:19:20 +01:00
committed by GitHub
8 changed files with 106 additions and 66 deletions

View File

@ -30,7 +30,7 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.13", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.5", package = "candle-gemm" }
gemm = { version = "0.15.6", package = "candle-gemm" }
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }

View File

@ -0,0 +1,24 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
pub const N_ITERS: usize = 5;
fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?;
let res = inp.conv1d(&w, 0, 1);
println!("{res:?}");
let start = std::time::Instant::now();
for i in 0..N_ITERS {
let res = inp.conv1d(&w, 0, 1);
println!("{i} {res:?}");
}
println!("{:?}", start.elapsed() / N_ITERS as u32);
Ok(())
}

View File

@ -1,6 +1,6 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: Option<usize>,
pub(crate) b_size: usize,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
@ -19,10 +19,7 @@ impl ParamsConv1D {
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
match self.b_size {
None => vec![self.c_out, l_out],
Some(n) => vec![n, self.c_out, l_out],
}
vec![self.b_size, self.c_out, l_out]
}
}

View File

@ -1037,10 +1037,10 @@ impl<'a> Map2 for Conv1D<'a> {
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
for b_idx in 0..p.b_size.unwrap_or(1) {
for b_idx in 0..p.b_size {
let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * l_out;
for dst_c_idx in 0..p.c_out {

View File

@ -904,7 +904,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dims = shape.dims();
let el = shape.elem_count();
let l_out = p.l_out();
let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
let dst_el = p.c_out * l_out * p.b_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.

View File

@ -773,18 +773,7 @@ impl Tensor {
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = match *self.dims() {
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
[c_in, l_in] => (None, c_in, l_in),
_ => Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
k_shape: kernel.shape().clone(),
padding,
stride,
msg: "input rank is not 2 or 3",
}
.bt())?,
};
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
// https://github.com/openai/whisper/blob/main/whisper/model.py
// TODO:
// - kv-cache support?
@ -31,9 +30,6 @@ const HOP_LENGTH: usize = 160;
const CHUNK_LENGTH: usize = 30;
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
const NO_SPEECH_THRESHOLD: f64 = 0.6;
const LOGPROB_THRESHOLD: f64 = -1.0;
@ -44,7 +40,6 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
const NO_TIMESTAMP_TOKEN: u32 = 50362;
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [
@ -56,6 +51,7 @@ const SUPPRESS_TOKENS: [u32; 91] = [
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
@ -66,6 +62,7 @@ struct DecodingResult {
compression_ratio: f64,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct Segment {
start: f64,
@ -243,10 +240,25 @@ struct Args {
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());

View File

@ -1,8 +1,5 @@
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use candle::{Device, Result, Tensor};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@ -22,6 +19,7 @@ pub struct Config {
}
impl Config {
#[allow(dead_code)]
pub fn tiny_en() -> Self {
Self {
num_mel_bins: 80,
@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
//
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn conv1d(
@ -66,32 +80,6 @@ fn conv1d(
Ok(Conv1d::new(weight, Some(bias), config))
}
fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
Ok(Conv1d::new(weight, None, config))
}
struct Dropout {
pr: f64,
}
impl Dropout {
fn new(pr: f64) -> Self {
Self { pr }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
}
}
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
@ -105,10 +93,12 @@ struct MultiHeadAttention {
value: Linear,
out: Linear,
n_head: usize,
span: tracing::Span,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
@ -119,10 +109,12 @@ impl MultiHeadAttention {
value,
out,
n_head,
span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
@ -134,7 +126,7 @@ impl MultiHeadAttention {
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
x.reshape(target_dims)?.transpose(1, 2)
}
fn qkv_attention(
@ -168,10 +160,12 @@ struct ResidualAttentionBlock {
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
span: tracing::Span,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
@ -192,10 +186,12 @@ impl ResidualAttentionBlock {
mlp_linear1,
mlp_linear2,
mlp_ln,
span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
@ -207,7 +203,7 @@ impl ResidualAttentionBlock {
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
Ok((x + mlp)?)
x + mlp
}
}
@ -234,10 +230,16 @@ pub struct AudioEncoder {
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
span: tracing::Span,
conv1_span: tracing::Span,
conv2_span: tracing::Span,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
@ -264,11 +266,22 @@ impl AudioEncoder {
positional_embedding,
blocks,
ln_post,
conv1_span,
conv2_span,
span,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?;
let _enter = self.span.enter();
let x = {
let _enter = self.conv1_span.enter();
self.conv1.forward(x)?.gelu()?
};
let x = {
let _enter = self.conv2_span.enter();
self.conv2.forward(&x)?.gelu()?
};
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
@ -288,10 +301,12 @@ pub struct TextDecoder {
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
span: tracing::Span,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
@ -314,10 +329,12 @@ impl TextDecoder {
blocks,
ln,
mask,
span,
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
@ -354,6 +371,7 @@ impl Whisper {
})
}
#[allow(dead_code)]
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;