mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Merge pull request #67 from LaurentMazare/whisper
Sketch the whisper model.
This commit is contained in:
@ -33,7 +33,12 @@ impl Tensor {
|
|||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
Op::Add(lhs, rhs)
|
Op::Conv1D {
|
||||||
|
arg: lhs,
|
||||||
|
kernel: rhs,
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| Op::Add(lhs, rhs)
|
||||||
| Op::Mul(lhs, rhs)
|
| Op::Mul(lhs, rhs)
|
||||||
| Op::Sub(lhs, rhs)
|
| Op::Sub(lhs, rhs)
|
||||||
| Op::Div(lhs, rhs)
|
| Op::Div(lhs, rhs)
|
||||||
@ -147,6 +152,7 @@ impl Tensor {
|
|||||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
||||||
Op::Embedding(_lhs, _rhs) => {
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||||
}
|
}
|
||||||
|
27
candle-core/src/conv.rs
Normal file
27
candle-core/src/conv.rs
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) struct ParamsConv1D {
|
||||||
|
pub(crate) b_size: Option<usize>,
|
||||||
|
// Maybe we should have a version without l_in as this bit depends on the input and not only on
|
||||||
|
// the weights.
|
||||||
|
pub(crate) l_in: usize,
|
||||||
|
pub(crate) c_out: usize,
|
||||||
|
pub(crate) c_in: usize,
|
||||||
|
pub(crate) k_size: usize,
|
||||||
|
pub(crate) padding: usize,
|
||||||
|
pub(crate) stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParamsConv1D {
|
||||||
|
pub(crate) fn l_out(&self) -> usize {
|
||||||
|
let dilation = 1;
|
||||||
|
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||||
|
let l_out = self.l_out();
|
||||||
|
match self.b_size {
|
||||||
|
None => vec![self.c_out, l_out],
|
||||||
|
Some(n) => vec![n, self.c_out, l_out],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -202,6 +202,63 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||||
|
|
||||||
|
impl<'a> Map2 for Conv1D<'a> {
|
||||||
|
const OP: &'static str = "conv1d";
|
||||||
|
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
||||||
|
&self,
|
||||||
|
inp: &[T],
|
||||||
|
inp_l: &Layout,
|
||||||
|
k: &[T],
|
||||||
|
k_l: &Layout,
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
// TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
|
||||||
|
let p = self.0;
|
||||||
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
|
let k = &k[k_l.start_offset()..];
|
||||||
|
let inp_stride = inp_l.stride();
|
||||||
|
let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
|
||||||
|
(inp_stride[0], &inp_stride[1..])
|
||||||
|
} else {
|
||||||
|
(0, inp_stride) // This value never gets used anyway
|
||||||
|
};
|
||||||
|
let k_stride = k_l.stride();
|
||||||
|
let k_over_2 = p.k_size / 2;
|
||||||
|
let l_out = p.l_out();
|
||||||
|
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||||
|
let mut dst = vec![T::zero(); dst_elems];
|
||||||
|
// The output shape is [b_size, c_out, l_out]
|
||||||
|
for b_idx in 0..p.b_size.unwrap_or(1) {
|
||||||
|
let inp_idx = b_idx * inp_stride0;
|
||||||
|
let dst_idx = b_idx * p.c_out * l_out;
|
||||||
|
for dst_c_idx in 0..p.c_out {
|
||||||
|
let dst_idx = dst_idx + dst_c_idx * l_out;
|
||||||
|
for dst_l in 0..l_out {
|
||||||
|
let dst_idx = dst_idx + dst_l;
|
||||||
|
let mut d = T::zero();
|
||||||
|
for offset in 0..p.k_size {
|
||||||
|
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
||||||
|
if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in {
|
||||||
|
let src_l = dst_l + offset - k_over_2;
|
||||||
|
for src_c_idx in 0..p.c_in {
|
||||||
|
let inp_idx =
|
||||||
|
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
|
||||||
|
let k_idx = dst_c_idx * k_stride[0]
|
||||||
|
+ src_c_idx * k_stride[1]
|
||||||
|
+ offset * k_stride[2];
|
||||||
|
d += inp[inp_idx] * k[k_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst[dst_idx] = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct MatMul((usize, usize, usize, usize));
|
struct MatMul((usize, usize, usize, usize));
|
||||||
|
|
||||||
impl Map2 for MatMul {
|
impl Map2 for MatMul {
|
||||||
@ -627,6 +684,16 @@ impl CpuStorage {
|
|||||||
WCond(pred, layout).map(t, t_l, f, f_l)
|
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let ids = self.as_slice::<u32>()?;
|
let ids = self.as_slice::<u32>()?;
|
||||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||||
|
@ -801,6 +801,16 @@ impl CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
||||||
|
@ -100,6 +100,16 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
_l: &Layout,
|
||||||
|
_kernel: &Self,
|
||||||
|
_kernel_l: &Layout,
|
||||||
|
_params: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
mod backprop;
|
mod backprop;
|
||||||
|
mod conv;
|
||||||
mod cpu_backend;
|
mod cpu_backend;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
mod cuda_backend;
|
mod cuda_backend;
|
||||||
|
@ -12,6 +12,14 @@ pub(crate) enum Op {
|
|||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
WhereCond(Tensor, Tensor, Tensor),
|
WhereCond(Tensor, Tensor, Tensor),
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Conv1D {
|
||||||
|
arg: Tensor,
|
||||||
|
kernel: Tensor,
|
||||||
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
},
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
|
|
||||||
#[allow(dead_code)] // add is currently unused.
|
#[allow(dead_code)] // add is currently unused.
|
||||||
|
@ -144,6 +144,32 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn conv1d(
|
||||||
|
&self,
|
||||||
|
l: &Layout,
|
||||||
|
kernel: &Self,
|
||||||
|
kernel_l: &Layout,
|
||||||
|
params: &crate::conv::ParamsConv1D,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(kernel, "conv1d")?;
|
||||||
|
self.same_dtype(kernel, "conv1d")?;
|
||||||
|
match (self, &kernel) {
|
||||||
|
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cpu(s))
|
||||||
|
}
|
||||||
|
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||||
|
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||||
|
Ok(Self::Cuda(s))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "conv1d",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn where_cond(
|
pub(crate) fn where_cond(
|
||||||
&self,
|
&self,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
|
@ -432,6 +432,42 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, dims, op, false))
|
Ok(from_storage(storage, dims, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||||
|
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||||
|
let (b_size, c_in, l_in) = match *self.dims() {
|
||||||
|
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
||||||
|
[c_in, l_in] => (None, c_in, l_in),
|
||||||
|
_ => todo!("proper error message"),
|
||||||
|
};
|
||||||
|
if c_in != c_in_k {
|
||||||
|
todo!("proper error message")
|
||||||
|
}
|
||||||
|
let params = crate::conv::ParamsConv1D {
|
||||||
|
b_size,
|
||||||
|
l_in,
|
||||||
|
c_out,
|
||||||
|
c_in,
|
||||||
|
k_size,
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
};
|
||||||
|
let storage =
|
||||||
|
self.storage
|
||||||
|
.conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?;
|
||||||
|
let op = if self.track_op() || kernel.track_op() {
|
||||||
|
Some(Op::Conv1D {
|
||||||
|
arg: self.clone(),
|
||||||
|
kernel: kernel.clone(),
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let out_dims = params.out_dims();
|
||||||
|
Ok(from_storage(storage, out_dims, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||||
let a_dims = self.shape().dims();
|
let a_dims = self.shape().dims();
|
||||||
let b_dims = rhs.shape().dims();
|
let b_dims = rhs.shape().dims();
|
||||||
|
13
candle-examples/examples/whisper/extract_weights.py
Normal file
13
candle-examples/examples/whisper/extract_weights.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Get the checkpoint from
|
||||||
|
# https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
data = torch.load("tiny.en.pt")
|
||||||
|
weights = {}
|
||||||
|
for k, v in data["model_state_dict"].items():
|
||||||
|
weights[k] = v.contiguous()
|
||||||
|
print(k, v.shape, v.dtype)
|
||||||
|
save_file(weights, "tiny.en.safetensors")
|
||||||
|
print(data["dims"])
|
573
candle-examples/examples/whisper/main.rs
Normal file
573
candle-examples/examples/whisper/main.rs
Normal file
@ -0,0 +1,573 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||||
|
// TODO:
|
||||||
|
// - kv-cache support?
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
|
use clap::Parser;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
struct VarBuilder<'a> {
|
||||||
|
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
|
||||||
|
dtype: DType,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> VarBuilder<'a> {
|
||||||
|
pub fn from_safetensors(
|
||||||
|
safetensors: Vec<SafeTensors<'a>>,
|
||||||
|
dtype: DType,
|
||||||
|
device: Device,
|
||||||
|
) -> Self {
|
||||||
|
let mut routing = HashMap::new();
|
||||||
|
for (index, sf) in safetensors.iter().enumerate() {
|
||||||
|
for k in sf.names() {
|
||||||
|
routing.insert(k.to_string(), index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
safetensors: Some((routing, safetensors)),
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn zeros(dtype: DType, device: Device) -> Self {
|
||||||
|
Self {
|
||||||
|
safetensors: None,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
||||||
|
let s: Shape = s.into();
|
||||||
|
match &self.safetensors {
|
||||||
|
None => Tensor::zeros(s, self.dtype, &self.device),
|
||||||
|
Some((routing, safetensors)) => {
|
||||||
|
// Unwrap or 0 just to let the proper error flow.
|
||||||
|
let index = routing.get(tensor_name).unwrap_or(&0);
|
||||||
|
let tensor = safetensors[*index]
|
||||||
|
.tensor(tensor_name, &self.device)?
|
||||||
|
.to_dtype(self.dtype)?;
|
||||||
|
if *tensor.shape() != s {
|
||||||
|
let msg = format!("shape mismatch for {tensor_name}");
|
||||||
|
Err(candle::Error::UnexpectedShape {
|
||||||
|
msg,
|
||||||
|
expected: s,
|
||||||
|
got: tensor.shape().clone(),
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum HiddenAct {
|
||||||
|
Gelu,
|
||||||
|
Relu,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HiddenAct {
|
||||||
|
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Gelu => xs.gelu(),
|
||||||
|
Self::Relu => xs.relu(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
struct Config {
|
||||||
|
n_mels: usize,
|
||||||
|
n_audio_ctx: usize,
|
||||||
|
n_audio_state: usize,
|
||||||
|
n_audio_head: usize,
|
||||||
|
n_audio_layer: usize,
|
||||||
|
n_vocab: usize,
|
||||||
|
n_text_ctx: usize,
|
||||||
|
n_text_state: usize,
|
||||||
|
n_text_head: usize,
|
||||||
|
n_text_layer: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn tiny_en() -> Self {
|
||||||
|
Self {
|
||||||
|
n_mels: 80,
|
||||||
|
n_vocab: 51864,
|
||||||
|
n_audio_ctx: 1500,
|
||||||
|
n_audio_state: 384,
|
||||||
|
n_audio_head: 6,
|
||||||
|
n_audio_layer: 4,
|
||||||
|
n_text_ctx: 448,
|
||||||
|
n_text_state: 384,
|
||||||
|
n_text_head: 6,
|
||||||
|
n_text_layer: 4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Embedding {
|
||||||
|
embeddings: Tensor,
|
||||||
|
hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
embeddings,
|
||||||
|
hidden_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||||
|
Ok(Self::new(embeddings, hidden_size))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut final_dims = indexes.dims().to_vec();
|
||||||
|
final_dims.push(self.hidden_size);
|
||||||
|
let indexes = indexes.flatten_all()?;
|
||||||
|
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||||
|
let values = values.reshape(final_dims)?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Linear {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||||
|
let bias = vb.get(size2, &format!("{p}.bias"))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
bias: Some(bias),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||||
|
Ok(Self { weight, bias: None })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let (bsize, _, _) = x.shape().r3()?;
|
||||||
|
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
struct ConvConfig {
|
||||||
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConvConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
padding: 0,
|
||||||
|
stride: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Conv1D {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
config: ConvConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conv1D {
|
||||||
|
fn load(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: ConvConfig,
|
||||||
|
p: &str,
|
||||||
|
vb: &VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb.get(
|
||||||
|
(out_channels, in_channels, kernel_size),
|
||||||
|
&format!("{p}.weight"),
|
||||||
|
)?;
|
||||||
|
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
bias: Some(bias),
|
||||||
|
config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_no_bias(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: ConvConfig,
|
||||||
|
p: &str,
|
||||||
|
vb: &VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb.get(
|
||||||
|
(out_channels, in_channels, kernel_size),
|
||||||
|
&format!("{p}.weight"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
bias: None,
|
||||||
|
config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => {
|
||||||
|
let b = bias.shape().r1()?;
|
||||||
|
let bias = bias.reshape((1, b, 1))?;
|
||||||
|
Ok(x.broadcast_add(&bias)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Dropout {
|
||||||
|
pr: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dropout {
|
||||||
|
fn new(pr: f64) -> Self {
|
||||||
|
Self { pr }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
// TODO
|
||||||
|
Ok(x.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This layer norm version handles both weight and bias so removes the mean.
|
||||||
|
struct LayerNorm {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerNorm {
|
||||||
|
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||||
|
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
eps: 1e-5,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||||
|
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||||
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
|
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
|
let x = x_normed
|
||||||
|
.broadcast_mul(&self.weight)?
|
||||||
|
.broadcast_add(&self.bias)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||||
|
struct MultiHeadAttention {
|
||||||
|
query: Linear,
|
||||||
|
key: Linear,
|
||||||
|
value: Linear,
|
||||||
|
out: Linear,
|
||||||
|
n_head: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiHeadAttention {
|
||||||
|
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
|
||||||
|
let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
|
||||||
|
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
|
||||||
|
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
|
||||||
|
Ok(Self {
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
out,
|
||||||
|
n_head,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let q = self.query.forward(x)?;
|
||||||
|
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||||
|
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||||
|
let wv = self.qkv_attention(&q, &k, &v)?;
|
||||||
|
let out = self.out.forward(&wv)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
|
||||||
|
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||||
|
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_, _, n_state) = q.shape().r3()?;
|
||||||
|
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||||
|
let q = (self.reshape_head(q)? * scale)?;
|
||||||
|
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||||
|
let v = self.reshape_head(v)?.contiguous()?;
|
||||||
|
let qk = q.matmul(&k)?;
|
||||||
|
let w = qk.softmax(qk.rank() - 1)?;
|
||||||
|
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
|
||||||
|
Ok(wv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||||
|
struct ResidualAttentionBlock {
|
||||||
|
attn: MultiHeadAttention,
|
||||||
|
attn_ln: LayerNorm,
|
||||||
|
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
||||||
|
mlp_linear1: Linear,
|
||||||
|
mlp_linear2: Linear,
|
||||||
|
mlp_ln: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResidualAttentionBlock {
|
||||||
|
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
|
||||||
|
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
|
||||||
|
let cross_attn = if ca {
|
||||||
|
let cross_attn =
|
||||||
|
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
|
||||||
|
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
|
||||||
|
Some((cross_attn, cross_attn_ln))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let n_mlp = n_state * 4;
|
||||||
|
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
|
||||||
|
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?;
|
||||||
|
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?;
|
||||||
|
Ok(Self {
|
||||||
|
attn,
|
||||||
|
attn_ln,
|
||||||
|
cross_attn,
|
||||||
|
mlp_linear1,
|
||||||
|
mlp_linear2,
|
||||||
|
mlp_ln,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
||||||
|
let mut x = (x + attn)?;
|
||||||
|
if let Some((attn, ln)) = &self.cross_attn {
|
||||||
|
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
|
||||||
|
}
|
||||||
|
let mlp = self.mlp_linear2.forward(
|
||||||
|
&self
|
||||||
|
.mlp_linear1
|
||||||
|
.forward(&self.mlp_ln.forward(&x)?)?
|
||||||
|
.gelu()?,
|
||||||
|
)?;
|
||||||
|
Ok((x + mlp)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||||
|
let max_timescale = 10000f32;
|
||||||
|
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||||
|
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||||
|
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||||
|
.collect();
|
||||||
|
let arange: Vec<_> = (0..length).map(|c| c as f32).collect();
|
||||||
|
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||||
|
let arange = Tensor::new(arange.as_slice(), &Device::Cpu)?.unsqueeze(1)?;
|
||||||
|
let sh = (length, channels / 2);
|
||||||
|
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
||||||
|
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
||||||
|
Ok(sincos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||||
|
struct AudioEncoder {
|
||||||
|
conv1: Conv1D,
|
||||||
|
conv2: Conv1D,
|
||||||
|
positional_embedding: Tensor,
|
||||||
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
|
ln_post: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioEncoder {
|
||||||
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let n_state = cfg.n_audio_state;
|
||||||
|
let n_head = cfg.n_audio_head;
|
||||||
|
let n_ctx = cfg.n_audio_ctx;
|
||||||
|
let cfg1 = ConvConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
};
|
||||||
|
let cfg2 = ConvConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 2,
|
||||||
|
};
|
||||||
|
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
||||||
|
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||||
|
let positional_embedding = if true {
|
||||||
|
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
|
||||||
|
} else {
|
||||||
|
/* The positional embeddings could be regenerated via the following. */
|
||||||
|
sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
|
||||||
|
};
|
||||||
|
let blocks = (0..cfg.n_audio_layer)
|
||||||
|
.map(|i| {
|
||||||
|
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?;
|
||||||
|
Ok(Self {
|
||||||
|
conv1,
|
||||||
|
conv2,
|
||||||
|
positional_embedding,
|
||||||
|
blocks,
|
||||||
|
ln_post,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = self.conv1.forward(x)?.gelu()?;
|
||||||
|
let x = self.conv2.forward(&x)?.gelu()?;
|
||||||
|
let x = x.transpose(1, 2)?;
|
||||||
|
let mut x = x.broadcast_add(&self.positional_embedding)?;
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
x = block.forward(&x, None)?
|
||||||
|
}
|
||||||
|
let x = self.ln_post.forward(&x)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||||
|
struct TextDecoder {
|
||||||
|
token_embedding: Embedding,
|
||||||
|
positional_embedding: Tensor,
|
||||||
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
|
ln: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextDecoder {
|
||||||
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let n_state = cfg.n_text_state;
|
||||||
|
let n_head = cfg.n_text_head;
|
||||||
|
let n_ctx = cfg.n_text_ctx;
|
||||||
|
let token_embedding =
|
||||||
|
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
|
||||||
|
let positional_embedding =
|
||||||
|
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
|
||||||
|
let blocks = (0..cfg.n_text_layer)
|
||||||
|
.map(|i| {
|
||||||
|
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||||
|
Ok(Self {
|
||||||
|
token_embedding,
|
||||||
|
positional_embedding,
|
||||||
|
blocks,
|
||||||
|
ln,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||||
|
let x_dims = x.dims();
|
||||||
|
let last = x_dims[x_dims.len() - 1];
|
||||||
|
let token_embedding = self.token_embedding.forward(x)?;
|
||||||
|
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||||
|
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
x = block.forward(&x, Some(xa))?;
|
||||||
|
}
|
||||||
|
let x = self.ln.forward(&x)?;
|
||||||
|
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
||||||
|
let logits = x.matmul(&w.t()?)?;
|
||||||
|
Ok(logits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||||
|
struct Whisper {
|
||||||
|
encoder: AudioEncoder,
|
||||||
|
decoder: TextDecoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Whisper {
|
||||||
|
fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let encoder = AudioEncoder::load("encoder", vb, cfg)?;
|
||||||
|
let decoder = TextDecoder::load("decoder", vb, cfg)?;
|
||||||
|
Ok(Self { encoder, decoder })
|
||||||
|
}
|
||||||
|
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||||
|
let enc = self.encoder.forward(mel)?;
|
||||||
|
let dec = self.decoder.forward(tokens, &enc)?;
|
||||||
|
Ok(dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weights: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
input: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = if args.cpu {
|
||||||
|
Device::Cpu
|
||||||
|
} else {
|
||||||
|
Device::new_cuda(0)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
|
||||||
|
let input = input.deserialize()?;
|
||||||
|
let tokens = input.tensor("tokens", &device)?.to_dtype(DType::U32)?;
|
||||||
|
let mel = input.tensor("mel", &device)?;
|
||||||
|
|
||||||
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||||
|
let weights = weights.deserialize()?;
|
||||||
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||||
|
let cfg = Config::tiny_en();
|
||||||
|
|
||||||
|
let model = Whisper::load(&vb, &cfg)?;
|
||||||
|
let logits = model.forward(&mel, &tokens)?;
|
||||||
|
println!("{logits}");
|
||||||
|
println!("python logits: {}", input.tensor("dec", &device)?);
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user